diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu index ca5d8353c3..022c39bfb6 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu @@ -16,12 +16,12 @@ template __global__ inline void min_length_logits_process( - T *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process( template <> __global__ inline void min_length_logits_process( - half *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + half* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process( } } -__global__ void update_repeat_times(const int64_t *token_ids_all, - const int64_t *prompt_lens, - const int64_t *cur_len, - int *repeat_times, - const int *batch_id_per_token_output, +__global__ void update_repeat_times(const int64_t* token_ids_all, + const int64_t* prompt_lens, + const int64_t* cur_len, + int* repeat_times, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, return; } int tid = threadIdx.x; - const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; - int *repeat_times_now = repeat_times + token_idx * length; - for (int i = tid; i < length_id; i += blockDim.x) { + const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; + int* repeat_times_now = repeat_times + token_idx * length; + for (int i = tid; i < cur_len[bi]; i += blockDim.x) { int64_t id = pre_ids_now[i]; if (id < 0) break; atomicAdd(&repeat_times_now[id], 1); @@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, template __global__ void update_value_by_repeat_times( - const int *repeat_times, - const T *penalty_scores, - const T *frequency_score, - const T *presence_score, - const float *temperatures, - T *logits, - const int *batch_id_per_token_output, + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times( if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int *repeat_times_now = repeat_times + token_idx * length; + T* logits_now = logits + token_idx * length; + const int* repeat_times_now = repeat_times + token_idx * length; float alpha = static_cast(penalty_scores[bi]); float beta = static_cast(frequency_score[bi]); float gamma = static_cast(presence_score[bi]); @@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times( } template -__global__ void ban_bad_words(T *logits, - const int64_t *bad_tokens, - const int64_t *bad_tokens_len, - const int *batch_id_per_token_output, +__global__ void ban_bad_words(T* logits, + const int64_t* bad_tokens, + const int64_t* bad_tokens_len, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits, if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length; + T* logits_now = logits + token_idx * length; + const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length; const int32_t bad_token_len = static_cast(min(bad_tokens_len[bi], bad_words_length)); for (int i = tid; i < bad_token_len; i += blockDim.x) { @@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits, template void token_penalty_multi_scores_kernel( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_score, - const paddle::Tensor &presence_score, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_score, + const paddle::Tensor& presence_score, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel( int64_t end_length = eos_token_id.shape()[0]; int block_size = (token_num + 32 - 1) / 32 * 32; min_length_logits_process<<<1, block_size, 0, cu_stream>>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), cur_len.data(), min_len.data(), eos_token_id.data(), @@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel( update_value_by_repeat_times <<>>( repeat_times.data(), - reinterpret_cast( - const_cast(penalty_scores.data())), - reinterpret_cast( - const_cast(frequency_score.data())), - reinterpret_cast( - const_cast(presence_score.data())), + reinterpret_cast( + const_cast(penalty_scores.data())), + reinterpret_cast( + const_cast(frequency_score.data())), + reinterpret_cast( + const_cast(presence_score.data())), temperatures.data(), - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast( + const_cast(logits.data())), batch_id_per_token_output.data(), token_num, bs, @@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel( block_size = (length_bad_words + 32 - 1) / 32 * 32; block_size = min(block_size, 512); ban_bad_words<<>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), bad_tokens.data(), bad_tokens_len.data(), batch_id_per_token_output.data(), @@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel( } void SpecTokenPenaltyMultiScores( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e9a213f436..96ab4dc2a9 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -775,7 +775,7 @@ class SpeculativeConfig: "benchmark_mode": False, "enf_gen_phase_tag": False, "enable_draft_logprob": False, - "verify_strategy": "topp", + "verify_strategy": "target_match", "accept_policy": "normal", } diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index ef75fe5d4e..227e73db5a 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -97,12 +97,12 @@ def _create_default_sampling_metadata( return fake_sampling_metadata -def _create_fd_config(max_model_len, method=None): +def _create_fd_config(max_model_len, method=None, verify_strategy="topp"): model_config: Mock = Mock() model_config.max_model_len = max_model_len model_config.architectures = ["test_model"] model_config.mm_max_tokens_per_item = None - speculative_config = SpeculativeConfig({"method": method} if method else {}) + speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy}) graph_opt_config = GraphOptimizationConfig({}) scheduler_config = SchedulerConfig({}) parallel_config = ParallelConfig({}) @@ -187,7 +187,7 @@ def test_speculative_sampler(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) @@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) sampling_metadata.share_inputs = share_inputs diff --git a/tests/operators/test_speculate_get_token_penalty_multi_scores.py b/tests/operators/test_speculate_get_token_penalty_multi_scores.py index 845f666ee7..61efdbf270 100644 --- a/tests/operators/test_speculate_get_token_penalty_multi_scores.py +++ b/tests/operators/test_speculate_get_token_penalty_multi_scores.py @@ -61,7 +61,7 @@ def update_repeat_times( token_ids_all_now = token_ids_all[bi] repeat_times_now = repeat_times[token_idx] - for i in range(length_id): + for i in range(cur_len[bi]): id = token_ids_all_now[i] if id < 0: break