[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)

* support new mtp

* refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process

* fix cuda-graph for spec-decoding

* fix xpu mtp and fix some note

* fix unittest and optmize note

* fix model status update in eos-branch
This commit is contained in:
freeliuzc
2026-03-24 10:19:01 +08:00
committed by GitHub
parent c62f6b4ea5
commit e87ce4b8cd
13 changed files with 1401 additions and 1150 deletions
@@ -442,8 +442,6 @@ def post_process_specualate(
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
routing_replay_manager: RoutingReplayManager = None,
):
if think_end_id > 0:
@@ -502,16 +500,16 @@ def post_process_specualate(
)
# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
# into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already
# handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so
# unified_update_model_status acts as a no-op for those checks. For naive mode
# (which skips verify), this kernel handles EOS/max_dec_len detection.
# into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx
# advancement, token_ids_all history write, and stop_flags/not_need_stop update for all
# paths (MTP, ngram, naive). Note: verify_draft_tokens intentionally does NOT write back
# step_idx (it is read-only in that kernel); step_idx is always updated here.
unified_update_model_status(
model_output.seq_lens_encoder, # seq_lens_encoder
model_output.seq_lens_decoder, # seq_lens_decoder
model_output.not_need_stop, # has_running_seqs
model_output.draft_tokens, # step_input_ids
model_output.actual_draft_token_num, # adaptive_step_input_len
model_output.accept_tokens, # step_output_ids (read-write)
model_output.accept_num, # step_output_len (read-write)
model_output.stop_flags, # stop_flags (read-write)
@@ -523,8 +521,6 @@ def post_process_specualate(
model_output.step_idx, # step_idx (read-write)
model_output.eos_token_id, # end_tokens
model_output.max_dec_len, # max_dec_len
is_naive_mode, # is_naive_mode
prefill_one_step_stop, # prefill_one_step_stop
)
if not skip_save_output:
@@ -592,8 +588,6 @@ def post_process(
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
is_naive_mode: bool = False,
prefill_one_step_stop: bool = False,
routing_replay_manager: RoutingReplayManager = None,
) -> None:
"""Post-processing steps after completing a single token generation."""
@@ -621,8 +615,6 @@ def post_process(
think_end_id,
splitwise_role_is_decode,
enable_entropy,
is_naive_mode,
prefill_one_step_stop,
routing_replay_manager,
)
else: