[Speculative Decoding] Unify Spec and non-spec branch (#6685)

* optimize spec-inference architecture

* delete debug log

* optimize spec_method usage  && fix unit_test

* add claude unit-test skill

* fix some ugly bug

* enhance robustness and bounds check

* unify method & spec_method to method to avoid bug

* activate CI

* fix unit test

* Unify logprobs computation for naive and speculative decoding, fix CUDA kernel

* fix logprob bug && optimize verify kernel

* fix exist_decode() judge
This commit is contained in:
freeliuzc
2026-03-11 14:58:44 +08:00
committed by GitHub
parent b6190de557
commit cf7934a4b2
41 changed files with 3428 additions and 392 deletions
@@ -66,14 +66,13 @@ elif current_platform.is_maca():
speculate_save_output,
speculate_save_output_topk,
speculate_set_stop_value_multi_seqs,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update,
step_paddle,
step_reschedule,
step_system_cache,
unified_update_model_status,
update_inputs,
update_inputs_v1,
)
@@ -88,11 +87,10 @@ else:
speculate_pre_process,
speculate_save_output,
speculate_save_output_topk,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_system_cache,
speculate_update,
speculate_set_stop_value_multi_seqs,
unified_update_model_status,
step_paddle,
step_system_cache,
update_inputs,
@@ -425,6 +423,8 @@ def post_process_specualate(
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
):
if think_end_id > 0:
speculate_limit_thinking_content_length(
@@ -457,18 +457,30 @@ def post_process_specualate(
if enable_entropy:
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.not_need_stop,
model_output.draft_tokens,
model_output.actual_draft_token_num,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.mask_rollback,
# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
# into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already
# handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so
# unified_update_model_status acts as a no-op for those checks. For naive mode
# (which skips verify), this kernel handles EOS/max_dec_len detection.
unified_update_model_status(
model_output.seq_lens_encoder, # seq_lens_encoder
model_output.seq_lens_decoder, # seq_lens_decoder
model_output.not_need_stop, # has_running_seqs
model_output.draft_tokens, # step_input_ids
model_output.actual_draft_token_num, # adaptive_step_input_len
model_output.accept_tokens, # step_output_ids (read-write)
model_output.accept_num, # step_output_len (read-write)
model_output.stop_flags, # stop_flags (read-write)
model_output.seq_lens_this_time, # seq_lens_this_time
model_output.is_block_step, # is_paused
model_output.mask_rollback, # mask_rollback
model_output.token_ids_all, # token_ids_all
model_output.prompt_lens, # prompt_lens
model_output.step_idx, # step_idx (read-write)
model_output.eos_token_id, # end_tokens
model_output.max_dec_len, # max_dec_len
is_naive_mode, # is_naive_mode
prefill_one_step_stop, # prefill_one_step_stop
)
if not skip_save_output:
@@ -522,20 +534,6 @@ def post_process_specualate(
save_each_rank,
)
# Update token_ids_all through accept tokens
speculate_set_value_by_flags_and_idx(
model_output.token_ids_all,
model_output.prompt_lens,
model_output.accept_tokens,
model_output.accept_num,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.step_idx,
)
def post_process(
sampler_or_pooler_output: Union[SamplerOutput, PoolerOutput],
@@ -550,6 +548,8 @@ def post_process(
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
) -> None:
"""Post-processing steps after completing a single token generation."""
@@ -575,6 +575,8 @@ def post_process(
think_end_id,
splitwise_role_is_decode,
enable_entropy,
is_naive_mode,
prefill_one_step_stop,
)
else:
post_process_normal(