[Optimization] Update ZMQ server (#6735)

* add batch zmq send reaponse

* update

* Revert "update"

This reverts commit 0234a25b47.

* update

* remove lock

* fix unit test

* add unit test

* add unit test

* pre commit

* add unit test

* fix unit test

* add unit test

* fix worker>1

* update zmq_worker_pid

* fix unit test

* fix unit test

* fix unit test

* add unit test

* fix unit test

* fix first token time

* fix logprobs

* add unit test

* op

* remore debug log

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
luukunn
2026-03-19 21:53:16 +08:00
committed by GitHub
parent 9148562ed0
commit c3d8db85c4
18 changed files with 2739 additions and 133 deletions
+145 -30
View File
@@ -17,14 +17,18 @@
import asyncio
import functools
import heapq
import os
import random
import time
import traceback
from multiprocessing.reduction import ForkingPickler
import aiozmq
import msgpack
import zmq
from fastapi import Request
import fastdeploy.envs as envs
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.stats import ZMQMetricsStats
@@ -77,27 +81,59 @@ UVICORN_CONFIG = {
class DealerConnectionManager:
"""
Manager for dealer connections, supporting multiplexing and connection reuse
Manager for dealer connections, supporting multiplexing and connection reuse.
When ZMQ_SEND_BATCH_DATA=1: uses PULL client for batch response reception.
When ZMQ_SEND_BATCH_DATA=0: uses DEALER connections with ROUTER for per-query response.
"""
def __init__(self, pid, max_connections=10):
self.pid = pid
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_map = {} # request_id -> response_queue
self.request_num = {} # request_id -> num_choices
self.lock = asyncio.Lock()
self.connection_tasks = []
self.running = False
if envs.ZMQ_SEND_BATCH_DATA:
# Batch mode: PULL client and dispatcher task
self.pull_client = None
self.dispatcher_task = None
else:
# Per-query mode: DEALER connections
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_num = {} # request_id -> num_choices
self.connection_tasks = []
async def initialize(self):
"""initialize all connections"""
self.running = True
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
if envs.ZMQ_SEND_BATCH_DATA:
# Create PULL client for batch response reception
# Each worker binds on a unique address to avoid PUSH round-robin
try:
self.worker_pid = os.getpid()
self.pull_ipc_path = f"/dev/shm/response_{self.pid}_w{self.worker_pid}.pull"
self.pull_client = await aiozmq.create_zmq_stream(zmq.PULL, bind=f"ipc://{self.pull_ipc_path}")
self.pull_metrics_address = self.pull_client.transport.getsockopt(zmq.LAST_ENDPOINT)
# Start dispatcher task
self.dispatcher_task = asyncio.create_task(self._dispatch_batch_responses())
api_server_logger.info(
f"Started PULL client (bind) for batch response, pid {self.pid}, worker {self.worker_pid}"
)
except Exception as e:
api_server_logger.error(f"Failed to create PULL client: {str(e)}")
# Reset running flag and propagate error to avoid hanging requests in batch mode
self.running = False
raise RuntimeError(
f"Failed to initialize PULL client for batch response "
f"(pid={self.pid}, worker={getattr(self, 'worker_pid', 'unknown')})"
) from e
else:
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
async def _add_connection(self, index):
"""create a new connection and start listening task"""
@@ -175,28 +211,84 @@ class DealerConnectionManager:
return self.connections[conn_index]
async def _dispatch_batch_responses(self):
"""
Receive batch responses and dispatch to corresponding request queues.
batch_data format: [[output, ...], [output, ...], ...]
"""
consecutive_errors = 0
max_consecutive_errors = 5
while self.running:
try:
raw_data = await self.pull_client.read()
# Deserialize batch response
batch_data = msgpack.unpackb(raw_data[-1])
# Record metrics
_zmq_metrics_stats = ZMQMetricsStats()
_zmq_metrics_stats.msg_recv_total += 1
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.pull_metrics_address)
# Parse req_id from outputs and dispatch in a single pass
for outputs in batch_data:
last_output = outputs[-1]
req_id = last_output["request_id"] if isinstance(last_output, dict) else last_output.request_id
if req_id.startswith(("cmpl", "embd", "reward", "chatcmpl")):
req_id = req_id.rsplit("_", 1)[0]
queue = self.request_map.get(req_id)
if queue is not None:
queue.put_nowait(outputs)
consecutive_errors = 0
except (ConnectionError, OSError) as e:
if self.running:
api_server_logger.error(f"Dispatcher: connection lost, exiting: {e}")
break
except Exception as e:
consecutive_errors += 1
if self.running:
api_server_logger.error(
f"Dispatcher error ({consecutive_errors}/{max_consecutive_errors}): "
f"{e}\n{traceback.format_exc()}"
)
if consecutive_errors >= max_consecutive_errors:
api_server_logger.error(f"Dispatcher: {max_consecutive_errors} consecutive errors, exiting")
break
async def get_connection(self, request_id, num_choices=1):
"""get a connection for the request"""
response_queue = asyncio.Queue()
async with self.lock:
self.request_map[request_id] = response_queue
self.request_num[request_id] = num_choices
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")
return dealer, response_queue
if envs.ZMQ_SEND_BATCH_DATA:
async with self.lock:
self.request_map[request_id] = response_queue
return None, response_queue
else:
async with self.lock:
self.request_map[request_id] = response_queue
self.request_num[request_id] = num_choices
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")
return dealer, response_queue
async def cleanup_request(self, request_id):
"""
clean up the request after it is finished
"""
async with self.lock:
if request_id in self.request_map:
del self.request_map[request_id]
del self.request_num[request_id]
try:
async with self.lock:
self.request_map.pop(request_id, None)
if not envs.ZMQ_SEND_BATCH_DATA:
self.request_num.pop(request_id, None)
except asyncio.CancelledError:
# If cancelled during lock acquisition, try cleanup without lock
self.request_map.pop(request_id, None)
if not envs.ZMQ_SEND_BATCH_DATA:
self.request_num.pop(request_id, None)
raise
async def close(self):
"""
@@ -204,18 +296,41 @@ class DealerConnectionManager:
"""
self.running = False
for task in self.connection_tasks:
task.cancel()
if envs.ZMQ_SEND_BATCH_DATA:
# Cancel dispatcher task
if self.dispatcher_task:
self.dispatcher_task.cancel()
async with self.lock:
for dealer in self.connections:
# Close PULL client
if self.pull_client:
try:
dealer.close()
self.pull_client.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
self.request_map.clear()
# Clean up IPC file created by bind
pull_ipc_path = getattr(self, "pull_ipc_path", None)
if pull_ipc_path:
try:
if os.path.exists(pull_ipc_path):
os.remove(pull_ipc_path)
except OSError:
pass
else:
for task in self.connection_tasks:
task.cancel()
async with self.lock:
for dealer in self.connections:
try:
dealer.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
# Clear request map
self.request_map.clear()
api_server_logger.info("All connections and tasks closed")