mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[RL] Support GLM MTP RL Model (#6267)
This commit is contained in:
@@ -193,9 +193,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
forward_meta.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
@@ -241,6 +238,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[1].item() > 0:
|
||||
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
@@ -309,7 +310,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
self.zero_seq_enc_lens_for_decode if use_fa_do_prefill else forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.batch_id_per_token,
|
||||
|
||||
Reference in New Issue
Block a user