mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Fuse get_max_len and get_kv_max_len (#4369)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* opt split_q_block * fuse max_lens and max kv_len
This commit is contained in:
@@ -59,7 +59,6 @@ void AppendAttentionKernel(
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -103,6 +102,7 @@ void AppendAttentionKernel(
|
||||
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
|
||||
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
|
||||
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
|
||||
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
|
||||
|
||||
auto main_stream = qkv.stream();
|
||||
static cudaEvent_t main_event;
|
||||
@@ -245,7 +245,6 @@ void AppendAttentionKernel(
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
|
||||
cudaStream_t exec_stream;
|
||||
if (max_enc_len_this_time > 0) {
|
||||
@@ -371,20 +370,20 @@ void AppendAttentionKernel(
|
||||
case paddle::DataType::INT8:{
|
||||
int8_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT8_E4M3FN:{
|
||||
phi::dtype::float8_e4m3fn tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
data_t tmp;
|
||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||
}
|
||||
if (max_enc_len_this_time > 0) {
|
||||
cudaEventRecord(decoder_event, exec_stream);
|
||||
@@ -413,7 +412,6 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
@@ -539,7 +537,6 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
@@ -616,7 +613,6 @@ void AppendAttentionWithOutput(
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -695,7 +691,6 @@ void AppendAttentionWithOutput(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
@@ -784,7 +779,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& set_max_lengths_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||
@@ -848,7 +842,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& set_max_lengths_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||
@@ -930,7 +923,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& set_max_lengths_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const std::vector<int64_t>& fmha_out_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
@@ -987,7 +979,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& set_max_lengths_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::DataType& fmha_out_dtype,
|
||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
@@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"set_max_lengths",
|
||||
"max_len_kv",
|
||||
paddle::Optional("rotary_embs"),
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
@@ -1105,7 +1095,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"set_max_lengths",
|
||||
"max_len_kv",
|
||||
"fmha_out",
|
||||
paddle::Optional("rotary_embs"),
|
||||
paddle::Optional("attn_mask"),
|
||||
|
||||
Reference in New Issue
Block a user