[Feature] Support mtp overlap schedule (#7001)

This commit is contained in:
sunxin
2026-04-01 14:24:26 +08:00
committed by GitHub
parent c6f0c5c3a6
commit c29e86fc9d
23 changed files with 215 additions and 138 deletions
@@ -160,6 +160,8 @@ class ForwardMeta:
position_ids: Optional[paddle.Tensor] = None
real_bsz: int = 0
def clear_caches(self):
"""Safely clean up the caches"""
if self.caches:
@@ -155,12 +155,15 @@ class CudaGraphPiecewiseBackend:
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
# Get real shape (total num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
if self.speculative_decoding and all(self.real_bsz_to_captured_size.values()):
seq_lens_this_time: paddle.Tensor = kwargs["forward_meta"].seq_lens_this_time
num_running_requests = int((seq_lens_this_time.flatten() > 0).sum().item())
real_bsz = kwargs["forward_meta"].real_bsz
num_running_requests = real_bsz if real_bsz > 0 else int((seq_lens_this_time.flatten() > 0).sum().item())
num_running_requests = max(1, num_running_requests)
real_shape = self.real_bsz_to_captured_size[num_running_requests]
else:
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
exist_prefill = kwargs["forward_meta"].exist_prefill
# Static split graph mode: use Static + CUDAGraph for prefill/mixed phase
static_cudagraph_for_prefill = exist_prefill and not self.full_cuda_graph and self.dy2st
@@ -123,7 +123,7 @@ def gather_logprobs(
indices = token_ids
top_logprobs = token_logprobs
return LogprobsTensors(indices, top_logprobs, token_ranks)
return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu())
def build_output_logprobs(
@@ -1041,7 +1041,7 @@ class SpeculativeSampler(nn.Layer):
)
sampler_output.logprobs_tensors = logprobs_tensors
if cu_batch_token_offset is not None:
sampler_output.cu_batch_token_offset = cu_batch_token_offset
sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu()
return sampler_output
def forward_xpu(
@@ -437,8 +437,6 @@ def post_process_specualate(
model_output: ModelOutputData,
share_inputs: InputBatch,
sampling_metadata: SamplingMetadata,
save_each_rank: bool = False,
skip_save_output: bool = False,
think_end_id: int = -1,
splitwise_role_is_decode: bool = False,
enable_entropy: bool = False,
@@ -508,7 +506,7 @@ def post_process_specualate(
unified_update_model_status(
model_output.seq_lens_encoder, # seq_lens_encoder
model_output.seq_lens_decoder, # seq_lens_decoder
model_output.not_need_stop, # has_running_seqs
model_output.not_need_stop_device, # has_running_seqs
model_output.draft_tokens, # step_input_ids
model_output.accept_tokens, # step_output_ids (read-write)
model_output.accept_num, # step_output_len (read-write)
@@ -522,24 +520,35 @@ def post_process_specualate(
model_output.max_dec_len, # max_dec_len
)
def save_output_specualate(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: InputBatch,
save_each_rank: bool = False,
skip_save_output: bool = False,
):
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
recover_model_output_map = recover_batch_index_for_output(
model_output,
recover_share_inputs = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["accept_tokens", "accept_num", "seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
[
"accept_tokens_cpu",
"accept_num_cpu",
"seq_lens_decoder_cpu",
"prompt_lens_cpu",
"last_preempted_idx",
],
)
speculate_save_output(
recover_model_output_map["accept_tokens"],
recover_model_output_map["accept_num"],
recover_share_inputs["accept_tokens_cpu"],
recover_share_inputs["accept_num_cpu"],
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
recover_share_inputs["seq_lens_decoder_cpu"],
recover_share_inputs["prompt_lens_cpu"],
recover_share_inputs["last_preempted_idx"],
model_output.mp_rank,
save_each_rank,
bool(envs.ENABLE_V1_KVCACHE_SCHEDULER),
@@ -548,30 +557,35 @@ def post_process_specualate(
recover_batch_index_for_sampler_output(
sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder
)
recover_model_output_map = recover_batch_index_for_output(
model_output,
recover_share_inputs = recover_batch_index_for_output(
share_inputs,
model_output.index_to_batch_id,
model_output.enable_pd_reorder,
["seq_lens_decoder", "prompt_lens"],
)
recover_share_inputs = recover_batch_index_for_output(
share_inputs, model_output.index_to_batch_id, model_output.enable_pd_reorder, ["preempted_idx"]
[
"sampled_token_ids",
"accept_tokens_cpu",
"accept_num_cpu",
"seq_lens_decoder_cpu",
"prompt_lens_cpu",
"last_preempted_idx",
],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
recover_share_inputs["sampled_token_ids"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
sampler_output.token_num_per_batch,
recover_share_inputs["accept_num_cpu"],
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_model_output_map["seq_lens_decoder"],
recover_model_output_map["prompt_lens"],
recover_share_inputs["preempted_idx"],
recover_share_inputs["seq_lens_decoder_cpu"],
recover_share_inputs["prompt_lens_cpu"],
recover_share_inputs["last_preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
)
share_inputs["last_preempted_idx"][:] = 0
def post_process(
@@ -609,13 +623,12 @@ def post_process(
model_output,
share_inputs,
sampling_metadata,
save_each_rank,
skip_save_output,
think_end_id,
splitwise_role_is_decode,
enable_entropy,
routing_replay_manager,
)
share_inputs["last_preempted_idx"].copy_(share_inputs["preempted_idx"])
else:
post_process_normal(
sampler_or_pooler_output,