mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-24 01:29:57 +08:00
[XPU] Fix PD + MTP (#6495)
* fix pd + mtp * fix code style * fix PD + MTP, D get P's first token * add anno for gpu(speculate_update) * update draft insertv1 * fix wapper & kernel * fix wapper * fix code stype
This commit is contained in:
@@ -26,7 +26,7 @@ from fastdeploy.model_executor.forward_meta import XPUForwardMeta
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler
|
||||
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData
|
||||
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from fastdeploy.model_executor.ops.xpu import ( # step_system_cache,; step_reschedule,
|
||||
@@ -49,7 +49,7 @@ if current_platform.is_xpu():
|
||||
speculate_step_paddle,
|
||||
speculate_step_reschedule,
|
||||
speculate_step_system_cache,
|
||||
speculate_update_v3,
|
||||
speculate_update,
|
||||
step_paddle,
|
||||
update_inputs,
|
||||
update_inputs_v1,
|
||||
@@ -408,11 +408,14 @@ def xpu_post_process_normal(
|
||||
|
||||
|
||||
def xpu_post_process_specualate(
|
||||
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
):
|
||||
""""""
|
||||
|
||||
# TODO(chenhuan09): support model_output.next_tokens,
|
||||
speculate_set_stop_value_multi_seqs(
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
@@ -423,9 +426,10 @@ def xpu_post_process_specualate(
|
||||
model_output.stop_token_ids,
|
||||
model_output.stop_seqs_len,
|
||||
model_output.eos_token_id,
|
||||
model_output.min_tokens,
|
||||
)
|
||||
|
||||
speculate_update_v3(
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.not_need_stop,
|
||||
@@ -436,16 +440,24 @@ def xpu_post_process_specualate(
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.is_block_step,
|
||||
model_output.stop_nums,
|
||||
model_output.mask_rollback,
|
||||
)
|
||||
if not skip_save_output:
|
||||
speculate_save_output(
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
save_each_rank, # False
|
||||
)
|
||||
if sampler_output.logprobs_tensors is None:
|
||||
speculate_save_output(
|
||||
model_output.accept_tokens,
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.prompt_lens,
|
||||
share_inputs["preempted_idx"],
|
||||
model_output.mp_rank,
|
||||
save_each_rank,
|
||||
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
|
||||
)
|
||||
else:
|
||||
# TODO(chenhuan09): support speculate_save_output_topk
|
||||
raise NotImplementedError("Not support speculate_save_output_topk now.")
|
||||
|
||||
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user