[Others] remove template NUM_EXPERTS_PER_RANK in permute_x_fp8_kernel (#5620)

This commit is contained in:
周周周
2026-01-04 11:21:15 +08:00
committed by GitHub
parent f732d7d2ad
commit e3957a5ebc
3 changed files with 23 additions and 16 deletions
@@ -51,12 +51,15 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
auto stream = topk_ids.stream();
using scalar_t = int64_t;
CUDA_CHECK(cudaGetLastError());
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
topk_ids_numel,
num_experts);
CUDA_CHECK(cudaGetLastError());
return {token_nums_per_expert};
}
@@ -797,7 +797,7 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype));
template <typename T, int NUM_EXPERTS_PER_RANK = 8>
template <typename T>
__global__ void permute_x_fp8_kernel(
const T* src_x,
const float* scale,
@@ -805,6 +805,7 @@ __global__ void permute_x_fp8_kernel(
const float* topk_weights,
const int* token_nums_per_expert,
const int* token_nums_per_expert_padded,
const int num_experts_per_rank,
const int moe_topk,
const int num_rows,
const int token_nums_this_rank,
@@ -824,11 +825,13 @@ __global__ void permute_x_fp8_kernel(
constexpr int vec_size = sizeof(int4) / sizeof(T);
constexpr int scale_vec_size = sizeof(int4) / sizeof(float);
__shared__ int write_idx; // cumsum start idx
__shared__ int token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK];
// num_experts_per_rank size array.
extern __shared__ int token_nums_per_expert_cum[];
if (tid == 0) {
int sum_now = 0;
int sum_now_padded = 0;
for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) {
for (int i = 0; i < num_experts_per_rank; i++) {
sum_now += token_nums_per_expert[i];
sum_now_padded += token_nums_per_expert_padded[i];
token_nums_per_expert_cum[i] = sum_now_padded;
@@ -843,14 +846,14 @@ __global__ void permute_x_fp8_kernel(
const int hidden_size_scale = hidden_size / 128;
const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size;
const int token_nums_feed_to_ffn =
token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK - 1];
token_nums_per_expert_cum[num_experts_per_rank - 1];
// prmt
for (int64_t s_token_idx = src_token_idx;
s_token_idx < token_nums_feed_to_ffn;
s_token_idx += gridDim.x) {
// the m_indices[s_token_idx] must be a value `i` in [0,
// NUM_EXPERTS_PER_RANK) here we parallel wo find the `i` we want.
for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i += blockDim.x) {
// num_experts_per_rank) here we parallel wo find the `i` we want.
for (int i = threadIdx.x; i < num_experts_per_rank; i += blockDim.x) {
const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1];
const int end_idx = token_nums_per_expert_cum[i];
if (s_token_idx >= start_idx && s_token_idx < end_idx) {
@@ -883,7 +886,7 @@ __global__ void permute_x_fp8_kernel(
dst_weights[dst_token_idx] =
topk_weights[s_token_idx * moe_topk + expert_idx];
// m_indices[dst_token_idx] = expert_now; // not need?
dst_indices[s_token_idx * NUM_EXPERTS_PER_RANK + expert_now] =
dst_indices[s_token_idx * num_experts_per_rank + expert_now] =
expert_now;
// cp x
for (int64_t v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) {
@@ -932,17 +935,15 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
auto place = input.place();
// const int gridx = min(132 * 8, num_rows);
const int gridx = 132 * 8;
DISPATCH_NUM_EXPERTS_PER_RANK(
num_experts_per_rank,
NUM_EXPERTS_PER_RANK,
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, NUM_EXPERTS_PER_RANK>
<<<gridx, 512, 0, stream>>>(
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn>
<<<gridx, 512, num_experts_per_rank * sizeof(int32_t), stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
num_experts_per_rank,
moe_topk,
num_rows,
token_nums_this_rank,
@@ -956,7 +957,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>());)
m_indices->data<int>());
CUDA_CHECK(cudaGetLastError());
}
std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(