[Cherry-Pick][Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy(#7467) (#7468)

* fix repeat_time kernel and change default spec verify strategy

* fix unit_test
This commit is contained in:
freeliuzc
2026-04-18 00:07:34 +08:00
committed by GitHub
parent 650d1e49aa
commit 56b761de3f
4 changed files with 81 additions and 83 deletions
@@ -16,12 +16,12 @@
template <typename T> template <typename T>
__global__ inline void min_length_logits_process( __global__ inline void min_length_logits_process(
T *logits, T* logits,
const int64_t *cur_len, const int64_t* cur_len,
const int64_t *min_len, const int64_t* min_len,
const int64_t *eos_token_id, const int64_t* eos_token_id,
const int *batch_id_per_token_output, const int* batch_id_per_token_output,
const int *cu_seqlens_q_output, const int* cu_seqlens_q_output,
const int64_t token_num, const int64_t token_num,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
@@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process(
template <> template <>
__global__ inline void min_length_logits_process<half>( __global__ inline void min_length_logits_process<half>(
half *logits, half* logits,
const int64_t *cur_len, const int64_t* cur_len,
const int64_t *min_len, const int64_t* min_len,
const int64_t *eos_token_id, const int64_t* eos_token_id,
const int *batch_id_per_token_output, const int* batch_id_per_token_output,
const int *cu_seqlens_q_output, const int* cu_seqlens_q_output,
const int64_t token_num, const int64_t token_num,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
@@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process<half>(
} }
} }
__global__ void update_repeat_times(const int64_t *token_ids_all, __global__ void update_repeat_times(const int64_t* token_ids_all,
const int64_t *prompt_lens, const int64_t* prompt_lens,
const int64_t *cur_len, const int64_t* cur_len,
int *repeat_times, int* repeat_times,
const int *batch_id_per_token_output, const int* batch_id_per_token_output,
const int64_t token_num, const int64_t token_num,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
@@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
return; return;
} }
int tid = threadIdx.x; int tid = threadIdx.x;
const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
int *repeat_times_now = repeat_times + token_idx * length; int* repeat_times_now = repeat_times + token_idx * length;
for (int i = tid; i < length_id; i += blockDim.x) { for (int i = tid; i < cur_len[bi]; i += blockDim.x) {
int64_t id = pre_ids_now[i]; int64_t id = pre_ids_now[i];
if (id < 0) break; if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1); atomicAdd(&repeat_times_now[id], 1);
@@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
template <typename T> template <typename T>
__global__ void update_value_by_repeat_times( __global__ void update_value_by_repeat_times(
const int *repeat_times, const int* repeat_times,
const T *penalty_scores, const T* penalty_scores,
const T *frequency_score, const T* frequency_score,
const T *presence_score, const T* presence_score,
const float *temperatures, const float* temperatures,
T *logits, T* logits,
const int *batch_id_per_token_output, const int* batch_id_per_token_output,
const int64_t token_num, const int64_t token_num,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
@@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times(
if (bi < 0) return; if (bi < 0) return;
if (bi >= bs) return; if (bi >= bs) return;
int tid = threadIdx.x; int tid = threadIdx.x;
T *logits_now = logits + token_idx * length; T* logits_now = logits + token_idx * length;
const int *repeat_times_now = repeat_times + token_idx * length; const int* repeat_times_now = repeat_times + token_idx * length;
float alpha = static_cast<float>(penalty_scores[bi]); float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]); float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]); float gamma = static_cast<float>(presence_score[bi]);
@@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times(
} }
template <typename T> template <typename T>
__global__ void ban_bad_words(T *logits, __global__ void ban_bad_words(T* logits,
const int64_t *bad_tokens, const int64_t* bad_tokens,
const int64_t *bad_tokens_len, const int64_t* bad_tokens_len,
const int *batch_id_per_token_output, const int* batch_id_per_token_output,
const int64_t token_num, const int64_t token_num,
const int64_t bs, const int64_t bs,
const int64_t length, const int64_t length,
@@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits,
if (bi < 0) return; if (bi < 0) return;
if (bi >= bs) return; if (bi >= bs) return;
int tid = threadIdx.x; int tid = threadIdx.x;
T *logits_now = logits + token_idx * length; T* logits_now = logits + token_idx * length;
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length; const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length;
const int32_t bad_token_len = const int32_t bad_token_len =
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length)); static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
for (int i = tid; i < bad_token_len; i += blockDim.x) { for (int i = tid; i < bad_token_len; i += blockDim.x) {
@@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits,
template <paddle::DataType D> template <paddle::DataType D>
void token_penalty_multi_scores_kernel( void token_penalty_multi_scores_kernel(
const paddle::Tensor &token_ids_all, const paddle::Tensor& token_ids_all,
const paddle::Tensor &prompt_lens, const paddle::Tensor& prompt_lens,
const paddle::Tensor &logits, const paddle::Tensor& logits,
const paddle::Tensor &penalty_scores, const paddle::Tensor& penalty_scores,
const paddle::Tensor &frequency_score, const paddle::Tensor& frequency_score,
const paddle::Tensor &presence_score, const paddle::Tensor& presence_score,
const paddle::Tensor &temperatures, const paddle::Tensor& temperatures,
const paddle::Tensor &bad_tokens, const paddle::Tensor& bad_tokens,
const paddle::Tensor &bad_tokens_len, const paddle::Tensor& bad_tokens_len,
const paddle::Tensor &cur_len, const paddle::Tensor& cur_len,
const paddle::Tensor &min_len, const paddle::Tensor& min_len,
const paddle::Tensor &eos_token_id, const paddle::Tensor& eos_token_id,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor &batch_id_per_token_output, const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output, const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) { const int max_seq_len) {
typedef PDTraits<D> traits_; typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_; typedef typename traits_::DataType DataType_;
@@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel(
int64_t end_length = eos_token_id.shape()[0]; int64_t end_length = eos_token_id.shape()[0];
int block_size = (token_num + 32 - 1) / 32 * 32; int block_size = (token_num + 32 - 1) / 32 * 32;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>( min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
const_cast<data_t *>(logits.data<data_t>())),
cur_len.data<int64_t>(), cur_len.data<int64_t>(),
min_len.data<int64_t>(), min_len.data<int64_t>(),
eos_token_id.data<int64_t>(), eos_token_id.data<int64_t>(),
@@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel(
update_value_by_repeat_times<DataType_> update_value_by_repeat_times<DataType_>
<<<token_num, block_size, 0, cu_stream>>>( <<<token_num, block_size, 0, cu_stream>>>(
repeat_times.data<int>(), repeat_times.data<int>(),
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(
const_cast<data_t *>(penalty_scores.data<data_t>())), const_cast<data_t*>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(
const_cast<data_t *>(frequency_score.data<data_t>())), const_cast<data_t*>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(
const_cast<data_t *>(presence_score.data<data_t>())), const_cast<data_t*>(presence_score.data<data_t>())),
temperatures.data<float>(), temperatures.data<float>(),
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(
const_cast<data_t *>(logits.data<data_t>())), const_cast<data_t*>(logits.data<data_t>())),
batch_id_per_token_output.data<int>(), batch_id_per_token_output.data<int>(),
token_num, token_num,
bs, bs,
@@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel(
block_size = (length_bad_words + 32 - 1) / 32 * 32; block_size = (length_bad_words + 32 - 1) / 32 * 32;
block_size = min(block_size, 512); block_size = min(block_size, 512);
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>( ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
const_cast<data_t *>(logits.data<data_t>())),
bad_tokens.data<int64_t>(), bad_tokens.data<int64_t>(),
bad_tokens_len.data<int64_t>(), bad_tokens_len.data<int64_t>(),
batch_id_per_token_output.data<int>(), batch_id_per_token_output.data<int>(),
@@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel(
} }
void SpecTokenPenaltyMultiScores( void SpecTokenPenaltyMultiScores(
const paddle::Tensor &token_ids_all, const paddle::Tensor& token_ids_all,
const paddle::Tensor &prompt_lens, const paddle::Tensor& prompt_lens,
const paddle::Tensor &logits, const paddle::Tensor& logits,
const paddle::Tensor &penalty_scores, const paddle::Tensor& penalty_scores,
const paddle::Tensor &frequency_scores, const paddle::Tensor& frequency_scores,
const paddle::Tensor &presence_scores, const paddle::Tensor& presence_scores,
const paddle::Tensor &temperatures, const paddle::Tensor& temperatures,
const paddle::Tensor &bad_tokens, const paddle::Tensor& bad_tokens,
const paddle::Tensor &bad_tokens_len, const paddle::Tensor& bad_tokens_len,
const paddle::Tensor &cur_len, const paddle::Tensor& cur_len,
const paddle::Tensor &min_len, const paddle::Tensor& min_len,
const paddle::Tensor &eos_token_id, const paddle::Tensor& eos_token_id,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor &batch_id_per_token_output, const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output, const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) { const int max_seq_len) {
switch (logits.type()) { switch (logits.type()) {
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
+1 -1
View File
@@ -774,7 +774,7 @@ class SpeculativeConfig:
"benchmark_mode": False, "benchmark_mode": False,
"enf_gen_phase_tag": False, "enf_gen_phase_tag": False,
"enable_draft_logprob": False, "enable_draft_logprob": False,
"verify_strategy": "topp", "verify_strategy": "target_match",
"accept_policy": "normal", "accept_policy": "normal",
} }
+4 -4
View File
@@ -97,12 +97,12 @@ def _create_default_sampling_metadata(
return fake_sampling_metadata return fake_sampling_metadata
def _create_fd_config(max_model_len, method=None): def _create_fd_config(max_model_len, method=None, verify_strategy="topp"):
model_config: Mock = Mock() model_config: Mock = Mock()
model_config.max_model_len = max_model_len model_config.max_model_len = max_model_len
model_config.architectures = ["test_model"] model_config.architectures = ["test_model"]
model_config.mm_max_tokens_per_item = None model_config.mm_max_tokens_per_item = None
speculative_config = SpeculativeConfig({"method": method} if method else {}) speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy})
graph_opt_config = GraphOptimizationConfig({}) graph_opt_config = GraphOptimizationConfig({})
scheduler_config = SchedulerConfig({}) scheduler_config = SchedulerConfig({})
parallel_config = ParallelConfig({}) parallel_config = ParallelConfig({})
@@ -187,7 +187,7 @@ def test_speculative_sampler():
max_draft_token_num = 1 max_draft_token_num = 1
# Use ngram method for speculative decoding # Use ngram method for speculative decoding
fd_config = _create_fd_config(max_model_len, method="ngram") fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
@@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs():
max_draft_token_num = 1 max_draft_token_num = 1
# Use ngram method for speculative decoding # Use ngram method for speculative decoding
fd_config = _create_fd_config(max_model_len, method="ngram") fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
sampling_metadata.share_inputs = share_inputs sampling_metadata.share_inputs = share_inputs
@@ -61,7 +61,7 @@ def update_repeat_times(
token_ids_all_now = token_ids_all[bi] token_ids_all_now = token_ids_all[bi]
repeat_times_now = repeat_times[token_idx] repeat_times_now = repeat_times[token_idx]
for i in range(length_id): for i in range(cur_len[bi]):
id = token_ids_all_now[i] id = token_ids_all_now[i]
if id < 0: if id < 0:
break break