mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
* fix mtp logprob hang when include stop_seq
This commit is contained in:
@@ -839,15 +839,15 @@ class SpeculativeSampler(nn.Layer):
|
||||
raw_logprobs = target_logits.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
sampled_token_ids=share_inputs["accept_tokens"],
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
|
||||
Reference in New Issue
Block a user