[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
@@ -137,6 +137,7 @@ __global__ void apply_token_enforce_generation_scores_kernel(
int tid = threadIdx.x;
const int bs_idx = batch_id_per_token_output[token_idx];
if (bs_idx < 0) return;
const int query_start_token_idx = cu_seqlens_q_output[bs_idx];
bool is_batch_first_token = (token_idx == query_start_token_idx);
+1
View File
@@ -62,6 +62,7 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed;
const int bi = batch_id_per_token_output[out_token_id];
if (bi < 0) continue;
if (seq_len_this_time[bi] == 0) continue;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
@@ -176,8 +176,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
int pre_ids_len = pre_ids.shape()[1];
auto cu_stream = seq_lens_this_time.stream();
int target_model_draft_tokens_len = target_model_draft_tokens.shape()[1];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
draft_model_preprocess_kernel<kBlockSize><<<1, kBlockSize, 0, cu_stream>>>(
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
@@ -187,7 +185,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(not_need_stop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
@@ -205,10 +203,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
target_model_draft_tokens_len,
pre_ids_len,
is_splitwise_prefill);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_preprocess)
@@ -123,8 +123,6 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
auto seq_lens_this_time_shape = seq_lens_this_time.shape();
auto cu_stream = seq_lens_this_time.stream();
const int real_bsz = seq_lens_this_time_shape[0];
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
const int end_ids_len = end_ids.shape()[0];
const int max_draft_token = draft_tokens.shape()[1];
const int pre_id_length = pre_ids.shape()[1];
@@ -149,7 +147,7 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
const_cast<int64_t*>(step_idx.data<int64_t>()),
cu_seqlens_q_output.data<int>(),
const_cast<bool*>(stop_flags.data<bool>()),
not_need_stop_gpu.data<bool>(),
const_cast<bool*>(not_need_stop.data<bool>()),
max_dec_len.data<int64_t>(),
end_ids.data<int64_t>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
@@ -161,11 +159,6 @@ void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
max_seq_len,
substep,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
bool* not_need_stop_data = const_cast<bool*>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(draft_model_update)
@@ -30,6 +30,7 @@ __global__ inline void min_length_logits_process(
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = batch_id_per_token_output[token_idx];
if (bi < 0) return;
if (bi >= bs) return;
const int query_start_token_idx = cu_seqlens_q_output[bi];
@@ -59,6 +60,7 @@ __global__ inline void min_length_logits_process<half>(
const int token_idx = threadIdx.x;
if (token_idx >= token_num) return;
const int bi = batch_id_per_token_output[token_idx];
if (bi < 0) return;
if (bi >= bs) return;
const int query_start_token_idx = cu_seqlens_q_output[bi];
@@ -85,6 +87,7 @@ __global__ void update_repeat_times(const int64_t *token_ids_all,
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = batch_id_per_token_output[token_idx];
if (bi < 0) return;
if (bi >= bs) return;
if (cur_len[bi] < 0) {
return;
@@ -115,6 +118,7 @@ __global__ void update_value_by_repeat_times(
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = batch_id_per_token_output[token_idx];
if (bi < 0) return;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
@@ -146,7 +150,7 @@ __global__ void ban_bad_words(T *logits,
const int token_idx = blockIdx.x;
if (token_idx >= token_num) return;
const int bi = batch_id_per_token_output[token_idx];
if (bi < 0) return;
if (bi >= bs) return;
int tid = threadIdx.x;
T *logits_now = logits + token_idx * length;
@@ -146,10 +146,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
const int bsz = seq_len.shape()[0];
const int max_seq_len = input_ids_shape[1];
const int token_num_data = cpu_token_num;
auto ids_remove_padding = paddle::empty(
{token_num_data}, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::empty(
{token_num_data}, paddle::DataType::INT32, input_ids.place());
auto ids_remove_padding = paddle::full(
{token_num_data}, 2, paddle::DataType::INT64, input_ids.place());
auto batch_id_per_token = paddle::full(
{token_num_data}, -1, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k =
@@ -170,9 +170,10 @@ std::vector<paddle::Tensor> SpeculatePreProcess(
auto cu_seq_lens_q_output =
paddle::empty({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto batch_id_per_token_output =
paddle::empty({bsz * max_draft_tokens_per_batch},
paddle::DataType::INT32,
input_ids.place());
paddle::full({bsz * max_draft_tokens_per_batch},
-1,
paddle::DataType::INT32,
input_ids.place());
auto real_output_token_num =
paddle::empty({1}, paddle::DataType::INT32, input_ids.place());
@@ -134,7 +134,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
prefill_one_step_stop = true;
}
}
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize>
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
@@ -150,7 +149,7 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
const_cast<bool *>(not_need_stop.data<bool>()),
real_bsz,
max_bsz,
max_next_step_tokens,
@@ -159,11 +158,6 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
block_size,
block_num_per_seq,
prefill_one_step_stop);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
@@ -419,6 +419,7 @@ __global__ void KeMatrixTopPBeamTopKFt(
const int lane = tid % 32;
const int token_id = blockIdx.x;
const int bid = batch_id_per_token_output[token_id];
if (bid < 0) return;
int top_num = TopPBeamTopK;
float top_p_value = static_cast<float>(top_ps[bid]);
@@ -190,13 +190,11 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
constexpr int BlockSize = 1024;
// has_running_seqs is CPU tensor, need to copy to GPU first
auto has_running_seqs_gpu =
has_running_seqs.copy_to(seq_lens_this_time.place(), false);
unified_update_model_status_kernel<BlockSize>
<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(has_running_seqs_gpu.data<bool>()),
const_cast<bool *>(has_running_seqs.data<bool>()),
const_cast<int64_t *>(step_input_ids.data<int64_t>()),
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
@@ -213,11 +211,6 @@ void UnifiedUpdateModelStatus(const paddle::Tensor &seq_lens_encoder,
max_step_tokens,
max_model_len,
num_end_tokens);
// Copy result back to CPU
auto has_running_seqs_cpu =
has_running_seqs_gpu.copy_to(has_running_seqs.place(), false);
bool *out_data = const_cast<bool *>(has_running_seqs.data<bool>());
out_data[0] = has_running_seqs_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(unified_update_model_status)
+1 -1
View File
@@ -573,7 +573,7 @@ class EngineArgs:
self.enable_prefix_caching = False
if (
not current_platform.is_cuda()
or self.speculative_config is not None
or (self.speculative_config is not None and self.enable_logprob)
or self.splitwise_role == "prefill"
or self.dynamic_load_weight
):
@@ -160,6 +160,8 @@ class ForwardMeta:
position_ids: Optional[paddle.Tensor] = None
real_bsz: int = 0
def clear_caches(self):
"""Safely clean up the caches"""
if self.caches:
@@ -155,12 +155,15 @@ class CudaGraphPiecewiseBackend:
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
# Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item())
real_bsz = kwargs["forward_meta"].real_bsz
num_running_requests = real_bsz if real_bsz > 0 else int((seq_lens_this_time.flatten() > 0).sum().item())
num_running_requests = max(1, num_running_requests)
real_shape = self.real_bsz_to_captured_size[num_running_requests]
else:
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
@@ -123,7 +123,7 @@ def gather_logprobs(
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu())
def build_output_logprobs(
@@ -1041,7 +1041,7 @@ class SpeculativeSampler(nn.Layer):
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
sampler_output.cu_batch_token_offset = cu_batch_token_offset
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output
def forward_xpu(
@@ -437,8 +437,6 @@ def post_process_specualate(
model_output: ModelOutputData,
share_inputs: InputBatch,
sampling_metadata: SamplingMetadata,
save_each_rank: bool = False,
skip_save_output: bool = False,
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
@@ -508,7 +506,7 @@ def post_process_specualate(
unified_update_model_status(
model_output.seq_lens_encoder, # seq_lens_encoder
model_output.seq_lens_decoder, # seq_lens_decoder
model_output.not_need_stop, # has_running_seqs
model_output.not_need_stop_device, # has_running_seqs
model_output.draft_tokens, # step_input_ids
model_output.accept_tokens, # step_output_ids (read-write)
model_output.accept_num, # step_output_len (read-write)
@@ -522,24 +520,35 @@ def post_process_specualate(
model_output.max_dec_len, # max_dec_len
)
def save_output_specualate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
save_each_rank: bool = False,
skip_save_output: bool = False,
):
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
recover_model_output_map = recover_batch_index_for_output(
model_output,
recover_share_inputs = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
[
"accept_tokens_cpu",
"accept_num_cpu",
"seq_lens_decoder_cpu",
"prompt_lens_cpu",
"last_preempted_idx",
],
)
speculate_save_output(
recover_model_output_map["accept_tokens"],
recover_model_output_map["accept_num"],
recover_share_inputs["accept_tokens_cpu"],
recover_share_inputs["accept_num_cpu"],
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
recover_share_inputs["seq_lens_decoder_cpu"],
recover_share_inputs["prompt_lens_cpu"],
recover_share_inputs["last_preempted_idx"],
model_output.mp_rank,
save_each_rank,
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
@@ -548,30 +557,35 @@ def post_process_specualate(
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
recover_model_output_map = recover_batch_index_for_output(
model_output,
recover_share_inputs = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
[
"sampled_token_ids",
"accept_tokens_cpu",
"accept_num_cpu",
"seq_lens_decoder_cpu",
"prompt_lens_cpu",
"last_preempted_idx",
],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
recover_share_inputs["sampled_token_ids"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
sampler_output.token_num_per_batch,
recover_share_inputs["accept_num_cpu"],
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
recover_share_inputs["seq_lens_decoder_cpu"],
recover_share_inputs["prompt_lens_cpu"],
recover_share_inputs["last_preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
)
share_inputs["last_preempted_idx"][:] = 0
def post_process(
@@ -609,13 +623,12 @@ def post_process(
model_output,
share_inputs,
sampling_metadata,
save_each_rank,
skip_save_output,
think_end_id,
splitwise_role_is_decode,
enable_entropy,
routing_replay_manager,
)
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
else:
post_process_normal(
sampler_or_pooler_output,
+43 -17
View File
@@ -71,7 +71,7 @@ else:
set_data_ipc,
unset_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process
from fastdeploy.model_executor.pre_and_post_process import async_set_value, pre_process
from fastdeploy.worker.input_batch import (
ProposerInputBatch,
@@ -143,6 +143,7 @@ class MTPProposer(Proposer):
# Forward meta store the global meta information of the forward
self.forward_meta = None
self.exist_prefill_flag = False
def _update_mtp_config(self, main_model):
"""
@@ -499,6 +500,7 @@ class MTPProposer(Proposer):
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.exist_prefill_flag = True
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
@@ -521,6 +523,7 @@ class MTPProposer(Proposer):
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generates first token
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.exist_prefill_flag = False
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1
# NOTE(liuzichang):
@@ -531,9 +534,14 @@ class MTPProposer(Proposer):
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
if current_platform.is_cuda():
async_set_value(
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables
)
else:
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
@@ -631,7 +639,6 @@ class MTPProposer(Proposer):
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.get("block_tables"), dtype="int32"
)
self.model_inputs["not_need_stop"][0] = True
self.model_inputs.seq_lens_this_time = self.model_inputs["seq_lens_this_time_buffer"]
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
@@ -706,10 +713,7 @@ class MTPProposer(Proposer):
"""
check whether prefill stage exist
"""
if np.any(self.share_inputs["seq_lens_encoder"].numpy() > 0):
return 1
else:
return 0
return self.exist_prefill_flag
def _prepare_inputs_cuda(self, full_hidden_states):
"""
@@ -729,7 +733,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["not_need_stop_device"],
self.model_inputs["pre_ids"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
@@ -822,7 +826,11 @@ class MTPProposer(Proposer):
else self.model_inputs["output_cum_offsets"]
),
self.model_inputs["stop_flags"],
self.model_inputs["not_need_stop"],
(
self.model_inputs["not_need_stop_device"]
if current_platform.is_cuda()
else self.model_inputs["not_need_stop"]
),
self.model_inputs["max_dec_len"],
self.model_inputs["eos_token_id"],
self.model_inputs["base_model_draft_tokens"],
@@ -858,18 +866,30 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"],
)
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
step_use_cudagraph: bool
Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP.
"""
is_blocking = (
(not self.fd_config.scheduler_config.enable_overlap_schedule)
or is_dummy_run
or self.exist_prefill()
or real_bsz == 0
)
for substep in range(self.num_model_steps):
if self.model_inputs["not_need_stop"]:
if is_blocking:
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
else:
if substep == 0:
token_num_cpu = real_bsz * (self.max_draft_token_num + 1)
else:
token_num_cpu = real_bsz
if token_num_cpu > 0:
self.model_inputs["substep"] = substep
# Remove padding
token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item()
(
ids_remove_padding,
batch_id_per_token,
@@ -918,6 +938,7 @@ class MTPProposer(Proposer):
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
)
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
self.forward_meta.real_bsz = real_bsz
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
@@ -1034,8 +1055,9 @@ class MTPProposer(Proposer):
else:
if hasattr(self.model, "empty_input_forward") and not is_dummy_run:
self.model.empty_input_forward(forward_meta=self.forward_meta)
self.exist_prefill_flag = False
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False):
def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, real_bsz: int = 0):
"""
Main process for MTP inference.
Args:
@@ -1241,11 +1263,15 @@ class MTPProposer(Proposer):
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
self,
full_hidden_states: paddle.Tensor,
step_use_cudagraph: bool = False,
is_dummy_run: bool = False,
real_bsz: int = 0,
):
"""Execute Draft Model"""
self._prepare_inputs(full_hidden_states)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, real_bsz=real_bsz)
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
+69 -38
View File
@@ -97,6 +97,7 @@ from fastdeploy.model_executor.pre_and_post_process import (
pre_process,
rebuild_padding,
save_output_normal,
save_output_specualate,
)
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import (
@@ -270,9 +271,8 @@ class GPUModelRunner(ModelRunnerBase):
# Cached token count for next batch prediction in overlap scheduling.
# Used to avoid synchronization overhead when preparing inputs for the next batch.
self._cached_launch_token_num = -1
self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule and (
not self.speculative_decoding
)
self._cached_real_bsz = -1
self.enable_overlap_schedule = fd_config.scheduler_config.enable_overlap_schedule
if self.enable_overlap_schedule:
logger.info("Using overlap schedule")
self.current_launch_token_num = 0
@@ -305,7 +305,7 @@ class GPUModelRunner(ModelRunnerBase):
return ((seq_lens_decoder > 0) & ~stop_flags).any().cpu().numpy().item()
def _resolve_current_launch_token_num(
self, cached_token_num: int, token_num_event, is_dummy_or_profile_run: bool
self, cached_token_num: int, cached_real_bsz: int, token_num_event, is_dummy_or_profile_run: bool
) -> int:
"""
Resolve token count for current batch.
@@ -322,10 +322,12 @@ class GPUModelRunner(ModelRunnerBase):
or (not self.enable_overlap_schedule)
or self.exist_prefill()
or cached_token_num <= 0
or cached_real_bsz <= 0
):
token_num_event.synchronize()
return self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
return cached_token_num
seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy()
return seq_lens_this_time_cpu.sum().item(), (seq_lens_this_time_cpu > 0).sum().item()
return cached_token_num, cached_real_bsz
def _predict_next_launch_token_num(self) -> int:
"""
@@ -338,11 +340,15 @@ class GPUModelRunner(ModelRunnerBase):
Returns -1 if prediction is not applicable (non-overlap or prefill exists).
"""
if self.exist_prefill():
return -1
return (
self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item()
+ self.share_inputs["is_block_step_cpu"].numpy().sum().item()
return -1, -1
seq_lens_this_time_cpu = self.share_inputs["seq_lens_this_time_cpu"].numpy()
is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy()
next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item()
token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1
next_launch_token_num = (
seq_lens_this_time_cpu.sum().item() + is_block_step_cpu.sum().item() * token_num_one_step
)
return next_launch_token_num, next_real_bsz
def only_prefill(self):
"""
@@ -1112,7 +1118,7 @@ class GPUModelRunner(ModelRunnerBase):
)
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
def _prepare_inputs(self, cached_token_num=-1, is_dummy_or_profile_run=False) -> None:
def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None:
"""Prepare the model inputs"""
if self.enable_mm and self.share_inputs["image_features_list"] is not None:
tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)]
@@ -1160,7 +1166,9 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["is_block_step_cpu"].copy_(self.share_inputs["is_block_step"], False)
token_num_event = paddle.device.cuda.create_event()
token_num_event.record()
token_num = self._resolve_current_launch_token_num(cached_token_num, token_num_event, is_dummy_or_profile_run)
token_num, real_bsz = self._resolve_current_launch_token_num(
cached_token_num, cached_real_bsz, token_num_event, is_dummy_or_profile_run
)
(
ids_remove_padding,
batch_id_per_token,
@@ -1195,6 +1203,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize forward meta data
self.initialize_forward_meta(is_dummy_or_profile_run=is_dummy_or_profile_run)
self.forward_meta.real_bsz = real_bsz
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
@@ -2052,7 +2061,8 @@ class GPUModelRunner(ModelRunnerBase):
) -> None:
model_inputs, p_done_idxs, _ = self._preprocess(model_forward_batch, num_running_requests)
model_output = self._execute(model_inputs)
if model_output is None or self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item() <= 0:
real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item()
if model_output is None or real_bsz <= 0:
if (
self.fd_config.speculative_config.method == SpecMethod.MTP
and hasattr(self.proposer.model, "empty_input_forward")
@@ -2061,9 +2071,9 @@ class GPUModelRunner(ModelRunnerBase):
self._execute_empty_mtp_input(self.forward_meta)
return
model_output_data, sampler_output, post_process_event = self._postprocess(
model_output, p_done_idxs, model_forward_batch, num_running_requests
model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz
)
if model_output_data is not None and not self.speculative_decoding:
if model_output_data is not None:
# synchronizes the async DtoH copies of sampled_token_ids.
post_process_event.synchronize()
self._save_model_output(model_output_data, sampler_output)
@@ -2075,7 +2085,7 @@ class GPUModelRunner(ModelRunnerBase):
) -> None:
# preprocess and execute model (current batch)
model_inputs, p_done_idxs, token_num_event = self._preprocess(
model_forward_batch, num_running_requests, self._cached_launch_token_num
model_forward_batch, num_running_requests, self._cached_launch_token_num, self._cached_real_bsz
)
model_output = self._execute(model_inputs)
# save output (last batch)
@@ -2091,10 +2101,11 @@ class GPUModelRunner(ModelRunnerBase):
# synchronizes the async DtoH copies of seq_lens_this_time_cpu and is_block_step_cpu,
# ensuring that the token count for the current batch is ready to be computed and reused in the subsequent batch.
token_num_event.synchronize()
next_launch_token_num = self._predict_next_launch_token_num()
if self.share_inputs["seq_lens_this_time_cpu"].numpy().sum().item() > 0 and model_output is not None:
next_launch_token_num, next_real_bsz = self._predict_next_launch_token_num()
real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item()
if real_bsz > 0 and model_output is not None:
model_output_data, sampler_output, post_process_event = self._postprocess(
model_output, p_done_idxs, model_forward_batch, num_running_requests
model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz
)
self._cached_model_output_data = model_output_data
self._cached_sampler_output = sampler_output
@@ -2104,12 +2115,14 @@ class GPUModelRunner(ModelRunnerBase):
self._cached_sampler_output = None
self._cached_post_process_event = None
self._cached_launch_token_num = next_launch_token_num
self._cached_real_bsz = next_real_bsz
def _preprocess(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
cached_token_num: int = -1,
cached_real_bsz: int = -1,
) -> None:
if self.deterministic_logger is not None:
self.deterministic_logger.log_batch_start(model_forward_batch)
@@ -2118,7 +2131,7 @@ class GPUModelRunner(ModelRunnerBase):
self._process_reorder()
# Prepare inputs of model and sampler.
current_launch_token_num, token_num_event = self._prepare_inputs(cached_token_num)
current_launch_token_num, token_num_event = self._prepare_inputs(cached_token_num, cached_real_bsz)
self.current_launch_token_num = current_launch_token_num
# NOTE(sunxin):
@@ -2170,12 +2183,13 @@ class GPUModelRunner(ModelRunnerBase):
p_done_idxs: List[int],
model_forward_batch: Optional[List[Request]] = None,
num_running_requests: int = None,
real_bsz: int = 0,
) -> None:
if self.speculative_decoding:
self.output_token_num_event.synchronize()
real_num = int(self._real_output_token_num_host)
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_num]
real_output_token_num = int(self._real_output_token_num_host)
real_batch_id_per_token_output = self.share_inputs["batch_id_per_token_output"][:real_output_token_num]
prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)
if self.is_pooling_model:
@@ -2305,7 +2319,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
int(self._real_output_token_num_host),
real_output_token_num,
self.increment_value,
)
if self.parallel_config.tensor_parallel_size > 1:
@@ -2388,6 +2402,17 @@ class GPUModelRunner(ModelRunnerBase):
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids)
# 5.1. Async cpy
post_process_event = paddle.device.cuda.create_event()
# if not self.speculative_decoding:
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
if self.speculative_decoding:
self.share_inputs["accept_tokens_cpu"].copy_(self.share_inputs["accept_tokens"], False)
self.share_inputs["accept_num_cpu"].copy_(self.share_inputs["accept_num"], False)
self.share_inputs["seq_lens_decoder_cpu"].copy_(self.share_inputs["seq_lens_decoder"], False)
self.share_inputs["prompt_lens_cpu"].copy_(self.share_inputs["prompt_lens"], False)
post_process_event.record()
# 6. Speculative decode -- proposer run (method="naive" has proposer=None, skip)
# For naive mode: seq_lens_this_time is already reset to 1 inside
# unified_update_model_status kernel. For MTP/Ngram, the proposer
@@ -2396,7 +2421,9 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_decoding and self.proposer is not None:
if self.spec_method == SpecMethod.MTP:
self.proposer.run(
full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph
full_hidden_states=model_output,
step_use_cudagraph=self.forward_meta.step_use_cudagraph,
real_bsz=real_bsz,
)
elif self.spec_method == SpecMethod.NAIVE:
pass
@@ -2422,17 +2449,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["accept_num"],
self.share_inputs["accept_tokens"],
self.share_inputs["is_block_step"],
self.share_inputs["not_need_stop"],
self.share_inputs["not_need_stop_device"],
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens,
)
# 8. Async cpy
post_process_event = paddle.device.cuda.create_event()
if not self.speculative_decoding:
self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False)
post_process_event.record()
self.exist_prefill_flag = False
return model_output_data, sampler_output, post_process_event
@@ -2441,13 +2462,23 @@ class GPUModelRunner(ModelRunnerBase):
model_output_data,
sampler_output,
):
save_output_normal(
model_output=model_output_data,
sampler_output=sampler_output,
share_inputs=self.share_inputs,
async_output_queue=self.async_output_queue,
save_each_rank=self.parallel_config.use_ep,
)
if self.speculative_decoding:
skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill"
save_output_specualate(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
save_each_rank=self.parallel_config.use_ep,
skip_save_output=skip_save_output,
)
else:
save_output_normal(
model_output=model_output_data,
sampler_output=sampler_output,
share_inputs=self.share_inputs,
async_output_queue=self.async_output_queue,
save_each_rank=self.parallel_config.use_ep,
)
def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]:
num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
+24
View File
@@ -316,6 +316,15 @@ class InputBatch:
dtype="float32",
)
self.cu_batch_token_offset = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
# For mtp overlap
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
self.accept_tokens_cpu = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
).pin_memory()
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
if self.enable_mm:
head_dim = self.model_config.head_dim
if (
@@ -435,6 +444,10 @@ class InputBatch:
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.seq_lens_decoder_cpu, i1, i2)
swap_data(self.prompt_lens_cpu, i1, i2)
swap_data(self.accept_tokens_cpu, i1, i2)
swap_data(self.accept_num_cpu, i1, i2)
if self.enable_mm:
if self.image_features_list is not None:
@@ -623,6 +636,15 @@ class InputBatch:
fill_paddle_tensor(self, "step_seq_lens_this_time", 0)
fill_paddle_tensor(self, "draft_logits", -1)
fill_paddle_tensor(self, "cu_batch_token_offset", 0)
# for mtp overlap
self.prompt_lens_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int64").pin_memory()
self.seq_lens_decoder_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory()
self.accept_num_cpu = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32").pin_memory()
self.accept_tokens_cpu = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
).pin_memory()
# Reset multimodal related tensors
if self.enable_mm:
@@ -697,6 +719,7 @@ class ProposerInputBatch(InputBatch):
self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"])
self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"])
self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu")
self.not_need_stop_device = paddle.to_tensor([False], dtype="bool")
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
@@ -1085,6 +1108,7 @@ def _recover_tensor(recover_tensor, index_to_batch_id_list):
"""
sort_len = len(index_to_batch_id_list)
if isinstance(recover_tensor.place, paddle.CUDAPinnedPlace):
recover_tensor = recover_tensor.cpu()
recover_res_tensor = paddle.empty_like(recover_tensor, device="cpu")
else:
recover_res_tensor = paddle.empty_like(recover_tensor)
@@ -137,7 +137,7 @@ class TestDraftModelPreprocess(unittest.TestCase):
seq_lens_encoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
seq_lens_decoder = paddle.randint(0, input_ids_len, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
not_need_stop = paddle.zeros([1], dtype="bool").cpu() # must be CPU: kernel writes back via raw pointer
not_need_stop = paddle.zeros([1], dtype="bool")
pre_ids = input_ids.clone()
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
@@ -230,7 +230,7 @@ class TestDraftModelPreprocess(unittest.TestCase):
seq_lens_encoder = paddle.zeros([bsz], dtype="int32") # all decode for simplicity
seq_lens_decoder = paddle.randint(max_draft_token + 1, 100, [bsz], dtype="int32")
step_idx = paddle.randint(0, 100, [bsz], dtype="int64")
not_need_stop = paddle.zeros([1], dtype="bool").cpu() # must be CPU: kernel writes back via raw pointer
not_need_stop = paddle.zeros([1], dtype="bool")
pre_ids = input_ids.clone()
accept_tokens = paddle.randint(0, 100, [bsz, 100], dtype="int64")
+1 -3
View File
@@ -64,13 +64,11 @@ CPU_PLACE = paddle.CPUPlace()
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Convert numpy dict → paddle tensors. not_need_stop stays on CPU."""
"""Convert numpy dict → paddle tensors."""
paddle_inputs = {}
for k, v in inputs.items():
if isinstance(v, (int, bool, float)):
paddle_inputs[k] = v
elif k == "not_need_stop":
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
elif v is not None:
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
return paddle_inputs
@@ -138,7 +138,7 @@ class TestSpeculateScheduleCache(unittest.TestCase):
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool)
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
@@ -53,14 +53,11 @@ CPU_PLACE = paddle.CPUPlace()
def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Convert numpy dict → paddle tensors. has_running_seqs goes to CPU."""
"""Convert numpy dict → paddle tensors."""
paddle_inputs = {}
for k, v in inputs.items():
if isinstance(v, (int, bool, float, str)):
paddle_inputs[k] = v
elif k == "has_running_seqs":
# Kernel host function: has_running_seqs.copy_to(GPU) → kernel → copy_to(CPU)
paddle_inputs[k] = paddle.to_tensor(v, place=CPU_PLACE)
elif v is not None:
paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE)
else:
+3 -2
View File
@@ -71,6 +71,7 @@ class TestMTPProposer(unittest.TestCase):
"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32"),
"seq_lens_decoder": paddle.zeros([2, 1], dtype="int32"),
"prompt_lens": paddle.zeros([2, 1], dtype="int64"),
"prompt_lens_cpu": paddle.zeros([2, 1], dtype="int64").pin_memory(),
"step_idx": paddle.zeros([2, 1], dtype="int64"),
"stop_flags": paddle.zeros([2, 1], dtype="bool"),
"token_ids_all": paddle.zeros([2, 2048], dtype="int64"),
@@ -379,11 +380,11 @@ class TestMTPProposer(unittest.TestCase):
self.assertEqual(proposer.forward_meta.pos_emb_type, "NORMAL")
# Test exist_prefill
proposer.share_inputs = {"seq_lens_encoder": paddle.ones([2, 1], dtype="int32")}
proposer.exist_prefill_flag = True
result = proposer.exist_prefill()
self.assertEqual(result, 1)
proposer.share_inputs = {"seq_lens_encoder": paddle.zeros([2, 1], dtype="int32")}
proposer.exist_prefill_flag = False
result = proposer.exist_prefill()
self.assertEqual(result, 0)