mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
【Optim】Optimize grid dimensions using max_tokens_per_expert for MoE models (#6007)
* update w4afp8 * build.sh ok * support cuda_graph * fix * add test * fix max_tokens_per_expert * >=70 * fix * compute_max_tokens_from_prefix_sum in w4afp8 * compute_max_tokens use cub
This commit is contained in:
@@ -41,7 +41,8 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
const float * input_dequant_scale,
|
||||
const int64_t *tokens,
|
||||
const int64_t max_tokens,
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream,
|
||||
const int64_t *max_tokens_per_expert = nullptr);
|
||||
"""
|
||||
|
||||
gemm_template_cu_head = """
|
||||
@@ -59,7 +60,8 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
const float * input_dequant_scale,
|
||||
const int64_t *tokens,
|
||||
const int64_t max_tokens,
|
||||
cudaStream_t stream) {{
|
||||
cudaStream_t stream,
|
||||
const int64_t *max_tokens_per_expert) {{
|
||||
|
||||
constexpr static int M = {M};
|
||||
constexpr static int K = {K};
|
||||
@@ -81,7 +83,7 @@ void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(
|
||||
{cutlass_type}>;
|
||||
run_gemm<cutlass::float_e4m3_t, {cutlass_type},
|
||||
Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize>
|
||||
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream);
|
||||
(weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream,max_tokens_per_expert);
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -116,6 +118,10 @@ gemm_case = [
|
||||
[7168, 3584, 7, 16384, 128], # num_max_dispatch_tokens_per_rank=256
|
||||
[7168, 7168, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
|
||||
[7168, 3584, 7, 20480, 128], # num_max_dispatch_tokens_per_rank=320
|
||||
[3072, 1536, 128, 0, 128],
|
||||
[1536, 1536, 128, 0, 128],
|
||||
[1536, 768, 128, 0, 128],
|
||||
[768, 1536, 128, 0, 128],
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
||||
Reference in New Issue
Block a user