mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Optimize draft logprob (#5842)
* optimize draft logprob * fix ut
This commit is contained in:
@@ -740,6 +740,8 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
self.num_extra_cache_layer = 0
|
self.num_extra_cache_layer = 0
|
||||||
|
|
||||||
|
self.enable_draft_logprob: bool = False
|
||||||
|
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|||||||
@@ -660,7 +660,11 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
top_p_logprob = None
|
top_p_logprob = None
|
||||||
top_p_token_mask = None
|
top_p_token_mask = None
|
||||||
|
|
||||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
if (
|
||||||
|
top_p_normalized_logprobs is not None
|
||||||
|
and share_inputs is not None
|
||||||
|
and sampling_metadata.top_p_normalized_logprobs_flag
|
||||||
|
):
|
||||||
real_token_top_p = (
|
real_token_top_p = (
|
||||||
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
|
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
|
||||||
)
|
)
|
||||||
@@ -837,9 +841,9 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
token_ids = share_inputs["accept_tokens"]
|
token_ids = share_inputs["accept_tokens"]
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
token_ids = paddle.concat(
|
idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
|
||||||
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
|
mask = idx < share_inputs["accept_num"].unsqueeze(1)
|
||||||
)
|
token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
|
||||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||||
|
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
@@ -950,6 +954,7 @@ class MTPSampler(nn.Layer):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
self.logprobs_mode = fd_config.model_config.logprobs_mode
|
||||||
|
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
|
||||||
|
|
||||||
def pre_process(self, skip_idx_list: List[int] = []):
|
def pre_process(self, skip_idx_list: List[int] = []):
|
||||||
"""pre process before running"""
|
"""pre process before running"""
|
||||||
@@ -1002,7 +1007,11 @@ class MTPSampler(nn.Layer):
|
|||||||
top_p_logprob = None
|
top_p_logprob = None
|
||||||
top_p_token_mask = None
|
top_p_token_mask = None
|
||||||
|
|
||||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
if (
|
||||||
|
top_p_normalized_logprobs is not None
|
||||||
|
and share_inputs is not None
|
||||||
|
and sampling_metadata.top_p_normalized_logprobs_flag
|
||||||
|
):
|
||||||
real_token_top_p = (
|
real_token_top_p = (
|
||||||
sampling_metadata.top_p[:real_bsz]
|
sampling_metadata.top_p[:real_bsz]
|
||||||
.squeeze(1)
|
.squeeze(1)
|
||||||
@@ -1079,7 +1088,7 @@ class MTPSampler(nn.Layer):
|
|||||||
""" """
|
""" """
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
|
||||||
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
|
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
|
||||||
if self.logprobs_mode == "raw_logprobs":
|
if self.logprobs_mode == "raw_logprobs":
|
||||||
raw_logprobs = self.compute_logprobs(
|
raw_logprobs = self.compute_logprobs(
|
||||||
@@ -1110,7 +1119,7 @@ class MTPSampler(nn.Layer):
|
|||||||
|
|
||||||
token_ids = None
|
token_ids = None
|
||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
if num_logprobs is not None and share_inputs["substep"] == 0:
|
if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
|
||||||
token_ids = paddle.empty(real_token_num, dtype="int64")
|
token_ids = paddle.empty(real_token_num, dtype="int64")
|
||||||
speculate_insert_first_token(
|
speculate_insert_first_token(
|
||||||
token_ids,
|
token_ids,
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class TokenProcessor:
|
|||||||
|
|
||||||
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
self.speculative_decoding = self.cfg.speculative_config.method is not None
|
||||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||||
|
self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob
|
||||||
|
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
if self.use_logprobs:
|
if self.use_logprobs:
|
||||||
@@ -446,7 +447,7 @@ class TokenProcessor:
|
|||||||
batch_result (list): batch results
|
batch_result (list): batch results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.cfg.speculative_config.method and self.use_logprobs:
|
if self.cfg.speculative_config.method and self.use_logprobs and self.enable_draft_logprob:
|
||||||
if mtype == 3: # target
|
if mtype == 3: # target
|
||||||
finished_batch_result, unfinished_batch_result = [], []
|
finished_batch_result, unfinished_batch_result = [], []
|
||||||
for r in batch_result:
|
for r in batch_result:
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class MTPProposer(Proposer):
|
|||||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||||
self.enable_logprob = self.model_config.enable_logprob
|
self.enable_logprob = self.model_config.enable_logprob
|
||||||
|
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
|
||||||
|
|
||||||
# [mixed, prefill, decoder]
|
# [mixed, prefill, decoder]
|
||||||
self.role = self.scheduler_config.splitwise_role
|
self.role = self.scheduler_config.splitwise_role
|
||||||
@@ -943,7 +944,7 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
# 4. Compute logits, Sample
|
# 4. Compute logits, Sample
|
||||||
logits = self.model.compute_logits(hidden_states)
|
logits = self.model.compute_logits(hidden_states)
|
||||||
if self.enable_logprob and substep == 0:
|
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
|
||||||
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
|
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
|
||||||
|
|
||||||
speculate_get_logits(
|
speculate_get_logits(
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
|||||||
cfg.speculative_config.method = "mtp" if speculative_decoding else None
|
cfg.speculative_config.method = "mtp" if speculative_decoding else None
|
||||||
cfg.speculative_config.num_speculative_tokens = 1
|
cfg.speculative_config.num_speculative_tokens = 1
|
||||||
cfg.model_config.enable_logprob = use_logprobs
|
cfg.model_config.enable_logprob = use_logprobs
|
||||||
|
cfg.speculative_config.enable_draft_logprob = True
|
||||||
|
|
||||||
processor = TokenProcessor.__new__(TokenProcessor)
|
processor = TokenProcessor.__new__(TokenProcessor)
|
||||||
processor.cfg = cfg
|
processor.cfg = cfg
|
||||||
@@ -139,6 +140,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
|
|||||||
processor.number_of_output_tokens = 0
|
processor.number_of_output_tokens = 0
|
||||||
processor.prefill_result_status = {}
|
processor.prefill_result_status = {}
|
||||||
processor.use_logprobs = use_logprobs
|
processor.use_logprobs = use_logprobs
|
||||||
|
processor.enable_draft_logprob = cfg.speculative_config.enable_draft_logprob
|
||||||
processor.num_draft_tokens = 0
|
processor.num_draft_tokens = 0
|
||||||
processor.num_accepted_tokens = 0
|
processor.num_accepted_tokens = 0
|
||||||
processor.num_emitted_tokens = 0
|
processor.num_emitted_tokens = 0
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class _DummyCfg:
|
|||||||
self.speculative_config = types.SimpleNamespace(
|
self.speculative_config = types.SimpleNamespace(
|
||||||
method=speculative_method,
|
method=speculative_method,
|
||||||
num_speculative_tokens=2,
|
num_speculative_tokens=2,
|
||||||
|
enable_draft_logprob=True,
|
||||||
)
|
)
|
||||||
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
|
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
|
||||||
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")
|
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")
|
||||||
|
|||||||
Reference in New Issue
Block a user