[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164)

This commit is contained in:
fxyfxy777
2026-04-09 16:17:56 +08:00
committed by GitHub
parent 33682c6749
commit 39ff38aba1
5 changed files with 444 additions and 69 deletions
@@ -28,7 +28,11 @@ from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_mo
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
from fastdeploy.model_executor.ops.gpu import (
count_tokens_per_expert_func,
moe_expert_dispatch,
moe_expert_reduce,
)
try:
from fastdeploy.model_executor.ops.gpu import (
@@ -126,6 +130,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. EP Dispatch
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
(
recv_x,
recv_topk_idx,
@@ -133,7 +138,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
recv_num_tokens_per_expert_list,
handle,
event,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
@@ -146,54 +151,91 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
# 3. Compute ffn
if token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 and w4afp8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
)
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
recv_topk_idx, layer.num_local_experts, True
)[2].cast(paddle.int64)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None,
False,
-1,
None,
None,
)
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=recv_topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
)
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
# --- original ep_moe_expert_dispatch / combine path ---
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
else:
tmp_ffn_out = recv_x
@@ -276,6 +318,69 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""
gate_out = gate(x)
gate_out = gate_out.cast("float32")
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
topk_idx_i32 = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx_i32,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=override_buffer_size,
)
)
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
paddle.int64
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None, # expert_idx_per_token not needed for w16a16 without bias
False,
-1,
None, # dequant_scale
None, # max_tokens_per_expert
)
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
)
return fused_moe_out
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
@@ -287,6 +392,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
getattr(layer, "renormalize", True),
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
)
(
permute_input,
token_nums_per_expert,
@@ -341,7 +447,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
ffn_out = self.compute_ffn(
layer,
permute_input,
@@ -363,7 +468,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)
return fused_moe_out
@@ -521,7 +521,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False)
(
permute_input,
permute_scale,
@@ -805,7 +805,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False)
(
permute_input,
permute_scale,