mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
Support setting communication groups in custom_allreduce and the all-to-all\transpose fused operator during the decoding phase. (#5917)
This commit is contained in:
@@ -26,6 +26,7 @@ from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
all_reduce,
|
||||
clear_ipc_handles,
|
||||
decode_alltoall_transpose,
|
||||
dispose,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_all_reduce,
|
||||
@@ -164,6 +165,23 @@ class CustomAllreduce:
|
||||
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def decode_alltoall_transpose(
|
||||
self,
|
||||
inp: paddle.Tensor,
|
||||
out: paddle.Tensor = None,
|
||||
registered: bool = False,
|
||||
):
|
||||
"""
|
||||
alltoall and transpose in decode.
|
||||
"""
|
||||
if out is None:
|
||||
out = paddle.empty_like(inp)
|
||||
if registered:
|
||||
decode_alltoall_transpose(inp, out, self._ptr, 0, 0)
|
||||
else:
|
||||
decode_alltoall_transpose(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def start_capture(self):
|
||||
"""
|
||||
set CUDA graph flag: True.
|
||||
|
||||
Reference in New Issue
Block a user