[RL] support moe-topk use topk_reduce_func (#7218)

* support moe-topk use topk_reduce_func

* fix ep error

* fix ut

* fix ut
This commit is contained in:
JYChen
2026-04-09 11:01:03 +08:00
committed by GitHub
parent 7005404ce3
commit 43ace7af25
7 changed files with 66 additions and 112 deletions
@@ -388,7 +388,9 @@ class TestFusedMoeCutlassBackend:
np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0))
def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch):
def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize):
def fake_get_moe_scores(
gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None
):
return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]])
def fake_dispatch(*args, **kwargs):
+26 -12
View File
@@ -1,10 +1,22 @@
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest import mock
import paddle
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
moe_topk_select,
)
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
@@ -135,15 +147,17 @@ class TestMoeRouting(unittest.TestCase):
e_score_correction_bias=e_score_correction_bias,
)
topk_values, topk_idx = moe_topk_select(
gating_output=gating_output,
n_group=n_group,
topk_group=topk_group,
top_k=top_k,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
renormalize=renormalize,
)
with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}):
new_score, topk_values, topk_idx = get_moe_scores(
gating_output=gating_output,
n_group=n_group,
topk_group=topk_group,
top_k=top_k,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
renormalize=renormalize,
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
)
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
equal_topk_ids = paddle.allclose(