Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -23,129 +23,160 @@
using namespace cute;
template <int kStages, class GemmType, class OutputType, class SmemLayoutA,
template <int kStages,
class GemmType,
class OutputType,
class SmemLayoutA,
class SmemLayoutE,
class SmemLayoutB, class SmemLayoutC>
class SmemLayoutB,
class SmemLayoutC>
struct SharedStorage {
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<uint32_t, cute::cosize_v<SmemLayoutE>> smem_e;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
union {
struct {
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutA>> smem_a;
cute::array_aligned<uint32_t, cute::cosize_v<SmemLayoutE>> smem_e;
cute::array_aligned<GemmType, cute::cosize_v<SmemLayoutB>> smem_b;
};
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutC>> smem_c;
};
struct {
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline;
};
};
template<int kBlockM_, int kBlockN_, int kBlockK_,
int kNWarps_, int kStages_,
int kTiles_, int M_,
int TokenPackSize_,
int TAIL_N_ = 0,
int kClusterM_ = 1,
typename elem_type=cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
template <int kBlockM_,
int kBlockN_,
int kBlockK_,
int kNWarps_,
int kStages_,
int kTiles_,
int M_,
int TokenPackSize_,
int TAIL_N_ = 0,
int kClusterM_ = 1,
typename elem_type = cutlass::float_e4m3_t,
typename OutputType = cutlass::bfloat16_t>
struct Kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using Element = elem_type;
using ElementAccum = float;
using ElementOutput = OutputType;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static_assert(kNWarps_ == 12);
static_assert(kNWarps_ == 12);
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int TAIL_N = TAIL_N_;
static constexpr int M = M_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;
static constexpr int kTiles = kTiles_;
static constexpr int TokenPackSize = TokenPackSize_;
static constexpr int TAIL_N = TAIL_N_;
static constexpr int M = M_;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kBlockK>>;
using TileShape_MNK_TAIL = Shape<Int<kBlockM>, Int<TAIL_N>, Int<kBlockK>>;
static constexpr int kClusterM = kClusterM_;
using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
static constexpr int kStages = kStages_;
static_assert(kStages > 1);
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::
ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using Mma = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK>());
using Mma = decltype(cute::GMMA::ss_op_selector_sparse<Element,
Element,
ElementAccum,
TileShape_MNK>());
using Mma_TAIL = decltype(cute::GMMA::ss_op_selector_sparse<Element, Element, ElementAccum, TileShape_MNK_TAIL>());
using Mma_TAIL =
decltype(cute::GMMA::ss_op_selector_sparse<Element,
Element,
ElementAccum,
TileShape_MNK_TAIL>());
using SmemLayoutAtomA = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, Int<kBlockM / 2>, Int<kBlockK>>());
using SmemLayoutAtomA =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
Int<kBlockM / 2>,
Int<kBlockK>>());
using SmemLayoutA = decltype(
tile_to_shape(SmemLayoutAtomA{},
make_shape(Int<kBlockM / 2>{}, Int<kBlockK>{}, Int<kStages>{})));
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(Int<kBlockM / 2>{}, Int<kBlockK>{}, Int<kStages>{})));
using SmemLayoutAtomB = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutB = decltype(
tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutB =
decltype(tile_to_shape(SmemLayoutAtomB{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
Int<kStages>{})));
using SmemLayoutAtomB_TAIL = decltype(
cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})),
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
using SmemLayoutAtomB_TAIL =
decltype(cutlass::gemm::collective::detail::rs_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK_TAIL{})),
decltype(cute::get<2>(TileShape_MNK_TAIL{}))>());
using SmemLayoutB_TAIL = decltype(
tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(
shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})
));
using SmemLayoutAtomC = decltype(
cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K, ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutB_TAIL =
decltype(tile_to_shape(SmemLayoutAtomB_TAIL{},
make_shape(shape<1>(TileShape_MNK_TAIL{}),
shape<2>(TileShape_MNK_TAIL{}),
Int<kStages>{})));
using SmemLayoutAtomC =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
ElementOutput,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<1>(TileShape_MNK{}))>());
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutC =
decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{})));
using SmemLayoutE = Layout<Shape<Int<NumMmaThreads>, Int<kBlockK / 64>, Int<kStages>>>;
using SmemLayoutE =
Layout<Shape<Int<NumMmaThreads>, Int<kBlockK / 64>, Int<kStages>>>;
using SharedStorage = SharedStorage<
kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutE, SmemLayoutB, SmemLayoutC>;
using SharedStorage = SharedStorage<kStages,
Element,
ElementOutput,
SmemLayoutA,
SmemLayoutE,
SmemLayoutB,
SmemLayoutC>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using TiledCopyC = decltype(make_tiled_copy(
TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<OutputType>);
static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem;
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyCAtom =
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, OutputType>;
using TiledCopyCThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyCValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{}));
using TiledCopyC =
decltype(make_tiled_copy(TiledCopyCAtom{},
TiledCopyCThrLayout{}, // Thr layout
TiledCopyCValLayout{} // Val layout
));
};
@@ -29,438 +29,483 @@
using namespace cute;
template <typename Ktraits>
struct CollectiveMainloopFwd {
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
using Element = typename Ktraits::Element;
using ElementOutput = typename Ktraits::ElementOutput;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using ElementAccum = typename Ktraits::ElementAccum;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int M = Ktraits::M;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kBlockK = Ktraits::kBlockK;
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kTiles = Ktraits::kTiles;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int M = Ktraits::M;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
using GmemTiledCopyStore = cute::SM90_TMA_STORE;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutE = typename Ktraits::SmemLayoutE;
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using GmemTiledCopy = cute::SM90_TMA_LOAD;
using GmemTiledCopyStore = cute::SM90_TMA_STORE;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using SmemLayoutA = typename Ktraits::SmemLayoutA;
using SmemLayoutB = typename Ktraits::SmemLayoutB;
using SmemLayoutC = typename Ktraits::SmemLayoutC;
using SmemLayoutE = typename Ktraits::SmemLayoutE;
using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL;
using WShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
using WStrideT = cute::Shape<int64_t, _1, int64_t, int64_t, int64_t>;
using WLayoutT = cute::Layout<WShapeT, WStrideT>;
using ShapeT = cute::Shape<int64_t, int64_t, int64_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using EShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
using EStrideT = cute::Shape<_1, int64_t, int64_t, int64_t, int64_t>;
using ELayoutT = cute::Layout<EShapeT, EStrideT>;
using WShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
using WStrideT = cute::Shape<int64_t, _1, int64_t, int64_t, int64_t>;
using WLayoutT = cute::Layout<WShapeT, WStrideT>;
using TMA_A = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
WShapeT{},
WStrideT{}),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
size<0>(ClusterShape{})));
using EShapeT = cute::Shape<int64_t, int64_t, int64_t, int64_t, int64_t>;
using EStrideT = cute::Shape<_1, int64_t, int64_t, int64_t, int64_t>;
using ELayoutT = cute::Layout<EShapeT, EStrideT>;
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
using TMA_A = decltype(make_tma_copy(
using TMA_E = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(make_gmem_ptr(static_cast<uint32_t const*>(nullptr)),
EShapeT{},
EStrideT{}),
SmemLayoutE{}(_, _, _0{}),
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
size<0>(ClusterShape{})));
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesE = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutE{})) * cutlass::sizeof_bits_v<int> / 8);
struct Arguments {
Element const* ptr_A;
WLayoutT layout_A;
uint32_t const* ptr_E;
ELayoutT layout_E;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput* ptr_C;
LayoutT layout_C;
const int* tokens;
const float* weight_scale;
};
struct Params {
WLayoutT layout_A;
ELayoutT layout_E;
LayoutT layout_B;
TMA_A tma_load_A;
TMA_E tma_load_E;
TMA_B tma_load_B;
const int* tokens;
const float* weight_scale;
ElementOutput* ptr_C;
};
Params static to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A =
make_tma_copy(GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
size<0>(ClusterShape{}));
Tensor mE = make_tensor(make_gmem_ptr(args.ptr_E), args.layout_E);
TMA_E tma_load_E = make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
WShapeT{},
WStrideT{}
),
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
size<0>(ClusterShape{})));
using TMA_B = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
ShapeT{},
StrideT{}
),
take<0, 2>(SmemLayoutB{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})));
using TMA_E = decltype(make_tma_copy(
GmemTiledCopy{},
make_tensor(
make_gmem_ptr(static_cast<uint32_t const*>(nullptr)),
EShapeT{},
EStrideT{}
),
mE,
SmemLayoutE{}(_, _, _0{}),
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
size<0>(ClusterShape{})));
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
return {args.layout_A,
args.layout_E,
args.layout_B,
tma_load_A,
tma_load_E,
tma_load_B,
args.tokens,
args.weight_scale,
args.ptr_C};
}
static constexpr uint32_t TmaTransactionBytesA = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesB = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesE = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutE{})) * cutlass::sizeof_bits_v<int> / 8);
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_B.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_E.get_tma_descriptor());
}
struct Arguments {
Element const* ptr_A;
WLayoutT layout_A;
uint32_t const* ptr_E;
ELayoutT layout_E;
Element const* ptr_B;
LayoutT layout_B;
ElementOutput * ptr_C;
LayoutT layout_C;
const int *tokens;
const float *weight_scale;
};
template <int CUR_N, typename SharedStorage>
CUTLASS_DEVICE void store(Params const& mainloop_params,
float* acc_s,
SharedStorage& shared_storage,
const int pre_fix_tokens,
const int tokens,
const float* weight_scale,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
typename Ktraits::TiledMma tiled_mma;
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(
partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}))
.layout());
struct Params {
WLayoutT layout_A;
ELayoutT layout_E;
LayoutT layout_B;
TMA_A tma_load_A;
TMA_E tma_load_E;
TMA_B tma_load_B;
const int *tokens;
const float *weight_scale;
ElementOutput * ptr_C;
};
Params static
to_underlying_arguments(Arguments const& args) {
Tensor mA = make_tensor(make_gmem_ptr(args.ptr_A), args.layout_A);
TMA_A tma_load_A = make_tma_copy(
GmemTiledCopy{},
mA,
SmemLayoutA{}(_, _, _0{}),
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
size<0>(ClusterShape{}));
Tensor mE = make_tensor(make_gmem_ptr(args.ptr_E), args.layout_E);
TMA_E tma_load_E = make_tma_copy(
GmemTiledCopy{},
mE,
SmemLayoutE{}(_, _, _0{}),
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
size<0>(ClusterShape{}));
Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B);
TMA_B tma_load_B = make_tma_copy(
GmemTiledCopy{},
mB,
SmemLayoutB{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}));
return {args.layout_A, args.layout_E, args.layout_B,
tma_load_A, tma_load_E, tma_load_B,
args.tokens, args.weight_scale, args.ptr_C};
#pragma unroll
for (int i = 0; i < size(tOrO_out); i += 4) {
acc_s[i] *= weight_scale[0];
acc_s[i + 1] *= weight_scale[0];
acc_s[i + 2] *= weight_scale[1];
acc_s[i + 3] *= weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) =
packHalf(acc_s[i], acc_s[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) =
packHalf(acc_s[i + 1], acc_s[i + 3]);
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_E.get_tma_descriptor());
uint16_t* smem_c =
reinterpret_cast<uint16_t*>(shared_storage.smem_c.data());
uint32_t* reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = CUR_N / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(
reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile(
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, "
"%4};\n" ::"r"(smem_ptr),
"r"(reg_data[4 * i + 0]),
"r"(reg_data[4 * i + 2]),
"r"(reg_data[4 * i + 1]),
"r"(reg_data[4 * i + 3]));
#endif
}
template <int CUR_N, typename SharedStorage>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
float * acc_s,
SharedStorage& shared_storage,
const int pre_fix_tokens,
const int tokens,
const float * weight_scale,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
typename Ktraits::TiledMma tiled_mma;
using packHalf = typename PackedHalf<ElementOutput>::Type;
Tensor tOrO_out = make_tensor<ElementOutput>(partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})).layout());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int batch_idx =
TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput* store_c = mainloop_params.ptr_C + batch_idx +
bidn * (M * kBlockN) + bidm * kBlockM;
#pragma unroll
for (int i = 0; i < size(tOrO_out); i+=4) {
acc_s[i] *= weight_scale[0];
acc_s[i + 1] *= weight_scale[0];
acc_s[i + 2] *= weight_scale[1];
acc_s[i + 3] *= weight_scale[1];
*reinterpret_cast<packHalf*>(&tOrO_out[i]) = packHalf(acc_s[i], acc_s[i + 2]);
*reinterpret_cast<packHalf*>(&tOrO_out[i + 2]) = packHalf(acc_s[i + 1], acc_s[i + 3]);
}
const int reamin_tokens = tokens - bidn * kBlockN;
uint16_t *smem_c = reinterpret_cast<uint16_t *>(shared_storage.smem_c.data());
const int col = tidx % 2;
uint32_t * reg_data = reinterpret_cast<uint32_t*>(tOrO_out.data());
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
constexpr int k_copy_times = CUR_N / 16;
#pragma unroll
for (int i = 0; i < k_copy_times; i++) {
uint32_t smem_ptr = cast_smem_ptr_to_uint(reinterpret_cast<uint128_t*>(smem_c + i * 16 * 128) + tidx);
#if defined(CUTE_ARCH_STSM_SM90_ENABLED)
asm volatile (
"stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "r"(smem_ptr), "r"(reg_data[4 * i + 0]), "r"(reg_data[4 * i + 2]), "r"(reg_data[4 * i + 1]), "r"(reg_data[4 * i + 3]));
#endif
}
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize;
ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM;
const int reamin_tokens = tokens - bidn * kBlockN;
const int col = tidx % 2;
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 + idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] = reinterpret_cast<uint4*>(smem_c)[idx];
}
constexpr int kPackSize = 16 / sizeof(ElementOutput);
constexpr int kNumVecElem = kBlockM / kPackSize;
constexpr int copy_len = CUR_N * kNumVecElem;
#pragma unroll
for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) {
const int idx_div2 = idx / 2;
const int store_idx = idx_div2 / 128 * 128 + idx_div2 % 8 * 16 +
idx_div2 % 128 / 16 + idx_div2 % 16 / 8 * 8;
const int store_global_idx = store_idx * 2 + col;
const int row = store_global_idx / kNumVecElem;
const int col = store_global_idx % kNumVecElem;
if (row >= reamin_tokens) {
continue;
}
const int offset = row * (M / kPackSize) + col;
reinterpret_cast<uint4*>(store_c)[offset] =
reinterpret_cast<uint4*>(smem_c)[idx];
}
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_packed_tensor(
const MTensor &mB,
const int tokens,
const int bidn) const {
template <typename MTensor>
CUTLASS_DEVICE auto get_local_packed_tensor(const MTensor& mB,
const int tokens,
const int bidn) const {
auto mB_this_batch = make_tensor(
mB.data(),
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride()));
return local_tile(
mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
}
auto mB_this_batch = make_tensor(
mB.data(),
make_layout(
cute::make_shape(tokens, size<1>(mB)),
mB.stride()
));
return local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
}
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(const MTensor& mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_offset = local_tile(mB(_, _, 0),
cute::make_shape(1, size<1>(mB)),
make_coord(pre_fix_token, _0{}));
template <typename MTensor>
CUTLASS_DEVICE auto get_local_no_packed_tensor(
const MTensor &mB,
const int pre_fix_token,
const int actual_token,
const int bidn) const {
auto g_tensor =
make_tensor(g_offset.data(),
make_layout(cute::make_shape(actual_token, size<1>(mB)),
g_offset.stride()));
auto g_offset = local_tile(
mB(_, _, 0),
cute::make_shape(1, size<1>(mB)),
make_coord(pre_fix_token, _0{}));
Tensor gB = local_tile(
g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto g_tensor = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_token, size<1>(mB)),
g_offset.stride()
));
return gB;
}
Tensor gB = local_tile(g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
template <typename SharedStorage>
CUTLASS_DEVICE void load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage& shared_storage,
const int pre_fix_tokens,
const int tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor sE =
make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
return gB;
}
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(
mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(
mainloop_params.layout_B.shape());
Tensor mE = mainloop_params.tma_load_E.get_tma_tensor(
mainloop_params.layout_E.shape());
Tensor gA =
local_tile(mA(_, _, _, bidm, bidb),
select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}),
make_coord(0, 0, _));
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_write,
SharedStorage &shared_storage,
const int pre_fix_tokens,
const int tokens,
const int bidm,
const int bidn,
const int bidb,
const int tidx) {
Tensor gE =
local_tile(mE(_, _, _, bidm, bidb),
select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}),
make_coord(0, 0));
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{});
Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sA),
group_modes<0, 2>(gA));
Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape());
Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape());
Tensor mE = mainloop_params.tma_load_E.get_tma_tensor(mainloop_params.layout_E.shape());
auto [tEgE, tEsE] = tma_partition(mainloop_params.tma_load_E,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sE),
group_modes<0, 2>(gE));
Tensor gA = local_tile(mA(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<kBlockM / 2>, Int<kBlockK>>{}), make_coord(0,0,_));
int lane_predicate = cute::elect_one_sync();
int warp_idx_in_warpgroup =
__shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
Tensor gE = local_tile(mE(_, _, _, bidm, bidb), select<0, 1>(Shape<Int<NumMmaThreads>, Int<kBlockK / 64>>{}), make_coord(0, 0));
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(mB, pre_fix_tokens, tokens, bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA));
auto [tEgE, tEsE] = tma_partition(mainloop_params.tma_load_E, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sE), group_modes<0, 2>(gE));
int lane_predicate = cute::elect_one_sync();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if constexpr (TokenPackSize == 0) {
Tensor gB = get_local_no_packed_tensor(
mB,
pre_fix_tokens,
tokens,
bidn);
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
const int kIters = kTiles / kStages;
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, s));
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i), tEsE(_, s));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, s));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i), tEsE(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
} else {
auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(),
make_layout(
cute::make_shape(tokens, size<1>(mB)),
mB.stride()
));
Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout<ClusterShape>{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB));
const int kIters = kTiles / kStages;
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, s));
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i), tEsE(_, s));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, s));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i), tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_E.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i), tEsE(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i), tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
}
}
template <int CUR_N, typename SharedStorage>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
float *acc_s,
const int tidx) {
using sMemBLayout = std::conditional_t<
CUR_N == kBlockN,
SmemLayoutB,
SmemLayoutB_TAIL
>;
using Mma = std::conditional_t<
CUR_N == kBlockN,
typename Ktraits::Mma,
typename Ktraits::Mma_TAIL
>;
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
Tensor sE = make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
const int wg_idx = tidx / 128;
const int wg_offset = wg_idx * 64;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
constexpr int E_STEP = kBlockK / 64 * NumMmaThreads;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
const int kIters = kTiles / kStages;
#pragma unroll
const int kIters = kTiles / kStages;
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
consumer_wait(pipeline, smem_pipe_read);
gemm<Mma, kBlockK, NumMmaThreads>(
sA(_, _, s).data().get().get() + wg_offset,
sB(_, _, s * B_STEPS).data().get().get(),
acc_s,
shared_storage.smem_e.data() + s * E_STEP + tidx);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i),
tAsA(_, s));
copy(mainloop_params.tma_load_E.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i),
tEsE(_, s));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i),
tBsB(_, s));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
consumer_wait(pipeline, smem_pipe_read);
gemm<Mma, kBlockK, NumMmaThreads>(
sA(_, _, i).data().get().get() + wg_offset,
sB(_, _, i * B_STEPS).data().get().get(),
acc_s,
shared_storage.smem_e.data() + i * E_STEP + tidx);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_E.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i),
tEsE(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i),
tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
} else {
auto mB_this_batch = make_tensor(
mB(_, _, bidb).data(),
make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride()));
Tensor gB = local_tile(
mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _));
auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B,
_0{},
Layout<ClusterShape>{},
group_modes<0, 2>(sB),
group_modes<0, 2>(gB));
const int kIters = kTiles / kStages;
if (tidx == 0) {
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
const int i = kiter * kStages + s;
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i),
tAsA(_, s));
copy(mainloop_params.tma_load_E.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i),
tEsE(_, s));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i),
tBsB(_, s));
++smem_pipe_write;
}
}
#pragma unroll
for (int i = kIters * kStages; i < kTiles; ++i) {
pipeline.producer_acquire(smem_pipe_write);
copy(mainloop_params.tma_load_A.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tAgA(_, i),
tAsA(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_E.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tEgE(_, i),
tEsE(_, smem_pipe_write.index()));
copy(mainloop_params.tma_load_B.with(
*pipeline.producer_get_barrier(smem_pipe_write), 0),
tBgB(_, i),
tBsB(_, smem_pipe_write.index()));
++smem_pipe_write;
}
}
}
}
template <int CUR_N, typename SharedStorage>
CUTLASS_DEVICE void mma(Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState& smem_pipe_read,
SharedStorage& shared_storage,
float* acc_s,
const int tidx) {
using sMemBLayout =
std::conditional_t<CUR_N == kBlockN, SmemLayoutB, SmemLayoutB_TAIL>;
using Mma = std::conditional_t<CUR_N == kBlockN,
typename Ktraits::Mma,
typename Ktraits::Mma_TAIL>;
Tensor sA =
make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{});
Tensor sE =
make_tensor(make_smem_ptr(shared_storage.smem_e.data()), SmemLayoutE{});
const int wg_idx = tidx / 128;
const int wg_offset = wg_idx * 64;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
constexpr int E_STEP = kBlockK / 64 * NumMmaThreads;
constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N);
const int kIters = kTiles / kStages;
#pragma unroll
for (int kiter = 0; kiter < kIters; ++kiter) {
#pragma unroll
for (int s = 0; s < kStages; s++) {
consumer_wait(pipeline, smem_pipe_read);
gemm<Mma, kBlockK, NumMmaThreads>(
sA(_, _, s).data().get().get() + wg_offset,
sB(_, _, s * B_STEPS).data().get().get(),
acc_s,
shared_storage.smem_e.data() + s * E_STEP + tidx);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
#pragma unroll
for (int i = 0; i < kTiles % kStages; ++i) {
consumer_wait(pipeline, smem_pipe_read);
gemm<Mma, kBlockK, NumMmaThreads>(
sA(_, _, i).data().get().get() + wg_offset,
sB(_, _, i * B_STEPS).data().get().get(),
acc_s,
shared_storage.smem_e.data() + i * E_STEP + tidx);
pipeline.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
}
};
@@ -32,69 +32,71 @@
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
using namespace cute;
template<typename T>
template <typename T>
struct PackedHalf;
template<>
template <>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
using Type = __half2;
};
template<>
template <>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
using Type = nv_bfloat162;
};
template <class PointerType>
__device__ GmmaDescriptor make_smem_desc(
PointerType smem_ptr,
int layout_type,
int leading_byte_offset = 0,
int stride_byte_offset = 1024) {
GmmaDescriptor desc;
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr,
int layout_type,
int leading_byte_offset = 0,
int stride_byte_offset = 1024) {
GmmaDescriptor desc;
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
template <typename Mma, size_t ...Idx>
__forceinline__ __device__ static void gemm(uint64_t const& desc_a, uint64_t const& desc_b, float* d, const uint32_t e, std::index_sequence<Idx...>) {
Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One);
template <typename Mma, size_t... Idx>
__forceinline__ __device__ static void gemm(uint64_t const& desc_a,
uint64_t const& desc_b,
float* d,
const uint32_t e,
std::index_sequence<Idx...>) {
Mma::fma(desc_a, desc_b, d[Idx]..., e, GMMA::ScaleOut::One);
}
template <typename Mma, int kBlockK, int NumMmaThreads, typename T>
__forceinline__ __device__ void gemm(
const T * sA,
const T * sB,
float * acc_c,
const uint32_t *E) {
__forceinline__ __device__ void gemm(const T* sA,
const T* sB,
float* acc_c,
const uint32_t* E) {
constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float);
constexpr int acc_num = sizeof(Mma::CRegisters) / sizeof(float);
warpgroup_arrive();
// 选择的下标 对应的16进制
// 01 4
// 02 8
// 03 12
// 12 9
// 13 13
// 23 14
#pragma unroll
for (int i = 0; i < kBlockK / 64; i++) {
GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024);
GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024);
gemm<Mma>(a_desc,
b_desc,
acc_c,
E[i * NumMmaThreads],
std::make_index_sequence<acc_num>{});
}
warpgroup_arrive();
// 选择的下标 对应的16进制
// 01 4
// 02 8
// 03 12
// 12 9
// 13 13
// 23 14
#pragma unroll
for (int i = 0; i < kBlockK / 64; i++) {
GmmaDescriptor a_desc = make_smem_desc(sA + i * 32, 1, 0, 1024);
GmmaDescriptor b_desc = make_smem_desc(sB + i * 64, 1, 0, 1024);
gemm<Mma>(a_desc, b_desc, acc_c, E[i * NumMmaThreads], std::make_index_sequence<acc_num>{});
}
warpgroup_commit_batch();
warpgroup_wait<0>();
warpgroup_commit_batch();
warpgroup_wait<0>();
}
@@ -27,283 +27,307 @@
#include "mainloop_fwd.h"
template <typename Ktraits>
void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) w8a8_sparse_gemm_kernel(
CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
1)
w8a8_sparse_gemm_kernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloopFwd<Ktraits>::Params const mainloop_params) {
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int M = Ktraits::M;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int TokenPackSize = Ktraits::TokenPackSize;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int TAIL_N = Ktraits::TAIL_N;
static constexpr int M = Ktraits::M;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
using CollectiveMainloop = CollectiveMainloopFwd<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
// Obtain warp index
int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA +
CollectiveMainloop::TmaTransactionBytesE +
CollectiveMainloop::TmaTransactionBytesB;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
MainloopPipeline pipeline(
shared_storage.pipeline, pipeline_params, ClusterShape{});
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesB;
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
} else {
__syncthreads();
}
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int pre_fix_tokens =
TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
const int tokens = TokenPackSize == 0
? mainloop_params.tokens[bidb + 1] - pre_fix_tokens
: mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<40>();
PipelineState smem_pipe_write =
cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
pre_fix_tokens,
tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<232>();
PipelineState smem_pipe_read;
constexpr int acc_num =
sizeof(typename Ktraits::Mma::CRegisters) / sizeof(float);
float acc_s[acc_num];
#pragma unroll
for (int i = 0; i < acc_num; ++i) {
acc_s[i] = 0.0f;
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
const int reamin_tokens = tokens - bidn * kBlockN;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesE + CollectiveMainloop::TmaTransactionBytesB;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
const int mma_tidx = tidx - NumCopyThreads;
MainloopPipeline pipeline(shared_storage.pipeline, pipeline_params, ClusterShape{});
const float2 weight_scale = reinterpret_cast<const float2 *>(
mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesB;
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
collective_mainloop.mma<TAIL_N>(mainloop_params,
pipeline,
smem_pipe_read,
shared_storage,
acc_s,
mma_tidx);
CollectiveMainloop collective_mainloop;
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
cute::cluster_wait();
collective_mainloop.store<TAIL_N>(
mainloop_params,
acc_s,
shared_storage,
pre_fix_tokens,
tokens,
reinterpret_cast<const float *>(&weight_scale),
bidm,
bidn,
bidb,
mma_tidx);
} else {
__syncthreads();
collective_mainloop.mma<kBlockN>(mainloop_params,
pipeline,
smem_pipe_read,
shared_storage,
acc_s,
mma_tidx);
collective_mainloop.store<kBlockN>(
mainloop_params,
acc_s,
shared_storage,
pre_fix_tokens,
tokens,
reinterpret_cast<const float *>(&weight_scale),
bidm,
bidn,
bidb,
mma_tidx);
}
const int bidm = blockIdx.x;
const int bidn = blockIdx.y;
const int bidb = blockIdx.z;
const int tidx = threadIdx.x;
const int pre_fix_tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb] : 0;
const int tokens = TokenPackSize == 0 ? mainloop_params.tokens[bidb + 1] - pre_fix_tokens : mainloop_params.tokens[bidb];
if (bidn * kBlockN >= tokens) {
return;
}
if (warp_group_idx == 0) {
cutlass::arch::warpgroup_reg_dealloc<40>();
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline,
smem_pipe_write,
shared_storage,
pre_fix_tokens,
tokens,
bidm,
bidn,
bidb,
tidx);
} else {
cutlass::arch::warpgroup_reg_alloc<232>();
PipelineState smem_pipe_read;
constexpr int acc_num = sizeof(typename Ktraits::Mma::CRegisters) / sizeof(float);
float acc_s[acc_num];
#pragma unroll
for (int i = 0; i < acc_num; ++i) {
acc_s[i] = 0.0f;
}
const int reamin_tokens = tokens - bidn * kBlockN;
const int mma_tidx = tidx - NumCopyThreads;
const float2 weight_scale = reinterpret_cast<const float2*>(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4];
if (TAIL_N > 0 && reamin_tokens < kBlockN) {
collective_mainloop.mma<TAIL_N>(
mainloop_params,
pipeline,
smem_pipe_read,
shared_storage,
acc_s,
mma_tidx);
collective_mainloop.store<TAIL_N>(
mainloop_params,
acc_s,
shared_storage,
pre_fix_tokens,
tokens,
reinterpret_cast<const float*>(&weight_scale),
bidm,
bidn,
bidb,
mma_tidx);
} else {
collective_mainloop.mma<kBlockN>(
mainloop_params,
pipeline,
smem_pipe_read,
shared_storage,
acc_s,
mma_tidx);
collective_mainloop.store<kBlockN>(
mainloop_params,
acc_s,
shared_storage,
pre_fix_tokens,
tokens,
reinterpret_cast<const float*>(&weight_scale),
bidm,
bidn,
bidb,
mma_tidx);
}
}
}
}
template <int Batch>
auto get_gmem_layout(int Rows, int Cols) {
return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
return make_layout(make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(Batch)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols)));
}
template <int Batch>
auto get_weight_gmem_layout(int m_nums, int k_nums, int Rows, int Cols) {
return make_layout(
make_shape(
static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(k_nums),
static_cast<int64_t>(m_nums),
static_cast<int64_t>(Batch)),
make_stride(
static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols),
static_cast<int64_t>(Rows * Cols * k_nums),
static_cast<int64_t>(Rows * Cols * k_nums * m_nums)));
return make_layout(
make_shape(static_cast<int64_t>(Rows),
static_cast<int64_t>(Cols),
static_cast<int64_t>(k_nums),
static_cast<int64_t>(m_nums),
static_cast<int64_t>(Batch)),
make_stride(static_cast<int64_t>(Cols),
cute::_1{},
static_cast<int64_t>(Rows * Cols),
static_cast<int64_t>(Rows * Cols * k_nums),
static_cast<int64_t>(Rows * Cols * k_nums * m_nums)));
}
template <int Batch>
auto get_gmem_e_layout(int ms, int ks, int ks_in, int Cols) {
return make_layout(
make_shape(
static_cast<int64_t>(Cols),
static_cast<int64_t>(ks_in),
static_cast<int64_t>(ks),
static_cast<int64_t>(ms),
static_cast<int64_t>(Batch)),
make_stride(
cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(ks_in * Cols),
static_cast<int64_t>(ks * ks_in * Cols),
static_cast<int64_t>(ms * ks * Cols * 2)));
return make_layout(make_shape(static_cast<int64_t>(Cols),
static_cast<int64_t>(ks_in),
static_cast<int64_t>(ks),
static_cast<int64_t>(ms),
static_cast<int64_t>(Batch)),
make_stride(cute::_1{},
static_cast<int64_t>(Cols),
static_cast<int64_t>(ks_in * Cols),
static_cast<int64_t>(ks * ks_in * Cols),
static_cast<int64_t>(ms * ks * Cols * 2)));
}
template <typename InputType, typename OutputType, typename Kernel_traits, int M, int K, int Batch, int kPackTokenSize>
void run_gemm(
const InputType * A,
const uint32_t *E,
const InputType * B,
OutputType * C,
const float *weight_scale,
const int *tokens_idx,
const int max_tokens,
cudaStream_t stream) {
template <typename InputType,
typename OutputType,
typename Kernel_traits,
int M,
int K,
int Batch,
int kPackTokenSize>
void run_gemm(const InputType *A,
const uint32_t *E,
const InputType *B,
OutputType *C,
const float *weight_scale,
const int *tokens_idx,
const int max_tokens,
cudaStream_t stream) {
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int NumMmaThreads = Kernel_traits::NumMmaThreads;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kBlockM = Kernel_traits::kBlockM;
using ElementOutput = typename Kernel_traits::ElementOutput;
using Element = typename Kernel_traits::Element;
using CollectiveMainloop = CollectiveMainloopFwd<Kernel_traits>;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
constexpr int NumMmaThreads = Kernel_traits::NumMmaThreads;
constexpr int kBlockK = Kernel_traits::kBlockK;
constexpr int kBlockM = Kernel_traits::kBlockM;
static_assert(M % Kernel_traits::kBlockM == 0);
constexpr int M_nums = M / Kernel_traits::kBlockM;
const int N_nums =
(max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
static_assert(M % Kernel_traits::kBlockM == 0);
constexpr int M_nums = M / Kernel_traits::kBlockM;
const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
constexpr int kTiles = Kernel_traits::kTiles;
constexpr int kTiles = Kernel_traits::kTiles;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const *>(A),
get_weight_gmem_layout<Batch>(M_nums, kTiles, kBlockM / 2, kBlockK),
static_cast<uint32_t const *>(E),
get_gmem_e_layout<Batch>(M_nums, kTiles, kBlockK / 64, NumMmaThreads),
static_cast<Element const *>(B),
get_gmem_layout<Batch>(
kPackTokenSize == 0 ? max_tokens * Batch : kPackTokenSize, K),
static_cast<ElementOutput *>(C),
get_gmem_layout<Batch>(
M, kPackTokenSize == 0 ? max_tokens : kPackTokenSize),
tokens_idx,
weight_scale,
});
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(A),
get_weight_gmem_layout<Batch>(M_nums, kTiles, kBlockM / 2, kBlockK),
static_cast<uint32_t const*>(E),
get_gmem_e_layout<Batch>(M_nums, kTiles, kBlockK / 64, NumMmaThreads),
static_cast<Element const*>(B),
get_gmem_layout<Batch>(kPackTokenSize == 0 ? max_tokens * Batch : kPackTokenSize, K),
static_cast<ElementOutput*>(C),
get_gmem_layout<Batch>(M, kPackTokenSize == 0 ? max_tokens : kPackTokenSize),
tokens_idx,
weight_scale,
});
void *kernel;
kernel = (void *)w8a8_sparse_gemm_kernel<Kernel_traits>;
void *kernel;
kernel = (void *)w8a8_sparse_gemm_kernel<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params);
dim3 grid_dims;
grid_dims.x = M_nums;
grid_dims.y = N_nums;
grid_dims.z = Batch;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}),
size<1>(ClusterShape{}),
size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{
grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params);
}
template <typename InputType, typename OutputType, int M, int K, int Batch, int kPackTokenSize>
void w8a8_sparse_gemm(
const InputType * A,
const uint32_t * E,
const InputType * B,
OutputType * C,
const float *weight_scale,
const int *tokens_idx,
const int max_tokens,
cudaStream_t stream) {
constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16;
constexpr static int kStages = 5;
constexpr int kCluster = 1;
static_assert(K % kBlockK == 0);
constexpr int kTiles = K / kBlockK;
const int max_tokens_pack16 = (max_tokens + 31) / 32 * 32;
template <typename InputType,
typename OutputType,
int M,
int K,
int Batch,
int kPackTokenSize>
void w8a8_sparse_gemm(const InputType *A,
const uint32_t *E,
const InputType *B,
OutputType *C,
const float *weight_scale,
const int *tokens_idx,
const int max_tokens,
cudaStream_t stream) {
constexpr static int kBlockM = 128;
constexpr static int kBlockK = 128;
constexpr static int kNWarps = 4 + kBlockM / 16;
constexpr static int kStages = 5;
constexpr int kCluster = 1;
static_assert(K % kBlockK == 0);
constexpr int kTiles = K / kBlockK;
const int max_tokens_pack16 = (max_tokens + 31) / 32 * 32;
using Kernel_traits = Kernel_traits<kBlockM, 256, kBlockK, kNWarps, kStages, kTiles, M, kPackTokenSize, 0, kCluster, InputType, OutputType>;
run_gemm<InputType, OutputType, Kernel_traits, M, K, Batch, kPackTokenSize>(A, E, B, C, weight_scale, tokens_idx, max_tokens_pack16, stream);
using Kernel_traits = Kernel_traits<kBlockM,
256,
kBlockK,
kNWarps,
kStages,
kTiles,
M,
kPackTokenSize,
0,
kCluster,
InputType,
OutputType>;
run_gemm<InputType, OutputType, Kernel_traits, M, K, Batch, kPackTokenSize>(
A, E, B, C, weight_scale, tokens_idx, max_tokens_pack16, stream);
}
@@ -21,92 +21,93 @@
#include "wfp8Afp8_sparse_gemm_template.h"
template <typename OutputType>
void DisPatchWFp8AFp8Gemm(
const cutlass::float_e4m3_t* input,
const uint32_t* sparse_idx,
const cutlass::float_e4m3_t* weight,
const int * tokens,
const float * weight_scale,
OutputType * out,
const int token_padding_size,
const int max_tokens,
const int batch_size,
const int M,
const int K,
cudaStream_t stream) {
void DisPatchWFp8AFp8Gemm(const cutlass::float_e4m3_t* input,
const uint32_t* sparse_idx,
const cutlass::float_e4m3_t* weight,
const int* tokens,
const float* weight_scale,
OutputType* out,
const int token_padding_size,
const int max_tokens,
const int batch_size,
const int M,
const int K,
cudaStream_t stream) {
const int max_tokens_pack32 = (max_tokens + 31) / 32 * 32;
const int max_tokens_pack32 = (max_tokens + 31) / 32 * 32;
int kBlockN = 256;
int TailN = max_tokens_pack32 % kBlockN;
if (max_tokens < 256) {
kBlockN = max_tokens_pack32;
TailN = 0;
}
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
SPARSE_GEMM_SWITCH_BF16(M, K, batch_size, token_padding_size, kBlockN, TailN,
weight,
sparse_idx,
input,
out,
weight_scale,
tokens,
max_tokens,
stream)
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
int kBlockN = 256;
int TailN = max_tokens_pack32 % kBlockN;
if (max_tokens < 256) {
kBlockN = max_tokens_pack32;
TailN = 0;
}
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
SPARSE_GEMM_SWITCH_BF16(M,
K,
batch_size,
token_padding_size,
kBlockN,
TailN,
weight,
sparse_idx,
input,
out,
weight_scale,
tokens,
max_tokens,
stream)
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
void WFp8AFp8Gemm(
const paddle::Tensor& input,
const paddle::Tensor& sparse_idx,
const paddle::Tensor& weight,
const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group
const paddle::Tensor& weight_scale,
const paddle::Tensor& out,
const int token_padding_size,
const int max_tokens,
const bool is_bfloat16) {
void WFp8AFp8Gemm(const paddle::Tensor& input,
const paddle::Tensor& sparse_idx,
const paddle::Tensor& weight,
const paddle::Tensor&
tokens, // If tokenpadding=0, this tensor represents the
// prefix sum of tensors, otherwise it represents
// the number of tokens in each group
const paddle::Tensor& weight_scale,
const paddle::Tensor& out,
const int token_padding_size,
const int max_tokens,
const bool is_bfloat16) {
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
const int batch_size = weight.dims()[0];
const int M = weight.dims()[1];
const int K = weight.dims()[2] * 2;
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) {
PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN'].");
}
if (is_bfloat16) {
DisPatchWFp8AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<phi::dtype::float8_e4m3fn>()),
tokens.data<int>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(const_cast<phi::dtype::bfloat16*>(out.data<phi::dtype::bfloat16>())),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream()
);
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
if (is_bfloat16) {
DisPatchWFp8AFp8Gemm(
reinterpret_cast<const cutlass::float_e4m3_t*>(
input.data<phi::dtype::float8_e4m3fn>()),
reinterpret_cast<const uint32_t*>(sparse_idx.data<int32_t>()),
reinterpret_cast<const cutlass::float_e4m3_t*>(
weight.data<phi::dtype::float8_e4m3fn>()),
tokens.data<int>(),
weight_scale.data<float>(),
reinterpret_cast<cutlass::bfloat16_t*>(
const_cast<phi::dtype::bfloat16*>(
out.data<phi::dtype::bfloat16>())),
token_padding_size,
max_tokens,
batch_size,
M,
K,
input.stream());
} else {
PD_THROW("Only supported dtype in ['BFLOAT16'].");
}
}
PD_BUILD_STATIC_OP(wfp8afp8_sparse_gemm)
.Inputs({"input",
"sparse_idx",
"weight",
"tokens",
"weight_scale",
"ffn_out"})
.Inputs(
{"input", "sparse_idx", "weight", "tokens", "weight_scale", "ffn_out"})
.Outputs({"out"})
.SetInplaceMap({{"ffn_out", "out"}})
.Attrs({"token_padding_size: int",
"max_tokens: int",
"is_bfloat16: bool"})
.Attrs({"token_padding_size: int", "max_tokens: int", "is_bfloat16: bool"})
.SetKernelFn(PD_KERNEL(WFp8AFp8Gemm));
@@ -19,78 +19,95 @@
#include "helper.h"
#include "paddle/extension.h"
void pack_E(const uint8_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) {
// 选择的下标 对应的16进制
// 01 4
// 02 8
// 03 12
// 12 9
// 13 13
// 23 14
const int ld1 = K / 4;
const int ld2 = K / 4 / 8;
const uint8_t select_idx[6] = {14, 13, 9, 12, 8, 4};
for (int b = 0; b < Batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < ld1; k+=8) {
uint32_t dst = 0;
for (int k2 = 7; k2 > 0; --k2) {
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k + k2]];
dst <<= 4;
}
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k]];
E_dst[b * M * ld2 + m * ld2 + k / 8] = dst;
}
void pack_E(const uint8_t *E_src,
int32_t *E_dst,
const int M,
const int K,
const int Batch) {
// 选择的下标 对应的16进制
// 01 4
// 02 8
// 03 12
// 12 9
// 13 13
// 23 14
const int ld1 = K / 4;
const int ld2 = K / 4 / 8;
const uint8_t select_idx[6] = {14, 13, 9, 12, 8, 4};
for (int b = 0; b < Batch; ++b) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < ld1; k += 8) {
uint32_t dst = 0;
for (int k2 = 7; k2 > 0; --k2) {
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k + k2]];
dst <<= 4;
}
dst |= select_idx[E_src[b * M * ld1 + m * ld1 + k]];
E_dst[b * M * ld2 + m * ld2 + k / 8] = dst;
}
}
}
}
void peruate_E(const int32_t *E_src, int32_t *E_dst, const int M, const int K, const int Batch) {
const int m_nums = M / 128;
const int k_nums = K / 128;
for (int b = 0; b < Batch; ++b) {
for (int m = 0; m < m_nums; ++m) {
for (int k = 0; k < k_nums; ++k) {
const int dst_idx = b * m_nums * k_nums * 512 + m * k_nums * 512 + k * 512;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
E_dst[dst_idx + 0 + j * 32 + 4 * i] = E_src[dst_idx + 0 + j * 64 + 4 * i];
E_dst[dst_idx + 2 + j * 32 + 4 * i] = E_src[dst_idx + 1 + j * 64 + 4 * i];
E_dst[dst_idx + 1 + j * 32 + 4 * i] = E_src[dst_idx + 32 + j * 64 + 4 * i];
E_dst[dst_idx + 3 + j * 32 + 4 * i] = E_src[dst_idx + 33 + j * 64 + 4 * i];
}
for (int j = 0; j < 8; ++j) {
E_dst[dst_idx + 256 + j * 32 + 4 * i] = E_src[dst_idx + 2 + j * 64 + 4 * i];
E_dst[dst_idx + 258 + j * 32 + 4 * i] = E_src[dst_idx + 3 + j * 64 + 4 * i];
E_dst[dst_idx + 257 + j * 32 + 4 * i] = E_src[dst_idx + 34 + j * 64 + 4 * i];
E_dst[dst_idx + 259 + j * 32 + 4 * i] = E_src[dst_idx + 35 + j * 64 + 4 * i];
}
}
}
void peruate_E(const int32_t *E_src,
int32_t *E_dst,
const int M,
const int K,
const int Batch) {
const int m_nums = M / 128;
const int k_nums = K / 128;
for (int b = 0; b < Batch; ++b) {
for (int m = 0; m < m_nums; ++m) {
for (int k = 0; k < k_nums; ++k) {
const int dst_idx =
b * m_nums * k_nums * 512 + m * k_nums * 512 + k * 512;
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; ++j) {
E_dst[dst_idx + 0 + j * 32 + 4 * i] =
E_src[dst_idx + 0 + j * 64 + 4 * i];
E_dst[dst_idx + 2 + j * 32 + 4 * i] =
E_src[dst_idx + 1 + j * 64 + 4 * i];
E_dst[dst_idx + 1 + j * 32 + 4 * i] =
E_src[dst_idx + 32 + j * 64 + 4 * i];
E_dst[dst_idx + 3 + j * 32 + 4 * i] =
E_src[dst_idx + 33 + j * 64 + 4 * i];
}
for (int j = 0; j < 8; ++j) {
E_dst[dst_idx + 256 + j * 32 + 4 * i] =
E_src[dst_idx + 2 + j * 64 + 4 * i];
E_dst[dst_idx + 258 + j * 32 + 4 * i] =
E_src[dst_idx + 3 + j * 64 + 4 * i];
E_dst[dst_idx + 257 + j * 32 + 4 * i] =
E_src[dst_idx + 34 + j * 64 + 4 * i];
E_dst[dst_idx + 259 + j * 32 + 4 * i] =
E_src[dst_idx + 35 + j * 64 + 4 * i];
}
}
}
}
}
}
std::vector<paddle::Tensor> WFp8AFp8GemmSparseIdxConvert(
const paddle::Tensor& weight,
const int batch_size,
const int M,
const int K) {
paddle::Tensor weight_temp = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
pack_E(weight.data<uint8_t>(), weight_temp.data<int32_t>(), M, K, batch_size);
peruate_E(weight_temp.data<int32_t>(), weight_new.data<int32_t>(), M, K, batch_size);
return {weight_new};
const paddle::Tensor &weight,
const int batch_size,
const int M,
const int K) {
paddle::Tensor weight_temp = paddle::empty(
{batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
paddle::Tensor weight_new = paddle::empty(
{batch_size, M, K / 32}, paddle::DataType::INT32, weight.place());
pack_E(weight.data<uint8_t>(), weight_temp.data<int32_t>(), M, K, batch_size);
peruate_E(weight_temp.data<int32_t>(),
weight_new.data<int32_t>(),
M,
K,
batch_size);
return {weight_new};
}
PD_BUILD_STATIC_OP(wfp8afp8_gemm_sparse_idx_convert)
.Inputs({"weight"})
.Outputs({"converted_weight"})
.Attrs({"batch: int",
"M: int",
"K: int"})
.Attrs({"batch: int", "M: int", "K: int"})
.SetKernelFn(PD_KERNEL(WFp8AFp8GemmSparseIdxConvert));