[TSP] Support qwen3 moe tsp + cudagraph (#4871)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support qwen3_moe tsp mode

* fix

* fix

* update

* update

* update

* fix

* support external_rmsnorm

* update

* fix
This commit is contained in:
Yuanle Liu
2025-11-10 23:37:51 +08:00
committed by GitHub
parent fb2eb403ab
commit 3dc0ffa46d
28 changed files with 173 additions and 273 deletions
+10 -9
View File
@@ -137,7 +137,6 @@ class FusedMoE(nn.Layer):
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
if self.ep_size > 1:
self.tp_size = 1
@@ -582,20 +581,18 @@ class FusedMoE(nn.Layer):
Forward split allgather function.
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
token_num_per_rank = (token_num + tp_size - 1) // tp_size
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
start_offset = tp_rank * token_num_per_rank
end_offset = (tp_rank + 1) * token_num_per_rank
start_offset = self.tp_rank * token_num_per_rank
end_offset = (self.tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype)
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
@@ -612,8 +609,12 @@ class FusedMoE(nn.Layer):
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
if (
self.ep_size > 1
and self.tp_size > 1
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
else:
out = self.quant_method.apply(self, x, gate)