mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Speculative Decoding with PD (#5856)
* [XPU] Speculative Decoding with PD * fix post process * share kv cache sender * support speculate decoding step system cache * support speculate decoding step system cache --------- Co-authored-by: root <root@gajl-bbc-onlinec-com-1512108.gajl.baidu.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.model_executor.forward_meta import XPUForwardMeta
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
@@ -28,7 +29,7 @@ from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import (
|
||||
from fastdeploy.model_executor.ops.xpu import ( # step_system_cache,; step_reschedule,
|
||||
adjust_batch,
|
||||
gather_next_token,
|
||||
get_infer_param,
|
||||
@@ -45,11 +46,14 @@ if current_platform.is_xpu():
|
||||
speculate_save_output,
|
||||
speculate_set_value_by_flags_and_idx,
|
||||
speculate_step_paddle,
|
||||
speculate_step_reschedule,
|
||||
speculate_step_system_cache,
|
||||
speculate_update_v3,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
)
|
||||
DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"
|
||||
|
||||
|
||||
def _build_stream_transfer_data(
|
||||
@@ -232,7 +236,6 @@ def xpu_process_output(
|
||||
""" """
|
||||
|
||||
output_padding_offset = share_inputs.get("output_padding_offset", None)
|
||||
|
||||
hiddden_states = gather_next_token(
|
||||
forward_output,
|
||||
cum_offsets,
|
||||
@@ -434,42 +437,101 @@ def step_xpu(
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int,
|
||||
enc_dec_block_num: int,
|
||||
speculative_decoding: bool,
|
||||
max_draft_token_num: int,
|
||||
speculative_config: SpeculativeConfig,
|
||||
enable_prefix_caching: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
TODO(chenhuan09): support PD
|
||||
"""
|
||||
if speculative_decoding:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
max_draft_token_num,
|
||||
)
|
||||
if speculative_config.method is not None:
|
||||
if DISABLE_RECOVER:
|
||||
speculate_step_reschedule(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
if enable_prefix_caching:
|
||||
speculate_step_system_cache(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
speculate_step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["step_seq_lens_encoder"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
share_inputs["seq_lens_decoder"],
|
||||
share_inputs["block_tables"],
|
||||
share_inputs["encoder_block_lens"],
|
||||
share_inputs["is_block_step"],
|
||||
share_inputs["step_block_list"],
|
||||
share_inputs["step_lens"],
|
||||
share_inputs["recover_block_list"],
|
||||
share_inputs["recover_lens"],
|
||||
share_inputs["need_block_list"],
|
||||
share_inputs["need_block_len"],
|
||||
share_inputs["used_list_len"],
|
||||
share_inputs["free_list"],
|
||||
share_inputs["free_list_len"],
|
||||
share_inputs["input_ids"],
|
||||
share_inputs["pre_ids"],
|
||||
share_inputs["step_idx"],
|
||||
share_inputs["next_tokens"],
|
||||
share_inputs["first_token_ids"],
|
||||
share_inputs["accept_num"],
|
||||
block_size,
|
||||
enc_dec_block_num,
|
||||
speculative_config.num_speculative_tokens,
|
||||
)
|
||||
else:
|
||||
# TODO(chenhuan09): add step system cache/reschedule support
|
||||
step_paddle(
|
||||
share_inputs["stop_flags"],
|
||||
share_inputs["seq_lens_this_time"],
|
||||
|
||||
Reference in New Issue
Block a user