[Others] Fix PD reorder for MTP (#6792)

* fix pd reorder in mtp

* add ut

* update

* fix mtp
This commit is contained in:
bukejiyu
2026-03-23 21:10:22 +08:00
committed by GitHub
parent 1b276e62d4
commit c62f6b4ea5
5 changed files with 61 additions and 55 deletions
+11 -7
View File
@@ -456,13 +456,17 @@ class MTPProposer(Proposer):
}
)
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
def insert_tasks_v1(
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
self.model_inputs["num_running_requests"] = num_running_requests
self.model_inputs["running_requests_ids"] = range(num_running_requests)
if target_model_index_to_batch_id:
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
for i in range(req_len):
request = req_dicts[i]
logger.debug(f"{i}th request-{request.request_id}: {request}")
@@ -962,9 +966,8 @@ class MTPProposer(Proposer):
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder[
"batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"
],
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
@@ -1081,7 +1084,8 @@ class MTPProposer(Proposer):
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"],
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
@@ -1244,11 +1248,11 @@ class MTPProposer(Proposer):
raise NotImplementedError
return cache_type
def reorder_inputs(self):
def reorder_inputs(self, target_model_input_batch):
"""
Reorder inputs to split prefill and decode.
"""
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs)
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
def _share_external_data(self, cache, cache_name, cache_shape):
if current_platform.is_xpu():
+2 -2
View File
@@ -982,7 +982,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.")
@@ -1226,7 +1226,7 @@ class GPUModelRunner(ModelRunnerBase):
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
if self.speculative_decoding:
if self.spec_method == SpecMethod.MTP:
self.proposer.reorder_inputs()
self.proposer.reorder_inputs(self.share_inputs.index_to_batch_id)
def load_model(self) -> None:
"""load or download model"""
+23 -43
View File
@@ -429,7 +429,7 @@ class InputBatch:
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.stop_flags, i1, i2)
if self.enable_mm:
if self.image_features_list is not None:
self.image_features_list[i1], self.image_features_list[i2] = (
@@ -675,7 +675,6 @@ class ProposerInputBatch(InputBatch):
def init_share_inputs(self):
# share with targe model
self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False)
self.index_to_batch_id = getattr(self.target_model_input_batch, "index_to_batch_id", {})
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
@@ -852,6 +851,7 @@ class ProposerInputBatch(InputBatch):
tensor[idx1] = tensor[idx2].clone()
tensor[idx2] = temp
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
swap_data(self.block_tables, i1, i2)
swap_data(self.input_ids, i1, i2)
swap_data(self.input_ids_cpu, i1, i2)
@@ -859,48 +859,14 @@ class ProposerInputBatch(InputBatch):
swap_data(self.seq_lens_encoder, i1, i2)
swap_data(self.seq_lens_decoder, i1, i2)
swap_data(self.step_idx, i1, i2)
swap_data(self.stop_flags, i1, i2)
swap_data(self.not_need_stop, i1, i2)
swap_data(self.pre_ids, i1, i2)
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.batch_id_per_token_output, i1, i2)
swap_data(self.token_ids_all, i1, i2)
else:
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.output_padding_offset, i1, i2)
swap_data(self.ids_remove_padding, i1, i2)
swap_data(self.batch_id_per_token, i1, i2)
swap_data(self.cu_seqlens_q, i1, i2)
swap_data(self.cu_seqlens_k, i1, i2)
swap_data(self.target_hidden_states, i1, i2)
swap_data(self.draft_tokens, i1, i2)
swap_data(self.encoder_block_lens, i1, i2)
swap_data(self.is_block_step, i1, i2)
swap_data(self.batch_drop, i1, i2)
swap_data(self.used_list_len, i1, i2)
if self.num_model_steps > 1:
swap_data(self.last_seq_lens_this_time, i1, i2)
swap_data(self.input_ids_len, i1, i2)
swap_data(self.first_token_hidden_states, i1, i2)
swap_data(self.batch_token_num, i1, i2)
swap_data(self.next_token_num, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.cu_next_token_offset, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm:
swap_data(self.attn_mask_offsets, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
swap_data(self.decode_states, i1, i2)
def reset_model_inputs(self) -> None:
"""
@@ -1042,14 +1008,28 @@ class ProposerInputBatch(InputBatch):
logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}")
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch):
swapped = set()
for i, target in input_batch.index_to_batch_id.items():
if i in swapped or target in swapped or i == target:
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict):
mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()}
for target_model_id in target_model_input_batch:
target_model_index = target_model_input_batch[target_model_id]
if input_batch.index_to_batch_id[target_model_id] == target_model_index:
continue
input_batch.swap_states(i, target)
swapped.add(i)
swapped.add(target)
mtp_id = mtp_index_2_mtp_id[target_model_index]
v1 = input_batch.index_to_batch_id[target_model_id]
v2 = input_batch.index_to_batch_id[mtp_id]
input_batch.swap_states(target_model_id, mtp_id)
# update mapping
mtp_index_2_mtp_id[v1] = mtp_id
mtp_index_2_mtp_id[v2] = target_model_id
keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys()
for key in keys_to_remove:
del input_batch.index_to_batch_id[key]
for k, v in mtp_index_2_mtp_id.items():
if v == key:
del mtp_index_2_mtp_id[k]
break
def reorder_split_prefill_and_decode(input_batch: InputBatch):
+18
View File
@@ -106,3 +106,21 @@ def test_model_against_baseline(
prompts,
),
)
mtp_model_path = os.path.join(model_path, "mtp")
speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path}
_ = run_with_timeout(
target=form_model_get_output_topp0,
kwargs={
"fd_runner": fd_runner,
"model_path": model_path,
"tensor_parallel_size": tensor_parallel_size,
"max_num_seqs": max_num_seqs,
"max_model_len": max_model_len,
"max_tokens": max_tokens,
"quantization": quantization,
"load_choices": "dummy",
"prompts": prompts,
"speculative_config": speculative_config,
},
)
+7 -3
View File
@@ -19,6 +19,7 @@ import socket
import subprocess
import time
import traceback
from functools import partial
from multiprocessing import Process, Queue
import pytest
@@ -54,10 +55,11 @@ def print_logs():
print(f"\n===== {log_file} end =====\n")
def run_with_timeout(target, args, timeout=60 * 5):
def run_with_timeout(target, args=(), kwargs={}, timeout=60 * 5):
clear_logs()
result_queue = Queue()
p = Process(target=target, args=(*args, result_queue))
wrapped_target = partial(target, result_queue=result_queue)
p = Process(target=wrapped_target, args=args, kwargs=kwargs)
p.start()
p.join(timeout)
if p.is_alive():
@@ -86,7 +88,8 @@ def form_model_get_output_topp0(
quantization,
load_choices,
prompts,
result_queue,
speculative_config={},
result_queue=None,
):
try:
with fd_runner(
@@ -96,6 +99,7 @@ def form_model_get_output_topp0(
max_model_len=max_model_len,
load_choices=load_choices,
quantization=quantization,
speculative_config=speculative_config,
) as fd_model:
fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens)
result_queue.put(fd_outputs)