fix eb5 mtp(mix) (#6800)

This commit is contained in:
cmcamdy
2026-03-13 17:36:57 +08:00
committed by GitHub
parent 8c1a2827d3
commit 7591e0d6bc
5 changed files with 55 additions and 5 deletions
@@ -264,6 +264,10 @@ class XPUForwardMeta(ForwardMeta):
# for pd_disaggregation
kv_signal_sender: Optional[paddle.Tensor] = None
hidden_states: Optional[paddle.Tensor] = None
is_draft: bool = False
def copy_from(self, other: "XPUForwardMeta", skip_keys: Optional[list] = None):
"""
Synchronize attributes from another XPUForwardMeta object
@@ -190,7 +190,8 @@ class XPUAttentionBackend(AttentionBackend):
else:
q_norm_weight = None
k_norm_weight = None
# draft model not use rope3d now
use_rope3d = self.rope_3d and not forward_meta.is_draft
res = block_attn(
qkv,
forward_meta.caches[2 * layer.layer_id],
@@ -229,7 +230,7 @@ class XPUAttentionBackend(AttentionBackend):
metadata.kv_signal_data_list[layer.layer_id],
forward_meta.kv_signal_sender,
layer.use_neox_rotary_style,
self.rope_3d,
use_rope3d,
)
return res
@@ -1068,7 +1068,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)