[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
+43 -17
View File
@@ -71,7 +71,7 @@ else:
set_data_ipc,
unset_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
from fastdeploy.worker.input_batch import (
ProposerInputBatch,
@@ -143,6 +143,7 @@ class MTPProposer(Proposer):
# Forward meta store the global meta information of the forward
self.forward_meta = None
self.exist_prefill_flag = False
def _update_mtp_config(self, main_model):
"""
@@ -499,6 +500,7 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.exist_prefill_flag = True
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
@@ -521,6 +523,7 @@ class MTPProposer(Proposer):
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.exist_prefill_flag = False
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
# NOTE(liuzichang):
@@ -531,9 +534,14 @@ class MTPProposer(Proposer):
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
if current_platform.is_cuda():
async_set_value(
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
)
else:
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
@@ -631,7 +639,6 @@ class MTPProposer(Proposer):
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.get("block_tables"), dtype="int32"
)
self.model_inputs["not_need_stop"][0] = True
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
@@ -706,10 +713,7 @@ class MTPProposer(Proposer):
"""
check whether prefill stage exist
"""
if np.any(self.share_inputs["seq_lens_encoder"].numpy() > 0):
return 1
else:
return 0
return self.exist_prefill_flag
def _prepare_inputs_cuda(self, full_hidden_states):
"""
@@ -729,7 +733,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["not_need_stop_device"],
self.model_inputs["pre_ids"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
@@ -822,7 +826,11 @@ class MTPProposer(Proposer):
else self.model_inputs["output_cum_offsets"]
),
self.model_inputs["stop_flags"],
self.model_inputs["not_need_stop"],
(
self.model_inputs["not_need_stop_device"]
if current_platform.is_cuda()
else self.model_inputs["not_need_stop"]
),
self.model_inputs["max_dec_len"],
self.model_inputs["eos_token_id"],
self.model_inputs["base_model_draft_tokens"],
@@ -858,18 +866,30 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"],
)
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
is_blocking = (
(not self.fd_config.scheduler_config.enable_overlap_schedule)
or is_dummy_run
or self.exist_prefill()
or real_bsz == 0
)
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
if is_blocking:
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
else:
if substep == 0:
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
else:
token_num_cpu = real_bsz
if token_num_cpu > 0:
self.model_inputs["substep"] = substep
# Remove padding
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
(
ids_remove_padding,
batch_id_per_token,
@@ -918,6 +938,7 @@ class MTPProposer(Proposer):
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
self.forward_meta.real_bsz = real_bsz
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
@@ -1034,8 +1055,9 @@ class MTPProposer(Proposer):
else:
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
self.model.empty_input_forward(forward_meta=self.forward_meta)
self.exist_prefill_flag = False
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
@@ -1241,11 +1263,15 @@ class MTPProposer(Proposer):
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
self,
full_hidden_states: paddle.Tensor,
step_use_cudagraph: bool = False,
is_dummy_run: bool = False,
real_bsz: int = 0,
):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()