[Feature]Supports SWA based on appendattn (#6547)

This commit is contained in:
AIbin
2026-03-01 19:02:08 +08:00
committed by GitHub
parent ea4d10d174
commit 59b578c337
17 changed files with 410 additions and 240 deletions
@@ -83,11 +83,73 @@ def append_attention(
causal: bool = True,
speculate_decoder: bool = False,
sliding_window: int = 0,
sink_size: int = 0,
head_wise_full_hidden: int = 0,
) -> paddle.Tensor:
"""
append_attention
"""
if current_platform.is_cuda():
if sliding_window > 0 and head_wise_full_hidden > 0:
out_swa = append_attention_gpu(
qkv.clone(),
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
rotary_embs,
attn_mask,
qkv_bias,
qkv_scale,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
cache_k_zp,
cache_v_zp,
linear_shift,
linear_smooth,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
sinks,
rms_norm_eps,
compute_type,
cache_quant_type,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder,
sliding_window,
sink_size,
)
sliding_window = 0
sink_size = 0
out = append_attention_gpu(
qkv,
key_cache,
@@ -142,7 +204,12 @@ def append_attention(
causal,
speculate_decoder,
sliding_window,
sink_size,
)
if head_wise_full_hidden > 0:
out_swa[:, :head_wise_full_hidden] = out[:, :head_wise_full_hidden]
return out_swa
return out
else:
raise NotImplementedError
@@ -266,6 +333,7 @@ def append_attention_with_output(
causal,
speculate_decoder,
sliding_window,
0,
)
else:
raise NotImplementedError