【Optim】Optimize grid dimensions using max_tokens_per_expert for MoE models (#6007)

* update w4afp8

* build.sh ok

* support cuda_graph

* fix

* add test

* fix max_tokens_per_expert

* >=70

* fix

* compute_max_tokens_from_prefix_sum in w4afp8

* compute_max_tokens use cub
This commit is contained in:
lizexu123
2026-01-15 19:18:42 +08:00
committed by GitHub
parent b0fc9cadb5
commit 6619298b50
9 changed files with 514 additions and 29 deletions
@@ -41,7 +41,8 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
const float * input_dequant_scale,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream);
cudaStream_t stream,
const int64_t *max_tokens_per_expert = nullptr);
"""
gemm_template_cu_head = """
@@ -59,7 +60,8 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
const float * input_dequant_scale,
const int64_t *tokens,
const int64_t max_tokens,
cudaStream_t stream) {{
cudaStream_t stream,
const int64_t *max_tokens_per_expert) {{
constexpr static int M = {M};
constexpr static int K = {K};
@@ -81,7 +83,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
{cutlass_type}>;
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize>
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream);
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream,max_tokens_per_expert);
}}
"""
@@ -116,6 +118,10 @@ gemm_case = [
[7168, 3584, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256
[7168, 7168, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
[7168, 3584, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
[3072, 1536, 128, 0, 128],
[1536, 1536, 128, 0, 128],
[1536, 768, 128, 0, 128],
[768, 1536, 128, 0, 128],
]
dtype = ["BF16"]