mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Speculative Decoding] Optimize draft logprob (#5842)
* optimize draft logprob * fix ut
This commit is contained in:
@@ -94,6 +94,7 @@ class MTPProposer(Proposer):
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
self.enable_logprob = self.model_config.enable_logprob
|
||||
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = self.scheduler_config.splitwise_role
|
||||
@@ -943,7 +944,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
if self.enable_logprob and substep == 0:
|
||||
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
|
||||
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
|
||||
|
||||
speculate_get_logits(
|
||||
|
||||
Reference in New Issue
Block a user