mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[OP]Unify MoE op with moe_permute path for bf16 GLM (#7164)
This commit is contained in:
@@ -28,7 +28,11 @@ from ..utils import get_tensor, group_wise_int4_weight_quantize, pack, rotate_mo
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
count_tokens_per_expert_func,
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
@@ -126,6 +130,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
# 1. Select topk experts and weights
|
||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||
# 2. EP Dispatch
|
||||
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
@@ -133,7 +138,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
|
||||
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)
|
||||
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
@@ -146,54 +151,91 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
# 3. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
# only w4a8 and w4afp8 need expert_idx_per_token
|
||||
# Other need not this tensor, so we make it None.
|
||||
expert_idx_per_token = None
|
||||
|
||||
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
|
||||
# --- moe_permute / moe_unpermute path ---
|
||||
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
|
||||
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
|
||||
hidden_states=recv_x,
|
||||
scale=None,
|
||||
expert_routemap_topk=recv_topk_idx_i32,
|
||||
expert_prob_topk=recv_topk_weights,
|
||||
num_experts=layer.num_local_experts,
|
||||
tokens_per_expert=[],
|
||||
padding_alignment=128,
|
||||
override_buffer_size=token_all_num,
|
||||
)
|
||||
|
||||
token_nums_per_expert_cumsum = count_tokens_per_expert_func(
|
||||
recv_topk_idx, layer.num_local_experts, True
|
||||
)[2].cast(paddle.int64)
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
token_nums_per_expert_cumsum,
|
||||
None,
|
||||
False,
|
||||
-1,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
|
||||
hidden_states_unzipped=ffn_out,
|
||||
zipped_expertwise_rowmap=permute_indices_per_token,
|
||||
expert_routemap_topk=recv_topk_idx_i32,
|
||||
token_prob_unzipped=dst_weights,
|
||||
total_zipped_tokens=recv_x.shape[0],
|
||||
num_experts=layer.num_local_experts,
|
||||
using_weighted_combine=True,
|
||||
)
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
# --- original ep_moe_expert_dispatch / combine path ---
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
dst_weights,
|
||||
dst_indices,
|
||||
cumsum_idx_gpu,
|
||||
expert_idx_per_token,
|
||||
dequant_scale,
|
||||
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
|
||||
recv_num_tokens_per_expert_list,
|
||||
token_all_num,
|
||||
self.moe_quant_type,
|
||||
)
|
||||
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
if hasattr(layer, "up_gate_proj_in_scale"):
|
||||
dequant_scale = None
|
||||
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
expert_idx_per_token,
|
||||
False,
|
||||
-1,
|
||||
dequant_scale,
|
||||
)
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
expert_idx_per_token,
|
||||
False,
|
||||
-1,
|
||||
dequant_scale,
|
||||
)
|
||||
|
||||
# prmt back per rank
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # down_proj_bias,
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)
|
||||
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
|
||||
ffn_out,
|
||||
dst_weights,
|
||||
permute_indices_per_token,
|
||||
dst_indices,
|
||||
None, # down_proj_bias,
|
||||
False, # norm_topk_prob
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
tmp_ffn_out = recv_x
|
||||
|
||||
@@ -276,6 +318,69 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
gate_out = gate(x)
|
||||
gate_out = gate_out.cast("float32")
|
||||
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
getattr(layer, "renormalize", True),
|
||||
)
|
||||
else:
|
||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
topk_idx_i32 = topk_idx.astype(paddle.int32)
|
||||
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
|
||||
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
|
||||
paddle.nn.functional.moe_permute(
|
||||
hidden_states=x,
|
||||
scale=None,
|
||||
expert_routemap_topk=topk_idx_i32,
|
||||
expert_prob_topk=topk_weights,
|
||||
num_experts=layer.num_experts,
|
||||
tokens_per_expert=[],
|
||||
padding_alignment=128,
|
||||
override_buffer_size=override_buffer_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
|
||||
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
|
||||
paddle.int64
|
||||
)
|
||||
if topk_ids_hookfunc is not None:
|
||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
token_nums_per_expert_cumsum,
|
||||
None, # expert_idx_per_token not needed for w16a16 without bias
|
||||
False,
|
||||
-1,
|
||||
None, # dequant_scale
|
||||
None, # max_tokens_per_expert
|
||||
)
|
||||
|
||||
fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
|
||||
hidden_states_unzipped=ffn_out,
|
||||
zipped_expertwise_rowmap=permute_indices_per_token,
|
||||
expert_routemap_topk=topk_idx_i32,
|
||||
token_prob_unzipped=dst_weights,
|
||||
total_zipped_tokens=x.shape[0],
|
||||
num_experts=layer.num_experts,
|
||||
using_weighted_combine=True,
|
||||
)
|
||||
return fused_moe_out
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
gate_out, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
@@ -287,6 +392,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
getattr(layer, "renormalize", True),
|
||||
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
|
||||
)
|
||||
|
||||
(
|
||||
permute_input,
|
||||
token_nums_per_expert,
|
||||
@@ -341,7 +447,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||
|
||||
ffn_out = self.compute_ffn(
|
||||
layer,
|
||||
permute_input,
|
||||
@@ -363,7 +468,6 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
|
||||
@@ -521,7 +521,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
else:
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False)
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
@@ -805,7 +805,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
|
||||
)
|
||||
else:
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
|
||||
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False)
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
|
||||
Reference in New Issue
Block a user