[Feature] support chunked moe (#4575)

* [Feature] support chunked moe

* update

* update

* fix and add test

* update

* fix conflict and modity test

* fix fused_moe

* fix fused_moe

* fix docstring

* fix

* fix typo

* fix test

* fix

* fix

* fix test

* fix test
This commit is contained in:
Longzhi Wang
2025-12-01 15:17:18 +08:00
committed by GitHub
parent 6f42c37359
commit add524d80c
10 changed files with 405 additions and 5 deletions
+56 -1
View File
@@ -612,6 +612,7 @@ class FusedMoE(nn.Layer):
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer):
@@ -633,9 +634,63 @@ class FusedMoE(nn.Layer):
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe:
out = self.forward_chunked_moe(x, gate)
else:
out = self.quant_method.apply(self, x, gate)
out = self.forward_normal(x, gate)
if self.reduce_results and self.tp_size > 1:
out = tensor_model_parallel_all_reduce(out, self.tp_group)
return out
def forward_chunked_moe(self, x: paddle.Tensor, gate: nn.Layer):
"""
Split input to multi chunk to reduce the memory usage of moe.
Args:
x (Tensor): Input tensor to the moe layer.
Returns:
Tensor: Output tensor.s
"""
chunk_size = self.fd_config.parallel_config.chunked_moe_size
token_num = x.shape[0]
fake_x = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
# input size that are less than a chunk, less than the max size data or empty input
# need to be repeated until the max chunk data infer MOE finished.
if token_num > chunk_size: # chunked moe
x_split_list = paddle.tensor_split(x, self.fd_config.parallel_config.moe_num_chunk, axis=0)
out_split_list = [None] * self.fd_config.parallel_config.moe_num_chunk
for i in range(self.fd_config.parallel_config.max_moe_num_chunk):
if i < self.fd_config.parallel_config.moe_num_chunk:
out_split_list[i] = self.quant_method.apply(self, x_split_list[i], gate)
else:
# just need to use real data to infer max_moe_num_chunk times.
self.quant_method.apply(self, fake_x, gate)
out = paddle.concat(out_split_list, axis=0)
else:
# when only one chunk, just need to use real data to infer once.
out = self.quant_method.apply(self, x, gate)
for i in range(self.fd_config.parallel_config.max_moe_num_chunk - 1):
self.quant_method.apply(self, fake_x, gate)
return out
def forward_normal(self, x: paddle.Tensor, gate: nn.Layer):
"""
Normal mode of forward.
Args:
x (Tensor): Input tensor to the moe layer.
Returns:
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
return out