mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Unify Spec and non-spec branch (#6685)
* optimize spec-inference architecture * delete debug log * optimize spec_method usage && fix unit_test * add claude unit-test skill * fix some ugly bug * enhance robustness and bounds check * unify method & spec_method to method to avoid bug * activate CI * fix unit test * Unify logprobs computation for naive and speculative decoding, fix CUDA kernel * fix logprob bug && optimize verify kernel * fix exist_decode() judge
This commit is contained in:
@@ -66,14 +66,13 @@ elif current_platform.is_maca():
|
||||
speculate_save_output,
|
||||
speculate_save_output_topk,
|
||||
speculate_set_stop_value_multi_seqs,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_step_reschedule,
|
||||
speculate_step_system_cache,
|
||||
speculate_update,
|
||||
step_paddle,
|
||||
step_reschedule,
|
||||
step_system_cache,
|
||||
unified_update_model_status,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
)
|
||||
@@ -88,11 +87,10 @@ else:
|
||||
speculate_pre_process,
|
||||
speculate_save_output,
|
||||
speculate_save_output_topk,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_step_system_cache,
|
||||
speculate_update,
|
||||
speculate_set_stop_value_multi_seqs,
|
||||
unified_update_model_status,
|
||||
step_paddle,
|
||||
step_system_cache,
|
||||
update_inputs,
|
||||
@@ -425,6 +423,8 @@ def post_process_specualate(
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
):
|
||||
if think_end_id > 0:
|
||||
speculate_limit_thinking_content_length(
|
||||
@@ -457,18 +457,30 @@ def post_process_specualate(
|
||||
if enable_entropy:
|
||||
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
|
||||
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
model_output.draft_tokens,
|
||||
model_output.actual_draft_token_num,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.is_block_step,
|
||||
model_output.mask_rollback,
|
||||
# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
|
||||
# into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already
|
||||
# handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so
|
||||
# unified_update_model_status acts as a no-op for those checks. For naive mode
|
||||
# (which skips verify), this kernel handles EOS/max_dec_len detection.
|
||||
unified_update_model_status(
|
||||
model_output.seq_lens_encoder, # seq_lens_encoder
|
||||
model_output.seq_lens_decoder, # seq_lens_decoder
|
||||
model_output.not_need_stop, # has_running_seqs
|
||||
model_output.draft_tokens, # step_input_ids
|
||||
model_output.actual_draft_token_num, # adaptive_step_input_len
|
||||
model_output.accept_tokens, # step_output_ids (read-write)
|
||||
model_output.accept_num, # step_output_len (read-write)
|
||||
model_output.stop_flags, # stop_flags (read-write)
|
||||
model_output.seq_lens_this_time, # seq_lens_this_time
|
||||
model_output.is_block_step, # is_paused
|
||||
model_output.mask_rollback, # mask_rollback
|
||||
model_output.token_ids_all, # token_ids_all
|
||||
model_output.prompt_lens, # prompt_lens
|
||||
model_output.step_idx, # step_idx (read-write)
|
||||
model_output.eos_token_id, # end_tokens
|
||||
model_output.max_dec_len, # max_dec_len
|
||||
is_naive_mode, # is_naive_mode
|
||||
prefill_one_step_stop, # prefill_one_step_stop
|
||||
)
|
||||
|
||||
if not skip_save_output:
|
||||
@@ -522,20 +534,6 @@ def post_process_specualate(
|
||||
save_each_rank,
|
||||
)
|
||||
|
||||
# Update token_ids_all through accept tokens
|
||||
|
||||
speculate_set_value_by_flags_and_idx(
|
||||
model_output.token_ids_all,
|
||||
model_output.prompt_lens,
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.step_idx,
|
||||
)
|
||||
|
||||
|
||||
def post_process(
|
||||
sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput],
|
||||
@@ -550,6 +548,8 @@ def post_process(
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
|
||||
@@ -575,6 +575,8 @@ def post_process(
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
is_naive_mode,
|
||||
prefill_one_step_stop,
|
||||
)
|
||||
else:
|
||||
post_process_normal(
|
||||
|
||||
Reference in New Issue
Block a user