[Feature]Supports SWA based on appendattn (#6547)

This commit is contained in:
AIbin
2026-03-01 19:02:08 +08:00
committed by GitHub
parent ea4d10d174
commit 59b578c337
17 changed files with 410 additions and 240 deletions
@@ -1018,7 +1018,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int* mask_offset = nullptr,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1037,7 +1038,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
if (sliding_window > 0) {
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
if (swa_part < 0) swa_part = 0;
int sink_part = mask_offset[q_idx * 2] + 128; // sink_size = 128
int sink_part =
mask_offset[q_idx * 2] + sink_size; // sink_size = 128
out_of_boundary =
q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
kv_idx < mask_offset[q_idx * 2] ||