support ep+tp at op layer (#4688)

This commit is contained in:
zhupengyang
2025-11-05 11:15:57 +08:00
committed by GitHub
parent 937eb3c6ed
commit 2fd254e5b7
8 changed files with 138 additions and 105 deletions
+29 -1
View File
@@ -564,6 +564,29 @@ class FusedMoE(nn.Layer):
else:
self.quant_method.process_loaded_weights(self, state_dict)
def forward_split_allgather(self, x: paddle.Tensor, gate: nn.Layer):
"""
Forward split allgather function.
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
token_num_per_rank = (token_num + tp_size - 1) // tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
start_offset = tp_rank * token_num_per_rank
end_offset = (tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
def forward(self, x: paddle.Tensor, gate: nn.Layer):
"""
Defines the forward computation of the moe layer.
@@ -575,5 +598,10 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
out = self.quant_method.apply(self, x, gate)
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size:
out = self.forward_split_allgather(x, gate)
else:
out = self.quant_method.apply(self, x, gate)
return out