fix eb5 mtp(mix) (#6800)

This commit is contained in:
cmcamdy
2026-03-13 17:36:57 +08:00
committed by GitHub
parent 8c1a2827d3
commit 7591e0d6bc
5 changed files with 55 additions and 5 deletions
@@ -1068,7 +1068,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]),
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)