mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Optimization] Update ZMQ server (#6735)
* add batch zmq send reaponse
* update
* Revert "update"
This reverts commit 0234a25b47.
* update
* remove lock
* fix unit test
* add unit test
* add unit test
* pre commit
* add unit test
* fix unit test
* add unit test
* fix worker>1
* update zmq_worker_pid
* fix unit test
* fix unit test
* fix unit test
* add unit test
* fix unit test
* fix first token time
* fix logprobs
* add unit test
* op
* remore debug log
---------
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -17,14 +17,18 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
|
||||
import aiozmq
|
||||
import msgpack
|
||||
import zmq
|
||||
from fastapi import Request
|
||||
|
||||
import fastdeploy.envs as envs
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.stats import ZMQMetricsStats
|
||||
@@ -77,27 +81,59 @@ UVICORN_CONFIG = {
|
||||
|
||||
class DealerConnectionManager:
|
||||
"""
|
||||
Manager for dealer connections, supporting multiplexing and connection reuse
|
||||
Manager for dealer connections, supporting multiplexing and connection reuse.
|
||||
When ZMQ_SEND_BATCH_DATA=1: uses PULL client for batch response reception.
|
||||
When ZMQ_SEND_BATCH_DATA=0: uses DEALER connections with ROUTER for per-query response.
|
||||
"""
|
||||
|
||||
def __init__(self, pid, max_connections=10):
|
||||
self.pid = pid
|
||||
self.max_connections = max(max_connections, 10)
|
||||
self.connections = []
|
||||
self.connection_load = []
|
||||
self.connection_heap = []
|
||||
self.request_map = {} # request_id -> response_queue
|
||||
self.request_num = {} # request_id -> num_choices
|
||||
self.lock = asyncio.Lock()
|
||||
self.connection_tasks = []
|
||||
self.running = False
|
||||
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
# Batch mode: PULL client and dispatcher task
|
||||
self.pull_client = None
|
||||
self.dispatcher_task = None
|
||||
else:
|
||||
# Per-query mode: DEALER connections
|
||||
self.max_connections = max(max_connections, 10)
|
||||
self.connections = []
|
||||
self.connection_load = []
|
||||
self.connection_heap = []
|
||||
self.request_num = {} # request_id -> num_choices
|
||||
self.connection_tasks = []
|
||||
|
||||
async def initialize(self):
|
||||
"""initialize all connections"""
|
||||
self.running = True
|
||||
for index in range(self.max_connections):
|
||||
await self._add_connection(index)
|
||||
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
|
||||
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
# Create PULL client for batch response reception
|
||||
# Each worker binds on a unique address to avoid PUSH round-robin
|
||||
try:
|
||||
self.worker_pid = os.getpid()
|
||||
self.pull_ipc_path = f"/dev/shm/response_{self.pid}_w{self.worker_pid}.pull"
|
||||
self.pull_client = await aiozmq.create_zmq_stream(zmq.PULL, bind=f"ipc://{self.pull_ipc_path}")
|
||||
self.pull_metrics_address = self.pull_client.transport.getsockopt(zmq.LAST_ENDPOINT)
|
||||
# Start dispatcher task
|
||||
self.dispatcher_task = asyncio.create_task(self._dispatch_batch_responses())
|
||||
api_server_logger.info(
|
||||
f"Started PULL client (bind) for batch response, pid {self.pid}, worker {self.worker_pid}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Failed to create PULL client: {str(e)}")
|
||||
# Reset running flag and propagate error to avoid hanging requests in batch mode
|
||||
self.running = False
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize PULL client for batch response "
|
||||
f"(pid={self.pid}, worker={getattr(self, 'worker_pid', 'unknown')})"
|
||||
) from e
|
||||
else:
|
||||
for index in range(self.max_connections):
|
||||
await self._add_connection(index)
|
||||
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
|
||||
|
||||
async def _add_connection(self, index):
|
||||
"""create a new connection and start listening task"""
|
||||
@@ -175,28 +211,84 @@ class DealerConnectionManager:
|
||||
|
||||
return self.connections[conn_index]
|
||||
|
||||
async def _dispatch_batch_responses(self):
|
||||
"""
|
||||
Receive batch responses and dispatch to corresponding request queues.
|
||||
batch_data format: [[output, ...], [output, ...], ...]
|
||||
"""
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 5
|
||||
while self.running:
|
||||
try:
|
||||
raw_data = await self.pull_client.read()
|
||||
# Deserialize batch response
|
||||
batch_data = msgpack.unpackb(raw_data[-1])
|
||||
|
||||
# Record metrics
|
||||
_zmq_metrics_stats = ZMQMetricsStats()
|
||||
_zmq_metrics_stats.msg_recv_total += 1
|
||||
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.pull_metrics_address)
|
||||
|
||||
# Parse req_id from outputs and dispatch in a single pass
|
||||
for outputs in batch_data:
|
||||
last_output = outputs[-1]
|
||||
req_id = last_output["request_id"] if isinstance(last_output, dict) else last_output.request_id
|
||||
if req_id.startswith(("cmpl", "embd", "reward", "chatcmpl")):
|
||||
req_id = req_id.rsplit("_", 1)[0]
|
||||
queue = self.request_map.get(req_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(outputs)
|
||||
|
||||
consecutive_errors = 0
|
||||
|
||||
except (ConnectionError, OSError) as e:
|
||||
if self.running:
|
||||
api_server_logger.error(f"Dispatcher: connection lost, exiting: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
if self.running:
|
||||
api_server_logger.error(
|
||||
f"Dispatcher error ({consecutive_errors}/{max_consecutive_errors}): "
|
||||
f"{e}\n{traceback.format_exc()}"
|
||||
)
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
api_server_logger.error(f"Dispatcher: {max_consecutive_errors} consecutive errors, exiting")
|
||||
break
|
||||
|
||||
async def get_connection(self, request_id, num_choices=1):
|
||||
"""get a connection for the request"""
|
||||
|
||||
response_queue = asyncio.Queue()
|
||||
|
||||
async with self.lock:
|
||||
self.request_map[request_id] = response_queue
|
||||
self.request_num[request_id] = num_choices
|
||||
dealer = self._get_least_loaded_connection()
|
||||
if not dealer:
|
||||
raise RuntimeError("No available connections")
|
||||
|
||||
return dealer, response_queue
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
async with self.lock:
|
||||
self.request_map[request_id] = response_queue
|
||||
return None, response_queue
|
||||
else:
|
||||
async with self.lock:
|
||||
self.request_map[request_id] = response_queue
|
||||
self.request_num[request_id] = num_choices
|
||||
dealer = self._get_least_loaded_connection()
|
||||
if not dealer:
|
||||
raise RuntimeError("No available connections")
|
||||
return dealer, response_queue
|
||||
|
||||
async def cleanup_request(self, request_id):
|
||||
"""
|
||||
clean up the request after it is finished
|
||||
"""
|
||||
async with self.lock:
|
||||
if request_id in self.request_map:
|
||||
del self.request_map[request_id]
|
||||
del self.request_num[request_id]
|
||||
try:
|
||||
async with self.lock:
|
||||
self.request_map.pop(request_id, None)
|
||||
if not envs.ZMQ_SEND_BATCH_DATA:
|
||||
self.request_num.pop(request_id, None)
|
||||
except asyncio.CancelledError:
|
||||
# If cancelled during lock acquisition, try cleanup without lock
|
||||
self.request_map.pop(request_id, None)
|
||||
if not envs.ZMQ_SEND_BATCH_DATA:
|
||||
self.request_num.pop(request_id, None)
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
@@ -204,18 +296,41 @@ class DealerConnectionManager:
|
||||
"""
|
||||
self.running = False
|
||||
|
||||
for task in self.connection_tasks:
|
||||
task.cancel()
|
||||
if envs.ZMQ_SEND_BATCH_DATA:
|
||||
# Cancel dispatcher task
|
||||
if self.dispatcher_task:
|
||||
self.dispatcher_task.cancel()
|
||||
|
||||
async with self.lock:
|
||||
for dealer in self.connections:
|
||||
# Close PULL client
|
||||
if self.pull_client:
|
||||
try:
|
||||
dealer.close()
|
||||
self.pull_client.close()
|
||||
except:
|
||||
pass
|
||||
self.connections.clear()
|
||||
self.connection_load.clear()
|
||||
self.request_map.clear()
|
||||
|
||||
# Clean up IPC file created by bind
|
||||
pull_ipc_path = getattr(self, "pull_ipc_path", None)
|
||||
if pull_ipc_path:
|
||||
try:
|
||||
if os.path.exists(pull_ipc_path):
|
||||
os.remove(pull_ipc_path)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
for task in self.connection_tasks:
|
||||
task.cancel()
|
||||
|
||||
async with self.lock:
|
||||
for dealer in self.connections:
|
||||
try:
|
||||
dealer.close()
|
||||
except:
|
||||
pass
|
||||
self.connections.clear()
|
||||
self.connection_load.clear()
|
||||
|
||||
# Clear request map
|
||||
self.request_map.clear()
|
||||
|
||||
api_server_logger.info("All connections and tasks closed")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user