make append_attn supports mask_offset (#3138)

* make append_attn supports mask_offset

* add unittest
This commit is contained in:
lzy
2025-08-14 18:40:55 +08:00
committed by GitHub
parent 6031f9a5f5
commit 1e06b9fa6d
10 changed files with 88 additions and 20 deletions
@@ -43,6 +43,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -141,6 +142,7 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -179,7 +181,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -250,7 +252,8 @@ __global__ void multi_query_append_attention_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -406,6 +409,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -502,7 +506,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -543,7 +547,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -616,7 +620,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -882,6 +887,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -939,6 +945,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1103,6 +1110,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1171,6 +1179,7 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -172,6 +173,7 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -248,7 +250,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -338,7 +340,8 @@ __global__ void multi_query_append_attention_c4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -505,6 +508,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -627,7 +631,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -706,7 +710,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -793,7 +797,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1088,6 +1093,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1151,6 +1157,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1335,6 +1342,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1411,6 +1419,7 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -48,6 +48,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -179,6 +180,7 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -216,7 +218,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -305,7 +307,8 @@ __global__ void multi_query_append_attention_c8_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -474,6 +477,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -601,7 +605,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -642,7 +646,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: chunk_len) /
: mask_offset ? 0 : chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -733,7 +737,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
q_len,
kv_len,
chunk_end,
s_frag);
s_frag,
mask_offset_this_seq);
}
// update m,d
@@ -1054,6 +1059,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1111,6 +1117,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1318,6 +1325,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1388,6 +1396,7 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -910,7 +910,8 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
float (*s_frag)[num_frags_z][8]) {
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -924,10 +925,15 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
+4
View File
@@ -27,6 +27,7 @@ struct AppendAttnMetaData {
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -477,6 +478,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
if (causal) { \
constexpr bool CAUSAL = true; \
__VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
}
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \