[Speculative Decoding] Add MTP logprob support for PD disaggregation (#7442)

* support mtp logprob in pd

* fix

* fix

* fix

* fix xpu bugs
This commit is contained in:
GoldPancake
2026-04-17 21:37:38 +08:00
committed by GitHub
parent 3b9d6c60d3
commit df3b4e12f4
7 changed files with 389 additions and 78 deletions
@@ -22,9 +22,14 @@ import paddle
from fastdeploy import envs
from fastdeploy.config import SpeculativeConfig
from fastdeploy.model_executor.ops.gpu import (
mtp_save_first_token,
mtp_save_first_token_with_topk,
)
from fastdeploy.platforms import current_platform
from fastdeploy.worker.input_batch import (
InputBatch,
ProposerInputBatch,
recover_batch_index_for_output,
recover_batch_index_for_sampler_output,
)
@@ -522,10 +527,76 @@ def save_output_specualate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
proposer_share_inputs: ProposerInputBatch,
local_rank: int,
tensor_parallel_rank: int,
save_each_rank: bool = False,
skip_save_output: bool = False,
is_mtp_prefill: bool = False,
):
if not skip_save_output:
if is_mtp_prefill:
if tensor_parallel_rank == 0:
skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
if sampler_output.logprobs_tensors is None:
recover_proposer_share_inputs_map = recover_batch_index_for_output(
proposer_share_inputs,
proposer_share_inputs.index_to_batch_id,
proposer_share_inputs.enable_pd_reorder,
[
"base_model_draft_tokens",
"seq_lens_decoder",
"prompt_lens",
"step_idx",
],
)
mtp_save_first_token(
recover_proposer_share_inputs_map["base_model_draft_tokens"],
proposer_share_inputs["not_need_stop"],
recover_proposer_share_inputs_map["seq_lens_decoder"],
recover_proposer_share_inputs_map["prompt_lens"],
recover_proposer_share_inputs_map["step_idx"],
local_rank,
save_each_rank,
skip_chunk_prefill,
)
else:
recover_share_inputs_map = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
[
"sampled_token_ids",
"accept_tokens_cpu",
"accept_num_cpu",
"seq_lens_decoder_cpu",
"prompt_lens_cpu",
"last_preempted_idx",
],
)
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
recover_proposer_share_inputs_map = recover_batch_index_for_output(
proposer_share_inputs,
proposer_share_inputs.index_to_batch_id,
proposer_share_inputs.enable_pd_reorder,
["base_model_draft_tokens"],
)
mtp_save_first_token_with_topk(
recover_proposer_share_inputs_map["base_model_draft_tokens"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
recover_share_inputs_map["accept_num_cpu"],
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_share_inputs_map["seq_lens_decoder_cpu"],
recover_share_inputs_map["prompt_lens_cpu"],
recover_share_inputs_map["last_preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
)
else:
if sampler_output.logprobs_tensors is None:
recover_share_inputs = recover_batch_index_for_output(
share_inputs,
+20 -18
View File
@@ -65,7 +65,6 @@ else:
eagle_get_self_hidden_states,
eagle_gather_hidden_states,
hybrid_mtp_ngram,
mtp_save_first_token,
mtp_step_paddle,
share_external_data,
speculate_get_logits,
@@ -840,23 +839,26 @@ class MTPProposer(Proposer):
)
if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0:
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
)
mtp_save_first_token(
recover_model_output_map["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_model_output_map["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
if current_platform.is_xpu():
# Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner.
# Only XPU platform is retained here.
skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER))
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder,
["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"],
)
mtp_save_first_token(
recover_model_output_map["base_model_draft_tokens"],
self.model_inputs["not_need_stop"],
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_model_output_map["step_idx"],
self.local_rank,
self.parallel_config.use_ep,
skip_save,
)
# Ensure only save first token once.
paddle.assign(
paddle.where(
+6 -2
View File
@@ -2552,13 +2552,17 @@ class GPUModelRunner(ModelRunnerBase):
sampler_output,
):
if self.speculative_decoding:
skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill"
save_output_specualate(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
proposer_share_inputs=self.proposer.model_inputs,
local_rank=self.local_rank,
tensor_parallel_rank=self.parallel_config.tensor_parallel_rank,
save_each_rank=self.parallel_config.use_ep,
skip_save_output=skip_save_output,
is_mtp_prefill=(
self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill"
),
)
else:
save_output_normal(