Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -37,8 +37,12 @@ struct IlvBlkLayoutAuto {};
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
template <typename ElementA_,
typename ElementB_,
typename ElementConvert_,
typename AccumulatorT,
class LayoutB,
class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
@@ -63,8 +67,9 @@ struct PrepackedLayoutBTemplate {
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<
should_interleave,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
decltype(get_interleaved_blk_layout<ElementB,
sizeof_bits_v<ElementConvert_>,
32>()),
void>,
IlvBlkLayout_>;
@@ -85,8 +90,8 @@ struct PrepackedLayoutBTemplate {
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
using GemmTileShape = decltype(make_shape(
size<0>(PPBlockShape_NK{}), _128{}, size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
@@ -95,11 +100,16 @@ struct PrepackedLayoutBTemplate {
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
Layout<Shape<_2, _1, _1>>,
Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
cute::GMMA::rs_op_selector<ElementMma,
ElementMma,
ElementAccumulator,
GemmTileShape,
GMMA::Major::K,
GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
@@ -163,9 +173,9 @@ struct PrepackedLayoutBTemplate {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto blocks_shape = cute::transform(shape_mkl,
append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
@@ -198,9 +208,9 @@ struct PrepackedLayoutBTemplate {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto blocks_shape = cute::transform(shape_mkl,
append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
@@ -217,9 +227,9 @@ struct PrepackedLayoutBTemplate {
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto blocks_shape = cute::transform(shape_mkl,
append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto stride = size(PPBlockShape_NK{});
// (BlocksN, BlocksK, L) -> (storage_idx)