[XPU] Fix PD + MTP (#6495)

* fix pd + mtp

* fix code style

* fix PD + MTP, D get P's first token

* add anno for gpu(speculate_update)

* update draft insertv1

* fix wapper & kernel

* fix wapper

* fix code stype
This commit is contained in:
cmcamdy
2026-02-27 19:07:35 +08:00
committed by GitHub
parent 12f754ef38
commit 13447279aa
21 changed files with 1186 additions and 123 deletions
@@ -26,7 +26,7 @@ from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import ( # step_system_cache,; step_reschedule,
@@ -49,7 +49,7 @@ if current_platform.is_xpu():
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update_v3,
speculate_update,
step_paddle,
update_inputs,
update_inputs_v1,
@@ -408,11 +408,14 @@ def xpu_post_process_normal(
def xpu_post_process_specualate(
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
save_each_rank: bool = False,
skip_save_output: bool = False,
):
""""""
# TODO(chenhuan09): support model_output.next_tokens,
speculate_set_stop_value_multi_seqs(
model_output.accept_tokens,
model_output.accept_num,
@@ -423,9 +426,10 @@ def xpu_post_process_specualate(
model_output.stop_token_ids,
model_output.stop_seqs_len,
model_output.eos_token_id,
model_output.min_tokens,
)
speculate_update_v3(
speculate_update(
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.not_need_stop,
@@ -436,16 +440,24 @@ def xpu_post_process_specualate(
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.is_block_step,
model_output.stop_nums,
model_output.mask_rollback,
)
if not skip_save_output:
speculate_save_output(
model_output.accept_tokens,
model_output.accept_num,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank, # False
)
if sampler_output.logprobs_tensors is None:
speculate_save_output(
model_output.accept_tokens,
model_output.accept_num,
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
share_inputs["preempted_idx"],
model_output.mp_rank,
save_each_rank,
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
)
else:
# TODO(chenhuan09): support speculate_save_output_topk
raise NotImplementedError("Not support speculate_save_output_topk now.")
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)