diff --git a/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu index 1e5c6ab731..6e1fbf6846 100644 --- a/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu @@ -46,13 +46,6 @@ std::vector PrefillDSMLAWriteCacheFP8( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& slot_mapping, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const paddle::optional& kv_signal_data, - const int max_seq_len, cudaStream_t& stream, paddle::Tensor* kv_cache) { typedef PDTraits traits_; @@ -90,34 +83,6 @@ std::vector PrefillDSMLAWriteCacheFP8( kv_lora_rank, pe_dim, block_size); - - // Handle PD disaggregation signal - const char* fmt_write_cache_completed_signal_str = - std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = - std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); - - if (fmt_write_cache_completed_signal_str && - (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || - std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc( - stream, - &(RemoteCacheKvIpc:: - save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc( - stream, - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast( - kv_signal_data.get().data()))); - } - } - } return {}; } @@ -130,13 +95,6 @@ std::vector DecodeDSMLAWriteCacheFP8( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& slot_mapping, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const int max_seq_len, - const bool speculate_decoder, cudaStream_t& stream, paddle::Tensor* kv_cache) { typedef PDTraits traits_; @@ -187,14 +145,7 @@ std::vector PrefillDSMLAWriteCache( const paddle::Tensor& kv_nope, const paddle::Tensor& kv_pe, const paddle::Tensor& slot_mapping, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const paddle::optional& kv_signal_data, const float* scale, - const int max_seq_len, cudaStream_t& stream, paddle::Tensor* kv_cache) { typedef PDTraits traits_; @@ -232,33 +183,6 @@ std::vector PrefillDSMLAWriteCache( block_size, scale); - // Handle PD disaggregation signal - const char* fmt_write_cache_completed_signal_str = - std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = - std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); - - if (fmt_write_cache_completed_signal_str && - (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || - std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc( - stream, - &(RemoteCacheKvIpc:: - save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc( - stream, - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast( - kv_signal_data.get().data()))); - } - } - } return {}; } @@ -372,15 +296,8 @@ std::vector DSMLAWriteCacheKernel( const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, const paddle::Tensor& slot_mapping, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const paddle::optional& kv_signal_data, const paddle::optional& scale, const std::string& cache_quant_type_str, - const int max_seq_len, const bool is_prefill) { cudaStream_t stream = kv_pe.stream(); AppendAttnMetaData meta_data; @@ -395,9 +312,7 @@ std::vector DSMLAWriteCacheKernel( meta_data.token_nums = kv_nope_dims[0]; meta_data.head_dims = kv_cache_dims[3]; meta_data.head_dims_v = nope_size; - meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = kv_cache_dims[2]; - meta_data.batch_size = seq_lens_decoder.dims()[0]; const float* scale_ptr = scale ? scale.get().data() : nullptr; @@ -411,13 +326,6 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_signal_data, - max_seq_len, stream, const_cast(&kv_cache)); } else { @@ -426,13 +334,6 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - false, stream, const_cast(&kv_cache)); } @@ -444,13 +345,6 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_signal_data, - max_seq_len, stream, const_cast(&kv_cache)); } else { @@ -459,13 +353,6 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - max_seq_len, - false, stream, const_cast(&kv_cache)); } @@ -482,14 +369,7 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_signal_data, scale_ptr, - max_seq_len, stream, const_cast(&kv_cache)); } @@ -499,14 +379,7 @@ std::vector DSMLAWriteCacheKernel( kv_nope, kv_pe, slot_mapping, - seq_lens, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_signal_data, scale_ptr, - max_seq_len, stream, const_cast(&kv_cache)); } @@ -588,18 +461,10 @@ PD_BUILD_STATIC_OP(ds_mla_write_cache) "kv_pe", "kv_cache", "slot_mapping", - "seq_lens", - "seq_lens_decoder", - "batch_id_per_token", - "cu_seqlens_q", - "block_tables", - paddle::Optional("kv_signal_data"), paddle::Optional("scale")}) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", - "max_seq_len: int", - "is_prefill: bool"}) + .Attrs({"cache_quant_type_str: std::string", "is_prefill: bool"}) .SetKernelFn(PD_KERNEL(DSMLAWriteCacheKernel)); PD_BUILD_STATIC_OP(indexer_k_quant_and_cache) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index cfac8abfd2..f6345b6afd 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -1215,15 +1215,8 @@ std::vector DSMLAWriteCacheKernel( const paddle::Tensor& kv_pe, const paddle::Tensor& kv_cache, const paddle::Tensor& slot_mapping, - const paddle::Tensor& seq_lens, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& block_tables, - const paddle::optional& kv_signal_data, const paddle::optional& scale, const std::string& cache_quant_type_str, - const int max_seq_len, const bool is_prefill); std::vector IndexerKQuantAndCacheKernel( diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index b38dcbab0a..1e82f4f61e 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -380,15 +380,8 @@ class DSAAttentionBackend(AttentionBackend): k_pe, latent_cache, metadata.slot_mapping, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - metadata.block_tables, - None, scale.cast(paddle.float32), "fp8_ds_mla", - self.max_seq_len, True, ) @@ -419,15 +412,8 @@ class DSAAttentionBackend(AttentionBackend): k_pe, latent_cache, metadata.slot_mapping, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_decoder, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - metadata.block_tables, - None, scale.cast(paddle.float32), "fp8_ds_mla", - self.max_seq_len, False, ) diff --git a/tests/operators/test_dsmla_writecache.py b/tests/operators/test_dsmla_writecache.py index c75c6221d3..20932ee4d3 100644 --- a/tests/operators/test_dsmla_writecache.py +++ b/tests/operators/test_dsmla_writecache.py @@ -143,15 +143,8 @@ class TestBasicPrefill(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, # kv_signal_data tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, # is_prefill ) @@ -173,15 +166,8 @@ class TestBasicDecode(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], False, # is_prefill ) @@ -205,15 +191,8 @@ class TestSingleToken(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -232,15 +211,8 @@ class TestLargeBatch(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -261,15 +233,8 @@ class TestUnalignedTokens(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -291,15 +256,8 @@ class TestQuantTypeFp8DsMla(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", # 主要测试的量化类型 - tensors["max_seq_len"], True, ) @@ -321,15 +279,8 @@ class TestQuantTypeNone(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, None, # scale 在无量化时可为 None "none", - tensors["max_seq_len"], True, ) self.assertIsNotNone(result) @@ -353,15 +304,8 @@ class TestWithoutScale(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, # kv_signal_data - None, # scale = None + None, "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -380,15 +324,8 @@ class TestWithoutKvSignalData(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, # kv_signal_data = None tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -410,15 +347,8 @@ class TestBfloat16Input(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -438,15 +368,8 @@ class TestFloat16Input(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) self.assertIsNotNone(result) @@ -471,15 +394,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, ) @@ -495,15 +411,8 @@ class TestDSMLAWriteCachePerformance(BaseDSMLAWriteCacheTest): tensors["kv_pe"], tensors["kv_cache"], tensors["slot_mapping"], - tensors["seq_lens"], - tensors["seq_lens_decoder"], - tensors["batch_id_per_token"], - tensors["cu_seqlens_q"], - tensors["block_tables"], - None, tensors["scale"], "fp8_ds_mla", - tensors["max_seq_len"], True, )