[Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy (#7467)

* fix repeat_time kernel and change default spec verify strategy

* fix unit_test
This commit is contained in:
freeliuzc
2026-04-18 00:38:01 +08:00
committed by GitHub
parent df3b4e12f4
commit 22a4f6019d
4 changed files with 81 additions and 83 deletions
+4 -4
View File
@@ -97,12 +97,12 @@ def _create_default_sampling_metadata(
return fake_sampling_metadata
def _create_fd_config(max_model_len, method=None):
def _create_fd_config(max_model_len, method=None, verify_strategy="topp"):
model_config: Mock = Mock()
model_config.max_model_len = max_model_len
model_config.architectures = ["test_model"]
model_config.mm_max_tokens_per_item = None
speculative_config = SpeculativeConfig({"method": method} if method else {})
speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy})
graph_opt_config = GraphOptimizationConfig({})
scheduler_config = SchedulerConfig({})
parallel_config = ParallelConfig({})
@@ -187,7 +187,7 @@ def test_speculative_sampler():
max_draft_token_num = 1
# Use ngram method for speculative decoding
fd_config = _create_fd_config(max_model_len, method="ngram")
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
@@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs():
max_draft_token_num = 1
# Use ngram method for speculative decoding
fd_config = _create_fd_config(max_model_len, method="ngram")
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
sampling_metadata.share_inputs = share_inputs
@@ -61,7 +61,7 @@ def update_repeat_times(
token_ids_all_now = token_ids_all[bi]
repeat_times_now = repeat_times[token_idx]
for i in range(length_id):
for i in range(cur_len[bi]):
id = token_ids_all_now[i]
if id < 0:
break