[XPU] Speculative Decoding with PD (#5856)

* [XPU] Speculative Decoding with PD

* fix post process

* share kv cache sender

* support speculate decoding step system cache

* support speculate decoding step system cache

---------

Co-authored-by: root <root@gajl-bbc-onlinec-com-1512108.gajl.baidu.com>
This commit is contained in:
cmcamdy
2026-01-05 17:31:03 +08:00
committed by GitHub
parent ac39c0f887
commit 690d4bcdb0
5 changed files with 193 additions and 74 deletions
@@ -21,6 +21,7 @@ import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.config import SpeculativeConfig
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
from fastdeploy.model_executor.layers.sample.sampler import Sampler
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
@@ -28,7 +29,7 @@ from fastdeploy.platforms import current_platform
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
from fastdeploy.model_executor.ops.xpu import ( # step_system_cache,; step_reschedule,
adjust_batch,
gather_next_token,
get_infer_param,
@@ -45,11 +46,14 @@ if current_platform.is_xpu():
speculate_save_output,
speculate_set_value_by_flags_and_idx,
speculate_step_paddle,
speculate_step_reschedule,
speculate_step_system_cache,
speculate_update_v3,
step_paddle,
update_inputs,
update_inputs_v1,
)
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
def _build_stream_transfer_data(
@@ -232,7 +236,6 @@ def xpu_process_output(
""" """
output_padding_offset = share_inputs.get("output_padding_offset", None)
hiddden_states = gather_next_token(
forward_output,
cum_offsets,
@@ -434,42 +437,101 @@ def step_xpu(
share_inputs: Dict[str, paddle.Tensor],
block_size: int,
enc_dec_block_num: int,
speculative_decoding: bool,
max_draft_token_num: int,
speculative_config: SpeculativeConfig,
enable_prefix_caching: bool = False,
) -> None:
"""
TODO(chenhuan09): support PD
"""
if speculative_decoding:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
max_draft_token_num,
)
if speculative_config.method is not None:
if DISABLE_RECOVER:
speculate_step_reschedule(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
if enable_prefix_caching:
speculate_step_system_cache(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["step_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
speculate_step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["step_seq_lens_encoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["encoder_block_lens"],
share_inputs["is_block_step"],
share_inputs["step_block_list"],
share_inputs["step_lens"],
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["need_block_list"],
share_inputs["need_block_len"],
share_inputs["used_list_len"],
share_inputs["free_list"],
share_inputs["free_list_len"],
share_inputs["input_ids"],
share_inputs["pre_ids"],
share_inputs["step_idx"],
share_inputs["next_tokens"],
share_inputs["first_token_ids"],
share_inputs["accept_num"],
block_size,
enc_dec_block_num,
speculative_config.num_speculative_tokens,
)
else:
# TODO(chenhuan09): add step system cache/reschedule support
step_paddle(
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],