mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] Fix PD reorder for MTP (#6792)
* fix pd reorder in mtp * add ut * update * fix mtp
This commit is contained in:
@@ -456,13 +456,17 @@ class MTPProposer(Proposer):
|
||||
}
|
||||
)
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
def insert_tasks_v1(
|
||||
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
|
||||
):
|
||||
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
req_len = len(req_dicts)
|
||||
self.model_inputs["num_running_requests"] = num_running_requests
|
||||
self.model_inputs["running_requests_ids"] = range(num_running_requests)
|
||||
if target_model_index_to_batch_id:
|
||||
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
||||
@@ -962,9 +966,8 @@ class MTPProposer(Proposer):
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
self.model_inputs,
|
||||
self.model_inputs.index_to_batch_id,
|
||||
self.model_inputs.enable_pd_reorder[
|
||||
"batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"
|
||||
],
|
||||
self.model_inputs.enable_pd_reorder,
|
||||
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
|
||||
)
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -1081,7 +1084,8 @@ class MTPProposer(Proposer):
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
self.model_inputs,
|
||||
self.model_inputs.index_to_batch_id,
|
||||
self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"],
|
||||
self.model_inputs.enable_pd_reorder,
|
||||
["batch_token_num", "cu_batch_token_offset"],
|
||||
)
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -1244,11 +1248,11 @@ class MTPProposer(Proposer):
|
||||
raise NotImplementedError
|
||||
return cache_type
|
||||
|
||||
def reorder_inputs(self):
|
||||
def reorder_inputs(self, target_model_input_batch):
|
||||
"""
|
||||
Reorder inputs to split prefill and decode.
|
||||
"""
|
||||
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs)
|
||||
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
|
||||
|
||||
def _share_external_data(self, cache, cache_name, cache_shape):
|
||||
if current_platform.is_xpu():
|
||||
|
||||
Reference in New Issue
Block a user