mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support mtp overlap schedule (#7001)
This commit is contained in:
@@ -71,7 +71,7 @@ else:
|
||||
set_data_ipc,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process
|
||||
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
|
||||
|
||||
from fastdeploy.worker.input_batch import (
|
||||
ProposerInputBatch,
|
||||
@@ -143,6 +143,7 @@ class MTPProposer(Proposer):
|
||||
|
||||
# Forward meta store the global meta information of the forward
|
||||
self.forward_meta = None
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
def _update_mtp_config(self, main_model):
|
||||
"""
|
||||
@@ -499,6 +500,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.exist_prefill_flag = True
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
@@ -521,6 +523,7 @@ class MTPProposer(Proposer):
|
||||
self.fd_config.scheduler_config.splitwise_role == "decode"
|
||||
): # In PD, we continue to decode after P generates first token
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.exist_prefill_flag = False
|
||||
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
|
||||
# NOTE(liuzichang):
|
||||
@@ -531,9 +534,14 @@ class MTPProposer(Proposer):
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
async_set_value(
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
|
||||
)
|
||||
else:
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
# has_decode_task = True
|
||||
# continue
|
||||
@@ -631,7 +639,6 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.get("block_tables"), dtype="int32"
|
||||
)
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
|
||||
|
||||
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
|
||||
@@ -706,10 +713,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if np.any(self.share_inputs["seq_lens_encoder"].numpy() > 0):
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
return self.exist_prefill_flag
|
||||
|
||||
def _prepare_inputs_cuda(self, full_hidden_states):
|
||||
"""
|
||||
@@ -729,7 +733,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["not_need_stop_device"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
@@ -822,7 +826,11 @@ class MTPProposer(Proposer):
|
||||
else self.model_inputs["output_cum_offsets"]
|
||||
),
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
(
|
||||
self.model_inputs["not_need_stop_device"]
|
||||
if current_platform.is_cuda()
|
||||
else self.model_inputs["not_need_stop"]
|
||||
),
|
||||
self.model_inputs["max_dec_len"],
|
||||
self.model_inputs["eos_token_id"],
|
||||
self.model_inputs["base_model_draft_tokens"],
|
||||
@@ -858,18 +866,30 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
step_use_cudagraph: bool
|
||||
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
|
||||
"""
|
||||
is_blocking = (
|
||||
(not self.fd_config.scheduler_config.enable_overlap_schedule)
|
||||
or is_dummy_run
|
||||
or self.exist_prefill()
|
||||
or real_bsz == 0
|
||||
)
|
||||
for substep in range(self.num_model_steps):
|
||||
if self.model_inputs["not_need_stop"]:
|
||||
if is_blocking:
|
||||
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
|
||||
else:
|
||||
if substep == 0:
|
||||
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
|
||||
else:
|
||||
token_num_cpu = real_bsz
|
||||
if token_num_cpu > 0:
|
||||
self.model_inputs["substep"] = substep
|
||||
# Remove padding
|
||||
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
|
||||
(
|
||||
ids_remove_padding,
|
||||
batch_id_per_token,
|
||||
@@ -918,6 +938,7 @@ class MTPProposer(Proposer):
|
||||
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
|
||||
)
|
||||
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||
self.forward_meta.real_bsz = real_bsz
|
||||
|
||||
# Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
@@ -1034,8 +1055,9 @@ class MTPProposer(Proposer):
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
|
||||
self.model.empty_input_forward(forward_meta=self.forward_meta)
|
||||
self.exist_prefill_flag = False
|
||||
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
|
||||
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
|
||||
"""
|
||||
Main process for MTP inference.
|
||||
Args:
|
||||
@@ -1241,11 +1263,15 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(
|
||||
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
|
||||
self,
|
||||
full_hidden_states: paddle.Tensor,
|
||||
step_use_cudagraph: bool = False,
|
||||
is_dummy_run: bool = False,
|
||||
real_bsz: int = 0,
|
||||
):
|
||||
"""Execute Draft Model"""
|
||||
self._prepare_inputs(full_hidden_states)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
|
||||
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
|
||||
Reference in New Issue
Block a user