mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 09:44:10 +08:00
@@ -34,9 +34,15 @@ using namespace cute;
|
||||
// register. Since the primary use case we want to support is Y = XW^t where
|
||||
// W is quantized, in this situation or right-hand operand is quantized so
|
||||
// we compute the transpose to move it to the left-hand side.
|
||||
template <typename ElementA_, typename ElementB_, typename ElementD_,
|
||||
typename AccumulatorT, typename GroupScaleT, typename GroupZeroT,
|
||||
typename ChannelScaleT, typename TokenScaleT, class KernelSchedule,
|
||||
template <typename ElementA_,
|
||||
typename ElementB_,
|
||||
typename ElementD_,
|
||||
typename AccumulatorT,
|
||||
typename GroupScaleT,
|
||||
typename GroupZeroT,
|
||||
typename ChannelScaleT,
|
||||
typename TokenScaleT,
|
||||
class KernelSchedule,
|
||||
typename ScheduleConfig>
|
||||
struct MacheteKernelTemplate {
|
||||
static constexpr bool with_C = false; // not ever used
|
||||
@@ -97,9 +103,12 @@ struct MacheteKernelTemplate {
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using PrepackedLayoutB =
|
||||
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementConvertGroup,
|
||||
AccumulatorT, LayoutA_Transpose, KernelSchedule>;
|
||||
using PrepackedLayoutB = PrepackedLayoutBTemplate<ElementA_,
|
||||
ElementB_,
|
||||
ElementConvertGroup,
|
||||
AccumulatorT,
|
||||
LayoutA_Transpose,
|
||||
KernelSchedule>;
|
||||
|
||||
static int constexpr TileShapeK =
|
||||
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
|
||||
@@ -123,9 +132,8 @@ struct MacheteKernelTemplate {
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename fastdeploy::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using ChTokScalesEpilogue = typename fastdeploy::c3x::
|
||||
ScaledEpilogue<ElementAccumulator, ElementD, TileShape>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
@@ -143,23 +151,45 @@ struct MacheteKernelTemplate {
|
||||
// EVTCompute
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose,
|
||||
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule,
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
EpilogueTileType,
|
||||
ElementAccumulator,
|
||||
ElementSChannel,
|
||||
ElementC,
|
||||
LayoutC_Transpose,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD_Transpose,
|
||||
AlignmentD,
|
||||
EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::MacheteCollectiveBuilder<
|
||||
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
|
||||
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::MacheteKernelTag,
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
BTypeTuple,
|
||||
PrepackedLayoutB,
|
||||
AlignmentB,
|
||||
ElementA,
|
||||
LayoutA_Transpose,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
TileScheduler>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
// stride_B is unused (since B is prepacked), but still required by cutlass
|
||||
@@ -194,9 +224,7 @@ struct MacheteKernelTemplate {
|
||||
int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0;
|
||||
int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0;
|
||||
|
||||
auto unwrap = [](auto const& t) {
|
||||
return t ? t->data() : nullptr;
|
||||
};
|
||||
auto unwrap = [](auto const& t) { return t ? t->data() : nullptr; };
|
||||
auto A_ptr = static_cast<ElementA const*>(A.data());
|
||||
auto B_ptr = static_cast<ElementB const*>(B.data());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data());
|
||||
@@ -218,7 +246,7 @@ struct MacheteKernelTemplate {
|
||||
if constexpr (with_group_scales) {
|
||||
PD_CHECK(S_group_ptr && layout_S_group);
|
||||
PD_CHECK((size<0>(*layout_S_group) == scale_k &&
|
||||
size<1>(*layout_S_group) == N));
|
||||
size<1>(*layout_S_group) == N));
|
||||
} else {
|
||||
PD_CHECK(!S_group_ptr, "Scales not supported");
|
||||
}
|
||||
@@ -226,9 +254,9 @@ struct MacheteKernelTemplate {
|
||||
if constexpr (with_group_zeropoints) {
|
||||
PD_CHECK(Z_group_ptr && layout_Z_group);
|
||||
PD_CHECK((size<0>(*layout_Z_group) == scale_k &&
|
||||
size<1>(*layout_Z_group) == N));
|
||||
size<1>(*layout_Z_group) == N));
|
||||
PD_CHECK(layout_S_group && *layout_Z_group == *layout_S_group,
|
||||
"Scales and zeros must have the same layout");
|
||||
"Scales and zeros must have the same layout");
|
||||
} else {
|
||||
PD_CHECK(!Z_group_ptr, "Zeropoints not supported");
|
||||
}
|
||||
@@ -263,14 +291,23 @@ struct MacheteKernelTemplate {
|
||||
|
||||
if constexpr (with_group_scales && with_group_zeropoints) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments = MainloopArguments{
|
||||
B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size, Z_group_ptr};
|
||||
mainloop_arguments = MainloopArguments{B_ptr,
|
||||
_StrideB{},
|
||||
A_ptr,
|
||||
stride_At,
|
||||
S_group_ptr,
|
||||
stride_S_group,
|
||||
group_size,
|
||||
Z_group_ptr};
|
||||
} else if constexpr (with_group_scales) {
|
||||
auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride();
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
|
||||
S_group_ptr, stride_S_group, group_size};
|
||||
mainloop_arguments = MainloopArguments{B_ptr,
|
||||
_StrideB{},
|
||||
A_ptr,
|
||||
stride_At,
|
||||
S_group_ptr,
|
||||
stride_S_group,
|
||||
group_size};
|
||||
} else {
|
||||
mainloop_arguments =
|
||||
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
|
||||
@@ -295,7 +332,7 @@ struct MacheteKernelTemplate {
|
||||
|
||||
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess,
|
||||
"Machete kernel failed to initialize workspace");
|
||||
"Machete kernel failed to initialize workspace");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
PD_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
|
||||
|
||||
Reference in New Issue
Block a user