mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[XPU] Refactor pre process (#6993)
* [XPU] support speculate_pre_process * merge develop * fix codestype * fix mtp, support cu_seqlens_q_output * fix mtp, support cu_seqlens_q_output * fix test --------- Co-authored-by: lizan1999 <lizan03@baidu.com>
This commit is contained in:
@@ -1114,6 +1114,8 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||
)
|
||||
|
||||
# TODO(chenhuan): support cached_token_num
|
||||
self.forward_meta = xpu_pre_process(
|
||||
self.share_inputs["input_ids"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
@@ -1577,9 +1579,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
if self.use_cudagraph:
|
||||
model_output = model_output[: self.real_token_num]
|
||||
hidden_states = xpu_process_output(
|
||||
model_output, self.share_inputs["cum_offsets"], self.forward_meta, self.share_inputs
|
||||
)
|
||||
hidden_states = xpu_process_output(model_output, self.forward_meta, self.share_inputs)
|
||||
# 4. Compute logits, Sample
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
sampler_output = None
|
||||
|
||||
Reference in New Issue
Block a user