mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 08:21:53 +08:00
[Optimization] refactor(chat_handler,completion_handler): extract base classes and use AsyncLLM (#5195)
* [Optimization] refactor(chat_handler,completion_handler): extract base classes and use AsyncLLM * [Optimization] refactor(chat_handler,completion_handler): rename class
This commit is contained in:
@@ -25,13 +25,14 @@ from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from fastdeploy.engine.request import PoolingRequestOutput, RequestOutput
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
InvalidParameterException,
|
||||
)
|
||||
from fastdeploy.utils import ErrorCode, ErrorType, api_server_logger
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.utils import ErrorCode, ErrorType, StatefulSemaphore, api_server_logger
|
||||
|
||||
RequestT = TypeVar("RequestT")
|
||||
|
||||
@@ -46,8 +47,6 @@ class ServeContext(
|
||||
request_id: str
|
||||
created_time: int = Field(default_factory=lambda: int(time.time()))
|
||||
preprocess_requests: Optional[list[dict]] = None
|
||||
request_output: Optional[Union[RequestOutput, PoolingRequestOutput]] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
model_config = ConfigDict(
|
||||
@@ -62,8 +61,7 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
Base pipeline for OpenAI-style serving implementations
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time):
|
||||
self.engine_client = engine_client
|
||||
def __init__(self, models, cfg, pid, ips, max_waiting_time):
|
||||
self.models = models
|
||||
self.cfg = cfg
|
||||
self.pid = pid
|
||||
@@ -77,12 +75,18 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
self.master_ip = ips.split(",")[0]
|
||||
else:
|
||||
self.master_ip = "0.0.0.0"
|
||||
|
||||
self.__semaphore = None
|
||||
api_server_logger.info(f"master ip: {self.master_ip}")
|
||||
|
||||
def _get_semaphore(self) -> StatefulSemaphore:
|
||||
if self.__semaphore is None:
|
||||
workers = 1
|
||||
self.__semaphore = StatefulSemaphore((FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
|
||||
return self.__semaphore
|
||||
|
||||
def _check_master(self) -> bool:
|
||||
"""Check if current node is master"""
|
||||
return self.engine_client.is_master
|
||||
return True
|
||||
|
||||
def _check_supported_model(self, model_name: str) -> tuple[bool, str]:
|
||||
"""Check if model is supported and return adjusted model name"""
|
||||
@@ -97,11 +101,11 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
async def _acquire_semaphore(self, request_id: str) -> bool:
|
||||
"""Acquire engine client semaphore with timeout"""
|
||||
try:
|
||||
api_server_logger.info(f"Acquire request:{request_id} status:{self.engine_client.semaphore.status()}")
|
||||
api_server_logger.info(f"Acquire request:{request_id} status:{self._get_semaphore().status()}")
|
||||
if self.max_waiting_time < 0:
|
||||
await self.engine_client.semaphore.acquire()
|
||||
await self._get_semaphore().acquire()
|
||||
else:
|
||||
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
|
||||
await asyncio.wait_for(self._get_semaphore().acquire(), timeout=self.max_waiting_time)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
self._release_semaphore(request_id)
|
||||
@@ -111,8 +115,8 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
|
||||
def _release_semaphore(self, request_id: str) -> None:
|
||||
"""Release engine client semaphore"""
|
||||
self.engine_client.semaphore.release()
|
||||
api_server_logger.info(f"Release request:{request_id} status:{self.engine_client.semaphore.status()}")
|
||||
self._get_semaphore().release()
|
||||
api_server_logger.info(f"Release request:{request_id} status:{self._get_semaphore().status()}")
|
||||
|
||||
def _create_error_response(
|
||||
self,
|
||||
@@ -126,8 +130,9 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
api_server_logger.error(message)
|
||||
return ErrorResponse(error=ErrorInfo(message=message, type=error_type, code=code, param=param))
|
||||
|
||||
def _generate_request_id(self, user: Optional[str] = None) -> str:
|
||||
def _generate_request_id(self, request: RequestT) -> str:
|
||||
"""Generate a unique request ID"""
|
||||
user = getattr(request, "user", None)
|
||||
if user is not None:
|
||||
return f"{self.request_id_prefix}-{user}-{uuid.uuid4()}"
|
||||
return f"{self.request_id_prefix}-{uuid.uuid4()}"
|
||||
@@ -142,13 +147,13 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _prepare_generators(self, ctx: ServeContext) -> Any:
|
||||
async def _prepare_generators(self, ctx: ServeContext) -> AsyncGenerator[Any]:
|
||||
"""Process engine response into final format"""
|
||||
# 此函数是一个异步方法,用于处理引擎响应并将其转换为最终格式
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_response(self, ctx: ServeContext) -> Any:
|
||||
def _build_response(self, ctx: ServeContext, request_output: dict | RequestOutput) -> Any:
|
||||
"""Generate the final response object"""
|
||||
pass
|
||||
|
||||
@@ -185,7 +190,8 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
# Step 1.3: Validate request
|
||||
self._validate_request(ctx)
|
||||
|
||||
request_id = self._generate_request_id(getattr(request, "user", None))
|
||||
request_id = self._generate_request_id(request)
|
||||
ctx.request_id = request_id
|
||||
api_server_logger.info(f"Initialize request {request_id}: {request}")
|
||||
|
||||
# Step 2: Semaphore acquisition
|
||||
@@ -201,8 +207,7 @@ class OpenAIServing(ABC, Generic[RequestT]):
|
||||
|
||||
# Step 5: Final response build
|
||||
async for request_output in generators:
|
||||
ctx.request_output = request_output
|
||||
yield self._build_response(ctx)
|
||||
yield self._build_response(ctx, request_output)
|
||||
|
||||
except InvalidParameterException as e:
|
||||
traceback.print_exc()
|
||||
@@ -220,8 +225,9 @@ class ZmqOpenAIServing(OpenAIServing):
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time, chat_template):
|
||||
super().__init__(engine_client, models, cfg, pid, ips, max_waiting_time)
|
||||
super().__init__(models, cfg, pid, ips, max_waiting_time)
|
||||
self.chat_template = chat_template
|
||||
self.engine_client = engine_client
|
||||
|
||||
def _request_to_dict(self, ctx: ServeContext):
|
||||
request = ctx.request
|
||||
@@ -286,3 +292,33 @@ class ZmqOpenAIServing(OpenAIServing):
|
||||
raise ValueError(f"Error processing response: {str(e)}")
|
||||
finally:
|
||||
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||
|
||||
@override
|
||||
def _get_semaphore(self):
|
||||
return self.engine_client.semaphore
|
||||
|
||||
@override
|
||||
async def _acquire_semaphore(self, request_id: str) -> bool:
|
||||
"""Acquire engine client semaphore with timeout"""
|
||||
try:
|
||||
api_server_logger.info(f"Acquire request:{request_id} status:{self._get_semaphore().status()}")
|
||||
if self.max_waiting_time < 0:
|
||||
await self._get_semaphore().acquire()
|
||||
else:
|
||||
await asyncio.wait_for(self._get_semaphore().acquire(), timeout=self.max_waiting_time)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
self._release_semaphore(request_id)
|
||||
error_msg = f"Request waiting timeout, request:{request_id} max waiting time:{self.max_waiting_time}"
|
||||
api_server_logger.error(error_msg)
|
||||
return False
|
||||
|
||||
@override
|
||||
def _release_semaphore(self, request_id: str) -> None:
|
||||
"""Release engine client semaphore"""
|
||||
self._get_semaphore().release()
|
||||
api_server_logger.info(f"Release request:{request_id} status:{self._get_semaphore().status()}")
|
||||
|
||||
@override
|
||||
def _check_master(self) -> bool:
|
||||
return self.engine_client.is_master
|
||||
|
||||
Reference in New Issue
Block a user