mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Models] Add forward_meta to moe models' forward function (#5138)
* [Models] Add forward_meta to moe models' forward function * fix missing param * fix * fix * fix forward_meta * fix test and remove chunked MoE releated in config * fix test * fix * fix
This commit is contained in:
@@ -25,6 +25,7 @@ from fastdeploy.distributed.communication import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_reduce_custom,
|
||||
)
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.utils import h2d_copy, slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -621,7 +622,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -641,7 +642,7 @@ class FusedMoE(nn.Layer):
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(x, gate)
|
||||
out = self.forward_chunked_moe(x, gate, forward_meta)
|
||||
else:
|
||||
out = self.forward_normal(x, gate)
|
||||
|
||||
@@ -652,7 +653,7 @@ class FusedMoE(nn.Layer):
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer, forward_meta: ForwardMeta):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
|
||||
@@ -671,11 +672,11 @@ class FusedMoE(nn.Layer):
|
||||
# input size that are less than a chunk, less than the max size data or empty input
|
||||
# need to be repeated until the max chunk data infer MOE finished.
|
||||
if token_num > chunk_size: # chunked moe
|
||||
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
|
||||
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
|
||||
x_split_list = paddle.tensor_split(x, forward_meta.moe_num_chunk, axis=0)
|
||||
out_split_list = [None] * forward_meta.moe_num_chunk
|
||||
|
||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
|
||||
if i < self.fd_config.parallel_config.moe_num_chunk:
|
||||
for i in range(forward_meta.max_moe_num_chunk):
|
||||
if i < forward_meta.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
@@ -685,7 +686,7 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
|
||||
for i in range(forward_meta.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user