mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Optimize draft logprob (#5842)
* optimize draft logprob * fix ut
This commit is contained in:
@@ -660,7 +660,11 @@ class SpeculativeSampler(nn.Layer):
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
if (
|
||||
top_p_normalized_logprobs is not None
|
||||
and share_inputs is not None
|
||||
and sampling_metadata.top_p_normalized_logprobs_flag
|
||||
):
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
|
||||
)
|
||||
@@ -837,9 +841,9 @@ class SpeculativeSampler(nn.Layer):
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = paddle.concat(
|
||||
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
|
||||
)
|
||||
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
@@ -950,6 +954,7 @@ class MTPSampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
@@ -1002,7 +1007,11 @@ class MTPSampler(nn.Layer):
|
||||
top_p_logprob = None
|
||||
top_p_token_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
if (
|
||||
top_p_normalized_logprobs is not None
|
||||
and share_inputs is not None
|
||||
and sampling_metadata.top_p_normalized_logprobs_flag
|
||||
):
|
||||
real_token_top_p = (
|
||||
sampling_metadata.top_p[:real_bsz]
|
||||
.squeeze(1)
|
||||
@@ -1079,7 +1088,7 @@ class MTPSampler(nn.Layer):
|
||||
""" """
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(
|
||||
@@ -1110,7 +1119,7 @@ class MTPSampler(nn.Layer):
|
||||
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
|
||||
token_ids = paddle.empty(real_token_num, dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_ids,
|
||||
|
||||
Reference in New Issue
Block a user