[BugFix] Fix token_penalty kernel (#6069)

* fix token_penalty kernel

* try to fix xpu

* fix xpu

* fix unit test
This commit is contained in:
freeliuzc
2026-01-28 12:03:05 +08:00
committed by GitHub
parent 85db063da6
commit ce06c6dfb3
13 changed files with 320 additions and 246 deletions
+2
View File
@@ -507,6 +507,7 @@ class MTPProposer(Proposer):
self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"]
self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"]
self.model_inputs["bad_tokens_len"] = self.target_model_inputs["bad_tokens_len"]
# Integrate the updated results in model forward
self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"]
@@ -1000,6 +1001,7 @@ class MTPProposer(Proposer):
repetition_penalties=self.model_inputs["penalty_score"],
min_dec_lens=self.model_inputs["min_dec_len"],
bad_words_token_ids=self.model_inputs["bad_tokens"],
bad_words_token_len=self.model_inputs["bad_tokens_len"],
eos_token_ids=self.model_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"],