// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" template __global__ void MoEDeepGEMMDePermuteKernel(T* out, const T* ffn_out, const int* permute_indices_per_token, const int64_t* topk_idx, const float* topk_weights, const int token_num, const int num_vecs, const int hidden, const int max_tokens_per_expert) { AlignedVector in_vec; AlignedVector acc_vec[TopK]; const int bid = blockIdx.x; const int wid = threadIdx.x / 32; const int tid = threadIdx.x % 32; extern __shared__ char shm[]; // TopK * hidden T* shm_hidden = reinterpret_cast(shm); for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { int src_expert_id = topk_idx[token_idx * TopK + wid]; int src_expert_token = permute_indices_per_token[token_idx * TopK + wid]; float weight = topk_weights[token_idx * TopK + wid]; for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) { Load(ffn_out + src_expert_id * max_tokens_per_expert * hidden + src_expert_token * hidden + hidden_vec_id * VecSize, &in_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { in_vec[i] *= weight; } Store(in_vec, shm_hidden + wid * hidden + hidden_vec_id * VecSize); } __syncthreads(); for (int hidden_vec_id = threadIdx.x; hidden_vec_id < num_vecs; hidden_vec_id += blockDim.x) { #pragma unroll for (int topk_id = 0; topk_id < TopK; topk_id++) { Load( shm_hidden + topk_id * hidden + hidden_vec_id * VecSize, &acc_vec[topk_id]); } #pragma unroll for (int i = 0; i < VecSize; i++) { #pragma unroll for (int topk_id = 1; topk_id < TopK; topk_id++) { acc_vec[0][i] += acc_vec[topk_id][i]; } } Store(acc_vec[0], out + token_idx * hidden + hidden_vec_id * VecSize); } } } template std::vector MoEDeepGEMMDePermuteDispatch( const paddle::Tensor& ffn_out, // [num_experts, max_tokens_per_expert, hidden] const paddle::Tensor& permute_indices_per_token, // [token_num, topk}] const paddle::Tensor& topk_idx, const paddle::Tensor& topk_weights) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int token_num = permute_indices_per_token.shape()[0]; const int max_tokens_per_expert = ffn_out.shape()[1]; const int hidden = ffn_out.shape()[2]; const int topk = permute_indices_per_token.shape()[1]; auto place = ffn_out.place(); auto stream = ffn_out.stream(); auto out = GetEmptyTensor({token_num, hidden}, ffn_out.dtype(), place); constexpr int VecSize = 16 / sizeof(data_t); int blocks = 32 * topk; int grids = min(132 * 4, token_num); int num_vecs = hidden / VecSize; assert(blocks <= 1024); int dyn_smem_size = 0; switch (topk) { case 4: dyn_smem_size = topk * hidden * sizeof(DataType_); if (dyn_smem_size >= (48 << 10)) { cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dyn_smem_size); } MoEDeepGEMMDePermuteKernel <<>>( reinterpret_cast(out.data()), reinterpret_cast(ffn_out.data()), permute_indices_per_token.data(), topk_idx.data(), topk_weights.data(), token_num, num_vecs, hidden, max_tokens_per_expert); break; case 8: dyn_smem_size = topk * hidden * sizeof(DataType_); if (dyn_smem_size >= (48 << 10)) { cudaFuncSetAttribute(MoEDeepGEMMDePermuteKernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dyn_smem_size); } MoEDeepGEMMDePermuteKernel <<