mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature]support swa & sink Based on appendattn (#6410)
* support swa & sink Based on appendattn
This commit is contained in:
@@ -1034,10 +1034,21 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len
|
||||
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2])
|
||||
: true;
|
||||
if (sliding_window > 0) {
|
||||
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
|
||||
if (swa_part < 0) swa_part = 0;
|
||||
int sink_part = mask_offset[q_idx * 2] + 128; // sink_size = 128
|
||||
out_of_boundary =
|
||||
q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2] ||
|
||||
(kv_idx >= sink_part && kv_idx < swa_part))
|
||||
: true;
|
||||
} else {
|
||||
out_of_boundary = q_idx < qo_len
|
||||
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2])
|
||||
: true;
|
||||
}
|
||||
} else if (sliding_window > 0) {
|
||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
|
||||
(int)qo_len -
|
||||
|
||||
Reference in New Issue
Block a user