From 5e54770b2ef659095e8e147294071661d43bcebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:57:07 +0800 Subject: [PATCH] =?UTF-8?q?[Feature]=20=E6=B7=BB=E5=8A=A0=20MoE=20?= =?UTF-8?q?=E5=B1=82=20latent=20mode=20=E6=94=AF=E6=8C=81=20(#7382)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../layers/moe/fused_moe_backend_base.py | 4 +- .../layers/moe/fused_moe_triton_backend.py | 49 ++++++++++++--- fastdeploy/model_executor/layers/moe/moe.py | 26 +++++++- tests/layers/test_fused_moe_triton_backend.py | 6 ++ tests/layers/test_fusedmoe.py | 63 ++++++++++++++++--- 5 files changed, 127 insertions(+), 21 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index f1cf9ce578..c2c61cd53c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -218,6 +218,8 @@ class MoEMethodBase(QuantMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Paddle Cutlass compute Fused MoE. @@ -237,7 +239,7 @@ class MoEMethodBase(QuantMethodBase): layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts ) else: - return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) + return self.apply_tp(layer, x, gate, topk_ids_hookfunc) class UnquantizedFusedMoEMethod(MoEMethodBase): diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 56f6e6dd42..be7c69a04d 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -292,6 +292,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -681,6 +683,8 @@ class Wfp8Afp8MoEMethod(QuantMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -980,6 +984,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -1174,6 +1180,9 @@ def python_op_fused_moe_kernel_paddle_infer_meta( config: dict, quant_config, topk_ids_hookfunc, + layer, + fc1_latent_proj, + fc2_latent_proj, ): token_num = x.shape[0] return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype) @@ -1211,19 +1220,34 @@ def python_op_fused_moe_kernel_paddle( config: dict, quant_config, topk_ids_hookfunc, + layer, + fc1_latent_proj, + fc2_latent_proj, ): token_num = x.shape[0] if x.shape[0] == 0: return paddle.zeros([token_num, hidden_size], dtype=x.dtype) - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - gate_correction_bias, - top_k, - True, # apply_norm_weight - False, - ) + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_ids) @@ -1244,6 +1268,9 @@ def python_op_fused_moe_kernel_paddle( from .triton_moe_kernels import fused_moe_kernel_paddle + if fc1_latent_proj is not None: + x = fc1_latent_proj(x) + if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False) else: @@ -1357,6 +1384,9 @@ def python_op_fused_moe_kernel_paddle( intermediate_cache3.reshape_([token_num, top_k, hidden_size]) out = intermediate_cache3.sum(axis=1) + if fc2_latent_proj is not None: + out = fc2_latent_proj(out) + return out @@ -1808,6 +1838,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): gate: nn.Layer, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ) -> paddle.Tensor: """ Triton compute Fused MoE. @@ -1855,4 +1887,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase): config, self.quant_config, topk_ids_hookfunc, + layer, + fc1_latent_proj, + fc2_latent_proj, ) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 120a4ecc30..4aa23f6793 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -709,7 +709,13 @@ class FusedMoE(nn.Layer): return out def forward( - self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None + self, + x: paddle.Tensor, + gate: nn.Layer, + forward_meta: ForwardMeta = None, + shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ): """ Defines the forward computation of the moe layer. @@ -762,7 +768,13 @@ class FusedMoE(nn.Layer): ) else: out = self.forward_normal( - x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts + x, + gate, + forward_meta, + topk_ids_hookfunc, + shared_experts, + fc1_latent_proj, + fc2_latent_proj, ) if self.reduce_results and self.tp_size > 1: @@ -829,6 +841,8 @@ class FusedMoE(nn.Layer): forward_meta: ForwardMeta, topk_ids_hookfunc: Callable = None, shared_experts: nn.Layer = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, ): """ Normal mode of forward. @@ -842,7 +856,13 @@ class FusedMoE(nn.Layer): """ if current_platform.is_cuda(): out = self.quant_method.apply( - self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts + self, + x, + gate, + topk_ids_hookfunc, + shared_experts, + fc1_latent_proj, + fc2_latent_proj, ) else: out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 7dacbbe390..6e8bf95c0e 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -509,6 +509,9 @@ class TestFusedMoeTritonBackend: config, quant_config, hook, + layer, + None, + None, ) assert "topk" in captured @@ -530,6 +533,9 @@ class TestFusedMoeTritonBackend: config, quant_config, None, + layer, + None, + None, ) assert meta.shape == [2, layer.hidden_size] diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py index 5c8e74e0f6..69dbaffbf8 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_fusedmoe.py @@ -506,10 +506,41 @@ class FuseMoEWrapper(paddle.nn.Layer): skip_quant=True, weight_dtype="float32", ) + self.gating.weight.set_value(paddle.rand(self.gating.weight.shape, dtype=paddle.float32)) + + self.fc1_latent_proj = None + self.fc2_latent_proj = None + + if self.fd_config.model_config.use_latent_moe: + self.fc1_latent_proj = ReplicatedLinear( + fd_config=self.fd_config, + input_size=self.fd_config.model_config.hidden_size, + output_size=self.fd_config.model_config.moe_latent_size, + with_bias=True, + ) + self.fc1_latent_proj.weight.set_value( + paddle.zeros(self.fc1_latent_proj.weight.shape).cast(paddle.float8_e4m3fn) + ) + self.fc1_latent_proj.bias.set_value(paddle.zeros(self.fc1_latent_proj.bias.shape)) + + self.fc2_latent_proj = ReplicatedLinear( + fd_config=self.fd_config, + input_size=self.fd_config.model_config.moe_latent_size, + output_size=self.fd_config.model_config.hidden_size, + with_bias=True, + ) + self.fc2_latent_proj.weight.set_value( + paddle.zeros(self.fc2_latent_proj.weight.shape).cast(paddle.float8_e4m3fn) + ) + self.fc2_latent_proj.bias.set_value(paddle.zeros(self.fc2_latent_proj.bias.shape)) self.fused_moe = FusedMoE( fd_config=self.fd_config, - hidden_size=self.fd_config.model_config.hidden_size, + hidden_size=( + self.fd_config.model_config.moe_latent_size + if self.fd_config.model_config.use_latent_moe + else self.fd_config.model_config.hidden_size + ), moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, num_experts=self.fd_config.model_config.moe_num_experts, top_k=self.fd_config.model_config.moe_k, @@ -517,8 +548,8 @@ class FuseMoEWrapper(paddle.nn.Layer): layer_idx=666, weight_key_map=weight_key_map, topk_method="noaux_tc", - topk_group=4, - n_group=8, + topk_group=0, + n_group=0, gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32), # gate_correction_bias = gate_correction_bias_real_data ) @@ -558,11 +589,20 @@ class FuseMoEWrapper(paddle.nn.Layer): class TestFusedMoE(unittest.TestCase): def setUp(self) -> None: self.architectures = ["Ernie4_5_MoeForCausalLM"] - self.hidden_size = 4096 - self.moe_intermediate_size = 2048 - self.moe_num_experts = 64 - self.moe_k = 8 - self.num_layers = 2 + self.hidden_size = 1536 + self.moe_intermediate_size = 1024 + + self.moe_num_experts = 256 + self.moe_k = 16 + self.use_latent_moe = True + self.moe_latent_size = 768 + + # self.moe_num_experts = 128 + # self.moe_k = 8 + # self.use_latent_moe = False + # self.moe_latent_size = 768 + + self.num_layers = 50 self.num_attention_heads = -1 self.model_config = self.build_model_config() @@ -584,6 +624,8 @@ class TestFusedMoE(unittest.TestCase): "moe_k": self.moe_k, "num_attention_heads": self.num_attention_heads, "dtype": "bfloat16", + "use_latent_moe": self.use_latent_moe, + "moe_latent_size": self.moe_latent_size, } tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}" @@ -635,9 +677,10 @@ class TestFusedMoE(unittest.TestCase): out = cache_hidden_states + cache_hidden_states else: gating = fused_moe[j % real_weight_layers].gating - gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) + fc1_latent_proj = fused_moe[j % real_weight_layers].fc1_latent_proj + fc2_latent_proj = fused_moe[j % real_weight_layers].fc2_latent_proj out = fused_moe[j % real_weight_layers].fused_moe( - cache_hidden_states[idx], gating, forward_meta=MockForwardMeta() + cache_hidden_states[idx], gating, MockForwardMeta(), None, fc1_latent_proj, fc2_latent_proj ) return out