[Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442)

* support mtp logprob in pd

* fix

* fix

* fix

* fix xpu bugs
This commit is contained in:
GoldPancake
2026-04-17 21:37:38 +08:00
committed by GitHub
parent 3b9d6c60d3
commit df3b4e12f4
7 changed files with 389 additions and 78 deletions
+20 -18
View File
@@ -65,7 +65,6 @@ else:
eagle_get_self_hidden_states,
eagle_gather_hidden_states,
hybrid_mtp_ngram,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
@@ -840,23 +839,26 @@ class MTPProposer(Proposer):
)
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
)
mtp_save_first_token(
recover_model_output_map["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_model_output_map["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
if current_platform.is_xpu():
# Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner.
# Only XPU platform is retained here.
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
)
mtp_save_first_token(
recover_model_output_map["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_model_output_map["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
# Ensure only save first token once.
paddle.assign(
paddle.where(