[Models] Add forward_meta to moe models' forward function (#5138)

* [Models] Add forward_meta to moe models' forward function

* fix missing param

* fix

* fix

* fix forward_meta

* fix test and remove chunked MoE releated in config

* fix test

* fix

* fix
This commit is contained in:
Longzhi Wang
2025-12-04 13:26:58 +08:00
committed by GitHub
parent f5bdb36e9b
commit 5cd17fd662
21 changed files with 131 additions and 87 deletions
@@ -104,7 +104,7 @@ class DeepSeekV3MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
def forward(self, x, forward_meta=None):
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
@@ -187,10 +187,10 @@ class DeepSeekV3MoE(nn.Layer):
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
""" """
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
moe_out = self.experts(hidden_states, self.gate, forward_meta)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1:
@@ -517,8 +517,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual
@@ -744,7 +743,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
)
return position_ids, mask_encoder_batch
def empty_input_forward(self):
def empty_input_forward(self, forward_meta):
"""
empty_input_forward
"""
@@ -756,7 +755,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
def forward(
self,