mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] refactor MTP and optimize spec-decoding postprocess (#6973)
* support new mtp * refactor(speculate_decoding and mtp): optimize mtp sturcture logic. Update spec-branch status-process * fix cuda-graph for spec-decoding * fix xpu mtp and fix some note * fix unittest and optmize note * fix model status update in eos-branch
This commit is contained in:
@@ -116,8 +116,10 @@ class MTPProposer(Proposer):
|
||||
self.pd_disaggregation_mode = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
if current_platform.is_xpu():
|
||||
self._prepare_inputs = self._prepare_inputs_xpu
|
||||
self._propose = self._propose_xpu
|
||||
elif current_platform.is_cuda() or current_platform.is_maca():
|
||||
self._prepare_inputs = self._prepare_inputs_cuda
|
||||
self._propose = self._propose_cuda
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@@ -705,10 +707,53 @@ class MTPProposer(Proposer):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _prepare_inputs(self, full_hidden_states):
|
||||
def _prepare_inputs_cuda(self, full_hidden_states):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
|
||||
MTP state (seq_lens_decoder, step_idx) is "shadow state":
|
||||
- Initialized from target model state each round
|
||||
- Used for MTP forward, but not committed until verify
|
||||
- No rollback needed since it's always re-initialized
|
||||
"""
|
||||
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["input_ids"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["seq_lens_decoder"],
|
||||
self.target_model_inputs["step_idx"],
|
||||
self.target_model_inputs["stop_flags"],
|
||||
self.model_inputs["max_dec_len"],
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.num_model_steps,
|
||||
self.role == "prefill", # is_splitwise_prefill
|
||||
)
|
||||
|
||||
target_hidden_states = eagle_get_hidden_states(
|
||||
full_hidden_states,
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.num_model_steps,
|
||||
)
|
||||
|
||||
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
||||
|
||||
def _prepare_inputs_xpu(self, full_hidden_states):
|
||||
use_v1_cache_scheduler = bool(envs.ENABLE_V1_KVCACHE_SCHEDULER)
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
|
||||
Reference in New Issue
Block a user