mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Support CudaGraph(add block attn cuda_graph support) (#6116)
* add block attn cuda_graph support
This commit is contained in:
@@ -79,6 +79,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_kv_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_context_len,
|
||||
const paddle::Tensor& decoder_context_len_cache,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& prefix_len,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -225,14 +233,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
vsl.usual_lod_vp = {
|
||||
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
nullptr};
|
||||
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
|
||||
vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()),
|
||||
enc_batch + 1,
|
||||
nullptr};
|
||||
const_cast<int32_t*>(encoder_kv_lod.data<int32_t>())};
|
||||
vsl.slot_mapping_vp = {
|
||||
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
nullptr}; // real batch
|
||||
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())}; // real batch
|
||||
param.max_valid_seqlen = max_enc_len;
|
||||
param.max_kv_valid_seqlen = max_kv_len;
|
||||
// setting for prefix cache
|
||||
@@ -252,7 +260,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
|
||||
const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()),
|
||||
enc_batch,
|
||||
nullptr};
|
||||
const_cast<int32_t*>(prefix_len.data<int32_t>())};
|
||||
|
||||
float* fake_perhead_scale = nullptr;
|
||||
if (is_cache_int8 && has_zp && is_prefix_cache) {
|
||||
@@ -550,20 +558,28 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
api::VectorParam<int32_t> decoder_context_len_vp = {
|
||||
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr}; // use for speculative_attention_decoder seq_len in
|
||||
// MTP
|
||||
const_cast<int32_t*>(
|
||||
decoder_context_len
|
||||
.data<int32_t>())}; // use for speculative_attention_decoder
|
||||
// seq_len in MTP
|
||||
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
|
||||
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr}; // use for split rope enc as prefix cache len in MTP
|
||||
const_cast<int32_t*>(
|
||||
decoder_context_len_cache
|
||||
.data<int32_t>())}; // use for split rope enc as prefix cache
|
||||
// len in MTP
|
||||
api::VectorParam<int32_t> decoder_batch_map_vp = {
|
||||
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr}; // real batch
|
||||
const_cast<int32_t*>(
|
||||
decoder_batch_map.data<int32_t>())}; // real batch
|
||||
api::VectorParam<int32_t> decoder_seq_lod_vp = {
|
||||
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
|
||||
dec_batch + 1,
|
||||
nullptr}; // use for split rope enc as lod in MTP
|
||||
const_cast<int32_t*>(
|
||||
decoder_seq_lod
|
||||
.data<int32_t>())}; // use for split rope enc as lod in MTP
|
||||
|
||||
// rope + cache
|
||||
int ret = 0;
|
||||
@@ -771,11 +787,12 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
vsl.usual_lod_vp = {
|
||||
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr};
|
||||
const_cast<int32_t*>(decoder_context_len.data<int32_t>())};
|
||||
vsl.slot_mapping_vp = {
|
||||
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
||||
dec_batch,
|
||||
nullptr}; // real batch
|
||||
const_cast<int32_t*>(
|
||||
decoder_batch_map.data<int32_t>())}; // real batch
|
||||
|
||||
xftblock::Tensor q_buf(
|
||||
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
|
||||
@@ -1013,6 +1030,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod,
|
||||
const paddle::Tensor& decoder_seq_lod,
|
||||
const paddle::Tensor& encoder_kv_lod,
|
||||
const paddle::Tensor& encoder_batch_map,
|
||||
const paddle::Tensor& decoder_context_len,
|
||||
const paddle::Tensor& decoder_context_len_cache,
|
||||
const paddle::Tensor& decoder_batch_map,
|
||||
const paddle::Tensor& prefix_len,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -1044,6 +1069,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
decoder_context_len_cache_cpu, \
|
||||
decoder_batch_map_cpu, \
|
||||
prefix_len_cpu, \
|
||||
encoder_seq_lod, \
|
||||
decoder_seq_lod, \
|
||||
encoder_kv_lod, \
|
||||
encoder_batch_map, \
|
||||
decoder_context_len, \
|
||||
decoder_context_len_cache, \
|
||||
decoder_batch_map, \
|
||||
prefix_len, \
|
||||
k_scales, \
|
||||
v_scales, \
|
||||
k_scales_inv, \
|
||||
@@ -1112,6 +1145,14 @@ PD_BUILD_STATIC_OP(block_attn)
|
||||
"decoder_context_len_cache_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
"prefix_len_cpu",
|
||||
"encoder_seq_lod",
|
||||
"decoder_seq_lod",
|
||||
"encoder_kv_lod",
|
||||
"encoder_batch_map",
|
||||
"decoder_context_len",
|
||||
"decoder_context_len_cache",
|
||||
"decoder_batch_map",
|
||||
"prefix_len",
|
||||
paddle::Optional("k_scales"),
|
||||
paddle::Optional("v_scales"),
|
||||
paddle::Optional("k_scales_inv"),
|
||||
|
||||
@@ -75,6 +75,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||
const paddle::Tensor& decoder_batch_map_cpu,
|
||||
const paddle::Tensor& prefix_len_cpu,
|
||||
const paddle::Tensor& encoder_seq_lod_xpu,
|
||||
const paddle::Tensor& decoder_seq_lod_xpu,
|
||||
const paddle::Tensor& encoder_kv_lod_xpu,
|
||||
const paddle::Tensor& encoder_batch_map_xpu,
|
||||
const paddle::Tensor& decoder_context_len_xpu,
|
||||
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||
const paddle::Tensor& decoder_batch_map_xpu,
|
||||
const paddle::Tensor& prefix_len_xpu,
|
||||
const paddle::optional<paddle::Tensor>& k_scales,
|
||||
const paddle::optional<paddle::Tensor>& v_scales,
|
||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||
@@ -634,6 +642,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
py::arg("decoder_context_len_cache_cpu"),
|
||||
py::arg("decoder_batch_map_cpu"),
|
||||
py::arg("prefix_len_cpu"),
|
||||
py::arg("encoder_seq_lod_xpu"),
|
||||
py::arg("decoder_seq_lod_xpu"),
|
||||
py::arg("encoder_kv_lod_xpu"),
|
||||
py::arg("encoder_batch_map_xpu"),
|
||||
py::arg("decoder_context_len_xpu"),
|
||||
py::arg("decoder_context_len_cache_xpu"),
|
||||
py::arg("decoder_batch_map_xpu"),
|
||||
py::arg("prefix_len_xpu"),
|
||||
py::arg("k_scales"),
|
||||
py::arg("v_scales"),
|
||||
py::arg("k_scales_inv"),
|
||||
|
||||
Reference in New Issue
Block a user