// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "cute/algorithm/copy.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" using namespace cute; template struct SharedStorage { union { struct { cute::array_aligned> smem_a; cute::array_aligned> smem_b; cute::array_aligned> smem_scale; }; cute::array_aligned> smem_c; }; struct { typename cutlass::PipelineTmaAsync::SharedStorage pipeline; }; }; template struct Kernel_traits { using Element = elem_type; using ElementOutput = OutputType; using ElementAccum = typename std:: conditional_t; static_assert(cutlass::sizeof_bits_v == 8); static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; static_assert(kNWarps_ == 12 || kNWarps_ == 16); static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN1 = kBlockN1_; static constexpr int kBlockN2 = kBlockN2_; static constexpr int kBlockN3 = kBlockN3_; static constexpr int kBlockK = kBlockK_; static constexpr int kTiles = kTiles_; static constexpr int TokenPackSize = TokenPackSize_; static constexpr int M = M_; static constexpr int K = K_; static constexpr int WeightScaleGroup = WeightScaleGroup_; using TileShape_MNK1 = Shape, Int, Int>; using TileShape_MNK2 = Shape, Int, Int>; using TileShape_MNK3 = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; static constexpr int kStages = kStages_; static_assert(kStages > 1); using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA:: rs_op_selector(), AtomLayoutMNK{})); using TiledMma2 = decltype(cute::make_tiled_mma( cute::GMMA:: rs_op_selector(), AtomLayoutMNK{})); using TiledMma3 = decltype(cute::make_tiled_mma( cute::GMMA:: rs_op_selector(), AtomLayoutMNK{})); using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, Int, Int>()); using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, make_shape(Int{}, Int{}, Int{}))); using SmemLayoutAtomB1 = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK1{})), decltype(cute::get<2>(TileShape_MNK1{}))>()); using SmemLayoutB1 = decltype(tile_to_shape(SmemLayoutAtomB1{}, make_shape(shape<1>(TileShape_MNK1{}), shape<2>(TileShape_MNK1{}), Int{}))); using SmemLayoutAtomB2 = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK2{})), decltype(cute::get<2>(TileShape_MNK2{}))>()); using SmemLayoutB2 = decltype(tile_to_shape(SmemLayoutAtomB2{}, make_shape(shape<1>(TileShape_MNK2{}), shape<2>(TileShape_MNK2{}), Int{}))); using SmemLayoutAtomB3 = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK3{})), decltype(cute::get<2>(TileShape_MNK3{}))>()); using SmemLayoutB3 = decltype(tile_to_shape(SmemLayoutAtomB3{}, make_shape(shape<1>(TileShape_MNK3{}), shape<2>(TileShape_MNK3{}), Int{}))); using SmemLayoutAtomC = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, ElementOutput, decltype(cute::get<0>(TileShape_MNK1{})), decltype(cute::get<1>(TileShape_MNK1{}))>()); using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK1{}))); using SmemCopyAtomAB = Copy_Atom; using SmemCopyAtomC = Copy_Atom; using SmemLayoutScale = Layout, Int>>; using SharedStorage = SharedStorage; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); static constexpr int kNumThreadsPerRow = kBlockN1 / kNumVecElem; // static_assert(NumMmaThreads % kNumThreadsPerRow == 0); static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; using TiledCopyCAtom = cute::Copy_Atom, OutputType>; using TiledCopyCThrLayout = decltype(cute::make_layout( cute::make_shape(Int{}, Int{}), LayoutRight{})); using TiledCopyCValLayout = decltype(cute::make_layout( cute::make_shape(_1{}, Int{}), LayoutRight{})); using TiledCopyC = decltype(make_tiled_copy(TiledCopyCAtom{}, TiledCopyCThrLayout{}, // Thr layout TiledCopyCValLayout{} // Val layout )); };