[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
This commit is contained in:
freeliuzc
2025-12-25 17:54:59 +08:00
committed by GitHub
parent 7247dc5f3a
commit 9018ccf74e
6 changed files with 227 additions and 172 deletions
+22 -10
View File
@@ -87,6 +87,8 @@ def draft_model_preprocess_kernel(
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -114,6 +116,7 @@ def draft_model_preprocess_kernel(
base_model_seq_len_decoder = base_model_seq_lens_decoder[tid]
base_model_seq_len_this_time = base_model_seq_lens_this_time[tid]
pre_ids_now = pre_ids[tid]
recompute_token_num_now = recompute_token_num[tid]
base_model_draft_tokens_now[1:base_model_draft_tokens_len] = -1
@@ -156,8 +159,10 @@ def draft_model_preprocess_kernel(
step_idx[tid] = base_model_step_idx[tid] - base_model_seq_len_this_time
else:
# 2: Last base model generated token and first MTP token
seq_lens_decoder[tid] -= num_model_step - 1
step_idx[tid] -= num_model_step - 1
seq_lens_decoder[tid] -= recompute_token_num_now
step_idx[tid] -= recompute_token_num_now
mask_rollback[tid] += recompute_token_num_now
recompute_token_num[tid] = num_model_step - 1
for i in range(accept_num_now):
draft_tokens_now[i] = accept_tokens_now[i]
@@ -187,6 +192,8 @@ def DispatchRunner(
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -244,6 +251,8 @@ def DispatchRunner(
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -273,6 +282,8 @@ def draft_model_preprocess_ref(
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -301,6 +312,8 @@ def draft_model_preprocess_ref(
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -318,7 +331,7 @@ def draft_model_preprocess_ref(
)
class TestDraftModelPreprocess:
class TestDraftModelPreprocess(unittest.TestCase):
def _run_tests(self):
paddle.seed(2022)
@@ -343,6 +356,8 @@ class TestDraftModelPreprocess:
not_need_stop = paddle.zeros([1], dtype="bool").cpu()
is_block_step = paddle.zeros([bsz], dtype="bool")
batch_drop = paddle.zeros([bsz], dtype="bool")
mask_rollback = paddle.zeros([bsz], dtype="int32")
recompute_token_num = paddle.zeros([bsz], dtype="int32")
# Output tensors
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
@@ -371,6 +386,8 @@ class TestDraftModelPreprocess:
is_block_step,
batch_drop,
pre_ids,
mask_rollback,
recompute_token_num,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
@@ -393,13 +410,8 @@ class TestDraftModelPreprocess:
def test_draft_model_preprocess(self):
results1, results2 = self._run_tests()
np.testing.assert_allclose(results1[0], results2[0]) # draft_tokens
np.testing.assert_allclose(results1[1], results2[1]) # input_ids
np.testing.assert_allclose(results1[2], results2[2]) # stop_flags
np.testing.assert_allclose(results1[3], results2[3]) # seq_lens_this_time
np.testing.assert_allclose(results1[11], results2[11]) # accept_tokens
np.testing.assert_allclose(results1[12], results2[12]) # accept_num
np.testing.assert_allclose(results1[7], results2[7]) # not_need_stop
for i in range(12):
np.testing.assert_equal(results1[i].numpy(), results2[i].numpy())
if __name__ == "__main__":