[Optimization] refactor(chat_handler,completion_handler): extract base classes and use AsyncLLM (#5195)

* [Optimization] refactor(chat_handler,completion_handler): extract base classes and use AsyncLLM

* [Optimization] refactor(chat_handler,completion_handler): rename class
This commit is contained in:
memoryCoderC
2025-12-25 16:28:15 +08:00
committed by GitHub
parent 8fc789bb3f
commit be3be4913a
19 changed files with 3601 additions and 66 deletions
+56 -20
View File
@@ -25,13 +25,14 @@ from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
from fastdeploy.engine.request import PoolingRequestOutput, RequestOutput
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
InvalidParameterException,
)
from fastdeploy.utils import ErrorCode, ErrorType, api_server_logger
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.utils import ErrorCode, ErrorType, StatefulSemaphore, api_server_logger
RequestT = TypeVar("RequestT")
@@ -46,8 +47,6 @@ class ServeContext(
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
preprocess_requests: Optional[list[dict]] = None
request_output: Optional[Union[RequestOutput, PoolingRequestOutput]] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
@@ -62,8 +61,7 @@ class OpenAIServing(ABC, Generic[RequestT]):
Base pipeline for OpenAI-style serving implementations
"""
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time):
self.engine_client = engine_client
def __init__(self, models, cfg, pid, ips, max_waiting_time):
self.models = models
self.cfg = cfg
self.pid = pid
@@ -77,12 +75,18 @@ class OpenAIServing(ABC, Generic[RequestT]):
self.master_ip = ips.split(",")[0]
else:
self.master_ip = "0.0.0.0"
self.__semaphore = None
api_server_logger.info(f"master ip: {self.master_ip}")
def _get_semaphore(self) -> StatefulSemaphore:
if self.__semaphore is None:
workers = 1
self.__semaphore = StatefulSemaphore((FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
return self.__semaphore
def _check_master(self) -> bool:
"""Check if current node is master"""
return self.engine_client.is_master
return True
def _check_supported_model(self, model_name: str) -> tuple[bool, str]:
"""Check if model is supported and return adjusted model name"""
@@ -97,11 +101,11 @@ class OpenAIServing(ABC, Generic[RequestT]):
async def _acquire_semaphore(self, request_id: str) -> bool:
"""Acquire engine client semaphore with timeout"""
try:
api_server_logger.info(f"Acquire request:{request_id} status:{self.engine_client.semaphore.status()}")
api_server_logger.info(f"Acquire request:{request_id} status:{self._get_semaphore().status()}")
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
await self._get_semaphore().acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
await asyncio.wait_for(self._get_semaphore().acquire(), timeout=self.max_waiting_time)
return True
except asyncio.TimeoutError:
self._release_semaphore(request_id)
@@ -111,8 +115,8 @@ class OpenAIServing(ABC, Generic[RequestT]):
def _release_semaphore(self, request_id: str) -> None:
"""Release engine client semaphore"""
self.engine_client.semaphore.release()
api_server_logger.info(f"Release request:{request_id} status:{self.engine_client.semaphore.status()}")
self._get_semaphore().release()
api_server_logger.info(f"Release request:{request_id} status:{self._get_semaphore().status()}")
def _create_error_response(
self,
@@ -126,8 +130,9 @@ class OpenAIServing(ABC, Generic[RequestT]):
api_server_logger.error(message)
return ErrorResponse(error=ErrorInfo(message=message, type=error_type, code=code, param=param))
def _generate_request_id(self, user: Optional[str] = None) -> str:
def _generate_request_id(self, request: RequestT) -> str:
"""Generate a unique request ID"""
user = getattr(request, "user", None)
if user is not None:
return f"{self.request_id_prefix}-{user}-{uuid.uuid4()}"
return f"{self.request_id_prefix}-{uuid.uuid4()}"
@@ -142,13 +147,13 @@ class OpenAIServing(ABC, Generic[RequestT]):
pass
@abstractmethod
async def _prepare_generators(self, ctx: ServeContext) -> Any:
async def _prepare_generators(self, ctx: ServeContext) -> AsyncGenerator[Any]:
"""Process engine response into final format"""
# 此函数是一个异步方法,用于处理引擎响应并将其转换为最终格式
pass
@abstractmethod
def _build_response(self, ctx: ServeContext) -> Any:
def _build_response(self, ctx: ServeContext, request_output: dict | RequestOutput) -> Any:
"""Generate the final response object"""
pass
@@ -185,7 +190,8 @@ class OpenAIServing(ABC, Generic[RequestT]):
# Step 1.3: Validate request
self._validate_request(ctx)
request_id = self._generate_request_id(getattr(request, "user", None))
request_id = self._generate_request_id(request)
ctx.request_id = request_id
api_server_logger.info(f"Initialize request {request_id}: {request}")
# Step 2: Semaphore acquisition
@@ -201,8 +207,7 @@ class OpenAIServing(ABC, Generic[RequestT]):
# Step 5: Final response build
async for request_output in generators:
ctx.request_output = request_output
yield self._build_response(ctx)
yield self._build_response(ctx, request_output)
except InvalidParameterException as e:
traceback.print_exc()
@@ -220,8 +225,9 @@ class ZmqOpenAIServing(OpenAIServing):
"""
def __init__(self, engine_client, models, cfg, pid, ips, max_waiting_time, chat_template):
super().__init__(engine_client, models, cfg, pid, ips, max_waiting_time)
super().__init__(models, cfg, pid, ips, max_waiting_time)
self.chat_template = chat_template
self.engine_client = engine_client
def _request_to_dict(self, ctx: ServeContext):
request = ctx.request
@@ -286,3 +292,33 @@ class ZmqOpenAIServing(OpenAIServing):
raise ValueError(f"Error processing response: {str(e)}")
finally:
await self.engine_client.connection_manager.cleanup_request(request_id)
@override
def _get_semaphore(self):
return self.engine_client.semaphore
@override
async def _acquire_semaphore(self, request_id: str) -> bool:
"""Acquire engine client semaphore with timeout"""
try:
api_server_logger.info(f"Acquire request:{request_id} status:{self._get_semaphore().status()}")
if self.max_waiting_time < 0:
await self._get_semaphore().acquire()
else:
await asyncio.wait_for(self._get_semaphore().acquire(), timeout=self.max_waiting_time)
return True
except asyncio.TimeoutError:
self._release_semaphore(request_id)
error_msg = f"Request waiting timeout, request:{request_id} max waiting time:{self.max_waiting_time}"
api_server_logger.error(error_msg)
return False
@override
def _release_semaphore(self, request_id: str) -> None:
"""Release engine client semaphore"""
self._get_semaphore().release()
api_server_logger.info(f"Release request:{request_id} status:{self._get_semaphore().status()}")
@override
def _check_master(self) -> bool:
return self.engine_client.is_master