mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Speculative Decoding] Support mtp super ultra overlap in pd-split mode with insert_task overlap (#7323)
* support mtp overlap in pd-split mode with insert_task overlap
This commit is contained in:
@@ -49,7 +49,10 @@ if current_platform.is_xpu():
|
||||
share_external_data,
|
||||
update_attn_mask_offsets,
|
||||
)
|
||||
|
||||
# temporary solution
|
||||
from fastdeploy.model_executor.xpu_pre_and_post_process import (
|
||||
async_set_value,
|
||||
xpu_pre_process,
|
||||
xpu_process_output,
|
||||
)
|
||||
@@ -483,28 +486,32 @@ class MTPProposer(Proposer):
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
|
||||
self.model_inputs["input_ids_len"][idx] = length - 1
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length].cpu()
|
||||
# TODO: use token_all_ids replace with input_ids_cpu
|
||||
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length].cpu()
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
|
||||
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
||||
async_set_value(
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
||||
)
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False)
|
||||
async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False)
|
||||
|
||||
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length)
|
||||
self.exist_prefill_flag = True
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index)
|
||||
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length)
|
||||
async_set_value(
|
||||
self.model_inputs["step_idx"][idx : idx + 1],
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0,
|
||||
)
|
||||
if self.use_attn_mask_offset:
|
||||
inputs = request.multimodal_inputs
|
||||
@@ -522,18 +529,19 @@ class MTPProposer(Proposer):
|
||||
if (
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generates first token
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
|
||||
self.exist_prefill_flag = False
|
||||
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
|
||||
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1)
|
||||
# NOTE(liuzichang):
|
||||
# extra 1 : P-D split need rollback one step
|
||||
self.model_inputs["mask_rollback"][idx : idx + 1] = 1
|
||||
|
||||
async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0)
|
||||
async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1)
|
||||
# has_prefill_task = True
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
|
||||
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
||||
if current_platform.is_cuda():
|
||||
async_set_value(
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
||||
@@ -542,16 +550,13 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
# has_decode_task = True
|
||||
# continue
|
||||
else:
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = True
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["is_block_step"][idx : idx + 1] = False
|
||||
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
|
||||
async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True)
|
||||
async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0)
|
||||
async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0)
|
||||
async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0)
|
||||
async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False)
|
||||
continue
|
||||
|
||||
# TODO(liuzichang): Solve splitewise-p bug to restore
|
||||
@@ -1233,6 +1238,7 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO: replace with gpu tensor
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids_cpu"].cuda(),
|
||||
self.model_inputs["input_ids_len"].cuda(),
|
||||
|
||||
Reference in New Issue
Block a user