Support setting communication groups in custom_allreduce and the all-to-all\transpose fused operator during the decoding phase. (#5917)

This commit is contained in:
lzy
2026-01-12 14:09:39 +08:00
committed by GitHub
parent 60ee72f682
commit 223b2f5d86
8 changed files with 288 additions and 50 deletions
+22 -10
View File
@@ -21,7 +21,10 @@ import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.distributed.communication import (
decode_alltoall_transpose,
tensor_model_parallel_all_reduce,
)
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
@@ -888,15 +891,24 @@ class RowParallelLinear(LinearBase):
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
token_num = x.shape[0]
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
if token_num_pad > token_num:
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_new[:token_num, :] = x
x = x_new
out = paddle.zeros_like(x)
paddle.distributed.alltoall(out, x, group=self.tp_group)
out.reshape_([self.tp_size, -1, x.shape[1]])
out = paddle.transpose(out, [1, 0, 2])
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
if self.fd_config.scheduler_config.splitwise_role == "decode":
if not (token_num_pad > token_num):
x_padded = x
else:
x_padded = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_padded[:token_num] = x
out = paddle.zeros([token_num_pad // self.tp_size, x.shape[1] * self.tp_size], x.dtype)
decode_alltoall_transpose(x_padded, out)
else:
if token_num_pad > token_num:
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
x_new[:token_num, :] = x
x = x_new
out = paddle.zeros_like(x)
paddle.distributed.alltoall(out, x, group=self.tp_group)
out.reshape_([self.tp_size, -1, x.shape[1]])
out = paddle.transpose(out, [1, 0, 2])
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
return out
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: