mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[PD Disaggregation] Support PD deployment of DeepSeekv3. (#5251)
* Support deepseekv3 cache transfer for PD deploy * clean some log info --------- Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -13,22 +13,24 @@
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mla_cache_kernel.cuh"
|
||||
#include "helper.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -50,8 +52,10 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
|
||||
prefill_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
@@ -65,6 +69,33 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
|
||||
pe_size,
|
||||
block_size,
|
||||
elem_nums);
|
||||
|
||||
const char* fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&(RemoteCacheKvIpc::
|
||||
save_cache_kv_complete_signal_layerwise_per_query),
|
||||
(void*)nullptr);
|
||||
} else {
|
||||
if (kv_signal_data) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||
(void*)(const_cast<int64_t*>(
|
||||
kv_signal_data.get().data<int64_t>())));
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -77,6 +108,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
@@ -85,7 +117,8 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
const auto nope_size =
|
||||
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
@@ -95,30 +128,34 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
return PrefillMLAWriteCache<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
@@ -126,18 +163,18 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
|
||||
template <paddle::DataType T>
|
||||
std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -154,15 +191,16 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
|
||||
|
||||
if (speculate_decoder) {
|
||||
const uint32_t elem_nums = token_num * kv_num_heads * all_size;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
speculate_decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
@@ -182,8 +220,10 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
decode_absorb_cache_kernel<DataType_, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_nope.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(
|
||||
const_cast<data_t*>(kv_pe.data<data_t>())),
|
||||
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
@@ -218,7 +258,8 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
const auto& kv_pe_dims = kv_pe.dims();
|
||||
const auto& kv_cache_dims = kv_cache.dims();
|
||||
meta_data.kv_num_heads = kv_cache_dims[1];
|
||||
const auto nope_size = kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
const auto nope_size =
|
||||
kv_nope_dims[kv_nope_dims.size() - 1] / meta_data.kv_num_heads;
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
@@ -228,38 +269,39 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
||||
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
||||
switch (kv_pe.dtype()) {
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
return DecodeMLAWriteCache<paddle::DataType::FLOAT16>(
|
||||
meta_data,
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
speculate_decoder,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||
.Inputs({"kv_nope",
|
||||
"kv_pe",
|
||||
@@ -268,11 +310,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache)
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables"})
|
||||
"block_tables",
|
||||
paddle::Optional("kv_signal_data")})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int"})
|
||||
.Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(decode_mla_write_cache)
|
||||
|
||||
@@ -527,6 +527,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len);
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
|
||||
* Dao. Licensed under the BSD 3-Clause.
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar,
|
||||
* Pradeep Ramani, Tri Dao. Licensed under the BSD 3-Clause.
|
||||
*
|
||||
* Modified by the FlashInfer team.
|
||||
*/
|
||||
@@ -39,8 +39,8 @@
|
||||
#include "epilogue.cuh"
|
||||
#include "helper.h"
|
||||
#include "kernel_traits.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "mainloop_load.cuh"
|
||||
#include "mainloop_mma.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
#ifdef DEBUG_MLA
|
||||
@@ -52,76 +52,91 @@ namespace mla_attn {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename DTypeQ_, typename DTypeKV_, typename DTypeO_, typename IdType_>
|
||||
template <typename DTypeQ_,
|
||||
typename DTypeKV_,
|
||||
typename DTypeO_,
|
||||
typename IdType_>
|
||||
struct Params {
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
using DTypeQ = DTypeQ_;
|
||||
using DTypeKV = DTypeKV_;
|
||||
using DTypeO = DTypeO_;
|
||||
using IdType = IdType_;
|
||||
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(
|
||||
16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(
|
||||
16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
uint32_t kv_stride_block_num;
|
||||
uint32_t kv_stride_block_size;
|
||||
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
uint32_t o_stride_bsz;
|
||||
uint32_t o_stride_head_num;
|
||||
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_num;
|
||||
int bsz;
|
||||
int token_num;
|
||||
int max_block_num;
|
||||
int max_block_num_per_seq;
|
||||
int q_num_head;
|
||||
int qk_head_dim;
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_num;
|
||||
|
||||
float sm_scale;
|
||||
float sm_scale;
|
||||
};
|
||||
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
#define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 64) { \
|
||||
constexpr size_t GROUP_SIZE = 64; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 128) { \
|
||||
constexpr size_t GROUP_SIZE = 128; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
PD_THROW("not support the group_size: ", group_size); \
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
|
||||
template <typename CollectiveMainloop,
|
||||
typename CollectiveEpilogue,
|
||||
typename Ktraits,
|
||||
bool CAUSAL,
|
||||
int SM_COUNT = 132,
|
||||
bool USE_REG_EALLOC = false,
|
||||
bool USE_FIXED_BLOCK = true>
|
||||
__global__ void __launch_bounds__(
|
||||
Ktraits::NUM_WARPS *cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveEpilogue::Params const epilogue_params) {
|
||||
using DTypeQ = typename Ktraits::DTypeQ;
|
||||
using DTypeKV = typename Ktraits::DTypeKV;
|
||||
using DTypeO = typename Ktraits::DTypeO;
|
||||
@@ -147,7 +162,8 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
using PipelineStateQ = typename MainloopPipelineQ::PipelineState;
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
auto &shared_storage =
|
||||
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
@@ -158,12 +174,14 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
}
|
||||
|
||||
// Obtain warp index
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
int const warp_group_thread_idx =
|
||||
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NUM_MMA_THREADS;
|
||||
@@ -173,17 +191,20 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
}
|
||||
|
||||
PipelineParamsQ pipeline_params_q;
|
||||
pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.role = warp_group_idx == 0
|
||||
? MainloopPipelineQ::ThreadCategory::Producer
|
||||
: MainloopPipelineQ::ThreadCategory::Consumer;
|
||||
pipeline_params_q.producer_arv_count = NUM_COPY_THREADS;
|
||||
pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
pipeline_params_q.consumer_arv_count =
|
||||
cutlass::NumThreadsPerWarpGroup; // just one wg qk
|
||||
|
||||
MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q);
|
||||
MainloopPipeline pipeline_kv = [&] {
|
||||
if constexpr (use_tma_load_kv) {
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params,
|
||||
pipeline_params.transaction_bytes =
|
||||
CollectiveMainloop::TmaTransactionBytesKV;
|
||||
return MainloopPipeline(shared_storage.pipeline_kv,
|
||||
pipeline_params,
|
||||
/*cluster_shape=*/Shape<_1, _1, _1>{});
|
||||
} else {
|
||||
return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params);
|
||||
@@ -196,191 +217,217 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
if (warp_group_idx == 0) {
|
||||
// producer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
if constexpr (USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<72>();
|
||||
}
|
||||
const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
const uint32_t warp_idx_in_warpgroup =
|
||||
__shfl_sync(0xffffffff, warp_idx % 4, 0);
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineStateQ smem_pipe_write_q =
|
||||
cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int seq_len_decoder_now =
|
||||
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
cutlass::arch::NamedBarrier::sync(
|
||||
Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
collective_mainloop.load_q(mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
collective_mainloop.load_kv(mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
collective_mainloop.load_kv_tma(mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// consumer
|
||||
if constexpr(USE_REG_EALLOC) {
|
||||
if constexpr (USE_REG_EALLOC) {
|
||||
cutlass::arch::warpgroup_reg_alloc<216>();
|
||||
}
|
||||
PipelineStateQ smem_pipe_read_q;
|
||||
PipelineState smem_pipe_read_kv;
|
||||
|
||||
typename Ktraits::TiledMmaPVSS tiled_mma_pv;
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
Tensor tOrO =
|
||||
partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
auto attention_updater =
|
||||
OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(
|
||||
mainloop_params.sm_scale);
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int seq_len_decoder_now =
|
||||
mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
cutlass::arch::NamedBarrier::sync(
|
||||
Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
mma_f16<Ktraits, CAUSAL>(mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
collective_epilogue.store(epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
template <typename KernelTraits,
|
||||
bool CAUSAL,
|
||||
typename Params,
|
||||
bool USE_REG_EALLOC = false,
|
||||
bool USE_FIXED_BLOCK = true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(
|
||||
Params ¶ms, cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
using DTypeKV = typename KernelTraits::DTypeKV;
|
||||
using DTypeO = typename KernelTraits::DTypeO;
|
||||
using IdType = typename KernelTraits::IdType;
|
||||
using NV_TYPE = typename KernelTraits::NV_TYPE;
|
||||
|
||||
using CollectiveMainloop =
|
||||
CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveMainloop = CollectiveMainloop<KernelTraits, CAUSAL>;
|
||||
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
{make_layout(
|
||||
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim),
|
||||
make_stride(params.qk_head_dim, _1{})), // layout q
|
||||
make_layout(
|
||||
make_shape(
|
||||
params.block_size, params.qk_head_dim, params.max_block_num),
|
||||
make_stride(params.qk_head_dim,
|
||||
_1{},
|
||||
params.block_size * params.qk_head_dim)),
|
||||
make_layout(make_shape(params.chunk_num,
|
||||
params.bsz * params.max_draft_token_num *
|
||||
params.q_num_head),
|
||||
make_stride(params.bsz * params.max_draft_token_num *
|
||||
params.q_num_head,
|
||||
_1{})),
|
||||
params.Q,
|
||||
params.KV,
|
||||
params.m,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
params.max_block_num_per_seq,
|
||||
params.q_stride_bsz,
|
||||
params.q_stride_head_num,
|
||||
params.kv_stride_block_num,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num});
|
||||
typename CollectiveEpilogue::Params epilogue_params =
|
||||
CollectiveEpilogue::to_underlying_arguments_ntma({
|
||||
params.O,
|
||||
make_layout(
|
||||
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
|
||||
make_stride(params.vo_head_dim, _1{})), // layout O
|
||||
params.O_tmp,
|
||||
make_layout(
|
||||
make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim),
|
||||
make_stride(params.vo_head_dim, _1{})) // layout O_tmp
|
||||
});
|
||||
|
||||
// Get the ptr to kernel function.
|
||||
auto kernel =
|
||||
MLAWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits, CAUSAL, 132>;
|
||||
auto kernel = MLAWithKVCacheKernel<CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
KernelTraits,
|
||||
CAUSAL,
|
||||
132>;
|
||||
int smem_size = sizeof(typename KernelTraits::SharedStorage);
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int multiprocessor_count;
|
||||
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
cudaDeviceGetAttribute(
|
||||
&multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
|
||||
int act_blocks_per_sm;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
@@ -390,15 +437,15 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
mainloop_params, epilogue_params
|
||||
);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(mainloop_params,
|
||||
epilogue_params);
|
||||
if (params.chunk_num > 1) {
|
||||
constexpr int vec_size = 16 / sizeof(DTypeO);
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(multiprocessor_count,
|
||||
params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
@@ -423,28 +470,35 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
template <uint32_t HEAD_DIM_QK,
|
||||
uint32_t HEAD_DIM_VO,
|
||||
typename NV_TYPE,
|
||||
typename Params,
|
||||
bool USE_REG_EALLOC = false,
|
||||
bool USE_FIXED_BLOCK = true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params ¶ms,
|
||||
cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head,
|
||||
GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
/*BLOCK_SHAPE_Q_=*/64,
|
||||
/*BLOCK_SHAPE_KV_=*/64,
|
||||
/*NUM_STAGES_=*/2,
|
||||
typename Params::DTypeQ,
|
||||
typename Params::DTypeKV,
|
||||
typename Params::DTypeO,
|
||||
typename Params::IdType,
|
||||
NV_TYPE>,
|
||||
CAUSAL,
|
||||
Params,
|
||||
USE_REG_EALLOC,
|
||||
USE_FIXED_BLOCK>(params, stream);)
|
||||
} else {
|
||||
return cudaErrorNotSupported;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pkill -9 -f python
|
||||
pkill -9 -f fastdeploy
|
||||
pkill -9 -f gunicorn
|
||||
pkill -9 -f redis-server
|
||||
# Kill redis-server if you need.
|
||||
#pkill -9 -f redis-server
|
||||
|
||||
sleep 1
|
||||
|
||||
@@ -159,16 +159,22 @@ class CacheMessager:
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
# value cache
|
||||
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
if val_cache_key in self.gpu_cache_kvs:
|
||||
val_cache = self.gpu_cache_kvs[val_cache_key]
|
||||
cache_v.append(val_cache)
|
||||
if paddle.is_compiled_with_xpu():
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
# key cache
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
if paddle.is_compiled_with_xpu():
|
||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
@@ -198,7 +204,6 @@ class CacheMessager:
|
||||
|
||||
elif protocol == "rdma":
|
||||
logger.info(f"splitwise_role rdma: {self.splitwise_role}, rank: {self.rank}, gpu_id: {gpu_id}")
|
||||
|
||||
self.messager[protocol] = RDMACommManager(
|
||||
splitwise_role,
|
||||
rank,
|
||||
@@ -460,16 +465,22 @@ class CacheMessagerV1:
|
||||
cache_v = []
|
||||
self.messager = {}
|
||||
for layer_idx in range(self.num_layers):
|
||||
# value cache
|
||||
val_cache_key = f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"
|
||||
if val_cache_key in self.gpu_cache_kvs:
|
||||
val_cache = self.gpu_cache_kvs[val_cache_key]
|
||||
cache_v.append(val_cache)
|
||||
if paddle.is_compiled_with_xpu():
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
# key cache
|
||||
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
|
||||
cache_k.append(key_cache)
|
||||
cache_v.append(val_cache)
|
||||
if paddle.is_compiled_with_xpu():
|
||||
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
|
||||
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
|
||||
else:
|
||||
cache_k_ptr_list.append(key_cache.data_ptr())
|
||||
cache_v_ptr_list.append(val_cache.data_ptr())
|
||||
cache_k_ptr_list = np.array(cache_k_ptr_list)
|
||||
cache_v_ptr_list = np.array(cache_v_ptr_list)
|
||||
|
||||
|
||||
@@ -245,6 +245,15 @@ class PrefixCacheManager:
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_manager_processes = []
|
||||
visible_devices = get_all_visible_devices()
|
||||
|
||||
val_cache_arg_str = ""
|
||||
if val_cache_shape:
|
||||
if isinstance(val_cache_shape, list):
|
||||
val_shape_str = ",".join(map(str, val_cache_shape))
|
||||
else:
|
||||
val_shape_str = str(val_cache_shape)
|
||||
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth "
|
||||
@@ -259,7 +268,7 @@ class PrefixCacheManager:
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ f" --value_cache_shape {val_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
@@ -332,6 +341,15 @@ class PrefixCacheManager:
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
cache_messager_processes = []
|
||||
visible_devices = get_all_visible_devices()
|
||||
|
||||
val_cache_arg_str = ""
|
||||
if value_cache_shape:
|
||||
if isinstance(value_cache_shape, list):
|
||||
val_shape_str = ",".join(map(str, value_cache_shape))
|
||||
else:
|
||||
val_shape_str = str(value_cache_shape)
|
||||
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth "
|
||||
@@ -345,7 +363,7 @@ class PrefixCacheManager:
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ f" --value_cache_shape {value_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
|
||||
+2
-2
@@ -198,8 +198,8 @@ int get_port_info(struct ibv_context* Context,
|
||||
int parse_port_ib_info();
|
||||
|
||||
// Memory region exchange
|
||||
bool client_exchange_mr(struct RdmaContext* ctx);
|
||||
bool server_exchange_mr(struct RdmaContext* ctx);
|
||||
bool client_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
|
||||
bool server_exchange_mr(struct RdmaContext* ctx, bool has_value_cache);
|
||||
bool server_send_memory_region(struct RdmaContext* ctx,
|
||||
void* local_mr,
|
||||
int byte_num);
|
||||
|
||||
@@ -149,6 +149,7 @@ class RDMACommunicator {
|
||||
struct ibv_pd* g_pd = NULL; // fd
|
||||
int RDMACommunicator_status; // Communicator status flag
|
||||
bool start_client_listener = false; // Client listener flag
|
||||
bool has_value_cache_; // MLA does not have value cache.
|
||||
};
|
||||
|
||||
#endif // KVCACHE_RDMA_H
|
||||
|
||||
+40
-19
@@ -712,8 +712,8 @@ bool exchange_mr_vector(struct RdmaContext *ctx,
|
||||
* @param ctx The RDMA context
|
||||
* @return true on success, false on failure
|
||||
*/
|
||||
bool client_exchange_mr(struct RdmaContext *ctx) {
|
||||
LOGD("verb client exchange mr: start");
|
||||
bool client_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
|
||||
LOGD("verb client exchange mr: start. has_value_cache=%d", has_value_cache);
|
||||
|
||||
if (ctx->conn.layer_number <= 0) {
|
||||
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
||||
@@ -723,19 +723,27 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
|
||||
auto layer_num = ctx->conn.layer_number;
|
||||
std::vector<void *> key_ptrs(layer_num);
|
||||
std::vector<uint32_t> key_rkeys(layer_num);
|
||||
std::vector<void *> val_ptrs(layer_num);
|
||||
std::vector<uint32_t> val_rkeys(layer_num);
|
||||
std::vector<void *> val_ptrs;
|
||||
std::vector<uint32_t> val_rkeys;
|
||||
if (has_value_cache) {
|
||||
val_ptrs.resize(layer_num);
|
||||
val_rkeys.resize(layer_num);
|
||||
}
|
||||
|
||||
if (!exchange_mr_vector(ctx, key_ptrs, true)) return false;
|
||||
if (!exchange_mr_vector(ctx, key_rkeys, true)) return false;
|
||||
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
|
||||
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
|
||||
if (has_value_cache) {
|
||||
if (!exchange_mr_vector(ctx, val_ptrs, true)) return false;
|
||||
if (!exchange_mr_vector(ctx, val_rkeys, true)) return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < layer_num; ++i) {
|
||||
ctx->conn.write_cache_key_remote_ptr_list.push_back(key_ptrs[i]);
|
||||
ctx->conn.write_cache_key_remote_rkey_list.push_back(key_rkeys[i]);
|
||||
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
|
||||
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
|
||||
if (has_value_cache) {
|
||||
ctx->conn.write_cache_value_remote_ptr_list.push_back(val_ptrs[i]);
|
||||
ctx->conn.write_cache_value_remote_rkey_list.push_back(val_rkeys[i]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -746,8 +754,8 @@ bool client_exchange_mr(struct RdmaContext *ctx) {
|
||||
* @param ctx The RDMA context
|
||||
* @return true on success, false on failure
|
||||
*/
|
||||
bool server_exchange_mr(struct RdmaContext *ctx) {
|
||||
LOGD("verbs server exchange mr: start");
|
||||
bool server_exchange_mr(struct RdmaContext *ctx, bool has_value_cache) {
|
||||
LOGD("verbs server exchange mr: start. has_value_cache=%d", has_value_cache);
|
||||
|
||||
if (ctx->conn.layer_number <= 0) {
|
||||
ERR("Invalid layer number: %d", ctx->conn.layer_number);
|
||||
@@ -759,8 +767,16 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
||||
auto &val_mrs = ctx->conn.write_cache_value_server_mr_list;
|
||||
|
||||
// Verify that server memory regions are properly initialized
|
||||
if (key_mrs.size() != layer_num || val_mrs.size() != layer_num) {
|
||||
ERR("server write cache memory region size error");
|
||||
if (key_mrs.size() != layer_num) {
|
||||
ERR("server write cache KEY memory region size error: %zu vs %d",
|
||||
key_mrs.size(),
|
||||
layer_num);
|
||||
return false;
|
||||
}
|
||||
if (has_value_cache && val_mrs.size() != layer_num) {
|
||||
ERR("server write cache VALUE memory region size error: %zu vs %d",
|
||||
val_mrs.size(),
|
||||
layer_num);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -772,22 +788,27 @@ bool server_exchange_mr(struct RdmaContext *ctx) {
|
||||
|
||||
send_key_ptrs.reserve(layer_num);
|
||||
send_key_rkeys.reserve(layer_num);
|
||||
send_val_ptrs.reserve(layer_num);
|
||||
send_val_rkeys.reserve(layer_num);
|
||||
if (has_value_cache) {
|
||||
send_val_ptrs.reserve(layer_num);
|
||||
send_val_rkeys.reserve(layer_num);
|
||||
}
|
||||
|
||||
// Collect memory region information from local MRs
|
||||
for (int i = 0; i < layer_num; ++i) {
|
||||
send_key_ptrs.push_back(reinterpret_cast<uint64_t>(key_mrs[i]->addr));
|
||||
send_key_rkeys.push_back(key_mrs[i]->rkey);
|
||||
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
|
||||
send_val_rkeys.push_back(val_mrs[i]->rkey);
|
||||
if (has_value_cache) {
|
||||
send_val_ptrs.push_back(reinterpret_cast<uint64_t>(val_mrs[i]->addr));
|
||||
send_val_rkeys.push_back(val_mrs[i]->rkey);
|
||||
}
|
||||
}
|
||||
|
||||
// Send all vectors to client
|
||||
if (!exchange_mr_vector(ctx, send_key_ptrs, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_key_rkeys, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
|
||||
if (has_value_cache) {
|
||||
if (!exchange_mr_vector(ctx, send_val_ptrs, false)) return false;
|
||||
if (!exchange_mr_vector(ctx, send_val_rkeys, false)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,18 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
throw std::runtime_error("Invalid layer number");
|
||||
}
|
||||
|
||||
if (local_cache_value_ptr_layer_head_.empty()) {
|
||||
has_value_cache_ = false;
|
||||
WARN(
|
||||
"Value Cache is empty (Maybe MLA Model). RDMA will run in Key-Only "
|
||||
"mode.");
|
||||
} else {
|
||||
has_value_cache_ = true;
|
||||
if (local_cache_value_ptr_layer_head_.size() != layer_number) {
|
||||
throw std::runtime_error("Key and Value cache layer number mismatch!");
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Setup cache vectors and pointers
|
||||
resize_vectors();
|
||||
assign_pointers();
|
||||
@@ -100,7 +112,6 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
});
|
||||
server_thread.detach();
|
||||
}
|
||||
|
||||
RDMACommunicator_status = 1;
|
||||
INFO("RDMA communicator initialized successfully");
|
||||
} catch (const std::exception& e) {
|
||||
@@ -119,7 +130,9 @@ void RDMACommunicator::resize_vectors() {
|
||||
}
|
||||
|
||||
local_cache_key_ptr_per_layer.resize(layer_number);
|
||||
local_cache_value_ptr_per_layer.resize(layer_number);
|
||||
if (has_value_cache_) {
|
||||
local_cache_value_ptr_per_layer.resize(layer_number);
|
||||
}
|
||||
}
|
||||
|
||||
void RDMACommunicator::assign_pointers() {
|
||||
@@ -131,15 +144,19 @@ void RDMACommunicator::assign_pointers() {
|
||||
// Assign pointers for each layer and block
|
||||
for (int layer_idx = 0; layer_idx < layer_number; ++layer_idx) {
|
||||
// Validate layer head pointers
|
||||
if (local_cache_key_ptr_layer_head_[layer_idx] == 0 ||
|
||||
local_cache_value_ptr_layer_head_[layer_idx] == 0) {
|
||||
if (local_cache_key_ptr_layer_head_[layer_idx] == 0) {
|
||||
throw std::runtime_error("Invalid cache pointer for layer " +
|
||||
std::to_string(layer_idx));
|
||||
}
|
||||
|
||||
// Resize block vectors for current layer
|
||||
local_cache_key_ptr_per_layer[layer_idx].resize(block_number);
|
||||
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
|
||||
|
||||
if (has_value_cache_) {
|
||||
if (local_cache_value_ptr_layer_head_[layer_idx] == 0) {
|
||||
throw std::runtime_error("Invalid VALUE cache pointer for layer " +
|
||||
std::to_string(layer_idx));
|
||||
}
|
||||
local_cache_value_ptr_per_layer[layer_idx].resize(block_number);
|
||||
}
|
||||
|
||||
// Calculate and assign block pointers
|
||||
for (int block_idx = 0; block_idx < block_number; ++block_idx) {
|
||||
@@ -147,9 +164,12 @@ void RDMACommunicator::assign_pointers() {
|
||||
reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[layer_idx] +
|
||||
block_idx * block_size_byte);
|
||||
|
||||
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
|
||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[layer_idx] +
|
||||
block_idx * block_size_byte);
|
||||
if (has_value_cache_) {
|
||||
local_cache_value_ptr_per_layer[layer_idx][block_idx] =
|
||||
reinterpret_cast<void*>(
|
||||
local_cache_value_ptr_layer_head_[layer_idx] +
|
||||
block_idx * block_size_byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -347,7 +367,7 @@ int RDMACommunicator::start_server(int sport, int sgid_idx, int gpu_index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
server_exchange_mr(ctx);
|
||||
server_exchange_mr(ctx, has_value_cache_);
|
||||
} else {
|
||||
auto ctx_iter = connectionContexts.find(event_fd);
|
||||
if (ctx_iter == connectionContexts.end()) {
|
||||
@@ -435,18 +455,33 @@ bool RDMACommunicator::deregister_memory_regions(struct RdmaContext* ctx) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
||||
if (!write_mr_key_list.empty() && !write_mr_value_list.empty()) {
|
||||
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
|
||||
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
|
||||
layer_idx);
|
||||
}
|
||||
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
|
||||
ERR("Failed to deregister memory region: write_mr_value_list, layer %d",
|
||||
layer_idx);
|
||||
if (!write_mr_key_list.empty()) {
|
||||
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
||||
if (write_mr_key_list[layer_idx]) {
|
||||
if (ibv_dereg_mr(write_mr_key_list[layer_idx])) {
|
||||
ERR("Failed to deregister memory region: write_mr_key_list, layer %d",
|
||||
layer_idx);
|
||||
}
|
||||
write_mr_key_list[layer_idx] = nullptr;
|
||||
}
|
||||
}
|
||||
write_mr_key_list.clear();
|
||||
}
|
||||
|
||||
if (!write_mr_value_list.empty()) {
|
||||
for (int layer_idx = 0; layer_idx < layer_number; layer_idx++) {
|
||||
if (write_mr_value_list[layer_idx]) {
|
||||
if (ibv_dereg_mr(write_mr_value_list[layer_idx])) {
|
||||
ERR("Failed to deregister memory region: write_mr_value_list, layer "
|
||||
"%d",
|
||||
layer_idx);
|
||||
}
|
||||
write_mr_value_list[layer_idx] = nullptr;
|
||||
}
|
||||
}
|
||||
write_mr_value_list.clear();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -548,7 +583,7 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ERR("Couldn't getexchange port infodestinations");
|
||||
return static_cast<int>(ConnStatus::kError);
|
||||
} else {
|
||||
client_exchange_mr(ctx);
|
||||
client_exchange_mr(ctx, has_value_cache_);
|
||||
}
|
||||
|
||||
// Allocate RDMA read and register read buffers
|
||||
@@ -735,15 +770,17 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
if (!write_mr_key_list.empty() || !write_mr_value_list.empty()) {
|
||||
if (!write_mr_key_list.empty()) {
|
||||
WARN("Memory regions already registered");
|
||||
return true;
|
||||
}
|
||||
|
||||
const size_t list_size = layer_number;
|
||||
write_mr_key_list.resize(list_size, nullptr);
|
||||
write_mr_value_list.resize(list_size, nullptr);
|
||||
|
||||
if (has_value_cache_) {
|
||||
write_mr_value_list.resize(list_size, nullptr);
|
||||
}
|
||||
|
||||
const uint32_t access_flags =
|
||||
IBV_ACCESS_LOCAL_WRITE |
|
||||
@@ -753,8 +790,6 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
||||
|
||||
for (int i = 0; i < static_cast<int>(list_size); ++i) {
|
||||
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
||||
void* val_ptr =
|
||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
||||
|
||||
write_mr_key_list[i] =
|
||||
@@ -765,13 +800,18 @@ bool RDMACommunicator::client_mr_register_per_layer(RdmaContext* ctx) {
|
||||
access_flags);
|
||||
if (!write_mr_key_list[i]) goto fail;
|
||||
|
||||
write_mr_value_list[i] =
|
||||
register_memory_region(ctx->pd,
|
||||
val_ptr,
|
||||
size,
|
||||
"client_value_" + std::to_string(i),
|
||||
access_flags);
|
||||
if (!write_mr_value_list[i]) goto fail;
|
||||
if (has_value_cache_) {
|
||||
void* val_ptr =
|
||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||
|
||||
write_mr_value_list[i] =
|
||||
register_memory_region(ctx->pd,
|
||||
val_ptr,
|
||||
size,
|
||||
"client_value_" + std::to_string(i),
|
||||
access_flags);
|
||||
if (!write_mr_value_list[i]) goto fail;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -812,8 +852,6 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
|
||||
|
||||
for (int i = 0; i < layer_number; ++i) {
|
||||
void* key_ptr = reinterpret_cast<void*>(local_cache_key_ptr_layer_head_[i]);
|
||||
void* val_ptr =
|
||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||
size_t size = static_cast<size_t>(block_size_byte) * block_number;
|
||||
|
||||
struct ibv_mr* key_mr = register_memory_region(
|
||||
@@ -822,21 +860,25 @@ bool RDMACommunicator::server_mr_register_per_layer(RdmaContext* ctx) {
|
||||
ERR("Failed to register key MR at layer %d", i);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
struct ibv_mr* value_mr = register_memory_region(
|
||||
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
|
||||
if (!value_mr) {
|
||||
ERR("Failed to register value MR at layer %d", i);
|
||||
ibv_dereg_mr(key_mr);
|
||||
goto fail;
|
||||
}
|
||||
|
||||
write_cache_key_server_mr_list.push_back(key_mr);
|
||||
write_cache_value_server_mr_list.push_back(value_mr);
|
||||
|
||||
if (has_value_cache_) {
|
||||
void* val_ptr =
|
||||
reinterpret_cast<void*>(local_cache_value_ptr_layer_head_[i]);
|
||||
struct ibv_mr* value_mr = register_memory_region(
|
||||
ctx->pd, val_ptr, size, "value_" + std::to_string(i), access_flags);
|
||||
if (!value_mr) {
|
||||
ERR("Failed to register value MR at layer %d", i);
|
||||
ibv_dereg_mr(key_mr);
|
||||
goto fail;
|
||||
}
|
||||
write_cache_value_server_mr_list.push_back(value_mr);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->conn.write_cache_key_server_mr_list = write_cache_key_server_mr_list;
|
||||
ctx->conn.write_cache_value_server_mr_list = write_cache_value_server_mr_list;
|
||||
|
||||
return true;
|
||||
|
||||
fail:
|
||||
@@ -899,8 +941,12 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
||||
|
||||
uint32_t cache_key_rkey =
|
||||
ctx->conn.write_cache_key_remote_rkey_list[layer_idx];
|
||||
uint32_t cache_value_rkey =
|
||||
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||
|
||||
uint32_t cache_value_rkey = 0;
|
||||
if (has_value_cache_) {
|
||||
cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||
}
|
||||
|
||||
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
||||
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
||||
uint64_t offset_in_block =
|
||||
@@ -914,15 +960,19 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
||||
cache_key_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
|
||||
if (has_value_cache_) {
|
||||
char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
}
|
||||
}
|
||||
|
||||
ctx->conn.wc_target_count = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
int loop_count = has_value_cache_ ? 2 : 1;
|
||||
for (int i = 0; i < loop_count; ++i) {
|
||||
bool is_key = (i == 0);
|
||||
uint32_t rkey = (is_key ? cache_key_rkey : cache_value_rkey);
|
||||
std::vector<uint64_t>& remote_addr =
|
||||
@@ -1038,6 +1088,10 @@ void RDMACommunicator::prepare_write_requests(
|
||||
bool is_key,
|
||||
std::vector<uint64_t>& remote_addr,
|
||||
uint32_t rkey) {
|
||||
if (!is_key) {
|
||||
assert(!write_mr_value_list.empty() &&
|
||||
"Trying to process Value Cache but it is empty!");
|
||||
}
|
||||
auto block_num = local_block_ids.size();
|
||||
|
||||
for (size_t i = 0; i < block_num; ++i) {
|
||||
|
||||
@@ -40,11 +40,10 @@ class RDMACommManager:
|
||||
try:
|
||||
import rdma_comm
|
||||
except:
|
||||
logger.error(
|
||||
raise RuntimeError(
|
||||
"The installation of the RDMA library failed."
|
||||
"Confirm whether your network card supports RDMA transmission."
|
||||
)
|
||||
return
|
||||
self.messager = rdma_comm.RDMACommunicator(
|
||||
splitwise_role,
|
||||
gpu_id,
|
||||
|
||||
@@ -755,8 +755,9 @@ class LLMEngine:
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
self.dp_processed.append(
|
||||
multiprocessing.Process(
|
||||
ctx.Process(
|
||||
target=start_data_parallel_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
|
||||
@@ -205,7 +205,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
@@ -279,6 +278,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
"none",
|
||||
getattr(forward_meta, "max_input_length", -1),
|
||||
)
|
||||
@@ -422,10 +422,10 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
"none",
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
|
||||
@@ -307,6 +307,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
"none",
|
||||
getattr(forward_meta, "max_input_length", -1),
|
||||
)
|
||||
|
||||
@@ -258,10 +258,10 @@ class FusedMoE(nn.Layer):
|
||||
else:
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
if not (expert_id - self.expert_id_offset >= 0 and expert_id - self.expert_id_offset < self.num_local_experts):
|
||||
return
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
weight_need_transpose = getattr(param, "weight_need_transpose", False)
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
|
||||
@@ -341,6 +341,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
|
||||
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
|
||||
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
|
||||
|
||||
query, compressed_kv, key_pe = qkv_a_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
|
||||
)
|
||||
@@ -399,6 +400,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
self.num_attention_heads_tp * (self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
]
|
||||
)
|
||||
|
||||
fmha_out_decode = self.mla_attn(
|
||||
q=q_input,
|
||||
k=None,
|
||||
@@ -418,6 +420,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
.transpose([1, 0, 2])
|
||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
)
|
||||
|
||||
if fmha_out is None:
|
||||
fmha_out = fmha_out_decode
|
||||
else:
|
||||
@@ -515,6 +518,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -674,7 +678,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config)
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
@@ -741,6 +744,20 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
)
|
||||
return position_ids, mask_encoder_batch
|
||||
|
||||
def empty_input_forward(self):
|
||||
"""
|
||||
empty_input_forward
|
||||
"""
|
||||
fake_hidden_states = paddle.empty(
|
||||
shape=[1, self.fd_config.model_config.hidden_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
for i in range(
|
||||
self.fd_config.model_config.first_k_dense_replace,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
|
||||
@@ -2328,7 +2328,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
# 5. Post Process
|
||||
model_output_data = ModelOutputData(
|
||||
next_tokens=self.share_inputs["next_tokens"],
|
||||
|
||||
Reference in New Issue
Block a user