mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] support chunked moe (#4575)
* [Feature] support chunked moe * update * update * fix and add test * update * fix conflict and modity test * fix fused_moe * fix fused_moe * fix docstring * fix * fix typo * fix test * fix * fix * fix test * fix test
This commit is contained in:
@@ -612,6 +612,7 @@ class FusedMoE(nn.Layer):
|
||||
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = multi_outs[:token_num, :]
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
@@ -633,9 +634,63 @@ class FusedMoE(nn.Layer):
|
||||
and token_num >= self.tp_size
|
||||
):
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
|
||||
out = self.forward_chunked_moe(x, gate)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
out = self.forward_normal(x, gate)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
out = tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
return out
|
||||
|
||||
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
"""
|
||||
Split input to multi chunk to reduce the memory usage of moe.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the moe layer.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.s
|
||||
"""
|
||||
chunk_size = self.fd_config.parallel_config.chunked_moe_size
|
||||
token_num = x.shape[0]
|
||||
fake_x = paddle.empty(
|
||||
shape=[0, self.fd_config.model_config.hidden_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
# input size that are less than a chunk, less than the max size data or empty input
|
||||
# need to be repeated until the max chunk data infer MOE finished.
|
||||
if token_num > chunk_size: # chunked moe
|
||||
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
|
||||
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
|
||||
|
||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
|
||||
if i < self.fd_config.parallel_config.moe_num_chunk:
|
||||
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
|
||||
else:
|
||||
# just need to use real data to infer max_moe_num_chunk times.
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
|
||||
out = paddle.concat(out_split_list, axis=0)
|
||||
else:
|
||||
# when only one chunk, just need to use real data to infer once.
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
|
||||
self.quant_method.apply(self, fake_x, gate)
|
||||
|
||||
return out
|
||||
|
||||
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
"""
|
||||
Normal mode of forward.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the moe layer.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user