mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
@@ -37,8 +37,12 @@ struct IlvBlkLayoutAuto {};
|
||||
// The contract here is that the `TiledMma` determined below matches the one
|
||||
// ultimately used in the kernel. (this is also why the other element types are
|
||||
// required along with the kernel schedule)
|
||||
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
|
||||
typename AccumulatorT, class LayoutB, class KernelSchedule,
|
||||
template <typename ElementA_,
|
||||
typename ElementB_,
|
||||
typename ElementConvert_,
|
||||
typename AccumulatorT,
|
||||
class LayoutB,
|
||||
class KernelSchedule,
|
||||
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
|
||||
// clang-format on
|
||||
struct PrepackedLayoutBTemplate {
|
||||
@@ -63,8 +67,9 @@ struct PrepackedLayoutBTemplate {
|
||||
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
|
||||
std::conditional_t<
|
||||
should_interleave,
|
||||
decltype(get_interleaved_blk_layout<
|
||||
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
|
||||
decltype(get_interleaved_blk_layout<ElementB,
|
||||
sizeof_bits_v<ElementConvert_>,
|
||||
32>()),
|
||||
void>,
|
||||
IlvBlkLayout_>;
|
||||
|
||||
@@ -85,8 +90,8 @@ struct PrepackedLayoutBTemplate {
|
||||
// registers.
|
||||
// The _128 here doesn't actually impact the shape of the stored tile directly
|
||||
// but may impact the op selected by rs_op_selector
|
||||
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
|
||||
size<1>(PPBlockShape_NK{})));
|
||||
using GemmTileShape = decltype(make_shape(
|
||||
size<0>(PPBlockShape_NK{}), _128{}, size<1>(PPBlockShape_NK{})));
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
gmma_rs_tag_to_major_B<LayoutB>();
|
||||
@@ -95,11 +100,16 @@ struct PrepackedLayoutBTemplate {
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
Layout<Shape<_2, _1, _1>>,
|
||||
Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
|
||||
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
|
||||
cute::GMMA::rs_op_selector<ElementMma,
|
||||
ElementMma,
|
||||
ElementAccumulator,
|
||||
GemmTileShape,
|
||||
GMMA::Major::K,
|
||||
GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
// Prepacked block, (athrid, val) -> (N,K)
|
||||
@@ -163,9 +173,9 @@ struct PrepackedLayoutBTemplate {
|
||||
constexpr auto block_layout = ppblock_TV_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto blocks_shape = cute::transform(shape_mkl,
|
||||
append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
@@ -198,9 +208,9 @@ struct PrepackedLayoutBTemplate {
|
||||
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
|
||||
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto blocks_shape = cute::transform(shape_mkl,
|
||||
append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
|
||||
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
|
||||
auto result = make_layout(
|
||||
@@ -217,9 +227,9 @@ struct PrepackedLayoutBTemplate {
|
||||
template <typename Shape_NKL>
|
||||
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
|
||||
// (BlocksN, BlocksK, L)
|
||||
auto blocks_shape =
|
||||
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto blocks_shape = cute::transform(shape_mkl,
|
||||
append(PPBlockShape_NK{}, _1{}),
|
||||
[](auto x, auto y) { return x / y; });
|
||||
auto stride = size(PPBlockShape_NK{});
|
||||
|
||||
// (BlocksN, BlocksK, L) -> (storage_idx)
|
||||
|
||||
Reference in New Issue
Block a user