mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Feature]Supports SWA based on appendattn (#6547)
This commit is contained in:
@@ -1018,7 +1018,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int* mask_offset = nullptr,
|
||||
const int sliding_window = 0) {
|
||||
const int sliding_window = 0,
|
||||
const int sink_size = 0) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -1037,7 +1038,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
if (sliding_window > 0) {
|
||||
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
|
||||
if (swa_part < 0) swa_part = 0;
|
||||
int sink_part = mask_offset[q_idx * 2] + 128; // sink_size = 128
|
||||
int sink_part =
|
||||
mask_offset[q_idx * 2] + sink_size; // sink_size = 128
|
||||
out_of_boundary =
|
||||
q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
|
||||
kv_idx < mask_offset[q_idx * 2] ||
|
||||
|
||||
Reference in New Issue
Block a user