MLA clean code (#5979)

This commit is contained in:
周周周
2026-01-10 21:05:00 +08:00
committed by GitHub
parent 62bd92f9ba
commit b8d9daa785
2 changed files with 28 additions and 29 deletions
@@ -124,13 +124,9 @@ class MLAAttentionBackend(AttentionBackend):
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
# For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
@@ -150,6 +146,8 @@ class MLAAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.useless_tensor = paddle.randn([1]).cast("int32")
if self.flash_attn_func is None:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
@@ -188,21 +186,21 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
self.useless_tensor, # not used in mla
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
self.useless_tensor, # not used in mla
self.useless_tensor, # not used in mla
self.useless_tensor, # not used in mla
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
-1, # not need.
-1, # not need.
-1, # not need.
self.block_size,
)
# MLA