mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
+2
-1
@@ -47,7 +47,8 @@
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -47,7 +47,8 @@
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -47,7 +47,8 @@
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,31 +25,41 @@ using namespace cute;
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
OutputTileThreadMap,
|
||||
T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
OutputTileThreadMap,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
OutputTileThreadMap,
|
||||
T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
OutputTileThreadMap,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
OutputTileThreadMap,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
@@ -56,15 +67,11 @@ protected:
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::Tensor const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(
|
||||
tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor,
|
||||
ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor,
|
||||
RowOrScalarLoad<T>>) {
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
@@ -102,24 +109,28 @@ protected:
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::multiplies,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
@@ -146,26 +157,30 @@ public:
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::
|
||||
Sm80EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
@@ -190,7 +205,7 @@ public:
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -202,35 +217,40 @@ private:
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::minus,
|
||||
float,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
using EVTComputeScaleB = cutlass::epilogue::threadblock::
|
||||
Sm80EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::
|
||||
Sm80EVT<ComputeScaleBiasA, ScaleA, EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
static ArgumentType prepare_args(
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -257,7 +277,7 @@ public:
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -272,7 +292,9 @@ private:
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::multiplies,
|
||||
int32_t,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
@@ -280,35 +302,41 @@ private:
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::minus,
|
||||
float,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
using EVTComputeScaleB = cutlass::epilogue::threadblock::
|
||||
Sm80EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::
|
||||
Sm80EVT<ComputeScaleBiasA, ScaleA, EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
static ArgumentType prepare_args(
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -324,4 +352,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c2x
|
||||
}; // namespace fastdeploy::c2x
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,24 +25,28 @@ namespace fastdeploy::c3x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T> struct identity {
|
||||
template <typename T>
|
||||
struct identity {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T lhs) const { return lhs; }
|
||||
};
|
||||
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct TrivialEpilogue {
|
||||
private:
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
using Compute = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
|
||||
cutlass::epilogue::thread::Identity,
|
||||
ElementD,
|
||||
ElementAcc,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
template <typename... Args> static ArgumentType prepare_args(Args... args) {
|
||||
template <typename... Args>
|
||||
static ArgumentType prepare_args(Args... args) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
@@ -52,38 +57,60 @@ public:
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>,
|
||||
EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>,
|
||||
EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/,
|
||||
TileShape,
|
||||
T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
@@ -142,24 +169,28 @@ protected:
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::multiplies,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
@@ -186,7 +217,7 @@ public:
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -194,17 +225,21 @@ private:
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
@@ -229,7 +264,7 @@ public:
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueColumnBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -237,17 +272,21 @@ private:
|
||||
using Bias = typename SUPER::template ColLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
@@ -275,7 +314,7 @@ public:
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -287,33 +326,39 @@ private:
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::minus,
|
||||
float,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::
|
||||
Sm90EVT<ComputeScaleBiasA, ScaleA, EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
static ArgumentType prepare_args(
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -340,7 +385,7 @@ public:
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
@@ -355,7 +400,9 @@ private:
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::multiplies,
|
||||
int32_t,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
@@ -363,33 +410,40 @@ private:
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::minus,
|
||||
float,
|
||||
int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::multiply_add,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::
|
||||
Sm90EVT<ComputeScaleBiasA, ScaleA, EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
static ArgumentType prepare_args(
|
||||
paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -414,24 +468,28 @@ public:
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::multiplies,
|
||||
float,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::multiplies,
|
||||
ElementD,
|
||||
float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
@@ -441,7 +499,8 @@ public:
|
||||
|
||||
static ArgumentType prepare_args(float const *const *a_scales_ptr,
|
||||
float const *const *b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
bool a_col_broadcast,
|
||||
bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
@@ -452,4 +511,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c3x
|
||||
}; // namespace fastdeploy::c3x
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,18 +18,20 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
\brief Functor performing linear combination with a maximum operation used by
|
||||
epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -46,60 +48,53 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace thread {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
{
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b) {
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
__forceinline__ __device__ float tanh_opt(float x) {
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
return fast_tanh(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static bool const kIsHeavy = true;
|
||||
struct GELU_taylor<float> {
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const {
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
return float(
|
||||
cutlass::constants::half<float>() * z *
|
||||
(cutlass::constants::one<float>() +
|
||||
tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
return float(cutlass::constants::half<float>() * z
|
||||
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const
|
||||
{
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const {
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
+289
-279
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,20 +18,23 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
||||
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one
|
||||
scaling factor per row, and one per column.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
||||
original file:
|
||||
3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
||||
|
||||
*/
|
||||
|
||||
@@ -46,305 +49,312 @@
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "common/quantization.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol
|
||||
{
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
template <typename ThreadblockShape_,
|
||||
int ThreadCount,
|
||||
typename ScaleTileIterator_,
|
||||
typename OutputTileIterator_,
|
||||
typename ElementAccumulator_,
|
||||
typename ElementCompute_,
|
||||
typename ElementwiseFunctor_,
|
||||
bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol {
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
static int const kThreadsPerRow =
|
||||
OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow =
|
||||
(OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
/// Argument structure
|
||||
struct Arguments {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments()
|
||||
: batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(0),
|
||||
batch_stride_C(0),
|
||||
batch_stride_D(0) {}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_,
|
||||
int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_,
|
||||
int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_),
|
||||
batch_stride_alpha(batch_stride_alpha_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_D(batch_stride_D_) {}
|
||||
};
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_, int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(batch_stride_alpha_)
|
||||
, batch_stride_C(batch_stride_C_)
|
||||
, batch_stride_D(batch_stride_D_)
|
||||
{
|
||||
}
|
||||
};
|
||||
struct Params {
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
struct Params
|
||||
{
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise),
|
||||
batch_stride_alpha(args.batch_stride_alpha),
|
||||
batch_stride_C(args.batch_stride_C),
|
||||
batch_stride_D(args.batch_stride_D) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
/// Shared storage
|
||||
struct SharedStorage {};
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise)
|
||||
, batch_stride_alpha(args.batch_stride_alpha)
|
||||
, batch_stride_C(args.batch_stride_C)
|
||||
, batch_stride_D(args.batch_stride_D)
|
||||
{
|
||||
}
|
||||
};
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
ElementAccumulator beta_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
int column_offset_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
int column_offset_;
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(
|
||||
Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col,
|
||||
typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D,
|
||||
common::QuantMode quant_option,
|
||||
AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col,
|
||||
typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0,
|
||||
0),
|
||||
int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0,
|
||||
0))
|
||||
: params_(params),
|
||||
shared_storage_(shared_storage),
|
||||
extent_(problem_size),
|
||||
elementwise_(params.elementwise),
|
||||
per_token_quant_(quant_option.hasPerTokenScaling()),
|
||||
per_channel_quant_(quant_option.hasPerChannelScaling()),
|
||||
ptr_alpha_row_(ptr_alpha_row),
|
||||
ptr_alpha_col_(ptr_alpha_col),
|
||||
iterator_alpha_col_(params_alpha_col,
|
||||
ptr_alpha_col,
|
||||
problem_size,
|
||||
thread_idx,
|
||||
threadblock_offset),
|
||||
iterator_C_(
|
||||
params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
|
||||
iterator_D_(
|
||||
params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
|
||||
extent_real_(problem_size_real) {
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr
|
||||
: params.elementwise.beta);
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, common::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params)
|
||||
, shared_storage_(shared_storage)
|
||||
, extent_(problem_size)
|
||||
, elementwise_(params.elementwise)
|
||||
, per_token_quant_(quant_option.hasPerTokenScaling())
|
||||
, per_channel_quant_(quant_option.hasPerChannelScaling())
|
||||
, ptr_alpha_row_(ptr_alpha_row)
|
||||
, ptr_alpha_col_(ptr_alpha_col)
|
||||
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
|
||||
, extent_real_(problem_size_real)
|
||||
{
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator())
|
||||
{
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
|
||||
{
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
|
||||
{
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
if (beta_ == ElementAccumulator()) {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices)
|
||||
{ ///< Total number of split-K slices
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx)
|
||||
{
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(
|
||||
int split_k_index, ///< Index of this threadblock within split-K
|
||||
///< partitioned scheme
|
||||
int split_k_slices) { ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx) {
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx *
|
||||
params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator
|
||||
/// slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue() {
|
||||
if (per_channel_quant_) {
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx) {
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale !=
|
||||
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx) {
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_) {
|
||||
int thread_offset_row =
|
||||
iterator_D_.thread_start_row() +
|
||||
OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_,
|
||||
ptr_alpha_row_ + thread_offset_row,
|
||||
thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx,
|
||||
int row_idx,
|
||||
int column_idx,
|
||||
int frag_idx,
|
||||
AccumulatorFragment const& accum) {
|
||||
NumericArrayConverter<ElementCompute,
|
||||
ElementAccumulator,
|
||||
kElementsPerAccess>
|
||||
source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_) {
|
||||
ComputeFragment alpha_col =
|
||||
reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(
|
||||
result, alpha_col, element_alpha_row_);
|
||||
} else {
|
||||
result = per_token_scale_accumulator_(
|
||||
result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue()
|
||||
{
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess>
|
||||
output_converter;
|
||||
OutputVector& output =
|
||||
reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx) {
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum,
|
||||
ComputeFragment const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx)
|
||||
{
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
return result;
|
||||
}
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
|
||||
{
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum,
|
||||
AlphaScaleElementType const& scale_col,
|
||||
AlphaScaleElementType const& scale_row) {
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i) {
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx)
|
||||
{
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_)
|
||||
{
|
||||
int thread_offset_row
|
||||
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
||||
{
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx)
|
||||
{
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
+163
-158
@@ -1,12 +1,12 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
@@ -18,23 +18,26 @@
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
The epilogue rearranges the result of a matrix product through shared memory
|
||||
to match canonical tensor layouts in global memory. Epilogues support
|
||||
conversion and reduction operations.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
||||
original file:
|
||||
3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
||||
|
||||
*/
|
||||
|
||||
@@ -80,35 +83,45 @@
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace epilogue {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
namespace detail {
|
||||
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
|
||||
ThreadMap>
|
||||
{
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared
|
||||
/// memory bank conflicts.
|
||||
template <typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
int32_t,
|
||||
8,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
ThreadMap> {
|
||||
using WarpTileIterator =
|
||||
cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape,
|
||||
InstructionShape,
|
||||
int32_t,
|
||||
32,
|
||||
16,
|
||||
8,
|
||||
8>;
|
||||
|
||||
using SharedLoadIterator
|
||||
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
using SharedLoadIterator = cutlass::epilogue::threadblock::
|
||||
SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -116,167 +129,159 @@ struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShap
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
|
||||
{
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = int32_t;
|
||||
using Element = int32_t;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
||||
static int const kAlignment =
|
||||
ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
||||
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
||||
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
/// Fragment object
|
||||
using Fragment =
|
||||
Array<Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
|
||||
ThreadMap::Iterations::kGroup *
|
||||
ThreadMap::Iterations::kCluster *
|
||||
ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
||||
/// Memory access size
|
||||
using AccessType =
|
||||
AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
||||
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType = AlignedArray<Element,
|
||||
const_min(128 / sizeof_bits<Element>::value,
|
||||
ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
||||
static int const kLoadsPerAccess =
|
||||
AccessType::kElements / LoadType::kElements;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Byte-level pointer
|
||||
LoadType const* pointers_[kLoadsPerAccess];
|
||||
/// Byte-level pointer
|
||||
LoadType const* pointers_[kLoadsPerAccess];
|
||||
|
||||
/// Stride along adjacent rows in units of LoadType
|
||||
int stride_;
|
||||
/// Stride along adjacent rows in units of LoadType
|
||||
int stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
|
||||
: stride_((ref.stride(0) / LoadType::kElements))
|
||||
{
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
|
||||
: stride_((ref.stride(0) / LoadType::kElements)) {
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
// Initialize pointers
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
||||
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
||||
|
||||
// Initialize pointers
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
||||
int col_idx =
|
||||
(thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
||||
int bank_offset =
|
||||
(col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
||||
|
||||
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
||||
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
||||
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
||||
|
||||
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
||||
|
||||
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
||||
}
|
||||
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] += pointer_offset / LoadType::kElements;
|
||||
}
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
||||
pointers_[i] += pointer_offset / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i]
|
||||
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
||||
}
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& offset) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i) {
|
||||
pointers_[i] += offset.row() * Shape::kRow * stride_ +
|
||||
offset.column() * Shape::kColumn / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
||||
{
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
|
||||
++cluster) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
|
||||
{
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ +
|
||||
group * ThreadMap::Delta::kGroup * stride_ +
|
||||
cluster * ThreadMap::Delta::kCluster * stride_ +
|
||||
pointer_offset / LoadType::kElements;
|
||||
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow *
|
||||
(group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn;
|
||||
++column) {
|
||||
int frag_idx =
|
||||
frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
|
||||
{
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v) {
|
||||
int vector_idx = (column * ThreadMap::Delta::kColumn /
|
||||
kElementsPerAccess * kLoadsPerAccess);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
|
||||
{
|
||||
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
||||
|
||||
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
|
||||
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
|
||||
+ pointer_offset / LoadType::kElements;
|
||||
|
||||
int frag_row_idx
|
||||
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
|
||||
{
|
||||
|
||||
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v)
|
||||
{
|
||||
|
||||
int vector_idx
|
||||
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
||||
|
||||
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
||||
|
||||
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
frag_ptr[frag_idx * kLoadsPerAccess + v] =
|
||||
memory_pointer[vector_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const
|
||||
{
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user