[RL] support moe-topk use topk_reduce_func (#7218)

* support moe-topk use topk_reduce_func

* fix ep error

* fix ut

* fix ut
This commit is contained in:
JYChen
2026-04-09 11:01:03 +08:00
committed by GitHub
parent 7005404ce3
commit 43ace7af25
7 changed files with 66 additions and 112 deletions
@@ -90,6 +90,7 @@ def get_moe_scores(
expert_in_rank_num_list: paddle.Tensor = None,
tokens_per_expert_stats_list: paddle.Tensor = None,
redundant_ep_rank_num_plus_one: int = 1,
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
) -> paddle.Tensor:
"""
compute moe scores using e_score_correction_bias.
@@ -97,6 +98,14 @@ def get_moe_scores(
scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias
if envs.FD_USE_PHI_MOE_TOPK:
# calculate renormalize and routed_scaling_factor value outside the noaux_tc
original_renormalize = renormalize
original_routed_scaling_factor = routed_scaling_factor
renormalize = False
routed_scaling_factor = 1.0
if expert_id_to_ep_rank_array is None:
scores, topk_values, topk_idx = noaux_tc(
scores,
@@ -123,6 +132,16 @@ def get_moe_scores(
routed_scaling_factor,
redundant_ep_rank_num_plus_one,
)
if envs.FD_USE_PHI_MOE_TOPK:
if original_renormalize:
if topk_reduce_func is not None:
topk_values = topk_values / topk_reduce_func(topk_values)
else:
# 使用默认的 sum + epsilon
topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20)
if original_routed_scaling_factor != 1.0:
topk_values *= original_routed_scaling_factor
return scores, topk_values, topk_idx
@@ -152,6 +171,8 @@ class FusedMoE(nn.Layer):
with_bias: bool = False,
activation="swiglu",
model_format: Optional[str] = None,
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True)
+ 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel
):
"""
Initialize the Moe layer with given parameters.
@@ -197,6 +218,7 @@ class FusedMoE(nn.Layer):
self.moe_tag = moe_tag
self.with_bias = with_bias
self.activation = activation
self.topk_reduce_func = topk_reduce_func
if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts