[BugFix] Fix token_penalty kernel (#6069)

* fix token_penalty kernel

* try to fix xpu

* fix xpu

* fix unit test
This commit is contained in:
freeliuzc
2026-01-28 12:03:05 +08:00
committed by GitHub
parent 85db063da6
commit ce06c6dfb3
13 changed files with 320 additions and 246 deletions
@@ -29,6 +29,7 @@ def apply_penalty_multi_scores(
presence_penalties: paddle.Tensor,
temperature: paddle.Tensor,
bad_words_token_ids: paddle.Tensor,
bad_words_token_len: paddle.Tensor,
step_idx: paddle.Tensor,
min_dec_lens: paddle.Tensor,
eos_token_ids: paddle.Tensor,
@@ -49,6 +50,7 @@ def apply_penalty_multi_scores(
presence_penalties,
temperature,
bad_words_token_ids,
bad_words_token_len,
step_idx,
min_dec_lens,
eos_token_ids,
@@ -167,6 +169,7 @@ def apply_speculative_penalty_multi_scores(
presence_penalties: paddle.Tensor,
temperature: paddle.Tensor,
bad_words_token_ids: paddle.Tensor,
bad_tokens_len: paddle.Tensor,
step_idx: paddle.Tensor,
min_dec_lens: paddle.Tensor,
eos_token_ids: paddle.Tensor,
@@ -182,29 +185,49 @@ def apply_speculative_penalty_multi_scores(
from fastdeploy.model_executor.ops.gpu import (
speculate_get_token_penalty_multi_scores,
)
speculate_get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
bad_tokens_len,
step_idx,
min_dec_lens,
eos_token_ids,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_len,
)
elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
speculate_get_token_penalty_multi_scores,
)
speculate_get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_len,
)
else:
raise NotImplementedError
speculate_get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
seq_lens_this_time,
output_padding_offset,
output_cum_offsets,
max_len,
)
# inplace
return logits