[Others] Fix PD reorder for MTP (#6792)

* fix pd reorder in mtp

* add ut

* update

* fix mtp
This commit is contained in:
bukejiyu
2026-03-23 21:10:22 +08:00
committed by GitHub
parent 1b276e62d4
commit c62f6b4ea5
5 changed files with 61 additions and 55 deletions
+11 -7
View File
@@ -456,13 +456,17 @@ class MTPProposer(Proposer):
}
)
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
def insert_tasks_v1(
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
self.model_inputs["num_running_requests"] = num_running_requests
self.model_inputs["running_requests_ids"] = range(num_running_requests)
if target_model_index_to_batch_id:
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
for i in range(req_len):
request = req_dicts[i]
logger.debug(f"{i}th request-{request.request_id}: {request}")
@@ -962,9 +966,8 @@ class MTPProposer(Proposer):
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder[
"batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"
],
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
@@ -1081,7 +1084,8 @@ class MTPProposer(Proposer):
recover_model_output_map = recover_batch_index_for_output(
self.model_inputs,
self.model_inputs.index_to_batch_id,
self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"],
self.model_inputs.enable_pd_reorder,
["batch_token_num", "cu_batch_token_offset"],
)
speculate_save_output_topk(
sampler_output.sampled_token_ids,
@@ -1244,11 +1248,11 @@ class MTPProposer(Proposer):
raise NotImplementedError
return cache_type
def reorder_inputs(self):
def reorder_inputs(self, target_model_input_batch):
"""
Reorder inputs to split prefill and decode.
"""
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs)
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
def _share_external_data(self, cache, cache_name, cache_shape):
if current_platform.is_xpu():