mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn (#7210)
* [BugFix] fix_flash_mask_attn_sm90 * [BugFix] fix_flash_mask_attn_sm90 * [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn * [BugFix] Fix batch_size derivation and relax shape checks in SM90 flash_mask_attn
This commit is contained in:
@@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
|
||||
PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape");
|
||||
PADDLE_ENFORCE(head_dim == 128, "Unmatched shape");
|
||||
PADDLE_ENFORCE(batch_size > 0, "Unmatched shape");
|
||||
PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape");
|
||||
PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape");
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
|
||||
Reference in New Issue
Block a user