[Metax] optimize flash mla (#4915)

This commit is contained in:
xiaozude
2025-11-12 16:43:46 +08:00
committed by GitHub
parent 9d9f5df8d0
commit c45b3ccb52
5 changed files with 37 additions and 38 deletions
@@ -728,7 +728,7 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
current_total_tokens = paddle.sum(seq_lens_this_time)
current_total_tokens = forward_meta.ids_remove_padding.shape[0]
position_ids = self.position_ids_buffer[:current_total_tokens]
mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens]