mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
950366e58d
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
516 lines
22 KiB
Python
516 lines
22 KiB
Python
"""
|
|
Async Router server for FastDeploy.
|
|
Handles client requests and manages prefill/decode/mixed instances.
|
|
This module references the router implementation of slglang and vllm.
|
|
"""
|
|
|
|
import asyncio
|
|
import copy
|
|
import json
|
|
import os
|
|
import random
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from itertools import chain
|
|
from typing import Dict, List, Optional
|
|
from uuid import uuid4
|
|
|
|
import aiohttp
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
|
|
from fastdeploy.router.utils import (
|
|
InstanceInfo,
|
|
InstanceRole,
|
|
check_service_health_async,
|
|
)
|
|
from fastdeploy.utils import FlexibleArgumentParser
|
|
from fastdeploy.utils import router_logger as logger
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
@dataclass
|
|
class RouterArgs:
|
|
host: str = "0.0.0.0"
|
|
"""
|
|
Host address to bind the router server
|
|
"""
|
|
port: int = 9000
|
|
"""
|
|
Port to bind the router server.
|
|
"""
|
|
splitwise: bool = False
|
|
"""
|
|
Router uses splitwise deployment
|
|
"""
|
|
request_timeout_secs: int = 1800
|
|
"""
|
|
Request timeout in seconds
|
|
"""
|
|
|
|
@staticmethod
|
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default=RouterArgs.host,
|
|
help="Host address to bind the router server.",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=RouterArgs.port,
|
|
help="Port number to bind the router server",
|
|
)
|
|
parser.add_argument(
|
|
"--splitwise",
|
|
action="store_true",
|
|
default=RouterArgs.splitwise,
|
|
help="Router uses splitwise deployment",
|
|
)
|
|
parser.add_argument(
|
|
"--request-timeout-secs",
|
|
type=int,
|
|
default=RouterArgs.request_timeout_secs,
|
|
help="Request timeout in seconds",
|
|
)
|
|
return parser
|
|
|
|
|
|
class Router:
|
|
"""
|
|
Router class that handles requests from client and
|
|
collects prefill/decode instance information
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.host = args.host
|
|
self.port = args.port
|
|
self.splitwise = args.splitwise
|
|
self.timeout = args.request_timeout_secs
|
|
|
|
self.mixed_servers = []
|
|
self.prefill_servers = []
|
|
self.decode_servers = []
|
|
self.lock = asyncio.Lock() # async-safe lock
|
|
logger.info("Router started at http://{}:{}".format(self.host, self.port))
|
|
|
|
async def register_instance(self, instance_info_dict: dict):
|
|
"""Register an instance asynchronously"""
|
|
try:
|
|
inst_info = InstanceInfo.from_dict(instance_info_dict)
|
|
except Exception as e:
|
|
logger.error(f"register instance failed: {e}")
|
|
raise
|
|
|
|
if (self.splitwise and inst_info.role == InstanceRole.MIXED) or (
|
|
not self.splitwise and inst_info.role != InstanceRole.MIXED
|
|
):
|
|
raise ValueError(f"Invalid instance role: {inst_info.role}, splitwise: {self.splitwise}")
|
|
|
|
if not await check_service_health_async(inst_info.url()):
|
|
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
|
|
|
async with self.lock:
|
|
instance_key = inst_info.get_key()
|
|
|
|
if inst_info.role == InstanceRole.MIXED:
|
|
self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
|
|
elif inst_info.role == InstanceRole.PREFILL:
|
|
self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
|
|
elif inst_info.role == InstanceRole.DECODE:
|
|
self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
|
|
|
|
def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
|
|
"""Update existing instance or add new one based on key (host_ip:port)."""
|
|
for i, existing in enumerate(server_list):
|
|
if existing.get_key() == key:
|
|
if existing != inst_info:
|
|
server_list[i] = inst_info
|
|
logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
|
|
return
|
|
|
|
server_list.append(inst_info)
|
|
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
|
|
|
|
async def registered_number(self):
|
|
"""Get number of registered instances"""
|
|
return {
|
|
"mixed": len(self.mixed_servers),
|
|
"prefill": len(self.prefill_servers),
|
|
"decode": len(self.decode_servers),
|
|
}
|
|
|
|
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
|
|
"""Get all registered decode instances, optionally filtered by version"""
|
|
async with self.lock:
|
|
instances = self.decode_servers
|
|
if version is not None:
|
|
instances = [inst for inst in instances if inst.version == version]
|
|
return [inst.to_dict() for inst in instances]
|
|
|
|
async def select_pd(self):
|
|
"""Select one prefill and one decode server"""
|
|
async with self.lock:
|
|
if not self.prefill_servers:
|
|
raise RuntimeError(f"No prefill servers available (decode={len(self.decode_servers)})")
|
|
if not self.decode_servers:
|
|
raise RuntimeError(f"No decode servers available (prefill={len(self.prefill_servers)})")
|
|
pidx = random.randint(0, len(self.prefill_servers) - 1)
|
|
didx = random.randint(0, len(self.decode_servers) - 1)
|
|
return self.prefill_servers[pidx], self.decode_servers[didx]
|
|
|
|
async def select_mixed(self):
|
|
"""Select one mixed server"""
|
|
async with self.lock:
|
|
if not self.mixed_servers:
|
|
raise RuntimeError(f"No mixed servers available. Registered mixed servers: {len(self.mixed_servers)}")
|
|
idx = random.randint(0, len(self.mixed_servers) - 1)
|
|
return self.mixed_servers[idx]
|
|
|
|
async def handle_request(self, request_data: dict, endpoint_name: str):
|
|
if self.splitwise:
|
|
return await self.handle_splitwise_request(request_data, endpoint_name)
|
|
else:
|
|
return await self.handle_mixed_request(request_data, endpoint_name)
|
|
|
|
async def handle_mixed_request(self, request_data: dict, endpoint_name: str):
|
|
logger.debug(f"Received request: {request_data}")
|
|
mixed_server = await self.select_mixed()
|
|
|
|
if request_data.get("stream", False):
|
|
if request_data.get("divided_stream", int(os.environ.get("DIVIDED_STREAM", "0")) == 1):
|
|
return await self._divided_generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
|
else:
|
|
return await self._generate_stream(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
|
else:
|
|
return await self._generate(request_data, [mixed_server.url()], endpoint=endpoint_name)
|
|
|
|
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
|
|
logger.debug(f"Received request: {request_data}")
|
|
prefill_server, decode_server = await self.select_pd()
|
|
logger.debug(f"Selected prefill server: {prefill_server}")
|
|
logger.debug(f"Selected decode server: {decode_server}")
|
|
|
|
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1",
|
|
)
|
|
|
|
# TODO: unify the disaggregate_info in server and remove redundancy params
|
|
is_same_node = prefill_server.host_ip == decode_server.host_ip
|
|
is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
|
|
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
|
|
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
|
|
|
|
disaggregate_info = {
|
|
"prefill_ip": prefill_server.host_ip,
|
|
"decode_ip": decode_server.host_ip,
|
|
"prefill_connector_port": prefill_server.connector_port,
|
|
"decode_connector_port": decode_server.connector_port,
|
|
"decode_device_ids": decode_server.device_ids,
|
|
"decode_rdma_ports": decode_server.rdma_ports,
|
|
"transfer_protocol": "ipc" if use_ipc else "rdma",
|
|
"decode_tp_size": decode_server.tp_size,
|
|
}
|
|
|
|
modified_request = request_data.copy()
|
|
modified_request["disaggregate_info"] = disaggregate_info
|
|
if "request_id" not in modified_request:
|
|
modified_request["request_id"] = str(uuid4())
|
|
|
|
logger.debug(f"Modified request: {modified_request}")
|
|
|
|
if request_data.get("stream", False):
|
|
return await self._generate_stream(
|
|
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
|
|
)
|
|
else:
|
|
return await self._generate(
|
|
modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name
|
|
)
|
|
|
|
async def _generate(
|
|
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
|
|
) -> ORJSONResponse:
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
|
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
|
|
results = await asyncio.gather(*tasks)
|
|
ret_json = await results[return_result_url_index].json()
|
|
return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status)
|
|
|
|
async def _generate_stream(
|
|
self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions"
|
|
):
|
|
async def stream_results():
|
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session:
|
|
tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls]
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = 1024 * 64 # prevent aiohttp's "Chunk too big" error
|
|
async for chunk in results[return_result_url_index].content.iter_chunked(
|
|
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
|
):
|
|
logger.debug(f"receive response chunk: {chunk}")
|
|
yield chunk
|
|
|
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
|
|
|
async def _divided_generate_stream(
|
|
self,
|
|
modified_request,
|
|
urls,
|
|
return_result_url_index=-1,
|
|
endpoint="v1/chat/completions",
|
|
):
|
|
"""
|
|
NOTE: Used for debugging, not used in production
|
|
"""
|
|
|
|
async def stream_results():
|
|
total_max_tokens = modified_request.get("max_tokens", 0)
|
|
step_max_tokens = modified_request.get("step_max_tokens", 10)
|
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
|
|
round_idx = -1
|
|
generated_tokens = 0
|
|
input_ids = []
|
|
output_ids = []
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
while generated_tokens < total_max_tokens:
|
|
round_idx += 1
|
|
remain_tokens = total_max_tokens - generated_tokens
|
|
cur_max_tokens = min(step_max_tokens, remain_tokens)
|
|
is_last_round = remain_tokens <= step_max_tokens
|
|
|
|
cur_request = copy.deepcopy(modified_request)
|
|
cur_request["max_tokens"] = cur_max_tokens
|
|
cur_request["return_token_ids"] = True
|
|
cur_request["max_streaming_response_tokens"] = 1
|
|
if round_idx == 0:
|
|
cur_request["disable_chat_template"] = False
|
|
else:
|
|
cur_request["messages"] = []
|
|
cur_request["prompt_token_ids"] = input_ids + output_ids
|
|
cur_request["disable_chat_template"] = True
|
|
|
|
logger.debug(f"_divided_generate_stream, cur_request={cur_request}")
|
|
|
|
resp = await session.post(
|
|
f"{urls[return_result_url_index]}/{endpoint}",
|
|
json=cur_request,
|
|
)
|
|
if resp.status != 200:
|
|
text = await resp.text()
|
|
raise RuntimeError(f"Request failed: {resp.status}, body={text}")
|
|
|
|
buffer = b""
|
|
chunk_idx = -1
|
|
is_real_finished = False
|
|
async for raw_chunk in resp.content.iter_chunked(64 * 1024):
|
|
try:
|
|
buffer += raw_chunk
|
|
|
|
while b"\n\n" in buffer:
|
|
event_bytes, buffer = buffer.split(b"\n\n", 1)
|
|
event_str = event_bytes.decode("utf-8")
|
|
|
|
for chunk in event_str.splitlines():
|
|
logger.debug(f"receive response chunk: {chunk}")
|
|
if not chunk:
|
|
continue
|
|
|
|
chunk_idx += 1
|
|
if round_idx > 0 and chunk_idx == 0:
|
|
continue
|
|
|
|
assert chunk.startswith("data: "), f"Invalid response chunk: {chunk}"
|
|
if chunk.startswith("data: [DONE]"):
|
|
if is_real_finished:
|
|
yield chunk + "\n\n"
|
|
else:
|
|
payload = json.loads(chunk[5:])
|
|
choices = payload.get("choices", [])
|
|
if not choices:
|
|
continue
|
|
delta = payload["choices"][0]["delta"]
|
|
finish_reason = payload["choices"][0].get("finish_reason")
|
|
|
|
if not input_ids and len(delta["prompt_token_ids"]) > 0:
|
|
input_ids = delta["prompt_token_ids"]
|
|
|
|
if finish_reason == "stop" or (is_last_round and finish_reason == "length"):
|
|
is_real_finished = True
|
|
|
|
token_ids = delta.get("completion_token_ids")
|
|
if (
|
|
token_ids
|
|
and isinstance(token_ids, list)
|
|
and (finish_reason is None or is_real_finished)
|
|
):
|
|
output_ids.extend(token_ids)
|
|
generated_tokens += len(token_ids)
|
|
|
|
if finish_reason is None or is_real_finished:
|
|
yield chunk + "\n\n"
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error decoding response chunk: {raw_chunk}, round_idx: {round_idx}, "
|
|
f"chunk_idx: {chunk_idx}, error: {e}, traceback:{traceback.format_exc()}"
|
|
)
|
|
pass
|
|
|
|
if not is_real_finished:
|
|
expected_tokens = (step_max_tokens - 1) * (round_idx + 1)
|
|
if generated_tokens != expected_tokens:
|
|
err_msg = (
|
|
f"Generated tokens mismatch: generated_tokens is {generated_tokens}, "
|
|
f"expected is {expected_tokens}"
|
|
)
|
|
logger.error(err_msg)
|
|
raise RuntimeError(err_msg)
|
|
|
|
if is_real_finished:
|
|
break
|
|
|
|
return StreamingResponse(
|
|
stream_results(),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
async def monitor_instance_health(self, interval_secs: float = 5.0):
|
|
"""
|
|
Continuously check the health of prefill, decode, and mixed instances and remove unhealthy ones.
|
|
"""
|
|
while True:
|
|
try:
|
|
prefill_to_remove = []
|
|
decode_to_remove = []
|
|
mixed_to_remove = []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
# check servers
|
|
prefill_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.prefill_servers]
|
|
decode_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.decode_servers]
|
|
mixed_tasks = [(inst, session.get(f"{inst.url()}/health")) for inst in self.mixed_servers]
|
|
|
|
# gather all tasks concurrently
|
|
all_tasks = prefill_tasks + decode_tasks + mixed_tasks
|
|
for inst, coro in all_tasks:
|
|
try:
|
|
resp = await coro
|
|
if resp.status != 200:
|
|
logger.warning(f"Instance {inst.url()} unhealthy: {resp.status}")
|
|
if inst in self.prefill_servers:
|
|
prefill_to_remove.append(inst)
|
|
elif inst in self.decode_servers:
|
|
decode_to_remove.append(inst)
|
|
elif inst in self.mixed_servers:
|
|
mixed_to_remove.append(inst)
|
|
except Exception as e:
|
|
logger.warning(f"Instance {inst.url()} check failed: {e}")
|
|
if inst in self.prefill_servers:
|
|
prefill_to_remove.append(inst)
|
|
elif inst in self.decode_servers:
|
|
decode_to_remove.append(inst)
|
|
elif inst in self.mixed_servers:
|
|
mixed_to_remove.append(inst)
|
|
|
|
# remove unhealthy instances under lock
|
|
async with self.lock:
|
|
if prefill_to_remove:
|
|
for inst in prefill_to_remove:
|
|
self.prefill_servers.remove(inst)
|
|
logger.info(f"Removed unhealthy prefill instance: {inst.url()}")
|
|
if decode_to_remove:
|
|
for inst in decode_to_remove:
|
|
self.decode_servers.remove(inst)
|
|
logger.info(f"Removed unhealthy decode instance: {inst.url()}")
|
|
if mixed_to_remove:
|
|
for inst in mixed_to_remove:
|
|
self.mixed_servers.remove(inst)
|
|
logger.info(f"Removed unhealthy mixed instance: {inst.url()}")
|
|
|
|
await asyncio.sleep(interval_secs)
|
|
|
|
prefill_instances = [inst.url() for inst in self.prefill_servers]
|
|
decode_instances = [inst.url() for inst in self.decode_servers]
|
|
mixed_instance = [inst.url() for inst in self.mixed_servers]
|
|
logger.debug(
|
|
f"Healthy prefill instances: {prefill_instances}, "
|
|
f"Healthy decode instances: {decode_instances}, "
|
|
f"Healthy mixed instance: {mixed_instance}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Failed to monitor instance health: {e}")
|
|
|
|
|
|
@app.post("/register")
|
|
async def register(instance_info_dict: dict):
|
|
"""Register prefill/decode/mixed servers"""
|
|
try:
|
|
await app.state.router.register_instance(instance_info_dict)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
return {"status": "success"}
|
|
|
|
|
|
@app.get("/registered_number")
|
|
async def registered_number():
|
|
"""Get the number of registered prefill/decode/mixed servers"""
|
|
return await app.state.router.registered_number()
|
|
|
|
|
|
@app.get("/decode_instances")
|
|
async def decode_instances(version: Optional[str] = None):
|
|
"""Get all registered decode instances, optionally filtered by version"""
|
|
return await app.state.router.get_decode_instances(version)
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def create_chat_completion(request_data: dict):
|
|
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def create_completion(request_data: dict):
|
|
return await app.state.router.handle_request(request_data, "v1/completions")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Basic health check"""
|
|
return Response(status_code=200)
|
|
|
|
|
|
@app.get("/health_generate")
|
|
async def health_generate():
|
|
"""Check all prefill and decode servers are healthy"""
|
|
router = app.state.router
|
|
async with aiohttp.ClientSession() as session:
|
|
tasks = [session.get(f"{s.url()}/health") for s in chain(router.prefill_servers, router.decode_servers)]
|
|
for coro in asyncio.as_completed(tasks):
|
|
resp = await coro
|
|
if resp.status != 200:
|
|
logger.warning(f"Server {resp.url} not healthy: {resp.status}")
|
|
return Response(status_code=200)
|
|
|
|
|
|
def launch_router(router_args: RouterArgs):
|
|
app.state.router_args = router_args
|
|
print(f"Starting router with args: {router_args}")
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
app.state.router = Router(app.state.router_args)
|
|
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
|
|
|
|
uvicorn.run(app, host=router_args.host, port=int(router_args.port))
|