[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
+14 -7
View File
@@ -750,7 +750,14 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["output_cum_offsets"],
# Note(ZKK):
# I strongly advise xpu student delete the fuck `output_cum_offsets` name in XPU backend
# like my pr https://github.com/PaddlePaddle/FastDeploy/pull/6358
(
self.model_inputs["cu_seqlens_q_output"]
if current_platform.is_cuda()
else self.model_inputs["output_cum_offsets"]
),
self.model_inputs["stop_flags"],
self.model_inputs["not_need_stop"],
self.model_inputs["max_dec_len"],
@@ -805,8 +812,8 @@ class MTPProposer(Proposer):
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
output_padding_offset,
cu_seqlens_q_output,
batch_id_per_token_output,
) = pre_process(
token_num_cpu,
self.model_inputs["input_ids"],
@@ -841,8 +848,8 @@ class MTPProposer(Proposer):
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# For speculative decoding
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
self.model_inputs["cu_seqlens_q_output"].copy_(cu_seqlens_q_output, False)
self.model_inputs["batch_id_per_token_output"].copy_(batch_id_per_token_output, False)
# Initialize forward meta data
self._initialize_forward_meta(
@@ -891,8 +898,8 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["output_padding_offset"],
self.model_config.max_model_len,
self.model_inputs["batch_id_per_token_output"],
self.model_inputs["cu_seqlens_q_output"],
self.model_inputs["first_token_hidden_states"],
self.enable_logprob if substep == 0 else False,
)