mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Support CudaGraph(add block attn cuda_graph support) (#6116)
* add block attn cuda_graph support
This commit is contained in:
@@ -79,6 +79,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||||
const paddle::Tensor& decoder_batch_map_cpu,
|
const paddle::Tensor& decoder_batch_map_cpu,
|
||||||
const paddle::Tensor& prefix_len_cpu,
|
const paddle::Tensor& prefix_len_cpu,
|
||||||
|
const paddle::Tensor& encoder_seq_lod,
|
||||||
|
const paddle::Tensor& decoder_seq_lod,
|
||||||
|
const paddle::Tensor& encoder_kv_lod,
|
||||||
|
const paddle::Tensor& encoder_batch_map,
|
||||||
|
const paddle::Tensor& decoder_context_len,
|
||||||
|
const paddle::Tensor& decoder_context_len_cache,
|
||||||
|
const paddle::Tensor& decoder_batch_map,
|
||||||
|
const paddle::Tensor& prefix_len,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales,
|
const paddle::optional<paddle::Tensor>& k_scales,
|
||||||
const paddle::optional<paddle::Tensor>& v_scales,
|
const paddle::optional<paddle::Tensor>& v_scales,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||||
@@ -225,14 +233,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
vsl.usual_lod_vp = {
|
vsl.usual_lod_vp = {
|
||||||
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
|
const_cast<int32_t*>(encoder_seq_lod_cpu.data<int32_t>()),
|
||||||
enc_batch + 1,
|
enc_batch + 1,
|
||||||
nullptr};
|
const_cast<int32_t*>(encoder_seq_lod.data<int32_t>())};
|
||||||
vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()),
|
vsl.kv_lod_vp = {const_cast<int32_t*>(encoder_kv_lod_cpu.data<int32_t>()),
|
||||||
enc_batch + 1,
|
enc_batch + 1,
|
||||||
nullptr};
|
const_cast<int32_t*>(encoder_kv_lod.data<int32_t>())};
|
||||||
vsl.slot_mapping_vp = {
|
vsl.slot_mapping_vp = {
|
||||||
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
|
const_cast<int32_t*>(encoder_batch_map_cpu.data<int32_t>()),
|
||||||
enc_batch,
|
enc_batch,
|
||||||
nullptr}; // real batch
|
const_cast<int32_t*>(encoder_batch_map.data<int32_t>())}; // real batch
|
||||||
param.max_valid_seqlen = max_enc_len;
|
param.max_valid_seqlen = max_enc_len;
|
||||||
param.max_kv_valid_seqlen = max_kv_len;
|
param.max_kv_valid_seqlen = max_kv_len;
|
||||||
// setting for prefix cache
|
// setting for prefix cache
|
||||||
@@ -252,7 +260,7 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
|
baidu::xpu::api::VectorParam<int32_t> prefix_lens_vp{
|
||||||
const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()),
|
const_cast<int32_t*>(prefix_len_cpu.data<int32_t>()),
|
||||||
enc_batch,
|
enc_batch,
|
||||||
nullptr};
|
const_cast<int32_t*>(prefix_len.data<int32_t>())};
|
||||||
|
|
||||||
float* fake_perhead_scale = nullptr;
|
float* fake_perhead_scale = nullptr;
|
||||||
if (is_cache_int8 && has_zp && is_prefix_cache) {
|
if (is_cache_int8 && has_zp && is_prefix_cache) {
|
||||||
@@ -550,20 +558,28 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
api::VectorParam<int32_t> decoder_context_len_vp = {
|
api::VectorParam<int32_t> decoder_context_len_vp = {
|
||||||
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
||||||
dec_batch,
|
dec_batch,
|
||||||
nullptr}; // use for speculative_attention_decoder seq_len in
|
const_cast<int32_t*>(
|
||||||
// MTP
|
decoder_context_len
|
||||||
|
.data<int32_t>())}; // use for speculative_attention_decoder
|
||||||
|
// seq_len in MTP
|
||||||
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
|
api::VectorParam<int32_t> decoder_context_len_cache_vp = {
|
||||||
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_context_len_cache_cpu.data<int32_t>()),
|
||||||
dec_batch,
|
dec_batch,
|
||||||
nullptr}; // use for split rope enc as prefix cache len in MTP
|
const_cast<int32_t*>(
|
||||||
|
decoder_context_len_cache
|
||||||
|
.data<int32_t>())}; // use for split rope enc as prefix cache
|
||||||
|
// len in MTP
|
||||||
api::VectorParam<int32_t> decoder_batch_map_vp = {
|
api::VectorParam<int32_t> decoder_batch_map_vp = {
|
||||||
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
||||||
dec_batch,
|
dec_batch,
|
||||||
nullptr}; // real batch
|
const_cast<int32_t*>(
|
||||||
|
decoder_batch_map.data<int32_t>())}; // real batch
|
||||||
api::VectorParam<int32_t> decoder_seq_lod_vp = {
|
api::VectorParam<int32_t> decoder_seq_lod_vp = {
|
||||||
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_seq_lod_cpu.data<int32_t>()),
|
||||||
dec_batch + 1,
|
dec_batch + 1,
|
||||||
nullptr}; // use for split rope enc as lod in MTP
|
const_cast<int32_t*>(
|
||||||
|
decoder_seq_lod
|
||||||
|
.data<int32_t>())}; // use for split rope enc as lod in MTP
|
||||||
|
|
||||||
// rope + cache
|
// rope + cache
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
@@ -771,11 +787,12 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
|||||||
vsl.usual_lod_vp = {
|
vsl.usual_lod_vp = {
|
||||||
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_context_len_cpu.data<int32_t>()),
|
||||||
dec_batch,
|
dec_batch,
|
||||||
nullptr};
|
const_cast<int32_t*>(decoder_context_len.data<int32_t>())};
|
||||||
vsl.slot_mapping_vp = {
|
vsl.slot_mapping_vp = {
|
||||||
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
const_cast<int32_t*>(decoder_batch_map_cpu.data<int32_t>()),
|
||||||
dec_batch,
|
dec_batch,
|
||||||
nullptr}; // real batch
|
const_cast<int32_t*>(
|
||||||
|
decoder_batch_map.data<int32_t>())}; // real batch
|
||||||
|
|
||||||
xftblock::Tensor q_buf(
|
xftblock::Tensor q_buf(
|
||||||
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
|
rt_guard, KV_BUF_TYPE, {total_dec_len, hidden_dim}, false, false);
|
||||||
@@ -1013,6 +1030,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||||
const paddle::Tensor& decoder_batch_map_cpu,
|
const paddle::Tensor& decoder_batch_map_cpu,
|
||||||
const paddle::Tensor& prefix_len_cpu,
|
const paddle::Tensor& prefix_len_cpu,
|
||||||
|
const paddle::Tensor& encoder_seq_lod,
|
||||||
|
const paddle::Tensor& decoder_seq_lod,
|
||||||
|
const paddle::Tensor& encoder_kv_lod,
|
||||||
|
const paddle::Tensor& encoder_batch_map,
|
||||||
|
const paddle::Tensor& decoder_context_len,
|
||||||
|
const paddle::Tensor& decoder_context_len_cache,
|
||||||
|
const paddle::Tensor& decoder_batch_map,
|
||||||
|
const paddle::Tensor& prefix_len,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales,
|
const paddle::optional<paddle::Tensor>& k_scales,
|
||||||
const paddle::optional<paddle::Tensor>& v_scales,
|
const paddle::optional<paddle::Tensor>& v_scales,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||||
@@ -1044,6 +1069,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
decoder_context_len_cache_cpu, \
|
decoder_context_len_cache_cpu, \
|
||||||
decoder_batch_map_cpu, \
|
decoder_batch_map_cpu, \
|
||||||
prefix_len_cpu, \
|
prefix_len_cpu, \
|
||||||
|
encoder_seq_lod, \
|
||||||
|
decoder_seq_lod, \
|
||||||
|
encoder_kv_lod, \
|
||||||
|
encoder_batch_map, \
|
||||||
|
decoder_context_len, \
|
||||||
|
decoder_context_len_cache, \
|
||||||
|
decoder_batch_map, \
|
||||||
|
prefix_len, \
|
||||||
k_scales, \
|
k_scales, \
|
||||||
v_scales, \
|
v_scales, \
|
||||||
k_scales_inv, \
|
k_scales_inv, \
|
||||||
@@ -1112,6 +1145,14 @@ PD_BUILD_STATIC_OP(block_attn)
|
|||||||
"decoder_context_len_cache_cpu",
|
"decoder_context_len_cache_cpu",
|
||||||
"decoder_batch_map_cpu",
|
"decoder_batch_map_cpu",
|
||||||
"prefix_len_cpu",
|
"prefix_len_cpu",
|
||||||
|
"encoder_seq_lod",
|
||||||
|
"decoder_seq_lod",
|
||||||
|
"encoder_kv_lod",
|
||||||
|
"encoder_batch_map",
|
||||||
|
"decoder_context_len",
|
||||||
|
"decoder_context_len_cache",
|
||||||
|
"decoder_batch_map",
|
||||||
|
"prefix_len",
|
||||||
paddle::Optional("k_scales"),
|
paddle::Optional("k_scales"),
|
||||||
paddle::Optional("v_scales"),
|
paddle::Optional("v_scales"),
|
||||||
paddle::Optional("k_scales_inv"),
|
paddle::Optional("k_scales_inv"),
|
||||||
|
|||||||
@@ -75,6 +75,14 @@ std::vector<paddle::Tensor> BlockAttn(
|
|||||||
const paddle::Tensor& decoder_context_len_cache_cpu,
|
const paddle::Tensor& decoder_context_len_cache_cpu,
|
||||||
const paddle::Tensor& decoder_batch_map_cpu,
|
const paddle::Tensor& decoder_batch_map_cpu,
|
||||||
const paddle::Tensor& prefix_len_cpu,
|
const paddle::Tensor& prefix_len_cpu,
|
||||||
|
const paddle::Tensor& encoder_seq_lod_xpu,
|
||||||
|
const paddle::Tensor& decoder_seq_lod_xpu,
|
||||||
|
const paddle::Tensor& encoder_kv_lod_xpu,
|
||||||
|
const paddle::Tensor& encoder_batch_map_xpu,
|
||||||
|
const paddle::Tensor& decoder_context_len_xpu,
|
||||||
|
const paddle::Tensor& decoder_context_len_cache_xpu,
|
||||||
|
const paddle::Tensor& decoder_batch_map_xpu,
|
||||||
|
const paddle::Tensor& prefix_len_xpu,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales,
|
const paddle::optional<paddle::Tensor>& k_scales,
|
||||||
const paddle::optional<paddle::Tensor>& v_scales,
|
const paddle::optional<paddle::Tensor>& v_scales,
|
||||||
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
const paddle::optional<paddle::Tensor>& k_scales_inv,
|
||||||
@@ -634,6 +642,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
|||||||
py::arg("decoder_context_len_cache_cpu"),
|
py::arg("decoder_context_len_cache_cpu"),
|
||||||
py::arg("decoder_batch_map_cpu"),
|
py::arg("decoder_batch_map_cpu"),
|
||||||
py::arg("prefix_len_cpu"),
|
py::arg("prefix_len_cpu"),
|
||||||
|
py::arg("encoder_seq_lod_xpu"),
|
||||||
|
py::arg("decoder_seq_lod_xpu"),
|
||||||
|
py::arg("encoder_kv_lod_xpu"),
|
||||||
|
py::arg("encoder_batch_map_xpu"),
|
||||||
|
py::arg("decoder_context_len_xpu"),
|
||||||
|
py::arg("decoder_context_len_cache_xpu"),
|
||||||
|
py::arg("decoder_batch_map_xpu"),
|
||||||
|
py::arg("prefix_len_xpu"),
|
||||||
py::arg("k_scales"),
|
py::arg("k_scales"),
|
||||||
py::arg("v_scales"),
|
py::arg("v_scales"),
|
||||||
py::arg("k_scales_inv"),
|
py::arg("k_scales_inv"),
|
||||||
|
|||||||
@@ -203,6 +203,14 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.decoder_context_len_cache_cpu,
|
forward_meta.decoder_context_len_cache_cpu,
|
||||||
forward_meta.decoder_batch_map_cpu,
|
forward_meta.decoder_batch_map_cpu,
|
||||||
forward_meta.prefix_len_cpu,
|
forward_meta.prefix_len_cpu,
|
||||||
|
forward_meta.encoder_seq_lod,
|
||||||
|
forward_meta.decoder_seq_lod,
|
||||||
|
forward_meta.encoder_kv_lod,
|
||||||
|
forward_meta.encoder_batch_map,
|
||||||
|
forward_meta.decoder_context_len,
|
||||||
|
forward_meta.decoder_context_len_cache,
|
||||||
|
forward_meta.decoder_batch_map,
|
||||||
|
forward_meta.prefix_len,
|
||||||
cache_k_scale,
|
cache_k_scale,
|
||||||
cache_v_scale,
|
cache_v_scale,
|
||||||
cache_k_out_scale,
|
cache_k_out_scale,
|
||||||
|
|||||||
Reference in New Issue
Block a user