[Speculative Decoding] Optimize draft logprob (#5842)

* optimize draft logprob

* fix ut
This commit is contained in:
GoldPancake
2025-12-31 13:35:56 +08:00
committed by GitHub
parent 9e45ef7ca9
commit 4e10ae5d99
6 changed files with 25 additions and 9 deletions
+2
View File
@@ -740,6 +740,8 @@ class SpeculativeConfig:
self.num_extra_cache_layer = 0 self.num_extra_cache_layer = 0
self.enable_draft_logprob: bool = False
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@@ -660,7 +660,11 @@ class SpeculativeSampler(nn.Layer):
top_p_logprob = None top_p_logprob = None
top_p_token_mask = None top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None: if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
real_token_top_p = ( real_token_top_p = (
sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1)
) )
@@ -837,9 +841,9 @@ class SpeculativeSampler(nn.Layer):
logprobs_tensors = None logprobs_tensors = None
token_ids = share_inputs["accept_tokens"] token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None: if num_logprobs is not None:
token_ids = paddle.concat( idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32")
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)] mask = idx < share_inputs["accept_num"].unsqueeze(1)
) token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
@@ -950,6 +954,7 @@ class MTPSampler(nn.Layer):
else: else:
raise NotImplementedError raise NotImplementedError
self.logprobs_mode = fd_config.model_config.logprobs_mode self.logprobs_mode = fd_config.model_config.logprobs_mode
self.enable_draft_logprob = fd_config.speculative_config.enable_draft_logprob
def pre_process(self, skip_idx_list: List[int] = []): def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running""" """pre process before running"""
@@ -1002,7 +1007,11 @@ class MTPSampler(nn.Layer):
top_p_logprob = None top_p_logprob = None
top_p_token_mask = None top_p_token_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None: if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
real_token_top_p = ( real_token_top_p = (
sampling_metadata.top_p[:real_bsz] sampling_metadata.top_p[:real_bsz]
.squeeze(1) .squeeze(1)
@@ -1079,7 +1088,7 @@ class MTPSampler(nn.Layer):
""" """ """ """
num_logprobs = sampling_metadata.max_num_logprobs num_logprobs = sampling_metadata.max_num_logprobs
real_bsz = share_inputs["seq_lens_this_time"].shape[0] real_bsz = share_inputs["seq_lens_this_time"].shape[0]
if num_logprobs is not None and share_inputs["substep"] == 0: if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
real_token_num = share_inputs["batch_token_num"][:real_bsz].sum() real_token_num = share_inputs["batch_token_num"][:real_bsz].sum()
if self.logprobs_mode == "raw_logprobs": if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs( raw_logprobs = self.compute_logprobs(
@@ -1110,7 +1119,7 @@ class MTPSampler(nn.Layer):
token_ids = None token_ids = None
logprobs_tensors = None logprobs_tensors = None
if num_logprobs is not None and share_inputs["substep"] == 0: if self.enable_draft_logprob and num_logprobs is not None and share_inputs["substep"] == 0:
token_ids = paddle.empty(real_token_num, dtype="int64") token_ids = paddle.empty(real_token_num, dtype="int64")
speculate_insert_first_token( speculate_insert_first_token(
token_ids, token_ids,
+2 -1
View File
@@ -77,6 +77,7 @@ class TokenProcessor:
self.speculative_decoding = self.cfg.speculative_config.method is not None self.speculative_decoding = self.cfg.speculative_config.method is not None
self.use_logprobs = self.cfg.model_config.enable_logprob self.use_logprobs = self.cfg.model_config.enable_logprob
self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob
if self.speculative_decoding: if self.speculative_decoding:
if self.use_logprobs: if self.use_logprobs:
@@ -446,7 +447,7 @@ class TokenProcessor:
batch_result (list): batch results batch_result (list): batch results
""" """
try: try:
if self.cfg.speculative_config.method and self.use_logprobs: if self.cfg.speculative_config.method and self.use_logprobs and self.enable_draft_logprob:
if mtype == 3: # target if mtype == 3: # target
finished_batch_result, unfinished_batch_result = [], [] finished_batch_result, unfinished_batch_result = [], []
for r in batch_result: for r in batch_result:
+2 -1
View File
@@ -94,6 +94,7 @@ class MTPProposer(Proposer):
self.mtp_strategy = self.speculative_config.mtp_strategy self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
self.enable_logprob = self.model_config.enable_logprob self.enable_logprob = self.model_config.enable_logprob
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
# [mixed, prefill, decoder] # [mixed, prefill, decoder]
self.role = self.scheduler_config.splitwise_role self.role = self.scheduler_config.splitwise_role
@@ -943,7 +944,7 @@ class MTPProposer(Proposer):
# 4. Compute logits, Sample # 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
if self.enable_logprob and substep == 0: if self.enable_logprob and self.enable_draft_logprob and substep == 0:
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"]) first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
speculate_get_logits( speculate_get_logits(
@@ -122,6 +122,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
cfg.speculative_config.method = "mtp" if speculative_decoding else None cfg.speculative_config.method = "mtp" if speculative_decoding else None
cfg.speculative_config.num_speculative_tokens = 1 cfg.speculative_config.num_speculative_tokens = 1
cfg.model_config.enable_logprob = use_logprobs cfg.model_config.enable_logprob = use_logprobs
cfg.speculative_config.enable_draft_logprob = True
processor = TokenProcessor.__new__(TokenProcessor) processor = TokenProcessor.__new__(TokenProcessor)
processor.cfg = cfg processor.cfg = cfg
@@ -139,6 +140,7 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
processor.number_of_output_tokens = 0 processor.number_of_output_tokens = 0
processor.prefill_result_status = {} processor.prefill_result_status = {}
processor.use_logprobs = use_logprobs processor.use_logprobs = use_logprobs
processor.enable_draft_logprob = cfg.speculative_config.enable_draft_logprob
processor.num_draft_tokens = 0 processor.num_draft_tokens = 0
processor.num_accepted_tokens = 0 processor.num_accepted_tokens = 0
processor.num_emitted_tokens = 0 processor.num_emitted_tokens = 0
+1
View File
@@ -53,6 +53,7 @@ class _DummyCfg:
self.speculative_config = types.SimpleNamespace( self.speculative_config = types.SimpleNamespace(
method=speculative_method, method=speculative_method,
num_speculative_tokens=2, num_speculative_tokens=2,
enable_draft_logprob=True,
) )
self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob) self.model_config = types.SimpleNamespace(enable_logprob=enable_logprob)
self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode") self.scheduler_config = types.SimpleNamespace(name="default", splitwise_role="decode")