[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7210)

* [BugFix] fix_flash_mask_attn_sm90

* [BugFix] fix_flash_mask_attn_sm90

* [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn

* [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn
This commit is contained in:
xiaoxiaohehe001
2026-04-09 11:05:10 +08:00
committed by GitHub
parent 43ace7af25
commit 51efe27d76
2 changed files with 1 additions and 2 deletions
@@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
constexpr int kBlockM = 128;