mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
ddb06ff83f
Co-authored-by: gongweibao <gognweibao@baidu.com>
56 lines
1.9 KiB
Plaintext
56 lines
1.9 KiB
Plaintext
#pragma once
|
|
|
|
#include "utils/machete_collective_builder.cuh"
|
|
#include "machete_mainloop.cuh"
|
|
|
|
namespace cutlass::gemm::collective {
|
|
using namespace cute;
|
|
|
|
struct MacheteKernelTag {};
|
|
|
|
template <class ElementPairA_,
|
|
class GmemLayoutA_,
|
|
int AlignmentA,
|
|
class ElementPairB_,
|
|
class GmemLayoutB_,
|
|
int AlignmentB,
|
|
class ElementAccumulator,
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class StageCountType,
|
|
class KernelScheduleType>
|
|
struct MacheteCollectiveBuilder<
|
|
MacheteKernelTag,
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
ElementPairA_,
|
|
GmemLayoutA_,
|
|
AlignmentA,
|
|
ElementPairB_,
|
|
GmemLayoutB_,
|
|
AlignmentB,
|
|
ElementAccumulator,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
StageCountType,
|
|
KernelScheduleType,
|
|
cute::enable_if_t<(
|
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
|
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
|
cute::is_same_v<KernelScheduleType,
|
|
KernelTmaWarpSpecializedCooperative>)>> {
|
|
using CollectiveOp = machete::MacheteCollectiveMma<ElementPairA_,
|
|
GmemLayoutA_,
|
|
AlignmentA,
|
|
ElementPairB_,
|
|
GmemLayoutB_,
|
|
AlignmentB,
|
|
ElementAccumulator,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
StageCountType,
|
|
KernelScheduleType>;
|
|
};
|
|
|
|
}; // namespace cutlass::gemm::collective
|