[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)

* Support deepseekv3 cache transfer for PD deploy

* clean some log info

---------

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-12-02 14:11:50 +08:00
committed by GitHub
parent 117980dd4e
commit 2e1680838f
17 changed files with 620 additions and 400 deletions
@@ -13,22 +13,24 @@
// limitations under the License.
#pragma once
#include "helper.h"
#include "mla_cache_kernel.cuh"
#include "helper.h"
#include "remote_cache_kv_ipc.h"
template <paddle::DataType T>
std::vector<paddle::Tensor> PrefillMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
prefill_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
pe_size,
block_size,
elem_nums);
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len) {
cudaStream_t stream = kv_pe.stream();
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
template <paddle::DataType T>
std::vector<paddle::Tensor> DecodeMLAWriteCache(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
const AppendAttnMetaData& meta_data,
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const int blocksize = 128;
int grid_size = 1;
if (speculate_decoder) {
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
const int pack_num = elem_nums / PackSize;
GetNumBlocks<128>(pack_num, &grid_size);
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
GetNumBlocks<128>(pack_num, &grid_size);
decode_absorb_cache_kernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_nope.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const auto& kv_pe_dims = kv_pe.dims();
const auto& kv_cache_dims = kv_cache.dims();
meta_data.kv_num_heads = kv_cache_dims[1];
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
const auto nope_size =
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
case paddle::DataType::FLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
meta_data,
kv_nope,
kv_pe,
seq_lens,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
}
return {};
}
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
.Inputs({"kv_nope",
"kv_pe",
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
"block_tables",
paddle::Optional("kv_signal_data")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int"})
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(decode_mla_write_cache)
+1
View File
@@ -527,6 +527,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const std::string& cache_quant_type_str,
const int max_seq_len);
+276 -222
View File
@@ -13,8 +13,8 @@
// limitations under the License.
/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
* Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
@@ -39,8 +39,8 @@
#include "epilogue.cuh"
#include "helper.h"
#include "kernel_traits.cuh"
#include "mainloop_mma.cuh"
#include "mainloop_load.cuh"
#include "mainloop_mma.cuh"
#include "utils.cuh"
#ifdef DEBUG_MLA
@@ -52,76 +52,91 @@ namespace mla_attn {
using namespace cute;
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
template <typename DTypeQ_,
typename DTypeKV_,
typename DTypeO_,
typename IdType_>
struct Params {
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
using DTypeQ = DTypeQ_;
using DTypeKV = DTypeKV_;
using DTypeO = DTypeO_;
using IdType = IdType_;
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(
16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(
16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t kv_stride_block_num;
uint32_t kv_stride_block_size;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
uint32_t o_stride_bsz;
uint32_t o_stride_head_num;
int bsz;
int token_num;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_num;
int bsz;
int token_num;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
int qk_head_dim;
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_num;
float sm_scale;
float sm_scale;
};
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
} else if (group_size == 64) { \
constexpr size_t GROUP_SIZE = 64; \
__VA_ARGS__ \
} else if (group_size == 128) { \
constexpr size_t GROUP_SIZE = 128; \
__VA_ARGS__ \
} else { \
PD_THROW("not support the group_size: ", group_size); \
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
template <typename CollectiveMainloop,
typename CollectiveEpilogue,
typename Ktraits,
bool CAUSAL,
int SM_COUNT = 132,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
__global__ void __launch_bounds__(
Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(
CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
@@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
@@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
@@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
}
PipelineParamsQ pipeline_params_q;
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.role = warp_group_idx == 0
? MainloopPipelineQ::ThreadCategory::Producer
: MainloopPipelineQ::ThreadCategory::Consumer;
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
pipeline_params_q.consumer_arv_count =
cutlass::NumThreadsPerWarpGroup; // just one wg qk
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
MainloopPipeline pipeline_kv = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytesKV;
return MainloopPipeline(shared_storage.pipeline_kv,
pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
@@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
if (warp_group_idx == 0) {
// producer
if constexpr(USE_REG_EALLOC) {
if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
const uint32_t warp_idx_in_warpgroup =
__shfl_sync(0xffffffff, warp_idx % 4, 0);
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineStateQ smem_pipe_write_q =
cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv =
cutlass::make_producer_start_state<MainloopPipeline>();
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
cutlass::arch::NamedBarrier::sync(
Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
collective_mainloop.load_q(mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
collective_mainloop.load_kv(mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
collective_mainloop.load_kv_tma(mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id);
}
}
}
} else {
// consumer
if constexpr(USE_REG_EALLOC) {
if constexpr (USE_REG_EALLOC) {
cutlass::arch::warpgroup_reg_alloc<216>();
}
PipelineStateQ smem_pipe_read_q;
PipelineState smem_pipe_read_kv;
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
Tensor tOrO =
partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
auto attention_updater =
OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
mainloop_params.sm_scale);
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int seq_len_decoder_now =
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
cutlass::arch::NamedBarrier::sync(
Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
mma_f16<Ktraits, CAUSAL>(mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
collective_epilogue.store(epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
template <typename KernelTraits,
bool CAUSAL,
typename Params,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
Params &params, cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using NV_TYPE = typename KernelTraits::NV_TYPE;
using CollectiveMainloop =
CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_num,
params.max_draft_token_num
});
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments(
{make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
make_stride(params.qk_head_dim, _1{})), // layout q
make_layout(
make_shape(
params.block_size, params.qk_head_dim, params.max_block_num),
make_stride(params.qk_head_dim,
_1{},
params.block_size * params.qk_head_dim)),
make_layout(make_shape(params.chunk_num,
params.bsz * params.max_draft_token_num *
params.q_num_head),
make_stride(params.bsz * params.max_draft_token_num *
params.q_num_head,
_1{})),
params.Q,
params.KV,
params.m,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
params.max_block_num_per_seq,
params.q_stride_bsz,
params.q_stride_head_num,
params.kv_stride_block_num,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_num,
params.max_draft_token_num});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments_ntma({
params.O,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})), // layout O
params.O_tmp,
make_layout(
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
make_stride(params.vo_head_dim, _1{})) // layout O_tmp
});
// Get the ptr to kernel function.
auto kernel =
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
CollectiveEpilogue,
KernelTraits,
CAUSAL,
132>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
int device;
cudaGetDevice(&device);
int multiprocessor_count;
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
cudaDeviceGetAttribute(
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
int act_blocks_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
@@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
dim3 grid_dims = {multiprocessor_count, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
mainloop_params, epilogue_params
);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
epilogue_params);
if (params.chunk_num > 1) {
constexpr int vec_size = 16 / sizeof(DTypeO);
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
dim3 grids_merge(multiprocessor_count,
params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE,
vec_size,
@@ -423,28 +470,35 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
template <uint32_t HEAD_DIM_QK,
uint32_t HEAD_DIM_VO,
typename NV_TYPE,
typename Params,
bool USE_REG_EALLOC = false,
bool USE_FIXED_BLOCK = true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params &params,
cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
DISPATCH_GROUP_SIZE(params.q_num_head,
GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,
/*BLOCK_SHAPE_Q_=*/64,
/*BLOCK_SHAPE_KV_=*/64,
/*NUM_STAGES_=*/2,
typename Params::DTypeQ,
typename Params::DTypeKV,
typename Params::DTypeO,
typename Params::IdType,
NV_TYPE>,
CAUSAL,
Params,
USE_REG_EALLOC,
USE_FIXED_BLOCK>(params, stream);)
} else {
return cudaErrorNotSupported;
}
+2 -1
View File
@@ -1,6 +1,7 @@
pkill -9 -f python
pkill -9 -f fastdeploy
pkill -9 -f gunicorn
pkill -9 -f redis-server
# Kill redis-server if you need.
#pkill -9 -f redis-server
sleep 1
+20 -9
View File
@@ -159,16 +159,22 @@ class CacheMessager:
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
# value cache
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
if val_cache_key in self.gpu_cache_kvs:
val_cache = self.gpu_cache_kvs[val_cache_key]
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -198,7 +204,6 @@ class CacheMessager:
elif protocol == "rdma":
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
self.messager[protocol] = RDMACommManager(
splitwise_role,
rank,
@@ -460,16 +465,22 @@ class CacheMessagerV1:
cache_v = []
self.messager = {}
for layer_idx in range(self.num_layers):
# value cache
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
if val_cache_key in self.gpu_cache_kvs:
val_cache = self.gpu_cache_kvs[val_cache_key]
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_v_ptr_list.append(val_cache.data_ptr())
# key cache
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)
@@ -245,6 +245,15 @@ class PrefixCacheManager:
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
visible_devices = get_all_visible_devices()
val_cache_arg_str = ""
if val_cache_shape:
if isinstance(val_cache_shape, list):
val_shape_str = ",".join(map(str, val_cache_shape))
else:
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
@@ -259,7 +268,7 @@ class PrefixCacheManager:
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ f" --value_cache_shape {val_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
@@ -332,6 +341,15 @@ class PrefixCacheManager:
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
visible_devices = get_all_visible_devices()
val_cache_arg_str = ""
if value_cache_shape:
if isinstance(value_cache_shape, list):
val_shape_str = ",".join(map(str, value_cache_shape))
else:
val_shape_str = str(value_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
@@ -345,7 +363,7 @@ class PrefixCacheManager:
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ f" --value_cache_shape {value_cache_shape}"
+ val_cache_arg_str
+ f" --pod_ip {pod_ip}"
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
@@ -198,8 +198,8 @@ int get_port_info(struct ibv_context* Context,
int parse_port_ib_info();
// Memory region exchange
bool client_exchange_mr(struct RdmaContext* ctx);
bool server_exchange_mr(struct RdmaContext* ctx);
bool client_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
bool server_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
bool server_send_memory_region(struct RdmaContext* ctx,
void* local_mr,
int byte_num);
@@ -149,6 +149,7 @@ class RDMACommunicator {
struct ibv_pd* g_pd = NULL; // fd
int RDMACommunicator_status; // Communicator status flag
bool start_client_listener = false; // Client listener flag
bool has_value_cache_; // MLA does not have value cache.
};
#endif // KVCACHE_RDMA_H
@@ -712,8 +712,8 @@ bool exchange_mr_vector(struct RdmaContext *ctx,
* @param ctx The RDMA context
* @return true on success, false on failure
*/
bool client_exchange_mr(struct RdmaContext *ctx) {
LOGD("verb client exchange mr: start");
bool client_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
LOGD("verb client exchange mr: start. has_value_cache=%d", has_value_cache);
if (ctx->conn.layer_number <= 0) {
ERR("Invalid layer number: %d", ctx->conn.layer_number);
@@ -723,19 +723,27 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
auto layer_num = ctx->conn.layer_number;
std::vector<void *> key_ptrs(layer_num);
std::vector<uint32_t> key_rkeys(layer_num);
std::vector<void *> val_ptrs(layer_num);
std::vector<uint32_t> val_rkeys(layer_num);
std::vector<void *> val_ptrs;
std::vector<uint32_t> val_rkeys;
if (has_value_cache) {
val_ptrs.resize(layer_num);
val_rkeys.resize(layer_num);
}
if (!exchange_mr_vector(ctx, key_ptrs, true)) return false;
if (!exchange_mr_vector(ctx, key_rkeys, true)) return false;
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
if (has_value_cache) {
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
}
for (int i = 0; i < layer_num; ++i) {
ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]);
ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]);
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
if (has_value_cache) {
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
}
}
return true;
}
@@ -746,8 +754,8 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
* @param ctx The RDMA context
* @return true on success, false on failure
*/
bool server_exchange_mr(struct RdmaContext *ctx) {
LOGD("verbs server exchange mr: start");
bool server_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
LOGD("verbs server exchange mr: start. has_value_cache=%d", has_value_cache);
if (ctx->conn.layer_number <= 0) {
ERR("Invalid layer number: %d", ctx->conn.layer_number);
@@ -759,8 +767,16 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
auto &val_mrs = ctx->conn.write_cache_value_server_mr_list;
// Verify that server memory regions are properly initialized
if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) {
ERR("server write cache memory region size error");
if (key_mrs.size() != layer_num) {
ERR("server write cache KEY memory region size error: %zu vs %d",
key_mrs.size(),
layer_num);
return false;
}
if (has_value_cache && val_mrs.size() != layer_num) {
ERR("server write cache VALUE memory region size error: %zu vs %d",
val_mrs.size(),
layer_num);
return false;
}
@@ -772,22 +788,27 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
send_key_ptrs.reserve(layer_num);
send_key_rkeys.reserve(layer_num);
send_val_ptrs.reserve(layer_num);
send_val_rkeys.reserve(layer_num);
if (has_value_cache) {
send_val_ptrs.reserve(layer_num);
send_val_rkeys.reserve(layer_num);
}
// Collect memory region information from local MRs
for (int i = 0; i < layer_num; ++i) {
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
send_key_rkeys.push_back(key_mrs[i]->rkey);
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
send_val_rkeys.push_back(val_mrs[i]->rkey);
if (has_value_cache) {
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
send_val_rkeys.push_back(val_mrs[i]->rkey);
}
}
// Send all vectors to client
if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false;
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
if (has_value_cache) {
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
}
return true;
}
@@ -78,6 +78,18 @@ RDMACommunicator::RDMACommunicator(std::string& role,
throw std::runtime_error("Invalid layer number");
}
if (local_cache_value_ptr_layer_head_.empty()) {
has_value_cache_ = false;
WARN(
"Value Cache is empty (Maybe MLA Model). RDMA will run in Key-Only "
"mode.");
} else {
has_value_cache_ = true;
if (local_cache_value_ptr_layer_head_.size() != layer_number) {
throw std::runtime_error("Key and Value cache layer number mismatch!");
}
}
// Step 2: Setup cache vectors and pointers
resize_vectors();
assign_pointers();
@@ -100,7 +112,6 @@ RDMACommunicator::RDMACommunicator(std::string& role,
});
server_thread.detach();
}
RDMACommunicator_status = 1;
INFO("RDMA communicator initialized successfully");
} catch (const std::exception& e) {
@@ -119,7 +130,9 @@ void RDMACommunicator::resize_vectors() {
}
local_cache_key_ptr_per_layer.resize(layer_number);
local_cache_value_ptr_per_layer.resize(layer_number);
if (has_value_cache_) {
local_cache_value_ptr_per_layer.resize(layer_number);
}
}
void RDMACommunicator::assign_pointers() {
@@ -131,15 +144,19 @@ void RDMACommunicator::assign_pointers() {
// Assign pointers for each layer and block
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
// Validate layer head pointers
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 ||
local_cache_value_ptr_layer_head_[layer_idx] == 0) {
if (local_cache_key_ptr_layer_head_[layer_idx] == 0) {
throw std::runtime_error("Invalid cache pointer for layer " +
std::to_string(layer_idx));
}
// Resize block vectors for current layer
local_cache_key_ptr_per_layer[layer_idx].resize(block_number);
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
if (has_value_cache_) {
if (local_cache_value_ptr_layer_head_[layer_idx] == 0) {
throw std::runtime_error("Invalid VALUE cache pointer for layer " +
std::to_string(layer_idx));
}
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
}
// Calculate and assign block pointers
for (int block_idx = 0; block_idx < block_number; ++block_idx) {
@@ -147,9 +164,12 @@ void RDMACommunicator::assign_pointers() {
reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] +
block_idx * block_size_byte);
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[layer_idx] +
block_idx * block_size_byte);
if (has_value_cache_) {
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
reinterpret_cast<void*>(
local_cache_value_ptr_layer_head_[layer_idx] +
block_idx * block_size_byte);
}
}
}
}
@@ -347,7 +367,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
continue;
}
server_exchange_mr(ctx);
server_exchange_mr(ctx, has_value_cache_);
} else {
auto ctx_iter = connectionContexts.find(event_fd);
if (ctx_iter == connectionContexts.end()) {
@@ -435,18 +455,33 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) {
return false;
}
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) {
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
layer_idx);
}
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
ERR("Failed to deregister memory region: write_mr_value_list, layer %d",
layer_idx);
if (!write_mr_key_list.empty()) {
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
if (write_mr_key_list[layer_idx]) {
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
layer_idx);
}
write_mr_key_list[layer_idx] = nullptr;
}
}
write_mr_key_list.clear();
}
if (!write_mr_value_list.empty()) {
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
if (write_mr_value_list[layer_idx]) {
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
ERR("Failed to deregister memory region: write_mr_value_list, layer "
"%d",
layer_idx);
}
write_mr_value_list[layer_idx] = nullptr;
}
}
write_mr_value_list.clear();
}
return true;
}
@@ -548,7 +583,7 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ERR("Couldn't getexchange port infodestinations");
return static_cast<int>(ConnStatus::kError);
} else {
client_exchange_mr(ctx);
client_exchange_mr(ctx, has_value_cache_);
}
// Allocate RDMA read and register read buffers
@@ -735,15 +770,17 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
}
std::lock_guard<std::mutex> lock(mutex_);
if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) {
if (!write_mr_key_list.empty()) {
WARN("Memory regions already registered");
return true;
}
const size_t list_size = layer_number;
write_mr_key_list.resize(list_size, nullptr);
write_mr_value_list.resize(list_size, nullptr);
if (has_value_cache_) {
write_mr_value_list.resize(list_size, nullptr);
}
const uint32_t access_flags =
IBV_ACCESS_LOCAL_WRITE |
@@ -753,8 +790,6 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
for (int i = 0; i < static_cast<int>(list_size); ++i) {
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
size_t size = static_cast<size_t>(block_size_byte) * block_number;
write_mr_key_list[i] =
@@ -765,13 +800,18 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
access_flags);
if (!write_mr_key_list[i]) goto fail;
write_mr_value_list[i] =
register_memory_region(ctx->pd,
val_ptr,
size,
"client_value_" + std::to_string(i),
access_flags);
if (!write_mr_value_list[i]) goto fail;
if (has_value_cache_) {
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
write_mr_value_list[i] =
register_memory_region(ctx->pd,
val_ptr,
size,
"client_value_" + std::to_string(i),
access_flags);
if (!write_mr_value_list[i]) goto fail;
}
}
return true;
@@ -812,8 +852,6 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
for (int i = 0; i < layer_number; ++i) {
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
size_t size = static_cast<size_t>(block_size_byte) * block_number;
struct ibv_mr* key_mr = register_memory_region(
@@ -822,21 +860,25 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
ERR("Failed to register key MR at layer %d", i);
goto fail;
}
struct ibv_mr* value_mr = register_memory_region(
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
if (!value_mr) {
ERR("Failed to register value MR at layer %d", i);
ibv_dereg_mr(key_mr);
goto fail;
}
write_cache_key_server_mr_list.push_back(key_mr);
write_cache_value_server_mr_list.push_back(value_mr);
if (has_value_cache_) {
void* val_ptr =
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
struct ibv_mr* value_mr = register_memory_region(
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
if (!value_mr) {
ERR("Failed to register value MR at layer %d", i);
ibv_dereg_mr(key_mr);
goto fail;
}
write_cache_value_server_mr_list.push_back(value_mr);
}
}
ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list;
ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list;
return true;
fail:
@@ -899,8 +941,12 @@ int RDMACommunicator::write_cache(const std::string& ip,
uint32_t cache_key_rkey =
ctx->conn.write_cache_key_remote_rkey_list[layer_idx];
uint32_t cache_value_rkey =
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
uint32_t cache_value_rkey = 0;
if (has_value_cache_) {
cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
}
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
uint64_t offset_in_block =
@@ -914,15 +960,19 @@ int RDMACommunicator::write_cache(const std::string& ip,
cache_key_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
char_ptr = static_cast<char*>(
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
cache_value_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
if (has_value_cache_) {
char_ptr = static_cast<char*>(
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
cache_value_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
}
}
ctx->conn.wc_target_count = 0;
for (int i = 0; i < 2; ++i) {
int loop_count = has_value_cache_ ? 2 : 1;
for (int i = 0; i < loop_count; ++i) {
bool is_key = (i == 0);
uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey);
std::vector<uint64_t>& remote_addr =
@@ -1038,6 +1088,10 @@ void RDMACommunicator::prepare_write_requests(
bool is_key,
std::vector<uint64_t>& remote_addr,
uint32_t rkey) {
if (!is_key) {
assert(!write_mr_value_list.empty() &&
"Trying to process Value Cache but it is empty!");
}
auto block_num = local_block_ids.size();
for (size_t i = 0; i < block_num; ++i) {
@@ -40,11 +40,10 @@ class RDMACommManager:
try:
import rdma_comm
except:
logger.error(
raise RuntimeError(
"The installation of the RDMA library failed."
"Confirm whether your network card supports RDMA transmission."
)
return
self.messager = rdma_comm.RDMACommunicator(
splitwise_role,
gpu_id,
+2 -1
View File
@@ -755,8 +755,9 @@ class LLMEngine:
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
ctx = multiprocessing.get_context("spawn")
self.dp_processed.append(
multiprocessing.Process(
ctx.Process(
target=start_data_parallel_service,
args=(
self.cfg,
@@ -205,7 +205,6 @@ class MLAAttentionBackend(AttentionBackend):
self.group_size,
self.block_size,
)
# MLA
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
@@ -279,6 +278,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none",
getattr(forward_meta, "max_input_length", -1),
)
@@ -422,10 +422,10 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none",
self.max_seq_len,
)
# FA
fmha_out = self.flash_attn_func(
q,
@@ -307,6 +307,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.kv_signal_data_list[layer.layer_id],
"none",
getattr(forward_meta, "max_input_length", -1),
)
+2 -2
View File
@@ -258,10 +258,10 @@ class FusedMoE(nn.Layer):
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
if not param._is_initialized():
param.initialize()
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
return
if not param._is_initialized():
param.initialize()
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if shard_id is None:
# 1.gate up fused in disk
@@ -341,6 +341,7 @@ class DeepseekV3MLAAttention(nn.Layer):
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
query, compressed_kv, key_pe = qkv_a_out.split(
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)
@@ -399,6 +400,7 @@ class DeepseekV3MLAAttention(nn.Layer):
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
]
)
fmha_out_decode = self.mla_attn(
q=q_input,
k=None,
@@ -418,6 +420,7 @@ class DeepseekV3MLAAttention(nn.Layer):
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
if fmha_out is None:
fmha_out = fmha_out_decode
else:
@@ -515,6 +518,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -674,7 +678,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
@@ -741,6 +744,20 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
)
return position_ids, mask_encoder_batch
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[1, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
def forward(
self,
ids_remove_padding: paddle.Tensor,
-1
View File
@@ -2328,7 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
# 5. Post Process
model_output_data = ModelOutputData(
next_tokens=self.share_inputs["next_tokens"],