[Feature] Support stopping the inference for the corresponding request in the online service after a disconnection request. (#5320)

* request disconnect

* request disconnect

* fix bug

* fix bug--amend

---------

Co-authored-by: root <root@yq01-sys-rpm26xc1knu.yq01.baidu.com>
This commit is contained in:
qwes5s5
2026-01-16 11:46:13 +08:00
committed by GitHub
parent 8f035101ad
commit b2a2e11551
25 changed files with 1339 additions and 63 deletions
@@ -444,33 +444,35 @@ class PrefixCacheManager:
else: else:
return True return True
def allocate_gpu_blocks(self, num_blocks): def allocate_gpu_blocks(self, num_blocks, req_id=None):
""" """
allocate gpu blocks. allocate gpu blocks.
""" """
assert num_blocks <= len( assert num_blocks <= len(
self.gpu_free_block_list self.gpu_free_block_list
), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}" ), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}"
logger.debug(f"{req_id} start allocate...")
allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)] allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)]
logger.info( logger.info(
f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" f"req_id:{req_id} allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
) )
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
return allocated_block_ids return allocated_block_ids
def recycle_gpu_blocks(self, gpu_block_ids): def recycle_gpu_blocks(self, gpu_block_ids, req_id=None):
""" """
recycle gpu blocks. recycle gpu blocks.
""" """
logger.info( logger.info(
f"recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" f"req_id:{req_id} recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}"
) )
if isinstance(gpu_block_ids, list): if isinstance(gpu_block_ids, list):
for gpu_block_id in gpu_block_ids: for gpu_block_id in gpu_block_ids:
heapq.heappush(self.gpu_free_block_list, gpu_block_id) heapq.heappush(self.gpu_free_block_list, gpu_block_id)
else: else:
heapq.heappush(self.gpu_free_block_list, gpu_block_ids) heapq.heappush(self.gpu_free_block_list, gpu_block_ids)
logger.debug(f"req_id:{req_id} recycle blocks end")
main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list))
main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource)
@@ -978,7 +980,7 @@ class PrefixCacheManager:
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}") logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
if leaf_node == self.radix_tree_root: if leaf_node == self.radix_tree_root:
self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id]) self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id], req_id)
del self.unfilled_req_block_map[req_id] del self.unfilled_req_block_map[req_id]
return return
+21 -1
View File
@@ -38,7 +38,7 @@ import zmq
from tqdm import tqdm from tqdm import tqdm
import fastdeploy.metrics.trace as tracing import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.request import Request, RequestOutput, RequestType from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.eplb.utils import init_eplb_signals from fastdeploy.eplb.utils import init_eplb_signals
@@ -721,6 +721,7 @@ class EngineService:
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens, max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
batch=num_prefill_batch, batch=num_prefill_batch,
) )
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks: for task in tasks:
task.metrics.engine_get_req_time = time.time() task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
@@ -787,6 +788,7 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch, batch=num_prefill_batch,
) )
tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set]
for task in tasks: for task in tasks:
task.metrics.engine_get_req_time = time.time() task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
@@ -1074,6 +1076,24 @@ class EngineService:
request, insert_task = None, [] request, insert_task = None, []
results: List[Tuple[str, Optional[str]]] = list() results: List[Tuple[str, Optional[str]]] = list()
if data: if data:
status_value = data.get("status", None)
if status_value is not None and status_value == RequestStatus.ABORT.value:
req_id = data["request_id"]
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
self.resource_manager.abort_req_ids_set.add(req_id)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if req_id in self.resource_manager.requests:
req = self.resource_manager.requests[req_id]
task = self.resource_manager._prepare_preempt_task(req)
self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz))
self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}")
else:
self.scheduler._recycle(req_id)
self.llm_logger.info(
f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler"
)
self.resource_manager.abort_req_ids_set.remove(req_id)
continue
err_msg = None err_msg = None
try: try:
request = Request.from_dict(data) request = Request.from_dict(data)
+1
View File
@@ -43,6 +43,7 @@ class RequestStatus(Enum):
RUNNING = 1 RUNNING = 1
PREEMPTED = 2 PREEMPTED = 2
FINISHED = 3 FINISHED = 3
ABORT = 4
class RequestType(Enum): class RequestType(Enum):
+1
View File
@@ -58,6 +58,7 @@ class ResourceManager:
self.req_dict = dict() self.req_dict = dict()
# current batch status of the engine # current batch status of the engine
self.real_bsz = 0 self.real_bsz = 0
self.abort_req_ids_set = set()
llm_logger.info(f"{self.info()}") llm_logger.info(f"{self.info()}")
main_process_metrics.max_batch_size.set(max_num_seqs) main_process_metrics.max_batch_size.set(max_num_seqs)
+42 -16
View File
@@ -175,6 +175,7 @@ class ResourceManagerV1(ResourceManager):
self.using_extend_tables_req_id = set() self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict() self.reuse_block_num_map = dict()
self.abort_req_ids_set = set()
# need block nums # need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32) need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
@@ -243,6 +244,7 @@ class ResourceManagerV1(ResourceManager):
request = self.requests[request_id] request = self.requests[request_id]
if process_func is not None: if process_func is not None:
process_func(request) process_func(request)
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
self.waiting.appendleft(request) self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id) self.to_be_rescheduled_request_id_set.remove(request_id)
@@ -640,7 +642,9 @@ class ResourceManagerV1(ResourceManager):
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}" f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
) )
request.block_tables.extend( request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
) )
# Prepare decoding task # Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request)) scheduled_reqs.append(self._prepare_decode_task(request))
@@ -653,7 +657,9 @@ class ResourceManagerV1(ResourceManager):
break break
# Allocation for next decoding blocks # Allocation for next decoding blocks
request.block_tables.extend( request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
) )
# Prepare decoding task # Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request)) scheduled_reqs.append(self._prepare_decode_task(request))
@@ -668,7 +674,9 @@ class ResourceManagerV1(ResourceManager):
def _allocate_decode_and_extend(): def _allocate_decode_and_extend():
allocate_block_num = self.need_block_num_map[request.request_id].consume() allocate_block_num = self.need_block_num_map[request.request_id].consume()
# Prepare decoding task # Prepare decoding task
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num)) request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
)
scheduled_reqs.append(self._prepare_decode_task(request)) scheduled_reqs.append(self._prepare_decode_task(request))
# Prepare extend task # Prepare extend task
@@ -682,7 +690,7 @@ class ResourceManagerV1(ResourceManager):
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
request.extend_block_tables.extend( request.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num) self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
) )
scheduled_reqs.append( scheduled_reqs.append(
ScheduledExtendBlocksTask( ScheduledExtendBlocksTask(
@@ -733,14 +741,18 @@ class ResourceManagerV1(ResourceManager):
num_new_block = self.get_new_block_nums(request, num_new_tokens) num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill # Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block): if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task # Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
if not can_schedule: if not can_schedule:
break break
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
# Prepare prefill task # Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens token_budget -= num_new_tokens
@@ -806,7 +818,9 @@ class ResourceManagerV1(ResourceManager):
# Allocate blocks to prefill # Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if not request.get("skip_allocate", False): if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
request.block_tables.extend(extra_gpu_block_ids) request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft() self.waiting.popleft()
self.running.append(request) self.running.append(request)
@@ -824,6 +838,7 @@ class ResourceManagerV1(ResourceManager):
self.tasks_list[allocated_position] = request self.tasks_list[allocated_position] = request
self.stop_flags[allocated_position] = False self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position self.req_dict[request.request_id] = allocated_position
llm_logger.debug(f"req_id:{request.request_id} allocate pos end")
else: else:
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request) self._free_blocks(request)
@@ -856,7 +871,9 @@ class ResourceManagerV1(ResourceManager):
# Allocate blocks to prefill # Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if not request.get("skip_allocate", False): if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
request.block_tables.extend(extra_gpu_block_ids) request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft() self.waiting.popleft()
self.running.append(request) self.running.append(request)
@@ -873,7 +890,7 @@ class ResourceManagerV1(ResourceManager):
self._free_blocks(request) self._free_blocks(request)
break break
else: else:
llm_logger.error("Unknown request status type") llm_logger.info(f"Unknown request status type:{request.status}, req_id:{request.request_id}")
for req in skip_requests: for req in skip_requests:
# move waiting request to end of the deque # move waiting request to end of the deque
@@ -1069,6 +1086,7 @@ class ResourceManagerV1(ResourceManager):
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
with self.lock: with self.lock:
self.apply_async_preprocess(request) self.apply_async_preprocess(request)
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
@@ -1120,7 +1138,9 @@ class ResourceManagerV1(ResourceManager):
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0] need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks): if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks) extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
need_extra_prefill_blocks, request.request_id
)
request.block_tables.extend(extra_gpu_block_ids) request.block_tables.extend(extra_gpu_block_ids)
allocated_position = self.get_available_position() allocated_position = self.get_available_position()
request.idx = allocated_position request.idx = allocated_position
@@ -1135,7 +1155,9 @@ class ResourceManagerV1(ResourceManager):
else: else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id)
)
request.num_computed_tokens = 0 request.num_computed_tokens = 0
allocated_position = self.get_available_position() allocated_position = self.get_available_position()
request.idx = allocated_position request.idx = allocated_position
@@ -1169,7 +1191,9 @@ class ResourceManagerV1(ResourceManager):
if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
return False return False
request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks) request.block_tables = self.cache_manager.allocate_gpu_blocks(
need_prealloc_prefill_blocks, request.request_id
)
request.num_computed_tokens = request.need_prefill_tokens request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position() allocated_position = self.get_available_position()
@@ -1221,16 +1245,18 @@ class ResourceManagerV1(ResourceManager):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching: if self.config.cache_config.enable_prefix_caching:
self.cache_manager.release_block_ids(request) self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.num_cached_blocks :]) self.cache_manager.recycle_gpu_blocks(
request.block_tables[request.num_cached_blocks :], request.request_id
)
else: else:
self.cache_manager.recycle_gpu_blocks(request.block_tables) self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id)
request.block_tables = [] request.block_tables = []
if request.request_id in self.using_extend_tables_req_id: if request.request_id in self.using_extend_tables_req_id:
reuse_block_num = self.reuse_block_num_map[request.request_id] reuse_block_num = self.reuse_block_num_map[request.request_id]
self.using_extend_tables_req_id.remove(request.request_id) self.using_extend_tables_req_id.remove(request.request_id)
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:]) self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:], request.request_id)
llm_logger.info( llm_logger.info(
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}" f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
) )
@@ -1280,9 +1306,9 @@ class ResourceManagerV1(ResourceManager):
for req in need_postprocess_reqs: for req in need_postprocess_reqs:
try: try:
self._free_blocks(req) self._free_blocks(req)
llm_logger.debug(f"req_id:{req.request_id} free pos:{req.idx}")
except Exception as e: except Exception as e:
llm_logger.warning(f"release block failed {req.request_id}: {e}") llm_logger.warning(f"release block failed {req.request_id}: {e}")
except Exception as e: except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
finally: finally:
+26
View File
@@ -16,6 +16,7 @@
import inspect import inspect
import os import os
import re
import time import time
import traceback import traceback
import uuid import uuid
@@ -28,6 +29,7 @@ from filelock import FileLock
import fastdeploy.metrics.trace as tracing import fastdeploy.metrics.trace as tracing
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import RequestStatus
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.eplb.utils import RedundantExpertWorkload from fastdeploy.eplb.utils import RedundantExpertWorkload
@@ -844,3 +846,27 @@ class EngineClient:
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list} content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
status_code = HTTPStatus.OK status_code = HTTPStatus.OK
return content, status_code return content, status_code
async def abort(self, request_id, n=1) -> None:
if envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE:
api_server_logger.info(f"abort request_id:{request_id}")
if n <= 0:
api_server_logger.warning("Abort function called with non-positive n: %d. No requests aborted.", n)
return
match = re.search(r"_\d+$", request_id)
if match:
prefix = request_id[: match.start()]
else:
api_server_logger.warning(
"request_id format error: %s does not end with _<number>. Using it as prefix.", request_id
)
prefix = request_id
request_ids = [f"{prefix}_{i}" for i in range(n)]
for req_id in request_ids:
data = {
"request_id": req_id,
"status": RequestStatus.ABORT.value,
}
self._send_task(data)
api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids))
+7 -1
View File
@@ -59,7 +59,11 @@ from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbeddi
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser from fastdeploy.entrypoints.openai.utils import (
UVICORN_CONFIG,
make_arg_parser,
with_cancellation,
)
from fastdeploy.entrypoints.openai.v1.serving_chat import ( from fastdeploy.entrypoints.openai.v1.serving_chat import (
OpenAIServingChat as OpenAIServingChatV1, OpenAIServingChat as OpenAIServingChatV1,
) )
@@ -410,6 +414,7 @@ def wrap_streaming_generator(original_generator: AsyncGenerator):
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest, req: Request): async def create_chat_completion(request: ChatCompletionRequest, req: Request):
""" """
Create a chat completion for the provided prompt and parameters. Create a chat completion for the provided prompt and parameters.
@@ -446,6 +451,7 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request):
@app.post("/v1/completions") @app.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, req: Request): async def create_completion(request: CompletionRequest, req: Request):
""" """
Create a completion for the provided prompt and parameters. Create a completion for the provided prompt and parameters.
+17 -2
View File
@@ -180,6 +180,13 @@ class OpenAIServingChat:
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}" error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg) api_server_logger.error(error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
except asyncio.CancelledError as e:
await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
)
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, " f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
@@ -267,7 +274,9 @@ class OpenAIServingChat:
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
if current_waiting_time == 300: if current_waiting_time == 300:
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) status, msg = self.engine_client.check_health(
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
)
if not status: if not status:
if choices: if choices:
chunk.choices = choices chunk.choices = choices
@@ -506,6 +515,10 @@ class OpenAIServingChat:
) )
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
except asyncio.CancelledError as e:
await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
except Exception as e: except Exception as e:
error_data = self._create_streaming_error_response( error_data = self._create_streaming_error_response(
f"request[{request_id}] generate stream error: {str(e)}, {str(traceback.format_exc())}" f"request[{request_id}] generate stream error: {str(e)}, {str(traceback.format_exc())}"
@@ -577,7 +590,9 @@ class OpenAIServingChat:
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
if current_waiting_time == 300: if current_waiting_time == 300:
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) status, msg = self.engine_client.check_health(
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
)
if not status: if not status:
raise ValueError(f"Engine is not healthy: {msg}") raise ValueError(f"Engine is not healthy: {msg}")
else: else:
@@ -230,7 +230,13 @@ class OpenAIServingCompletion:
) )
api_server_logger.error(error_msg) api_server_logger.error(error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
except asyncio.CancelledError as e:
await self.engine_client.abort(f"{request_id}_0", num_choices)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
)
except Exception as e: except Exception as e:
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}" error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg) api_server_logger.error(error_msg)
@@ -285,7 +291,9 @@ class OpenAIServingCompletion:
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
if current_waiting_time == 300: if current_waiting_time == 300:
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) status, msg = self.engine_client.check_health(
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
)
if not status: if not status:
raise ValueError(f"Engine is not healthy: {msg}") raise ValueError(f"Engine is not healthy: {msg}")
else: else:
@@ -455,7 +463,9 @@ class OpenAIServingCompletion:
except asyncio.TimeoutError: except asyncio.TimeoutError:
current_waiting_time += 10 current_waiting_time += 10
if current_waiting_time == 300: if current_waiting_time == 300:
status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) status, msg = self.engine_client.check_health(
time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT
)
if not status: if not status:
raise ValueError(f"Engine is not healthy: {msg}") raise ValueError(f"Engine is not healthy: {msg}")
else: else:
@@ -634,6 +644,10 @@ class OpenAIServingCompletion:
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}") api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
except asyncio.CancelledError as e:
await self.engine_client.abort(f"{request_id}_0", num_choices)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
except Exception as e: except Exception as e:
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}") api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n" yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n"
+54
View File
@@ -15,6 +15,7 @@
""" """
import asyncio import asyncio
import functools
import heapq import heapq
import random import random
import time import time
@@ -22,6 +23,7 @@ from multiprocessing.reduction import ForkingPickler
import aiozmq import aiozmq
import zmq import zmq
from fastapi import Request
from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
@@ -253,3 +255,55 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
return parser return parser
async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.
This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.
A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.
In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.The response object will only correctly
listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4).
"""
# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["req"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
return handler_task.result()
return None
return wrapper
+3
View File
@@ -158,6 +158,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# "Enable FP8 calibration on HPU" # "Enable FP8 calibration on HPU"
"FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"),
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
"FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int(
os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1")
),
# Whether to collect user information # Whether to collect user information
"DO_NOT_TRACK": lambda: (os.getenv("DO_NOT_TRACK", "0")) == "1", "DO_NOT_TRACK": lambda: (os.getenv("DO_NOT_TRACK", "0")) == "1",
# Usage stats server url # Usage stats server url
+48 -2
View File
@@ -248,9 +248,24 @@ class TokenProcessor:
task: Request = self.resource_manager.tasks_list[i] task: Request = self.resource_manager.tasks_list[i]
task_id = task.request_id task_id = task.request_id
token_ids = stream_data.tokens # numpy.array token_ids = stream_data.tokens # numpy.array
if token_ids is not None and token_ids[-1] <= 0: if token_ids is not None and token_ids[-1] <= 0:
if task_id in self.resource_manager.abort_req_ids_set:
if (
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if ( if (
task_id in self.resource_manager.to_be_rescheduled_request_id_set task_id in self.resource_manager.to_be_rescheduled_request_id_set
@@ -711,6 +726,19 @@ class TokenProcessor:
if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, meas preemption has happend in worker if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, meas preemption has happend in worker
llm_logger.info(f"sync preemption for request_id {task_id} done.") llm_logger.info(f"sync preemption for request_id {task_id} done.")
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.abort_req_ids_set:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if task_id in self.resource_manager.to_be_rescheduled_request_id_set: if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id) self.resource_manager.reschedule_preempt_task(task_id)
continue continue
@@ -737,6 +765,22 @@ class TokenProcessor:
if recovery_stop: if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}") llm_logger.info(f"recovery stop signal found at task {task_id}")
if not recovery_stop and token_id < 0: if not recovery_stop and token_id < 0:
if task_id in self.resource_manager.abort_req_ids_set:
if (
envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID
) or not envs.ENABLE_V1_KVCACHE_SCHEDULER:
llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.")
self.resource_manager.abort_req_ids_set.remove(task_id)
self._recycle_resources(task_id, i, task)
llm_logger.info(f"{task_id} received negative token. Recycle end.")
abort_res = RequestOutput(
request_id=task_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{task_id} is aborted.",
)
batch_result.append(abort_res)
continue
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if ( if (
task_id in self.resource_manager.to_be_rescheduled_request_id_set task_id in self.resource_manager.to_be_rescheduled_request_id_set
@@ -760,6 +804,7 @@ class TokenProcessor:
task.metrics.record_recv_first_token() task.metrics.record_recv_first_token()
task.metrics.cal_cost_time() task.metrics.cal_cost_time()
metrics = copy.copy(task.metrics) metrics = copy.copy(task.metrics)
llm_logger.info(f"task:{task.request_id} start recode first token")
self._record_first_token_metrics(task, current_time) self._record_first_token_metrics(task, current_time)
tracing.trace_report_span( tracing.trace_report_span(
@@ -845,7 +890,6 @@ class TokenProcessor:
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
if token_id in task.eos_token_ids or is_prefill or recovery_stop: if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True result.finished = True
trace_carrier = tracing.trace_get_proc_propagate_context(rid=rid) trace_carrier = tracing.trace_get_proc_propagate_context(rid=rid)
@@ -872,7 +916,9 @@ class TokenProcessor:
self._compute_speculative_status(result) self._compute_speculative_status(result)
if not is_prefill: if not is_prefill:
self._record_completion_metrics(task, current_time) self._record_completion_metrics(task, current_time)
llm_logger.info(f"task {task_id} received eos token. Recycling.")
self._recycle_resources(task_id, i, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill)
llm_logger.info(f"eos token {task_id} Recycle end.")
break break
llm_logger.debug(f"get response from infer: {result}") llm_logger.debug(f"get response from infer: {result}")
+1
View File
@@ -220,6 +220,7 @@ class ErrorCode(str, Enum):
CONNECTION_ERROR = "connection_error" CONNECTION_ERROR = "connection_error"
MISSING_REQUIRED_PARAMETER = "missing_required_parameter" MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
INTERNAL_ERROR = "internal_error" INTERNAL_ERROR = "internal_error"
CLIENT_ABORTED = "client_aborted"
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
+1 -1
View File
@@ -8,7 +8,7 @@ aiozmq
openai openai
tqdm tqdm
pynvml pynvml
uvicorn==0.29.0 uvicorn>=0.38.0
fastapi fastapi
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
redis redis
+1 -1
View File
@@ -8,7 +8,7 @@ aiozmq
openai>=1.93.0 openai>=1.93.0
tqdm tqdm
pynvml pynvml
uvicorn==0.29.0 uvicorn>=0.38.0
fastapi fastapi
paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl
redis redis
+1 -1
View File
@@ -8,7 +8,7 @@ aiozmq
openai>=1.93.0 openai>=1.93.0
tqdm tqdm
pynvml pynvml
uvicorn==0.29.0 uvicorn>=0.38.0
fastapi fastapi
# if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton. # if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton.
paddleformers==0.3.2 paddleformers==0.3.2
+14 -23
View File
@@ -14,7 +14,6 @@
import asyncio import asyncio
import os import os
import re
import shutil import shutil
import signal import signal
import subprocess import subprocess
@@ -48,7 +47,7 @@ def setup_and_run_server():
- Tears down server after all tests finish - Tears down server after all tests finish
""" """
print("Pre-test port cleanup...") print("Pre-test port cleanup...")
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633)) FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8333))
clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT]) clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT])
env = os.environ.copy() env = os.environ.copy()
@@ -173,9 +172,12 @@ def parse_prometheus_to_dict(metrics_text: str):
value = float(line.split("}")[1].strip()) value = float(line.split("}")[1].strip())
# 解析 labels # 解析 labels
# 用正则取出所有 key 和 value(去掉外层引号) labels = {}
pairs = re.findall(r'(\w+)="([^"]*)"', labels_str) for kv in labels_str.split(","):
labels = {k: v for k, v in pairs} if "=" not in kv:
continue
k, v = kv.split("=")
labels[k] = v.strip('"')
# 存储 # 存储
if metric_name not in result: if metric_name not in result:
@@ -212,7 +214,6 @@ def test_metrics_with_clear_and_reset():
""" """
Test the metrics monitoring endpoint. Test the metrics monitoring endpoint.
""" """
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
async_concurrency(n=10) async_concurrency(n=10)
@@ -229,23 +230,13 @@ def test_metrics_with_clear_and_reset():
running = metrics["fastdeploy:num_requests_running"] running = metrics["fastdeploy:num_requests_running"]
waiting = metrics["fastdeploy:num_requests_waiting"] waiting = metrics["fastdeploy:num_requests_waiting"]
print("ASSERT clear_load_weight后非0 running:", running, "waiting:", waiting) print(
assert running != 0 or waiting != 0, "Expected running/waiting to be non-zero" "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):",
running,
# ===== reset_scheduler ===== "waiting:",
reset_url = f"http://0.0.0.0:{FD_CONTROLLER_PORT}/controller/reset_scheduler" waiting,
print("Calling reset_scheduler...") )
r = requests.post(reset_url, json={"reset": True}, timeout=30) assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight"
assert r.status_code == 200, f"reset_scheduler failed: {r.status_code}"
metrics = get_metrics_dict(metrics_url)
running = metrics["fastdeploy:num_requests_running"]
waiting = metrics["fastdeploy:num_requests_waiting"]
print("ASSERT reset_scheduler后为0 running:", running, "waiting:", waiting)
# Temporarily disable this assertion. The running/waiting states are not strictly
# guaranteed to reach zero in the current workflow, so we skip this check for now.
# assert running == 0 and waiting == 0, "Expected running/waiting to be zero"
if __name__ == "__main__": if __name__ == "__main__":
+16 -5
View File
@@ -83,6 +83,7 @@ def _reload_api_server(args):
fake_envs_mod.TRACES_EXPORTER = "console" fake_envs_mod.TRACES_EXPORTER = "console"
fake_envs_mod.EXPORTER_OTLP_ENDPOINT = "" fake_envs_mod.EXPORTER_OTLP_ENDPOINT = ""
fake_envs_mod.EXPORTER_OTLP_HEADERS = "" fake_envs_mod.EXPORTER_OTLP_HEADERS = ""
fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024
fake_envs_mod.environment_variables = _FakeEnvVars() fake_envs_mod.environment_variables = _FakeEnvVars()
# Save original sys.argv and replace with minimal valid args to avoid parse errors # Save original sys.argv and replace with minimal valid args to avoid parse errors
@@ -397,26 +398,36 @@ async def test_chat_and_completion_routes():
chat_handler = MagicMock() chat_handler = MagicMock()
chat_handler.create_chat_completion = AsyncMock(return_value=error_resp) chat_handler.create_chat_completion = AsyncMock(return_value=error_resp)
api_server.app.state.chat_handler = chat_handler api_server.app.state.chat_handler = chat_handler
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500
success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo()) success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo())
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp) api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp)
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200
async def stream_gen(): async def stream_gen():
yield "data" yield "data"
api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen()) api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen())
assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse)
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse)
# Completion handler # Completion handler
completion_handler = MagicMock() completion_handler = MagicMock()
completion_handler.create_completion = AsyncMock(return_value=error_resp) completion_handler.create_completion = AsyncMock(return_value=error_resp)
api_server.app.state.completion_handler = completion_handler api_server.app.state.completion_handler = completion_handler
assert (await api_server.create_completion(body, fake_req)).status_code == 500
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
assert (await api_server.create_completion(body, fake_req)).status_code == 500
api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp) api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp)
assert (await api_server.create_completion(body, fake_req)).status_code == 200
with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()):
assert (await api_server.create_completion(body, fake_req)).status_code == 200
# HTTPException handling # HTTPException handling
class RaiseHTTP: class RaiseHTTP:
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import asyncio
import json import json
import unittest import unittest
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
@@ -1079,6 +1080,127 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
# logprobs should be None when not requested # logprobs should be None when not requested
self.assertIsNone(choice.logprobs) self.assertIsNone(choice.logprobs)
async def test_create_chat_completion_cancelled_error(self):
"""Test asyncio.CancelledError handling in create_chat_completion method"""
# Create mock request
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock format_and_add_data to raise CancelledError
self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock(
side_effect=asyncio.CancelledError("Test cancellation during data formatting")
)
# Mock the abort method that should be called when CancelledError occurs
self.chat_completion_handler.engine_client.abort = AsyncMock()
# Execute and verify that CancelledError is handled properly
# The CancelledError should be caught and handled, not re-raised
try:
await self.chat_completion_handler.create_chat_completion(request)
except asyncio.CancelledError:
# This should not happen as CancelledError should be caught and handled
self.fail("CancelledError should be caught and handled, not re-raised")
# Verify abort was called despite the cancellation
self.chat_completion_handler.engine_client.abort.assert_called_once()
async def test_chat_completion_stream_generator_cancelled_error(self):
"""Test asyncio.CancelledError handling in chat_completion_stream_generator method"""
# Create mock request
request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=True)
request_id = "test_cancel_request"
model_name = "test_model"
prompt_token_ids = [1, 2, 3]
prompt_tokens = "Hello world"
# Mock the connection manager
mock_dealer = MagicMock()
mock_response_queue = AsyncMock()
# Mock get_connection to return normally
self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
# Mock the semaphore
self.chat_completion_handler.engine_client.semaphore = MagicMock()
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True)
self.chat_completion_handler.engine_client.semaphore.release = MagicMock()
# Mock the model weight status check
self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False)
# Mock the response processor to raise CancelledError during processing
mock_response_processor = MagicMock()
mock_response_processor.enable_multimodal_content.return_value = False
async def mock_async_generator_with_cancel():
# Simulate some normal response first
yield {
"request_id": f"{request_id}_0",
"error_code": 200,
"metrics": {
"first_token_time": 1234567890,
"inference_start_time": 1234567880,
"arrival_time": 1234567890,
"request_start_time": 1234567870,
},
"prompt_logprobs": None,
"outputs": {
"token_ids": [5],
"text": "Hi",
"top_logprobs": None,
"draft_top_logprobs": None,
"multipart": [{"type": "text", "text": "Hi"}],
},
"finished": False,
"num_cached_tokens": 0,
"num_input_image_tokens": 0,
"num_input_video_tokens": 0,
}
# Then raise CancelledError
raise asyncio.CancelledError("Test cancellation during streaming")
mock_response_processor.process_response_chat.return_value = mock_async_generator_with_cancel()
# Mock the cleanup method
self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock()
# Mock the abort method that should be called when CancelledError occurs
self.chat_completion_handler.engine_client.abort = AsyncMock()
with patch(
"fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor
):
# Execute the generator and verify CancelledError handling
# The CancelledError should be caught and handled, not re-raised
chunks = []
try:
async for chunk in self.chat_completion_handler.chat_completion_stream_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100
):
chunks.append(chunk)
except asyncio.CancelledError:
# This should not happen as CancelledError should be caught and handled
self.fail("CancelledError should be caught and handled, not re-raised")
# Should have received at least one chunk before cancellation
self.assertGreaterEqual(len(chunks), 1)
self.assertIsNotNone(chunks[0])
# Verify cleanup and abort were called despite the cancellation
self.chat_completion_handler.engine_client.connection_manager.cleanup_request.assert_called_once()
self.chat_completion_handler.engine_client.abort.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import asyncio
import unittest import unittest
from typing import List from typing import List
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@@ -520,7 +521,7 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
results.append(result) results.append(result)
# Verify results # Verify results
self.assertTrue(len(results) > 0) self.assertGreater(len(results), 0)
# Check that the first response contains prompt_logprobs # Check that the first response contains prompt_logprobs
self.assertIn("prompt_logprobs", results[0]) self.assertIn("prompt_logprobs", results[0])
@@ -1216,6 +1217,118 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase):
self.assertIsNone(result.choices[0].prompt_logprobs) self.assertIsNone(result.choices[0].prompt_logprobs)
self.assertIsNone(result.choices[0].logprobs) self.assertIsNone(result.choices[0].logprobs)
async def test_create_completion_cancelled_error(self):
"""Test create_completion with asyncio.CancelledError"""
# Mock the engine client and its dependencies
mock_engine_client = Mock()
mock_engine_client.is_master = True
mock_engine_client.semaphore = Mock()
mock_engine_client.semaphore.acquire = AsyncMock()
mock_engine_client.semaphore.release = Mock()
mock_engine_client.format_and_add_data = AsyncMock()
mock_engine_client.format_and_add_data.return_value = [1, 2, 3]
mock_engine_client.abort = AsyncMock()
# Create serving completion instance
serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360)
# Create mock request
mock_request = Mock()
mock_request.prompt = "Hello, world!"
mock_request.prompt_token_ids = None
mock_request.stream = False
mock_request.n = 1
mock_request.request_id = None
mock_request.user = None
# Mock to_dict_for_infer method to return a proper dict
def mock_to_dict_for_infer(request_id_idx, prompt):
return {"prompt": prompt, "request_id": request_id_idx, "prompt_tokens": 10, "max_tokens": 100}
mock_request.to_dict_for_infer = mock_to_dict_for_infer
# Mock format_and_add_data to raise CancelledError
mock_engine_client.format_and_add_data.side_effect = asyncio.CancelledError("Test cancellation")
# Call method and expect CancelledError to be handled
result = await serving_completion.create_completion(mock_request)
# Verify that error was handled properly
self.assertIsNotNone(result)
self.assertTrue(hasattr(result, "error"))
self.assertEqual(result.error.code, "client_aborted")
self.assertIn("client disconnected", result.error.message)
# Verify that abort was called
mock_engine_client.abort.assert_called_once()
async def test_completion_stream_generator_cancelled_error(self):
"""Test completion_stream_generator with asyncio.CancelledError"""
# Mock the engine client and its dependencies
mock_engine_client = Mock()
mock_engine_client.semaphore = Mock()
mock_engine_client.semaphore.acquire = AsyncMock()
mock_engine_client.semaphore.release = Mock()
mock_engine_client.connection_manager = AsyncMock()
mock_engine_client.data_processor = Mock()
mock_engine_client.ori_vocab_size = 1000
mock_engine_client.check_model_weight_status.return_value = False
mock_engine_client.check_health.return_value = (True, "Healthy")
mock_engine_client.abort = AsyncMock()
# Mock the data_processor methods
mock_engine_client.data_processor.process_response_dict = Mock()
# Mock connection manager get_connection method
mock_dealer = Mock()
mock_dealer.write = Mock()
mock_response_queue = AsyncMock()
# Make response_queue.get raise CancelledError
mock_response_queue.get.side_effect = asyncio.CancelledError("Test cancellation")
mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
# Create serving completion instance
serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360)
# Create mock request
mock_request = Mock()
mock_request.prompt_logprobs = None
mock_request.logprobs = None
mock_request.include_draft_logprobs = False
mock_request.return_token_ids = True
mock_request.include_stop_str_in_output = False
mock_request.max_streaming_response_tokens = 1
mock_request.max_tokens = None
mock_request.stream_options = Mock()
mock_request.stream_options.include_usage = False
mock_request.n = 1
mock_request.echo = False
# Call the method and collect results
result_generator = serving_completion.completion_stream_generator(
request=mock_request,
num_choices=1,
request_id="test_request",
created_time=1234567890,
model_name="test_model",
prompt_batched_token_ids=[[1, 2, 3]],
prompt_tokens_list=["hello", "world"],
max_tokens_list=[100],
)
# Collect results - should handle CancelledError gracefully
results = []
async for result in result_generator:
results.append(result)
# Verify that abort was called
mock_engine_client.abort.assert_called_once_with("test_request_0", 1)
# Verify that generator ends gracefully (should have [DONE] message)
self.assertTrue(len(results) > 0)
self.assertTrue(any("[DONE]" in result for result in results))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
@@ -0,0 +1,356 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import unittest
from unittest.mock import MagicMock, patch
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastdeploy.entrypoints.openai.utils import with_cancellation
class TestWithCancellation(unittest.TestCase):
"""Test cases for with_cancellation decorator"""
def setUp(self):
"""Set up test fixtures"""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
"""Clean up test fixtures"""
self.loop.close()
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_normal_execution(self, mock_listen_disconnect):
"""Test that handler executes normally when no disconnect occurs"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create a mock that returns a coroutine that never completes
async def never_disconnect(request):
await asyncio.Future() # This will never complete
mock_listen_disconnect.side_effect = never_disconnect
# Setup handler function
@with_cancellation
async def test_handler(self, raw_request):
await asyncio.sleep(0.01) # Simulate some work
return "test_result"
# Run the test - need to pass self as first arg
result = self.loop.run_until_complete(test_handler(None, mock_request))
# Verify results
self.assertEqual(result, "test_result")
mock_listen_disconnect.assert_called_once_with(mock_request)
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_client_disconnect(self, mock_listen_disconnect):
"""Test that handler is cancelled when client disconnects"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create a future that will complete (disconnect) after a short delay
disconnect_future = asyncio.Future()
self.loop.call_later(0.05, disconnect_future.set_result, None)
mock_listen_disconnect.return_value = disconnect_future
# Setup handler function that takes longer than disconnect
handler_called = False
@with_cancellation
async def test_handler(self, raw_request):
nonlocal handler_called
try:
await asyncio.sleep(0.1) # Simulate work
handler_called = True
return "should_not_reach_here"
except asyncio.CancelledError:
# Handler should be cancelled
raise
# Run the test - need to pass self as first arg
result = self.loop.run_until_complete(test_handler(None, mock_request))
# Verify results
self.assertIsNone(result) # Should return None when cancelled
self.assertFalse(handler_called) # Handler should not complete
mock_listen_disconnect.assert_called_once_with(mock_request)
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_handler_with_args_kwargs(self, mock_listen_disconnect):
"""Test that decorator properly handles handler with args and kwargs"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create a mock that returns a coroutine that never completes
async def never_disconnect(request):
await asyncio.Future() # This will never complete
mock_listen_disconnect.side_effect = never_disconnect
# Setup handler function with multiple arguments
@with_cancellation
async def test_handler(arg1, arg2, raw_request, kwarg1=None, kwarg2=None):
await asyncio.sleep(0.01)
return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1, "kwarg2": kwarg2, "request": raw_request}
# Run the test with both positional and keyword arguments
result = self.loop.run_until_complete(
test_handler("value1", "value2", mock_request, kwarg1="kwvalue1", kwarg2="kwvalue2")
)
# Verify results
expected = {
"arg1": "value1",
"arg2": "value2",
"kwarg1": "kwvalue1",
"kwarg2": "kwvalue2",
"request": mock_request,
}
self.assertEqual(result, expected)
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_handler_returns_streaming_response(self, mock_listen_disconnect):
"""Test that decorator handles StreamingResponse correctly"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create a mock that returns a coroutine that never completes
async def never_disconnect(request):
await asyncio.Future() # This will never complete
mock_listen_disconnect.side_effect = never_disconnect
# Setup handler that returns StreamingResponse
@with_cancellation
async def test_handler(self, raw_request):
async def generate():
yield "chunk1"
yield "chunk2"
return StreamingResponse(generate())
# Run the test - need to pass self as first arg
result = self.loop.run_until_complete(test_handler(None, mock_request))
# Verify results
self.assertIsInstance(result, StreamingResponse)
mock_listen_disconnect.assert_called_once_with(mock_request)
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_handler_exception_propagation(self, mock_listen_disconnect):
"""Test that exceptions from handler are properly propagated"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create a mock that returns a coroutine that never completes
async def never_disconnect(request):
await asyncio.Future() # This will never complete
mock_listen_disconnect.side_effect = never_disconnect
# Setup handler that raises an exception
@with_cancellation
async def test_handler(self, raw_request):
await asyncio.sleep(0.01)
raise ValueError("Test exception")
# Run the test and expect exception - need to pass self as first arg
with self.assertRaises(ValueError) as context:
self.loop.run_until_complete(test_handler(None, mock_request))
self.assertEqual(str(context.exception), "Test exception")
mock_listen_disconnect.assert_called_once_with(mock_request)
@patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect")
def test_concurrent_cancellation_and_completion(self, mock_listen_disconnect):
"""Test edge case where cancellation and completion happen simultaneously"""
# Setup mock request
mock_request = MagicMock(spec=Request)
# Create futures that complete at roughly the same time
disconnect_future = asyncio.Future()
handler_future = asyncio.Future()
# Set both futures to complete almost simultaneously
self.loop.call_later(0.05, lambda: disconnect_future.set_result(None))
self.loop.call_later(0.05, lambda: handler_future.set_result("completed"))
mock_listen_disconnect.return_value = disconnect_future
@with_cancellation
async def test_handler(self, raw_request):
return await handler_future
# Run the test - need to pass self as first arg
result = self.loop.run_until_complete(test_handler(None, mock_request))
# The result depends on which task completes first
# This test ensures the decorator handles this edge case gracefully
self.assertIn(result, [None, "completed"])
def test_wrapper_preserves_function_metadata(self):
"""Test that the wrapper preserves the original function's metadata"""
def original_handler(raw_request):
"""Original handler docstring"""
pass
# Apply decorator
decorated_handler = with_cancellation(original_handler)
# Verify metadata is preserved
self.assertEqual(decorated_handler.__name__, "original_handler")
self.assertEqual(decorated_handler.__doc__, "Original handler docstring")
self.assertTrue(hasattr(decorated_handler, "__wrapped__"))
class TestListenForDisconnect(unittest.TestCase):
"""Test cases for listen_for_disconnect function"""
def setUp(self):
"""Set up test fixtures"""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
"""Clean up test fixtures"""
self.loop.close()
def test_listen_for_disconnect_normal_flow(self):
"""Test that listen_for_disconnect waits for disconnect message"""
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
# Setup mock request
mock_request = MagicMock()
receive_call_count = 0
# Create mock messages - normal messages followed by disconnect
messages = [
{"type": "http.request", "body": b"some data"},
{"type": "http.request", "body": b"more data"},
{"type": "http.disconnect"}, # This should break the loop
]
# Setup receive method to return messages sequentially
receive_iter = iter(messages)
async def mock_receive():
nonlocal receive_call_count
receive_call_count += 1
try:
return next(receive_iter)
except StopIteration:
# After all messages, return a message that keeps it waiting
await asyncio.Future() # Never completes
mock_request.receive = mock_receive
# Run the function in the event loop
self.loop.run_until_complete(listen_for_disconnect(mock_request))
# Verify that receive was called multiple times
self.assertGreaterEqual(receive_call_count, 3)
def test_listen_for_disconnect_immediate_disconnect(self):
"""Test that listen_for_disconnect returns immediately on disconnect"""
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
# Setup mock request
mock_request = MagicMock()
receive_called = False
# Setup receive to return disconnect immediately (as a coroutine)
async def mock_receive():
nonlocal receive_called
receive_called = True
return {"type": "http.disconnect"}
mock_request.receive = mock_receive
# Run the function in the event loop
self.loop.run_until_complete(listen_for_disconnect(mock_request))
# Verify that receive was called exactly once
self.assertTrue(receive_called)
def test_listen_for_disconnect_timeout(self):
"""Test that listen_for_disconnect can be cancelled with timeout"""
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
# Setup mock request
mock_request = MagicMock()
# Setup receive to never return disconnect
async def mock_receive():
await asyncio.Future() # Never completes
mock_request.receive = mock_receive
# Run with timeout to test cancellation
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(asyncio.wait_for(listen_for_disconnect(mock_request), timeout=0.01))
def test_listen_for_disconnect_various_message_types(self):
"""Test that listen_for_disconnect ignores non-disconnect messages"""
from fastdeploy.entrypoints.openai.utils import listen_for_disconnect
# Setup mock request
mock_request = MagicMock()
receive_call_count = 0
# Create various non-disconnect messages
messages = [
{"type": "http.request", "body": b"data"},
{"type": "http.response", "status": 200},
{"type": "websocket.connect"},
{"type": "http.request", "body": b"more data"},
{"type": "http.disconnect"}, # Final disconnect
]
# Setup receive method
receive_iter = iter(messages)
async def mock_receive():
nonlocal receive_call_count
receive_call_count += 1
try:
return next(receive_iter)
except StopIteration:
await asyncio.Future()
mock_request.receive = mock_receive
# Run the function in the event loop
self.loop.run_until_complete(listen_for_disconnect(mock_request))
# Verify all messages were processed
self.assertEqual(receive_call_count, 5)
if __name__ == "__main__":
unittest.main()
+257
View File
@@ -0,0 +1,257 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import unittest
from unittest.mock import MagicMock, patch
from fastdeploy.engine.request import RequestStatus
from fastdeploy.entrypoints.engine_client import EngineClient
class TestEngineClientAbort(unittest.TestCase):
"""Test cases for EngineClient.abort method"""
def setUp(self):
"""Set up test fixtures"""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
# Create a mock FDConfig
self.mock_fd_config = MagicMock()
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
self.mock_fd_config.model_config.enable_mm = False
self.mock_fd_config.model_config.max_model_len = 2048
self.mock_fd_config.model_config.enable_logprob = True
self.mock_fd_config.cache_config.enable_prefix_caching = False
self.mock_fd_config.scheduler_config.splitwise_role = "mixed"
self.mock_fd_config.limit_mm_per_prompt = 5
self.mock_fd_config.eplb_config.enable_eplb = False
self.mock_fd_config.structured_outputs_config.reasoning_parser = None
self.mock_fd_config.mm_processor_kwargs = {}
self.mock_fd_config.tool_parser = None
self.mock_fd_config.cache_config.max_processor_cache = 0
# Create EngineClient instance
with patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"):
with patch("fastdeploy.entrypoints.engine_client.IPCSignal"):
with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"):
with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"):
with patch("fastdeploy.entrypoints.engine_client.FileLock"):
self.engine_client = EngineClient(
pid=12345, port=8000, fd_config=self.mock_fd_config, workers=1
)
def tearDown(self):
"""Clean up test fixtures"""
self.loop.close()
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_single_request(self, mock_send_task):
"""Test aborting a single request"""
request_id = "test_request"
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=1))
# Verify _send_task was called with correct data
expected_data = {
"request_id": "test_request_0",
"status": RequestStatus.ABORT.value,
}
mock_send_task.assert_called_once_with(expected_data)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_multiple_requests(self, mock_send_task):
"""Test aborting multiple requests"""
request_id = "test_request"
n = 3
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
# Verify _send_task was called correct number of times
self.assertEqual(mock_send_task.call_count, n)
# Verify each call had correct request_id
expected_calls = [
({"request_id": "test_request_0", "status": RequestStatus.ABORT.value},),
({"request_id": "test_request_1", "status": RequestStatus.ABORT.value},),
({"request_id": "test_request_2", "status": RequestStatus.ABORT.value},),
]
actual_calls = [call.args for call in mock_send_task.call_args_list]
self.assertEqual(actual_calls, expected_calls)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_with_existing_suffix(self, mock_send_task):
"""Test aborting request that already has _number suffix"""
request_id = "test_request_123_2"
n = 2
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
# Verify _send_task was called correct number of times
self.assertEqual(mock_send_task.call_count, n)
# Verify each call had correct request_id (should use prefix before existing suffix)
expected_calls = [
({"request_id": "test_request_123_0", "status": RequestStatus.ABORT.value},),
({"request_id": "test_request_123_1", "status": RequestStatus.ABORT.value},),
]
actual_calls = [call.args for call in mock_send_task.call_args_list]
self.assertEqual(actual_calls, expected_calls)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_with_no_suffix(self, mock_send_task):
"""Test aborting request without _number suffix"""
request_id = "test_request_without_suffix"
n = 2
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
# Verify _send_task was called correct number of times
self.assertEqual(mock_send_task.call_count, n)
# Verify each call had correct request_id (should use full request_id as prefix)
expected_calls = [
({"request_id": "test_request_without_suffix_0", "status": RequestStatus.ABORT.value},),
({"request_id": "test_request_without_suffix_1", "status": RequestStatus.ABORT.value},),
]
actual_calls = [call.args for call in mock_send_task.call_args_list]
self.assertEqual(actual_calls, expected_calls)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_with_zero_n(self, mock_send_task):
"""Test aborting with n=0 should not send any requests"""
request_id = "test_request_123"
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=0))
# Verify _send_task was not called
mock_send_task.assert_not_called()
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_with_negative_n(self, mock_send_task):
"""Test aborting with negative n should not send any requests"""
request_id = "test_request_123"
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=-1))
# Verify _send_task was not called
mock_send_task.assert_not_called()
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", False)
@patch.object(EngineClient, "_send_task")
def test_abort_when_feature_disabled(self, mock_send_task):
"""Test abort when FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE is False"""
request_id = "test_request_123"
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=1))
# Verify _send_task was not called
mock_send_task.assert_not_called()
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch.object(EngineClient, "_send_task")
def test_abort_request_id_regex_parsing(self, mock_send_task):
"""Test that request_id regex parsing works correctly for various formats"""
test_cases = [
("simple_request", "simple_request"),
("request_with_underscores", "request_with_underscores"),
("request_123", "request"),
("request_123_456", "request_123"),
("request_0", "request"),
("complex_name_123_456_789", "complex_name_123_456"),
]
for input_request_id, expected_prefix in test_cases:
with self.subTest(input_request_id=input_request_id):
mock_send_task.reset_mock()
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(input_request_id, n=1))
# Verify _send_task was called with correct prefix
expected_data = {
"request_id": f"{expected_prefix}_0",
"status": RequestStatus.ABORT.value,
}
mock_send_task.assert_called_once_with(expected_data)
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch("fastdeploy.entrypoints.engine_client.api_server_logger")
@patch.object(EngineClient, "_send_task")
def test_abort_logging(self, mock_send_task, mock_logger):
"""Test that abort method logs correctly"""
request_id = "test_request"
n = 2
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
# Verify info log was called twice
self.assertEqual(mock_logger.info.call_count, 2)
# Verify the first log message (abort start)
first_call = mock_logger.info.call_args_list[0]
self.assertEqual(first_call[0][0], "abort request_id:test_request")
# Verify the second log message (abort completion with request IDs)
second_call = mock_logger.info.call_args_list[1]
expected_log_message = "Aborted request(s) %s."
self.assertEqual(second_call[0][0], expected_log_message)
self.assertEqual(second_call[0][1], "test_request_0,test_request_1")
@patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True)
@patch("fastdeploy.entrypoints.engine_client.api_server_logger")
@patch.object(EngineClient, "_send_task")
def test_abort_warning_logging_for_invalid_format(self, mock_send_task, mock_logger):
"""Test that abort method logs warning for invalid request_id format"""
request_id = "invalid_format_no_suffix" # This should actually not trigger warning
n = 1
# Run the abort method
self.loop.run_until_complete(self.engine_client.abort(request_id, n=n))
# Verify warning log was called (this case might not actually trigger warning)
# The warning is only triggered when regex doesn't match, but our test case has valid format
# Let's test with a case that should trigger warning
mock_logger.reset_mock()
# This should trigger warning because it doesn't end with _number
request_id_no_suffix = "just_a_string"
self.loop.run_until_complete(self.engine_client.abort(request_id_no_suffix, n=1))
# Should have logged warning about format error
mock_logger.warning.assert_called()
if __name__ == "__main__":
unittest.main()
+157 -1
View File
@@ -17,7 +17,7 @@
import random import random
import time import time
import unittest import unittest
from unittest.mock import Mock from unittest.mock import MagicMock, Mock, patch
import paddle import paddle
@@ -34,6 +34,19 @@ class MockConfig:
class SpeculativeConfig: class SpeculativeConfig:
method = None method = None
num_speculative_tokens = 1
num_model_steps = 1
max_candidate_len = 5
verify_window = 2
max_ngram_size = 5
min_ngram_size = 2
model = None
quantization = None
num_gpu_block_expand_ratio = 1
model_type = "main"
benchmark_mode = False
num_extra_cache_layer = 0
mtp_strategy = "default"
class ModelConfig: class ModelConfig:
enable_logprob = False enable_logprob = False
@@ -91,6 +104,8 @@ class MockResourceManager:
self.stop_flags = [False] self.stop_flags = [False]
self.tasks_list = [MockTask()] self.tasks_list = [MockTask()]
self.to_be_rescheduled_request_id_set = set() self.to_be_rescheduled_request_id_set = set()
self.abort_req_ids_set = set()
self.req_dict = {}
def info(self): def info(self):
return "Mock resource manager info" return "Mock resource manager info"
@@ -241,6 +256,147 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase):
assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1
assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i]
def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self):
"""Test aborted task receiving negative token triggers recycling in speculative decoding mode"""
processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True)
# Set up task as aborted
task_id = "test_aborted_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = {task_id}
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
# Use a larger batch to avoid the early recycling condition
processor.resource_manager.req_dict[task_id] = 0 # batch_id = 0
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# stop_flag
processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2))
# mtype target = 3
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3))
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2))
# Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(-9))
processor.output_tokens[4, 0].set_tensor(paddle.to_tensor(1))
# Add second task to tasks_list
task2 = MockTask()
task2.request_id = "test_request_2"
processor.resource_manager.tasks_list = [task, task2]
processor.resource_manager.stop_flags = [False, False]
# Update tokens_counter to include both tasks
processor.tokens_counter[task_id] = 0
processor.tokens_counter[task2.request_id] = 0
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
# In speculative decoding mode, when accept_num[i] == PREEMPTED_TOKEN_ID,
# the code logs "sync preemption" and continues without triggering abort recycling
# This is the expected behavior for speculative decoding mode
mock_logger.info.assert_any_call(f"sync preemption for request_id {task_id} done.")
# Verify that _recycle_resources was NOT called for the aborted task
# (it may be called for other tasks like test_request_2 if they receive EOS tokens)
for call in processor._recycle_resources.call_args_list:
self.assertNotEqual(
call[0][0], task_id, f"_recycle_resources should not be called for aborted task {task_id}"
)
# Verify that the task is still in abort_req_ids_set
self.assertIn(task_id, processor.resource_manager.abort_req_ids_set)
def test_process_batch_output_aborted_task_negative_token_normal_mode(self):
"""Test aborted task receiving negative token triggers recycling in normal mode"""
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
# Set up task as aborted
task_id = "test_aborted_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = {task_id}
# Add the task to req_dict to prevent _recycle_aborted_task from processing it early
# batch_id should be < batch_size - 1 to avoid early recycling
processor.resource_manager.req_dict[task_id] = (
0 # batch_id = 0, batch_size = 1, so 0 < 0 is false, but 0 >= 0 is true
)
# Actually, let's use a larger batch to avoid the early recycling condition
processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# batch = 2 (so batch_id=0 is < batch_size-1=1)
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2))
# Set negative token for first task (batch_id=0)
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
# Set positive token for second task (batch_id=1)
processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100))
# Add second task to tasks_list
task2 = MockTask()
task2.request_id = "test_request_2"
processor.resource_manager.tasks_list = [task, task2]
processor.resource_manager.stop_flags = [False, False]
# Update tokens_counter to include both tasks
processor.tokens_counter[task_id] = 0
processor.tokens_counter[task2.request_id] = 0
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
processor._recycle_resources.assert_called_once_with(task_id, 0, task)
self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set)
def test_process_batch_output_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False)
# Set up task as not aborted
task_id = "test_normal_request"
task = processor.resource_manager.tasks_list[0]
task.request_id = task_id
processor.resource_manager.abort_req_ids_set = set() # Empty set
# Mock _recycle_resources to track if it's called
processor._recycle_resources = MagicMock()
# Set up output tokens with negative token
# batch = 1
processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(1))
# Set negative token
processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1))
# Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
processor._process_batch_output()
print(mock_logger)
# Verify the recycling logic was NOT triggered
# When a non-aborted task receives a negative token, the code just continues
# without logging or recycling
processor._recycle_resources.assert_not_called()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2, buffer=False) unittest.main(verbosity=2, buffer=False)
@@ -153,6 +153,60 @@ class TestTokenProcessorLogprobs(unittest.TestCase):
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)
def test_process_batch_output_use_zmq_aborted_task_negative_token(self):
"""Test aborted task receiving negative token triggers recycling logic"""
# Set up task as aborted
task_id = "test_aborted_request"
self.task_mock.request_id = task_id
self.processor.resource_manager.abort_req_ids_set = {task_id}
# Create stream data with negative token
stream_data = MagicMock()
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
stream_data.batch_id = 0
# Mock _recycle_resources to track if it's called
self.processor._recycle_resources = MagicMock()
# Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER
with (
patch("fastdeploy.output.token_processor.llm_logger") as mock_logger,
patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0),
):
# Call the method
result = self.processor._process_batch_output_use_zmq([stream_data])
# Verify the recycling logic was triggered
mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.")
self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock)
self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set)
self.assertEqual(len(result), 1) # Should return abort result
self.assertEqual(result[0].finished, True)
self.assertEqual(result[0].error_code, 499)
self.assertIn("aborted", result[0].error_msg.lower())
def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self):
"""Test non-aborted task receiving negative token does not trigger recycling"""
# Set up task as not aborted
task_id = "test_normal_request"
self.task_mock.request_id = task_id
self.processor.resource_manager.abort_req_ids_set = set() # Empty set
# Create stream data with negative token
stream_data = MagicMock()
stream_data.tokens = np.array([1, 2, -1]) # Last token is negative
stream_data.batch_id = 0
# Mock _recycle_resources to track if it's called
self.processor._recycle_resources = MagicMock()
# Call the method
self.processor._process_batch_output_use_zmq([stream_data])
# Verify recycling logic was NOT triggered
self.processor._recycle_resources.assert_not_called()
self.processor.llm_logger.info.assert_not_called()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
+1
View File
@@ -74,6 +74,7 @@ class _DummyResourceManager:
self.req_dict = {} self.req_dict = {}
self.requests = {} self.requests = {}
self.to_be_rescheduled_request_id_set = set() self.to_be_rescheduled_request_id_set = set()
self.abort_req_ids_set = set()
self.recycled = [] self.recycled = []
self.cached_tasks = [] self.cached_tasks = []
self.cleared = False self.cleared = False