diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index 51e4b5df84..86c8693302 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -79,6 +79,14 @@ std::vector BlockAttnKernel( const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_kv_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_context_len, + const paddle::Tensor& decoder_context_len_cache, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& prefix_len, const paddle::optional& k_scales, const paddle::optional& v_scales, const paddle::optional& k_scales_inv, @@ -225,14 +233,14 @@ std::vector BlockAttnKernel( vsl.usual_lod_vp = { const_cast(encoder_seq_lod_cpu.data()), enc_batch + 1, - nullptr}; + const_cast(encoder_seq_lod.data())}; vsl.kv_lod_vp = {const_cast(encoder_kv_lod_cpu.data()), enc_batch + 1, - nullptr}; + const_cast(encoder_kv_lod.data())}; vsl.slot_mapping_vp = { const_cast(encoder_batch_map_cpu.data()), enc_batch, - nullptr}; // real batch + const_cast(encoder_batch_map.data())}; // real batch param.max_valid_seqlen = max_enc_len; param.max_kv_valid_seqlen = max_kv_len; // setting for prefix cache @@ -252,7 +260,7 @@ std::vector BlockAttnKernel( baidu::xpu::api::VectorParam prefix_lens_vp{ const_cast(prefix_len_cpu.data()), enc_batch, - nullptr}; + const_cast(prefix_len.data())}; float* fake_perhead_scale = nullptr; if (is_cache_int8 && has_zp && is_prefix_cache) { @@ -550,20 +558,28 @@ std::vector BlockAttnKernel( api::VectorParam decoder_context_len_vp = { const_cast(decoder_context_len_cpu.data()), dec_batch, - nullptr}; // use for speculative_attention_decoder seq_len in - // MTP + const_cast( + decoder_context_len + .data())}; // use for speculative_attention_decoder + // seq_len in MTP api::VectorParam decoder_context_len_cache_vp = { const_cast(decoder_context_len_cache_cpu.data()), dec_batch, - nullptr}; // use for split rope enc as prefix cache len in MTP + const_cast( + decoder_context_len_cache + .data())}; // use for split rope enc as prefix cache + // len in MTP api::VectorParam decoder_batch_map_vp = { const_cast(decoder_batch_map_cpu.data()), dec_batch, - nullptr}; // real batch + const_cast( + decoder_batch_map.data())}; // real batch api::VectorParam decoder_seq_lod_vp = { const_cast(decoder_seq_lod_cpu.data()), dec_batch + 1, - nullptr}; // use for split rope enc as lod in MTP + const_cast( + decoder_seq_lod + .data())}; // use for split rope enc as lod in MTP // rope + cache int ret = 0; @@ -771,11 +787,12 @@ std::vector BlockAttnKernel( vsl.usual_lod_vp = { const_cast(decoder_context_len_cpu.data()), dec_batch, - nullptr}; + const_cast(decoder_context_len.data())}; vsl.slot_mapping_vp = { const_cast(decoder_batch_map_cpu.data()), dec_batch, - nullptr}; // real batch + const_cast( + decoder_batch_map.data())}; // real batch xftblock::Tensor q_buf( rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false); @@ -1013,6 +1030,14 @@ std::vector BlockAttn( const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod, + const paddle::Tensor& decoder_seq_lod, + const paddle::Tensor& encoder_kv_lod, + const paddle::Tensor& encoder_batch_map, + const paddle::Tensor& decoder_context_len, + const paddle::Tensor& decoder_context_len_cache, + const paddle::Tensor& decoder_batch_map, + const paddle::Tensor& prefix_len, const paddle::optional& k_scales, const paddle::optional& v_scales, const paddle::optional& k_scales_inv, @@ -1044,6 +1069,14 @@ std::vector BlockAttn( decoder_context_len_cache_cpu, \ decoder_batch_map_cpu, \ prefix_len_cpu, \ + encoder_seq_lod, \ + decoder_seq_lod, \ + encoder_kv_lod, \ + encoder_batch_map, \ + decoder_context_len, \ + decoder_context_len_cache, \ + decoder_batch_map, \ + prefix_len, \ k_scales, \ v_scales, \ k_scales_inv, \ @@ -1112,6 +1145,14 @@ PD_BUILD_STATIC_OP(block_attn) "decoder_context_len_cache_cpu", "decoder_batch_map_cpu", "prefix_len_cpu", + "encoder_seq_lod", + "decoder_seq_lod", + "encoder_kv_lod", + "encoder_batch_map", + "decoder_context_len", + "decoder_context_len_cache", + "decoder_batch_map", + "prefix_len", paddle::Optional("k_scales"), paddle::Optional("v_scales"), paddle::Optional("k_scales_inv"), diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index c9d3641752..10b5e0bd11 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -75,6 +75,14 @@ std::vector BlockAttn( const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& prefix_len_cpu, + const paddle::Tensor& encoder_seq_lod_xpu, + const paddle::Tensor& decoder_seq_lod_xpu, + const paddle::Tensor& encoder_kv_lod_xpu, + const paddle::Tensor& encoder_batch_map_xpu, + const paddle::Tensor& decoder_context_len_xpu, + const paddle::Tensor& decoder_context_len_cache_xpu, + const paddle::Tensor& decoder_batch_map_xpu, + const paddle::Tensor& prefix_len_xpu, const paddle::optional& k_scales, const paddle::optional& v_scales, const paddle::optional& k_scales_inv, @@ -634,6 +642,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("decoder_context_len_cache_cpu"), py::arg("decoder_batch_map_cpu"), py::arg("prefix_len_cpu"), + py::arg("encoder_seq_lod_xpu"), + py::arg("decoder_seq_lod_xpu"), + py::arg("encoder_kv_lod_xpu"), + py::arg("encoder_batch_map_xpu"), + py::arg("decoder_context_len_xpu"), + py::arg("decoder_context_len_cache_xpu"), + py::arg("decoder_batch_map_xpu"), + py::arg("prefix_len_xpu"), py::arg("k_scales"), py::arg("v_scales"), py::arg("k_scales_inv"), diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index 8b9481c9e9..cf9c0be90d 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -203,6 +203,14 @@ class XPUAttentionBackend(AttentionBackend): forward_meta.decoder_context_len_cache_cpu, forward_meta.decoder_batch_map_cpu, forward_meta.prefix_len_cpu, + forward_meta.encoder_seq_lod, + forward_meta.decoder_seq_lod, + forward_meta.encoder_kv_lod, + forward_meta.encoder_batch_map, + forward_meta.decoder_context_len, + forward_meta.decoder_context_len_cache, + forward_meta.decoder_batch_map, + forward_meta.prefix_len, cache_k_scale, cache_v_scale, cache_k_out_scale,