mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] ep+tp all2all (#4836)
This commit is contained in:
@@ -137,6 +137,7 @@ class FusedMoE(nn.Layer):
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
|
||||
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
|
||||
if self.ep_size > 1:
|
||||
self.tp_size = 1
|
||||
@@ -612,7 +613,7 @@ class FusedMoE(nn.Layer):
|
||||
"""
|
||||
token_num = x.shape[0]
|
||||
tp_size = self.fd_config.parallel_config.tensor_parallel_size
|
||||
if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size:
|
||||
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
|
||||
out = self.forward_split_allgather(x, gate)
|
||||
else:
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
|
||||
Reference in New Issue
Block a user