[Speculative Decoding] Optimize draft logprob (#5842)

* optimize draft logprob

* fix ut
This commit is contained in:
GoldPancake
2025-12-31 13:35:56 +08:00
committed by GitHub
parent 9e45ef7ca9
commit 4e10ae5d99
6 changed files with 25 additions and 9 deletions
+2 -1
View File
@@ -94,6 +94,7 @@ class MTPProposer(Proposer):
self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
self.enable_logprob = self.model_config.enable_logprob
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
# [mixed, prefill, decoder]
self.role = self.scheduler_config.splitwise_role
@@ -943,7 +944,7 @@ class MTPProposer(Proposer):
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
if self.enable_logprob and substep == 0:
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
speculate_get_logits(