Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -34,9 +34,15 @@ using namespace cute;
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
template <typename ElementA_,
typename ElementB_,
typename ElementD_,
typename AccumulatorT,
typename GroupScaleT,
typename GroupZeroT,
typename ChannelScaleT,
typename TokenScaleT,
class KernelSchedule,
typename ScheduleConfig>
struct MacheteKernelTemplate {
static constexpr bool with_C = false; // not ever used
@@ -97,9 +103,12 @@ struct MacheteKernelTemplate {
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
using PrepackedLayoutB = PrepackedLayoutBTemplate<ElementA_,
ElementB_,
ElementConvertGroup,
AccumulatorT,
LayoutA_Transpose,
KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
@@ -123,9 +132,8 @@ struct MacheteKernelTemplate {
"Currently token and channel scales (if present) must be the same type");
// Currently only supports float scales
using ChTokScalesEpilogue =
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
TileShape>;
using ChTokScalesEpilogue = typename fastdeploy::c3x::
ScaledEpilogue<ElementAccumulator, ElementD, TileShape>;
static_assert((with_channel_scales || with_token_scales) ||
(std::is_same_v<ElementSChannel, float> &&
std::is_same_v<ElementSToken, float>),
@@ -143,23 +151,45 @@ struct MacheteKernelTemplate {
// EVTCompute
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
EpilogueTileType,
ElementAccumulator,
ElementSChannel,
ElementC,
LayoutC_Transpose,
AlignmentC,
ElementD,
LayoutD_Transpose,
AlignmentD,
EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::MacheteKernelTag,
ArchTag,
OperatorClass,
BTypeTuple,
PrepackedLayoutB,
AlignmentB,
ElementA,
LayoutA_Transpose,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
CollectiveMainloop,
CollectiveEpilogue,
TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
@@ -194,9 +224,7 @@ struct MacheteKernelTemplate {
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
auto unwrap = [](auto const& t) {
return t ? t->data() : nullptr;
};
auto unwrap = [](auto const& t) { return t ? t->data() : nullptr; };
auto A_ptr = static_cast<ElementA const*>(A.data());
auto B_ptr = static_cast<ElementB const*>(B.data());
auto D_ptr = static_cast<ElementD*>(D.data());
@@ -218,7 +246,7 @@ struct MacheteKernelTemplate {
if constexpr (with_group_scales) {
PD_CHECK(S_group_ptr && layout_S_group);
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
size<1>(*layout_S_group) == N));
size<1>(*layout_S_group) == N));
} else {
PD_CHECK(!S_group_ptr, "Scales not supported");
}
@@ -226,9 +254,9 @@ struct MacheteKernelTemplate {
if constexpr (with_group_zeropoints) {
PD_CHECK(Z_group_ptr && layout_Z_group);
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
size<1>(*layout_Z_group) == N));
size<1>(*layout_Z_group) == N));
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
"Scales and zeros must have the same layout");
"Scales and zeros must have the same layout");
} else {
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
}
@@ -263,14 +291,23 @@ struct MacheteKernelTemplate {
if constexpr (with_group_scales && with_group_zeropoints) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
mainloop_arguments = MainloopArguments{B_ptr,
_StrideB{},
A_ptr,
stride_At,
S_group_ptr,
stride_S_group,
group_size,
Z_group_ptr};
} else if constexpr (with_group_scales) {
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_group_ptr, stride_S_group, group_size};
mainloop_arguments = MainloopArguments{B_ptr,
_StrideB{},
A_ptr,
stride_At,
S_group_ptr,
stride_S_group,
group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
@@ -295,7 +332,7 @@ struct MacheteKernelTemplate {
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
PD_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");