[RL] Support GLM MTP RL Model (#6267)

This commit is contained in:
GoldPancake
2026-02-04 20:14:35 +08:00
committed by GitHub
parent 765df94e6c
commit 183b8d325a
10 changed files with 308 additions and 33 deletions
@@ -193,9 +193,6 @@ class FlashAttentionBackend(AttentionBackend):
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
forward_meta.attention_metadata = metadata
def forward_mixed(
@@ -241,6 +238,10 @@ class FlashAttentionBackend(AttentionBackend):
)
if forward_meta.max_len_tensor_cpu[1].item() > 0:
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
@@ -309,7 +310,7 @@ class FlashAttentionBackend(AttentionBackend):
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,