mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
a6f0055d51
* commit * commit --------- Co-authored-by: “liuruian” <liuruian@baidu.com>
995 lines
35 KiB
Python
995 lines
35 KiB
Python
"""# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import signal
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
import uvicorn
|
|
import zmq
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
from gunicorn.app.base import BaseApplication
|
|
from opentelemetry import trace
|
|
from opentelemetry.propagate import extract
|
|
|
|
import fastdeploy.metrics.trace as tracing
|
|
from fastdeploy import envs
|
|
from fastdeploy.engine.args_utils import EngineArgs
|
|
from fastdeploy.engine.async_llm import AsyncLLM
|
|
from fastdeploy.engine.engine import LLMEngine
|
|
from fastdeploy.engine.expert_service import ExpertService
|
|
from fastdeploy.engine.request import ControlRequest
|
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
|
from fastdeploy.entrypoints.engine_client import EngineClient
|
|
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
|
from fastdeploy.entrypoints.openai.protocol import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatRewardRequest,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
ControlSchedulerRequest,
|
|
EmbeddingRequest,
|
|
ErrorInfo,
|
|
ErrorResponse,
|
|
ModelList,
|
|
)
|
|
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
|
from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
|
|
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
|
|
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
|
from fastdeploy.entrypoints.openai.utils import (
|
|
UVICORN_CONFIG,
|
|
make_arg_parser,
|
|
with_cancellation,
|
|
)
|
|
from fastdeploy.entrypoints.openai.v1.serving_chat import (
|
|
OpenAIServingChat as OpenAIServingChatV1,
|
|
)
|
|
from fastdeploy.entrypoints.openai.v1.serving_completion import (
|
|
OpenAIServingCompletion as OpenAIServingCompletionV1,
|
|
)
|
|
from fastdeploy.envs import environment_variables
|
|
from fastdeploy.metrics.metrics import get_filtered_metrics
|
|
from fastdeploy.utils import (
|
|
ExceptionHandler,
|
|
FlexibleArgumentParser,
|
|
StatefulSemaphore,
|
|
api_server_logger,
|
|
console_logger,
|
|
get_host_ip,
|
|
get_version_info,
|
|
is_port_available,
|
|
retrive_model_from_server,
|
|
)
|
|
|
|
_tracing_inited = False
|
|
|
|
parser = make_arg_parser(FlexibleArgumentParser())
|
|
args = parser.parse_args()
|
|
|
|
console_logger.info(f"Number of api-server workers: {args.workers}.")
|
|
|
|
args.model = retrive_model_from_server(args.model, args.revision)
|
|
chat_template = load_chat_template(args.chat_template, args.model)
|
|
if args.tool_parser_plugin:
|
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
|
llm_engine = None
|
|
|
|
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
|
|
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
|
|
|
|
|
|
class StandaloneApplication(BaseApplication):
|
|
def __init__(self, app, options=None):
|
|
self.application = app
|
|
self.options = options or {}
|
|
super().__init__()
|
|
|
|
def load_config(self):
|
|
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
|
|
for key, value in config.items():
|
|
self.cfg.set(key.lower(), value)
|
|
|
|
def load(self):
|
|
return self.application
|
|
|
|
|
|
def load_engine():
|
|
"""
|
|
load engine
|
|
"""
|
|
global llm_engine
|
|
if llm_engine is not None:
|
|
return llm_engine
|
|
|
|
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}, port: {args.port}")
|
|
engine_args = EngineArgs.from_cli_args(args)
|
|
if envs.FD_ENABLE_ASYNC_LLM:
|
|
engine = AsyncLLM.from_engine_args(engine_args, pid=args.port)
|
|
else:
|
|
engine = LLMEngine.from_engine_args(engine_args)
|
|
started = False
|
|
if isinstance(engine, AsyncLLM):
|
|
started = asyncio.run(engine.start())
|
|
else:
|
|
started = engine.start(api_server_pid=args.port)
|
|
if not started:
|
|
api_server_logger.error(
|
|
"Failed to initialize FastDeploy LLM engine, service exit now!"
|
|
"Please check the log file for more details."
|
|
)
|
|
return None
|
|
llm_engine = engine
|
|
return engine
|
|
|
|
|
|
def load_data_service():
|
|
"""
|
|
load data service
|
|
"""
|
|
global llm_engine
|
|
if llm_engine is not None:
|
|
return llm_engine
|
|
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}, port: {args.port}")
|
|
engine_args = EngineArgs.from_cli_args(args)
|
|
config = engine_args.create_engine_config()
|
|
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}")
|
|
api_server_logger.info(f"local_data_parallel_id: {config.parallel_config.local_data_parallel_id}")
|
|
expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id)
|
|
if not expert_service.start(args.port, config.parallel_config.local_data_parallel_id):
|
|
api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!")
|
|
return None
|
|
llm_engine = expert_service
|
|
return expert_service
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
async context manager for FastAPI lifespan
|
|
"""
|
|
global engine_args
|
|
global _tracing_inited
|
|
import logging
|
|
|
|
# Initialize tracing in worker lifecycle instead of module import time.
|
|
# This avoids creating grpc/cygrpc state before gunicorn forks workers.
|
|
if not _tracing_inited:
|
|
tracing.process_tracing_init()
|
|
_tracing_inited = True
|
|
|
|
uvicorn_access = logging.getLogger("uvicorn.access")
|
|
uvicorn_access.handlers.clear()
|
|
|
|
formatter = logging.Formatter(
|
|
"%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s"
|
|
)
|
|
|
|
handler = logging.StreamHandler()
|
|
handler.setFormatter(formatter)
|
|
uvicorn_access.addHandler(handler)
|
|
uvicorn_access.propagate = False
|
|
|
|
if args.tokenizer is None:
|
|
args.tokenizer = args.model
|
|
pid = args.port
|
|
api_server_logger.info(f"{pid}")
|
|
|
|
if args.served_model_name is not None:
|
|
served_model_names = args.served_model_name
|
|
verification = True
|
|
else:
|
|
served_model_names = args.model
|
|
verification = False
|
|
model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
|
|
|
|
engine_args = EngineArgs.from_cli_args(args, skip_port_check=True)
|
|
fd_config = engine_args.create_engine_config()
|
|
if envs.FD_ENABLE_ASYNC_LLM:
|
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = engine_args.engine_worker_queue_port[engine_args.local_data_parallel_id]
|
|
engine_client = EngineClient(
|
|
pid=pid,
|
|
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
|
|
fd_config=fd_config,
|
|
workers=args.workers,
|
|
max_logprobs=args.max_logprobs,
|
|
)
|
|
await engine_client.connection_manager.initialize()
|
|
app.state.dynamic_load_weight = args.dynamic_load_weight
|
|
model_handler = OpenAIServingModels(
|
|
model_paths,
|
|
args.max_model_len,
|
|
args.ips,
|
|
)
|
|
app.state.model_handler = model_handler
|
|
global llm_engine
|
|
if envs.FD_ENABLE_ASYNC_LLM:
|
|
await llm_engine.init_connections()
|
|
chat_handler = OpenAIServingChatV1(
|
|
llm_engine,
|
|
fd_config,
|
|
app.state.model_handler,
|
|
pid,
|
|
args.ips,
|
|
args.max_waiting_time,
|
|
chat_template,
|
|
args.enable_mm_output,
|
|
args.tokenizer_base_url,
|
|
)
|
|
completion_handler = OpenAIServingCompletionV1(
|
|
llm_engine,
|
|
fd_config,
|
|
app.state.model_handler,
|
|
pid,
|
|
args.ips,
|
|
args.max_waiting_time,
|
|
)
|
|
else:
|
|
chat_handler = OpenAIServingChat(
|
|
engine_client,
|
|
app.state.model_handler,
|
|
pid,
|
|
args.ips,
|
|
args.max_waiting_time,
|
|
chat_template,
|
|
args.enable_mm_output,
|
|
args.tokenizer_base_url,
|
|
)
|
|
completion_handler = OpenAIServingCompletion(
|
|
engine_client,
|
|
app.state.model_handler,
|
|
pid,
|
|
args.ips,
|
|
args.max_waiting_time,
|
|
)
|
|
|
|
embedding_handler = OpenAIServingEmbedding(
|
|
engine_client,
|
|
app.state.model_handler,
|
|
fd_config,
|
|
pid,
|
|
args.ips,
|
|
args.max_waiting_time,
|
|
chat_template,
|
|
)
|
|
reward_handler = OpenAIServingReward(
|
|
engine_client, app.state.model_handler, fd_config, pid, args.ips, args.max_waiting_time, chat_template
|
|
)
|
|
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
|
engine_client.pid = pid
|
|
app.state.engine_client = engine_client
|
|
app.state.chat_handler = chat_handler
|
|
app.state.completion_handler = completion_handler
|
|
app.state.embedding_handler = embedding_handler
|
|
app.state.reward_handler = reward_handler
|
|
app.state.event_loop = asyncio.get_running_loop()
|
|
|
|
if llm_engine is not None and not isinstance(llm_engine, AsyncLLM):
|
|
llm_engine.engine.data_processor = engine_client.data_processor
|
|
yield
|
|
# close zmq
|
|
try:
|
|
if envs.FD_ENABLE_ASYNC_LLM:
|
|
await llm_engine.shutdown()
|
|
await engine_client.connection_manager.close()
|
|
engine_client.zmq_client.close()
|
|
from prometheus_client import multiprocess
|
|
|
|
multiprocess.mark_process_dead(os.getpid())
|
|
api_server_logger.info(f"Closing metrics client pid: {pid}")
|
|
except Exception as e:
|
|
api_server_logger.warning(f"exit error: {e}, {str(traceback.format_exc())}")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception)
|
|
app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
|
|
|
|
|
|
env_api_key_func = environment_variables.get("FD_API_KEY")
|
|
env_tokens = env_api_key_func() if env_api_key_func else []
|
|
if tokens := [key for key in (args.api_key or env_tokens) if key]:
|
|
app.add_middleware(AuthenticationMiddleware, tokens)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def connection_manager():
|
|
"""
|
|
async context manager for connection manager
|
|
"""
|
|
try:
|
|
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
|
|
yield
|
|
except asyncio.TimeoutError:
|
|
api_server_logger.info(f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}")
|
|
raise HTTPException(
|
|
status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}"
|
|
)
|
|
|
|
|
|
# TODO 传递真实引擎值 通过pid 获取状态
|
|
@app.get("/health")
|
|
def health(request: Request) -> Response:
|
|
"""Health check."""
|
|
|
|
status, msg = app.state.engine_client.check_health()
|
|
if not status:
|
|
return Response(content=msg, status_code=404)
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return Response(content=msg, status_code=304)
|
|
return Response(status_code=200)
|
|
|
|
|
|
@app.get("/load")
|
|
async def list_all_routes():
|
|
"""
|
|
列出所有以/v1开头的路由信息
|
|
|
|
Args:
|
|
无参数
|
|
|
|
Returns:
|
|
dict: 包含所有符合条件的路由信息的字典,格式如下:
|
|
{
|
|
"routes": [
|
|
{
|
|
"path": str, # 路由路径
|
|
"methods": list, # 支持的HTTP方法列表,已排序
|
|
"tags": list # 路由标签列表,默认为空列表
|
|
},
|
|
...
|
|
]
|
|
}
|
|
|
|
"""
|
|
routes_info = []
|
|
|
|
for route in app.routes:
|
|
# 直接检查路径是否以/v1开头
|
|
if route.path.startswith("/v1"):
|
|
methods = sorted(route.methods)
|
|
tags = getattr(route, "tags", []) or []
|
|
routes_info.append({"path": route.path, "methods": methods, "tags": tags})
|
|
return {"routes": routes_info}
|
|
|
|
|
|
@app.api_route("/ping", methods=["GET", "POST"])
|
|
def ping(raw_request: Request) -> Response:
|
|
"""Ping check. Endpoint required for SageMaker"""
|
|
return health(raw_request)
|
|
|
|
|
|
@app.post("/v1/pause")
|
|
async def pause(request: Request) -> Response:
|
|
# todo: support wait_for_inflight_requests(default False), clear_cache(default True) arguments
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
control_request = ControlRequest(request_id, "pause")
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.post("/v1/resume")
|
|
async def resume(request: Request) -> Response:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
control_request = ControlRequest(request_id, "resume")
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.get("/v1/is_paused")
|
|
async def is_paused(request: Request) -> Response:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
control_request = ControlRequest(request_id, "is_paused")
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.post("/v1/sleep")
|
|
async def sleep(request: Request) -> Response:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
# Support both JSON body and query parameter
|
|
if await request.body():
|
|
request_data = await request.json()
|
|
else:
|
|
# Extract query params
|
|
request_data = dict(request.query_params)
|
|
|
|
try:
|
|
control_request = ControlRequest(request_id, "sleep", request_data)
|
|
except TypeError as e:
|
|
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
|
|
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.post("/v1/wakeup")
|
|
async def wakeup(request: Request) -> Response:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
# Support both JSON body and query parameter
|
|
if await request.body():
|
|
request_data = await request.json()
|
|
else:
|
|
# Extract query params
|
|
request_data = dict(request.query_params)
|
|
|
|
try:
|
|
control_request = ControlRequest(request_id, "wakeup", request_data)
|
|
except TypeError as e:
|
|
return JSONResponse(status_code=400, content={"error": "Invalid parameter type", "message": str(e)})
|
|
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.post("/v1/update_weights")
|
|
async def update_weights(request: Request) -> Response:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
|
|
request_data = await request.json() if await request.body() else {}
|
|
|
|
args = {}
|
|
|
|
# Validate and extract version parameter
|
|
if "version" in request_data and request_data["version"] is not None:
|
|
if not isinstance(request_data["version"], str):
|
|
return JSONResponse(
|
|
status_code=400, content={"error": "Invalid parameter type", "message": "version must be a string"}
|
|
)
|
|
args["version"] = request_data["version"]
|
|
|
|
# Validate and extract verify_checksum parameter
|
|
if "verify_checksum" in request_data and request_data["verify_checksum"] is not None:
|
|
if not isinstance(request_data["verify_checksum"], bool):
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"error": "Invalid parameter type", "message": "verify_checksum must be a boolean"},
|
|
)
|
|
args["verify_checksum"] = request_data["verify_checksum"]
|
|
|
|
control_request = ControlRequest(request_id, "update_weights", args)
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
@app.post("/v1/abort_requests")
|
|
async def abort_requests(request: Request):
|
|
body = await request.json()
|
|
abort_all = body.get("abort_all", False)
|
|
req_ids = body.get("req_ids", None)
|
|
|
|
# 参数校验
|
|
if not abort_all and not req_ids:
|
|
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})
|
|
|
|
control_request = ControlRequest(
|
|
request_id=f"control-{uuid.uuid4()}",
|
|
method="abort_requests",
|
|
args={"abort_all": abort_all, "req_ids": req_ids or []},
|
|
)
|
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
|
return control_response.to_api_json_response()
|
|
|
|
|
|
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
|
"""
|
|
Wrap an async generator to release the connection semaphore when the generator is finished.
|
|
"""
|
|
|
|
async def wrapped_generator():
|
|
span = trace.get_current_span()
|
|
if span is not None and span.is_recording():
|
|
last_time = None
|
|
count = 0
|
|
try:
|
|
async for chunk in original_generator:
|
|
last_time = time.time()
|
|
# 首包捕获
|
|
if count == 0 and span is not None and span.is_recording():
|
|
last_time = time.time()
|
|
span.add_event("first_chunk", {"time": last_time})
|
|
count += 1
|
|
yield chunk
|
|
except Exception as e:
|
|
# 错误捕获
|
|
if span is not None and span.is_recording():
|
|
span.add_event("stream_error", {"time": time.time(), "error": str(e), "total_chunk": count})
|
|
span.record_exception(e)
|
|
span.set_status({"code": "ERROR", "description": str(e)})
|
|
raise
|
|
finally:
|
|
# 尾包捕获
|
|
if span is not None and span.is_recording() and count > 0:
|
|
span.add_event("last_chunk", {"time": last_time, "total_chunk": count})
|
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
|
connection_semaphore.release()
|
|
else:
|
|
try:
|
|
async for chunk in original_generator:
|
|
yield chunk
|
|
finally:
|
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
|
connection_semaphore.release()
|
|
|
|
return wrapped_generator
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@with_cancellation
|
|
async def create_chat_completion(request: ChatCompletionRequest, req: Request):
|
|
"""
|
|
Create a chat completion for the provided prompt and parameters.
|
|
"""
|
|
api_server_logger.debug(f"Chat Received request: {request.model_dump_json()}")
|
|
if envs.TRACES_ENABLE:
|
|
if req.headers:
|
|
headers = dict(req.headers)
|
|
trace_context = extract(headers)
|
|
request.trace_context = trace_context
|
|
if app.state.dynamic_load_weight:
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
|
try:
|
|
async with connection_manager():
|
|
tracing.label_span(request)
|
|
generator = await app.state.chat_handler.create_chat_completion(request)
|
|
if isinstance(generator, ErrorResponse):
|
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
|
connection_semaphore.release()
|
|
return JSONResponse(content=generator.model_dump(), status_code=500)
|
|
elif isinstance(generator, ChatCompletionResponse):
|
|
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
|
connection_semaphore.release()
|
|
return JSONResponse(content=generator.model_dump())
|
|
else:
|
|
wrapped_generator = wrap_streaming_generator(generator)
|
|
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
|
|
|
except HTTPException as e:
|
|
api_server_logger.error(f"Error in chat completion: {str(e)}")
|
|
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
@with_cancellation
|
|
async def create_completion(request: CompletionRequest, req: Request):
|
|
"""
|
|
Create a completion for the provided prompt and parameters.
|
|
"""
|
|
api_server_logger.info(f"Completion Received request: {request.model_dump_json()}")
|
|
if envs.TRACES_ENABLE:
|
|
if req.headers:
|
|
headers = dict(req.headers)
|
|
trace_context = extract(headers)
|
|
request.trace_context = trace_context
|
|
if app.state.dynamic_load_weight:
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
|
try:
|
|
async with connection_manager():
|
|
tracing.label_span(request)
|
|
generator = await app.state.completion_handler.create_completion(request)
|
|
if isinstance(generator, ErrorResponse):
|
|
connection_semaphore.release()
|
|
return JSONResponse(content=generator.model_dump(), status_code=500)
|
|
elif isinstance(generator, CompletionResponse):
|
|
connection_semaphore.release()
|
|
return JSONResponse(content=generator.model_dump())
|
|
else:
|
|
wrapped_generator = wrap_streaming_generator(generator)
|
|
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
|
except HTTPException as e:
|
|
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models() -> Response:
|
|
"""
|
|
List all available models.
|
|
"""
|
|
if app.state.dynamic_load_weight:
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
|
|
|
models = await app.state.model_handler.list_models()
|
|
if isinstance(models, ErrorResponse):
|
|
return JSONResponse(content=models.model_dump())
|
|
elif isinstance(models, ModelList):
|
|
return JSONResponse(content=models.model_dump())
|
|
|
|
|
|
@app.post("/v1/reward")
|
|
async def create_reward(request: ChatRewardRequest):
|
|
"""
|
|
Create reward for the input texts
|
|
"""
|
|
if app.state.dynamic_load_weight:
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
|
|
|
generator = await app.state.reward_handler.create_reward(request)
|
|
return JSONResponse(content=generator.model_dump())
|
|
|
|
|
|
@app.post("/v1/embeddings")
|
|
async def create_embedding(request: EmbeddingRequest):
|
|
"""
|
|
Create embeddings for the input texts
|
|
"""
|
|
if app.state.dynamic_load_weight:
|
|
status, msg = app.state.engine_client.is_workers_alive()
|
|
if not status:
|
|
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
|
|
|
generator = await app.state.embedding_handler.create_embedding(request)
|
|
return JSONResponse(content=generator.model_dump())
|
|
|
|
|
|
@app.get("/update_model_weight")
|
|
@tracing.trace_span("update_model_weight")
|
|
def update_model_weight(request: Request) -> Response:
|
|
"""
|
|
update model weight
|
|
"""
|
|
if app.state.dynamic_load_weight:
|
|
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
control_request = ControlRequest(request_id, "wakeup")
|
|
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
|
|
return control_response.to_api_json_response()
|
|
else:
|
|
status_code, msg = app.state.engine_client.update_model_weight()
|
|
return JSONResponse(content=msg, status_code=status_code)
|
|
else:
|
|
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
|
|
|
|
|
|
@app.get("/clear_load_weight")
|
|
@tracing.trace_span("clear_load_weight")
|
|
def clear_load_weight(request: Request) -> Response:
|
|
"""
|
|
clear model weight
|
|
"""
|
|
if app.state.dynamic_load_weight:
|
|
if envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
|
|
request_id = f"control-{uuid.uuid4()}"
|
|
control_request = ControlRequest(request_id, "sleep")
|
|
control_response = app.state.engine_client.run_control_method_sync(control_request, app.state.event_loop)
|
|
return control_response.to_api_json_response()
|
|
else:
|
|
status_code, msg = app.state.engine_client.clear_load_weight()
|
|
return JSONResponse(content=msg, status_code=status_code)
|
|
else:
|
|
return JSONResponse(content={"error": "Dynamic Load Weight Disabled."}, status_code=404)
|
|
|
|
|
|
@app.post("/rearrange_experts")
|
|
@tracing.trace_span("rearrange_experts")
|
|
async def rearrange_experts(request: Request):
|
|
"""
|
|
rearrange experts
|
|
"""
|
|
request_dict = await request.json()
|
|
content, status_code = await app.state.engine_client.rearrange_experts(request_dict=request_dict)
|
|
return JSONResponse(content, status_code=status_code)
|
|
|
|
|
|
@app.post("/get_per_expert_tokens_stats")
|
|
@tracing.trace_span("get_per_expert_tokens_stats")
|
|
async def get_per_expert_tokens_stats(request: Request):
|
|
"""
|
|
get per expert tokens stats
|
|
"""
|
|
request_dict = await request.json()
|
|
content, status_code = await app.state.engine_client.get_per_expert_tokens_stats(request_dict=request_dict)
|
|
return JSONResponse(content, status_code=status_code)
|
|
|
|
|
|
@app.post("/check_redundant")
|
|
@tracing.trace_span("check_redundant")
|
|
async def check_redundant(request: Request):
|
|
"""
|
|
check redundant
|
|
"""
|
|
request_dict = await request.json()
|
|
content, status_code = await app.state.engine_client.check_redundant(request_dict=request_dict)
|
|
return JSONResponse(content, status_code=status_code)
|
|
|
|
|
|
def launch_api_server() -> None:
|
|
"""
|
|
启动http服务
|
|
"""
|
|
if not is_port_available(args.host, args.port):
|
|
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
|
|
|
api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}")
|
|
api_server_logger.info(f"args: {args.__dict__}")
|
|
# fd_start_span("FD_START")
|
|
|
|
# set control_socket_disable=True to avoid conflicts when running multiple instances
|
|
options = {
|
|
"bind": f"{args.host}:{args.port}",
|
|
"workers": args.workers,
|
|
"worker_class": "uvicorn.workers.UvicornWorker",
|
|
"loglevel": "info",
|
|
"graceful_timeout": args.timeout_graceful_shutdown,
|
|
"timeout": args.timeout,
|
|
"control_socket_disable": True,
|
|
}
|
|
|
|
try:
|
|
StandaloneApplication(app, options).run()
|
|
except Exception as e:
|
|
api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}")
|
|
|
|
|
|
metrics_app = FastAPI()
|
|
|
|
# Be tolerant to tests that monkeypatch/partially mock args.
|
|
_metrics_port = getattr(args, "metrics_port", None)
|
|
_main_port = getattr(args, "port", None)
|
|
|
|
if _metrics_port is None or (_main_port is not None and _metrics_port == _main_port):
|
|
metrics_app = app
|
|
|
|
|
|
@metrics_app.get("/metrics")
|
|
@tracing.trace_span("metrics")
|
|
async def metrics():
|
|
"""
|
|
metrics
|
|
"""
|
|
metrics_text = get_filtered_metrics()
|
|
return Response(metrics_text, media_type="text/plain")
|
|
|
|
|
|
@metrics_app.get("/config-info")
|
|
@tracing.trace_span("config-info")
|
|
def config_info() -> Response:
|
|
"""
|
|
Get the current configuration of the API server.
|
|
"""
|
|
global llm_engine
|
|
if llm_engine is None:
|
|
return Response("Engine not loaded", status_code=500)
|
|
cfg = llm_engine.cfg
|
|
|
|
def process_object(obj):
|
|
if hasattr(obj, "__dict__"):
|
|
return obj.__dict__
|
|
if isinstance(obj, (set, frozenset)):
|
|
return list(obj)
|
|
return str(obj)
|
|
|
|
cfg_dict = {k: v for k, v in cfg.__dict__.items()}
|
|
|
|
# Version info
|
|
cfg_dict["version_info"] = get_version_info()
|
|
|
|
# Chat template
|
|
cfg_dict["chat_template"] = chat_template
|
|
|
|
# Server config from args
|
|
cfg_dict["server_config"] = {
|
|
"host": args.host,
|
|
"port": args.port,
|
|
"workers": args.workers,
|
|
"metrics_port": args.metrics_port,
|
|
"controller_port": args.controller_port,
|
|
"max_concurrency": args.max_concurrency,
|
|
"max_waiting_time": args.max_waiting_time,
|
|
"timeout": args.timeout,
|
|
"timeout_graceful_shutdown": args.timeout_graceful_shutdown,
|
|
"served_model_name": args.served_model_name,
|
|
"task": args.task,
|
|
"model_config_name": args.model_config_name,
|
|
"tokenizer_base_url": args.tokenizer_base_url,
|
|
"enable_mm_output": args.enable_mm_output,
|
|
"tool_call_parser": args.tool_call_parser,
|
|
"tool_parser_plugin": args.tool_parser_plugin,
|
|
}
|
|
|
|
# GPU info
|
|
try:
|
|
import paddle
|
|
|
|
from fastdeploy.platforms import current_platform
|
|
|
|
device_info = {}
|
|
device_info["device_type"] = current_platform.device_name
|
|
device_info["device_count"] = paddle.device.cuda.device_count()
|
|
device_ids = str(cfg.parallel_config.device_ids).split(",") if cfg.parallel_config else ["0"]
|
|
first_device = int(device_ids[0].strip()) - 1
|
|
props = paddle.device.cuda.get_device_properties(first_device)
|
|
device_info["device_name"] = props.name
|
|
device_info["device_total_memory"] = props.total_memory
|
|
device_info["device_multi_processor_count"] = props.multi_processor_count
|
|
device_info["device_major"] = props.major
|
|
device_info["device_minor"] = props.minor
|
|
cfg_dict["device_info"] = device_info
|
|
except Exception:
|
|
cfg_dict["device_info"] = None
|
|
|
|
env_dict = {k: v() for k, v in environment_variables.items()}
|
|
cfg_dict["env_config"] = env_dict
|
|
result_content = json.dumps(cfg_dict, default=process_object, ensure_ascii=False)
|
|
return Response(result_content, media_type="application/json")
|
|
|
|
|
|
def run_metrics_server():
|
|
"""
|
|
run metrics server
|
|
"""
|
|
|
|
uvicorn.run(metrics_app, host="0.0.0.0", port=args.metrics_port, log_config=UVICORN_CONFIG, log_level="error")
|
|
|
|
|
|
def launch_metrics_server():
|
|
"""Metrics server running the sub thread"""
|
|
if not is_port_available(args.host, args.metrics_port):
|
|
raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.")
|
|
|
|
metrics_server_thread = threading.Thread(target=run_metrics_server, daemon=True)
|
|
metrics_server_thread.start()
|
|
time.sleep(1)
|
|
|
|
|
|
controller_app = FastAPI()
|
|
|
|
|
|
@controller_app.post("/controller/reset_scheduler")
|
|
def reset_scheduler():
|
|
"""
|
|
reset scheduler
|
|
"""
|
|
global llm_engine
|
|
|
|
if llm_engine is None:
|
|
return Response("Engine not loaded", status_code=500)
|
|
|
|
llm_engine.engine.clear_data()
|
|
llm_engine.engine.scheduler.reset()
|
|
return Response("Scheduler Reset Successfully", status_code=200)
|
|
|
|
|
|
@controller_app.post("/controller/scheduler")
|
|
def control_scheduler(request: ControlSchedulerRequest):
|
|
"""
|
|
Control the scheduler behavior with the given parameters.
|
|
"""
|
|
|
|
content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code="0"))
|
|
|
|
global llm_engine
|
|
if llm_engine is None:
|
|
content.message = "Engine is not loaded"
|
|
content.code = 500
|
|
return JSONResponse(content=content.model_dump(), status_code=500)
|
|
|
|
if request.reset:
|
|
llm_engine.engine.clear_data()
|
|
llm_engine.engine.scheduler.reset()
|
|
|
|
if request.load_shards_num or request.reallocate_shard:
|
|
if hasattr(llm_engine.engine.scheduler, "update_config") and callable(
|
|
llm_engine.engine.scheduler.update_config
|
|
):
|
|
llm_engine.engine.scheduler.update_config(
|
|
load_shards_num=request.load_shards_num,
|
|
reallocate=request.reallocate_shard,
|
|
)
|
|
else:
|
|
content.message = "This scheduler doesn't support the `update_config()` method."
|
|
content.code = 400
|
|
return JSONResponse(content=content.model_dump(), status_code=400)
|
|
|
|
return JSONResponse(content=content.model_dump(), status_code=200)
|
|
|
|
|
|
def run_controller_server():
|
|
"""
|
|
run controller server
|
|
"""
|
|
uvicorn.run(
|
|
controller_app,
|
|
host="0.0.0.0",
|
|
port=args.controller_port,
|
|
log_config=UVICORN_CONFIG,
|
|
log_level="error",
|
|
)
|
|
|
|
|
|
def launch_controller_server():
|
|
"""Controller server running the sub thread"""
|
|
if args.controller_port < 0:
|
|
return
|
|
|
|
if not is_port_available(args.host, args.controller_port):
|
|
raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.")
|
|
|
|
controller_server_thread = threading.Thread(target=run_controller_server, daemon=True)
|
|
controller_server_thread.start()
|
|
time.sleep(1)
|
|
|
|
|
|
def launch_worker_monitor():
|
|
"""
|
|
Detect whether worker process is alive. If not, stop the API serverby triggering llm_engine.
|
|
"""
|
|
|
|
def _monitor():
|
|
global llm_engine
|
|
while True:
|
|
if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc.poll() is not None:
|
|
console_logger.error(
|
|
f"Worker process has died in the background (code={llm_engine.worker_proc.returncode}). API server is forced to stop."
|
|
)
|
|
os.kill(os.getpid(), signal.SIGINT)
|
|
break
|
|
time.sleep(5)
|
|
|
|
worker_monitor_thread = threading.Thread(target=_monitor, daemon=True)
|
|
worker_monitor_thread.start()
|
|
time.sleep(1)
|
|
|
|
|
|
def main():
|
|
"""main函数"""
|
|
|
|
if args.ips and get_host_ip() not in args.ips:
|
|
api_server_logger.error(f"Worker IP {get_host_ip()} not in the list of allowed IPs {args.ips}.")
|
|
return
|
|
|
|
if args.local_data_parallel_id == 0:
|
|
if not load_engine():
|
|
return
|
|
else:
|
|
if not load_data_service():
|
|
return
|
|
api_server_logger.info("FastDeploy LLM engine initialized!\n")
|
|
if args.metrics_port is not None and args.metrics_port != args.port:
|
|
launch_metrics_server()
|
|
console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
|
|
else:
|
|
console_logger.info(f"Launching metrics service at http://{args.host}:{args.port}/metrics")
|
|
console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
|
|
console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
|
|
|
|
launch_worker_monitor()
|
|
launch_controller_server()
|
|
launch_api_server()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|