[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
This commit is contained in:
freeliuzc
2025-12-25 17:54:59 +08:00
committed by GitHub
parent 7247dc5f3a
commit 9018ccf74e
6 changed files with 227 additions and 172 deletions
+11 -1
View File
@@ -493,6 +493,12 @@ class MTPProposer(Proposer):
shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32"
)
self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
# NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed
# using the target model's hidden states.
self.model_inputs["recompute_token_num"] = paddle.full(
[self.max_num_seqs, 1], self.num_model_steps - 1, dtype="int32"
)
# attn_mask
if self.enable_mm:
self.model_inputs["attn_mask_offsets"] = paddle.full(
@@ -562,7 +568,9 @@ class MTPProposer(Proposer):
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
# P-D split need rollback one step
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
# has_prefill_task = True
@@ -758,6 +766,8 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"],
self.model_inputs["is_block_step"],
self.model_inputs["pre_ids"],
self.model_inputs["mask_rollback"],
self.model_inputs["recompute_token_num"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],