[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
+24
View File
@@ -316,6 +316,15 @@ class InputBatch:
dtype="float32",
)
self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
# For mtp overlap
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
self.accept_tokens_cpu = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
).pin_memory()
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
if self.enable_mm:
head_dim = self.model_config.head_dim
if (
@@ -435,6 +444,10 @@ class InputBatch:
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.seq_lens_decoder_cpu, i1, i2)
swap_data(self.prompt_lens_cpu, i1, i2)
swap_data(self.accept_tokens_cpu, i1, i2)
swap_data(self.accept_num_cpu, i1, i2)
if self.enable_mm:
if self.image_features_list is not None:
@@ -623,6 +636,15 @@ class InputBatch:
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
fill_paddle_tensor(self, "draft_logits", -1)
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
# for mtp overlap
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
self.accept_tokens_cpu = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
).pin_memory()
# Reset multimodal related tensors
if self.enable_mm:
@@ -697,6 +719,7 @@ class ProposerInputBatch(InputBatch):
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
@@ -1085,6 +1108,7 @@ def _recover_tensor(recover_tensor, index_to_batch_id_list):
"""
sort_len = len(index_to_batch_id_list)
if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace):
recover_tensor = recover_tensor.cpu()
recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu")
else:
recover_res_tensor = paddle.empty_like(recover_tensor)