diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu index cb89cf79a2..ac7006d2c6 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu @@ -1,4 +1,5 @@ -// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu +// adapted from: +// https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // @@ -14,8 +15,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "helper.h" #include "all_reduce.cuh" +#include "helper.h" // Fake pointer type, must match fptr_t type in ops.h. // We use this type alias to indicate when pointers are passed in as int64_t. @@ -23,8 +24,9 @@ using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, - paddle::Tensor& rank_data, int64_t rank, - bool full_nvlink) { + paddle::Tensor& rank_data, + int64_t rank, + bool full_nvlink) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); @@ -37,9 +39,71 @@ fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(), - rank_data.numel(), rank, world_size, - full_nvlink); + return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, + rank_data.data(), + rank_data.numel(), + rank, + world_size, + full_nvlink); +} + +/** + * alltoall and transpose in decode. + */ +void decode_alltoall_transpose(paddle::Tensor& inp, + paddle::Tensor& out, + fptr_t _fa, + fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { + auto fa = reinterpret_cast(_fa); + auto stream = inp.stream(); + + auto input_size = inp.numel() * 2; + auto token_num = inp.shape()[0]; + auto hidden_size = inp.shape()[1]; + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + cudaMemcpyAsync( + reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream); + } else { + reg_buffer = inp.data(); + } + switch (out.dtype()) { + case phi::DataType::FLOAT32: { + fa->decode_alltoall_transpose(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } + case phi::DataType::FLOAT16: { + fa->decode_alltoall_transpose(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + case phi::DataType::BFLOAT16: { + fa->decode_alltoall_transpose( + stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + token_num, + hidden_size, + out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "decode_alltoall_transpose only supports float32, float16 and " + "bfloat16"); + } } /** @@ -49,36 +113,43 @@ fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ -void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa, - fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { +void all_reduce(paddle::Tensor& inp, + paddle::Tensor& out, + fptr_t _fa, + fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); auto stream = inp.stream(); auto input_size = inp.numel() * 2; auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { - cudaMemcpyAsync(reg_buffer, inp.data(), input_size, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync( + reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream); } else { reg_buffer = inp.data(); } switch (out.dtype()) { case phi::DataType::FLOAT32: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), + fa->allreduce(stream, + reinterpret_cast(reg_buffer), reinterpret_cast(out.data()), out.numel()); break; } case phi::DataType::FLOAT16: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data()), out.numel()); + fa->allreduce(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + out.numel()); break; } #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) case phi::DataType::BFLOAT16: { - fa->allreduce( - stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out.data()), out.numel()); + fa->allreduce(stream, + reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + out.numel()); break; } #endif @@ -132,11 +203,11 @@ void clear_ipc_handles(fptr_t _fa) { std::tuple allocate_shared_buffer_and_handle( int64_t size) { - auto device_index = phi::backends::gpu::GetCurrentDeviceId(); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); + auto stream = + paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer @@ -148,19 +219,20 @@ std::tuple allocate_shared_buffer_and_handle( // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. auto handle = - paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index)); - CUDACHECK( - cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); + paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, + paddle::DataType::UINT8, + paddle::GPUPlace(device_index)); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } - fptr_t open_mem_handle(paddle::Tensor& mem_handle) { void* ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle( - (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()), - cudaIpcMemLazyEnablePeerAccess)); + CUDACHECK( + cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)mem_handle.data()), + cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } @@ -168,11 +240,20 @@ void free_shared_buffer(fptr_t buffer) { CUDACHECK(cudaFree(reinterpret_cast(buffer))); } +PD_BUILD_STATIC_OP(decode_alltoall_transpose) + .Inputs({"inp", "out"}) + .Outputs({"new_out"}) + .Attrs({"_fa: int64_t", + "_reg_buffer: int64_t", + "reg_buffer_sz_bytes: int64_t"}) + .SetInplaceMap({{"out", "new_out"}}) + .SetKernelFn(PD_KERNEL(decode_alltoall_transpose)); PD_BUILD_STATIC_OP(all_reduce) - .Inputs({"inp", - "out"}) + .Inputs({"inp", "out"}) .Outputs({"new_out"}) - .Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"}) + .Attrs({"_fa: int64_t", + "_reg_buffer: int64_t", + "reg_buffer_sz_bytes: int64_t"}) .SetInplaceMap({{"out", "new_out"}}) .SetKernelFn(PD_KERNEL(all_reduce)); diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh index b17ece5903..744d0576f5 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh @@ -208,12 +208,12 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg, &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; if constexpr (need_fence) { st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val) - ; + while (ld_flag_acquire(self_counter_ptr) != val) { + } } else { st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val) - ; + while (ld_flag_volatile(self_counter_ptr) != val) { + } } } if constexpr (is_start || need_fence) __syncthreads(); @@ -229,6 +229,38 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { return downcast

