mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
DSA clean code (#6827)
This commit is contained in:
@@ -46,13 +46,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
@@ -90,34 +83,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
|
||||
kv_lora_rank,
|
||||
pe_dim,
|
||||
block_size);
|
||||
|
||||
// Handle PD disaggregation signal
|
||||
const char* fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&(RemoteCacheKvIpc::
|
||||
save_cache_kv_complete_signal_layerwise_per_query),
|
||||
(void*)nullptr);
|
||||
} else {
|
||||
if (kv_signal_data) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||
(void*)(const_cast<int64_t*>(
|
||||
kv_signal_data.get().data<int64_t>())));
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -130,13 +95,6 @@ std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const int max_seq_len,
|
||||
const bool speculate_decoder,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
@@ -187,14 +145,7 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
|
||||
const paddle::Tensor& kv_nope,
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const float* scale,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* kv_cache) {
|
||||
typedef PDTraits<T> traits_;
|
||||
@@ -232,33 +183,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
|
||||
block_size,
|
||||
scale);
|
||||
|
||||
// Handle PD disaggregation signal
|
||||
const char* fmt_write_cache_completed_signal_str =
|
||||
std::getenv("FLAGS_fmt_write_cache_completed_signal");
|
||||
const char* FLAGS_use_pd_disaggregation_per_chunk =
|
||||
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
|
||||
|
||||
if (fmt_write_cache_completed_signal_str &&
|
||||
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
|
||||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
|
||||
if (FLAGS_use_pd_disaggregation_per_chunk &&
|
||||
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
|
||||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&(RemoteCacheKvIpc::
|
||||
save_cache_kv_complete_signal_layerwise_per_query),
|
||||
(void*)nullptr);
|
||||
} else {
|
||||
if (kv_signal_data) {
|
||||
cudaLaunchHostFunc(
|
||||
stream,
|
||||
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
|
||||
(void*)(const_cast<int64_t*>(
|
||||
kv_signal_data.get().data<int64_t>())));
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -372,15 +296,8 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& scale,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool is_prefill) {
|
||||
cudaStream_t stream = kv_pe.stream();
|
||||
AppendAttnMetaData meta_data;
|
||||
@@ -395,9 +312,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
meta_data.token_nums = kv_nope_dims[0];
|
||||
meta_data.head_dims = kv_cache_dims[3];
|
||||
meta_data.head_dims_v = nope_size;
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = kv_cache_dims[2];
|
||||
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||
|
||||
const float* scale_ptr = scale ? scale.get().data<float>() : nullptr;
|
||||
|
||||
@@ -411,13 +326,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
} else {
|
||||
@@ -426,13 +334,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
false,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
@@ -444,13 +345,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
} else {
|
||||
@@ -459,13 +353,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
max_seq_len,
|
||||
false,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
@@ -482,14 +369,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
scale_ptr,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
@@ -499,14 +379,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
kv_nope,
|
||||
kv_pe,
|
||||
slot_mapping,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
kv_signal_data,
|
||||
scale_ptr,
|
||||
max_seq_len,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&kv_cache));
|
||||
}
|
||||
@@ -588,18 +461,10 @@ PD_BUILD_STATIC_OP(ds_mla_write_cache)
|
||||
"kv_pe",
|
||||
"kv_cache",
|
||||
"slot_mapping",
|
||||
"seq_lens",
|
||||
"seq_lens_decoder",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("scale")})
|
||||
.Outputs({"kv_cache_out"})
|
||||
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
|
||||
.Attrs({"cache_quant_type_str: std::string",
|
||||
"max_seq_len: int",
|
||||
"is_prefill: bool"})
|
||||
.Attrs({"cache_quant_type_str: std::string", "is_prefill: bool"})
|
||||
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
|
||||
|
||||
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
|
||||
|
||||
@@ -1215,15 +1215,8 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
|
||||
const paddle::Tensor& kv_pe,
|
||||
const paddle::Tensor& kv_cache,
|
||||
const paddle::Tensor& slot_mapping,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& scale,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int max_seq_len,
|
||||
const bool is_prefill);
|
||||
|
||||
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
|
||||
|
||||
@@ -380,15 +380,8 @@ class DSAAttentionBackend(AttentionBackend):
|
||||
k_pe,
|
||||
latent_cache,
|
||||
metadata.slot_mapping,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
None,
|
||||
scale.cast(paddle.float32),
|
||||
"fp8_ds_mla",
|
||||
self.max_seq_len,
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -419,15 +412,8 @@ class DSAAttentionBackend(AttentionBackend):
|
||||
k_pe,
|
||||
latent_cache,
|
||||
metadata.slot_mapping,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
None,
|
||||
scale.cast(paddle.float32),
|
||||
"fp8_ds_mla",
|
||||
self.max_seq_len,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
@@ -143,15 +143,8 @@ class TestBasicPrefill(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None, # kv_signal_data
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True, # is_prefill
|
||||
)
|
||||
|
||||
@@ -173,15 +166,8 @@ class TestBasicDecode(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
False, # is_prefill
|
||||
)
|
||||
|
||||
@@ -205,15 +191,8 @@ class TestSingleToken(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -232,15 +211,8 @@ class TestLargeBatch(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -261,15 +233,8 @@ class TestUnalignedTokens(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -291,15 +256,8 @@ class TestQuantTypeFp8DsMla(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla", # 主要测试的量化类型
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -321,15 +279,8 @@ class TestQuantTypeNone(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
None, # scale 在无量化时可为 None
|
||||
"none",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
@@ -353,15 +304,8 @@ class TestWithoutScale(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None, # kv_signal_data
|
||||
None, # scale = None
|
||||
None,
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -380,15 +324,8 @@ class TestWithoutKvSignalData(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None, # kv_signal_data = None
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -410,15 +347,8 @@ class TestBfloat16Input(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -438,15 +368,8 @@ class TestFloat16Input(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
@@ -471,15 +394,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -495,15 +411,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
|
||||
tensors["kv_pe"],
|
||||
tensors["kv_cache"],
|
||||
tensors["slot_mapping"],
|
||||
tensors["seq_lens"],
|
||||
tensors["seq_lens_decoder"],
|
||||
tensors["batch_id_per_token"],
|
||||
tensors["cu_seqlens_q"],
|
||||
tensors["block_tables"],
|
||||
None,
|
||||
tensors["scale"],
|
||||
"fp8_ds_mla",
|
||||
tensors["max_seq_len"],
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user