[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)

* support new mtp

* refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process

* fix cuda-graph for spec-decoding

* fix xpu mtp and fix some note

* fix unittest and optmize note

* fix model status update in eos-branch
This commit is contained in:
freeliuzc
2026-03-24 10:19:01 +08:00
committed by GitHub
parent c62f6b4ea5
commit e87ce4b8cd
13 changed files with 1401 additions and 1150 deletions
+46 -1
View File
@@ -116,8 +116,10 @@ class MTPProposer(Proposer):
self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode
if current_platform.is_xpu():
self._prepare_inputs = self._prepare_inputs_xpu
self._propose = self._propose_xpu
elif current_platform.is_cuda() or current_platform.is_maca():
self._prepare_inputs = self._prepare_inputs_cuda
self._propose = self._propose_cuda
else:
raise RuntimeError(
@@ -705,10 +707,53 @@ class MTPProposer(Proposer):
else:
return 0
def _prepare_inputs(self, full_hidden_states):
def _prepare_inputs_cuda(self, full_hidden_states):
"""
Prepare MTP inputs
MTP state (seq_lens_decoder, step_idx) is "shadow state":
- Initialized from target model state each round
- Used for MTP forward, but not committed until verify
- No rollback needed since it's always re-initialized
"""
draft_model_preprocess(
self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"],
self.model_inputs["stop_flags"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["pre_ids"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["seq_lens_decoder"],
self.target_model_inputs["step_idx"],
self.target_model_inputs["stop_flags"],
self.model_inputs["max_dec_len"],
self.target_model_inputs["draft_tokens"],
self.num_model_steps,
self.role == "prefill", # is_splitwise_prefill
)
target_hidden_states = eagle_get_hidden_states(
full_hidden_states,
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["stop_flags"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps,
)
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def _prepare_inputs_xpu(self, full_hidden_states):
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
draft_model_preprocess(
self.model_inputs["draft_tokens"],