Files
FastDeploy/custom_ops/gpu_ops/machete/utils/machete_collective_builder.cuh
T
gongweibao ddb06ff83f init (#6642)
Co-authored-by: gongweibao <gognweibao@baidu.com>
2026-03-04 21:55:31 +08:00

85 lines
2.9 KiB
Plaintext

// adapted from:
// https://github.com/vllm-project/vllm/blob/main/csrc/cutlass_extensions/vllm_collective_builder.cuh
#pragma once
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// MacheteCollectiveBuilder is a wrapper around CollectiveBuilder that allows
// for for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag,
class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType,
class Enable = void>
struct MacheteCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType>
struct MacheteCollectiveBuilder<CutlassKernelTag,
ArchTag,
OpClass,
ElementA,
GmemLayoutA,
AlignmentA,
ElementB,
GmemLayoutB,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType> {
using CollectiveOp =
typename CollectiveBuilder<ArchTag,
OpClass,
ElementA,
GmemLayoutA,
AlignmentA,
ElementB,
GmemLayoutB,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective