mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
MLA clean code (#5979)
This commit is contained in:
@@ -124,13 +124,9 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
||||
@@ -150,6 +146,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
self.useless_tensor = paddle.randn([1]).cast("int32")
|
||||
|
||||
if self.flash_attn_func is None:
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
@@ -188,21 +186,21 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
|
||||
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
self.useless_tensor, # not used in mla
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
self.useless_tensor, # not used in mla
|
||||
self.useless_tensor, # not used in mla
|
||||
self.useless_tensor, # not used in mla
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
-1, # not need.
|
||||
-1, # not need.
|
||||
-1, # not need.
|
||||
self.block_size,
|
||||
)
|
||||
# MLA
|
||||
|
||||
Reference in New Issue
Block a user