mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Models] Add forward_meta to moe models' forward function (#5138)
* [Models] Add forward_meta to moe models' forward function * fix missing param * fix * fix * fix forward_meta * fix test and remove chunked MoE releated in config * fix test * fix * fix
This commit is contained in:
@@ -549,8 +549,6 @@ class ParallelConfig:
|
|||||||
self.enable_expert_parallel = False
|
self.enable_expert_parallel = False
|
||||||
self.enable_chunked_moe = False
|
self.enable_chunked_moe = False
|
||||||
self.chunked_moe_size = 256
|
self.chunked_moe_size = 256
|
||||||
self.max_moe_num_chunk = 1
|
|
||||||
self.moe_num_chunk = 1
|
|
||||||
|
|
||||||
self.local_data_parallel_id = 0
|
self.local_data_parallel_id = 0
|
||||||
# Engine worker queue port
|
# Engine worker queue port
|
||||||
|
|||||||
@@ -143,6 +143,10 @@ class ForwardMeta:
|
|||||||
# Flag of profile run
|
# Flag of profile run
|
||||||
is_dummy_or_profile_run: bool = False
|
is_dummy_or_profile_run: bool = False
|
||||||
|
|
||||||
|
# chunked MoE related
|
||||||
|
moe_num_chunk: int = 1
|
||||||
|
max_moe_num_chunk: int = 1
|
||||||
|
|
||||||
def clear_caches(self):
|
def clear_caches(self):
|
||||||
"""Safely clean up the caches"""
|
"""Safely clean up the caches"""
|
||||||
if self.caches:
|
if self.caches:
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from fastdeploy.distributed.communication import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
tensor_model_parallel_all_reduce_custom,
|
tensor_model_parallel_all_reduce_custom,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
@@ -621,7 +622,7 @@ class FusedMoE(nn.Layer):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||||
"""
|
"""
|
||||||
Defines the forward computation of the moe layer.
|
Defines the forward computation of the moe layer.
|
||||||
|
|
||||||
@@ -641,7 +642,7 @@ class FusedMoE(nn.Layer):
|
|||||||
):
|
):
|
||||||
out = self.forward_split_allgather(x, gate)
|
out = self.forward_split_allgather(x, gate)
|
||||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||||
out = self.forward_chunked_moe(x, gate)
|
out = self.forward_chunked_moe(x, gate, forward_meta)
|
||||||
else:
|
else:
|
||||||
out = self.forward_normal(x, gate)
|
out = self.forward_normal(x, gate)
|
||||||
|
|
||||||
@@ -652,7 +653,7 @@ class FusedMoE(nn.Layer):
|
|||||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
|
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||||
"""
|
"""
|
||||||
Split input to multi chunk to reduce the memory usage of moe.
|
Split input to multi chunk to reduce the memory usage of moe.
|
||||||
|
|
||||||
@@ -671,11 +672,11 @@ class FusedMoE(nn.Layer):
|
|||||||
# input size that are less than a chunk, less than the max size data or empty input
|
# input size that are less than a chunk, less than the max size data or empty input
|
||||||
# need to be repeated until the max chunk data infer MOE finished.
|
# need to be repeated until the max chunk data infer MOE finished.
|
||||||
if token_num > chunk_size: # chunked moe
|
if token_num > chunk_size: # chunked moe
|
||||||
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
|
x_split_list = paddle.tensor_split(x, forward_meta.moe_num_chunk, axis=0)
|
||||||
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
|
out_split_list = [None] * forward_meta.moe_num_chunk
|
||||||
|
|
||||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
|
for i in range(forward_meta.max_moe_num_chunk):
|
||||||
if i < self.fd_config.parallel_config.moe_num_chunk:
|
if i < forward_meta.moe_num_chunk:
|
||||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||||
else:
|
else:
|
||||||
# just need to use real data to infer max_moe_num_chunk times.
|
# just need to use real data to infer max_moe_num_chunk times.
|
||||||
@@ -685,7 +686,7 @@ class FusedMoE(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
# when only one chunk, just need to use real data to infer once.
|
# when only one chunk, just need to use real data to infer once.
|
||||||
out = self.quant_method.apply(self, x, gate)
|
out = self.quant_method.apply(self, x, gate)
|
||||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
|
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||||
self.quant_method.apply(self, fake_x, gate)
|
self.quant_method.apply(self, fake_x, gate)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class DeepSeekV3MLP(nn.Layer):
|
|||||||
self.up_gate_proj.load_state_dict(state_dict)
|
self.up_gate_proj.load_state_dict(state_dict)
|
||||||
self.down_proj.load_state_dict(state_dict)
|
self.down_proj.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta=None):
|
||||||
""" """
|
""" """
|
||||||
gate_up_out = self.up_gate_proj(x)
|
gate_up_out = self.up_gate_proj(x)
|
||||||
act_out = self.act_fn(gate_up_out)
|
act_out = self.act_fn(gate_up_out)
|
||||||
@@ -187,10 +187,10 @@ class DeepSeekV3MoE(nn.Layer):
|
|||||||
self.experts.load_state_dict(state_dict)
|
self.experts.load_state_dict(state_dict)
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
|
||||||
""" """
|
""" """
|
||||||
shared_experts_out = self.shared_experts(hidden_states)
|
shared_experts_out = self.shared_experts(hidden_states)
|
||||||
moe_out = self.experts(hidden_states, self.gate)
|
moe_out = self.experts(hidden_states, self.gate, forward_meta)
|
||||||
moe_out = moe_out + shared_experts_out
|
moe_out = moe_out + shared_experts_out
|
||||||
# We do to TP all reduce after the sum of experts.
|
# We do to TP all reduce after the sum of experts.
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -517,8 +517,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
|||||||
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
|
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -744,7 +743,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
)
|
)
|
||||||
return position_ids, mask_encoder_batch
|
return position_ids, mask_encoder_batch
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -756,7 +755,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.first_k_dense_replace,
|
self.fd_config.model_config.first_k_dense_replace,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class Ernie4_5_MLP(nn.Layer):
|
|||||||
self.up_gate_proj.load_state_dict(state_dict)
|
self.up_gate_proj.load_state_dict(state_dict)
|
||||||
self.down_proj.load_state_dict(state_dict)
|
self.down_proj.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None):
|
||||||
gate_up_out = self.up_gate_proj(hidden_states)
|
gate_up_out = self.up_gate_proj(hidden_states)
|
||||||
act_out = self.act_fn(gate_up_out)
|
act_out = self.act_fn(gate_up_out)
|
||||||
down_out = self.down_proj(act_out)
|
down_out = self.down_proj(act_out)
|
||||||
@@ -213,8 +213,16 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
def update_state_dict(self, state_dict):
|
def update_state_dict(self, state_dict):
|
||||||
self.experts.load_state_dict(state_dict, True)
|
self.experts.load_state_dict(state_dict, True)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(
|
||||||
out = self.experts(hidden_states, self.gate)
|
self,
|
||||||
|
hidden_states: paddle.Tensor,
|
||||||
|
forward_meta: ForwardMeta,
|
||||||
|
):
|
||||||
|
out = self.experts(
|
||||||
|
x=hidden_states,
|
||||||
|
gate=self.gate,
|
||||||
|
forward_meta=forward_meta,
|
||||||
|
)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
s_x = self.shared_experts(hidden_states)
|
s_x = self.shared_experts(hidden_states)
|
||||||
out = out + s_x
|
out = out + s_x
|
||||||
@@ -344,7 +352,10 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
|||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_meta=forward_meta,
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -611,7 +622,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -623,7 +634,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.moe_layer_start_index,
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
|
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate, forward_meta)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -448,7 +448,7 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.moe_layer_start_index,
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.fused_moe(hidden_states=fake_hidden_states, forward_meta=forward_meta)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -170,8 +170,8 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
|
|||||||
model_format="",
|
model_format="",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
|
||||||
out = self.experts(hidden_states, self.gate)
|
out = self.experts(hidden_states, self.gate, forward_meta)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
@@ -270,7 +270,7 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor, vl_moe_meta: VLMoEMeta):
|
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta, vl_moe_meta: VLMoEMeta):
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
shared_experts_out = self.shared_experts(hidden_states)
|
shared_experts_out = self.shared_experts(hidden_states)
|
||||||
hidden_states, text_input, image_input = text_image_gather_scatter(
|
hidden_states, text_input, image_input = text_image_gather_scatter(
|
||||||
@@ -282,8 +282,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
vl_moe_meta.image_index,
|
vl_moe_meta.image_index,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
text_out = self.text_fused_moe(text_input)
|
text_out = self.text_fused_moe(text_input, forward_meta)
|
||||||
image_out = self.image_fused_moe(image_input)
|
image_out = self.image_fused_moe(image_input, forward_meta)
|
||||||
hidden_states, _, _ = text_image_gather_scatter(
|
hidden_states, _, _ = text_image_gather_scatter(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
text_out,
|
text_out,
|
||||||
@@ -395,9 +395,9 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
|
|||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
if isinstance(self.mlp, Ernie4_5_VLMoE):
|
if isinstance(self.mlp, Ernie4_5_VLMoE):
|
||||||
hidden_states = self.mlp(hidden_states, vl_moe_meta)
|
hidden_states = self.mlp(hidden_states, forward_meta, vl_moe_meta)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -754,7 +754,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -766,8 +766,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.moe_layer_start_index,
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states, forward_meta)
|
||||||
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
|
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states, forward_meta)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class Glm4MoeMLP(nn.Layer):
|
|||||||
act_method=fd_config.model_config.hidden_act,
|
act_method=fd_config.model_config.hidden_act,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta=None):
|
||||||
""" """
|
""" """
|
||||||
gate_up_out = self.up_gate_proj(x)
|
gate_up_out = self.up_gate_proj(x)
|
||||||
act_out = self.act_fn(gate_up_out)
|
act_out = self.act_fn(gate_up_out)
|
||||||
@@ -161,9 +161,9 @@ class Glm4Moe(nn.Layer):
|
|||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta):
|
||||||
shared_experts_out = self.shared_experts(x)
|
shared_experts_out = self.shared_experts(x)
|
||||||
out = self.experts(x, self.gate)
|
out = self.experts(x, self.gate, forward_meta)
|
||||||
out = out + shared_experts_out
|
out = out + shared_experts_out
|
||||||
# We do to TP all reduce after the sum of experts.
|
# We do to TP all reduce after the sum of experts.
|
||||||
if self.tensor_parallel_size > 1:
|
if self.tensor_parallel_size > 1:
|
||||||
@@ -306,7 +306,10 @@ class Glm4MoeDecoderLayer(nn.Layer):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(
|
||||||
|
hidden_states,
|
||||||
|
forward_meta,
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -494,7 +497,7 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -506,7 +509,7 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.first_k_dense_replace,
|
self.fd_config.model_config.first_k_dense_replace,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -124,8 +124,8 @@ class GptOssMoe(nn.Layer):
|
|||||||
model_format="",
|
model_format="",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
|
||||||
expert_output = self.experts(hidden_states, self.router)
|
expert_output = self.experts(hidden_states, self.router, forward_meta)
|
||||||
return expert_output
|
return expert_output
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +173,7 @@ class GptOssDecoderLayer(nn.Layer):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class Qwen2MLP(nn.Layer):
|
|||||||
self.up_gate_proj.load_state_dict(state_dict)
|
self.up_gate_proj.load_state_dict(state_dict)
|
||||||
self.down_proj.load_state_dict(state_dict)
|
self.down_proj.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta):
|
||||||
""" """
|
""" """
|
||||||
gate_up_out = self.up_gate_proj(x)
|
gate_up_out = self.up_gate_proj(x)
|
||||||
act_out = self.act_fn(gate_up_out)
|
act_out = self.act_fn(gate_up_out)
|
||||||
@@ -205,7 +205,7 @@ class Qwen2DecoderLayer(nn.Layer):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|||||||
@@ -79,8 +79,8 @@ class Qwen3MoeBlock(nn.Layer):
|
|||||||
weight_dtype="float32",
|
weight_dtype="float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta):
|
||||||
return self.experts(x, self.gate)
|
return self.experts(x, self.gate, forward_meta)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
""" """
|
""" """
|
||||||
@@ -125,7 +125,7 @@ class Qwen3MLP(nn.Layer):
|
|||||||
self.up_gate_proj.load_state_dict(state_dict)
|
self.up_gate_proj.load_state_dict(state_dict)
|
||||||
self.down_proj.load_state_dict(state_dict)
|
self.down_proj.load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_meta):
|
||||||
""" """
|
""" """
|
||||||
gate_up_out = self.up_gate_proj(x)
|
gate_up_out = self.up_gate_proj(x)
|
||||||
act_out = self.act_fn(gate_up_out)
|
act_out = self.act_fn(gate_up_out)
|
||||||
@@ -204,7 +204,7 @@ class Qwen3DecoderLayer(nn.Layer):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_meta)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
@@ -416,7 +416,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def empty_input_forward(self):
|
def empty_input_forward(self, forward_meta):
|
||||||
"""
|
"""
|
||||||
empty_input_forward
|
empty_input_forward
|
||||||
"""
|
"""
|
||||||
@@ -428,7 +428,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
|
|||||||
self.fd_config.model_config.moe_layer_start_index,
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
self.fd_config.model_config.num_hidden_layers,
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
):
|
):
|
||||||
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate, forward_meta)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -988,7 +988,7 @@ class MTPProposer(Proposer):
|
|||||||
self._get_self_hidden_states(hidden_states)
|
self._get_self_hidden_states(hidden_states)
|
||||||
else:
|
else:
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta=self.forward_meta)
|
||||||
|
|
||||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||||
"""
|
"""
|
||||||
@@ -1078,7 +1078,7 @@ class MTPProposer(Proposer):
|
|||||||
self._get_self_hidden_states(hidden_states)
|
self._get_self_hidden_states(hidden_states)
|
||||||
else:
|
else:
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(self.forward_meta)
|
||||||
|
|
||||||
def _get_self_hidden_states(self, hidden_states):
|
def _get_self_hidden_states(self, hidden_states):
|
||||||
target_hidden_states = eagle_get_self_hidden_states(
|
target_hidden_states = eagle_get_self_hidden_states(
|
||||||
|
|||||||
@@ -971,7 +971,7 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||||
if not self.not_need_stop():
|
if not self.not_need_stop():
|
||||||
self._execute_empty_input()
|
self._execute_empty_input(self.forward_meta)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 1. Prepare inputs of model and sampler.
|
# 1. Prepare inputs of model and sampler.
|
||||||
@@ -1088,14 +1088,14 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False)
|
self.seq_lens_this_time_buffer.copy_(self.share_inputs["seq_lens_this_time"], False)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _execute_empty_input(self) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
In certain scenarios, such as during EP,
|
In certain scenarios, such as during EP,
|
||||||
the runner needs to execute partial modules of the model without input data.
|
the runner needs to execute partial modules of the model without input data.
|
||||||
This requires the model to implement the `empty_input_forward` method.
|
This requires the model to implement the `empty_input_forward` method.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
|||||||
@@ -275,11 +275,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
token_num = self.share_inputs["ids_remove_padding"].shape[0]
|
token_num = self.share_inputs["ids_remove_padding"].shape[0]
|
||||||
|
|
||||||
if token_num > chunk_size:
|
if token_num > chunk_size:
|
||||||
self.fd_config.parallel_config.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
|
self.forward_meta.moe_num_chunk = (token_num + chunk_size - 1) // chunk_size
|
||||||
else:
|
else:
|
||||||
self.fd_config.parallel_config.moe_num_chunk = 1
|
self.forward_meta.moe_num_chunk = 1
|
||||||
|
|
||||||
dist_status_obj.moe_num_chunk = self.fd_config.parallel_config.moe_num_chunk
|
dist_status_obj.moe_num_chunk = self.forward_meta.moe_num_chunk
|
||||||
|
|
||||||
# only ep need to collect and sync distributed status
|
# only ep need to collect and sync distributed status
|
||||||
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
|
||||||
@@ -1448,7 +1448,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
if_only_decode = dist_status.if_only_decode
|
if_only_decode = dist_status.if_only_decode
|
||||||
if self.fd_config.parallel_config.enable_chunked_moe:
|
if self.fd_config.parallel_config.enable_chunked_moe:
|
||||||
self.fd_config.parallel_config.max_moe_num_chunk = dist_status.max_moe_num_chunk
|
self.forward_meta.max_moe_num_chunk = dist_status.max_moe_num_chunk
|
||||||
|
|
||||||
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
||||||
|
|
||||||
@@ -2202,7 +2202,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||||
if not self.not_need_stop():
|
if not self.not_need_stop():
|
||||||
self._execute_empty_input()
|
self._execute_empty_input(self.forward_meta)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. Padding inputs for cuda graph
|
# 2. Padding inputs for cuda graph
|
||||||
@@ -2473,14 +2473,14 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
return pooler_output
|
return pooler_output
|
||||||
|
|
||||||
def _execute_empty_input(self) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
In certain scenarios, such as during EP,
|
In certain scenarios, such as during EP,
|
||||||
the runner needs to execute partial modules of the model without input data.
|
the runner needs to execute partial modules of the model without input data.
|
||||||
This requires the model to implement the `empty_input_forward` method.
|
This requires the model to implement the `empty_input_forward` method.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
|||||||
@@ -1361,14 +1361,14 @@ class HPUModelRunner(ModelRunnerBase):
|
|||||||
self.prof.step()
|
self.prof.step()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _execute_empty_input(self) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
In certain scenarios, such as during EP,
|
In certain scenarios, such as during EP,
|
||||||
the runner needs to execute partial modules of the model without input data.
|
the runner needs to execute partial modules of the model without input data.
|
||||||
This requires the model to implement the `empty_input_forward` method.
|
This requires the model to implement the `empty_input_forward` method.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
|||||||
@@ -1812,7 +1812,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||||
if not self.not_need_stop():
|
if not self.not_need_stop():
|
||||||
self._execute_empty_input()
|
self._execute_empty_input(self.forward_meta)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. Padding inputs for cuda graph
|
# 2. Padding inputs for cuda graph
|
||||||
@@ -1998,14 +1998,14 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _execute_empty_input(self) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
In certain scenarios, such as during EP,
|
In certain scenarios, such as during EP,
|
||||||
the runner needs to execute partial modules of the model without input data.
|
the runner needs to execute partial modules of the model without input data.
|
||||||
This requires the model to implement the `empty_input_forward` method.
|
This requires the model to implement the `empty_input_forward` method.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
|||||||
@@ -1227,7 +1227,7 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
|
||||||
# when there is data on other runner, the current runner is required to execute part of the model.
|
# when there is data on other runner, the current runner is required to execute part of the model.
|
||||||
if not self.not_need_stop() and not is_dummy_run:
|
if not self.not_need_stop() and not is_dummy_run:
|
||||||
self._execute_empty_input()
|
self._execute_empty_input(self.forward_meta)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. Padding inputs for cuda grph
|
# 2. Padding inputs for cuda grph
|
||||||
@@ -1323,14 +1323,14 @@ class XPUModelRunner(ModelRunnerBase):
|
|||||||
destroy_kv_signal_sender(self.kv_signal_sender)
|
destroy_kv_signal_sender(self.kv_signal_sender)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _execute_empty_input(self) -> None:
|
def _execute_empty_input(self, forward_meta) -> None:
|
||||||
"""
|
"""
|
||||||
In certain scenarios, such as during EP,
|
In certain scenarios, such as during EP,
|
||||||
the runner needs to execute partial modules of the model without input data.
|
the runner needs to execute partial modules of the model without input data.
|
||||||
This requires the model to implement the `empty_input_forward` method.
|
This requires the model to implement the `empty_input_forward` method.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward(forward_meta)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward")
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,13 @@ class MockStructuredOutputsConfig:
|
|||||||
logits_processors = []
|
logits_processors = []
|
||||||
|
|
||||||
|
|
||||||
|
class MockForwardMeta:
|
||||||
|
def __init__(self):
|
||||||
|
# chunked MoE related.
|
||||||
|
self.moe_num_chunk = 1
|
||||||
|
self.max_moe_num_chunk = 1
|
||||||
|
|
||||||
|
|
||||||
class MockModelConfig:
|
class MockModelConfig:
|
||||||
max_model_len = 10
|
max_model_len = 10
|
||||||
pad_token_id = 0
|
pad_token_id = 0
|
||||||
@@ -60,8 +67,6 @@ class MockFDConfig:
|
|||||||
enable_expert_parallel = True
|
enable_expert_parallel = True
|
||||||
enable_chunked_moe = True
|
enable_chunked_moe = True
|
||||||
chunked_moe_size = 2
|
chunked_moe_size = 2
|
||||||
max_moe_num_chunk = 1
|
|
||||||
moe_num_chunk = 1
|
|
||||||
use_ep = True
|
use_ep = True
|
||||||
use_sequence_parallel_moe = False
|
use_sequence_parallel_moe = False
|
||||||
|
|
||||||
@@ -148,19 +153,19 @@ class TestChunkedMoE(unittest.TestCase):
|
|||||||
def run_model_runner(self):
|
def run_model_runner(self):
|
||||||
self.model_runner.initialize_forward_meta()
|
self.model_runner.initialize_forward_meta()
|
||||||
|
|
||||||
assert self.model_runner.fd_config.parallel_config.max_moe_num_chunk == 5, (
|
assert self.model_runner.forward_meta.max_moe_num_chunk == 5, (
|
||||||
f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, "
|
f"chunk size is 2, max token_num is 10, max_moe_num_chunk should be 5, "
|
||||||
f"but got {self.model_runner.fd_config.parallel_config.max_moe_num_chunk}"
|
f"but got {self.model_runner.forward_meta.max_moe_num_chunk }"
|
||||||
)
|
)
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 5, (
|
assert self.model_runner.forward_meta.moe_num_chunk == 5, (
|
||||||
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5"
|
f"chunk size is 2, token_num is 10, moe_num_chunk in rank 0 should be 5, "
|
||||||
f"but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
|
f"but got {self.model_runner.forward_meta.moe_num_chunk}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.model_runner.fd_config.parallel_config.moe_num_chunk == 1, (
|
assert self.model_runner.forward_meta.moe_num_chunk == 1, (
|
||||||
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1"
|
f"chunk size is 2, token_num is 1, moe_num_chunk in rank 1 should be 1, "
|
||||||
f", but got {self.model_runner.fd_config.parallel_config.moe_num_chunk}"
|
f"but got {self.model_runner.forward_meta.moe_num_chunk}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_fused_moe(self):
|
def run_fused_moe(self):
|
||||||
@@ -170,7 +175,7 @@ class TestChunkedMoE(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
x = paddle.ones([1])
|
x = paddle.ones([1])
|
||||||
|
|
||||||
out = self.fused_moe.forward(x, gate)
|
out = self.fused_moe.forward(x, gate, MockForwardMeta())
|
||||||
assert out.shape == x.shape
|
assert out.shape == x.shape
|
||||||
|
|
||||||
def test_case(self):
|
def test_case(self):
|
||||||
|
|||||||
@@ -44,6 +44,13 @@ if "nvidia graphics device" in paddle.device.cuda.get_device_name().lower():
|
|||||||
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
|
os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")
|
||||||
|
|
||||||
|
|
||||||
|
class MockForwardMeta:
|
||||||
|
def __init__(self):
|
||||||
|
# chunked MoE related.
|
||||||
|
self.moe_num_chunk = 1
|
||||||
|
self.max_moe_num_chunk = 1
|
||||||
|
|
||||||
|
|
||||||
class FFNWrapper(paddle.nn.Layer):
|
class FFNWrapper(paddle.nn.Layer):
|
||||||
def __init__(self, model_config: ModelConfig):
|
def __init__(self, model_config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -134,7 +141,7 @@ class TestFusedMoE(unittest.TestCase):
|
|||||||
init_distributed_environment()
|
init_distributed_environment()
|
||||||
|
|
||||||
ffn = FFNWrapper(self.model_config)
|
ffn = FFNWrapper(self.model_config)
|
||||||
|
forward_meta = MockForwardMeta()
|
||||||
moe_cuda_graphs = [None] * 100
|
moe_cuda_graphs = [None] * 100
|
||||||
cache_hidden_states = [None] * 100
|
cache_hidden_states = [None] * 100
|
||||||
test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256, 4096, 4096 * 4]
|
test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256, 4096, 4096 * 4]
|
||||||
@@ -147,7 +154,7 @@ class TestFusedMoE(unittest.TestCase):
|
|||||||
|
|
||||||
num_layers = self.num_layers
|
num_layers = self.num_layers
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
out = ffn.ffn(cache_hidden_states[idx])
|
out = ffn.ffn(cache_hidden_states[idx], forward_meta=forward_meta)
|
||||||
|
|
||||||
moe_cuda_graphs[idx].capture_end()
|
moe_cuda_graphs[idx].capture_end()
|
||||||
|
|
||||||
|
|||||||
@@ -432,6 +432,13 @@ gate_correction_bias_real_data = paddle.to_tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockForwardMeta:
|
||||||
|
def __init__(self):
|
||||||
|
# chunked MoE related.
|
||||||
|
self.moe_num_chunk = 1
|
||||||
|
self.max_moe_num_chunk = 1
|
||||||
|
|
||||||
|
|
||||||
class FuseMoEWrapper(paddle.nn.Layer):
|
class FuseMoEWrapper(paddle.nn.Layer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -607,7 +614,9 @@ class TestFusedMoE(unittest.TestCase):
|
|||||||
|
|
||||||
def fake_model_run():
|
def fake_model_run():
|
||||||
for j in range(num_layers):
|
for j in range(num_layers):
|
||||||
out = fused_moe[j % real_weight_layers].fused_moe(cache_hidden_states[idx], gating)
|
out = fused_moe[j % real_weight_layers].fused_moe(
|
||||||
|
cache_hidden_states[idx], gating, forward_meta=MockForwardMeta()
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
+8
-1
@@ -29,6 +29,13 @@ from fastdeploy.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockForwardMeta:
|
||||||
|
def __init__(self):
|
||||||
|
# chunked MoE related.
|
||||||
|
self.moe_num_chunk = 1
|
||||||
|
self.max_moe_num_chunk = 1
|
||||||
|
|
||||||
|
|
||||||
class FakeModelConfig:
|
class FakeModelConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hidden_size = 768
|
self.hidden_size = 768
|
||||||
@@ -85,7 +92,7 @@ class OpPerformanceTester:
|
|||||||
def _fake_model_run(self, x):
|
def _fake_model_run(self, x):
|
||||||
for j in range(self.num_layers):
|
for j in range(self.num_layers):
|
||||||
if self.gate:
|
if self.gate:
|
||||||
out = self.op_fn(x, self.gate)
|
out = self.op_fn(x, self.gate, forward_meta=MockForwardMeta())
|
||||||
else:
|
else:
|
||||||
out = self.op_fn(x)
|
out = self.op_fn(x)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user