[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)

* support mtp overlap in pd-split mode with insert_task overlap
This commit is contained in:
freeliuzc
2026-04-13 19:41:17 +08:00
committed by GitHub
parent 5ddd1af756
commit 31e2a8bbad
6 changed files with 351 additions and 122 deletions
+36 -30
View File
@@ -49,7 +49,10 @@ if current_platform.is_xpu():
share_external_data,
update_attn_mask_offsets,
)
# temporary solution
from fastdeploy.model_executor.xpu_pre_and_post_process import (
async_set_value,
xpu_pre_process,
xpu_process_output,
)
@@ -483,28 +486,32 @@ class MTPProposer(Proposer):
input_ids = request.prompt_token_ids + request.output_token_ids
self.model_inputs["input_ids_len"][idx] = length - 1
self.model_inputs["pre_ids"][idx : idx + 1] = -1
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
# TODO: use token_all_ids replace with input_ids_cpu
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
async_set_value(
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
)
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False)
async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False)
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length)
self.exist_prefill_flag = True
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index)
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length)
async_set_value(
self.model_inputs["step_idx"][idx : idx + 1],
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0,
)
if self.use_attn_mask_offset:
inputs = request.multimodal_inputs
@@ -522,18 +529,19 @@ class MTPProposer(Proposer):
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
self.exist_prefill_flag = False
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1)
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0)
async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1)
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
if current_platform.is_cuda():
async_set_value(
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
@@ -542,16 +550,13 @@ class MTPProposer(Proposer):
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
else:
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["stop_flags"][idx : idx + 1] = True
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["is_block_step"][idx : idx + 1] = False
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True)
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0)
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0)
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False)
continue
# TODO(liuzichang): Solve splitewise-p bug to restore
@@ -1233,6 +1238,7 @@ class MTPProposer(Proposer):
)
def _extend_draft_token_with_ngram_match(self):
# TODO: replace with gpu tensor
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"].cuda(),
self.model_inputs["input_ids_len"].cuda(),