[BugFix] Fix token_penalty kernel (#6069)

* fix token_penalty kernel

* try to fix xpu

* fix xpu

* fix unit test
This commit is contained in:
freeliuzc
2026-01-28 12:03:05 +08:00
committed by GitHub
parent 85db063da6
commit ce06c6dfb3
13 changed files with 320 additions and 246 deletions
@@ -512,6 +512,7 @@ class Sampler(nn.Layer):
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
@@ -752,6 +753,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
@@ -896,6 +898,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
@@ -1126,6 +1129,7 @@ class MTPSampler(nn.Layer):
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
@@ -1178,6 +1182,7 @@ class MTPSampler(nn.Layer):
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.bad_words_token_len,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,