mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -71,6 +71,7 @@ class TestMTPProposer(unittest.TestCase):
|
||||
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
|
||||
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
|
||||
"prompt_lens_cpu": paddle.zeros([2, 1], dtype="int64").pin_memory(),
|
||||
"step_idx": paddle.zeros([2, 1], dtype="int64"),
|
||||
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
|
||||
"token_ids_all": paddle.zeros([2, 2048], dtype="int64"),
|
||||
@@ -379,11 +380,11 @@ class TestMTPProposer(unittest.TestCase):
|
||||
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")
|
||||
|
||||
# Test exist_prefill
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
|
||||
proposer.exist_prefill_flag = True
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
|
||||
proposer.exist_prefill_flag = False
|
||||
result = proposer.exist_prefill()
|
||||
self.assertEqual(result, 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user