Files
FastDeploy/fastdeploy/entrypoints/openai/utils.py
T
Yonghua Li a7f52c300d [Feature] support v1 update/clear api for RL (#6761)
* [Feature] support v1 update/clear api for RL

* [fix] fix execute_model and add sleep/wakeup api

* [fix] fix mtp and key_prefix

* [chore] move _update_key_prefix to resume method

* [fix] make the interface safe to call multiple times

* [fix] fix some tiny bugs

* [chore] make small changes against pr review

* [docs] add docs for weight update

* [test] add some tests and update docs

* [style] fix code style check

* [test] fix ci

* [fix] fix stale control responses when control method timed out

* [chore] remove unused code

* [chore] fix code style

* [chore] optimize tags and key_prefix

* [test] fix ci

* [chore] fix code style

* [test] fix ci

* [fix] fix ep control

* [fix] fix ep control for engine cache queue
2026-03-25 19:18:46 +08:00

430 lines
17 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import functools
import heapq
import os
import random
import time
import traceback
from multiprocessing.reduction import ForkingPickler
import aiozmq
import msgpack
import zmq
from fastapi import Request
import fastdeploy.envs as envs
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.metrics.stats import ZMQMetricsStats
from fastdeploy.utils import FlexibleArgumentParser, api_server_logger
UVICORN_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"custom": {
"()": "colorlog.ColoredFormatter",
"format": "%(log_color)s%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s%(reset)s",
"datefmt": "%Y-%m-%d %H:%M:%S", # Timestamp format
"log_colors": {
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
}
},
"handlers": {
"default": {
"class": "colorlog.StreamHandler",
"stream": "ext://sys.stderr",
"formatter": "custom",
},
},
"loggers": {
"uvicorn": {
"level": "INFO",
"handlers": ["default"],
"propagate": False,
},
"uvicorn.error": {
"level": "INFO",
"handlers": ["default"],
"propagate": False,
},
"uvicorn.access": {
"level": "INFO",
"handlers": ["default"],
"propagate": False,
"formatter": "custom",
},
},
}
class DealerConnectionManager:
"""
Manager for dealer connections, supporting multiplexing and connection reuse.
When ZMQ_SEND_BATCH_DATA=1: uses PULL client for batch response reception.
When ZMQ_SEND_BATCH_DATA=0: uses DEALER connections with ROUTER for per-query response.
"""
def __init__(self, pid, max_connections=10):
self.pid = pid
self.request_map = {} # request_id -> response_queue
self.lock = asyncio.Lock()
self.running = False
if envs.ZMQ_SEND_BATCH_DATA:
# Batch mode: PULL client and dispatcher task
self.pull_client = None
self.dispatcher_task = None
else:
# Per-query mode: DEALER connections
self.max_connections = max(max_connections, 10)
self.connections = []
self.connection_load = []
self.connection_heap = []
self.request_num = {} # request_id -> num_choices
self.connection_tasks = []
async def initialize(self):
"""initialize all connections"""
self.running = True
if envs.ZMQ_SEND_BATCH_DATA:
# Create PULL client for batch response reception
# Each worker binds on a unique address to avoid PUSH round-robin
try:
self.worker_pid = os.getpid()
self.pull_ipc_path = f"/dev/shm/response_{self.pid}_w{self.worker_pid}.pull"
self.pull_client = await aiozmq.create_zmq_stream(zmq.PULL, bind=f"ipc://{self.pull_ipc_path}")
self.pull_metrics_address = self.pull_client.transport.getsockopt(zmq.LAST_ENDPOINT)
# Start dispatcher task
self.dispatcher_task = asyncio.create_task(self._dispatch_batch_responses())
api_server_logger.info(
f"Started PULL client (bind) for batch response, pid {self.pid}, worker {self.worker_pid}"
)
except Exception as e:
api_server_logger.error(f"Failed to create PULL client: {str(e)}")
# Reset running flag and propagate error to avoid hanging requests in batch mode
self.running = False
raise RuntimeError(
f"Failed to initialize PULL client for batch response "
f"(pid={self.pid}, worker={getattr(self, 'worker_pid', 'unknown')})"
) from e
else:
for index in range(self.max_connections):
await self._add_connection(index)
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
async def _add_connection(self, index):
"""create a new connection and start listening task"""
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
)
async with self.lock:
self.connections.append(dealer)
self.connection_load.append(0)
heapq.heappush(self.connection_heap, (0, index))
# start listening
task = asyncio.create_task(self._listen_connection(dealer, index))
self.connection_tasks.append(task)
return True
except Exception as e:
api_server_logger.error(f"Failed to create dealer: {str(e)}")
return False
async def _listen_connection(self, dealer, conn_index):
"""
listen for messages from the dealer connection
"""
while self.running:
try:
raw_data = await dealer.read()
response = ForkingPickler.loads(raw_data[-1])
_zmq_metrics_stats = ZMQMetricsStats()
_zmq_metrics_stats.msg_recv_total += 1
if "zmq_send_time" in response:
_zmq_metrics_stats.zmq_latency = time.perf_counter() - response["zmq_send_time"]
address = dealer.transport.getsockopt(zmq.LAST_ENDPOINT)
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, address)
request_id = response[-1]["request_id"]
if request_id[:4] in ["cmpl", "embd"]:
request_id = request_id.rsplit("_", 1)[0]
elif "reward" == request_id[:6]:
request_id = request_id.rsplit("_", 1)[0]
elif "chatcmpl" == request_id[:8]:
request_id = request_id.rsplit("_", 1)[0]
async with self.lock:
if request_id in self.request_map:
await self.request_map[request_id].put(response)
if response[-1]["finished"]:
self.request_num[request_id] -= 1
if self.request_num[request_id] == 0:
self._update_load(conn_index, -1)
else:
api_server_logger.warning(
f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}"
)
except Exception as e:
api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}")
break
api_server_logger.info(f"Listener loop ended for conn_index {conn_index}")
def _update_load(self, conn_index, delta):
"""Update connection load and maintain the heap"""
self.connection_load[conn_index] += delta
heapq.heapify(self.connection_heap)
# For Debugging purposes
if random.random() < 0.01:
min_load = self.connection_heap[0][0] if self.connection_heap else 0
max_load = max(self.connection_load) if self.connection_load else 0
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
def _get_least_loaded_connection(self):
"""
Get the least loaded connection
"""
if not self.connection_heap:
return None
load, conn_index = self.connection_heap[0]
self._update_load(conn_index, 1)
return self.connections[conn_index]
async def _dispatch_batch_responses(self):
"""
Receive batch responses and dispatch to corresponding request queues.
batch_data format: [[output, ...], [output, ...], ...]
"""
consecutive_errors = 0
max_consecutive_errors = 5
while self.running:
try:
raw_data = await self.pull_client.read()
# Deserialize batch response
batch_data = msgpack.unpackb(raw_data[-1])
# Record metrics
_zmq_metrics_stats = ZMQMetricsStats()
_zmq_metrics_stats.msg_recv_total += 1
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.pull_metrics_address)
# Parse req_id from outputs and dispatch in a single pass
for outputs in batch_data:
last_output = outputs[-1]
req_id = last_output["request_id"] if isinstance(last_output, dict) else last_output.request_id
if req_id.startswith(("cmpl", "embd", "reward", "chatcmpl")):
req_id = req_id.rsplit("_", 1)[0]
queue = self.request_map.get(req_id)
if queue is not None:
queue.put_nowait(outputs)
consecutive_errors = 0
except (ConnectionError, OSError) as e:
if self.running:
api_server_logger.error(f"Dispatcher: connection lost, exiting: {e}")
break
except Exception as e:
consecutive_errors += 1
if self.running:
api_server_logger.error(
f"Dispatcher error ({consecutive_errors}/{max_consecutive_errors}): "
f"{e}\n{traceback.format_exc()}"
)
if consecutive_errors >= max_consecutive_errors:
api_server_logger.error(f"Dispatcher: {max_consecutive_errors} consecutive errors, exiting")
break
async def get_connection(self, request_id, num_choices=1):
"""get a connection for the request"""
response_queue = asyncio.Queue()
if envs.ZMQ_SEND_BATCH_DATA:
async with self.lock:
self.request_map[request_id] = response_queue
return None, response_queue
else:
async with self.lock:
self.request_map[request_id] = response_queue
self.request_num[request_id] = num_choices
dealer = self._get_least_loaded_connection()
if not dealer:
raise RuntimeError("No available connections")
return dealer, response_queue
async def cleanup_request(self, request_id):
"""
clean up the request after it is finished
"""
try:
async with self.lock:
self.request_map.pop(request_id, None)
if not envs.ZMQ_SEND_BATCH_DATA:
self.request_num.pop(request_id, None)
except asyncio.CancelledError:
# If cancelled during lock acquisition, try cleanup without lock
self.request_map.pop(request_id, None)
if not envs.ZMQ_SEND_BATCH_DATA:
self.request_num.pop(request_id, None)
raise
async def close(self):
"""
close all connections and tasks
"""
self.running = False
if envs.ZMQ_SEND_BATCH_DATA:
# Cancel dispatcher task
if self.dispatcher_task:
self.dispatcher_task.cancel()
# Close PULL client
if self.pull_client:
try:
self.pull_client.close()
except:
pass
# Clean up IPC file created by bind
pull_ipc_path = getattr(self, "pull_ipc_path", None)
if pull_ipc_path:
try:
if os.path.exists(pull_ipc_path):
os.remove(pull_ipc_path)
except OSError:
pass
else:
for task in self.connection_tasks:
task.cancel()
async with self.lock:
for dealer in self.connections:
try:
dealer.close()
except:
pass
self.connections.clear()
self.connection_load.clear()
# Clear request map
self.request_map.clear()
api_server_logger.info("All connections and tasks closed")
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--metrics-port", default=None, type=int, help="port for metrics server")
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser.add_argument(
"--max-waiting-time",
default=-1,
type=int,
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser.add_argument(
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
)
parser.add_argument(
"--timeout-graceful-shutdown",
default=0,
type=int,
help="timeout for graceful shutdown in seconds (used by gunicorn).Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
)
parser.add_argument(
"--timeout",
default=0,
type=int,
help="Workers silent for more than this many seconds are killed and restarted.Value is a positive number or 0. Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
)
parser.add_argument("--api-key", type=str, action="append", help="API_KEY required for service authentication")
parser = EngineArgs.add_cli_args(parser)
return parser
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.The response object will only correctly
listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4).
"""
# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["req"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
return handler_task.result()
return None
return wrapper