[Feature]Supports SWA based on appendattn (#6547)

This commit is contained in:
AIbin
2026-03-01 19:02:08 +08:00
committed by GitHub
parent ea4d10d174
commit 59b578c337
17 changed files with 410 additions and 240 deletions
@@ -61,7 +61,8 @@ void CascadeAppendAttentionC16Kernel(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window) {
const int sliding_window,
const int sink_size = 0) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto num_heads = meta_data.q_num_heads;
@@ -122,7 +123,8 @@ void CascadeAppendAttentionC16Kernel(
is_decoder,
stream,
out,
sliding_window);
sliding_window,
sink_size);
})})})})})})
}
@@ -171,7 +173,8 @@ CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
@@ -218,7 +221,8 @@ CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -264,7 +268,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -310,7 +315,8 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
@@ -357,7 +363,8 @@ CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -403,4 +410,5 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -61,7 +61,8 @@ void CascadeAppendAttentionC4Kernel(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window) {
const int sliding_window = 0,
const int sink_size = 0) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto num_heads = meta_data.q_num_heads;
@@ -126,7 +127,8 @@ void CascadeAppendAttentionC4Kernel(
is_decoder,
stream,
out,
sliding_window);
sliding_window,
sink_size);
})})})})})})
}
@@ -175,7 +177,8 @@ CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
@@ -222,7 +225,8 @@ CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -268,7 +272,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -314,7 +319,8 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m3fn>(
@@ -361,7 +367,8 @@ CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m3fn>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const AppendAttnMetaData& meta_data,
@@ -407,4 +414,5 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -62,7 +62,8 @@ void CascadeAppendAttentionC8Kernel(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window) {
const int sliding_window = 0,
const int sink_size = 0) {
const auto token_num = meta_data.token_nums;
const auto block_size = meta_data.block_size;
const auto num_heads = meta_data.q_num_heads;
@@ -131,7 +132,8 @@ void CascadeAppendAttentionC8Kernel(
is_decoder,
stream,
out,
sliding_window);
sliding_window,
sink_size);
})})})})})})})
}
@@ -173,7 +175,8 @@ CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, true>(
@@ -213,7 +216,8 @@ CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, false>(
@@ -253,7 +257,8 @@ CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, true>(
@@ -293,7 +298,8 @@ CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m3fn, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const AppendAttnMetaData& meta_data,
@@ -332,7 +338,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const AppendAttnMetaData& meta_data,
@@ -371,7 +378,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
@@ -411,7 +419,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
@@ -451,7 +460,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, false>(
@@ -491,7 +501,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void
CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, true>(
@@ -531,7 +542,8 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4m3fn, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const AppendAttnMetaData& meta_data,
@@ -570,7 +582,8 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const AppendAttnMetaData& meta_data,
@@ -609,4 +622,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -1018,7 +1018,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int* mask_offset = nullptr,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1037,7 +1038,8 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
if (sliding_window > 0) {
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
if (swa_part < 0) swa_part = 0;
int sink_part = mask_offset[q_idx * 2] + 128; // sink_size = 128
int sink_part =
mask_offset[q_idx * 2] + sink_size; // sink_size = 128
out_of_boundary =
q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
kv_idx < mask_offset[q_idx * 2] ||
@@ -13,11 +13,11 @@
// limitations under the License.
#pragma once
#include "append_attention_c16_impl.cuh"
#include "append_attention_c4_impl.cuh"
#include "append_attention_c8_impl.cuh"
#include "helper.h"
#include "utils.cuh"
#include "append_attention_c16_impl.cuh"
#include "append_attention_c8_impl.cuh"
#include "append_attention_c4_impl.cuh"
template <typename T, typename OutT>
void CascadeAppendAttentionKernel(
@@ -39,9 +39,8 @@ void CascadeAppendAttentionKernel(
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -66,160 +65,167 @@ void CascadeAppendAttentionKernel(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window) {
if (cache_quant_type_str == "none") {
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT, false>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
const int sliding_window = 0,
const int sink_size = 0) {
if (cache_quant_type_str == "none") {
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT, false>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_fp8" or
cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window,
sink_size);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
@@ -58,7 +58,8 @@ __global__ void multi_query_append_attention_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const uint32_t kv_num_heads = gridDim.z;
@@ -269,7 +270,8 @@ __global__ void multi_query_append_attention_kernel(
-1,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
// update m,d
@@ -463,7 +465,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -683,7 +686,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
// update m,d
@@ -878,7 +882,8 @@ void MultiQueryAppendAttention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window) {
const int sliding_window,
const int sink_size = 0) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -987,7 +992,8 @@ void MultiQueryAppendAttention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
@@ -1054,7 +1060,8 @@ void MultiQueryAppendAttention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
constexpr int blockx = HEAD_DIM / vec_size;
@@ -1208,7 +1215,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1290,7 +1298,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -53,4 +53,5 @@ void MultiQueryAppendAttention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -64,7 +64,8 @@ __global__ void multi_query_append_attention_c4_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -353,7 +354,8 @@ __global__ void multi_query_append_attention_c4_kernel(
-1,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -560,7 +562,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -852,7 +855,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1057,7 +1061,8 @@ void MultiQueryAppendC4Attention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window) {
const int sliding_window = 0,
const int sink_size = 0) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -1189,7 +1194,8 @@ void MultiQueryAppendC4Attention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
@@ -1263,7 +1269,8 @@ void MultiQueryAppendC4Attention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -1437,7 +1444,8 @@ void MultiQueryAppendC4Attention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1528,7 +1536,8 @@ void MultiQueryAppendC4Attention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -57,4 +57,5 @@ void MultiQueryAppendC4Attention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -67,7 +67,8 @@ __global__ void multi_query_append_attention_c8_kernel(
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head =
HEAD_DIM / num_elems_per_128b<T>(); // 128 / 8 = 16
constexpr uint32_t num_vecs_per_head_k =
@@ -369,7 +370,8 @@ __global__ void multi_query_append_attention_c8_kernel(
-1,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
// update m,d
@@ -607,7 +609,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const uint32_t attn_mask_len = -1,
const int sliding_window = 0) {
const int sliding_window = 0,
const int sink_size = 0) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -918,7 +921,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
sliding_window,
sink_size);
}
// update m,d
@@ -1148,7 +1152,8 @@ void MultiQueryAppendC8Attention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window) {
const int sliding_window,
const int sink_size = 0) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
using OUT_NV_TYPE = typename cascade_attn_type_traits<OutT>::type;
@@ -1316,7 +1321,8 @@ void MultiQueryAppendC8Attention(
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) {
@@ -1384,7 +1390,8 @@ void MultiQueryAppendC8Attention(
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
constexpr int blockx = HEAD_DIM / vec_size;
@@ -1589,7 +1596,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1674,7 +1682,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len,
sliding_window);
sliding_window,
sink_size);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -57,4 +57,5 @@ void MultiQueryAppendC8Attention(
const bool is_decoder,
cudaStream_t &stream,
paddle::Tensor *out,
const int sliding_window);
const int sliding_window,
const int sink_size);
@@ -36,7 +36,7 @@
],
"max_instances_per_file": 80,
"file_prefix": "multiquery_attention_c8_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n"
},
"multiquery_attention_c4": {
"name": "multiquery_attention_c4",
@@ -71,7 +71,7 @@
],
"max_instances_per_file": 160,
"file_prefix": "multiquery_attention_c4_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n"
},
"multiquery_attention_c16": {
"name": "multiquery_attention_c16",
@@ -106,7 +106,7 @@
],
"max_instances_per_file": 160,
"file_prefix": "multiquery_attention_c16_",
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window,\n const int sink_size);\n\n"
},
"multiquery_decoder_attention": {
"name": "multiquery_decoder_attention",