mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[MTP] refactor MTP pre_process (#6358)
This commit is contained in:
@@ -750,7 +750,14 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["output_cum_offsets"],
|
||||
# Note(ZKK):
|
||||
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
|
||||
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
|
||||
(
|
||||
self.model_inputs["cu_seqlens_q_output"]
|
||||
if current_platform.is_cuda()
|
||||
else self.model_inputs["output_cum_offsets"]
|
||||
),
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["max_dec_len"],
|
||||
@@ -805,8 +812,8 @@ class MTPProposer(Proposer):
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
output_cum_offsets,
|
||||
output_padding_offset,
|
||||
cu_seqlens_q_output,
|
||||
batch_id_per_token_output,
|
||||
) = pre_process(
|
||||
token_num_cpu,
|
||||
self.model_inputs["input_ids"],
|
||||
@@ -841,8 +848,8 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
|
||||
# For speculative decoding
|
||||
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
|
||||
self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
|
||||
|
||||
# Initialize forward meta data
|
||||
self._initialize_forward_meta(
|
||||
@@ -891,8 +898,8 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["output_padding_offset"],
|
||||
self.model_config.max_model_len,
|
||||
self.model_inputs["batch_id_per_token_output"],
|
||||
self.model_inputs["cu_seqlens_q_output"],
|
||||
self.model_inputs["first_token_hidden_states"],
|
||||
self.enable_logprob if substep == 0 else False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user