mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes (#5738)
* fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
This commit is contained in:
@@ -493,6 +493,12 @@ class MTPProposer(Proposer):
|
||||
shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32"
|
||||
)
|
||||
self.model_inputs["mask_rollback"] = paddle.full([self.max_num_seqs, 1], 0, dtype="int32")
|
||||
# NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed
|
||||
# using the target model's hidden states.
|
||||
self.model_inputs["recompute_token_num"] = paddle.full(
|
||||
[self.max_num_seqs, 1], self.num_model_steps - 1, dtype="int32"
|
||||
)
|
||||
|
||||
# attn_mask
|
||||
if self.enable_mm:
|
||||
self.model_inputs["attn_mask_offsets"] = paddle.full(
|
||||
@@ -562,7 +568,9 @@ class MTPProposer(Proposer):
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generates first token
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
# P-D split need rollback one step
|
||||
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
|
||||
# NOTE(liuzichang):
|
||||
# extra 1 : P-D split need rollback one step
|
||||
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
|
||||
|
||||
# has_prefill_task = True
|
||||
@@ -758,6 +766,8 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["is_block_step"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.model_inputs["mask_rollback"],
|
||||
self.model_inputs["recompute_token_num"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
|
||||
Reference in New Issue
Block a user