[Feature]Supports SWA based on appendattn (#6547)

This commit is contained in:
AIbin
2026-03-01 19:02:08 +08:00
committed by GitHub
parent ea4d10d174
commit 59b578c337
17 changed files with 410 additions and 240 deletions
@@ -13,11 +13,11 @@
// limitations under the License.
#pragma once
#include "append_attention_c16_impl.cuh"
#include "append_attention_c4_impl.cuh"
#include "append_attention_c8_impl.cuh"
#include "helper.h"
#include "utils.cuh"
#include "append_attention_c16_impl.cuh"
#include "append_attention_c8_impl.cuh"
#include "append_attention_c4_impl.cuh"
template <typename T, typename OutT>
void CascadeAppendAttentionKernel(
@@ -39,9 +39,8 @@ void CascadeAppendAttentionKernel(
const paddle::optional<paddle::Tensor>&
shift_bias, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>&
sinks, // [num_heads]
smooth_weight, // [num_kv_heads, head_dim]
const paddle::optional<paddle::Tensor>& sinks, // [num_heads]
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
@@ -66,160 +65,167 @@ void CascadeAppendAttentionKernel(
const bool enable_prefill,
cudaStream_t& stream,
paddle::Tensor* out,
const int sliding_window) {
if (cache_quant_type_str == "none") {
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT, false>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
const int sliding_window = 0,
const int sink_size = 0) {
if (cache_quant_type_str == "none") {
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT, false>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_fp8" or
cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out,
sliding_window,
sink_size);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
sinks,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out,
sliding_window,
sink_size);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}