[Optimization] Fuse get_max_len and get_kv_max_len (#4369)
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* opt split_q_block

* fuse max_lens and max kv_len
This commit is contained in:
Sunny-bot1
2025-10-13 20:35:00 +08:00
committed by GitHub
parent 425205b03c
commit a751d977bc
15 changed files with 29 additions and 116 deletions
+4 -15
View File
@@ -59,7 +59,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -103,6 +102,7 @@ void AppendAttentionKernel(
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
auto main_stream = qkv.stream();
static cudaEvent_t main_event;
@@ -245,7 +245,6 @@ void AppendAttentionKernel(
if (max_just_dec_len_this_time > 0) {
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
cudaStream_t exec_stream;
if (max_enc_len_this_time > 0) {
@@ -371,20 +370,20 @@ void AppendAttentionKernel(
case paddle::DataType::INT8:{
int8_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
case paddle::DataType::FLOAT8_E4M3FN:{
phi::dtype::float8_e4m3fn tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
break;
}
}
} else {
data_t tmp;
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
}
if (max_enc_len_this_time > 0) {
cudaEventRecord(decoder_event, exec_stream);
@@ -413,7 +412,6 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -539,7 +537,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
@@ -616,7 +613,6 @@ void AppendAttentionWithOutput(
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -695,7 +691,6 @@ void AppendAttentionWithOutput(
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
@@ -784,7 +779,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
@@ -848,7 +842,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
@@ -930,7 +923,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
@@ -987,7 +979,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
@@ -1046,7 +1037,6 @@ PD_BUILD_STATIC_OP(append_attention)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
@@ -1105,7 +1095,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),