mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] support moe-topk use topk_reduce_func (#7218)
* support moe-topk use topk_reduce_func * fix ep error * fix ut * fix ut
This commit is contained in:
@@ -509,6 +509,7 @@ class EPRunner:
|
|||||||
expert_in_rank_num_list=expert_in_rank_num_list,
|
expert_in_rank_num_list=expert_in_rank_num_list,
|
||||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||||
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
|
redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1,
|
||||||
|
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select(
|
||||||
@@ -534,6 +535,7 @@ class EPRunner:
|
|||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
getattr(layer, "renormalize", True),
|
getattr(layer, "renormalize", True),
|
||||||
|
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
|
|||||||
@@ -285,6 +285,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
layer.routed_scaling_factor,
|
layer.routed_scaling_factor,
|
||||||
layer.gate_correction_bias,
|
layer.gate_correction_bias,
|
||||||
getattr(layer, "renormalize", True),
|
getattr(layer, "renormalize", True),
|
||||||
|
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
|
||||||
)
|
)
|
||||||
(
|
(
|
||||||
permute_input,
|
permute_input,
|
||||||
|
|||||||
@@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
|
|||||||
return ffn_out
|
return ffn_out
|
||||||
|
|
||||||
|
|
||||||
def moe_topk_select(
|
|
||||||
gating_output: paddle.Tensor,
|
|
||||||
n_group: int,
|
|
||||||
topk_group: int,
|
|
||||||
top_k: int,
|
|
||||||
routed_scaling_factor: float,
|
|
||||||
e_score_correction_bias: paddle.Tensor,
|
|
||||||
renormalize: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Topk selection using paddle PHI topk API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gating_output: gate output logits, shape [seq_len, n_experts]
|
|
||||||
n_group: number of expert groups
|
|
||||||
topk_group: number of top-k groups to select
|
|
||||||
top_k: number of top experts per token
|
|
||||||
routed_scaling_factor: scaling factor for routed experts
|
|
||||||
e_score_correction_bias: bias for expert selection
|
|
||||||
renormalize: whether to renormalize topk probabilities
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
topk_weights: normalized topk probabilities, shape [seq_len, top_k]
|
|
||||||
topk_ids: topk expert indices, shape [seq_len, top_k]
|
|
||||||
"""
|
|
||||||
# compute gate probs via sigmoid
|
|
||||||
gate_probs = paddle.nn.functional.sigmoid(gating_output)
|
|
||||||
# probs_for_choice includes correction bias for topk selection
|
|
||||||
probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs
|
|
||||||
# group-based topk selection
|
|
||||||
n_group = n_group if n_group > 0 else 1
|
|
||||||
topk_group = topk_group if topk_group > 0 else 1
|
|
||||||
if n_group > 1 and topk_group < n_group:
|
|
||||||
seq_length, n_experts = probs_for_choice.shape
|
|
||||||
group_scores = (
|
|
||||||
probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
|
||||||
) # [seq_len, n_group]
|
|
||||||
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group]
|
|
||||||
group_mask = paddle.sum(
|
|
||||||
paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype),
|
|
||||||
axis=1, # Sum over topk_group dimension -> [seq_len, n_group]
|
|
||||||
)
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1])
|
|
||||||
) # [seq_len, n_experts]
|
|
||||||
probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
|
||||||
|
|
||||||
_, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1)
|
|
||||||
topk_weights = paddle.index_sample(gate_probs, topk_ids)
|
|
||||||
|
|
||||||
# normalize combine weights
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12)
|
|
||||||
|
|
||||||
# apply routed scaling factor
|
|
||||||
if routed_scaling_factor:
|
|
||||||
topk_weights = topk_weights * routed_scaling_factor
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
class DeepGemmFusedMoeMethod(MoEMethodBase):
|
class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||||
"""
|
"""
|
||||||
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
|
DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend.
|
||||||
@@ -403,22 +342,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
hidden_size = x.shape[1]
|
hidden_size = x.shape[1]
|
||||||
|
|
||||||
# 1. Select topk experts and weights
|
# 1. Select topk experts and weights
|
||||||
if (
|
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
||||||
fastdeploy.envs.FD_USE_PHI_MOE_TOPK
|
|
||||||
and layer.redundant_table_manger is None
|
|
||||||
and layer.topk_method == "noaux_tc"
|
|
||||||
):
|
|
||||||
topk_weights, topk_idx = moe_topk_select(
|
|
||||||
gate_out,
|
|
||||||
layer.n_group,
|
|
||||||
layer.topk_group,
|
|
||||||
layer.top_k,
|
|
||||||
layer.routed_scaling_factor,
|
|
||||||
layer.gate_correction_bias,
|
|
||||||
getattr(layer, "renormalize", True),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
|
|
||||||
|
|
||||||
if topk_ids_hookfunc is not None:
|
if topk_ids_hookfunc is not None:
|
||||||
topk_ids_hookfunc(topk_ids=topk_idx)
|
topk_ids_hookfunc(topk_ids=topk_idx)
|
||||||
@@ -820,28 +744,16 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
gate_out = gate_out.cast("float32")
|
gate_out = gate_out.cast("float32")
|
||||||
|
|
||||||
if layer.topk_method == "noaux_tc":
|
if layer.topk_method == "noaux_tc":
|
||||||
|
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
|
||||||
if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK:
|
gate_out,
|
||||||
_, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores(
|
layer.n_group,
|
||||||
gate_out,
|
layer.topk_group,
|
||||||
layer.n_group,
|
layer.top_k,
|
||||||
layer.topk_group,
|
layer.routed_scaling_factor,
|
||||||
layer.top_k,
|
layer.gate_correction_bias,
|
||||||
layer.routed_scaling_factor,
|
getattr(layer, "renormalize", True),
|
||||||
layer.gate_correction_bias,
|
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
|
||||||
getattr(layer, "renormalize", True),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = moe_topk_select(
|
|
||||||
gate_out,
|
|
||||||
layer.n_group,
|
|
||||||
layer.topk_group,
|
|
||||||
layer.top_k,
|
|
||||||
layer.routed_scaling_factor,
|
|
||||||
layer.gate_correction_bias,
|
|
||||||
getattr(layer, "renormalize", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||||
gate_out,
|
gate_out,
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def get_moe_scores(
|
|||||||
expert_in_rank_num_list: paddle.Tensor = None,
|
expert_in_rank_num_list: paddle.Tensor = None,
|
||||||
tokens_per_expert_stats_list: paddle.Tensor = None,
|
tokens_per_expert_stats_list: paddle.Tensor = None,
|
||||||
redundant_ep_rank_num_plus_one: int = 1,
|
redundant_ep_rank_num_plus_one: int = 1,
|
||||||
|
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
compute moe scores using e_score_correction_bias.
|
compute moe scores using e_score_correction_bias.
|
||||||
@@ -97,6 +98,14 @@ def get_moe_scores(
|
|||||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||||
scores_with_bias = scores + e_score_correction_bias
|
scores_with_bias = scores + e_score_correction_bias
|
||||||
|
|
||||||
|
if envs.FD_USE_PHI_MOE_TOPK:
|
||||||
|
# calculate renormalize and routed_scaling_factor value outside the noaux_tc
|
||||||
|
original_renormalize = renormalize
|
||||||
|
original_routed_scaling_factor = routed_scaling_factor
|
||||||
|
renormalize = False
|
||||||
|
routed_scaling_factor = 1.0
|
||||||
|
|
||||||
if expert_id_to_ep_rank_array is None:
|
if expert_id_to_ep_rank_array is None:
|
||||||
scores, topk_values, topk_idx = noaux_tc(
|
scores, topk_values, topk_idx = noaux_tc(
|
||||||
scores,
|
scores,
|
||||||
@@ -123,6 +132,16 @@ def get_moe_scores(
|
|||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
redundant_ep_rank_num_plus_one,
|
redundant_ep_rank_num_plus_one,
|
||||||
)
|
)
|
||||||
|
if envs.FD_USE_PHI_MOE_TOPK:
|
||||||
|
if original_renormalize:
|
||||||
|
if topk_reduce_func is not None:
|
||||||
|
topk_values = topk_values / topk_reduce_func(topk_values)
|
||||||
|
else:
|
||||||
|
# 使用默认的 sum + epsilon
|
||||||
|
topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20)
|
||||||
|
|
||||||
|
if original_routed_scaling_factor != 1.0:
|
||||||
|
topk_values *= original_routed_scaling_factor
|
||||||
return scores, topk_values, topk_idx
|
return scores, topk_values, topk_idx
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +171,8 @@ class FusedMoE(nn.Layer):
|
|||||||
with_bias: bool = False,
|
with_bias: bool = False,
|
||||||
activation="swiglu",
|
activation="swiglu",
|
||||||
model_format: Optional[str] = None,
|
model_format: Optional[str] = None,
|
||||||
|
topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True)
|
||||||
|
+ 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Moe layer with given parameters.
|
Initialize the Moe layer with given parameters.
|
||||||
@@ -197,6 +218,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.moe_tag = moe_tag
|
self.moe_tag = moe_tag
|
||||||
self.with_bias = with_bias
|
self.with_bias = with_bias
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.topk_reduce_func = topk_reduce_func
|
||||||
|
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ class Glm4Moe(nn.Layer):
|
|||||||
layer_idx=layer_id,
|
layer_idx=layer_id,
|
||||||
gate_correction_bias=self.gate.e_score_correction_bias,
|
gate_correction_bias=self.gate.e_score_correction_bias,
|
||||||
weight_key_map=weight_key_map,
|
weight_key_map=weight_key_map,
|
||||||
|
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.n_shared_experts > 0:
|
if self.n_shared_experts > 0:
|
||||||
|
|||||||
@@ -388,7 +388,9 @@ class TestFusedMoeCutlassBackend:
|
|||||||
np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0))
|
np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0))
|
||||||
|
|
||||||
def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch):
|
def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch):
|
||||||
def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize):
|
def fake_get_moe_scores(
|
||||||
|
gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize, topk_reduce_func=None
|
||||||
|
):
|
||||||
return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]])
|
return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]])
|
||||||
|
|
||||||
def fake_dispatch(*args, **kwargs):
|
def fake_dispatch(*args, **kwargs):
|
||||||
|
|||||||
@@ -1,10 +1,22 @@
|
|||||||
|
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import (
|
|
||||||
moe_topk_select,
|
|
||||||
)
|
|
||||||
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
||||||
|
|
||||||
|
|
||||||
@@ -135,15 +147,17 @@ class TestMoeRouting(unittest.TestCase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_values, topk_idx = moe_topk_select(
|
with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}):
|
||||||
gating_output=gating_output,
|
new_score, topk_values, topk_idx = get_moe_scores(
|
||||||
n_group=n_group,
|
gating_output=gating_output,
|
||||||
topk_group=topk_group,
|
n_group=n_group,
|
||||||
top_k=top_k,
|
topk_group=topk_group,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
top_k=top_k,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
renormalize=renormalize,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
renormalize=renormalize,
|
||||||
|
topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20,
|
||||||
|
)
|
||||||
|
|
||||||
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
|
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
|
||||||
equal_topk_ids = paddle.allclose(
|
equal_topk_ids = paddle.allclose(
|
||||||
|
|||||||
Reference in New Issue
Block a user