[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7210)

* [BugFix] fix_flash_mask_attn_sm90

* [BugFix] fix_flash_mask_attn_sm90

* [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn

* [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn
This commit is contained in:
xiaoxiaohehe001
2026-04-09 11:05:10 +08:00
committed by GitHub
parent 43ace7af25
commit 51efe27d76
2 changed files with 1 additions and 2 deletions
@@ -309,7 +309,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]],
forward_meta.attn_cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,