mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
Support setting communication groups in custom_allreduce and the all-to-all\transpose fused operator during the decoding phase. (#5917)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
// adapted from:
|
||||
// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
|
||||
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
@@ -14,8 +15,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "all_reduce.cuh"
|
||||
#include "helper.h"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
@@ -23,8 +24,9 @@ using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
paddle::Tensor& rank_data,
|
||||
int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
@@ -37,9 +39,71 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
|
||||
rank_data.numel(), rank, world_size,
|
||||
full_nvlink);
|
||||
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs,
|
||||
rank_data.data(),
|
||||
rank_data.numel(),
|
||||
rank,
|
||||
world_size,
|
||||
full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* alltoall and transpose in decode.
|
||||
*/
|
||||
void decode_alltoall_transpose(paddle::Tensor& inp,
|
||||
paddle::Tensor& out,
|
||||
fptr_t _fa,
|
||||
fptr_t _reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
|
||||
auto input_size = inp.numel() * 2;
|
||||
auto token_num = inp.shape()[0];
|
||||
auto hidden_size = inp.shape()[1];
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
cudaMemcpyAsync(
|
||||
reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
reg_buffer = inp.data();
|
||||
}
|
||||
switch (out.dtype()) {
|
||||
case phi::DataType::FLOAT32: {
|
||||
fa->decode_alltoall_transpose<float>(stream,
|
||||
reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data()),
|
||||
token_num,
|
||||
hidden_size,
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case phi::DataType::FLOAT16: {
|
||||
fa->decode_alltoall_transpose<half>(stream,
|
||||
reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data()),
|
||||
token_num,
|
||||
hidden_size,
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
case phi::DataType::BFLOAT16: {
|
||||
fa->decode_alltoall_transpose<nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data()),
|
||||
token_num,
|
||||
hidden_size,
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"decode_alltoall_transpose only supports float32, float16 and "
|
||||
"bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,36 +113,43 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
void all_reduce(paddle::Tensor& inp,
|
||||
paddle::Tensor& out,
|
||||
fptr_t _fa,
|
||||
fptr_t _reg_buffer,
|
||||
int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
|
||||
auto input_size = inp.numel() * 2;
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(
|
||||
reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream);
|
||||
} else {
|
||||
reg_buffer = inp.data();
|
||||
}
|
||||
switch (out.dtype()) {
|
||||
case phi::DataType::FLOAT32: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
|
||||
fa->allreduce<float>(stream,
|
||||
reinterpret_cast<float*>(reg_buffer),
|
||||
reinterpret_cast<float*>(out.data()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case phi::DataType::FLOAT16: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data()), out.numel());
|
||||
fa->allreduce<half>(stream,
|
||||
reinterpret_cast<half*>(reg_buffer),
|
||||
reinterpret_cast<half*>(out.data()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
|
||||
case phi::DataType::BFLOAT16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
|
||||
fa->allreduce<nv_bfloat16>(stream,
|
||||
reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
@@ -132,11 +203,11 @@ void clear_ipc_handles(fptr_t _fa) {
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
|
||||
void* buffer;
|
||||
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
|
||||
auto stream =
|
||||
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
|
||||
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||
|
||||
// Allocate buffer
|
||||
@@ -148,19 +219,20 @@ std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
// Create IPC memhandle for the allocated buffer.
|
||||
// Will use it in open_mem_handle.
|
||||
auto handle =
|
||||
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
|
||||
CUDACHECK(
|
||||
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
|
||||
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))},
|
||||
paddle::DataType::UINT8,
|
||||
paddle::GPUPlace(device_index));
|
||||
CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
|
||||
|
||||
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||
}
|
||||
|
||||
|
||||
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void* ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
CUDACHECK(
|
||||
cudaIpcOpenMemHandle((void**)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t*)mem_handle.data()),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||
}
|
||||
|
||||
@@ -168,11 +240,20 @@ void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(decode_alltoall_transpose)
|
||||
.Inputs({"inp", "out"})
|
||||
.Outputs({"new_out"})
|
||||
.Attrs({"_fa: int64_t",
|
||||
"_reg_buffer: int64_t",
|
||||
"reg_buffer_sz_bytes: int64_t"})
|
||||
.SetInplaceMap({{"out", "new_out"}})
|
||||
.SetKernelFn(PD_KERNEL(decode_alltoall_transpose));
|
||||
|
||||
PD_BUILD_STATIC_OP(all_reduce)
|
||||
.Inputs({"inp",
|
||||
"out"})
|
||||
.Inputs({"inp", "out"})
|
||||
.Outputs({"new_out"})
|
||||
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
|
||||
.Attrs({"_fa: int64_t",
|
||||
"_reg_buffer: int64_t",
|
||||
"reg_buffer_sz_bytes: int64_t"})
|
||||
.SetInplaceMap({{"out", "new_out"}})
|
||||
.SetKernelFn(PD_KERNEL(all_reduce));
|
||||
|
||||
@@ -208,12 +208,12 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg,
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val)
|
||||
;
|
||||
while (ld_flag_acquire(self_counter_ptr) != val) {
|
||||
}
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val)
|
||||
;
|
||||
while (ld_flag_volatile(self_counter_ptr) != val) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
@@ -229,6 +229,38 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) decode_alltoall_transpose_kernel(
|
||||
RankData* _dp, // [tp_size, m / tp_size, part_hidden_size]
|
||||
RankSignals sg,
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result, // [m / tp_size, part_hidden_size * tp_size]
|
||||
const int rank,
|
||||
const int token_num,
|
||||
const int hidden_size,
|
||||
const int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
const int hidden_size_p = hidden_size / packed_t<T>::P::size;
|
||||
const int part_hidden_size_p = hidden_size_p / ngpus;
|
||||
const int rank_token_id = token_num / ngpus * rank;
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// alltoall and transpose
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
const int token_idx = idx / hidden_size_p;
|
||||
const int src_token_idx = token_idx + rank_token_id;
|
||||
const int src_rank = (idx % hidden_size_p) / part_hidden_size_p;
|
||||
const int src_idx =
|
||||
src_token_idx * part_hidden_size_p + (idx % part_hidden_size_p);
|
||||
((P*)result)[idx] = ((const P**)&dp.ptrs[0])[src_rank][src_idx];
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData* _dp,
|
||||
@@ -461,6 +493,87 @@ class CustomAllreduce {
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* alltoall and transpose in decode.
|
||||
*/
|
||||
template <typename T>
|
||||
void decode_alltoall_transpose(cudaStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int token_num,
|
||||
int part_hidden_size,
|
||||
int size,
|
||||
int threads = 512,
|
||||
int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
int hidden_size = part_hidden_size * world_size_;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom decode_alltoall_transpose currently requires input length to "
|
||||
"be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (size / d % world_size_ != 0)
|
||||
throw std::runtime_error(
|
||||
"custom decode_alltoall_transpose currently requires input length to "
|
||||
"be multiple "
|
||||
"of " +
|
||||
std::to_string(d) + " and " + std::to_string(world_size_));
|
||||
if (token_num % world_size_ != 0)
|
||||
throw std::runtime_error(
|
||||
"custom decode_alltoall_transpose currently requires input token_num "
|
||||
"to be multiple "
|
||||
"of " +
|
||||
std::to_string(world_size_));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>( \
|
||||
ptrs, sg_, self_sg_, output, rank_, token_num, hidden_size, size);
|
||||
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
KL(ngpus, decode_alltoall_transpose_kernel); \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
|
||||
@@ -1827,7 +1827,6 @@ class FDConfig:
|
||||
elif self.scheduler_config.splitwise_role == "prefill":
|
||||
self.model_config.moe_phase = MoEPhase(phase="prefill")
|
||||
elif self.scheduler_config.splitwise_role == "decode":
|
||||
self._disable_sequence_parallel_moe_if_needed("PD's decode node")
|
||||
self.model_config.moe_phase = MoEPhase(phase="decode")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -35,13 +35,16 @@ def capture_custom_allreduce():
|
||||
yield
|
||||
|
||||
|
||||
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
model_parallel_group = hcg.get_model_parallel_group()
|
||||
def use_custom_allreduce(
|
||||
tp_group: paddle.distributed.communication.group.Group = None, custom_all_reduce_max_bytes: int = 8192 * 1024
|
||||
):
|
||||
if tp_group is None:
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
tp_group = hcg.get_model_parallel_group()
|
||||
global _TP_AR
|
||||
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
|
||||
|
||||
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
|
||||
_TP_AR = CustomAllreduce(tp_group, custom_all_reduce_max_bytes)
|
||||
|
||||
|
||||
def custom_ar_clear_ipc_handles():
|
||||
@@ -86,6 +89,18 @@ try:
|
||||
dist.all_reduce(input_)
|
||||
return input_
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def decode_alltoall_transpose(
|
||||
input_: paddle.Tensor,
|
||||
out: paddle.Tensor = None,
|
||||
) -> paddle.Tensor:
|
||||
"""alltoall and transpose in decode."""
|
||||
if input_.shape[0] == 0:
|
||||
return input_
|
||||
global _TP_AR
|
||||
input_ = _TP_AR.decode_alltoall_transpose(input_, out)
|
||||
return input_
|
||||
|
||||
except:
|
||||
tensor_model_parallel_all_reduce = None
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
all_reduce,
|
||||
clear_ipc_handles,
|
||||
decode_alltoall_transpose,
|
||||
dispose,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_all_reduce,
|
||||
@@ -164,6 +165,23 @@ class CustomAllreduce:
|
||||
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def decode_alltoall_transpose(
|
||||
self,
|
||||
inp: paddle.Tensor,
|
||||
out: paddle.Tensor = None,
|
||||
registered: bool = False,
|
||||
):
|
||||
"""
|
||||
alltoall and transpose in decode.
|
||||
"""
|
||||
if out is None:
|
||||
out = paddle.empty_like(inp)
|
||||
if registered:
|
||||
decode_alltoall_transpose(inp, out, self._ptr, 0, 0)
|
||||
else:
|
||||
decode_alltoall_transpose(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def start_capture(self):
|
||||
"""
|
||||
set CUDA graph flag: True.
|
||||
|
||||
@@ -21,7 +21,10 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.distributed.communication import (
|
||||
decode_alltoall_transpose,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
from fastdeploy.model_executor.utils import (
|
||||
default_weight_loader,
|
||||
@@ -888,15 +891,24 @@ class RowParallelLinear(LinearBase):
|
||||
def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
token_num = x.shape[0]
|
||||
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
|
||||
if token_num_pad > token_num:
|
||||
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
|
||||
x_new[:token_num, :] = x
|
||||
x = x_new
|
||||
out = paddle.zeros_like(x)
|
||||
paddle.distributed.alltoall(out, x, group=self.tp_group)
|
||||
out.reshape_([self.tp_size, -1, x.shape[1]])
|
||||
out = paddle.transpose(out, [1, 0, 2])
|
||||
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
|
||||
if self.fd_config.scheduler_config.splitwise_role == "decode":
|
||||
if not (token_num_pad > token_num):
|
||||
x_padded = x
|
||||
else:
|
||||
x_padded = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
|
||||
x_padded[:token_num] = x
|
||||
out = paddle.zeros([token_num_pad // self.tp_size, x.shape[1] * self.tp_size], x.dtype)
|
||||
decode_alltoall_transpose(x_padded, out)
|
||||
else:
|
||||
if token_num_pad > token_num:
|
||||
x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype)
|
||||
x_new[:token_num, :] = x
|
||||
x = x_new
|
||||
out = paddle.zeros_like(x)
|
||||
paddle.distributed.alltoall(out, x, group=self.tp_group)
|
||||
out.reshape_([self.tp_size, -1, x.shape[1]])
|
||||
out = paddle.transpose(out, [1, 0, 2])
|
||||
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
|
||||
return out
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
|
||||
@@ -64,7 +64,7 @@ class DcuWorker(GpuWorker):
|
||||
):
|
||||
from fastdeploy.distributed.communication import use_custom_allreduce
|
||||
|
||||
use_custom_allreduce()
|
||||
use_custom_allreduce(self.fd_config.parallel_config.tp_group)
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class GpuWorker(WorkerBase):
|
||||
):
|
||||
from fastdeploy.distributed.communication import use_custom_allreduce
|
||||
|
||||
use_custom_allreduce()
|
||||
use_custom_allreduce(self.fd_config.parallel_config.tp_group)
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user