[Feature] use phi permute/unpermute & rm swiglu (#6361)

* tp文字输出正常

* B eb5 mini文字输出正常

* eb5mini ep B卡 文字输出正常

* default use phi moe op

* stash

* tp H卡正常

* ep ok

* rm debug

* rm debug tool

* rm del ffn_out

* rm swiglu

* add envs to swiglu

* merge dev

* fix ci baseline

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix ci baseline 2

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
fxyfxy777
2026-03-12 17:01:57 +08:00
committed by GitHub
parent a3d7979711
commit 250ce40b40
18 changed files with 187 additions and 112 deletions
@@ -151,7 +151,8 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
(permute_input.shape[0], layer_added_weight_attrs_0.shape[1]),
dtype=paddle.bfloat16,
)
if disable_ue8m0_cast:
# if disable_ue8m0_cast:
if permute_scale.strides[0] != 1:
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
# disable_ue8m0_cast is False for SM100
@@ -487,31 +488,52 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
elif token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
recv_topk_idx = recv_topk_idx.astype(paddle.int32)
(
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
dst_weights,
permute_scale,
m_indices,
) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=recv_x_scale,
expert_routemap_topk=recv_topk_idx,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
return_expert_indices=True,
override_buffer_size=token_all_num,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
token_nums_this_rank[0],
token_nums_this_rank[1],
True, # use_in_ep
token_all_num,
)
(
permute_input,
permute_scale,
permute_indices_per_token,
_,
_,
dst_weights,
dst_indices,
_,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x_value,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
token_nums_this_rank[0],
token_nums_this_rank[1],
True, # use_in_ep
token_all_num,
)
assert permute_input.shape[0] == token_all_num
if not self.quant_config.deepgemm_scale_ue8m0:
if permute_scale.strides[0] != 1:
permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0])
# up_gate_proj
@@ -553,20 +575,30 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
ffn_out,
m_indices,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=recv_topk_idx,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
)
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)
else:
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)
else:
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)
# 5. EP combine
event = deep_ep.Buffer.capture()
if self.ep_prefill_runner.num_worst_tokens <= 0:
@@ -697,14 +729,11 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids)
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
x, 128, self.quant_config.deepgemm_scale_ue8m0
)
else:
recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
@@ -717,26 +746,49 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
else recv_x_scale.T[: recv_x.shape[0]]
)
(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
topk_ids,
topk_weights,
tmp[0],
tmp[1],
False, # use_in_ep
-1,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
topk_ids = topk_ids.astype(paddle.int32)
override_buffer_size = recv_x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
(
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
dst_weights,
permute_scale,
m_indices,
) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=recv_x_scale,
expert_routemap_topk=topk_ids,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
return_expert_indices=True,
override_buffer_size=override_buffer_size,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
topk_ids,
topk_weights,
tmp[0],
tmp[1],
False, # use_in_ep
-1,
)
ffn_out = m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
permute_input,
@@ -751,14 +803,24 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
)
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None,
False, # norm_topk_prob
1.0,
)
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE:
tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_ids,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
)
else:
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None,
False, # norm_topk_prob
1.0,
)
return tmp_ffn_out