""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import asyncio import functools import heapq import os import random import time import traceback from multiprocessing.reduction import ForkingPickler import aiozmq import msgpack import zmq from fastapi import Request import fastdeploy.envs as envs from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.stats import ZMQMetricsStats from fastdeploy.utils import FlexibleArgumentParser, api_server_logger UVICORN_CONFIG = { "version": 1, "disable_existing_loggers": False, "formatters": { "custom": { "()": "colorlog.ColoredFormatter", "format": "%(log_color)s%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s%(reset)s", "datefmt": "%Y-%m-%d %H:%M:%S", # Timestamp format "log_colors": { "DEBUG": "cyan", "INFO": "green", "WARNING": "yellow", "ERROR": "red", "CRITICAL": "red,bg_white", }, } }, "handlers": { "default": { "class": "colorlog.StreamHandler", "stream": "ext://sys.stderr", "formatter": "custom", }, }, "loggers": { "uvicorn": { "level": "INFO", "handlers": ["default"], "propagate": False, }, "uvicorn.error": { "level": "INFO", "handlers": ["default"], "propagate": False, }, "uvicorn.access": { "level": "INFO", "handlers": ["default"], "propagate": False, "formatter": "custom", }, }, } class DealerConnectionManager: """ Manager for dealer connections, supporting multiplexing and connection reuse. When ZMQ_SEND_BATCH_DATA=1: uses PULL client for batch response reception. When ZMQ_SEND_BATCH_DATA=0: uses DEALER connections with ROUTER for per-query response. """ def __init__(self, pid, max_connections=10): self.pid = pid self.request_map = {} # request_id -> response_queue self.lock = asyncio.Lock() self.running = False if envs.ZMQ_SEND_BATCH_DATA: # Batch mode: PULL client and dispatcher task self.pull_client = None self.dispatcher_task = None else: # Per-query mode: DEALER connections self.max_connections = max(max_connections, 10) self.connections = [] self.connection_load = [] self.connection_heap = [] self.request_num = {} # request_id -> num_choices self.connection_tasks = [] async def initialize(self): """initialize all connections""" self.running = True if envs.ZMQ_SEND_BATCH_DATA: # Create PULL client for batch response reception # Each worker binds on a unique address to avoid PUSH round-robin try: self.worker_pid = os.getpid() self.pull_ipc_path = f"/dev/shm/response_{self.pid}_w{self.worker_pid}.pull" self.pull_client = await aiozmq.create_zmq_stream(zmq.PULL, bind=f"ipc://{self.pull_ipc_path}") self.pull_metrics_address = self.pull_client.transport.getsockopt(zmq.LAST_ENDPOINT) # Start dispatcher task self.dispatcher_task = asyncio.create_task(self._dispatch_batch_responses()) api_server_logger.info( f"Started PULL client (bind) for batch response, pid {self.pid}, worker {self.worker_pid}" ) except Exception as e: api_server_logger.error(f"Failed to create PULL client: {str(e)}") # Reset running flag and propagate error to avoid hanging requests in batch mode self.running = False raise RuntimeError( f"Failed to initialize PULL client for batch response " f"(pid={self.pid}, worker={getattr(self, 'worker_pid', 'unknown')})" ) from e else: for index in range(self.max_connections): await self._add_connection(index) api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}") async def _add_connection(self, index): """create a new connection and start listening task""" try: dealer = await aiozmq.create_zmq_stream( zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc", ) async with self.lock: self.connections.append(dealer) self.connection_load.append(0) heapq.heappush(self.connection_heap, (0, index)) # start listening task = asyncio.create_task(self._listen_connection(dealer, index)) self.connection_tasks.append(task) return True except Exception as e: api_server_logger.error(f"Failed to create dealer: {str(e)}") return False async def _listen_connection(self, dealer, conn_index): """ listen for messages from the dealer connection """ while self.running: try: raw_data = await dealer.read() response = ForkingPickler.loads(raw_data[-1]) _zmq_metrics_stats = ZMQMetricsStats() _zmq_metrics_stats.msg_recv_total += 1 if "zmq_send_time" in response: _zmq_metrics_stats.zmq_latency = time.perf_counter() - response["zmq_send_time"] address = dealer.transport.getsockopt(zmq.LAST_ENDPOINT) main_process_metrics.record_zmq_stats(_zmq_metrics_stats, address) request_id = response[-1]["request_id"] if request_id[:4] in ["cmpl", "embd"]: request_id = request_id.rsplit("_", 1)[0] elif "reward" == request_id[:6]: request_id = request_id.rsplit("_", 1)[0] elif "chatcmpl" == request_id[:8]: request_id = request_id.rsplit("_", 1)[0] async with self.lock: if request_id in self.request_map: await self.request_map[request_id].put(response) if response[-1]["finished"]: self.request_num[request_id] -= 1 if self.request_num[request_id] == 0: self._update_load(conn_index, -1) else: api_server_logger.warning( f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}" ) except Exception as e: api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}") break api_server_logger.info(f"Listener loop ended for conn_index {conn_index}") def _update_load(self, conn_index, delta): """Update connection load and maintain the heap""" self.connection_load[conn_index] += delta heapq.heapify(self.connection_heap) # For Debugging purposes if random.random() < 0.01: min_load = self.connection_heap[0][0] if self.connection_heap else 0 max_load = max(self.connection_load) if self.connection_load else 0 api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}") def _get_least_loaded_connection(self): """ Get the least loaded connection """ if not self.connection_heap: return None load, conn_index = self.connection_heap[0] self._update_load(conn_index, 1) return self.connections[conn_index] async def _dispatch_batch_responses(self): """ Receive batch responses and dispatch to corresponding request queues. batch_data format: [[output, ...], [output, ...], ...] """ consecutive_errors = 0 max_consecutive_errors = 5 while self.running: try: raw_data = await self.pull_client.read() # Deserialize batch response batch_data = msgpack.unpackb(raw_data[-1]) # Record metrics _zmq_metrics_stats = ZMQMetricsStats() _zmq_metrics_stats.msg_recv_total += 1 main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.pull_metrics_address) # Parse req_id from outputs and dispatch in a single pass for outputs in batch_data: last_output = outputs[-1] req_id = last_output["request_id"] if isinstance(last_output, dict) else last_output.request_id if req_id.startswith(("cmpl", "embd", "reward", "chatcmpl")): req_id = req_id.rsplit("_", 1)[0] queue = self.request_map.get(req_id) if queue is not None: queue.put_nowait(outputs) consecutive_errors = 0 except (ConnectionError, OSError) as e: if self.running: api_server_logger.error(f"Dispatcher: connection lost, exiting: {e}") break except Exception as e: consecutive_errors += 1 if self.running: api_server_logger.error( f"Dispatcher error ({consecutive_errors}/{max_consecutive_errors}): " f"{e}\n{traceback.format_exc()}" ) if consecutive_errors >= max_consecutive_errors: api_server_logger.error(f"Dispatcher: {max_consecutive_errors} consecutive errors, exiting") break async def get_connection(self, request_id, num_choices=1): """get a connection for the request""" response_queue = asyncio.Queue() if envs.ZMQ_SEND_BATCH_DATA: async with self.lock: self.request_map[request_id] = response_queue return None, response_queue else: async with self.lock: self.request_map[request_id] = response_queue self.request_num[request_id] = num_choices dealer = self._get_least_loaded_connection() if not dealer: raise RuntimeError("No available connections") return dealer, response_queue async def cleanup_request(self, request_id): """ clean up the request after it is finished """ try: async with self.lock: self.request_map.pop(request_id, None) if not envs.ZMQ_SEND_BATCH_DATA: self.request_num.pop(request_id, None) except asyncio.CancelledError: # If cancelled during lock acquisition, try cleanup without lock self.request_map.pop(request_id, None) if not envs.ZMQ_SEND_BATCH_DATA: self.request_num.pop(request_id, None) raise async def close(self): """ close all connections and tasks """ self.running = False if envs.ZMQ_SEND_BATCH_DATA: # Cancel dispatcher task if self.dispatcher_task: self.dispatcher_task.cancel() # Close PULL client if self.pull_client: try: self.pull_client.close() except: pass # Clean up IPC file created by bind pull_ipc_path = getattr(self, "pull_ipc_path", None) if pull_ipc_path: try: if os.path.exists(pull_ipc_path): os.remove(pull_ipc_path) except OSError: pass else: for task in self.connection_tasks: task.cancel() async with self.lock: for dealer in self.connections: try: dealer.close() except: pass self.connections.clear() self.connection_load.clear() # Clear request map self.request_map.clear() api_server_logger.info("All connections and tasks closed") def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: _is_multi_server = os.environ.get("FD_ENABLE_MULTI_API_SERVER") == "1" parser.add_argument("--port", default=8000, type=int, help="port to the http server") parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") parser.add_argument("--workers", default=1 if _is_multi_server else 4, type=int, help="number of workers") parser.add_argument("--metrics-port", default=None, type=int, help="port for metrics server") parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") parser.add_argument( "--max-waiting-time", default=-1, type=int, help="max waiting time for connection, if set value -1 means no waiting time limit", ) parser.add_argument( "--max-concurrency", default=512 if _is_multi_server else 2048, type=int, help="max concurrency" ) parser.add_argument( "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " ) parser.add_argument( "--timeout-graceful-shutdown", default=0, type=int, help="timeout for graceful shutdown in seconds (used by gunicorn).Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.", ) parser.add_argument( "--timeout", default=0, type=int, help="Workers silent for more than this many seconds are killed and restarted.Value is a positive number or 0. Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.", ) parser.add_argument("--api-key", type=str, action="append", help="API_KEY required for service authentication") parser = EngineArgs.add_cli_args(parser) return parser async def listen_for_disconnect(request: Request) -> None: """Returns if a disconnect message is received""" while True: message = await request.receive() if message["type"] == "http.disconnect": break def with_cancellation(handler_func): """Decorator that allows a route handler to be cancelled by client disconnections. This does _not_ use request.is_disconnected, which does not work with middleware. Instead this follows the pattern from starlette.StreamingResponse, which simultaneously awaits on two tasks- one to wait for an http disconnect message, and the other to do the work that we want done. When the first task finishes, the other is cancelled. A core assumption of this method is that the body of the request has already been read. This is a safe assumption to make for fastapi handlers that have already parsed the body of the request into a pydantic model for us. This decorator is unsafe to use elsewhere, as it will consume and throw away all incoming messages for the request while it looks for a disconnect message. In the case where a `StreamingResponse` is returned by the handler, this wrapper will stop listening for disconnects and instead the response object will start listening for disconnects.The response object will only correctly listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4). """ # Functools.wraps is required for this wrapper to appear to fastapi as a # normal route handler, with the correct request type hinting. @functools.wraps(handler_func) async def wrapper(*args, **kwargs): # The request is either the second positional arg or `raw_request` request = args[1] if len(args) > 1 else kwargs["req"] handler_task = asyncio.create_task(handler_func(*args, **kwargs)) cancellation_task = asyncio.create_task(listen_for_disconnect(request)) done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED) for task in pending: task.cancel() if handler_task in done: return handler_task.result() return None return wrapper