mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-26 18:20:06 +08:00
[Feature]Supports SWA based on appendattn (#6547)
This commit is contained in:
@@ -83,11 +83,73 @@ def append_attention(
|
||||
causal: bool = True,
|
||||
speculate_decoder: bool = False,
|
||||
sliding_window: int = 0,
|
||||
sink_size: int = 0,
|
||||
head_wise_full_hidden: int = 0,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
append_attention
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
if sliding_window > 0 and head_wise_full_hidden > 0:
|
||||
out_swa = append_attention_gpu(
|
||||
qkv.clone(),
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_scale,
|
||||
k_quant_scale,
|
||||
v_quant_scale,
|
||||
k_dequant_scale,
|
||||
v_dequant_scale,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
linear_shift,
|
||||
linear_smooth,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
sinks,
|
||||
rms_norm_eps,
|
||||
compute_type,
|
||||
cache_quant_type,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window,
|
||||
sink_size,
|
||||
)
|
||||
sliding_window = 0
|
||||
sink_size = 0
|
||||
|
||||
out = append_attention_gpu(
|
||||
qkv,
|
||||
key_cache,
|
||||
@@ -142,7 +204,12 @@ def append_attention(
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window,
|
||||
sink_size,
|
||||
)
|
||||
if head_wise_full_hidden > 0:
|
||||
out_swa[:, :head_wise_full_hidden] = out[:, :head_wise_full_hidden]
|
||||
return out_swa
|
||||
|
||||
return out
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -266,6 +333,7 @@ def append_attention_with_output(
|
||||
causal,
|
||||
speculate_decoder,
|
||||
sliding_window,
|
||||
0,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user