mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-05-06 15:40:33 +08:00
[BugFix] Fix token_penalty kernel (#6069)
* fix token_penalty kernel * try to fix xpu * fix xpu * fix unit test
This commit is contained in:
@@ -43,6 +43,8 @@ class SamplingMetadata:
|
||||
step_idx: paddle.Tensor
|
||||
|
||||
top_p: paddle.Tensor
|
||||
# only GPU used
|
||||
bad_words_token_len: Optional[paddle.Tensor] = None
|
||||
top_k: Optional[paddle.Tensor] = None
|
||||
top_k_list: Optional[list] = None
|
||||
min_p: Optional[paddle.Tensor] = None
|
||||
|
||||
@@ -29,6 +29,7 @@ def apply_penalty_multi_scores(
|
||||
presence_penalties: paddle.Tensor,
|
||||
temperature: paddle.Tensor,
|
||||
bad_words_token_ids: paddle.Tensor,
|
||||
bad_words_token_len: paddle.Tensor,
|
||||
step_idx: paddle.Tensor,
|
||||
min_dec_lens: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
@@ -49,6 +50,7 @@ def apply_penalty_multi_scores(
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
bad_words_token_len,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
@@ -167,6 +169,7 @@ def apply_speculative_penalty_multi_scores(
|
||||
presence_penalties: paddle.Tensor,
|
||||
temperature: paddle.Tensor,
|
||||
bad_words_token_ids: paddle.Tensor,
|
||||
bad_tokens_len: paddle.Tensor,
|
||||
step_idx: paddle.Tensor,
|
||||
min_dec_lens: paddle.Tensor,
|
||||
eos_token_ids: paddle.Tensor,
|
||||
@@ -182,29 +185,49 @@ def apply_speculative_penalty_multi_scores(
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
bad_tokens_len,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
elif current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
speculate_get_token_penalty_multi_scores,
|
||||
)
|
||||
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
speculate_get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
seq_lens_this_time,
|
||||
output_padding_offset,
|
||||
output_cum_offsets,
|
||||
max_len,
|
||||
)
|
||||
|
||||
# inplace
|
||||
return logits
|
||||
|
||||
|
||||
@@ -512,6 +512,7 @@ class Sampler(nn.Layer):
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
@@ -752,6 +753,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
@@ -896,6 +898,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
@@ -1126,6 +1129,7 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
@@ -1178,6 +1182,7 @@ class MTPSampler(nn.Layer):
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
sampling_metadata.bad_words_token_len,
|
||||
sampling_metadata.step_idx,
|
||||
sampling_metadata.min_dec_lens,
|
||||
sampling_metadata.eos_token_ids,
|
||||
|
||||
Reference in New Issue
Block a user