[Models] Add forward_meta to moe models' forward function (#5138)

* [Models] Add forward_meta to moe models' forward function

* fix missing param

* fix

* fix

* fix forward_meta

* fix test and remove chunked MoE releated in config

* fix test

* fix

* fix
This commit is contained in:
Longzhi Wang
2025-12-04 13:26:58 +08:00
committed by GitHub
parent f5bdb36e9b
commit 5cd17fd662
21 changed files with 131 additions and 87 deletions
-2
View File
@@ -549,8 +549,6 @@ class ParallelConfig:
self.enable_expert_parallel = False self.enable_expert_parallel = False
self.enable_chunked_moe = False self.enable_chunked_moe = False
self.chunked_moe_size = 256 self.chunked_moe_size = 256
self.max_moe_num_chunk = 1
self.moe_num_chunk = 1
self.local_data_parallel_id = 0 self.local_data_parallel_id = 0
# Engine worker queue port # Engine worker queue port
@@ -143,6 +143,10 @@ class ForwardMeta:
# Flag of profile run # Flag of profile run
is_dummy_or_profile_run: bool = False is_dummy_or_profile_run: bool = False
# chunked MoE related
moe_num_chunk: int = 1
max_moe_num_chunk: int = 1
def clear_caches(self): def clear_caches(self):
"""Safely clean up the caches""" """Safely clean up the caches"""
if self.caches: if self.caches:
+9 -8
View File
@@ -25,6 +25,7 @@ from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_all_reduce_custom, tensor_model_parallel_all_reduce_custom,
) )
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.utils import h2d_copy, slice_fn from fastdeploy.model_executor.utils import h2d_copy, slice_fn
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -621,7 +622,7 @@ class FusedMoE(nn.Layer):
return out return out
def forward(self, x: paddle.Tensor, gate: nn.Layer): def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
""" """
Defines the forward computation of the moe layer. Defines the forward computation of the moe layer.
@@ -641,7 +642,7 @@ class FusedMoE(nn.Layer):
): ):
out = self.forward_split_allgather(x, gate) out = self.forward_split_allgather(x, gate)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe: elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(x, gate) out = self.forward_chunked_moe(x, gate, forward_meta)
else: else:
out = self.forward_normal(x, gate) out = self.forward_normal(x, gate)
@@ -652,7 +653,7 @@ class FusedMoE(nn.Layer):
out = tensor_model_parallel_all_reduce(out, self.tp_group) out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out return out
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer): def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
""" """
Split input to multi chunk to reduce the memory usage of moe. Split input to multi chunk to reduce the memory usage of moe.
@@ -671,11 +672,11 @@ class FusedMoE(nn.Layer):
# input size that are less than a chunk, less than the max size data or empty input # input size that are less than a chunk, less than the max size data or empty input
# need to be repeated until the max chunk data infer MOE finished. # need to be repeated until the max chunk data infer MOE finished.
if token_num > chunk_size: # chunked moe if token_num > chunk_size: # chunked moe
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0) x_split_list = paddle.tensor_split(x, forward_meta.moe_num_chunk, axis=0)
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk out_split_list = [None] * forward_meta.moe_num_chunk
for i in range(self.fd_config.parallel_config.max_moe_num_chunk): for i in range(forward_meta.max_moe_num_chunk):
if i < self.fd_config.parallel_config.moe_num_chunk: if i < forward_meta.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate) out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else: else:
# just need to use real data to infer max_moe_num_chunk times. # just need to use real data to infer max_moe_num_chunk times.
@@ -685,7 +686,7 @@ class FusedMoE(nn.Layer):
else: else:
# when only one chunk, just need to use real data to infer once. # when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate) out = self.quant_method.apply(self, x, gate)
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1): for i in range(forward_meta.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate) self.quant_method.apply(self, fake_x, gate)
return out return out
@@ -104,7 +104,7 @@ class DeepSeekV3MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict) self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict)
def forward(self, x): def forward(self, x, forward_meta=None):
""" """ """ """
gate_up_out = self.up_gate_proj(x) gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out) act_out = self.act_fn(gate_up_out)
@@ -187,10 +187,10 @@ class DeepSeekV3MoE(nn.Layer):
self.experts.load_state_dict(state_dict) self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
""" """ """ """
shared_experts_out = self.shared_experts(hidden_states) shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.experts(hidden_states, self.gate) moe_out = self.experts(hidden_states, self.gate, forward_meta)
moe_out = moe_out + shared_experts_out moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts. # We do to TP all reduce after the sum of experts.
if self.tp_size > 1: if self.tp_size > 1:
@@ -517,8 +517,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch) hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual return hidden_states, residual
@@ -744,7 +743,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
) )
return position_ids, mask_encoder_batch return position_ids, mask_encoder_batch
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -756,7 +755,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
self.fd_config.model_config.first_k_dense_replace, self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate) self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
def forward( def forward(
self, self,
@@ -94,7 +94,7 @@ class Ernie4_5_MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict) self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
gate_up_out = self.up_gate_proj(hidden_states) gate_up_out = self.up_gate_proj(hidden_states)
act_out = self.act_fn(gate_up_out) act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out) down_out = self.down_proj(act_out)
@@ -213,8 +213,16 @@ class Ernie4_5_MoE(nn.Layer):
def update_state_dict(self, state_dict): def update_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict, True) self.experts.load_state_dict(state_dict, True)
def forward(self, hidden_states: paddle.Tensor): def forward(
out = self.experts(hidden_states, self.gate) self,
hidden_states: paddle.Tensor,
forward_meta: ForwardMeta,
):
out = self.experts(
x=hidden_states,
gate=self.gate,
forward_meta=forward_meta,
)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states) s_x = self.shared_experts(hidden_states)
out = out + s_x out = out + s_x
@@ -344,7 +352,10 @@ class Ernie4_5_DecoderLayer(nn.Layer):
residual, residual,
) )
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
return hidden_states, residual return hidden_states, residual
@@ -611,7 +622,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
return logits return logits
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -623,7 +634,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate) self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate, forward_meta)
def forward( def forward(
self, self,
@@ -436,7 +436,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
return logits return logits
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -448,7 +448,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.fused_moe(hidden_states=fake_hidden_states, forward_meta=forward_meta)
def forward( def forward(
self, self,
@@ -170,8 +170,8 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
model_format="", model_format="",
) )
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
out = self.experts(hidden_states, self.gate) out = self.experts(hidden_states, self.gate, forward_meta)
return out return out
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
@@ -270,7 +270,7 @@ class Ernie4_5_VLMoE(nn.Layer):
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict) self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta): def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
shared_experts_out = self.shared_experts(hidden_states) shared_experts_out = self.shared_experts(hidden_states)
hidden_states, text_input, image_input = text_image_gather_scatter( hidden_states, text_input, image_input = text_image_gather_scatter(
@@ -282,8 +282,8 @@ class Ernie4_5_VLMoE(nn.Layer):
vl_moe_meta.image_index, vl_moe_meta.image_index,
True, True,
) )
text_out = self.text_fused_moe(text_input) text_out = self.text_fused_moe(text_input, forward_meta)
image_out = self.image_fused_moe(image_input) image_out = self.image_fused_moe(image_input, forward_meta)
hidden_states, _, _ = text_image_gather_scatter( hidden_states, _, _ = text_image_gather_scatter(
hidden_states, hidden_states,
text_out, text_out,
@@ -395,9 +395,9 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if isinstance(self.mlp, Ernie4_5_VLMoE): if isinstance(self.mlp, Ernie4_5_VLMoE):
hidden_states = self.mlp(hidden_states, vl_moe_meta) hidden_states = self.mlp(hidden_states, forward_meta, vl_moe_meta)
else: else:
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual return hidden_states, residual
@@ -754,7 +754,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
return logits return logits
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -766,8 +766,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states, forward_meta)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states) self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states, forward_meta)
def get_input_embeddings( def get_input_embeddings(
self, self,
+9 -6
View File
@@ -85,7 +85,7 @@ class Glm4MoeMLP(nn.Layer):
act_method=fd_config.model_config.hidden_act, act_method=fd_config.model_config.hidden_act,
) )
def forward(self, x): def forward(self, x, forward_meta=None):
""" """ """ """
gate_up_out = self.up_gate_proj(x) gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out) act_out = self.act_fn(gate_up_out)
@@ -161,9 +161,9 @@ class Glm4Moe(nn.Layer):
reduce_results=False, reduce_results=False,
) )
def forward(self, x): def forward(self, x, forward_meta):
shared_experts_out = self.shared_experts(x) shared_experts_out = self.shared_experts(x)
out = self.experts(x, self.gate) out = self.experts(x, self.gate, forward_meta)
out = out + shared_experts_out out = out + shared_experts_out
# We do to TP all reduce after the sum of experts. # We do to TP all reduce after the sum of experts.
if self.tensor_parallel_size > 1: if self.tensor_parallel_size > 1:
@@ -306,7 +306,10 @@ class Glm4MoeDecoderLayer(nn.Layer):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(
hidden_states,
forward_meta,
)
return hidden_states, residual return hidden_states, residual
@@ -494,7 +497,7 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
return logits return logits
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -506,7 +509,7 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.first_k_dense_replace, self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate) self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
def forward( def forward(
self, self,
+3 -3
View File
@@ -124,8 +124,8 @@ class GptOssMoe(nn.Layer):
model_format="", model_format="",
) )
def forward(self, hidden_states: paddle.Tensor): def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
expert_output = self.experts(hidden_states, self.router) expert_output = self.experts(hidden_states, self.router, forward_meta)
return expert_output return expert_output
@@ -173,7 +173,7 @@ class GptOssDecoderLayer(nn.Layer):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual return hidden_states, residual
+2 -2
View File
@@ -89,7 +89,7 @@ class Qwen2MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict) self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict)
def forward(self, x): def forward(self, x, forward_meta):
""" """ """ """
gate_up_out = self.up_gate_proj(x) gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out) act_out = self.act_fn(gate_up_out)
@@ -205,7 +205,7 @@ class Qwen2DecoderLayer(nn.Layer):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual return hidden_states, residual
+6 -6
View File
@@ -79,8 +79,8 @@ class Qwen3MoeBlock(nn.Layer):
weight_dtype="float32", weight_dtype="float32",
) )
def forward(self, x): def forward(self, x, forward_meta):
return self.experts(x, self.gate) return self.experts(x, self.gate, forward_meta)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """ """ """
@@ -125,7 +125,7 @@ class Qwen3MLP(nn.Layer):
self.up_gate_proj.load_state_dict(state_dict) self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict) self.down_proj.load_state_dict(state_dict)
def forward(self, x): def forward(self, x, forward_meta):
""" """ """ """
gate_up_out = self.up_gate_proj(x) gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out) act_out = self.act_fn(gate_up_out)
@@ -204,7 +204,7 @@ class Qwen3DecoderLayer(nn.Layer):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_meta)
return hidden_states, residual return hidden_states, residual
@@ -416,7 +416,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
return logits return logits
def empty_input_forward(self): def empty_input_forward(self, forward_meta):
""" """
empty_input_forward empty_input_forward
""" """
@@ -428,7 +428,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers, self.fd_config.model_config.num_hidden_layers,
): ):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate) self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
def forward( def forward(
self, self,
+2 -2
View File
@@ -988,7 +988,7 @@ class MTPProposer(Proposer):
self._get_self_hidden_states(hidden_states) self._get_self_hidden_states(hidden_states)
else: else:
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta=self.forward_meta)
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False): def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
""" """
@@ -1078,7 +1078,7 @@ class MTPProposer(Proposer):
self._get_self_hidden_states(hidden_states) self._get_self_hidden_states(hidden_states)
else: else:
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(self.forward_meta)
def _get_self_hidden_states(self, hidden_states): def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states( target_hidden_states = eagle_get_self_hidden_states(
+3 -3
View File
@@ -971,7 +971,7 @@ class GCUModelRunner(ModelRunnerBase):
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model. # when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop(): if not self.not_need_stop():
self._execute_empty_input() self._execute_empty_input(self.forward_meta)
return None return None
# 1. Prepare inputs of model and sampler. # 1. Prepare inputs of model and sampler.
@@ -1088,14 +1088,14 @@ class GCUModelRunner(ModelRunnerBase):
self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False) self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False)
return None return None
def _execute_empty_input(self) -> None: def _execute_empty_input(self, forward_meta) -> None:
""" """
In certain scenarios, such as during EP, In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data. the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method. This requires the model to implement the `empty_input_forward` method.
""" """
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta)
else: else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
+7 -7
View File
@@ -275,11 +275,11 @@ class GPUModelRunner(ModelRunnerBase):
token_num = self.share_inputs["ids_remove_padding"].shape[0] token_num = self.share_inputs["ids_remove_padding"].shape[0]
if token_num > chunk_size: if token_num > chunk_size:
self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
else: else:
self.fd_config.parallel_config.moe_num_chunk = 1 self.forward_meta.moe_num_chunk = 1
dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk dist_status_obj.moe_num_chunk = self.forward_meta.moe_num_chunk
# only ep need to collect and sync distributed status # only ep need to collect and sync distributed status
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
@@ -1448,7 +1448,7 @@ class GPUModelRunner(ModelRunnerBase):
if_only_decode = dist_status.if_only_decode if_only_decode = dist_status.if_only_decode
if self.fd_config.parallel_config.enable_chunked_moe: if self.fd_config.parallel_config.enable_chunked_moe:
self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk self.forward_meta.max_moe_num_chunk = dist_status.max_moe_num_chunk
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
@@ -2202,7 +2202,7 @@ class GPUModelRunner(ModelRunnerBase):
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model. # when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop(): if not self.not_need_stop():
self._execute_empty_input() self._execute_empty_input(self.forward_meta)
return None return None
# 2. Padding inputs for cuda graph # 2. Padding inputs for cuda graph
@@ -2473,14 +2473,14 @@ class GPUModelRunner(ModelRunnerBase):
return pooler_output return pooler_output
def _execute_empty_input(self) -> None: def _execute_empty_input(self, forward_meta) -> None:
""" """
In certain scenarios, such as during EP, In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data. the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method. This requires the model to implement the `empty_input_forward` method.
""" """
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta)
else: else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
+2 -2
View File
@@ -1361,14 +1361,14 @@ class HPUModelRunner(ModelRunnerBase):
self.prof.step() self.prof.step()
return None return None
def _execute_empty_input(self) -> None: def _execute_empty_input(self, forward_meta) -> None:
""" """
In certain scenarios, such as during EP, In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data. the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method. This requires the model to implement the `empty_input_forward` method.
""" """
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta)
else: else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
+3 -3
View File
@@ -1812,7 +1812,7 @@ class MetaxModelRunner(ModelRunnerBase):
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model. # when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop(): if not self.not_need_stop():
self._execute_empty_input() self._execute_empty_input(self.forward_meta)
return None return None
# 2. Padding inputs for cuda graph # 2. Padding inputs for cuda graph
@@ -1998,14 +1998,14 @@ class MetaxModelRunner(ModelRunnerBase):
) )
return None return None
def _execute_empty_input(self) -> None: def _execute_empty_input(self, forward_meta) -> None:
""" """
In certain scenarios, such as during EP, In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data. the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method. This requires the model to implement the `empty_input_forward` method.
""" """
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta)
else: else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
+3 -3
View File
@@ -1227,7 +1227,7 @@ class XPUModelRunner(ModelRunnerBase):
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model. # when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop() and not is_dummy_run: if not self.not_need_stop() and not is_dummy_run:
self._execute_empty_input() self._execute_empty_input(self.forward_meta)
return None return None
# 2. Padding inputs for cuda grph # 2. Padding inputs for cuda grph
@@ -1323,14 +1323,14 @@ class XPUModelRunner(ModelRunnerBase):
destroy_kv_signal_sender(self.kv_signal_sender) destroy_kv_signal_sender(self.kv_signal_sender)
return None return None
def _execute_empty_input(self) -> None: def _execute_empty_input(self, forward_meta) -> None:
""" """
In certain scenarios, such as during EP, In certain scenarios, such as during EP,
the runner needs to execute partial modules of the model without input data. the runner needs to execute partial modules of the model without input data.
This requires the model to implement the `empty_input_forward` method. This requires the model to implement the `empty_input_forward` method.
""" """
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward(forward_meta)
else: else:
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
+16 -11
View File
@@ -28,6 +28,13 @@ class MockStructuredOutputsConfig:
logits_processors = [] logits_processors = []
class MockForwardMeta:
def __init__(self):
# chunked MoE related.
self.moe_num_chunk = 1
self.max_moe_num_chunk = 1
class MockModelConfig: class MockModelConfig:
max_model_len = 10 max_model_len = 10
pad_token_id = 0 pad_token_id = 0
@@ -60,8 +67,6 @@ class MockFDConfig:
enable_expert_parallel = True enable_expert_parallel = True
enable_chunked_moe = True enable_chunked_moe = True
chunked_moe_size = 2 chunked_moe_size = 2
max_moe_num_chunk = 1
moe_num_chunk = 1
use_ep = True use_ep = True
use_sequence_parallel_moe = False use_sequence_parallel_moe = False
@@ -148,19 +153,19 @@ class TestChunkedMoE(unittest.TestCase):
def run_model_runner(self): def run_model_runner(self):
self.model_runner.initialize_forward_meta() self.model_runner.initialize_forward_meta()
assert self.model_runner.fd_config.parallel_config.max_moe_num_chunk == 5, ( assert self.model_runner.forward_meta.max_moe_num_chunk == 5, (
f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, " f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, "
f"but got {self.model_runner.fd_config.parallel_config.max_moe_num_chunk}" f"but got {self.model_runner.forward_meta.max_moe_num_chunk }"
) )
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 5, ( assert self.model_runner.forward_meta.moe_num_chunk == 5, (
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5" f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5, "
f"but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}" f"but got {self.model_runner.forward_meta.moe_num_chunk}"
) )
else: else:
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 1, ( assert self.model_runner.forward_meta.moe_num_chunk == 1, (
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1" f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1, "
f", but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}" f"but got {self.model_runner.forward_meta.moe_num_chunk}"
) )
def run_fused_moe(self): def run_fused_moe(self):
@@ -170,7 +175,7 @@ class TestChunkedMoE(unittest.TestCase):
else: else:
x = paddle.ones([1]) x = paddle.ones([1])
out = self.fused_moe.forward(x, gate) out = self.fused_moe.forward(x, gate, MockForwardMeta())
assert out.shape == x.shape assert out.shape == x.shape
def test_case(self): def test_case(self):
+9 -2
View File
@@ -44,6 +44,13 @@ if "nvidia graphics device" in paddle.device.cuda.get_device_name().lower():
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17") os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
class MockForwardMeta:
def __init__(self):
# chunked MoE related.
self.moe_num_chunk = 1
self.max_moe_num_chunk = 1
class FFNWrapper(paddle.nn.Layer): class FFNWrapper(paddle.nn.Layer):
def __init__(self, model_config: ModelConfig): def __init__(self, model_config: ModelConfig):
super().__init__() super().__init__()
@@ -134,7 +141,7 @@ class TestFusedMoE(unittest.TestCase):
init_distributed_environment() init_distributed_environment()
ffn = FFNWrapper(self.model_config) ffn = FFNWrapper(self.model_config)
forward_meta = MockForwardMeta()
moe_cuda_graphs = [None] * 100 moe_cuda_graphs = [None] * 100
cache_hidden_states = [None] * 100 cache_hidden_states = [None] * 100
test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256, 4096, 4096 * 4] test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256, 4096, 4096 * 4]
@@ -147,7 +154,7 @@ class TestFusedMoE(unittest.TestCase):
num_layers = self.num_layers num_layers = self.num_layers
for _ in range(num_layers): for _ in range(num_layers):
out = ffn.ffn(cache_hidden_states[idx]) out = ffn.ffn(cache_hidden_states[idx], forward_meta=forward_meta)
moe_cuda_graphs[idx].capture_end() moe_cuda_graphs[idx].capture_end()
+10 -1
View File
@@ -432,6 +432,13 @@ gate_correction_bias_real_data = paddle.to_tensor(
) )
class MockForwardMeta:
def __init__(self):
# chunked MoE related.
self.moe_num_chunk = 1
self.max_moe_num_chunk = 1
class FuseMoEWrapper(paddle.nn.Layer): class FuseMoEWrapper(paddle.nn.Layer):
def __init__( def __init__(
self, self,
@@ -607,7 +614,9 @@ class TestFusedMoE(unittest.TestCase):
def fake_model_run(): def fake_model_run():
for j in range(num_layers): for j in range(num_layers):
out = fused_moe[j % real_weight_layers].fused_moe(cache_hidden_states[idx], gating) out = fused_moe[j % real_weight_layers].fused_moe(
cache_hidden_states[idx], gating, forward_meta=MockForwardMeta()
)
return out return out
+8 -1
View File
@@ -29,6 +29,13 @@ from fastdeploy.config import (
) )
class MockForwardMeta:
def __init__(self):
# chunked MoE related.
self.moe_num_chunk = 1
self.max_moe_num_chunk = 1
class FakeModelConfig: class FakeModelConfig:
def __init__(self): def __init__(self):
self.hidden_size = 768 self.hidden_size = 768
@@ -85,7 +92,7 @@ class OpPerformanceTester:
def _fake_model_run(self, x): def _fake_model_run(self, x):
for j in range(self.num_layers): for j in range(self.num_layers):
if self.gate: if self.gate:
out = self.op_fn(x, self.gate) out = self.op_fn(x, self.gate, forward_meta=MockForwardMeta())
else: else:
out = self.op_fn(x) out = self.op_fn(x)
return out return out