mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Others] Fix PD reorder for MTP (#6792)
* fix pd reorder in mtp * add ut * update * fix mtp
This commit is contained in:
@@ -456,13 +456,17 @@ class MTPProposer(Proposer):
|
||||
}
|
||||
)
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
def insert_tasks_v1(
|
||||
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
|
||||
):
|
||||
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
req_len = len(req_dicts)
|
||||
self.model_inputs["num_running_requests"] = num_running_requests
|
||||
self.model_inputs["running_requests_ids"] = range(num_running_requests)
|
||||
if target_model_index_to_batch_id:
|
||||
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
||||
@@ -962,9 +966,8 @@ class MTPProposer(Proposer):
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
self.model_inputs,
|
||||
self.model_inputs.index_to_batch_id,
|
||||
self.model_inputs.enable_pd_reorder[
|
||||
"batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"
|
||||
],
|
||||
self.model_inputs.enable_pd_reorder,
|
||||
["batch_token_num", "cu_batch_token_offset", "seq_lens_decoder", "prompt_lens"],
|
||||
)
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -1081,7 +1084,8 @@ class MTPProposer(Proposer):
|
||||
recover_model_output_map = recover_batch_index_for_output(
|
||||
self.model_inputs,
|
||||
self.model_inputs.index_to_batch_id,
|
||||
self.model_inputs.enable_pd_reorder["batch_token_num", "cu_batch_token_offset"],
|
||||
self.model_inputs.enable_pd_reorder,
|
||||
["batch_token_num", "cu_batch_token_offset"],
|
||||
)
|
||||
speculate_save_output_topk(
|
||||
sampler_output.sampled_token_ids,
|
||||
@@ -1244,11 +1248,11 @@ class MTPProposer(Proposer):
|
||||
raise NotImplementedError
|
||||
return cache_type
|
||||
|
||||
def reorder_inputs(self):
|
||||
def reorder_inputs(self, target_model_input_batch):
|
||||
"""
|
||||
Reorder inputs to split prefill and decode.
|
||||
"""
|
||||
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs)
|
||||
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)
|
||||
|
||||
def _share_external_data(self, cache, cache_name, cache_shape):
|
||||
if current_platform.is_xpu():
|
||||
|
||||
@@ -982,7 +982,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.")
|
||||
@@ -1226,7 +1226,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
|
||||
if self.speculative_decoding:
|
||||
if self.spec_method == SpecMethod.MTP:
|
||||
self.proposer.reorder_inputs()
|
||||
self.proposer.reorder_inputs(self.share_inputs.index_to_batch_id)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""load or download model"""
|
||||
|
||||
@@ -429,7 +429,7 @@ class InputBatch:
|
||||
swap_data(self.step_seq_lens_this_time, i1, i2)
|
||||
swap_data(self.draft_logits, i1, i2)
|
||||
swap_data(self.cu_batch_token_offset, i1, i2)
|
||||
swap_data(self.stop_flags, i1, i2)
|
||||
|
||||
if self.enable_mm:
|
||||
if self.image_features_list is not None:
|
||||
self.image_features_list[i1], self.image_features_list[i2] = (
|
||||
@@ -675,7 +675,6 @@ class ProposerInputBatch(InputBatch):
|
||||
def init_share_inputs(self):
|
||||
# share with targe model
|
||||
self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False)
|
||||
self.index_to_batch_id = getattr(self.target_model_input_batch, "index_to_batch_id", {})
|
||||
|
||||
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
|
||||
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
|
||||
@@ -852,6 +851,7 @@ class ProposerInputBatch(InputBatch):
|
||||
tensor[idx1] = tensor[idx2].clone()
|
||||
tensor[idx2] = temp
|
||||
|
||||
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
|
||||
swap_data(self.block_tables, i1, i2)
|
||||
swap_data(self.input_ids, i1, i2)
|
||||
swap_data(self.input_ids_cpu, i1, i2)
|
||||
@@ -859,48 +859,14 @@ class ProposerInputBatch(InputBatch):
|
||||
swap_data(self.seq_lens_encoder, i1, i2)
|
||||
swap_data(self.seq_lens_decoder, i1, i2)
|
||||
swap_data(self.step_idx, i1, i2)
|
||||
swap_data(self.stop_flags, i1, i2)
|
||||
swap_data(self.not_need_stop, i1, i2)
|
||||
swap_data(self.pre_ids, i1, i2)
|
||||
if current_platform.is_cuda():
|
||||
swap_data(self.cu_seqlens_q_output, i1, i2)
|
||||
swap_data(self.batch_id_per_token_output, i1, i2)
|
||||
swap_data(self.token_ids_all, i1, i2)
|
||||
else:
|
||||
swap_data(self.output_cum_offsets, i1, i2)
|
||||
swap_data(self.output_padding_offset, i1, i2)
|
||||
swap_data(self.ids_remove_padding, i1, i2)
|
||||
swap_data(self.batch_id_per_token, i1, i2)
|
||||
swap_data(self.cu_seqlens_q, i1, i2)
|
||||
swap_data(self.cu_seqlens_k, i1, i2)
|
||||
|
||||
swap_data(self.target_hidden_states, i1, i2)
|
||||
|
||||
swap_data(self.draft_tokens, i1, i2)
|
||||
swap_data(self.encoder_block_lens, i1, i2)
|
||||
|
||||
swap_data(self.is_block_step, i1, i2)
|
||||
swap_data(self.batch_drop, i1, i2)
|
||||
swap_data(self.used_list_len, i1, i2)
|
||||
|
||||
if self.num_model_steps > 1:
|
||||
swap_data(self.last_seq_lens_this_time, i1, i2)
|
||||
|
||||
swap_data(self.input_ids_len, i1, i2)
|
||||
swap_data(self.first_token_hidden_states, i1, i2)
|
||||
|
||||
swap_data(self.batch_token_num, i1, i2)
|
||||
swap_data(self.next_token_num, i1, i2)
|
||||
swap_data(self.cu_batch_token_offset, i1, i2)
|
||||
swap_data(self.cu_next_token_offset, i1, i2)
|
||||
swap_data(self.mask_rollback, i1, i2)
|
||||
swap_data(self.recompute_token_num, i1, i2)
|
||||
|
||||
if self.enable_mm:
|
||||
swap_data(self.attn_mask_offsets, i1, i2)
|
||||
swap_data(self.attn_mask_offsets_full, i1, i2)
|
||||
swap_data(self.attn_mask_offsets_decoder, i1, i2)
|
||||
swap_data(self.decode_states, i1, i2)
|
||||
|
||||
def reset_model_inputs(self) -> None:
|
||||
"""
|
||||
@@ -1042,14 +1008,28 @@ class ProposerInputBatch(InputBatch):
|
||||
logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}")
|
||||
|
||||
|
||||
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch):
|
||||
swapped = set()
|
||||
for i, target in input_batch.index_to_batch_id.items():
|
||||
if i in swapped or target in swapped or i == target:
|
||||
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict):
|
||||
mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()}
|
||||
for target_model_id in target_model_input_batch:
|
||||
target_model_index = target_model_input_batch[target_model_id]
|
||||
if input_batch.index_to_batch_id[target_model_id] == target_model_index:
|
||||
continue
|
||||
input_batch.swap_states(i, target)
|
||||
swapped.add(i)
|
||||
swapped.add(target)
|
||||
mtp_id = mtp_index_2_mtp_id[target_model_index]
|
||||
v1 = input_batch.index_to_batch_id[target_model_id]
|
||||
v2 = input_batch.index_to_batch_id[mtp_id]
|
||||
input_batch.swap_states(target_model_id, mtp_id)
|
||||
# update mapping
|
||||
mtp_index_2_mtp_id[v1] = mtp_id
|
||||
mtp_index_2_mtp_id[v2] = target_model_id
|
||||
|
||||
keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys()
|
||||
|
||||
for key in keys_to_remove:
|
||||
del input_batch.index_to_batch_id[key]
|
||||
for k, v in mtp_index_2_mtp_id.items():
|
||||
if v == key:
|
||||
del mtp_index_2_mtp_id[k]
|
||||
break
|
||||
|
||||
|
||||
def reorder_split_prefill_and_decode(input_batch: InputBatch):
|
||||
|
||||
@@ -106,3 +106,21 @@ def test_model_against_baseline(
|
||||
prompts,
|
||||
),
|
||||
)
|
||||
|
||||
mtp_model_path = os.path.join(model_path, "mtp")
|
||||
speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path}
|
||||
_ = run_with_timeout(
|
||||
target=form_model_get_output_topp0,
|
||||
kwargs={
|
||||
"fd_runner": fd_runner,
|
||||
"model_path": model_path,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"max_num_seqs": max_num_seqs,
|
||||
"max_model_len": max_model_len,
|
||||
"max_tokens": max_tokens,
|
||||
"quantization": quantization,
|
||||
"load_choices": "dummy",
|
||||
"prompts": prompts,
|
||||
"speculative_config": speculative_config,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ import socket
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
import pytest
|
||||
@@ -54,10 +55,11 @@ def print_logs():
|
||||
print(f"\n===== {log_file} end =====\n")
|
||||
|
||||
|
||||
def run_with_timeout(target, args, timeout=60 * 5):
|
||||
def run_with_timeout(target, args=(), kwargs={}, timeout=60 * 5):
|
||||
clear_logs()
|
||||
result_queue = Queue()
|
||||
p = Process(target=target, args=(*args, result_queue))
|
||||
wrapped_target = partial(target, result_queue=result_queue)
|
||||
p = Process(target=wrapped_target, args=args, kwargs=kwargs)
|
||||
p.start()
|
||||
p.join(timeout)
|
||||
if p.is_alive():
|
||||
@@ -86,7 +88,8 @@ def form_model_get_output_topp0(
|
||||
quantization,
|
||||
load_choices,
|
||||
prompts,
|
||||
result_queue,
|
||||
speculative_config={},
|
||||
result_queue=None,
|
||||
):
|
||||
try:
|
||||
with fd_runner(
|
||||
@@ -96,6 +99,7 @@ def form_model_get_output_topp0(
|
||||
max_model_len=max_model_len,
|
||||
load_choices=load_choices,
|
||||
quantization=quantization,
|
||||
speculative_config=speculative_config,
|
||||
) as fd_model:
|
||||
fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens)
|
||||
result_queue.put(fd_outputs)
|
||||
|
||||
Reference in New Issue
Block a user