[Feature] 添加 MoE 层 latent mode 支持 (#7382)

This commit is contained in:
周周周
2026-04-15 13:57:07 +08:00
committed by GitHub
parent f7a2418ce2
commit 5e54770b2e
5 changed files with 127 additions and 21 deletions
@@ -218,6 +218,8 @@ class MoEMethodBase(QuantMethodBase):
gate: nn.Layer, gate: nn.Layer,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Paddle Cutlass compute Fused MoE. Paddle Cutlass compute Fused MoE.
@@ -237,7 +239,7 @@ class MoEMethodBase(QuantMethodBase):
layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts
) )
else: else:
return self.apply_tp(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) return self.apply_tp(layer, x, gate, topk_ids_hookfunc)
class UnquantizedFusedMoEMethod(MoEMethodBase): class UnquantizedFusedMoEMethod(MoEMethodBase):
@@ -292,6 +292,8 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
gate: nn.Layer, gate: nn.Layer,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
@@ -681,6 +683,8 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
gate: nn.Layer, gate: nn.Layer,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
@@ -980,6 +984,8 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
gate: nn.Layer, gate: nn.Layer,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
@@ -1174,6 +1180,9 @@ def python_op_fused_moe_kernel_paddle_infer_meta(
config: dict, config: dict,
quant_config, quant_config,
topk_ids_hookfunc, topk_ids_hookfunc,
layer,
fc1_latent_proj,
fc2_latent_proj,
): ):
token_num = x.shape[0] token_num = x.shape[0]
return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype) return paddle.static.MetaTensor(shape=[token_num, hidden_size], dtype=x.dtype)
@@ -1211,19 +1220,34 @@ def python_op_fused_moe_kernel_paddle(
config: dict, config: dict,
quant_config, quant_config,
topk_ids_hookfunc, topk_ids_hookfunc,
layer,
fc1_latent_proj,
fc2_latent_proj,
): ):
token_num = x.shape[0] token_num = x.shape[0]
if x.shape[0] == 0: if x.shape[0] == 0:
return paddle.zeros([token_num, hidden_size], dtype=x.dtype) return paddle.zeros([token_num, hidden_size], dtype=x.dtype)
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( if layer.topk_method == "noaux_tc":
gate_out, gate_out, topk_weights, topk_ids = get_moe_scores(
gate_correction_bias, gate_out,
top_k, layer.n_group,
True, # apply_norm_weight layer.topk_group,
False, layer.top_k,
) layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
gate_correction_bias,
top_k,
True, # apply_norm_weight
False,
)
if topk_ids_hookfunc is not None: if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_ids) topk_ids_hookfunc(topk_ids=topk_ids)
@@ -1244,6 +1268,9 @@ def python_op_fused_moe_kernel_paddle(
from .triton_moe_kernels import fused_moe_kernel_paddle from .triton_moe_kernels import fused_moe_kernel_paddle
if fc1_latent_proj is not None:
x = fc1_latent_proj(x)
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT: if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False) x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False)
else: else:
@@ -1357,6 +1384,9 @@ def python_op_fused_moe_kernel_paddle(
intermediate_cache3.reshape_([token_num, top_k, hidden_size]) intermediate_cache3.reshape_([token_num, top_k, hidden_size])
out = intermediate_cache3.sum(axis=1) out = intermediate_cache3.sum(axis=1)
if fc2_latent_proj is not None:
out = fc2_latent_proj(out)
return out return out
@@ -1808,6 +1838,8 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
gate: nn.Layer, gate: nn.Layer,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Triton compute Fused MoE. Triton compute Fused MoE.
@@ -1855,4 +1887,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
config, config,
self.quant_config, self.quant_config,
topk_ids_hookfunc, topk_ids_hookfunc,
layer,
fc1_latent_proj,
fc2_latent_proj,
) )
+23 -3
View File
@@ -709,7 +709,13 @@ class FusedMoE(nn.Layer):
return out return out
def forward( def forward(
self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta = None, shared_experts: nn.Layer = None self,
x: paddle.Tensor,
gate: nn.Layer,
forward_meta: ForwardMeta = None,
shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
): ):
""" """
Defines the forward computation of the moe layer. Defines the forward computation of the moe layer.
@@ -762,7 +768,13 @@ class FusedMoE(nn.Layer):
) )
else: else:
out = self.forward_normal( out = self.forward_normal(
x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts x,
gate,
forward_meta,
topk_ids_hookfunc,
shared_experts,
fc1_latent_proj,
fc2_latent_proj,
) )
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
@@ -829,6 +841,8 @@ class FusedMoE(nn.Layer):
forward_meta: ForwardMeta, forward_meta: ForwardMeta,
topk_ids_hookfunc: Callable = None, topk_ids_hookfunc: Callable = None,
shared_experts: nn.Layer = None, shared_experts: nn.Layer = None,
fc1_latent_proj: nn.Layer = None,
fc2_latent_proj: nn.Layer = None,
): ):
""" """
Normal mode of forward. Normal mode of forward.
@@ -842,7 +856,13 @@ class FusedMoE(nn.Layer):
""" """
if current_platform.is_cuda(): if current_platform.is_cuda():
out = self.quant_method.apply( out = self.quant_method.apply(
self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts self,
x,
gate,
topk_ids_hookfunc,
shared_experts,
fc1_latent_proj,
fc2_latent_proj,
) )
else: else:
out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) out = self.quant_method.apply(self, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)
@@ -509,6 +509,9 @@ class TestFusedMoeTritonBackend:
config, config,
quant_config, quant_config,
hook, hook,
layer,
None,
None,
) )
assert "topk" in captured assert "topk" in captured
@@ -530,6 +533,9 @@ class TestFusedMoeTritonBackend:
config, config,
quant_config, quant_config,
None, None,
layer,
None,
None,
) )
assert meta.shape == [2, layer.hidden_size] assert meta.shape == [2, layer.hidden_size]
+53 -10
View File
@@ -506,10 +506,41 @@ class FuseMoEWrapper(paddle.nn.Layer):
skip_quant=True, skip_quant=True,
weight_dtype="float32", weight_dtype="float32",
) )
self.gating.weight.set_value(paddle.rand(self.gating.weight.shape, dtype=paddle.float32))
self.fc1_latent_proj = None
self.fc2_latent_proj = None
if self.fd_config.model_config.use_latent_moe:
self.fc1_latent_proj = ReplicatedLinear(
fd_config=self.fd_config,
input_size=self.fd_config.model_config.hidden_size,
output_size=self.fd_config.model_config.moe_latent_size,
with_bias=True,
)
self.fc1_latent_proj.weight.set_value(
paddle.zeros(self.fc1_latent_proj.weight.shape).cast(paddle.float8_e4m3fn)
)
self.fc1_latent_proj.bias.set_value(paddle.zeros(self.fc1_latent_proj.bias.shape))
self.fc2_latent_proj = ReplicatedLinear(
fd_config=self.fd_config,
input_size=self.fd_config.model_config.moe_latent_size,
output_size=self.fd_config.model_config.hidden_size,
with_bias=True,
)
self.fc2_latent_proj.weight.set_value(
paddle.zeros(self.fc2_latent_proj.weight.shape).cast(paddle.float8_e4m3fn)
)
self.fc2_latent_proj.bias.set_value(paddle.zeros(self.fc2_latent_proj.bias.shape))
self.fused_moe = FusedMoE( self.fused_moe = FusedMoE(
fd_config=self.fd_config, fd_config=self.fd_config,
hidden_size=self.fd_config.model_config.hidden_size, hidden_size=(
self.fd_config.model_config.moe_latent_size
if self.fd_config.model_config.use_latent_moe
else self.fd_config.model_config.hidden_size
),
moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size,
num_experts=self.fd_config.model_config.moe_num_experts, num_experts=self.fd_config.model_config.moe_num_experts,
top_k=self.fd_config.model_config.moe_k, top_k=self.fd_config.model_config.moe_k,
@@ -517,8 +548,8 @@ class FuseMoEWrapper(paddle.nn.Layer):
layer_idx=666, layer_idx=666,
weight_key_map=weight_key_map, weight_key_map=weight_key_map,
topk_method="noaux_tc", topk_method="noaux_tc",
topk_group=4, topk_group=0,
n_group=8, n_group=0,
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32), gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
# gate_correction_bias = gate_correction_bias_real_data # gate_correction_bias = gate_correction_bias_real_data
) )
@@ -558,11 +589,20 @@ class FuseMoEWrapper(paddle.nn.Layer):
class TestFusedMoE(unittest.TestCase): class TestFusedMoE(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.architectures = ["Ernie4_5_MoeForCausalLM"] self.architectures = ["Ernie4_5_MoeForCausalLM"]
self.hidden_size = 4096 self.hidden_size = 1536
self.moe_intermediate_size = 2048 self.moe_intermediate_size = 1024
self.moe_num_experts = 64
self.moe_k = 8 self.moe_num_experts = 256
self.num_layers = 2 self.moe_k = 16
self.use_latent_moe = True
self.moe_latent_size = 768
# self.moe_num_experts = 128
# self.moe_k = 8
# self.use_latent_moe = False
# self.moe_latent_size = 768
self.num_layers = 50
self.num_attention_heads = -1 self.num_attention_heads = -1
self.model_config = self.build_model_config() self.model_config = self.build_model_config()
@@ -584,6 +624,8 @@ class TestFusedMoE(unittest.TestCase):
"moe_k": self.moe_k, "moe_k": self.moe_k,
"num_attention_heads": self.num_attention_heads, "num_attention_heads": self.num_attention_heads,
"dtype": "bfloat16", "dtype": "bfloat16",
"use_latent_moe": self.use_latent_moe,
"moe_latent_size": self.moe_latent_size,
} }
tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}" tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}"
@@ -635,9 +677,10 @@ class TestFusedMoE(unittest.TestCase):
out = cache_hidden_states + cache_hidden_states out = cache_hidden_states + cache_hidden_states
else: else:
gating = fused_moe[j % real_weight_layers].gating gating = fused_moe[j % real_weight_layers].gating
gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) fc1_latent_proj = fused_moe[j % real_weight_layers].fc1_latent_proj
fc2_latent_proj = fused_moe[j % real_weight_layers].fc2_latent_proj
out = fused_moe[j % real_weight_layers].fused_moe( out = fused_moe[j % real_weight_layers].fused_moe(
cache_hidden_states[idx], gating, forward_meta=MockForwardMeta() cache_hidden_states[idx], gating, MockForwardMeta(), None, fc1_latent_proj, fc2_latent_proj
) )
return out return out