[Others] Fix PD reorder for MTP (#6792)

* fix pd reorder in mtp

* add ut

* update

* fix mtp
This commit is contained in:
bukejiyu
2026-03-23 21:10:22 +08:00
committed by GitHub
parent 1b276e62d4
commit c62f6b4ea5
5 changed files with 61 additions and 55 deletions
+7 -3
View File
@@ -19,6 +19,7 @@ import socket
import subprocess
import time
import traceback
from functools import partial
from multiprocessing import Process, Queue
import pytest
@@ -54,10 +55,11 @@ def print_logs():
print(f"\n===== {log_file} end =====\n")
def run_with_timeout(target, args, timeout=60 * 5):
def run_with_timeout(target, args=(), kwargs={}, timeout=60 * 5):
clear_logs()
result_queue = Queue()
p = Process(target=target, args=(*args, result_queue))
wrapped_target = partial(target, result_queue=result_queue)
p = Process(target=wrapped_target, args=args, kwargs=kwargs)
p.start()
p.join(timeout)
if p.is_alive():
@@ -86,7 +88,8 @@ def form_model_get_output_topp0(
quantization,
load_choices,
prompts,
result_queue,
speculative_config={},
result_queue=None,
):
try:
with fd_runner(
@@ -96,6 +99,7 @@ def form_model_get_output_topp0(
max_model_len=max_model_len,
load_choices=load_choices,
quantization=quantization,
speculative_config=speculative_config,
) as fd_model:
fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens)
result_queue.put(fd_outputs)