mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
a7f52c300d
* [Feature] support v1 update/clear api for RL * [fix] fix execute_model and add sleep/wakeup api * [fix] fix mtp and key_prefix * [chore] move _update_key_prefix to resume method * [fix] make the interface safe to call multiple times * [fix] fix some tiny bugs * [chore] make small changes against pr review * [docs] add docs for weight update * [test] add some tests and update docs * [style] fix code style check * [test] fix ci * [fix] fix stale control responses when control method timed out * [chore] remove unused code * [chore] fix code style * [chore] optimize tags and key_prefix * [test] fix ci * [chore] fix code style * [test] fix ci * [fix] fix ep control * [fix] fix ep control for engine cache queue
430 lines
17 KiB
Python
430 lines
17 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import asyncio
|
|
import functools
|
|
import heapq
|
|
import os
|
|
import random
|
|
import time
|
|
import traceback
|
|
from multiprocessing.reduction import ForkingPickler
|
|
|
|
import aiozmq
|
|
import msgpack
|
|
import zmq
|
|
from fastapi import Request
|
|
|
|
import fastdeploy.envs as envs
|
|
from fastdeploy.engine.args_utils import EngineArgs
|
|
from fastdeploy.metrics.metrics import main_process_metrics
|
|
from fastdeploy.metrics.stats import ZMQMetricsStats
|
|
from fastdeploy.utils import FlexibleArgumentParser, api_server_logger
|
|
|
|
UVICORN_CONFIG = {
|
|
"version": 1,
|
|
"disable_existing_loggers": False,
|
|
"formatters": {
|
|
"custom": {
|
|
"()": "colorlog.ColoredFormatter",
|
|
"format": "%(log_color)s%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s%(reset)s",
|
|
"datefmt": "%Y-%m-%d %H:%M:%S", # Timestamp format
|
|
"log_colors": {
|
|
"DEBUG": "cyan",
|
|
"INFO": "green",
|
|
"WARNING": "yellow",
|
|
"ERROR": "red",
|
|
"CRITICAL": "red,bg_white",
|
|
},
|
|
}
|
|
},
|
|
"handlers": {
|
|
"default": {
|
|
"class": "colorlog.StreamHandler",
|
|
"stream": "ext://sys.stderr",
|
|
"formatter": "custom",
|
|
},
|
|
},
|
|
"loggers": {
|
|
"uvicorn": {
|
|
"level": "INFO",
|
|
"handlers": ["default"],
|
|
"propagate": False,
|
|
},
|
|
"uvicorn.error": {
|
|
"level": "INFO",
|
|
"handlers": ["default"],
|
|
"propagate": False,
|
|
},
|
|
"uvicorn.access": {
|
|
"level": "INFO",
|
|
"handlers": ["default"],
|
|
"propagate": False,
|
|
"formatter": "custom",
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
class DealerConnectionManager:
|
|
"""
|
|
Manager for dealer connections, supporting multiplexing and connection reuse.
|
|
When ZMQ_SEND_BATCH_DATA=1: uses PULL client for batch response reception.
|
|
When ZMQ_SEND_BATCH_DATA=0: uses DEALER connections with ROUTER for per-query response.
|
|
"""
|
|
|
|
def __init__(self, pid, max_connections=10):
|
|
self.pid = pid
|
|
self.request_map = {} # request_id -> response_queue
|
|
self.lock = asyncio.Lock()
|
|
self.running = False
|
|
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
# Batch mode: PULL client and dispatcher task
|
|
self.pull_client = None
|
|
self.dispatcher_task = None
|
|
else:
|
|
# Per-query mode: DEALER connections
|
|
self.max_connections = max(max_connections, 10)
|
|
self.connections = []
|
|
self.connection_load = []
|
|
self.connection_heap = []
|
|
self.request_num = {} # request_id -> num_choices
|
|
self.connection_tasks = []
|
|
|
|
async def initialize(self):
|
|
"""initialize all connections"""
|
|
self.running = True
|
|
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
# Create PULL client for batch response reception
|
|
# Each worker binds on a unique address to avoid PUSH round-robin
|
|
try:
|
|
self.worker_pid = os.getpid()
|
|
self.pull_ipc_path = f"/dev/shm/response_{self.pid}_w{self.worker_pid}.pull"
|
|
self.pull_client = await aiozmq.create_zmq_stream(zmq.PULL, bind=f"ipc://{self.pull_ipc_path}")
|
|
self.pull_metrics_address = self.pull_client.transport.getsockopt(zmq.LAST_ENDPOINT)
|
|
# Start dispatcher task
|
|
self.dispatcher_task = asyncio.create_task(self._dispatch_batch_responses())
|
|
api_server_logger.info(
|
|
f"Started PULL client (bind) for batch response, pid {self.pid}, worker {self.worker_pid}"
|
|
)
|
|
except Exception as e:
|
|
api_server_logger.error(f"Failed to create PULL client: {str(e)}")
|
|
# Reset running flag and propagate error to avoid hanging requests in batch mode
|
|
self.running = False
|
|
raise RuntimeError(
|
|
f"Failed to initialize PULL client for batch response "
|
|
f"(pid={self.pid}, worker={getattr(self, 'worker_pid', 'unknown')})"
|
|
) from e
|
|
else:
|
|
for index in range(self.max_connections):
|
|
await self._add_connection(index)
|
|
api_server_logger.info(f"Started {self.max_connections} connections, pid {self.pid}")
|
|
|
|
async def _add_connection(self, index):
|
|
"""create a new connection and start listening task"""
|
|
try:
|
|
dealer = await aiozmq.create_zmq_stream(
|
|
zmq.DEALER,
|
|
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
|
|
)
|
|
async with self.lock:
|
|
self.connections.append(dealer)
|
|
self.connection_load.append(0)
|
|
heapq.heappush(self.connection_heap, (0, index))
|
|
|
|
# start listening
|
|
task = asyncio.create_task(self._listen_connection(dealer, index))
|
|
self.connection_tasks.append(task)
|
|
return True
|
|
except Exception as e:
|
|
api_server_logger.error(f"Failed to create dealer: {str(e)}")
|
|
return False
|
|
|
|
async def _listen_connection(self, dealer, conn_index):
|
|
"""
|
|
listen for messages from the dealer connection
|
|
"""
|
|
while self.running:
|
|
try:
|
|
raw_data = await dealer.read()
|
|
response = ForkingPickler.loads(raw_data[-1])
|
|
_zmq_metrics_stats = ZMQMetricsStats()
|
|
_zmq_metrics_stats.msg_recv_total += 1
|
|
if "zmq_send_time" in response:
|
|
_zmq_metrics_stats.zmq_latency = time.perf_counter() - response["zmq_send_time"]
|
|
address = dealer.transport.getsockopt(zmq.LAST_ENDPOINT)
|
|
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, address)
|
|
|
|
request_id = response[-1]["request_id"]
|
|
if request_id[:4] in ["cmpl", "embd"]:
|
|
request_id = request_id.rsplit("_", 1)[0]
|
|
elif "reward" == request_id[:6]:
|
|
request_id = request_id.rsplit("_", 1)[0]
|
|
elif "chatcmpl" == request_id[:8]:
|
|
request_id = request_id.rsplit("_", 1)[0]
|
|
async with self.lock:
|
|
if request_id in self.request_map:
|
|
await self.request_map[request_id].put(response)
|
|
if response[-1]["finished"]:
|
|
self.request_num[request_id] -= 1
|
|
if self.request_num[request_id] == 0:
|
|
self._update_load(conn_index, -1)
|
|
else:
|
|
api_server_logger.warning(
|
|
f"request_id {request_id} not in request_map, available keys: {list(self.request_map.keys())}"
|
|
)
|
|
except Exception as e:
|
|
api_server_logger.error(f"Listener error: {str(e)}\n{traceback.format_exc()}")
|
|
break
|
|
api_server_logger.info(f"Listener loop ended for conn_index {conn_index}")
|
|
|
|
def _update_load(self, conn_index, delta):
|
|
"""Update connection load and maintain the heap"""
|
|
self.connection_load[conn_index] += delta
|
|
heapq.heapify(self.connection_heap)
|
|
|
|
# For Debugging purposes
|
|
if random.random() < 0.01:
|
|
min_load = self.connection_heap[0][0] if self.connection_heap else 0
|
|
max_load = max(self.connection_load) if self.connection_load else 0
|
|
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
|
|
|
|
def _get_least_loaded_connection(self):
|
|
"""
|
|
Get the least loaded connection
|
|
"""
|
|
if not self.connection_heap:
|
|
return None
|
|
|
|
load, conn_index = self.connection_heap[0]
|
|
self._update_load(conn_index, 1)
|
|
|
|
return self.connections[conn_index]
|
|
|
|
async def _dispatch_batch_responses(self):
|
|
"""
|
|
Receive batch responses and dispatch to corresponding request queues.
|
|
batch_data format: [[output, ...], [output, ...], ...]
|
|
"""
|
|
consecutive_errors = 0
|
|
max_consecutive_errors = 5
|
|
while self.running:
|
|
try:
|
|
raw_data = await self.pull_client.read()
|
|
# Deserialize batch response
|
|
batch_data = msgpack.unpackb(raw_data[-1])
|
|
|
|
# Record metrics
|
|
_zmq_metrics_stats = ZMQMetricsStats()
|
|
_zmq_metrics_stats.msg_recv_total += 1
|
|
main_process_metrics.record_zmq_stats(_zmq_metrics_stats, self.pull_metrics_address)
|
|
|
|
# Parse req_id from outputs and dispatch in a single pass
|
|
for outputs in batch_data:
|
|
last_output = outputs[-1]
|
|
req_id = last_output["request_id"] if isinstance(last_output, dict) else last_output.request_id
|
|
if req_id.startswith(("cmpl", "embd", "reward", "chatcmpl")):
|
|
req_id = req_id.rsplit("_", 1)[0]
|
|
queue = self.request_map.get(req_id)
|
|
if queue is not None:
|
|
queue.put_nowait(outputs)
|
|
|
|
consecutive_errors = 0
|
|
|
|
except (ConnectionError, OSError) as e:
|
|
if self.running:
|
|
api_server_logger.error(f"Dispatcher: connection lost, exiting: {e}")
|
|
break
|
|
except Exception as e:
|
|
consecutive_errors += 1
|
|
if self.running:
|
|
api_server_logger.error(
|
|
f"Dispatcher error ({consecutive_errors}/{max_consecutive_errors}): "
|
|
f"{e}\n{traceback.format_exc()}"
|
|
)
|
|
if consecutive_errors >= max_consecutive_errors:
|
|
api_server_logger.error(f"Dispatcher: {max_consecutive_errors} consecutive errors, exiting")
|
|
break
|
|
|
|
async def get_connection(self, request_id, num_choices=1):
|
|
"""get a connection for the request"""
|
|
|
|
response_queue = asyncio.Queue()
|
|
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
async with self.lock:
|
|
self.request_map[request_id] = response_queue
|
|
return None, response_queue
|
|
else:
|
|
async with self.lock:
|
|
self.request_map[request_id] = response_queue
|
|
self.request_num[request_id] = num_choices
|
|
dealer = self._get_least_loaded_connection()
|
|
if not dealer:
|
|
raise RuntimeError("No available connections")
|
|
return dealer, response_queue
|
|
|
|
async def cleanup_request(self, request_id):
|
|
"""
|
|
clean up the request after it is finished
|
|
"""
|
|
try:
|
|
async with self.lock:
|
|
self.request_map.pop(request_id, None)
|
|
if not envs.ZMQ_SEND_BATCH_DATA:
|
|
self.request_num.pop(request_id, None)
|
|
except asyncio.CancelledError:
|
|
# If cancelled during lock acquisition, try cleanup without lock
|
|
self.request_map.pop(request_id, None)
|
|
if not envs.ZMQ_SEND_BATCH_DATA:
|
|
self.request_num.pop(request_id, None)
|
|
raise
|
|
|
|
async def close(self):
|
|
"""
|
|
close all connections and tasks
|
|
"""
|
|
self.running = False
|
|
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
# Cancel dispatcher task
|
|
if self.dispatcher_task:
|
|
self.dispatcher_task.cancel()
|
|
|
|
# Close PULL client
|
|
if self.pull_client:
|
|
try:
|
|
self.pull_client.close()
|
|
except:
|
|
pass
|
|
|
|
# Clean up IPC file created by bind
|
|
pull_ipc_path = getattr(self, "pull_ipc_path", None)
|
|
if pull_ipc_path:
|
|
try:
|
|
if os.path.exists(pull_ipc_path):
|
|
os.remove(pull_ipc_path)
|
|
except OSError:
|
|
pass
|
|
else:
|
|
for task in self.connection_tasks:
|
|
task.cancel()
|
|
|
|
async with self.lock:
|
|
for dealer in self.connections:
|
|
try:
|
|
dealer.close()
|
|
except:
|
|
pass
|
|
self.connections.clear()
|
|
self.connection_load.clear()
|
|
|
|
# Clear request map
|
|
self.request_map.clear()
|
|
|
|
api_server_logger.info("All connections and tasks closed")
|
|
|
|
|
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
|
|
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
|
|
parser.add_argument("--workers", default=1, type=int, help="number of workers")
|
|
parser.add_argument("--metrics-port", default=None, type=int, help="port for metrics server")
|
|
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
|
|
parser.add_argument(
|
|
"--max-waiting-time",
|
|
default=-1,
|
|
type=int,
|
|
help="max waiting time for connection, if set value -1 means no waiting time limit",
|
|
)
|
|
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
|
|
|
|
parser.add_argument(
|
|
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
|
|
)
|
|
parser.add_argument(
|
|
"--timeout-graceful-shutdown",
|
|
default=0,
|
|
type=int,
|
|
help="timeout for graceful shutdown in seconds (used by gunicorn).Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--timeout",
|
|
default=0,
|
|
type=int,
|
|
help="Workers silent for more than this many seconds are killed and restarted.Value is a positive number or 0. Setting it to 0 has the effect of infinite timeouts by disabling timeouts for all workers entirely.",
|
|
)
|
|
|
|
parser.add_argument("--api-key", type=str, action="append", help="API_KEY required for service authentication")
|
|
|
|
parser = EngineArgs.add_cli_args(parser)
|
|
return parser
|
|
|
|
|
|
async def listen_for_disconnect(request: Request) -> None:
|
|
"""Returns if a disconnect message is received"""
|
|
while True:
|
|
message = await request.receive()
|
|
if message["type"] == "http.disconnect":
|
|
break
|
|
|
|
|
|
def with_cancellation(handler_func):
|
|
"""Decorator that allows a route handler to be cancelled by client
|
|
disconnections.
|
|
|
|
This does _not_ use request.is_disconnected, which does not work with
|
|
middleware. Instead this follows the pattern from
|
|
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
|
|
to wait for an http disconnect message, and the other to do the work that we
|
|
want done. When the first task finishes, the other is cancelled.
|
|
|
|
A core assumption of this method is that the body of the request has already
|
|
been read. This is a safe assumption to make for fastapi handlers that have
|
|
already parsed the body of the request into a pydantic model for us.
|
|
This decorator is unsafe to use elsewhere, as it will consume and throw away
|
|
all incoming messages for the request while it looks for a disconnect
|
|
message.
|
|
|
|
In the case where a `StreamingResponse` is returned by the handler, this
|
|
wrapper will stop listening for disconnects and instead the response object
|
|
will start listening for disconnects.The response object will only correctly
|
|
listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4).
|
|
"""
|
|
|
|
# Functools.wraps is required for this wrapper to appear to fastapi as a
|
|
# normal route handler, with the correct request type hinting.
|
|
@functools.wraps(handler_func)
|
|
async def wrapper(*args, **kwargs):
|
|
# The request is either the second positional arg or `raw_request`
|
|
request = args[1] if len(args) > 1 else kwargs["req"]
|
|
|
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
|
|
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
if handler_task in done:
|
|
return handler_task.result()
|
|
return None
|
|
|
|
return wrapper
|