[Bugfix] Fix mtp logprob hang problem when include stop_seq (#5927)

* fix mtp logprob hang when include stop_seq
This commit is contained in:
GoldPancake
2026-01-08 14:21:24 +08:00
committed by GitHub
parent dc170e3005
commit a1fc4e249e
2 changed files with 7 additions and 4 deletions
@@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer):
raw_logprobs = target_logits.clone()
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = share_inputs["accept_tokens"]
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
sampled_token_ids=share_inputs["accept_tokens"],
logprobs_tensors=logprobs_tensors,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],