(tmp); } +template +__global__ void __launch_bounds__(512, 1) decode_alltoall_transpose_kernel( + RankData* _dp, // [tp_size, m / tp_size, part_hidden_size] + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, // [m / tp_size, part_hidden_size * tp_size] + const int rank, + const int token_num, + const int hidden_size, + const int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + const int hidden_size_p = hidden_size / packed_t::P::size; + const int part_hidden_size_p = hidden_size_p / ngpus; + const int rank_token_id = token_num / ngpus * rank; + auto dp = *_dp; + multi_gpu_barrier(sg, self_sg, rank); + // alltoall and transpose + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + const int token_idx = idx / hidden_size_p; + const int src_token_idx = token_idx + rank_token_id; + const int src_rank = (idx % hidden_size_p) / part_hidden_size_p; + const int src_idx = + src_token_idx * part_hidden_size_p + (idx % part_hidden_size_p); + ((P*)result)[idx] = ((const P**)&dp.ptrs[0])[src_rank][src_idx]; + } + multi_gpu_barrier(sg, self_sg, rank); +} + template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, @@ -461,6 +493,87 @@ class CustomAllreduce { graph_unreg_buffers_.clear(); } + /** + * alltoall and transpose in decode. + */ + template + void decode_alltoall_transpose(cudaStream_t stream, + T* input, + T* output, + int token_num, + int part_hidden_size, + int size, + int threads = 512, + int block_limit = 36) { + auto d = packed_t::P::size; + int hidden_size = part_hidden_size * world_size_; + if (size % d != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input length to " + "be multiple " + "of " + + std::to_string(d)); + if (size / d % world_size_ != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input length to " + "be multiple " + "of " + + std::to_string(d) + " and " + std::to_string(world_size_)); + if (token_num % world_size_ != 0) + throw std::runtime_error( + "custom decode_alltoall_transpose currently requires input token_num " + "to be multiple " + "of " + + std::to_string(world_size_)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>( \ + ptrs, sg_, self_sg_, output, rank_, token_num, hidden_size, size); + +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + KL(ngpus, decode_alltoall_transpose_kernel); \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + /** * Performs allreduce, assuming input has already been registered. * diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d5058d6b3c..8d5eaf33a5 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1827,7 +1827,6 @@ class FDConfig: elif self.scheduler_config.splitwise_role == "prefill": self.model_config.moe_phase = MoEPhase(phase="prefill") elif self.scheduler_config.splitwise_role == "decode": - self._disable_sequence_parallel_moe_if_needed("PD's decode node") self.model_config.moe_phase = MoEPhase(phase="decode") else: raise NotImplementedError diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index 039f545c32..d5b72eebda 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -35,13 +35,16 @@ def capture_custom_allreduce(): yield -def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): - hcg = fleet.get_hybrid_communicate_group() - model_parallel_group = hcg.get_model_parallel_group() +def use_custom_allreduce( + tp_group: paddle.distributed.communication.group.Group = None, custom_all_reduce_max_bytes: int = 8192 * 1024 +): + if tp_group is None: + hcg = fleet.get_hybrid_communicate_group() + tp_group = hcg.get_model_parallel_group() global _TP_AR from fastdeploy.distributed.custom_all_reduce import CustomAllreduce - _TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes) + _TP_AR = CustomAllreduce(tp_group, custom_all_reduce_max_bytes) def custom_ar_clear_ipc_handles(): @@ -86,6 +89,18 @@ try: dist.all_reduce(input_) return input_ + @paddle.jit.marker.unified + def decode_alltoall_transpose( + input_: paddle.Tensor, + out: paddle.Tensor = None, + ) -> paddle.Tensor: + """alltoall and transpose in decode.""" + if input_.shape[0] == 0: + return input_ + global _TP_AR + input_ = _TP_AR.decode_alltoall_transpose(input_, out) + return input_ + except: tensor_model_parallel_all_reduce = None diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index 0c9be796ce..4c081271a1 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -26,6 +26,7 @@ from fastdeploy.distributed.custom_all_reduce import cuda_wrapper from fastdeploy.model_executor.ops.gpu import ( all_reduce, clear_ipc_handles, + decode_alltoall_transpose, dispose, get_graph_buffer_ipc_meta, init_custom_all_reduce, @@ -164,6 +165,23 @@ class CustomAllreduce: all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size) return out + def decode_alltoall_transpose( + self, + inp: paddle.Tensor, + out: paddle.Tensor = None, + registered: bool = False, + ): + """ + alltoall and transpose in decode. + """ + if out is None: + out = paddle.empty_like(inp) + if registered: + decode_alltoall_transpose(inp, out, self._ptr, 0, 0) + else: + decode_alltoall_transpose(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size) + return out + def start_capture(self): """ set CUDA graph flag: True. diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 582432f76b..6aadb2d3ba 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -21,7 +21,10 @@ import paddle from paddle import nn from fastdeploy.config import FDConfig -from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import ( + decode_alltoall_transpose, + tensor_model_parallel_all_reduce, +) from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.utils import ( default_weight_loader, @@ -888,15 +891,24 @@ class RowParallelLinear(LinearBase): def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: token_num = x.shape[0] token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size - if token_num_pad > token_num: - x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype) - x_new[:token_num, :] = x - x = x_new - out = paddle.zeros_like(x) - paddle.distributed.alltoall(out, x, group=self.tp_group) - out.reshape_([self.tp_size, -1, x.shape[1]]) - out = paddle.transpose(out, [1, 0, 2]) - out.reshape_([x.shape[0] // self.tp_size, self.input_size]) + if self.fd_config.scheduler_config.splitwise_role == "decode": + if not (token_num_pad > token_num): + x_padded = x + else: + x_padded = paddle.zeros([token_num_pad, x.shape[1]], x.dtype) + x_padded[:token_num] = x + out = paddle.zeros([token_num_pad // self.tp_size, x.shape[1] * self.tp_size], x.dtype) + decode_alltoall_transpose(x_padded, out) + else: + if token_num_pad > token_num: + x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype) + x_new[:token_num, :] = x + x = x_new + out = paddle.zeros_like(x) + paddle.distributed.alltoall(out, x, group=self.tp_group) + out.reshape_([self.tp_size, -1, x.shape[1]]) + out = paddle.transpose(out, [1, 0, 2]) + out.reshape_([x.shape[0] // self.tp_size, self.input_size]) return out def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: diff --git a/fastdeploy/worker/dcu_worker.py b/fastdeploy/worker/dcu_worker.py index 9a9b3eebe5..6d17da11af 100644 --- a/fastdeploy/worker/dcu_worker.py +++ b/fastdeploy/worker/dcu_worker.py @@ -64,7 +64,7 @@ class DcuWorker(GpuWorker): ): from fastdeploy.distributed.communication import use_custom_allreduce - use_custom_allreduce() + use_custom_allreduce(self.fd_config.parallel_config.tp_group) else: raise RuntimeError(f"Not support device type: {self.device_config.device}") diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 0d57ccf250..70fc146b6f 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -75,7 +75,7 @@ class GpuWorker(WorkerBase): ): from fastdeploy.distributed.communication import use_custom_allreduce - use_custom_allreduce() + use_custom_allreduce(self.fd_config.parallel_config.tp_group) else: raise RuntimeError(f"Not support device type: {self.device_config.device}")