support_lastnorm_gather_split_r2.4 (#5925)

* support_lastnorm_gather_split_r2.4

* support_lastnorm_gather_split_r2.4v1

* support_lastnorm_gather_split_r2.4v2
This commit is contained in:
xiaoluomi
2026-01-09 11:29:59 +08:00
committed by GitHub
parent 741a01562b
commit f12b7a7a19
9 changed files with 31 additions and 8 deletions
+2 -2
View File
@@ -1012,7 +1012,7 @@ class MTPProposer(Proposer):
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
@@ -1125,7 +1125,7 @@ class MTPProposer(Proposer):
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
)
# 4. Compute logits, Sample
logits = self.model.compute_logits(hidden_states)
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
sampled_token_ids, sampler_output = self.sampler(
logits,
self.sampling_metadata,