[XPU] Support CudaGraph(add block attn cuda_graph support) (#6116)

* add block attn cuda_graph support
This commit is contained in:
yinwei
2026-01-20 19:33:11 +08:00
committed by GitHub
parent 00a6a73431
commit 51a8a2ed57
3 changed files with 76 additions and 11 deletions
+52 -11
View File
@@ -79,6 +79,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu, const paddle::Tensor& prefix_len_cpu,
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_kv_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_context_len,
const paddle::Tensor& decoder_context_len_cache,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& prefix_len,
const paddle::optional<paddle::Tensor>& k_scales, const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales, const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv, const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -225,14 +233,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.usual_lod_vp = { vsl.usual_lod_vp = {
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()), const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
enc_batch + 1, enc_batch + 1,
nullptr}; const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()), vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()),
enc_batch + 1, enc_batch + 1,
nullptr}; const_cast<int32_t*>(encoder_kv_lod.data<int32_t>())};
vsl.slot_mapping_vp = { vsl.slot_mapping_vp = {
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()), const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
enc_batch, enc_batch,
nullptr}; // real batch const_cast<int32_t*>(encoder_batch_map.data<int32_t>())}; // real batch
param.max_valid_seqlen = max_enc_len; param.max_valid_seqlen = max_enc_len;
param.max_kv_valid_seqlen = max_kv_len; param.max_kv_valid_seqlen = max_kv_len;
// setting for prefix cache // setting for prefix cache
@@ -252,7 +260,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{ baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()), const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()),
enc_batch, enc_batch,
nullptr}; const_cast<int32_t*>(prefix_len.data<int32_t>())};
float* fake_perhead_scale = nullptr; float* fake_perhead_scale = nullptr;
if (is_cache_int8 && has_zp && is_prefix_cache) { if (is_cache_int8 && has_zp && is_prefix_cache) {
@@ -550,20 +558,28 @@ std::vector<paddle::Tensor> BlockAttnKernel(
api::VectorParam<int32_t> decoder_context_len_vp = { api::VectorParam<int32_t> decoder_context_len_vp = {
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
dec_batch, dec_batch,
nullptr}; // use for speculative_attention_decoder seq_len in const_cast<int32_t*>(
// MTP decoder_context_len
.data<int32_t>())}; // use for speculative_attention_decoder
// seq_len in MTP
api::VectorParam<int32_t> decoder_context_len_cache_vp = { api::VectorParam<int32_t> decoder_context_len_cache_vp = {
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
dec_batch, dec_batch,
nullptr}; // use for split rope enc as prefix cache len in MTP const_cast<int32_t*>(
decoder_context_len_cache
.data<int32_t>())}; // use for split rope enc as prefix cache
// len in MTP
api::VectorParam<int32_t> decoder_batch_map_vp = { api::VectorParam<int32_t> decoder_batch_map_vp = {
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch, dec_batch,
nullptr}; // real batch const_cast<int32_t*>(
decoder_batch_map.data<int32_t>())}; // real batch
api::VectorParam<int32_t> decoder_seq_lod_vp = { api::VectorParam<int32_t> decoder_seq_lod_vp = {
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
dec_batch + 1, dec_batch + 1,
nullptr}; // use for split rope enc as lod in MTP const_cast<int32_t*>(
decoder_seq_lod
.data<int32_t>())}; // use for split rope enc as lod in MTP
// rope + cache // rope + cache
int ret = 0; int ret = 0;
@@ -771,11 +787,12 @@ std::vector<paddle::Tensor> BlockAttnKernel(
vsl.usual_lod_vp = { vsl.usual_lod_vp = {
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
dec_batch, dec_batch,
nullptr}; const_cast<int32_t*>(decoder_context_len.data<int32_t>())};
vsl.slot_mapping_vp = { vsl.slot_mapping_vp = {
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()), const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
dec_batch, dec_batch,
nullptr}; // real batch const_cast<int32_t*>(
decoder_batch_map.data<int32_t>())}; // real batch
xftblock::Tensor q_buf( xftblock::Tensor q_buf(
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false); rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
@@ -1013,6 +1030,14 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu, const paddle::Tensor& prefix_len_cpu,
const paddle::Tensor& encoder_seq_lod,
const paddle::Tensor& decoder_seq_lod,
const paddle::Tensor& encoder_kv_lod,
const paddle::Tensor& encoder_batch_map,
const paddle::Tensor& decoder_context_len,
const paddle::Tensor& decoder_context_len_cache,
const paddle::Tensor& decoder_batch_map,
const paddle::Tensor& prefix_len,
const paddle::optional<paddle::Tensor>& k_scales, const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales, const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv, const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -1044,6 +1069,14 @@ std::vector<paddle::Tensor> BlockAttn(
decoder_context_len_cache_cpu, \ decoder_context_len_cache_cpu, \
decoder_batch_map_cpu, \ decoder_batch_map_cpu, \
prefix_len_cpu, \ prefix_len_cpu, \
encoder_seq_lod, \
decoder_seq_lod, \
encoder_kv_lod, \
encoder_batch_map, \
decoder_context_len, \
decoder_context_len_cache, \
decoder_batch_map, \
prefix_len, \
k_scales, \ k_scales, \
v_scales, \ v_scales, \
k_scales_inv, \ k_scales_inv, \
@@ -1112,6 +1145,14 @@ PD_BUILD_STATIC_OP(block_attn)
"decoder_context_len_cache_cpu", "decoder_context_len_cache_cpu",
"decoder_batch_map_cpu", "decoder_batch_map_cpu",
"prefix_len_cpu", "prefix_len_cpu",
"encoder_seq_lod",
"decoder_seq_lod",
"encoder_kv_lod",
"encoder_batch_map",
"decoder_context_len",
"decoder_context_len_cache",
"decoder_batch_map",
"prefix_len",
paddle::Optional("k_scales"), paddle::Optional("k_scales"),
paddle::Optional("v_scales"), paddle::Optional("v_scales"),
paddle::Optional("k_scales_inv"), paddle::Optional("k_scales_inv"),
@@ -75,6 +75,14 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::Tensor& decoder_context_len_cache_cpu, const paddle::Tensor& decoder_context_len_cache_cpu,
const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& prefix_len_cpu, const paddle::Tensor& prefix_len_cpu,
const paddle::Tensor& encoder_seq_lod_xpu,
const paddle::Tensor& decoder_seq_lod_xpu,
const paddle::Tensor& encoder_kv_lod_xpu,
const paddle::Tensor& encoder_batch_map_xpu,
const paddle::Tensor& decoder_context_len_xpu,
const paddle::Tensor& decoder_context_len_cache_xpu,
const paddle::Tensor& decoder_batch_map_xpu,
const paddle::Tensor& prefix_len_xpu,
const paddle::optional<paddle::Tensor>& k_scales, const paddle::optional<paddle::Tensor>& k_scales,
const paddle::optional<paddle::Tensor>& v_scales, const paddle::optional<paddle::Tensor>& v_scales,
const paddle::optional<paddle::Tensor>& k_scales_inv, const paddle::optional<paddle::Tensor>& k_scales_inv,
@@ -634,6 +642,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("decoder_context_len_cache_cpu"), py::arg("decoder_context_len_cache_cpu"),
py::arg("decoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"),
py::arg("prefix_len_cpu"), py::arg("prefix_len_cpu"),
py::arg("encoder_seq_lod_xpu"),
py::arg("decoder_seq_lod_xpu"),
py::arg("encoder_kv_lod_xpu"),
py::arg("encoder_batch_map_xpu"),
py::arg("decoder_context_len_xpu"),
py::arg("decoder_context_len_cache_xpu"),
py::arg("decoder_batch_map_xpu"),
py::arg("prefix_len_xpu"),
py::arg("k_scales"), py::arg("k_scales"),
py::arg("v_scales"), py::arg("v_scales"),
py::arg("k_scales_inv"), py::arg("k_scales_inv"),
@@ -203,6 +203,14 @@ class XPUAttentionBackend(AttentionBackend):
forward_meta.decoder_context_len_cache_cpu, forward_meta.decoder_context_len_cache_cpu,
forward_meta.decoder_batch_map_cpu, forward_meta.decoder_batch_map_cpu,
forward_meta.prefix_len_cpu, forward_meta.prefix_len_cpu,
forward_meta.encoder_seq_lod,
forward_meta.decoder_seq_lod,
forward_meta.encoder_kv_lod,
forward_meta.encoder_batch_map,
forward_meta.decoder_context_len,
forward_meta.decoder_context_len_cache,
forward_meta.decoder_batch_map,
forward_meta.prefix_len,
cache_k_scale, cache_k_scale,
cache_v_scale, cache_v_scale,
cache_k_out_scale, cache_k_out_scale,