mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Cherry-Pick][Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy(#7467) (#7468)
* fix repeat_time kernel and change default spec verify strategy * fix unit_test
This commit is contained in:
@@ -16,12 +16,12 @@
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ inline void min_length_logits_process(
|
__global__ inline void min_length_logits_process(
|
||||||
T *logits,
|
T* logits,
|
||||||
const int64_t *cur_len,
|
const int64_t* cur_len,
|
||||||
const int64_t *min_len,
|
const int64_t* min_len,
|
||||||
const int64_t *eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int *batch_id_per_token_output,
|
const int* batch_id_per_token_output,
|
||||||
const int *cu_seqlens_q_output,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
@@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process(
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ inline void min_length_logits_process<half>(
|
__global__ inline void min_length_logits_process<half>(
|
||||||
half *logits,
|
half* logits,
|
||||||
const int64_t *cur_len,
|
const int64_t* cur_len,
|
||||||
const int64_t *min_len,
|
const int64_t* min_len,
|
||||||
const int64_t *eos_token_id,
|
const int64_t* eos_token_id,
|
||||||
const int *batch_id_per_token_output,
|
const int* batch_id_per_token_output,
|
||||||
const int *cu_seqlens_q_output,
|
const int* cu_seqlens_q_output,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
@@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process<half>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void update_repeat_times(const int64_t *token_ids_all,
|
__global__ void update_repeat_times(const int64_t* token_ids_all,
|
||||||
const int64_t *prompt_lens,
|
const int64_t* prompt_lens,
|
||||||
const int64_t *cur_len,
|
const int64_t* cur_len,
|
||||||
int *repeat_times,
|
int* repeat_times,
|
||||||
const int *batch_id_per_token_output,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
@@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
|
const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
|
||||||
int *repeat_times_now = repeat_times + token_idx * length;
|
int* repeat_times_now = repeat_times + token_idx * length;
|
||||||
for (int i = tid; i < length_id; i += blockDim.x) {
|
for (int i = tid; i < cur_len[bi]; i += blockDim.x) {
|
||||||
int64_t id = pre_ids_now[i];
|
int64_t id = pre_ids_now[i];
|
||||||
if (id < 0) break;
|
if (id < 0) break;
|
||||||
atomicAdd(&repeat_times_now[id], 1);
|
atomicAdd(&repeat_times_now[id], 1);
|
||||||
@@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void update_value_by_repeat_times(
|
__global__ void update_value_by_repeat_times(
|
||||||
const int *repeat_times,
|
const int* repeat_times,
|
||||||
const T *penalty_scores,
|
const T* penalty_scores,
|
||||||
const T *frequency_score,
|
const T* frequency_score,
|
||||||
const T *presence_score,
|
const T* presence_score,
|
||||||
const float *temperatures,
|
const float* temperatures,
|
||||||
T *logits,
|
T* logits,
|
||||||
const int *batch_id_per_token_output,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
@@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times(
|
|||||||
if (bi < 0) return;
|
if (bi < 0) return;
|
||||||
if (bi >= bs) return;
|
if (bi >= bs) return;
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
T *logits_now = logits + token_idx * length;
|
T* logits_now = logits + token_idx * length;
|
||||||
const int *repeat_times_now = repeat_times + token_idx * length;
|
const int* repeat_times_now = repeat_times + token_idx * length;
|
||||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||||
float beta = static_cast<float>(frequency_score[bi]);
|
float beta = static_cast<float>(frequency_score[bi]);
|
||||||
float gamma = static_cast<float>(presence_score[bi]);
|
float gamma = static_cast<float>(presence_score[bi]);
|
||||||
@@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ban_bad_words(T *logits,
|
__global__ void ban_bad_words(T* logits,
|
||||||
const int64_t *bad_tokens,
|
const int64_t* bad_tokens,
|
||||||
const int64_t *bad_tokens_len,
|
const int64_t* bad_tokens_len,
|
||||||
const int *batch_id_per_token_output,
|
const int* batch_id_per_token_output,
|
||||||
const int64_t token_num,
|
const int64_t token_num,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
@@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits,
|
|||||||
if (bi < 0) return;
|
if (bi < 0) return;
|
||||||
if (bi >= bs) return;
|
if (bi >= bs) return;
|
||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
T *logits_now = logits + token_idx * length;
|
T* logits_now = logits + token_idx * length;
|
||||||
const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length;
|
const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length;
|
||||||
const int32_t bad_token_len =
|
const int32_t bad_token_len =
|
||||||
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
|
static_cast<int32_t>(min(bad_tokens_len[bi], bad_words_length));
|
||||||
for (int i = tid; i < bad_token_len; i += blockDim.x) {
|
for (int i = tid; i < bad_token_len; i += blockDim.x) {
|
||||||
@@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits,
|
|||||||
|
|
||||||
template <paddle::DataType D>
|
template <paddle::DataType D>
|
||||||
void token_penalty_multi_scores_kernel(
|
void token_penalty_multi_scores_kernel(
|
||||||
const paddle::Tensor &token_ids_all,
|
const paddle::Tensor& token_ids_all,
|
||||||
const paddle::Tensor &prompt_lens,
|
const paddle::Tensor& prompt_lens,
|
||||||
const paddle::Tensor &logits,
|
const paddle::Tensor& logits,
|
||||||
const paddle::Tensor &penalty_scores,
|
const paddle::Tensor& penalty_scores,
|
||||||
const paddle::Tensor &frequency_score,
|
const paddle::Tensor& frequency_score,
|
||||||
const paddle::Tensor &presence_score,
|
const paddle::Tensor& presence_score,
|
||||||
const paddle::Tensor &temperatures,
|
const paddle::Tensor& temperatures,
|
||||||
const paddle::Tensor &bad_tokens,
|
const paddle::Tensor& bad_tokens,
|
||||||
const paddle::Tensor &bad_tokens_len,
|
const paddle::Tensor& bad_tokens_len,
|
||||||
const paddle::Tensor &cur_len,
|
const paddle::Tensor& cur_len,
|
||||||
const paddle::Tensor &min_len,
|
const paddle::Tensor& min_len,
|
||||||
const paddle::Tensor &eos_token_id,
|
const paddle::Tensor& eos_token_id,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor &batch_id_per_token_output,
|
const paddle::Tensor& batch_id_per_token_output,
|
||||||
const paddle::Tensor &cu_seqlens_q_output,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const int max_seq_len) {
|
const int max_seq_len) {
|
||||||
typedef PDTraits<D> traits_;
|
typedef PDTraits<D> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
@@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel(
|
|||||||
int64_t end_length = eos_token_id.shape()[0];
|
int64_t end_length = eos_token_id.shape()[0];
|
||||||
int block_size = (token_num + 32 - 1) / 32 * 32;
|
int block_size = (token_num + 32 - 1) / 32 * 32;
|
||||||
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
min_length_logits_process<<<1, block_size, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
|
||||||
const_cast<data_t *>(logits.data<data_t>())),
|
|
||||||
cur_len.data<int64_t>(),
|
cur_len.data<int64_t>(),
|
||||||
min_len.data<int64_t>(),
|
min_len.data<int64_t>(),
|
||||||
eos_token_id.data<int64_t>(),
|
eos_token_id.data<int64_t>(),
|
||||||
@@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel(
|
|||||||
update_value_by_repeat_times<DataType_>
|
update_value_by_repeat_times<DataType_>
|
||||||
<<<token_num, block_size, 0, cu_stream>>>(
|
<<<token_num, block_size, 0, cu_stream>>>(
|
||||||
repeat_times.data<int>(),
|
repeat_times.data<int>(),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(
|
||||||
const_cast<data_t *>(penalty_scores.data<data_t>())),
|
const_cast<data_t*>(penalty_scores.data<data_t>())),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(
|
||||||
const_cast<data_t *>(frequency_score.data<data_t>())),
|
const_cast<data_t*>(frequency_score.data<data_t>())),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(
|
||||||
const_cast<data_t *>(presence_score.data<data_t>())),
|
const_cast<data_t*>(presence_score.data<data_t>())),
|
||||||
temperatures.data<float>(),
|
temperatures.data<float>(),
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(
|
||||||
const_cast<data_t *>(logits.data<data_t>())),
|
const_cast<data_t*>(logits.data<data_t>())),
|
||||||
batch_id_per_token_output.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
token_num,
|
token_num,
|
||||||
bs,
|
bs,
|
||||||
@@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel(
|
|||||||
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
block_size = (length_bad_words + 32 - 1) / 32 * 32;
|
||||||
block_size = min(block_size, 512);
|
block_size = min(block_size, 512);
|
||||||
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
|
ban_bad_words<DataType_><<<token_num, block_size, 0, cu_stream>>>(
|
||||||
reinterpret_cast<DataType_ *>(
|
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
|
||||||
const_cast<data_t *>(logits.data<data_t>())),
|
|
||||||
bad_tokens.data<int64_t>(),
|
bad_tokens.data<int64_t>(),
|
||||||
bad_tokens_len.data<int64_t>(),
|
bad_tokens_len.data<int64_t>(),
|
||||||
batch_id_per_token_output.data<int>(),
|
batch_id_per_token_output.data<int>(),
|
||||||
@@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SpecTokenPenaltyMultiScores(
|
void SpecTokenPenaltyMultiScores(
|
||||||
const paddle::Tensor &token_ids_all,
|
const paddle::Tensor& token_ids_all,
|
||||||
const paddle::Tensor &prompt_lens,
|
const paddle::Tensor& prompt_lens,
|
||||||
const paddle::Tensor &logits,
|
const paddle::Tensor& logits,
|
||||||
const paddle::Tensor &penalty_scores,
|
const paddle::Tensor& penalty_scores,
|
||||||
const paddle::Tensor &frequency_scores,
|
const paddle::Tensor& frequency_scores,
|
||||||
const paddle::Tensor &presence_scores,
|
const paddle::Tensor& presence_scores,
|
||||||
const paddle::Tensor &temperatures,
|
const paddle::Tensor& temperatures,
|
||||||
const paddle::Tensor &bad_tokens,
|
const paddle::Tensor& bad_tokens,
|
||||||
const paddle::Tensor &bad_tokens_len,
|
const paddle::Tensor& bad_tokens_len,
|
||||||
const paddle::Tensor &cur_len,
|
const paddle::Tensor& cur_len,
|
||||||
const paddle::Tensor &min_len,
|
const paddle::Tensor& min_len,
|
||||||
const paddle::Tensor &eos_token_id,
|
const paddle::Tensor& eos_token_id,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor &batch_id_per_token_output,
|
const paddle::Tensor& batch_id_per_token_output,
|
||||||
const paddle::Tensor &cu_seqlens_q_output,
|
const paddle::Tensor& cu_seqlens_q_output,
|
||||||
const int max_seq_len) {
|
const int max_seq_len) {
|
||||||
switch (logits.type()) {
|
switch (logits.type()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
|
|||||||
@@ -774,7 +774,7 @@ class SpeculativeConfig:
|
|||||||
"benchmark_mode": False,
|
"benchmark_mode": False,
|
||||||
"enf_gen_phase_tag": False,
|
"enf_gen_phase_tag": False,
|
||||||
"enable_draft_logprob": False,
|
"enable_draft_logprob": False,
|
||||||
"verify_strategy": "topp",
|
"verify_strategy": "target_match",
|
||||||
"accept_policy": "normal",
|
"accept_policy": "normal",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,12 +97,12 @@ def _create_default_sampling_metadata(
|
|||||||
return fake_sampling_metadata
|
return fake_sampling_metadata
|
||||||
|
|
||||||
|
|
||||||
def _create_fd_config(max_model_len, method=None):
|
def _create_fd_config(max_model_len, method=None, verify_strategy="topp"):
|
||||||
model_config: Mock = Mock()
|
model_config: Mock = Mock()
|
||||||
model_config.max_model_len = max_model_len
|
model_config.max_model_len = max_model_len
|
||||||
model_config.architectures = ["test_model"]
|
model_config.architectures = ["test_model"]
|
||||||
model_config.mm_max_tokens_per_item = None
|
model_config.mm_max_tokens_per_item = None
|
||||||
speculative_config = SpeculativeConfig({"method": method} if method else {})
|
speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy})
|
||||||
graph_opt_config = GraphOptimizationConfig({})
|
graph_opt_config = GraphOptimizationConfig({})
|
||||||
scheduler_config = SchedulerConfig({})
|
scheduler_config = SchedulerConfig({})
|
||||||
parallel_config = ParallelConfig({})
|
parallel_config = ParallelConfig({})
|
||||||
@@ -187,7 +187,7 @@ def test_speculative_sampler():
|
|||||||
max_draft_token_num = 1
|
max_draft_token_num = 1
|
||||||
|
|
||||||
# Use ngram method for speculative decoding
|
# Use ngram method for speculative decoding
|
||||||
fd_config = _create_fd_config(max_model_len, method="ngram")
|
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
|
||||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
||||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||||
@@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs():
|
|||||||
max_draft_token_num = 1
|
max_draft_token_num = 1
|
||||||
|
|
||||||
# Use ngram method for speculative decoding
|
# Use ngram method for speculative decoding
|
||||||
fd_config = _create_fd_config(max_model_len, method="ngram")
|
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
|
||||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
||||||
sampling_metadata.share_inputs = share_inputs
|
sampling_metadata.share_inputs = share_inputs
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ def update_repeat_times(
|
|||||||
token_ids_all_now = token_ids_all[bi]
|
token_ids_all_now = token_ids_all[bi]
|
||||||
repeat_times_now = repeat_times[token_idx]
|
repeat_times_now = repeat_times[token_idx]
|
||||||
|
|
||||||
for i in range(length_id):
|
for i in range(cur_len[bi]):
|
||||||
id = token_ids_all_now[i]
|
id = token_ids_all_now[i]
|
||||||
if id < 0:
|
if id < 0:
|
||||||
break
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user