[Others] remove template NUM_EXPERTS_PER_RANK in permute_x_fp8_kernel (#5620)

This commit is contained in:
周周周
2026-01-04 11:21:15 +08:00
committed by GitHub
parent f732d7d2ad
commit e3957a5ebc
3 changed files with 23 additions and 16 deletions
@@ -51,12 +51,15 @@ std::vector<paddle::Tensor> count_tokens_per_expert_func(
auto stream = topk_ids.stream();
using scalar_t = int64_t;
CUDA_CHECK(cudaGetLastError());
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
topk_ids_numel,
num_experts);
CUDA_CHECK(cudaGetLastError());
return {token_nums_per_expert};
}