[RL] support moe-topk use topk_reduce_func (#7218)

* support moe-topk use topk_reduce_func

* fix ep error

* fix ut

* fix ut
This commit is contained in:
JYChen
2026-04-09 11:01:03 +08:00
committed by GitHub
parent 7005404ce3
commit 43ace7af25
7 changed files with 66 additions and 112 deletions
@@ -509,6 +509,7 @@ class EPRunner:
expert_in_rank_num_list=expert_in_rank_num_list, expert_in_rank_num_list=expert_in_rank_num_list,
tokens_per_expert_stats_list=tokens_per_expert_stats_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list,
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
) )
else: else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
@@ -534,6 +535,7 @@ class EPRunner:
layer.routed_scaling_factor, layer.routed_scaling_factor,
layer.gate_correction_bias, layer.gate_correction_bias,
getattr(layer, "renormalize", True), getattr(layer, "renormalize", True),
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
) )
else: else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
@@ -285,6 +285,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
layer.routed_scaling_factor, layer.routed_scaling_factor,
layer.gate_correction_bias, layer.gate_correction_bias,
getattr(layer, "renormalize", True), getattr(layer, "renormalize", True),
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
) )
( (
permute_input, permute_input,
@@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
return ffn_out return ffn_out
def moe_topk_select(
gating_output: paddle.Tensor,
n_group: int,
topk_group: int,
top_k: int,
routed_scaling_factor: float,
e_score_correction_bias: paddle.Tensor,
renormalize: bool = False,
):
"""
Topk selection using paddle PHI topk API.
Args:
gating_output: gate output logits, shape [seq_len, n_experts]
n_group: number of expert groups
topk_group: number of top-k groups to select
top_k: number of top experts per token
routed_scaling_factor: scaling factor for routed experts
e_score_correction_bias: bias for expert selection
renormalize: whether to renormalize topk probabilities
Returns:
topk_weights: normalized topk probabilities, shape [seq_len, top_k]
topk_ids: topk expert indices, shape [seq_len, top_k]
"""
# compute gate probs via sigmoid
gate_probs = paddle.nn.functional.sigmoid(gating_output)
# probs_for_choice includes correction bias for topk selection
probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs
# group-based topk selection
n_group = n_group if n_group > 0 else 1
topk_group = topk_group if topk_group > 0 else 1
if n_group > 1 and topk_group < n_group:
seq_length, n_experts = probs_for_choice.shape
group_scores = (
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
) # [seq_len, n_group]
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
group_mask = paddle.sum(
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
)
score_mask = (
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
) # [seq_len, n_experts]
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
topk_weights = paddle.index_sample(gate_probs, topk_ids)
# normalize combine weights
if renormalize:
topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12)
# apply routed scaling factor
if routed_scaling_factor:
topk_weights = topk_weights * routed_scaling_factor
return topk_weights, topk_ids
class DeepGemmFusedMoeMethod(MoEMethodBase): class DeepGemmFusedMoeMethod(MoEMethodBase):
""" """
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend. DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
@@ -403,22 +342,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
hidden_size = x.shape[1] hidden_size = x.shape[1]
# 1. Select topk experts and weights # 1. Select topk experts and weights
if ( topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
fastdeploy.envs.FD_USE_PHI_MOE_TOPK
and layer.redundant_table_manger is None
and layer.topk_method == "noaux_tc"
):
topk_weights, topk_idx = moe_topk_select(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
if topk_ids_hookfunc is not None: if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx) topk_ids_hookfunc(topk_ids=topk_idx)
@@ -820,28 +744,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
gate_out = gate_out.cast("float32") gate_out = gate_out.cast("float32")
if layer.topk_method == "noaux_tc": if layer.topk_method == "noaux_tc":
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: gate_out,
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( layer.n_group,
gate_out, layer.topk_group,
layer.n_group, layer.top_k,
layer.topk_group, layer.routed_scaling_factor,
layer.top_k, layer.gate_correction_bias,
layer.routed_scaling_factor, getattr(layer, "renormalize", True),
layer.gate_correction_bias, topk_reduce_func=getattr(layer, "topk_reduce_func", None),
getattr(layer, "renormalize", True), )
)
else:
topk_weights, topk_ids = moe_topk_select(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else: else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out, gate_out,
@@ -90,6 +90,7 @@ def get_moe_scores(
expert_in_rank_num_list: paddle.Tensor = None, expert_in_rank_num_list: paddle.Tensor = None,
tokens_per_expert_stats_list: paddle.Tensor = None, tokens_per_expert_stats_list: paddle.Tensor = None,
redundant_ep_rank_num_plus_one: int = 1, redundant_ep_rank_num_plus_one: int = 1,
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
compute moe scores using e_score_correction_bias. compute moe scores using e_score_correction_bias.
@@ -97,6 +98,14 @@ def get_moe_scores(
scores = paddle.nn.functional.sigmoid(gating_output) scores = paddle.nn.functional.sigmoid(gating_output)
assert e_score_correction_bias is not None, "e_score_correction_bias is none!" assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
scores_with_bias = scores + e_score_correction_bias scores_with_bias = scores + e_score_correction_bias
if envs.FD_USE_PHI_MOE_TOPK:
# calculate renormalize and routed_scaling_factor value outside the noaux_tc
original_renormalize = renormalize
original_routed_scaling_factor = routed_scaling_factor
renormalize = False
routed_scaling_factor = 1.0
if expert_id_to_ep_rank_array is None: if expert_id_to_ep_rank_array is None:
scores, topk_values, topk_idx = noaux_tc( scores, topk_values, topk_idx = noaux_tc(
scores, scores,
@@ -123,6 +132,16 @@ def get_moe_scores(
routed_scaling_factor, routed_scaling_factor,
redundant_ep_rank_num_plus_one, redundant_ep_rank_num_plus_one,
) )
if envs.FD_USE_PHI_MOE_TOPK:
if original_renormalize:
if topk_reduce_func is not None:
topk_values = topk_values / topk_reduce_func(topk_values)
else:
# 使用默认的 sum + epsilon
topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20)
if original_routed_scaling_factor != 1.0:
topk_values *= original_routed_scaling_factor
return scores, topk_values, topk_idx return scores, topk_values, topk_idx
@@ -152,6 +171,8 @@ class FusedMoE(nn.Layer):
with_bias: bool = False, with_bias: bool = False,
activation="swiglu", activation="swiglu",
model_format: Optional[str] = None, model_format: Optional[str] = None,
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True)
+ 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel
): ):
""" """
Initialize the Moe layer with given parameters. Initialize the Moe layer with given parameters.
@@ -197,6 +218,7 @@ class FusedMoE(nn.Layer):
self.moe_tag = moe_tag self.moe_tag = moe_tag
self.with_bias = with_bias self.with_bias = with_bias
self.activation = activation self.activation = activation
self.topk_reduce_func = topk_reduce_func
if self.ep_size > 1: if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
@@ -180,6 +180,7 @@ class Glm4Moe(nn.Layer):
layer_idx=layer_id, layer_idx=layer_id,
gate_correction_bias=self.gate.e_score_correction_bias, gate_correction_bias=self.gate.e_score_correction_bias,
weight_key_map=weight_key_map, weight_key_map=weight_key_map,
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
) )
if self.n_shared_experts > 0: if self.n_shared_experts > 0:
@@ -388,7 +388,9 @@ class TestFusedMoeCutlassBackend:
np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0)) np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0))
def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch):
def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize): def fake_get_moe_scores(
gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None
):
return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]])
def fake_dispatch(*args, **kwargs): def fake_dispatch(*args, **kwargs):
+26 -12
View File
@@ -1,10 +1,22 @@
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest import unittest
from unittest import mock
import paddle import paddle
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
moe_topk_select,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
@@ -135,15 +147,17 @@ class TestMoeRouting(unittest.TestCase):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
) )
topk_values, topk_idx = moe_topk_select( with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}):
gating_output=gating_output, new_score, topk_values, topk_idx = get_moe_scores(
n_group=n_group, gating_output=gating_output,
topk_group=topk_group, n_group=n_group,
top_k=top_k, topk_group=topk_group,
routed_scaling_factor=routed_scaling_factor, top_k=top_k,
e_score_correction_bias=e_score_correction_bias, routed_scaling_factor=routed_scaling_factor,
renormalize=renormalize, e_score_correction_bias=e_score_correction_bias,
) renormalize=renormalize,
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
)
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
equal_topk_ids = paddle.allclose( equal_topk_ids = paddle.allclose(