DSA clean code (#6827)

This commit is contained in:
周周周
2026-03-13 16:39:47 +08:00
committed by GitHub
parent 49fe68a518
commit 8c1a2827d3
4 changed files with 2 additions and 249 deletions
@@ -46,13 +46,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
@@ -90,34 +83,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCacheFP8(
kv_lora_rank,
pe_dim,
block_size);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
@@ -130,13 +95,6 @@ std::vector<paddle::Tensor> DecodeDSMLAWriteCacheFP8(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
@@ -187,14 +145,7 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
const paddle::Tensor& kv_nope,
const paddle::Tensor& kv_pe,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const float* scale,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* kv_cache) {
typedef PDTraits<T> traits_;
@@ -232,33 +183,6 @@ std::vector<paddle::Tensor> PrefillDSMLAWriteCache(
block_size,
scale);
// Handle PD disaggregation signal
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
stream,
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
stream,
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
return {};
}
@@ -372,15 +296,8 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& scale,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool is_prefill) {
cudaStream_t stream = kv_pe.stream();
AppendAttnMetaData meta_data;
@@ -395,9 +312,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
meta_data.token_nums = kv_nope_dims[0];
meta_data.head_dims = kv_cache_dims[3];
meta_data.head_dims_v = nope_size;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = seq_lens_decoder.dims()[0];
const float* scale_ptr = scale ? scale.get().data<float>() : nullptr;
@@ -411,13 +326,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
@@ -426,13 +334,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
@@ -444,13 +345,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
} else {
@@ -459,13 +353,6 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
false,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
@@ -482,14 +369,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
@@ -499,14 +379,7 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
kv_nope,
kv_pe,
slot_mapping,
seq_lens,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_signal_data,
scale_ptr,
max_seq_len,
stream,
const_cast<paddle::Tensor*>(&kv_cache));
}
@@ -588,18 +461,10 @@ PD_BUILD_STATIC_OP(ds_mla_write_cache)
"kv_pe",
"kv_cache",
"slot_mapping",
"seq_lens",
"seq_lens_decoder",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
paddle::Optional("kv_signal_data"),
paddle::Optional("scale")})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
.Attrs({"cache_quant_type_str: std::string",
"max_seq_len: int",
"is_prefill: bool"})
.Attrs({"cache_quant_type_str: std::string", "is_prefill: bool"})
.SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel));
PD_BUILD_STATIC_OP(indexer_k_quant_and_cache)
-7
View File
@@ -1215,15 +1215,8 @@ std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
const paddle::Tensor& kv_pe,
const paddle::Tensor& kv_cache,
const paddle::Tensor& slot_mapping,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& scale,
const std::string& cache_quant_type_str,
const int max_seq_len,
const bool is_prefill);
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
@@ -380,15 +380,8 @@ class DSAAttentionBackend(AttentionBackend):
k_pe,
latent_cache,
metadata.slot_mapping,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
None,
scale.cast(paddle.float32),
"fp8_ds_mla",
self.max_seq_len,
True,
)
@@ -419,15 +412,8 @@ class DSAAttentionBackend(AttentionBackend):
k_pe,
latent_cache,
metadata.slot_mapping,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
None,
scale.cast(paddle.float32),
"fp8_ds_mla",
self.max_seq_len,
False,
)
+1 -92
View File
@@ -143,15 +143,8 @@ class TestBasicPrefill(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None, # kv_signal_data
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True, # is_prefill
)
@@ -173,15 +166,8 @@ class TestBasicDecode(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
False, # is_prefill
)
@@ -205,15 +191,8 @@ class TestSingleToken(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -232,15 +211,8 @@ class TestLargeBatch(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -261,15 +233,8 @@ class TestUnalignedTokens(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -291,15 +256,8 @@ class TestQuantTypeFp8DsMla(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla", # 主要测试的量化类型
tensors["max_seq_len"],
True,
)
@@ -321,15 +279,8 @@ class TestQuantTypeNone(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
None, # scale 在无量化时可为 None
"none",
tensors["max_seq_len"],
True,
)
self.assertIsNotNone(result)
@@ -353,15 +304,8 @@ class TestWithoutScale(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None, # kv_signal_data
None, # scale = None
None,
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -380,15 +324,8 @@ class TestWithoutKvSignalData(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None, # kv_signal_data = None
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -410,15 +347,8 @@ class TestBfloat16Input(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -438,15 +368,8 @@ class TestFloat16Input(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
self.assertIsNotNone(result)
@@ -471,15 +394,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)
@@ -495,15 +411,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest):
tensors["kv_pe"],
tensors["kv_cache"],
tensors["slot_mapping"],
tensors["seq_lens"],
tensors["seq_lens_decoder"],
tensors["batch_id_per_token"],
tensors["cu_seqlens_q"],
tensors["block_tables"],
None,
tensors["scale"],
"fp8_ds_mla",
tensors["max_seq_len"],
True,
)