From 51efe27d76a1895281d9dc3d934a3eb1053f8aa8 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Thu, 9 Apr 2026 11:05:10 +0800 Subject: [PATCH] [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7210) * [BugFix] fix_flash_mask_attn_sm90 * [BugFix] fix_flash_mask_attn_sm90 * [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn * [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn --- custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu | 1 - .../model_executor/layers/attention/flash_mask_attn_backend.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0c..dce65b9727 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 6fb0a5c124..bdb018ec26 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -309,7 +309,7 @@ class FlashMaskAttentionBackend(AttentionBackend): q, k, v, - forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], forward_meta.attn_cu_seqlens_k, forward_meta.seq_lens_encoder, res_encoder,