[Speculative Decoding] Optimize draft logprob (#5842)

* optimize draft logprob

* fix ut
This commit is contained in:
GoldPancake
2025-12-31 13:35:56 +08:00
committed by GitHub
parent 9e45ef7ca9
commit 4e10ae5d99
6 changed files with 25 additions and 9 deletions
@@ -660,7 +660,11 @@ class SpeculativeSampler(nn.Layer):
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
)
@@ -837,9 +841,9 @@ class SpeculativeSampler(nn.Layer):
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
)
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
mask = idx < share_inputs["accept_num"].unsqueeze(1)
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
@@ -950,6 +954,7 @@ class MTPSampler(nn.Layer):
else:
raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
@@ -1002,7 +1007,11 @@ class MTPSampler(nn.Layer):
top_p_logprob = None
top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
real_token_top_p = (
sampling_metadata.top_p[:real_bsz]
.squeeze(1)
@@ -1079,7 +1088,7 @@ class MTPSampler(nn.Layer):
""" """
num_logprobs = sampling_metadata.max_num_logprobs
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0:
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(
@@ -1110,7 +1119,7 @@ class MTPSampler(nn.Layer):
token_ids = None
logprobs_tensors = None
if num_logprobs is not None and share_inputs["substep"] == 0:
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
token_ids = paddle.empty(real_token_num, dtype="int64")
speculate_insert_first_token(
token_ids,