mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy (#7467)
* fix repeat_time kernel and change default spec verify strategy * fix unit_test
This commit is contained in:
@@ -97,12 +97,12 @@ def _create_default_sampling_metadata(
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _create_fd_config(max_model_len, method=None):
|
||||
def _create_fd_config(max_model_len, method=None, verify_strategy="topp"):
|
||||
model_config: Mock = Mock()
|
||||
model_config.max_model_len = max_model_len
|
||||
model_config.architectures = ["test_model"]
|
||||
model_config.mm_max_tokens_per_item = None
|
||||
speculative_config = SpeculativeConfig({"method": method} if method else {})
|
||||
speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy})
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
parallel_config = ParallelConfig({})
|
||||
@@ -187,7 +187,7 @@ def test_speculative_sampler():
|
||||
max_draft_token_num = 1
|
||||
|
||||
# Use ngram method for speculative decoding
|
||||
fd_config = _create_fd_config(max_model_len, method="ngram")
|
||||
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
@@ -208,7 +208,7 @@ def test_speculative_sampler_logprobs():
|
||||
max_draft_token_num = 1
|
||||
|
||||
# Use ngram method for speculative decoding
|
||||
fd_config = _create_fd_config(max_model_len, method="ngram")
|
||||
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
||||
sampling_metadata.share_inputs = share_inputs
|
||||
|
||||
@@ -61,7 +61,7 @@ def update_repeat_times(
|
||||
token_ids_all_now = token_ids_all[bi]
|
||||
repeat_times_now = repeat_times[token_idx]
|
||||
|
||||
for i in range(length_id):
|
||||
for i in range(cur_len[bi]):
|
||||
id = token_ids_all_now[i]
|
||||
if id < 0:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user