mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)
* support new mtp * refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process * fix cuda-graph for spec-decoding * fix xpu mtp and fix some note * fix unittest and optmize note * fix model status update in eos-branch
This commit is contained in:
@@ -442,8 +442,6 @@ def post_process_specualate(
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
):
|
||||
if think_end_id > 0:
|
||||
@@ -502,16 +500,16 @@ def post_process_specualate(
|
||||
)
|
||||
|
||||
# Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx
|
||||
# into a single kernel launch. For MTP/ngram paths, verify_draft_tokens has already
|
||||
# handled EOS/max_dec_len detection (replacing tokens + updating step_idx), so
|
||||
# unified_update_model_status acts as a no-op for those checks. For naive mode
|
||||
# (which skips verify), this kernel handles EOS/max_dec_len detection.
|
||||
# into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx
|
||||
# advancement, token_ids_all history write, and stop_flags/not_need_stop update for all
|
||||
# paths (MTP, ngram, naive). Note: verify_draft_tokens intentionally does NOT write back
|
||||
# step_idx (it is read-only in that kernel); step_idx is always updated here.
|
||||
|
||||
unified_update_model_status(
|
||||
model_output.seq_lens_encoder, # seq_lens_encoder
|
||||
model_output.seq_lens_decoder, # seq_lens_decoder
|
||||
model_output.not_need_stop, # has_running_seqs
|
||||
model_output.draft_tokens, # step_input_ids
|
||||
model_output.actual_draft_token_num, # adaptive_step_input_len
|
||||
model_output.accept_tokens, # step_output_ids (read-write)
|
||||
model_output.accept_num, # step_output_len (read-write)
|
||||
model_output.stop_flags, # stop_flags (read-write)
|
||||
@@ -523,8 +521,6 @@ def post_process_specualate(
|
||||
model_output.step_idx, # step_idx (read-write)
|
||||
model_output.eos_token_id, # end_tokens
|
||||
model_output.max_dec_len, # max_dec_len
|
||||
is_naive_mode, # is_naive_mode
|
||||
prefill_one_step_stop, # prefill_one_step_stop
|
||||
)
|
||||
|
||||
if not skip_save_output:
|
||||
@@ -592,8 +588,6 @@ def post_process(
|
||||
think_end_id: int = -1,
|
||||
splitwise_role_is_decode: bool = False,
|
||||
enable_entropy: bool = False,
|
||||
is_naive_mode: bool = False,
|
||||
prefill_one_step_stop: bool = False,
|
||||
routing_replay_manager: RoutingReplayManager = None,
|
||||
) -> None:
|
||||
"""Post-processing steps after completing a single token generation."""
|
||||
@@ -621,8 +615,6 @@ def post_process(
|
||||
think_end_id,
|
||||
splitwise_role_is_decode,
|
||||
enable_entropy,
|
||||
is_naive_mode,
|
||||
prefill_one_step_stop,
|
||||
routing_replay_manager,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user