[Others] Remove useless code (#5404)

This commit is contained in:
周周周
2025-12-08 13:59:46 +08:00
committed by GitHub
parent 3066a0c34b
commit 2aea8a3a60
8 changed files with 139 additions and 166 deletions
@@ -206,20 +206,9 @@ class AppendAttentionBackend(AttentionBackend):
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
return key_cache_shape, value_cache_shape
def forward_mixed(
@@ -63,13 +63,7 @@ class FlashAttentionMetadata(AttentionMetadata):
FlashAttentionMetadata
"""
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
@@ -83,7 +77,6 @@ class FlashAttentionMetadata(AttentionMetadata):
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
@@ -133,9 +126,6 @@ class FlashAttentionBackend(AttentionBackend):
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None:
@@ -154,7 +144,8 @@ class FlashAttentionBackend(AttentionBackend):
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
# Note(ZKK): here must be consistent with append_attn_backend.py
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
)
@@ -172,27 +163,13 @@ class FlashAttentionBackend(AttentionBackend):
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
key_cache_shape[-1] = self.head_dim // 2
value_cache_shape = key_cache_shape
return key_cache_shape, value_cache_shape
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
@@ -215,18 +192,20 @@ class FlashAttentionBackend(AttentionBackend):
self.block_size,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
if forward_meta.max_len_tensor_cpu[1] > 0:
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -251,8 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
@@ -276,19 +254,21 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
if metadata.max_len_tensor_cpu[1] > 0:
use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0
if use_fa_do_prefill:
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
forward_meta.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
@@ -315,7 +295,7 @@ class FlashAttentionBackend(AttentionBackend):
q,
k,
v,
metadata.cu_seqlens_q,
forward_meta.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
@@ -327,23 +307,23 @@ class FlashAttentionBackend(AttentionBackend):
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode,
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
forward_meta.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
metadata.max_len_tensor_cpu_decoder,
metadata.rotary_embs,
metadata.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu,
forward_meta.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
@@ -378,7 +358,7 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_method is not None,
)
if metadata.max_len_tensor_cpu[1] > 0:
if use_fa_do_prefill:
merge_prefill_decode_output(
res_encoder,
res_decoder,
@@ -24,6 +24,7 @@ from fastdeploy.platforms import current_platform
def pre_cache_len_concat(
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
max_dec_len: int = 0,
@@ -32,7 +33,7 @@ def pre_cache_len_concat(
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
out = pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
out = pre_cache_len_concat(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, max_dec_len, block_size)
return out
else:
raise NotImplementedError