mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)
* fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
This commit is contained in:
@@ -87,6 +87,8 @@ def draft_model_preprocess_kernel(
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -114,6 +116,7 @@ def draft_model_preprocess_kernel(
|
||||
base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]
|
||||
base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]
|
||||
pre_ids_now = pre_ids[tid]
|
||||
recompute_token_num_now = recompute_token_num[tid]
|
||||
|
||||
base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1
|
||||
|
||||
@@ -156,8 +159,10 @@ def draft_model_preprocess_kernel(
|
||||
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time
|
||||
else:
|
||||
# 2: Last base model generated token and first MTP token
|
||||
seq_lens_decoder[tid] -= num_model_step - 1
|
||||
step_idx[tid] -= num_model_step - 1
|
||||
seq_lens_decoder[tid] -= recompute_token_num_now
|
||||
step_idx[tid] -= recompute_token_num_now
|
||||
mask_rollback[tid] += recompute_token_num_now
|
||||
recompute_token_num[tid] = num_model_step - 1
|
||||
|
||||
for i in range(accept_num_now):
|
||||
draft_tokens_now[i] = accept_tokens_now[i]
|
||||
@@ -187,6 +192,8 @@ def DispatchRunner(
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -244,6 +251,8 @@ def DispatchRunner(
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -273,6 +282,8 @@ def draft_model_preprocess_ref(
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -301,6 +312,8 @@ def draft_model_preprocess_ref(
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -318,7 +331,7 @@ def draft_model_preprocess_ref(
|
||||
)
|
||||
|
||||
|
||||
class TestDraftModelPreprocess:
|
||||
class TestDraftModelPreprocess(unittest.TestCase):
|
||||
def _run_tests(self):
|
||||
paddle.seed(2022)
|
||||
|
||||
@@ -343,6 +356,8 @@ class TestDraftModelPreprocess:
|
||||
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
|
||||
is_block_step = paddle.zeros([bsz], dtype="bool")
|
||||
batch_drop = paddle.zeros([bsz], dtype="bool")
|
||||
mask_rollback = paddle.zeros([bsz], dtype="int32")
|
||||
recompute_token_num = paddle.zeros([bsz], dtype="int32")
|
||||
|
||||
# Output tensors
|
||||
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
|
||||
@@ -371,6 +386,8 @@ class TestDraftModelPreprocess:
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
mask_rollback,
|
||||
recompute_token_num,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
@@ -393,13 +410,8 @@ class TestDraftModelPreprocess:
|
||||
|
||||
def test_draft_model_preprocess(self):
|
||||
results1, results2 = self._run_tests()
|
||||
np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens
|
||||
np.testing.assert_allclose(results1[1], results2[1]) # input_ids
|
||||
np.testing.assert_allclose(results1[2], results2[2]) # stop_flags
|
||||
np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time
|
||||
np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens
|
||||
np.testing.assert_allclose(results1[12], results2[12]) # accept_num
|
||||
np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop
|
||||
for i in range(12):
|
||||
np.testing.assert_equal(results1[i].numpy(), results2[i].numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user