diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6892996290..a5079f5e66 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1014,7 +1014,9 @@ class MTPProposer(Proposer): # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta) if self.enable_logprob and self.enable_draft_logprob and substep == 0: - first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"]) + first_token_logits = self.model.compute_logits( + self.model_inputs["first_token_hidden_states"], forward_meta=self.forward_meta + ) speculate_get_logits( self.model_inputs["draft_logits"],