[Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy (#7467)

* fix repeat_time kernel and change default spec verify strategy

* fix unit_test
This commit is contained in:
freeliuzc
2026-04-18 00:38:01 +08:00
committed by GitHub
parent df3b4e12f4
commit 22a4f6019d
4 changed files with 81 additions and 83 deletions
@@ -16,12 +16,12 @@
template <typename T>
__global__ inline void min_length_logits_process(
T *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
T* logits,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int* batch_id_per_token_output,
const int* cu_seqlens_q_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process(
template <>
__global__ inline void min_length_logits_process<half>(
half *logits,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int *batch_id_per_token_output,
const int *cu_seqlens_q_output,
half* logits,
const int64_t* cur_len,
const int64_t* min_len,
const int64_t* eos_token_id,
const int* batch_id_per_token_output,
const int* cu_seqlens_q_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process<half>(
}
}
__global__ void update_repeat_times(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *cur_len,
int *repeat_times,
const int *batch_id_per_token_output,
__global__ void update_repeat_times(const int64_t* token_ids_all,
const int64_t* prompt_lens,
const int64_t* cur_len,
int* repeat_times,
const int* batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
return;
}
int tid = threadIdx.x;
const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
int *repeat_times_now = repeat_times + token_idx * length;
for (int i = tid; i < length_id; i += blockDim.x) {
const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
int* repeat_times_now = repeat_times + token_idx * length;
for (int i = tid; i < cur_len[bi]; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
@@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
template <typename T>
__global__ void update_value_by_repeat_times(
const int *repeat_times,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
T *logits,
const int *batch_id_per_token_output,
const int* repeat_times,
const T* penalty_scores,
const T* frequency_score,
const T* presence_score,
const float* temperatures,
T* logits,
const int* batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times(
if (bi < 0) return;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
const int *repeat_times_now = repeat_times + token_idx * length;
T* logits_now = logits + token_idx * length;
const int* repeat_times_now = repeat_times + token_idx * length;
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
@@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times(
}
template <typename T>
__global__ void ban_bad_words(T *logits,
const int64_t *bad_tokens,
const int64_t *bad_tokens_len,
const int *batch_id_per_token_output,
__global__ void ban_bad_words(T* logits,
const int64_t* bad_tokens,
const int64_t* bad_tokens_len,
const int* batch_id_per_token_output,
const int64_t token_num,
const int64_t bs,
const int64_t length,
@@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits,
if (bi < 0) return;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length;
T* logits_now = logits + token_idx * length;
const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length;
const int32_t bad_token_len =
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
for (int i = tid; i < bad_token_len; i += blockDim.x) {
@@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits,
template <paddle::DataType D>
void token_penalty_multi_scores_kernel(
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_score,
const paddle::Tensor &presence_score,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &bad_tokens_len,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_score,
const paddle::Tensor& presence_score,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
@@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel(
int64_t end_length = eos_token_id.shape()[0];
int block_size = (token_num + 32 - 1) / 32 * 32;
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
@@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel(
update_value_by_repeat_times<DataType_>
<<<token_num, block_size, 0, cu_stream>>>(
repeat_times.data<int>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(presence_score.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(presence_score.data<data_t>())),
temperatures.data<float>(),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
reinterpret_cast<DataType_*>(
const_cast<data_t*>(logits.data<data_t>())),
batch_id_per_token_output.data<int>(),
token_num,
bs,
@@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel(
block_size = (length_bad_words + 32 - 1) / 32 * 32;
block_size = min(block_size, 512);
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(logits.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
bad_tokens.data<int64_t>(),
bad_tokens_len.data<int64_t>(),
batch_id_per_token_output.data<int>(),
@@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel(
}
void SpecTokenPenaltyMultiScores(
const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &logits,
const paddle::Tensor &penalty_scores,
const paddle::Tensor &frequency_scores,
const paddle::Tensor &presence_scores,
const paddle::Tensor &temperatures,
const paddle::Tensor &bad_tokens,
const paddle::Tensor &bad_tokens_len,
const paddle::Tensor &cur_len,
const paddle::Tensor &min_len,
const paddle::Tensor &eos_token_id,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token_output,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& bad_tokens_len,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token_output,
const paddle::Tensor& cu_seqlens_q_output,
const int max_seq_len) {
switch (logits.type()) {
case paddle::DataType::BFLOAT16: {