diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 8336d8a8f5..f5b81c0271 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -444,33 +444,35 @@ class PrefixCacheManager: else: return True - def allocate_gpu_blocks(self, num_blocks): + def allocate_gpu_blocks(self, num_blocks, req_id=None): """ allocate gpu blocks. """ assert num_blocks <= len( self.gpu_free_block_list ), f"gpu free block num: {len(self.gpu_free_block_list)} < needed number {num_blocks}" + logger.debug(f"{req_id} start allocate...") allocated_block_ids = [heapq.heappop(self.gpu_free_block_list) for i in range(num_blocks)] logger.info( - f"allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" + f"req_id:{req_id} allocate_gpu_blocks: {allocated_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" ) main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) return allocated_block_ids - def recycle_gpu_blocks(self, gpu_block_ids): + def recycle_gpu_blocks(self, gpu_block_ids, req_id=None): """ recycle gpu blocks. """ logger.info( - f"recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" + f"req_id:{req_id} recycle_gpu_blocks: {gpu_block_ids}, len(self.gpu_free_block_list) {len(self.gpu_free_block_list)}" ) if isinstance(gpu_block_ids, list): for gpu_block_id in gpu_block_ids: heapq.heappush(self.gpu_free_block_list, gpu_block_id) else: heapq.heappush(self.gpu_free_block_list, gpu_block_ids) + logger.debug(f"req_id:{req_id} recycle blocks end") main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) @@ -978,7 +980,7 @@ class PrefixCacheManager: logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}") if leaf_node == self.radix_tree_root: - self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id]) + self.recycle_gpu_blocks(self.unfilled_req_block_map[req_id], req_id) del self.unfilled_req_block_map[req_id] return diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index a75a4f768d..99e0dfd1c6 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -38,7 +38,7 @@ import zmq from tqdm import tqdm import fastdeploy.metrics.trace as tracing -from fastdeploy.engine.request import Request, RequestOutput, RequestType +from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.eplb.utils import init_eplb_signals @@ -721,6 +721,7 @@ class EngineService: max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens, batch=num_prefill_batch, ) + tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set] for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) @@ -787,6 +788,7 @@ class EngineService: max_num_batched_tokens=max_num_batched_tokens, batch=num_prefill_batch, ) + tasks = [task for task in tasks if task.request_id not in self.resource_manager.abort_req_ids_set] for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) @@ -1074,6 +1076,24 @@ class EngineService: request, insert_task = None, [] results: List[Tuple[str, Optional[str]]] = list() if data: + status_value = data.get("status", None) + if status_value is not None and status_value == RequestStatus.ABORT.value: + req_id = data["request_id"] + self.llm_logger.info(f"Receive abort request, req_id: {req_id}") + self.resource_manager.abort_req_ids_set.add(req_id) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if req_id in self.resource_manager.requests: + req = self.resource_manager.requests[req_id] + task = self.resource_manager._prepare_preempt_task(req) + self.engine_worker_queue.put_tasks(([task], self.resource_manager.real_bsz)) + self.llm_logger.info(f"put abort task in engine worker queue, req_id: {req_id}") + else: + self.scheduler._recycle(req_id) + self.llm_logger.info( + f"req_id:{req_id} has not been allocated any resources, recycled it in scheduler" + ) + self.resource_manager.abort_req_ids_set.remove(req_id) + continue err_msg = None try: request = Request.from_dict(data) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 2c2f0f6c20..34a3ad2845 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -43,6 +43,7 @@ class RequestStatus(Enum): RUNNING = 1 PREEMPTED = 2 FINISHED = 3 + ABORT = 4 class RequestType(Enum): diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 4388e01dd8..b7dbb61c6b 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -58,6 +58,7 @@ class ResourceManager: self.req_dict = dict() # current batch status of the engine self.real_bsz = 0 + self.abort_req_ids_set = set() llm_logger.info(f"{self.info()}") main_process_metrics.max_batch_size.set(max_num_seqs) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 2dd8bb1100..d55cf06c00 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -175,6 +175,7 @@ class ResourceManagerV1(ResourceManager): self.using_extend_tables_req_id = set() self.reuse_block_num_map = dict() + self.abort_req_ids_set = set() # need block nums need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32) @@ -243,6 +244,7 @@ class ResourceManagerV1(ResourceManager): request = self.requests[request_id] if process_func is not None: process_func(request) + llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}") self.waiting.appendleft(request) self.to_be_rescheduled_request_id_set.remove(request_id) @@ -640,7 +642,9 @@ class ResourceManagerV1(ResourceManager): f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}" ) request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) + self.cache_manager.allocate_gpu_blocks( + self.config.cache_config.enc_dec_block_num, request.request_id + ) ) # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) @@ -653,7 +657,9 @@ class ResourceManagerV1(ResourceManager): break # Allocation for next decoding blocks request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num) + self.cache_manager.allocate_gpu_blocks( + self.config.cache_config.enc_dec_block_num, request.request_id + ) ) # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) @@ -668,7 +674,9 @@ class ResourceManagerV1(ResourceManager): def _allocate_decode_and_extend(): allocate_block_num = self.need_block_num_map[request.request_id].consume() # Prepare decoding task - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(allocate_block_num)) + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id) + ) scheduled_reqs.append(self._prepare_decode_task(request)) # Prepare extend task @@ -682,7 +690,7 @@ class ResourceManagerV1(ResourceManager): request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache request.extend_block_tables.extend( - self.cache_manager.allocate_gpu_blocks(allocate_block_num) + self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id) ) scheduled_reqs.append( ScheduledExtendBlocksTask( @@ -733,14 +741,18 @@ class ResourceManagerV1(ResourceManager): num_new_block = self.get_new_block_nums(request, num_new_tokens) # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) + ) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) else: # Not enough blocks to allocate, trigger preemption can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) if not can_schedule: break - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block)) + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) + ) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) token_budget -= num_new_tokens @@ -806,7 +818,9 @@ class ResourceManagerV1(ResourceManager): # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): if not request.get("skip_allocate", False): - extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks( + num_new_block, request.request_id + ) request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) @@ -824,6 +838,7 @@ class ResourceManagerV1(ResourceManager): self.tasks_list[allocated_position] = request self.stop_flags[allocated_position] = False self.req_dict[request.request_id] = allocated_position + llm_logger.debug(f"req_id:{request.request_id} allocate pos end") else: if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) @@ -856,7 +871,9 @@ class ResourceManagerV1(ResourceManager): # Allocate blocks to prefill if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): if not request.get("skip_allocate", False): - extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks( + num_new_block, request.request_id + ) request.block_tables.extend(extra_gpu_block_ids) self.waiting.popleft() self.running.append(request) @@ -873,7 +890,7 @@ class ResourceManagerV1(ResourceManager): self._free_blocks(request) break else: - llm_logger.error("Unknown request status type") + llm_logger.info(f"Unknown request status type:{request.status}, req_id:{request.request_id}") for req in skip_requests: # move waiting request to end of the deque @@ -1069,6 +1086,7 @@ class ResourceManagerV1(ResourceManager): def add_request(self, request: Request) -> None: with self.lock: self.apply_async_preprocess(request) + llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}") self.waiting.append(request) self.requests[request.request_id] = request @@ -1120,7 +1138,9 @@ class ResourceManagerV1(ResourceManager): need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0] if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks): - extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks) + extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks( + need_extra_prefill_blocks, request.request_id + ) request.block_tables.extend(extra_gpu_block_ids) allocated_position = self.get_available_position() request.idx = allocated_position @@ -1135,7 +1155,9 @@ class ResourceManagerV1(ResourceManager): else: if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): - request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks)) + request.block_tables.extend( + self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id) + ) request.num_computed_tokens = 0 allocated_position = self.get_available_position() request.idx = allocated_position @@ -1169,7 +1191,9 @@ class ResourceManagerV1(ResourceManager): if not self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks): return False - request.block_tables = self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks) + request.block_tables = self.cache_manager.allocate_gpu_blocks( + need_prealloc_prefill_blocks, request.request_id + ) request.num_computed_tokens = request.need_prefill_tokens request.disaggregate_info["block_tables"] = request.block_tables allocated_position = self.get_available_position() @@ -1221,16 +1245,18 @@ class ResourceManagerV1(ResourceManager): def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching: self.cache_manager.release_block_ids(request) - self.cache_manager.recycle_gpu_blocks(request.block_tables[request.num_cached_blocks :]) + self.cache_manager.recycle_gpu_blocks( + request.block_tables[request.num_cached_blocks :], request.request_id + ) else: - self.cache_manager.recycle_gpu_blocks(request.block_tables) + self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id) request.block_tables = [] if request.request_id in self.using_extend_tables_req_id: reuse_block_num = self.reuse_block_num_map[request.request_id] self.using_extend_tables_req_id.remove(request.request_id) - self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:]) + self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:], request.request_id) llm_logger.info( f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}" ) @@ -1280,9 +1306,9 @@ class ResourceManagerV1(ResourceManager): for req in need_postprocess_reqs: try: self._free_blocks(req) + llm_logger.debug(f"req_id:{req.request_id} free pos:{req.idx}") except Exception as e: llm_logger.warning(f"release block failed {req.request_id}: {e}") - except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") finally: diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 7c6d5a3267..0a0c61da39 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -16,6 +16,7 @@ import inspect import os +import re import time import traceback import uuid @@ -28,6 +29,7 @@ from filelock import FileLock import fastdeploy.metrics.trace as tracing from fastdeploy import envs from fastdeploy.config import FDConfig +from fastdeploy.engine.request import RequestStatus from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.eplb.utils import RedundantExpertWorkload @@ -844,3 +846,27 @@ class EngineClient: content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list} status_code = HTTPStatus.OK return content, status_code + + async def abort(self, request_id, n=1) -> None: + if envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE: + api_server_logger.info(f"abort request_id:{request_id}") + if n <= 0: + api_server_logger.warning("Abort function called with non-positive n: %d. No requests aborted.", n) + return + match = re.search(r"_\d+$", request_id) + if match: + prefix = request_id[: match.start()] + else: + api_server_logger.warning( + "request_id format error: %s does not end with _. Using it as prefix.", request_id + ) + prefix = request_id + request_ids = [f"{prefix}_{i}" for i in range(n)] + for req_id in request_ids: + data = { + "request_id": req_id, + "status": RequestStatus.ABORT.value, + } + self._send_task(data) + + api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids)) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 13879f72a8..fed19e8c81 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -59,7 +59,11 @@ from fastdeploy.entrypoints.openai.serving_embedding import OpenAIServingEmbeddi from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager -from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser +from fastdeploy.entrypoints.openai.utils import ( + UVICORN_CONFIG, + make_arg_parser, + with_cancellation, +) from fastdeploy.entrypoints.openai.v1.serving_chat import ( OpenAIServingChat as OpenAIServingChatV1, ) @@ -410,6 +414,7 @@ def wrap_streaming_generator(original_generator: AsyncGenerator): @app.post("/v1/chat/completions") +@with_cancellation async def create_chat_completion(request: ChatCompletionRequest, req: Request): """ Create a chat completion for the provided prompt and parameters. @@ -446,6 +451,7 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request): @app.post("/v1/completions") +@with_cancellation async def create_completion(request: CompletionRequest, req: Request): """ Create a completion for the provided prompt and parameters. diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index fa3a0f74d5..ec00b1562c 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -180,6 +180,13 @@ class OpenAIServingChat: error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) + except asyncio.CancelledError as e: + await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED) + ) except Exception as e: error_msg = ( f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, " @@ -267,7 +274,9 @@ class OpenAIServingChat: except asyncio.TimeoutError: current_waiting_time += 10 if current_waiting_time == 300: - status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) + status, msg = self.engine_client.check_health( + time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT + ) if not status: if choices: chunk.choices = choices @@ -506,6 +515,10 @@ class OpenAIServingChat: ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + except asyncio.CancelledError as e: + await self.engine_client.abort(f"{request_id}_0", 1 if request.n is None else request.n) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) except Exception as e: error_data = self._create_streaming_error_response( f"request[{request_id}] generate stream error: {str(e)}, {str(traceback.format_exc())}" @@ -577,7 +590,9 @@ class OpenAIServingChat: except asyncio.TimeoutError: current_waiting_time += 10 if current_waiting_time == 300: - status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) + status, msg = self.engine_client.check_health( + time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT + ) if not status: raise ValueError(f"Engine is not healthy: {msg}") else: diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9f211335b4..082439822e 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -230,7 +230,13 @@ class OpenAIServingCompletion: ) api_server_logger.error(error_msg) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) - + except asyncio.CancelledError as e: + await self.engine_client.abort(f"{request_id}_0", num_choices) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED) + ) except Exception as e: error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) @@ -285,7 +291,9 @@ class OpenAIServingCompletion: except asyncio.TimeoutError: current_waiting_time += 10 if current_waiting_time == 300: - status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) + status, msg = self.engine_client.check_health( + time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT + ) if not status: raise ValueError(f"Engine is not healthy: {msg}") else: @@ -455,7 +463,9 @@ class OpenAIServingCompletion: except asyncio.TimeoutError: current_waiting_time += 10 if current_waiting_time == 300: - status, msg = self.engine_client.check_health(time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT) + status, msg = self.engine_client.check_health( + time_interval_threashold=envs.FD_WORKER_ALIVE_TIMEOUT + ) if not status: raise ValueError(f"Engine is not healthy: {msg}") else: @@ -634,6 +644,10 @@ class OpenAIServingCompletion: yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}") + except asyncio.CancelledError as e: + await self.engine_client.abort(f"{request_id}_0", num_choices) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) except Exception as e: api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}") yield f"data: {ErrorResponse(error=ErrorInfo(message=str(e), code='400', type=ErrorType.INTERNAL_ERROR)).model_dump_json(exclude_unset=True)}\n\n" diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 9a7fe239a9..451934cb60 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -15,6 +15,7 @@ """ import asyncio +import functools import heapq import random import time @@ -22,6 +23,7 @@ from multiprocessing.reduction import ForkingPickler import aiozmq import zmq +from fastapi import Request from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.metrics.metrics import main_process_metrics @@ -253,3 +255,55 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = EngineArgs.add_cli_args(parser) return parser + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects.The response object will only correctly + listen when the ASGI protocol version used by Uvicorn is less than 2.4(Excluding 2.4). + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["req"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 9612b90529..9cba6cbe51 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -158,6 +158,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), + "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( + os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") + ), # Whether to collect user information "DO_NOT_TRACK": lambda: (os.getenv("DO_NOT_TRACK", "0")) == "1", # Usage stats server url diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index dcafb4c322..15c5a4e2ba 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -248,9 +248,24 @@ class TokenProcessor: task: Request = self.resource_manager.tasks_list[i] task_id = task.request_id - token_ids = stream_data.tokens # numpy.array if token_ids is not None and token_ids[-1] <= 0: + if task_id in self.resource_manager.abort_req_ids_set: + if ( + envs.ENABLE_V1_KVCACHE_SCHEDULER and token_ids[-1] == PREEMPTED_TOKEN_ID + ) or not envs.ENABLE_V1_KVCACHE_SCHEDULER: + llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") + self.resource_manager.abort_req_ids_set.remove(task_id) + self._recycle_resources(task_id, i, task) + llm_logger.info(f"{task_id} received negative token. Recycle end.") + abort_res = RequestOutput( + request_id=task_id, + finished=True, + error_code=499, + error_msg=f"Your request with request_id:{task_id} is aborted.", + ) + batch_result.append(abort_res) + continue if envs.ENABLE_V1_KVCACHE_SCHEDULER: if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set @@ -711,6 +726,19 @@ class TokenProcessor: if accept_num[i] == PREEMPTED_TOKEN_ID: # in MTP, meas preemption has happend in worker llm_logger.info(f"sync preemption for request_id {task_id} done.") if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if task_id in self.resource_manager.abort_req_ids_set: + llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") + self.resource_manager.abort_req_ids_set.remove(task_id) + self._recycle_resources(task_id, i, task) + llm_logger.info(f"{task_id} received negative token. Recycle end.") + abort_res = RequestOutput( + request_id=task_id, + finished=True, + error_code=499, + error_msg=f"Your request with request_id:{task_id} is aborted.", + ) + batch_result.append(abort_res) + continue if task_id in self.resource_manager.to_be_rescheduled_request_id_set: self.resource_manager.reschedule_preempt_task(task_id) continue @@ -737,6 +765,22 @@ class TokenProcessor: if recovery_stop: llm_logger.info(f"recovery stop signal found at task {task_id}") if not recovery_stop and token_id < 0: + if task_id in self.resource_manager.abort_req_ids_set: + if ( + envs.ENABLE_V1_KVCACHE_SCHEDULER and token_id == PREEMPTED_TOKEN_ID + ) or not envs.ENABLE_V1_KVCACHE_SCHEDULER: + llm_logger.info(f"Aborted task {task_id} received negative token. Recycling.") + self.resource_manager.abort_req_ids_set.remove(task_id) + self._recycle_resources(task_id, i, task) + llm_logger.info(f"{task_id} received negative token. Recycle end.") + abort_res = RequestOutput( + request_id=task_id, + finished=True, + error_code=499, + error_msg=f"Your request with request_id:{task_id} is aborted.", + ) + batch_result.append(abort_res) + continue if envs.ENABLE_V1_KVCACHE_SCHEDULER: if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set @@ -760,6 +804,7 @@ class TokenProcessor: task.metrics.record_recv_first_token() task.metrics.cal_cost_time() metrics = copy.copy(task.metrics) + llm_logger.info(f"task:{task.request_id} start recode first token") self._record_first_token_metrics(task, current_time) tracing.trace_report_span( @@ -845,7 +890,6 @@ class TokenProcessor: result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) - if token_id in task.eos_token_ids or is_prefill or recovery_stop: result.finished = True trace_carrier = tracing.trace_get_proc_propagate_context(rid=rid) @@ -872,7 +916,9 @@ class TokenProcessor: self._compute_speculative_status(result) if not is_prefill: self._record_completion_metrics(task, current_time) + llm_logger.info(f"task {task_id} received eos token. Recycling.") self._recycle_resources(task_id, i, task, result, is_prefill) + llm_logger.info(f"eos token {task_id} Recycle end.") break llm_logger.debug(f"get response from infer: {result}") diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 2796b931c2..bca191442c 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -220,6 +220,7 @@ class ErrorCode(str, Enum): CONNECTION_ERROR = "connection_error" MISSING_REQUIRED_PARAMETER = "missing_required_parameter" INTERNAL_ERROR = "internal_error" + CLIENT_ABORTED = "client_aborted" class ColoredFormatter(logging.Formatter): diff --git a/requirements_dcu.txt b/requirements_dcu.txt index dc756ba490..803b6cff26 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -8,7 +8,7 @@ aiozmq openai tqdm pynvml -uvicorn==0.29.0 +uvicorn>=0.38.0 fastapi paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index abeef5b228..4ed3341820 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -8,7 +8,7 @@ aiozmq openai>=1.93.0 tqdm pynvml -uvicorn==0.29.0 +uvicorn>=0.38.0 fastapi paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl redis diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index ad06c803cd..eaece2b673 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -8,7 +8,7 @@ aiozmq openai>=1.93.0 tqdm pynvml -uvicorn==0.29.0 +uvicorn>=0.38.0 fastapi # if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton. paddleformers==0.3.2 diff --git a/tests/ci_use/metrics/test_metrics.py b/tests/ci_use/metrics/test_metrics.py index 0fb46c54f4..ea0265a86a 100644 --- a/tests/ci_use/metrics/test_metrics.py +++ b/tests/ci_use/metrics/test_metrics.py @@ -14,7 +14,6 @@ import asyncio import os -import re import shutil import signal import subprocess @@ -48,7 +47,7 @@ def setup_and_run_server(): - Tears down server after all tests finish """ print("Pre-test port cleanup...") - FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633)) + FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8333)) clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT]) env = os.environ.copy() @@ -173,9 +172,12 @@ def parse_prometheus_to_dict(metrics_text: str): value = float(line.split("}")[1].strip()) # 解析 labels - # 用正则取出所有 key 和 value(去掉外层引号) - pairs = re.findall(r'(\w+)="([^"]*)"', labels_str) - labels = {k: v for k, v in pairs} + labels = {} + for kv in labels_str.split(","): + if "=" not in kv: + continue + k, v = kv.split("=") + labels[k] = v.strip('"') # 存储 if metric_name not in result: @@ -212,7 +214,6 @@ def test_metrics_with_clear_and_reset(): """ Test the metrics monitoring endpoint. """ - FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633)) metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" async_concurrency(n=10) @@ -229,23 +230,13 @@ def test_metrics_with_clear_and_reset(): running = metrics["fastdeploy:num_requests_running"] waiting = metrics["fastdeploy:num_requests_waiting"] - print("ASSERT clear_load_weight后非0 running:", running, "waiting:", waiting) - assert running != 0 or waiting != 0, "Expected running/waiting to be non-zero" - - # ===== reset_scheduler ===== - reset_url = f"http://0.0.0.0:{FD_CONTROLLER_PORT}/controller/reset_scheduler" - print("Calling reset_scheduler...") - r = requests.post(reset_url, json={"reset": True}, timeout=30) - assert r.status_code == 200, f"reset_scheduler failed: {r.status_code}" - - metrics = get_metrics_dict(metrics_url) - running = metrics["fastdeploy:num_requests_running"] - waiting = metrics["fastdeploy:num_requests_waiting"] - - print("ASSERT reset_scheduler后为0 running:", running, "waiting:", waiting) - # Temporarily disable this assertion. The running/waiting states are not strictly - # guaranteed to reach zero in the current workflow, so we skip this check for now. - # assert running == 0 and waiting == 0, "Expected running/waiting to be zero" + print( + "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", + running, + "waiting:", + waiting, + ) + assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight" if __name__ == "__main__": diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 40d8da93f8..f491d60295 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -83,6 +83,7 @@ def _reload_api_server(args): fake_envs_mod.TRACES_EXPORTER = "console" fake_envs_mod.EXPORTER_OTLP_ENDPOINT = "" fake_envs_mod.EXPORTER_OTLP_HEADERS = "" + fake_envs_mod.FD_SUPPORT_MAX_CONNECTIONS = 1024 fake_envs_mod.environment_variables = _FakeEnvVars() # Save original sys.argv and replace with minimal valid args to avoid parse errors @@ -397,26 +398,36 @@ async def test_chat_and_completion_routes(): chat_handler = MagicMock() chat_handler.create_chat_completion = AsyncMock(return_value=error_resp) api_server.app.state.chat_handler = chat_handler - assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500 + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500 success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo()) api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp) - assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200 + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200 async def stream_gen(): yield "data" api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen()) - assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse) + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse) # Completion handler completion_handler = MagicMock() completion_handler.create_completion = AsyncMock(return_value=error_resp) api_server.app.state.completion_handler = completion_handler - assert (await api_server.create_completion(body, fake_req)).status_code == 500 + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + assert (await api_server.create_completion(body, fake_req)).status_code == 500 api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp) - assert (await api_server.create_completion(body, fake_req)).status_code == 200 + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + assert (await api_server.create_completion(body, fake_req)).status_code == 200 # HTTPException handling class RaiseHTTP: diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 40b4a3f4f0..2236f22af1 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import asyncio import json import unittest from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -1079,6 +1080,127 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase): # logprobs should be None when not requested self.assertIsNone(choice.logprobs) + async def test_create_chat_completion_cancelled_error(self): + """Test asyncio.CancelledError handling in create_chat_completion method""" + # Create mock request + request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=False) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock format_and_add_data to raise CancelledError + self.chat_completion_handler.engine_client.format_and_add_data = AsyncMock( + side_effect=asyncio.CancelledError("Test cancellation during data formatting") + ) + + # Mock the abort method that should be called when CancelledError occurs + self.chat_completion_handler.engine_client.abort = AsyncMock() + + # Execute and verify that CancelledError is handled properly + # The CancelledError should be caught and handled, not re-raised + try: + await self.chat_completion_handler.create_chat_completion(request) + except asyncio.CancelledError: + # This should not happen as CancelledError should be caught and handled + self.fail("CancelledError should be caught and handled, not re-raised") + + # Verify abort was called despite the cancellation + self.chat_completion_handler.engine_client.abort.assert_called_once() + + async def test_chat_completion_stream_generator_cancelled_error(self): + """Test asyncio.CancelledError handling in chat_completion_stream_generator method""" + # Create mock request + request = ChatCompletionRequest(messages=[{"role": "user", "content": "Hello"}], stream=True) + + request_id = "test_cancel_request" + model_name = "test_model" + prompt_token_ids = [1, 2, 3] + prompt_tokens = "Hello world" + + # Mock the connection manager + mock_dealer = MagicMock() + mock_response_queue = AsyncMock() + + # Mock get_connection to return normally + self.chat_completion_handler.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + # Mock the semaphore + self.chat_completion_handler.engine_client.semaphore = MagicMock() + self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=True) + self.chat_completion_handler.engine_client.semaphore.release = MagicMock() + + # Mock the model weight status check + self.chat_completion_handler.engine_client.check_model_weight_status = Mock(return_value=False) + + # Mock the response processor to raise CancelledError during processing + mock_response_processor = MagicMock() + mock_response_processor.enable_multimodal_content.return_value = False + + async def mock_async_generator_with_cancel(): + # Simulate some normal response first + yield { + "request_id": f"{request_id}_0", + "error_code": 200, + "metrics": { + "first_token_time": 1234567890, + "inference_start_time": 1234567880, + "arrival_time": 1234567890, + "request_start_time": 1234567870, + }, + "prompt_logprobs": None, + "outputs": { + "token_ids": [5], + "text": "Hi", + "top_logprobs": None, + "draft_top_logprobs": None, + "multipart": [{"type": "text", "text": "Hi"}], + }, + "finished": False, + "num_cached_tokens": 0, + "num_input_image_tokens": 0, + "num_input_video_tokens": 0, + } + # Then raise CancelledError + raise asyncio.CancelledError("Test cancellation during streaming") + + mock_response_processor.process_response_chat.return_value = mock_async_generator_with_cancel() + + # Mock the cleanup method + self.chat_completion_handler.engine_client.connection_manager.cleanup_request = AsyncMock() + + # Mock the abort method that should be called when CancelledError occurs + self.chat_completion_handler.engine_client.abort = AsyncMock() + + with patch( + "fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor", return_value=mock_response_processor + ): + # Execute the generator and verify CancelledError handling + # The CancelledError should be caught and handled, not re-raised + chunks = [] + try: + async for chunk in self.chat_completion_handler.chat_completion_stream_generator( + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens=100 + ): + chunks.append(chunk) + except asyncio.CancelledError: + # This should not happen as CancelledError should be caught and handled + self.fail("CancelledError should be caught and handled, not re-raised") + + # Should have received at least one chunk before cancellation + self.assertGreaterEqual(len(chunks), 1) + self.assertIsNotNone(chunks[0]) + + # Verify cleanup and abort were called despite the cancellation + self.chat_completion_handler.engine_client.connection_manager.cleanup_request.assert_called_once() + self.chat_completion_handler.engine_client.abort.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/tests/entrypoints/openai/test_serving_completion.py b/tests/entrypoints/openai/test_serving_completion.py index 761213d1d5..814f505e63 100644 --- a/tests/entrypoints/openai/test_serving_completion.py +++ b/tests/entrypoints/openai/test_serving_completion.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import asyncio import unittest from typing import List from unittest.mock import AsyncMock, Mock, patch @@ -520,7 +521,7 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase): results.append(result) # Verify results - self.assertTrue(len(results) > 0) + self.assertGreater(len(results), 0) # Check that the first response contains prompt_logprobs self.assertIn("prompt_logprobs", results[0]) @@ -1216,6 +1217,118 @@ class TestOpenAIServingCompletion(unittest.IsolatedAsyncioTestCase): self.assertIsNone(result.choices[0].prompt_logprobs) self.assertIsNone(result.choices[0].logprobs) + async def test_create_completion_cancelled_error(self): + """Test create_completion with asyncio.CancelledError""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.is_master = True + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.format_and_add_data = AsyncMock() + mock_engine_client.format_and_add_data.return_value = [1, 2, 3] + mock_engine_client.abort = AsyncMock() + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request + mock_request = Mock() + mock_request.prompt = "Hello, world!" + mock_request.prompt_token_ids = None + mock_request.stream = False + mock_request.n = 1 + mock_request.request_id = None + mock_request.user = None + + # Mock to_dict_for_infer method to return a proper dict + def mock_to_dict_for_infer(request_id_idx, prompt): + return {"prompt": prompt, "request_id": request_id_idx, "prompt_tokens": 10, "max_tokens": 100} + + mock_request.to_dict_for_infer = mock_to_dict_for_infer + + # Mock format_and_add_data to raise CancelledError + mock_engine_client.format_and_add_data.side_effect = asyncio.CancelledError("Test cancellation") + + # Call method and expect CancelledError to be handled + result = await serving_completion.create_completion(mock_request) + + # Verify that error was handled properly + self.assertIsNotNone(result) + self.assertTrue(hasattr(result, "error")) + self.assertEqual(result.error.code, "client_aborted") + self.assertIn("client disconnected", result.error.message) + + # Verify that abort was called + mock_engine_client.abort.assert_called_once() + + async def test_completion_stream_generator_cancelled_error(self): + """Test completion_stream_generator with asyncio.CancelledError""" + # Mock the engine client and its dependencies + mock_engine_client = Mock() + mock_engine_client.semaphore = Mock() + mock_engine_client.semaphore.acquire = AsyncMock() + mock_engine_client.semaphore.release = Mock() + mock_engine_client.connection_manager = AsyncMock() + mock_engine_client.data_processor = Mock() + mock_engine_client.ori_vocab_size = 1000 + mock_engine_client.check_model_weight_status.return_value = False + mock_engine_client.check_health.return_value = (True, "Healthy") + mock_engine_client.abort = AsyncMock() + + # Mock the data_processor methods + mock_engine_client.data_processor.process_response_dict = Mock() + + # Mock connection manager get_connection method + mock_dealer = Mock() + mock_dealer.write = Mock() + mock_response_queue = AsyncMock() + + # Make response_queue.get raise CancelledError + mock_response_queue.get.side_effect = asyncio.CancelledError("Test cancellation") + mock_engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + # Create serving completion instance + serving_completion = OpenAIServingCompletion(mock_engine_client, None, "pid", None, 360) + + # Create mock request + mock_request = Mock() + mock_request.prompt_logprobs = None + mock_request.logprobs = None + mock_request.include_draft_logprobs = False + mock_request.return_token_ids = True + mock_request.include_stop_str_in_output = False + mock_request.max_streaming_response_tokens = 1 + mock_request.max_tokens = None + mock_request.stream_options = Mock() + mock_request.stream_options.include_usage = False + mock_request.n = 1 + mock_request.echo = False + + # Call the method and collect results + result_generator = serving_completion.completion_stream_generator( + request=mock_request, + num_choices=1, + request_id="test_request", + created_time=1234567890, + model_name="test_model", + prompt_batched_token_ids=[[1, 2, 3]], + prompt_tokens_list=["hello", "world"], + max_tokens_list=[100], + ) + + # Collect results - should handle CancelledError gracefully + results = [] + async for result in result_generator: + results.append(result) + + # Verify that abort was called + mock_engine_client.abort.assert_called_once_with("test_request_0", 1) + + # Verify that generator ends gracefully (should have [DONE] message) + self.assertTrue(len(results) > 0) + self.assertTrue(any("[DONE]" in result for result in results)) + if __name__ == "__main__": unittest.main() diff --git a/tests/entrypoints/openai/test_with_cancellation.py b/tests/entrypoints/openai/test_with_cancellation.py new file mode 100644 index 0000000000..b7496aa254 --- /dev/null +++ b/tests/entrypoints/openai/test_with_cancellation.py @@ -0,0 +1,356 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from fastapi import Request +from fastapi.responses import StreamingResponse + +from fastdeploy.entrypoints.openai.utils import with_cancellation + + +class TestWithCancellation(unittest.TestCase): + """Test cases for with_cancellation decorator""" + + def setUp(self): + """Set up test fixtures""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + """Clean up test fixtures""" + self.loop.close() + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_normal_execution(self, mock_listen_disconnect): + """Test that handler executes normally when no disconnect occurs""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create a mock that returns a coroutine that never completes + async def never_disconnect(request): + await asyncio.Future() # This will never complete + + mock_listen_disconnect.side_effect = never_disconnect + + # Setup handler function + @with_cancellation + async def test_handler(self, raw_request): + await asyncio.sleep(0.01) # Simulate some work + return "test_result" + + # Run the test - need to pass self as first arg + result = self.loop.run_until_complete(test_handler(None, mock_request)) + + # Verify results + self.assertEqual(result, "test_result") + mock_listen_disconnect.assert_called_once_with(mock_request) + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_client_disconnect(self, mock_listen_disconnect): + """Test that handler is cancelled when client disconnects""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create a future that will complete (disconnect) after a short delay + disconnect_future = asyncio.Future() + self.loop.call_later(0.05, disconnect_future.set_result, None) + mock_listen_disconnect.return_value = disconnect_future + + # Setup handler function that takes longer than disconnect + handler_called = False + + @with_cancellation + async def test_handler(self, raw_request): + nonlocal handler_called + try: + await asyncio.sleep(0.1) # Simulate work + handler_called = True + return "should_not_reach_here" + except asyncio.CancelledError: + # Handler should be cancelled + raise + + # Run the test - need to pass self as first arg + result = self.loop.run_until_complete(test_handler(None, mock_request)) + + # Verify results + self.assertIsNone(result) # Should return None when cancelled + self.assertFalse(handler_called) # Handler should not complete + mock_listen_disconnect.assert_called_once_with(mock_request) + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_handler_with_args_kwargs(self, mock_listen_disconnect): + """Test that decorator properly handles handler with args and kwargs""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create a mock that returns a coroutine that never completes + async def never_disconnect(request): + await asyncio.Future() # This will never complete + + mock_listen_disconnect.side_effect = never_disconnect + + # Setup handler function with multiple arguments + @with_cancellation + async def test_handler(arg1, arg2, raw_request, kwarg1=None, kwarg2=None): + await asyncio.sleep(0.01) + return {"arg1": arg1, "arg2": arg2, "kwarg1": kwarg1, "kwarg2": kwarg2, "request": raw_request} + + # Run the test with both positional and keyword arguments + result = self.loop.run_until_complete( + test_handler("value1", "value2", mock_request, kwarg1="kwvalue1", kwarg2="kwvalue2") + ) + + # Verify results + expected = { + "arg1": "value1", + "arg2": "value2", + "kwarg1": "kwvalue1", + "kwarg2": "kwvalue2", + "request": mock_request, + } + self.assertEqual(result, expected) + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_handler_returns_streaming_response(self, mock_listen_disconnect): + """Test that decorator handles StreamingResponse correctly""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create a mock that returns a coroutine that never completes + async def never_disconnect(request): + await asyncio.Future() # This will never complete + + mock_listen_disconnect.side_effect = never_disconnect + + # Setup handler that returns StreamingResponse + @with_cancellation + async def test_handler(self, raw_request): + async def generate(): + yield "chunk1" + yield "chunk2" + + return StreamingResponse(generate()) + + # Run the test - need to pass self as first arg + result = self.loop.run_until_complete(test_handler(None, mock_request)) + + # Verify results + self.assertIsInstance(result, StreamingResponse) + mock_listen_disconnect.assert_called_once_with(mock_request) + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_handler_exception_propagation(self, mock_listen_disconnect): + """Test that exceptions from handler are properly propagated""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create a mock that returns a coroutine that never completes + async def never_disconnect(request): + await asyncio.Future() # This will never complete + + mock_listen_disconnect.side_effect = never_disconnect + + # Setup handler that raises an exception + @with_cancellation + async def test_handler(self, raw_request): + await asyncio.sleep(0.01) + raise ValueError("Test exception") + + # Run the test and expect exception - need to pass self as first arg + with self.assertRaises(ValueError) as context: + self.loop.run_until_complete(test_handler(None, mock_request)) + + self.assertEqual(str(context.exception), "Test exception") + mock_listen_disconnect.assert_called_once_with(mock_request) + + @patch("fastdeploy.entrypoints.openai.utils.listen_for_disconnect") + def test_concurrent_cancellation_and_completion(self, mock_listen_disconnect): + """Test edge case where cancellation and completion happen simultaneously""" + + # Setup mock request + mock_request = MagicMock(spec=Request) + + # Create futures that complete at roughly the same time + disconnect_future = asyncio.Future() + handler_future = asyncio.Future() + + # Set both futures to complete almost simultaneously + self.loop.call_later(0.05, lambda: disconnect_future.set_result(None)) + self.loop.call_later(0.05, lambda: handler_future.set_result("completed")) + + mock_listen_disconnect.return_value = disconnect_future + + @with_cancellation + async def test_handler(self, raw_request): + return await handler_future + + # Run the test - need to pass self as first arg + result = self.loop.run_until_complete(test_handler(None, mock_request)) + + # The result depends on which task completes first + # This test ensures the decorator handles this edge case gracefully + self.assertIn(result, [None, "completed"]) + + def test_wrapper_preserves_function_metadata(self): + """Test that the wrapper preserves the original function's metadata""" + + def original_handler(raw_request): + """Original handler docstring""" + pass + + # Apply decorator + decorated_handler = with_cancellation(original_handler) + + # Verify metadata is preserved + self.assertEqual(decorated_handler.__name__, "original_handler") + self.assertEqual(decorated_handler.__doc__, "Original handler docstring") + self.assertTrue(hasattr(decorated_handler, "__wrapped__")) + + +class TestListenForDisconnect(unittest.TestCase): + """Test cases for listen_for_disconnect function""" + + def setUp(self): + """Set up test fixtures""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + """Clean up test fixtures""" + self.loop.close() + + def test_listen_for_disconnect_normal_flow(self): + """Test that listen_for_disconnect waits for disconnect message""" + from fastdeploy.entrypoints.openai.utils import listen_for_disconnect + + # Setup mock request + mock_request = MagicMock() + receive_call_count = 0 + + # Create mock messages - normal messages followed by disconnect + messages = [ + {"type": "http.request", "body": b"some data"}, + {"type": "http.request", "body": b"more data"}, + {"type": "http.disconnect"}, # This should break the loop + ] + + # Setup receive method to return messages sequentially + receive_iter = iter(messages) + + async def mock_receive(): + nonlocal receive_call_count + receive_call_count += 1 + try: + return next(receive_iter) + except StopIteration: + # After all messages, return a message that keeps it waiting + await asyncio.Future() # Never completes + + mock_request.receive = mock_receive + + # Run the function in the event loop + self.loop.run_until_complete(listen_for_disconnect(mock_request)) + + # Verify that receive was called multiple times + self.assertGreaterEqual(receive_call_count, 3) + + def test_listen_for_disconnect_immediate_disconnect(self): + """Test that listen_for_disconnect returns immediately on disconnect""" + from fastdeploy.entrypoints.openai.utils import listen_for_disconnect + + # Setup mock request + mock_request = MagicMock() + receive_called = False + + # Setup receive to return disconnect immediately (as a coroutine) + async def mock_receive(): + nonlocal receive_called + receive_called = True + return {"type": "http.disconnect"} + + mock_request.receive = mock_receive + + # Run the function in the event loop + self.loop.run_until_complete(listen_for_disconnect(mock_request)) + + # Verify that receive was called exactly once + self.assertTrue(receive_called) + + def test_listen_for_disconnect_timeout(self): + """Test that listen_for_disconnect can be cancelled with timeout""" + from fastdeploy.entrypoints.openai.utils import listen_for_disconnect + + # Setup mock request + mock_request = MagicMock() + + # Setup receive to never return disconnect + async def mock_receive(): + await asyncio.Future() # Never completes + + mock_request.receive = mock_receive + + # Run with timeout to test cancellation + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(asyncio.wait_for(listen_for_disconnect(mock_request), timeout=0.01)) + + def test_listen_for_disconnect_various_message_types(self): + """Test that listen_for_disconnect ignores non-disconnect messages""" + from fastdeploy.entrypoints.openai.utils import listen_for_disconnect + + # Setup mock request + mock_request = MagicMock() + receive_call_count = 0 + + # Create various non-disconnect messages + messages = [ + {"type": "http.request", "body": b"data"}, + {"type": "http.response", "status": 200}, + {"type": "websocket.connect"}, + {"type": "http.request", "body": b"more data"}, + {"type": "http.disconnect"}, # Final disconnect + ] + + # Setup receive method + receive_iter = iter(messages) + + async def mock_receive(): + nonlocal receive_call_count + receive_call_count += 1 + try: + return next(receive_iter) + except StopIteration: + await asyncio.Future() + + mock_request.receive = mock_receive + + # Run the function in the event loop + self.loop.run_until_complete(listen_for_disconnect(mock_request)) + + # Verify all messages were processed + self.assertEqual(receive_call_count, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/test_abort.py b/tests/entrypoints/test_abort.py new file mode 100644 index 0000000000..e378ff814e --- /dev/null +++ b/tests/entrypoints/test_abort.py @@ -0,0 +1,257 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.engine.request import RequestStatus +from fastdeploy.entrypoints.engine_client import EngineClient + + +class TestEngineClientAbort(unittest.TestCase): + """Test cases for EngineClient.abort method""" + + def setUp(self): + """Set up test fixtures""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + # Create a mock FDConfig + self.mock_fd_config = MagicMock() + self.mock_fd_config.parallel_config.tensor_parallel_size = 1 + self.mock_fd_config.model_config.enable_mm = False + self.mock_fd_config.model_config.max_model_len = 2048 + self.mock_fd_config.model_config.enable_logprob = True + self.mock_fd_config.cache_config.enable_prefix_caching = False + self.mock_fd_config.scheduler_config.splitwise_role = "mixed" + self.mock_fd_config.limit_mm_per_prompt = 5 + self.mock_fd_config.eplb_config.enable_eplb = False + self.mock_fd_config.structured_outputs_config.reasoning_parser = None + self.mock_fd_config.mm_processor_kwargs = {} + self.mock_fd_config.tool_parser = None + self.mock_fd_config.cache_config.max_processor_cache = 0 + + # Create EngineClient instance + with patch("fastdeploy.entrypoints.engine_client.InputPreprocessor"): + with patch("fastdeploy.entrypoints.engine_client.IPCSignal"): + with patch("fastdeploy.entrypoints.engine_client.StatefulSemaphore"): + with patch("fastdeploy.entrypoints.engine_client.DealerConnectionManager"): + with patch("fastdeploy.entrypoints.engine_client.FileLock"): + self.engine_client = EngineClient( + pid=12345, port=8000, fd_config=self.mock_fd_config, workers=1 + ) + + def tearDown(self): + """Clean up test fixtures""" + self.loop.close() + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_single_request(self, mock_send_task): + """Test aborting a single request""" + request_id = "test_request" + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=1)) + + # Verify _send_task was called with correct data + expected_data = { + "request_id": "test_request_0", + "status": RequestStatus.ABORT.value, + } + mock_send_task.assert_called_once_with(expected_data) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_multiple_requests(self, mock_send_task): + """Test aborting multiple requests""" + request_id = "test_request" + n = 3 + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=n)) + + # Verify _send_task was called correct number of times + self.assertEqual(mock_send_task.call_count, n) + + # Verify each call had correct request_id + expected_calls = [ + ({"request_id": "test_request_0", "status": RequestStatus.ABORT.value},), + ({"request_id": "test_request_1", "status": RequestStatus.ABORT.value},), + ({"request_id": "test_request_2", "status": RequestStatus.ABORT.value},), + ] + + actual_calls = [call.args for call in mock_send_task.call_args_list] + self.assertEqual(actual_calls, expected_calls) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_with_existing_suffix(self, mock_send_task): + """Test aborting request that already has _number suffix""" + request_id = "test_request_123_2" + n = 2 + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=n)) + + # Verify _send_task was called correct number of times + self.assertEqual(mock_send_task.call_count, n) + + # Verify each call had correct request_id (should use prefix before existing suffix) + expected_calls = [ + ({"request_id": "test_request_123_0", "status": RequestStatus.ABORT.value},), + ({"request_id": "test_request_123_1", "status": RequestStatus.ABORT.value},), + ] + + actual_calls = [call.args for call in mock_send_task.call_args_list] + self.assertEqual(actual_calls, expected_calls) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_with_no_suffix(self, mock_send_task): + """Test aborting request without _number suffix""" + request_id = "test_request_without_suffix" + n = 2 + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=n)) + + # Verify _send_task was called correct number of times + self.assertEqual(mock_send_task.call_count, n) + + # Verify each call had correct request_id (should use full request_id as prefix) + expected_calls = [ + ({"request_id": "test_request_without_suffix_0", "status": RequestStatus.ABORT.value},), + ({"request_id": "test_request_without_suffix_1", "status": RequestStatus.ABORT.value},), + ] + + actual_calls = [call.args for call in mock_send_task.call_args_list] + self.assertEqual(actual_calls, expected_calls) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_with_zero_n(self, mock_send_task): + """Test aborting with n=0 should not send any requests""" + request_id = "test_request_123" + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=0)) + + # Verify _send_task was not called + mock_send_task.assert_not_called() + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_with_negative_n(self, mock_send_task): + """Test aborting with negative n should not send any requests""" + request_id = "test_request_123" + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=-1)) + + # Verify _send_task was not called + mock_send_task.assert_not_called() + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", False) + @patch.object(EngineClient, "_send_task") + def test_abort_when_feature_disabled(self, mock_send_task): + """Test abort when FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE is False""" + request_id = "test_request_123" + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=1)) + + # Verify _send_task was not called + mock_send_task.assert_not_called() + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch.object(EngineClient, "_send_task") + def test_abort_request_id_regex_parsing(self, mock_send_task): + """Test that request_id regex parsing works correctly for various formats""" + test_cases = [ + ("simple_request", "simple_request"), + ("request_with_underscores", "request_with_underscores"), + ("request_123", "request"), + ("request_123_456", "request_123"), + ("request_0", "request"), + ("complex_name_123_456_789", "complex_name_123_456"), + ] + + for input_request_id, expected_prefix in test_cases: + with self.subTest(input_request_id=input_request_id): + mock_send_task.reset_mock() + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(input_request_id, n=1)) + + # Verify _send_task was called with correct prefix + expected_data = { + "request_id": f"{expected_prefix}_0", + "status": RequestStatus.ABORT.value, + } + mock_send_task.assert_called_once_with(expected_data) + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + @patch.object(EngineClient, "_send_task") + def test_abort_logging(self, mock_send_task, mock_logger): + """Test that abort method logs correctly""" + request_id = "test_request" + n = 2 + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=n)) + + # Verify info log was called twice + self.assertEqual(mock_logger.info.call_count, 2) + + # Verify the first log message (abort start) + first_call = mock_logger.info.call_args_list[0] + self.assertEqual(first_call[0][0], "abort request_id:test_request") + + # Verify the second log message (abort completion with request IDs) + second_call = mock_logger.info.call_args_list[1] + expected_log_message = "Aborted request(s) %s." + self.assertEqual(second_call[0][0], expected_log_message) + self.assertEqual(second_call[0][1], "test_request_0,test_request_1") + + @patch("fastdeploy.entrypoints.engine_client.envs.FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", True) + @patch("fastdeploy.entrypoints.engine_client.api_server_logger") + @patch.object(EngineClient, "_send_task") + def test_abort_warning_logging_for_invalid_format(self, mock_send_task, mock_logger): + """Test that abort method logs warning for invalid request_id format""" + request_id = "invalid_format_no_suffix" # This should actually not trigger warning + n = 1 + + # Run the abort method + self.loop.run_until_complete(self.engine_client.abort(request_id, n=n)) + + # Verify warning log was called (this case might not actually trigger warning) + # The warning is only triggered when regex doesn't match, but our test case has valid format + # Let's test with a case that should trigger warning + mock_logger.reset_mock() + + # This should trigger warning because it doesn't end with _number + request_id_no_suffix = "just_a_string" + self.loop.run_until_complete(self.engine_client.abort(request_id_no_suffix, n=1)) + + # Should have logged warning about format error + mock_logger.warning.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index fe0e05e3b0..96e58bd521 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -17,7 +17,7 @@ import random import time import unittest -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch import paddle @@ -34,6 +34,19 @@ class MockConfig: class SpeculativeConfig: method = None + num_speculative_tokens = 1 + num_model_steps = 1 + max_candidate_len = 5 + verify_window = 2 + max_ngram_size = 5 + min_ngram_size = 2 + model = None + quantization = None + num_gpu_block_expand_ratio = 1 + model_type = "main" + benchmark_mode = False + num_extra_cache_layer = 0 + mtp_strategy = "default" class ModelConfig: enable_logprob = False @@ -91,6 +104,8 @@ class MockResourceManager: self.stop_flags = [False] self.tasks_list = [MockTask()] self.to_be_rescheduled_request_id_set = set() + self.abort_req_ids_set = set() + self.req_dict = {} def info(self): return "Mock resource manager info" @@ -241,6 +256,147 @@ class TestTokenProcessorProcessBatchOutput(unittest.TestCase): assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] + def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self): + """Test aborted task receiving negative token triggers recycling in speculative decoding mode""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + + # Set up task as aborted + task_id = "test_aborted_request" + task = processor.resource_manager.tasks_list[0] + task.request_id = task_id + processor.resource_manager.abort_req_ids_set = {task_id} + + # Add the task to req_dict to prevent _recycle_aborted_task from processing it early + # Use a larger batch to avoid the early recycling condition + processor.resource_manager.req_dict[task_id] = 0 # batch_id = 0 + + # Mock _recycle_resources to track if it's called + processor._recycle_resources = MagicMock() + + # Set up output tokens with negative token + # stop_flag + processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) + # mtype target = 3 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # batch = 2 (so batch_id=0 is < batch_size-1=1) + processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) + # Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic + processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(-9)) + processor.output_tokens[4, 0].set_tensor(paddle.to_tensor(1)) + + # Add second task to tasks_list + task2 = MockTask() + task2.request_id = "test_request_2" + processor.resource_manager.tasks_list = [task, task2] + processor.resource_manager.stop_flags = [False, False] + # Update tokens_counter to include both tasks + processor.tokens_counter[task_id] = 0 + processor.tokens_counter[task2.request_id] = 0 + + # Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER + with ( + patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + ): + # Call the method + processor._process_batch_output() + + # In speculative decoding mode, when accept_num[i] == PREEMPTED_TOKEN_ID, + # the code logs "sync preemption" and continues without triggering abort recycling + # This is the expected behavior for speculative decoding mode + mock_logger.info.assert_any_call(f"sync preemption for request_id {task_id} done.") + # Verify that _recycle_resources was NOT called for the aborted task + # (it may be called for other tasks like test_request_2 if they receive EOS tokens) + for call in processor._recycle_resources.call_args_list: + self.assertNotEqual( + call[0][0], task_id, f"_recycle_resources should not be called for aborted task {task_id}" + ) + # Verify that the task is still in abort_req_ids_set + self.assertIn(task_id, processor.resource_manager.abort_req_ids_set) + + def test_process_batch_output_aborted_task_negative_token_normal_mode(self): + """Test aborted task receiving negative token triggers recycling in normal mode""" + processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False) + + # Set up task as aborted + task_id = "test_aborted_request" + task = processor.resource_manager.tasks_list[0] + task.request_id = task_id + processor.resource_manager.abort_req_ids_set = {task_id} + + # Add the task to req_dict to prevent _recycle_aborted_task from processing it early + # batch_id should be < batch_size - 1 to avoid early recycling + processor.resource_manager.req_dict[task_id] = ( + 0 # batch_id = 0, batch_size = 1, so 0 < 0 is false, but 0 >= 0 is true + ) + # Actually, let's use a larger batch to avoid the early recycling condition + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + # Mock _recycle_resources to track if it's called + processor._recycle_resources = MagicMock() + + # Set up output tokens with negative token + # batch = 2 (so batch_id=0 is < batch_size-1=1) + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(2)) + # Set negative token for first task (batch_id=0) + processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1)) + # Set positive token for second task (batch_id=1) + processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(100)) + + # Add second task to tasks_list + task2 = MockTask() + task2.request_id = "test_request_2" + processor.resource_manager.tasks_list = [task, task2] + processor.resource_manager.stop_flags = [False, False] + # Update tokens_counter to include both tasks + processor.tokens_counter[task_id] = 0 + processor.tokens_counter[task2.request_id] = 0 + + # Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER + with ( + patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + ): + # Call the method + processor._process_batch_output() + + # Verify the recycling logic was triggered + mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.") + processor._recycle_resources.assert_called_once_with(task_id, 0, task) + self.assertNotIn(task_id, processor.resource_manager.abort_req_ids_set) + + def test_process_batch_output_non_aborted_task_negative_token(self): + """Test non-aborted task receiving negative token does not trigger recycling""" + processor = self.setup_token_processor(speculative_decoding=False, use_logprobs=False) + + # Set up task as not aborted + task_id = "test_normal_request" + task = processor.resource_manager.tasks_list[0] + task.request_id = task_id + processor.resource_manager.abort_req_ids_set = set() # Empty set + + # Mock _recycle_resources to track if it's called + processor._recycle_resources = MagicMock() + + # Set up output tokens with negative token + # batch = 1 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(1)) + # Set negative token + processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(-1)) + + # Mock llm_logger to capture the log message and envs.ENABLE_V1_KVCACHE_SCHEDULER + with ( + patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + ): + # Call the method + processor._process_batch_output() + print(mock_logger) + # Verify the recycling logic was NOT triggered + # When a non-aborted task receives a negative token, the code just continues + # without logging or recycling + processor._recycle_resources.assert_not_called() + if __name__ == "__main__": unittest.main(verbosity=2, buffer=False) diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index 2d4c2efadd..eb9ea9fc11 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -153,6 +153,60 @@ class TestTokenProcessorLogprobs(unittest.TestCase): self.assertEqual(len(result), 0) + def test_process_batch_output_use_zmq_aborted_task_negative_token(self): + """Test aborted task receiving negative token triggers recycling logic""" + # Set up task as aborted + task_id = "test_aborted_request" + self.task_mock.request_id = task_id + self.processor.resource_manager.abort_req_ids_set = {task_id} + + # Create stream data with negative token + stream_data = MagicMock() + stream_data.tokens = np.array([1, 2, -1]) # Last token is negative + stream_data.batch_id = 0 + + # Mock _recycle_resources to track if it's called + self.processor._recycle_resources = MagicMock() + + # Mock the llm_logger module and envs.ENABLE_V1_KVCACHE_SCHEDULER + with ( + patch("fastdeploy.output.token_processor.llm_logger") as mock_logger, + patch("fastdeploy.output.token_processor.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0), + ): + # Call the method + result = self.processor._process_batch_output_use_zmq([stream_data]) + + # Verify the recycling logic was triggered + mock_logger.info.assert_any_call(f"Aborted task {task_id} received negative token. Recycling.") + self.processor._recycle_resources.assert_called_once_with(task_id, 0, self.task_mock) + self.assertNotIn(task_id, self.processor.resource_manager.abort_req_ids_set) + self.assertEqual(len(result), 1) # Should return abort result + self.assertEqual(result[0].finished, True) + self.assertEqual(result[0].error_code, 499) + self.assertIn("aborted", result[0].error_msg.lower()) + + def test_process_batch_output_use_zmq_non_aborted_task_negative_token(self): + """Test non-aborted task receiving negative token does not trigger recycling""" + # Set up task as not aborted + task_id = "test_normal_request" + self.task_mock.request_id = task_id + self.processor.resource_manager.abort_req_ids_set = set() # Empty set + + # Create stream data with negative token + stream_data = MagicMock() + stream_data.tokens = np.array([1, 2, -1]) # Last token is negative + stream_data.batch_id = 0 + + # Mock _recycle_resources to track if it's called + self.processor._recycle_resources = MagicMock() + + # Call the method + self.processor._process_batch_output_use_zmq([stream_data]) + + # Verify recycling logic was NOT triggered + self.processor._recycle_resources.assert_not_called() + self.processor.llm_logger.info.assert_not_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index 96c0182acf..2d284bb7bf 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -74,6 +74,7 @@ class _DummyResourceManager: self.req_dict = {} self.requests = {} self.to_be_rescheduled_request_id_set = set() + self.abort_req_ids_set = set() self.recycled = [] self.cached_tasks = [] self.cleared = False