[Optimization] Update ZMQ server (#6735)

* add batch zmq send reaponse

* update

* Revert "update"

This reverts commit 0234a25b47.

* update

* remove lock

* fix unit test

* add unit test

* add unit test

* pre commit

* add unit test

* fix unit test

* add unit test

* fix worker>1

* update zmq_worker_pid

* fix unit test

* fix unit test

* fix unit test

* add unit test

* fix unit test

* fix first token time

* fix logprobs

* add unit test

* op

* remore debug log

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
luukunn
2026-03-19 21:53:16 +08:00
committed by GitHub
parent 9148562ed0
commit c3d8db85c4
18 changed files with 2739 additions and 133 deletions
+72 -14
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
import collections
import copy
import json
import multiprocessing
@@ -1117,13 +1118,25 @@ class EngineService:
self.internal_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
# ROUTER mode: need to receive client handles
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
if envs.ZMQ_SEND_BATCH_DATA:
# PUSH mode: batch send, no need to receive client handles
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PUSH)
# Mapping from request_id to worker_pid for routing batch responses
self.request_worker_map = {}
else:
# ROUTER mode: per-query send, need to receive client handles
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread.start()
@@ -1161,8 +1174,17 @@ class EngineService:
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
continue
# Extract zmq_worker_pid for per-worker PUSH routing.
# Only needed when ZMQ_SEND_BATCH_DATA=True AND not using internal adapter,
# because FD_ENABLE_INTERNAL_ADAPTER uses ROUTER (worker_pid is irrelevant).
worker_pid = None
if envs.ZMQ_SEND_BATCH_DATA and not envs.FD_ENABLE_INTERNAL_ADAPTER:
worker_pid = data["zmq_worker_pid"]
if ControlRequest.is_control_request(data):
try: # todo: run control request async, do not block request generation
if worker_pid is not None:
self.request_worker_map[data.get("request_id")] = worker_pid
control_req = ControlRequest.from_dict(data)
self.run_control_method(control_req)
except Exception as e:
@@ -1175,6 +1197,11 @@ class EngineService:
request, insert_task = data, []
results: List[Tuple[str, Optional[str]]] = list()
if data:
# Store worker_pid mapping for normal/abort requests
if worker_pid is not None:
req_id_for_map = data.get("request_id")
if req_id_for_map:
self.request_worker_map[req_id_for_map] = worker_pid
status_value = data.get("status", None)
if status_value is not None and status_value == RequestStatus.ABORT.value:
req_id = data["request_id"]
@@ -1200,7 +1227,9 @@ class EngineService:
if self.is_paused:
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
self._send_error_response(
request.request_id, "Request is aborted since LLM Engine is paused."
request.request_id,
"Request is aborted since LLM Engine is paused.",
worker_pid=worker_pid,
)
continue
except Exception as e:
@@ -1261,6 +1290,11 @@ class EngineService:
method = control_req.get_method()
request_id = control_req.request_id
# Look up worker_pid for routing control response
worker_pid = None
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
worker_pid = self.request_worker_map.pop(request_id, None)
try:
self.llm_logger.info(f"START run control method {request_id}: {method}")
@@ -1269,19 +1303,22 @@ class EngineService:
if handler is None or not callable(handler):
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
self.llm_logger.error(str(error_result))
self.send_response_server.send_response(request_id, [error_result])
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
return
result = handler(control_req)
self.llm_logger.info(f"SUCCESS run control method {method}.")
succ_result = ControlResponse(request_id, 200, "Success", result)
self.send_response_server.send_response(request_id, [succ_result])
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
except Exception as e:
error_msg = f"Failed run control method {method}: {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
self.send_response_server.send_response(request_id, [error_result])
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
def _control_pause(self, control_request: ControlRequest):
"""Pauses the LLM engine and aborts all running/inflight requests.
@@ -1454,7 +1491,7 @@ class EngineService:
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
self.llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
)
@@ -1464,10 +1501,15 @@ class EngineService:
error_code=error_code,
error_msg=error_msg,
)
# Look up worker_pid from mapping if not provided
if worker_pid is None and envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
worker_pid = self.request_worker_map.pop(request_id, None)
# Since the request is not in scheduler
# Send result by zmq directly
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.send_response_server.send_response(None, [[error_result]])
elif envs.ZMQ_SEND_BATCH_DATA:
self.send_response_server.send_response(None, [[error_result]], worker_pid=worker_pid)
else:
self.send_response_server.send_response(request_id, [error_result])
@@ -1529,6 +1571,7 @@ class EngineService:
self.send_response_server.send_response(None, new_contents)
else:
worker_batches = collections.defaultdict(list)
for request_id, contents in results.items():
new_contents = []
for content in contents:
@@ -1537,7 +1580,9 @@ class EngineService:
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
token_ids=content.outputs.token_ids,
req_id=request_id,
is_end=content.finished,
)
else:
token_ids = content.outputs.token_ids
@@ -1553,9 +1598,19 @@ class EngineService:
)
else:
new_contents.append(content)
if len(new_contents):
self.llm_logger.debug(f"Send response for request id: {request_id}")
self.send_response_server.send_response(request_id, new_contents)
if new_contents:
if envs.ZMQ_SEND_BATCH_DATA:
wpid = self.request_worker_map.get(request_id)
worker_batches[wpid].append(new_contents)
is_finished = any(getattr(c, "finished", False) for c in new_contents)
if is_finished:
self.request_worker_map.pop(request_id, None)
else:
self.send_response_server.send_response(request_id, new_contents)
if envs.ZMQ_SEND_BATCH_DATA:
for wpid, batch_data in worker_batches.items():
if batch_data:
self.send_response_server.send_response(None, batch_data, worker_pid=wpid)
except Exception as e:
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -1714,6 +1769,9 @@ class EngineService:
self.cache_task_queue.clear_transfer_task()
self.send_response_server.req_dict.clear()
self.recv_request_server.req_dict.clear()
# Clean up worker_pid mapping (batch mode)
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
self.request_worker_map.clear()
self.llm_logger.info("Clear Data: Successfully")
return True
except Exception as e: