[MTP] refactor MTP pre_process (#6358)

This commit is contained in:
周周周
2026-02-09 10:47:15 +08:00
committed by GitHub
parent 18e79dd660
commit 2b4748de4f
24 changed files with 411 additions and 533 deletions
+30 -11
View File
@@ -272,12 +272,20 @@ class InputBatch:
fill_value=max_draft_token_num,
dtype="int32",
)
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.output_padding_offset = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
self.batch_id_per_token_output = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
else:
self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.output_padding_offset = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER
self.step_draft_tokens = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
@@ -404,7 +412,10 @@ class InputBatch:
swap_data(self.accept_num, i1, i2)
swap_data(self.draft_tokens, i1, i2)
swap_data(self.actual_draft_token_num, i1, i2)
swap_data(self.output_cum_offsets, i1, i2)
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
else:
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.step_draft_tokens, i1, i2)
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
@@ -512,8 +523,12 @@ class ProposerInputBatch(InputBatch):
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"])
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
else:
self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"])
self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"])
self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"])
self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"])
self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"])
@@ -659,8 +674,12 @@ class ProposerInputBatch(InputBatch):
swap_data(self.stop_flags, i1, i2)
swap_data(self.not_need_stop, i1, i2)
swap_data(self.pre_ids, i1, i2)
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.output_padding_offset, i1, i2)
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.batch_id_per_token_output, i1, i2)
else:
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.output_padding_offset, i1, i2)
swap_data(self.ids_remove_padding, i1, i2)
swap_data(self.batch_id_per_token, i1, i2)
swap_data(self.cu_seqlens_q, i1, i2)