[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
+3 -2
View File
@@ -71,6 +71,7 @@ class TestMTPProposer(unittest.TestCase):
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
"prompt_lens_cpu": paddle.zeros([2, 1], dtype="int64").pin_memory(),
"step_idx": paddle.zeros([2, 1], dtype="int64"),
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
"token_ids_all": paddle.zeros([2, 2048], dtype="int64"),
@@ -379,11 +380,11 @@ class TestMTPProposer(unittest.TestCase):
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")
# Test exist_prefill
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
proposer.exist_prefill_flag = True
result = proposer.exist_prefill()
self.assertEqual(result, 1)
proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
proposer.exist_prefill_flag = False
result = proposer.exist_prefill()
self.assertEqual(result, 0)