diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 0472b89fb1..edee21af06 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -16,15 +16,9 @@ from __future__ import annotations -import asyncio -import json -import multiprocessing +import inspect import os -import re import signal -import subprocess -import sys -import threading import time import traceback import uuid @@ -33,142 +27,223 @@ from dataclasses import asdict from typing import Any, AsyncGenerator, Dict, List, Optional, Union import numpy as np -import paddle -from tqdm import tqdm +import zmq from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.common_engine import EngineService -from fastdeploy.engine.expert_service import start_data_parallel_service -from fastdeploy.engine.request import Request, RequestOutput +from fastdeploy.engine.request import RequestOutput from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.input.preprocess import InputPreprocessor -from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal +from fastdeploy.inter_communicator import IPCSignal +from fastdeploy.inter_communicator.zmq_client import ZmqIpcClient from fastdeploy.metrics.metrics import main_process_metrics -from fastdeploy.utils import EngineError, console_logger, envs, llm_logger - - -class AsyncRequestQueue: - """Async request output queue for managing single request output stream""" - - def __init__(self, request_id: str): - self.request_id = request_id - self.queue: asyncio.Queue[Union[RequestOutput, Exception]] = asyncio.Queue() - self._finished = False - self._cache_size = 0 - - async def put(self, output: RequestOutput) -> None: - """Put output to queue with memory allocation optimization""" - if isinstance(output, RequestOutput) and output.finished: - self._finished = True - await self.queue.put(output) - self._cache_size += 1 - - async def put_error(self, error: Exception) -> None: - """Put error to queue""" - self._finished = True - await self.queue.put(error) - - async def get(self) -> RequestOutput: - """Get output, raise exception if it's an error""" - result = await self.queue.get() - self._cache_size = max(0, self._cache_size - 1) - if isinstance(result, Exception): - raise result - return result - - def get_nowait(self) -> Optional[RequestOutput]: - """Non-blocking get output""" - try: - result = self.queue.get_nowait() - self._cache_size = max(0, self._cache_size - 1) - if isinstance(result, Exception): - raise result - return result - except asyncio.QueueEmpty: - return None - - @property - def finished(self) -> bool: - """Check if request is completed""" - return self._finished - - @property - def size(self) -> int: - """Return queue size for performance monitoring""" - return self._cache_size +from fastdeploy.utils import EngineError, llm_logger class AsyncOutputProcessor: """Async output processor responsible for distributing engine outputs to corresponding request queues""" - def __init__(self, tokenizer=None): - self.request_queues: Dict[str, AsyncRequestQueue] = {} - self.tokenizer = tokenizer + def __init__(self, data_processor=None): + """ + Args: + data_processor: The data processor created by InputPreprocessor, + used to post-process RequestOutput (decode token_ids, reasoning, tools, etc.). + """ + self.data_processor = data_processor - async def register_request(self, request_id: str, queue: AsyncRequestQueue) -> None: - """Register request queue""" - self.request_queues[request_id] = queue + def _process_output( + self, + response_dict: Dict[str, Any], + stream: bool = True, + enable_thinking: bool = False, + include_stop_str_in_output: bool = False, + ) -> Dict[str, Any]: + """Process a single response dict via data_processor.process_response_dict. - async def process_outputs(self, outputs: Dict[str, List[RequestOutput]]) -> None: - """Process engine outputs and distribute to corresponding request queues""" - if not outputs: - return - - finished_requests = [] - - for request_id, output_list in outputs.items(): - if request_id not in self.request_queues: - continue - - queue = self.request_queues[request_id] - - # Ensure output_list is in list format - if not isinstance(output_list, list): - output_list = [output_list] - - for output in output_list: - # Process single output - processed_output = self._process_single_output(output) - await queue.put(processed_output) - - if processed_output.finished: - finished_requests.append(request_id) - - # Clean up completed requests - for request_id in finished_requests: - self.request_queues.pop(request_id, None) - - def _process_single_output(self, output: RequestOutput) -> RequestOutput: - """Process single output for token decoding""" + This mirrors the behavior of ChatResponseProcessor in the OpenAI serving + path: operate on a dict representation and return a dict. On any error + we fall back to the original dict and ensure ``outputs.text`` exists to + avoid cascading failures. + """ try: - token_ids = output.outputs.token_ids - decoded_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) - output.outputs.text = decoded_text + processed = self.data_processor.process_response_dict( + response_dict, + stream=stream, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + # Some processors may return None when there is no valid text. + if processed is None: + outputs = response_dict.get("outputs") or {} + if "text" not in outputs: + outputs["text"] = "" + response_dict["outputs"] = outputs + return response_dict + return processed except Exception: - if not hasattr(output.outputs, "text"): - output.outputs.text = "" - - return output - - async def abort_request(self, request_id: str) -> None: - """Abort request and clean up related resources""" - if request_id in self.request_queues: - queue = self.request_queues.pop(request_id) - await queue.put_error(EngineError("Request aborted", error_code=499)) - - async def propagate_error(self, error: Exception) -> None: - """Propagate error to all active request queues""" - tasks = [] - for queue in list(self.request_queues.values()): - if not queue.finished: - tasks.append(queue.put_error(error)) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - self.request_queues.clear() + outputs = response_dict.get("outputs") or {} + if "text" not in outputs: + outputs["text"] = "" + response_dict["outputs"] = outputs + return response_dict -class AsyncLLMEngine: +class EngineServiceClient: + """ + Base engine service client, responsible for managing EngineService lifecycle. + """ + + def __init__(self, cfg, pid): + self.cfg = cfg + self.engine_process = None + self.engine_pid = pid + self._running = False + + llm_logger.info(f"EngineServiceClient initialized with engine_pid: {self.engine_pid}") + + async def start(self): + """Start engine service process""" + try: + # Start independent engine process + self._start_engine_process() + + # Wait for engine to be ready + if not self._wait_engine_ready(): + raise EngineError("Engine failed to start within timeout", error_code=500) + + self._running = True + llm_logger.info("EngineServiceClient started successfully") + + except Exception as e: + llm_logger.error(f"Failed to start EngineServiceClient: {e}") + raise + return True + + def _start_engine_process(self): + """Start engine process""" + try: + import multiprocessing + + self.shutdown_signal = multiprocessing.Value("i", 0) # 0=running, 1=shutdown + + def run_engine(): + engine = None + + def signal_handler(signum, frame): + llm_logger.info(f"Engine process received signal {signum}, initiating shutdown...") + if engine: + engine.running = False + + # Register signal handlers + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + try: + engine = EngineService(self.cfg, use_async_llm=True) + # Start engine with ZMQ service + engine.start(async_llm_pid=self.engine_pid) + + # Keep engine running until shutdown signal is received + while self.shutdown_signal.value == 0 and getattr(engine, "running", True): + time.sleep(0.5) + + except Exception as e: + llm_logger.error(f"Engine process error: {e}, {str(traceback.format_exc())}") + finally: + if engine and hasattr(engine, "_exit_sub_services"): + try: + engine._exit_sub_services() + llm_logger.info("Engine process cleanup completed") + except Exception as e: + llm_logger.error(f"Error during engine cleanup: {e}") + + self.engine_process = multiprocessing.Process(target=run_engine) + self.engine_process.start() + + llm_logger.info(f"Started engine process with PID: {self.engine_process.pid}") + + except Exception as e: + llm_logger.error(f"Failed to start engine process: {e}") + raise + + def _wait_engine_ready(self) -> bool: + """Wait for engine and workers to be fully ready""" + max_wait_time = 180 # seconds + wait_interval = 1 + elapsed_time = 0 + + llm_logger.info("Waiting for engine and workers to be ready...") + + # Use IPC signals to check engine readiness + # Get the correct suffix + ipc_suffix = ( + self.cfg.parallel_config.engine_worker_queue_port[0] + if hasattr(self.cfg, "parallel_config") + else self.engine_pid + ) + + # Check if loaded_model_signal exists and is ready + loaded_model_signal = None + + while elapsed_time < max_wait_time: + # Try to connect to loaded_model_signal + if loaded_model_signal is None: + try: + loaded_model_signal = IPCSignal( + name="loaded_model_signal", + array=np.zeros([1], dtype=np.int32), + dtype=np.int32, + suffix=ipc_suffix, + create=False, + ) + except: + # Signal not ready yet + time.sleep(wait_interval) + elapsed_time += wait_interval + continue + + # Check if workers have loaded models + if loaded_model_signal.value[0] > 0: + llm_logger.info("Workers have loaded models successfully") + # Give ZMQ service more time to fully start + llm_logger.info("Waiting additional time for ZMQ service to be ready...") + time.sleep(5) # Wait for ZMQ service startup + recv_result_handle + return True + + time.sleep(wait_interval) + elapsed_time += wait_interval + + if elapsed_time % 10 == 0: # Log every 10 seconds + llm_logger.info(f"Waiting for workers to load models... ({elapsed_time}s)") + + return False + + def shutdown(self): + """Shutdown engine service process""" + llm_logger.info("Shutting down EngineServiceClient...") + + self._running = False + + # Send graceful shutdown signal to engine process + if hasattr(self, "shutdown_signal"): + llm_logger.info("Sending shutdown signal to engine process...") + self.shutdown_signal.value = 1 + + # Wait for engine process to shutdown + if self.engine_process and self.engine_process.is_alive(): + llm_logger.info("Waiting for engine process to shutdown...") + self.engine_process.terminate() + self.engine_process.join(timeout=5) + if self.engine_process.is_alive(): + llm_logger.warning("Force killing engine process...") + self.engine_process.kill() + + llm_logger.info("EngineServiceClient shutdown completed") + + +class AsyncLLM(EngineServiceClient): """ Engine class responsible for managing the Large Language Model (LLM) operations. @@ -180,36 +255,36 @@ class AsyncLLMEngine: resource_manager (ResourceManager): Manager for resource allocation. token_processor (TokenProcessor): Processor for token generation. engine_worker_queue (EngineWorkerQueue): Queue for communication between engine and workers. - is_started (bool): Flag indicating if the engine has started. do_profile (int): Flag indicating if profiling is enabled. """ @classmethod - def from_engine_args(cls, engine_args: EngineArgs): + def from_engine_args(cls, engine_args: EngineArgs, pid): """ - Creates an AsyncLLMEngine from the provided engine arguments. + Creates an AsyncLLM client from the provided engine arguments. Args: engine_args (EngineArgs): Engine arguments object. Returns: - AsyncLLMEngine: Instance of the AsyncLLMEngine class. + AsyncLLM: Instance of the AsyncLLM class. """ # Create the engine configs. config = engine_args.create_engine_config() - # Create the AsyncLLMEngine. - return cls(cfg=config) + # Create the AsyncLLM client. + return cls(cfg=config, pid=pid) - def __init__(self, cfg): + def __init__(self, cfg, pid): """ - Initializes the AsyncLLMEngine with the provided configuration. + Initializes the AsyncLLM client with the provided configuration. Args: cfg (Config): Config object containing all the configuration parameters. """ + super().__init__(cfg, pid) self.cfg = cfg self.running = True - self.is_started = False + self._prompt_metadata: Dict[str, Dict[str, Any]] = {} self.input_processor = InputPreprocessor( cfg.model_config, @@ -218,110 +293,39 @@ class AsyncLLMEngine: cfg.mm_processor_kwargs, cfg.tool_parser, ) - self.engine_service = EngineService(cfg) + # Create data processor + self.data_processor = self.input_processor.create_processor() - if self.cfg.cache_config.num_gpu_blocks_override is None: - self.do_profile = 1 - else: - self.do_profile = 0 + # Create high-performance async connection manager + self.connection_manager = None + self.request_client = None - # Create async output processor, pass tokenizer for decoding - tokenizer = None - if hasattr(self, "input_processor") and hasattr(self.input_processor, "tokenizer"): - tokenizer = self.input_processor.tokenizer - elif hasattr(self, "data_processor") and hasattr(self.data_processor, "tokenizer"): - tokenizer = self.data_processor.tokenizer - - self.output_processor = AsyncOutputProcessor(tokenizer=tokenizer) - - self.output_handler: Optional[asyncio.Task] = None + # Output processor uses data_processor for post-processing engine outputs + self.output_processor = AsyncOutputProcessor(self.data_processor) self._finalizer = weakref.finalize(self, self._exit_sub_services) main_process_metrics.set_cache_config_info(obj=self.cfg.cache_config) - def start(self): - """ - Initializes the engine and starts its sub-services. - """ - assert not self.is_started, "The engine is already started." - start_time = time.time() - - self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0] - self._init_worker_signals() - - self.data_processor = self.input_processor.create_processor() - self.engine_service.data_processor = self.data_processor - - # Launch components: scheduler, cache_manager, expert_service et.al. - self.launch_components() - - # Update output processor tokenizer - if hasattr(self.data_processor, "tokenizer") and self.data_processor.tokenizer: - self.output_processor.tokenizer = self.data_processor.tokenizer - - self.engine_service.start() - - # If block number is specified and model is deployed in splitwise mode, start cache manager first - if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine_service.start_cache_service(device_ids, self.ipc_signal_suffix) - - # Start workers - self.worker_proc = self._start_worker_service() - console_logger.info("Waiting worker processes ready...") - time.sleep(5) - self.worker_init_status = dict() - - result_container = {} - - def check_worker_initialize_status_func(res: dict): - res["worker_is_alive"] = True - if not self.check_worker_initialize_status(): - console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") - res["worker_is_alive"] = False - - self.check_worker_initialize_status_func_thread = threading.Thread( - target=check_worker_initialize_status_func, args=(result_container,), daemon=True - ) - self.check_worker_initialize_status_func_thread.start() - - # Wait model loading - while self.loaded_model_signal.value[0] == 0: - # Make sure worker process is alive - if not self.check_worker_initialize_status_func_thread.is_alive(): - return False - time.sleep(1) - - # If block number is not specified, let workers do profiling to determine the block number, - # and then start the cache manager - if self.do_profile: - self._stop_profile() - elif self.cfg.cache_config.enable_prefix_caching: - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine_service.start_cache_service(device_ids, self.ipc_signal_suffix) - - # Set cache manager signal - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.launched_cache_manager_signal.value[0] = 1 - - # Worker launched - self.check_worker_initialize_status_func_thread.join() - if not result_container["worker_is_alive"]: - console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") - return False - - console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.") - + async def init_connections(self): + """Initialize high-performance ZMQ connections""" try: - # Start output handler eagerly if we are in the asyncio eventloop. - asyncio.get_running_loop() - self._start_output_handler() - except RuntimeError: - pass + # Create ZMQ client for sending requests + self.request_client = ZmqIpcClient(name=self.engine_pid, mode=zmq.PUSH) + self.request_client.connect() - self.is_started = True - return True + # Create high-performance async connection manager for receiving responses + self.connection_manager = DealerConnectionManager( + pid=self.engine_pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) + ) + + if not self.connection_manager.running: + await self.connection_manager.initialize() + + llm_logger.info("High-performance ZMQ connections initialized successfully") + except Exception as e: + llm_logger.error(f"Failed to initialize ZMQ connections: {e}") + raise async def get_model_config(self): """Get model configuration""" @@ -356,7 +360,7 @@ class AsyncLLMEngine: sampling_params: Optional[SamplingParams] = None, arrival_time: Optional[float] = None, **kwargs, - ) -> AsyncRequestQueue: + ): """ Async add request @@ -367,18 +371,11 @@ class AsyncLLMEngine: arrival_time: Arrival time **kwargs: Other parameters - Returns: - AsyncRequestQueue: Request output queue """ - if not self.is_started or self.engine_service is None: - raise EngineError("Engine not started. Call start() first.", error_code=500) if request_id is None: request_id = str(uuid.uuid4()) - # Create output queue - output_queue = AsyncRequestQueue(request_id) - if arrival_time is None: arrival_time = time.time() @@ -401,40 +398,32 @@ class AsyncLLMEngine: prompt.update(asdict(sampling_params)) try: - request = Request.from_dict(prompt) - request.metrics.scheduler_recv_req_time = time.time() - - # Check if already preprocessed by AsyncEngineClient + # Check if already preprocessed by api_server is_preprocessed = prompt.get("_preprocessed", False) - # Set sampling_params - if sampling_params is not None: - request.sampling_params = sampling_params + if inspect.iscoroutinefunction(self.data_processor.process_request_dict): + request = await self.data_processor.process_request_dict(prompt, self.cfg.model_config.max_model_len) + else: + request = self.data_processor.process_request_dict(prompt, self.cfg.model_config.max_model_len) - # Preprocess request - request = self.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs) + request["prompt_token_ids_len"] = len(request["prompt_token_ids"]) - prompt_token_ids_len = len(request.prompt_token_ids) - request.prompt_token_ids_len = prompt_token_ids_len - request.need_prefill_tokens = prompt_token_ids_len + # Cache prompt metadata for later enrichment of async responses + req_id = request.get("request_id") + self._prompt_metadata[req_id] = { + "prompt_token_ids": request.get("prompt_token_ids"), + "prompt_tokens": request.get("prompt_tokens"), + } if not is_preprocessed: - request.metrics.preprocess_start_time = arrival_time - input_ids_len = request.prompt_token_ids_len + request["preprocess_start_time"] = arrival_time + input_ids_len = request["prompt_token_ids_len"] - request.set( - "max_tokens", - min( - self.cfg.model_config.max_model_len - input_ids_len, - request.get("max_tokens"), - ), + request["max_tokens"] = min( + self.cfg.model_config.max_model_len - input_ids_len, request.get("max_tokens") ) - if request.get("reasoning_max_tokens") is None: - default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1) - request.set("reasoning_max_tokens", default_reasoning_max_tokens) - - min_tokens = request.get("min_tokens") + min_tokens = request.get("min_tokens", 1) if input_ids_len + min_tokens >= self.cfg.model_config.max_model_len: error_msg = ( f"Input text is too long, length of prompt token({input_ids_len}) " @@ -443,25 +432,22 @@ class AsyncLLMEngine: llm_logger.error(error_msg) raise EngineError(error_msg, error_code=400) - if input_ids_len > self.cfg.model_config.max_model_len: - error_msg = f"Length of input token({input_ids_len}) exceeds the limit max_model_len({self.cfg.model_config.max_model_len})." - llm_logger.error(error_msg) - raise EngineError(error_msg, error_code=400) + request["preprocess_end_time"] = time.time() + preprocess_cost_time = request["preprocess_end_time"] - request["preprocess_start_time"] + llm_logger.info( + f"Cache request with request_id ({request.get('request_id')}), " + f"preprocess time cost {preprocess_cost_time}" + ) - request.metrics.preprocess_end_time = time.time() - - # Register output queue first, then add request - await self.output_processor.register_request(request_id, output_queue) - - # TODO: Optimize architecture to implement async transmission to worker - self.engine_service.scheduler.put_requests([request]) - - return output_queue + if not self.cfg.model_config.enable_mm: + self.request_client.send_json(request) + else: + self.request_client.send_pyobj(request) except EngineError: raise except Exception as e: - raise EngineError(f"Request processing failed: {e}", error_code=400) + raise EngineError(f"async_llm add request failed: {e}", error_code=400) async def generate( self, @@ -475,40 +461,109 @@ class AsyncLLMEngine: Args: prompt: Input prompt - sampling_params: Sampling parameters + sampling_params: Sampling parameters. If `sampling_params.n > 1`, + will generate `n` completions sequentially. request_id: Request ID **kwargs: Other parameters Yields: RequestOutput: Generated output """ - if not self.is_started: - raise EngineError("Engine not started. Call start() first.", error_code=500) + + num_choices = sampling_params.n if sampling_params is not None and sampling_params.n else 1 + stream = True + include_stop_str_in_output = False + enable_thinking = kwargs.pop("enable_thinking", False) + + if isinstance(prompt, dict): + num_choices = prompt.get("n") + stream = prompt.get("stream", True) + include_stop_str_in_output = prompt.get("include_stop_str_in_output", False) + + # Ensure ZMQ client and connection manager are initialized in current process + if ( + self.request_client is None + or self.connection_manager is None + or not getattr(self.connection_manager, "running", False) + ): + raise EngineError( + "AsyncLLM engine not initialized. Call init_connections() before generate.", + error_code=500, + ) + + # Build request ids and connection key + if num_choices <= 1: + # Single-choice: keep user-provided request_id semantics + child_request_ids = [request_id or str(uuid.uuid4())] + conn_request_id = child_request_ids[0] + else: + # Multi-choice: use unified "cmpl-" base id so DealerConnectionManager + # can merge cmpl-xxx_0, cmpl-xxx_1, ... back to the same response queue. + user_request_id = request_id or str(uuid.uuid4()) + conn_request_id = f"cmpl-{user_request_id}" + child_request_ids = [f"{conn_request_id}_{i}" for i in range(num_choices)] try: - # Ensure output processor is running - self._start_output_handler() + # 1) Send all sub-requests to engine + for child_request_id in child_request_ids: + await self.add_request(child_request_id, prompt, sampling_params, **kwargs) - # Async add request - output_queue = await self.add_request(request_id, prompt, sampling_params, **kwargs) + # 2) Get a shared connection for conn_request_id and handshake all sub-requests + dealer, response_queue = await self.connection_manager.get_connection( + request_id=conn_request_id, num_choices=num_choices + ) - finished = False + for child_request_id in child_request_ids: + dealer.write([b"", child_request_id.encode("utf-8")]) - while not finished: - # Prefer non-blocking get first - output = output_queue.get_nowait() or await output_queue.get() - finished = output.finished - yield output + # 3) Stream responses from all choices interleaved + remaining = num_choices + while remaining > 0: + response_list = await response_queue.get() + + for response_item in response_list: + if isinstance(response_item, dict) and "request_id" in response_item: + req_id = response_item.get("request_id") + + # First, use output_processor to post-process the raw dict + if hasattr(self, "output_processor"): + processed_output = self.output_processor._process_output( + response_item, + stream=stream, + enable_thinking=enable_thinking, + include_stop_str_in_output=include_stop_str_in_output, + ) + else: + processed_output = response_item + + # Then convert processed dict to RequestOutput + request_output = RequestOutput.from_dict(processed_output) + + # Enrich outputs with prompt metadata on the first packet + if req_id: + prompt_meta = self._prompt_metadata.get(req_id) + if prompt_meta is not None and request_output.outputs.send_idx == 0: + request_output.prompt_token_ids = prompt_meta.get("prompt_token_ids") + request_output.prompt = prompt_meta.get("prompt_tokens") + self._prompt_metadata.pop(req_id, None) + + if request_output.finished: + remaining -= 1 + + yield request_output - except EngineError: - raise except GeneratorExit: - llm_logger.info(f"Request {request_id} generator exit (outer)") + llm_logger.info(f"Request {conn_request_id} generator exit (outer)") return except Exception as e: - await self.abort_request(request_id) - llm_logger.error(f"Request {request_id} failed: {e}") + llm_logger.error(f"Request {conn_request_id} failed: {e}") raise EngineError(str(e), error_code=500) from e + finally: + # Ensure request_map/request_num are cleaned up + try: + await self.connection_manager.cleanup_request(conn_request_id) + except Exception: + pass async def abort_request(self, request_id: str) -> None: """ @@ -518,47 +573,13 @@ class AsyncLLMEngine: request_id: Request ID to abort """ try: - await self.output_processor.abort_request(request_id) + # Clean up request through DealerConnectionManager + if hasattr(self, "connection_manager") and self.connection_manager: + await self.connection_manager.cleanup_request(request_id) llm_logger.info(f"Aborted request {request_id}") except Exception as e: llm_logger.error(f"Failed to abort request {request_id}: {e}") - def _start_output_handler(self) -> None: - """Start background output processing task""" - if self.output_handler is not None: - return - - async def output_handler_loop(): - """Background loop: get results from engine service and distribute to corresponding queues""" - try: - while self.running: - # Check engine service status - if self.engine_service is None: - await asyncio.sleep(0.001) - continue - - results = self.engine_service.scheduler.get_results() - - if not results: - # No results, minimal delay to yield control - await asyncio.sleep(0) - continue - - await self.output_processor.process_outputs(results) - - except GeneratorExit: - llm_logger.info("Output handler loop received GeneratorExit, shutting down gracefully") - except asyncio.CancelledError: - llm_logger.info("Output handler loop cancelled, shutting down gracefully") - except Exception as e: - llm_logger.exception("AsyncLLM output_handler failed") - await self.output_processor.propagate_error(e) - finally: - llm_logger.info("Output handler loop finished") - - self.output_handler = asyncio.create_task(output_handler_loop()) - llm_logger.info("Output handler started") - async def shutdown(self): """ Gracefully shutdown AsyncLLM engine @@ -567,453 +588,32 @@ class AsyncLLMEngine: self.running = False - # Clean up request queues in output processor (clean queues first to avoid new tasks) - if hasattr(self, "output_processor"): + # Close high-performance connection manager + if hasattr(self, "connection_manager") and self.connection_manager is not None: + llm_logger.info("Stopping connection manager...") try: - await self.output_processor.propagate_error(Exception("AsyncLLM shutdown")) + await self.connection_manager.close() except Exception as e: - llm_logger.warning(f"Error while cleaning output processor: {e}") + llm_logger.error(f"Error while stopping connection manager: {e}") - # Shutdown async output processor - if hasattr(self, "output_handler") and self.output_handler and not self.output_handler.done(): - self.output_handler.cancel() + # Close ZMQ client + if hasattr(self, "request_client") and self.request_client is not None: + llm_logger.info("Closing request client...") try: - await asyncio.wait_for(self.output_handler, timeout=2.0) - except asyncio.CancelledError: - llm_logger.info("Output handler cancelled successfully") - except asyncio.TimeoutError: - llm_logger.warning("Output handler cancellation timeout, proceeding with cleanup") + self.request_client.close() except Exception as e: - llm_logger.warning(f"Error while cancelling output handler: {e}") - finally: - self.output_handler = None + llm_logger.warning(f"Error closing request client: {e}") - # Shutdown underlying engine service - if hasattr(self, "engine_service") and self.engine_service is not None: - llm_logger.info("Stopping engine service...") - try: - if hasattr(self.engine_service, "running"): - self.engine_service.running = False + # Shutdown engine service process + try: + super().shutdown() + except Exception as e: + llm_logger.error(f"Error while stopping engine service process: {e}") - self._exit_sub_services() - except Exception as e: - llm_logger.error(f"Error while stopping engine service: {e}") - - self.is_started = False llm_logger.info("AsyncLLM shutdown completed") - def _worker_processes_ready(self): - """ - judge if all worker processes are ready - - """ - if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node: - return True - return False - - def _init_worker_signals(self): - """ - Initialize shared memory to indicate engine status - """ - # worker_ready_signal 用于worker进程感知engine是否启动完成 - worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) - self.worker_ready_signal = IPCSignal( - name="worker_ready_signal", - array=worker_ready_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # launched_cache_manager_signal 用于感知engine是否启动了cache_manager - if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": - launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) - self.launched_cache_manager_signal = IPCSignal( - name="launched_cache_manager_signal", - array=launched_cache_manager_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # launched_expert_service_signal: Used to sense whether each expet_servic is started successfully - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - launched_expert_service_signal_data = np.zeros( - shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 - ) - self.launched_expert_service_signal = IPCSignal( - name="launched_expert_service_signal", - array=launched_expert_service_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - # loaded_model_signal: Used to detect whether each worker has completed model loading - loaded_model_signal_data = np.zeros([1], dtype=np.int32) - self.loaded_model_signal = IPCSignal( - name="loaded_model_signal", - array=loaded_model_signal_data, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - - if self.do_profile: - if paddle.is_compiled_with_custom_device("iluvatar_gpu"): - get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) - else: - get_profile_block_num = np.zeros([1], dtype=np.int32) - self.get_profile_block_num_signal = IPCSignal( - name="get_profile_block_num", - array=get_profile_block_num, - dtype=np.int32, - suffix=self.ipc_signal_suffix, - create=True, - ) - def _exit_sub_services(self): """ - exit sub services + Clean up any remaining resources """ - self.running = False - - if hasattr(self, "cache_manager_processes"): - self.engine_service.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() - self.engine_service.resource_manager.cache_manager.cache_ready_signal.clear() - for p in self.cache_manager_processes: - llm_logger.info(f"Killing cache manager process {p.pid}") - try: - pgid = os.getpgid(p.pid) - os.killpg(pgid, signal.SIGTERM) - except Exception as e: - console_logger.error( - f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" - ) - self.worker_ready_signal.clear() - self.loaded_model_signal.clear() - - if hasattr(self, "get_profile_block_num_signal"): - self.get_profile_block_num_signal.clear() - if hasattr(self, "worker_proc") and self.worker_proc is not None: - try: - pgid = os.getpgid(self.worker_proc.pid) - os.killpg(pgid, signal.SIGTERM) - except Exception as e: - console_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") - - if hasattr(self, "zmq_server") and self.zmq_server is not None: - self.zmq_server.close() - if hasattr(self, "dp_processed"): - for p in self.dp_processed: - console_logger.info(f"Waiting for worker {p.pid} to exit") - p.join() - for p in self.dp_engine_worker_queue_server: - p.cleanup() - - def _setting_environ_variables(self): - """ - 配置环境变量 - """ - variables = { - "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, - "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")), - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", - "FLAGS_use_append_attn": 1, - "NCCL_ALGO": "Ring", - "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), - "OMP_NUM_THREADS": 3, - } - # environment variables needed by Dy2St - variables.update( - { - "SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"), - "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), - "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), - "SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), - "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), - "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), - "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( - "FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1" - ), - "FLAGS_parameters_persistent_mode_in_dy2st": os.getenv( - "FLAGS_parameters_persistent_mode_in_dy2st", default="1" - ), - } - ) - - if self.cfg.scheduler_config.splitwise_role != "mixed": - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 - else: - variables["FLAGS_use_pd_disaggregation"] = 1 - # TODO dynamic load environment variable - if self.cfg.scheduler_config.splitwise_role == "prefill": - variables["FLAGS_fmt_write_cache_completed_signal"] = 1 - - if self.cfg.model_config.enable_mm: - variables["FLAGS_max_partition_size"] = 1024 - - command_prefix = "" - for k, v in variables.items(): - command_prefix += f"{k}={v} " - return command_prefix - - def _start_worker_service(self): - """ - start gpu worker service - - """ - log_dir = os.getenv("FD_LOG_DIR", default="log") - command_prefix = self._setting_environ_variables() - current_file_path = os.path.abspath(__file__) - current_dir_path = os.path.split(current_file_path)[0] - # TODO - uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u" - pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch" - pd_cmd = pd_cmd + f" --log_dir {log_dir}" - - worker_path = "../worker/worker_process.py" - py_script = os.path.join(current_dir_path, worker_path) - - ori_vocab_size = ( - len(self.data_processor.tokenizer.sp_model) - if hasattr(self.data_processor.tokenizer, "sp_model") - else len(self.data_processor.tokenizer.vocab) - ) - - think_end_id = self.data_processor.tokenizer.get_vocab().get("", -1) - if think_end_id > 0: - llm_logger.info(f"Get think_end_id {think_end_id} from vocab.") - else: - llm_logger.info("No token found in vocabulary, the model can not do reasoning.") - image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1) - line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1) - - ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port) - ips = None - if self.cfg.ips is not None: - ips = ",".join(self.cfg.ips) - arguments = ( - f" --devices {self.cfg.parallel_config.device_ids} {py_script}" - f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}" - f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" - f" --model {self.cfg.model_config.model!s}" - f" --device_ids {self.cfg.parallel_config.device_ids}" - f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" - f" --engine_worker_queue_port {ports}" - f" --pod_ip {self.cfg.master_ip}" - f" --block_size {self.cfg.cache_config.block_size}" - f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" - f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" - f" --pad_token_id {self.data_processor.pad_token_id}" - f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}" - f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}" - f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}" - f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" - f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" - f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}" - f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" - f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" - f" --ori_vocab_size {ori_vocab_size}" - f" --think_end_id {think_end_id}" - f" --image_patch_id {image_patch_id}" - f" --line_break_id {line_break_id}" - f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" - f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" - f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" - f" --load_strategy {self.cfg.load_config.load_strategy}" - f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" - f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" - f" --load_choices {self.cfg.load_config.load_choices}" - f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" - f" --ips {ips}" - f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" - f" --runner {self.cfg.model_config.runner}" - f" --convert {self.cfg.model_config.convert}" - f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" - f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" - f" --max_logprobs {self.cfg.model_config.max_logprobs}" - f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" - ) - - worker_store_true_flag = { - "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel, - "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching, - "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, - "do_profile": self.do_profile, - "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, - "disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace, - "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, - "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, - "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, - "enable_logprob": self.cfg.model_config.enable_logprob, - "lm_head_fp32": self.cfg.model_config.lm_head_fp32, - } - for worker_flag, value in worker_store_true_flag.items(): - if value: - arguments = arguments + f" --{worker_flag}" - - worker_default_none_flag = { - "num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override, - } - for worker_flag, value in worker_default_none_flag.items(): - if value: - arguments = arguments + f" --{worker_flag} {value}" - - if self.cfg.nnode > 1: - pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" - pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" - llm_logger.info(f"Launch worker service command: {pd_cmd}") - p = subprocess.Popen( - pd_cmd, - stdout=subprocess.PIPE, - shell=True, - preexec_fn=os.setsid, - ) - return p - - def _stop_profile(self): - """ - Stop profiling of the model server and reset variables. - """ - self.do_profile = 0 - while self.get_profile_block_num_signal.value[0] == 0: - time.sleep(1) - num_gpu_blocks = self.get_profile_block_num_signal.value[0] - self.cfg.cache_config.reset(num_gpu_blocks) - self.engine_service.resource_manager.reset_cache_config(self.cfg.cache_config) - if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine_service.start_cache_service(device_ids, self.ipc_signal_suffix) - - def check_health(self, time_interval_threashold=30): - """ - Check the health of the model server by checking whether all workers are alive. - - """ - if self.engine_service.worker_healthy_live_signal.value[0]: - elapsed_time = time.time() - self.engine_service.worker_healthy_live_signal.value[0] - if elapsed_time > time_interval_threashold: - return False, "Worker Service Not Healthy" - - return True, "" - - def launch_components(self): - if self.cfg.scheduler_config.splitwise_role != "mixed": - self.engine_service.split_mode_get_tasks() - if self.cfg.scheduler_config.name == "splitwise": - self.splitwise_receive_thread = threading.Thread( - target=self.engine_service.split_connector.start_receiver, args=() - ) - self.splitwise_receive_thread.daemon = True - self.splitwise_receive_thread.start() - - self.cfg.init_cache_info() - - role = self.cfg.scheduler_config.splitwise_role - host_ip = self.cfg.host_ip - disaggregate = self.cfg.disaggregate_info - if self.cfg.scheduler_config.name == "splitwise": - self.engine_service.scheduler.start(role, host_ip, disaggregate) - - if not envs.FD_ENABLE_MULTI_API_SERVER: - if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: - self.launched_expert_service_signal.value[0] = 1 - self.dp_processed = [] - self.dp_engine_worker_queue_server = [] - for i in range( - 1, - self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, - ): - if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[i]), - ) - else: - address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" - llm_logger.info(f"dp start queue service {address}") - self.dp_engine_worker_queue_server.append( - EngineWorkerQueue( - address=address, - is_server=True, - num_client=self.cfg.parallel_config.tensor_parallel_size, - local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, - ) - ) - self.dp_processed.append( - multiprocessing.Process( - target=start_data_parallel_service, - args=( - self.cfg, - i, - ), - ) - ) - llm_logger.info( - f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}" - + f" data parallel id {i}" - ) - self.dp_processed[-1].start() - while self.launched_expert_service_signal.value[i] == 0: - time.sleep(1) - - def check_worker_initialize_status(self): - """ - Check the initlialize status of workers by stdout logging - """ - - def detect_thread(): - for line in self.worker_proc.stdout: - line = line.decode("utf-8", errors="ignore") - if self.worker_init_status.get("finished", False): - break - if match := re.search( - r"Loading (?:safetensors )?checkpoint shards:\s*(\d+)", - line, - ): - self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100 - elif (match := re.search(r"Start load layer (\d+)", line)) or ( - match := re.search(r"set state for layer (\d+)", line) - ): - progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers - self.worker_init_status["layer_loadding"] = progress - if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1: - self.worker_init_status["finished"] = True - - self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) - self.checking_worker_status_thread.start() - - # display weight loadding progress - with tqdm(total=100, desc="Loading Weights") as pbar: - progress = 0 - while progress < 100: - progress = int(self.worker_init_status.get("weight_loadding", 0) * 100) - if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready(): - progress = 100 - pbar.update(progress - pbar.n) - pbar.refresh() - time.sleep(0.5) - if self.worker_proc.poll() is not None: - return False - - # display layer loadding progress - with tqdm(total=100, desc="Loading Layers") as pbar: - progress = 0 - while progress < 100: - progress = int(self.worker_init_status.get("layer_loadding", 0) * 100) - if self._worker_processes_ready(): - progress = 100 - pbar.update(progress - pbar.n) - pbar.refresh() - time.sleep(0.5) - if self.worker_proc.poll() is not None: - return False - - self.worker_init_status["finished"] = True - try: - self.checking_worker_status_thread.join(timeout=1) - except Exception: - pass - return True + pass diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 296727ffd6..9a370f4435 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -17,7 +17,13 @@ from __future__ import annotations import copy +import json +import multiprocessing import os +import re +import signal +import subprocess +import sys import threading import time import traceback @@ -30,6 +36,7 @@ import paddle import requests import zmq from opentelemetry import trace +from tqdm import tqdm from fastdeploy.engine.request import Request, RequestOutput, RequestType from fastdeploy.engine.resource_manager import ResourceManager @@ -66,7 +73,7 @@ class EngineService: Base class containing common engine functionality """ - def __init__(self, cfg, start_queue=True): + def __init__(self, cfg, start_queue=True, use_async_llm=False): """ Initializes the LLMEngine with the provided configuration. @@ -74,6 +81,7 @@ class EngineService: cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg + self.use_async_llm = use_async_llm if cfg.scheduler_config.splitwise_role != "mixed" or cfg.cache_config.enable_prefix_caching: if isinstance(self.cfg.cache_config.cache_queue_port, str): self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",") @@ -149,10 +157,21 @@ class EngineService: ) init_eplb_signals(cfg, current_suffix) + if self.use_async_llm: + # Add worker management attributes + self.worker_proc = None + self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0 + self.ipc_signal_suffix = None + self.cache_manager_processes = None + self._finalizer = weakref.finalize(self, self._exit_sub_services) - def start(self): + def start(self, async_llm_pid=None): self.running = True + + if self.use_async_llm: + self.start_worker_service(async_llm_pid) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.insert_task_to_worker_thread = threading.Thread( target=self._schedule_request_to_worker_v1, daemon=True @@ -167,6 +186,69 @@ class EngineService: self._register_to_router() + def start_worker_service(self, async_llm_pid=None): + # Initialize IPC signals for worker management + self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0] + self._init_worker_signals() + + # Create data processor if not exists + if not hasattr(self, "data_processor"): + self.create_data_processor() + + # Launch components: scheduler, cache_manager, expert_service et.al. + self.launch_components() + + # If block number is specified and model is deployed in splitwise mode, start cache manager first + if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + + # Start worker processes + self.worker_proc = self._start_worker_service() + time.sleep(5) + self.worker_init_status = dict() + result_container = {} + + def check_worker_initialize_status_func(res: dict): + res["worker_is_alive"] = True + if not self.check_worker_initialize_status(): + llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") + res["worker_is_alive"] = False + + self.check_worker_initialize_status_func_thread = threading.Thread( + target=check_worker_initialize_status_func, args=(result_container,), daemon=True + ) + self.check_worker_initialize_status_func_thread.start() + + # Wait model loading + while self.loaded_model_signal.value[0] == 0: + # Make sure worker process is alive + if not self.check_worker_initialize_status_func_thread.is_alive(): + return False + time.sleep(1) + + # If block number is not specified, let workers do profiling to determine the block number, + # and then start the cache manager + if self.do_profile: + self._stop_profile() + elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching: + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + + # Set cache manager signal + if self.cfg.scheduler_config.splitwise_role != "mixed": + self.launched_cache_manager_signal.value[0] = 1 + + # Worker launched + self.check_worker_initialize_status_func_thread.join() + if not result_container["worker_is_alive"]: + llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.") + return False + + # Start ZMQ service for communication with AsyncLLM + if async_llm_pid: + self.start_zmq_service(async_llm_pid) + def create_data_processor(self): self.input_processor = InputPreprocessor( self.cfg.model_config, @@ -970,7 +1052,13 @@ class EngineService: else: err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: - self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") + # The message "Context was terminated" is normal when closing a ZMQ context + if "Context was terminated" in str(err): + self.llm_logger.info( + "Engine stops inserting zmq task into scheduler due to ZMQ context termination (normal shutdown)." + ) + else: + self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") break request, insert_task = None, [] @@ -1336,6 +1424,58 @@ class EngineService: """ llm_logger.info("Exit sub services.....") self.running = False + + if self.use_async_llm: + # Clean up worker processes first (before closing multiprocessing services) + if hasattr(self, "worker_proc") and self.worker_proc is not None: + llm_logger.info("Cleaning up worker processes...") + try: + pgid = os.getpgid(self.worker_proc.pid) + os.killpg(pgid, signal.SIGTERM) + except Exception as e: + llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") + + # Clean up cache manager processes + if hasattr(self, "cache_manager_processes"): + llm_logger.info("Cleaning up cache manager processes...") + self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() + self.resource_manager.cache_manager.cache_ready_signal.clear() + for p in self.cache_manager_processes: + llm_logger.info(f"Killing cache manager process {p.pid}") + try: + pgid = os.getpgid(p.pid) + os.killpg(pgid, signal.SIGTERM) + except Exception as e: + llm_logger.error( + f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" + ) + + if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None: + llm_logger.info("Cleaning up cache_task_queue...") + # Check if cleanup method exists + if hasattr(self.cache_task_queue, "cleanup"): + self.cache_task_queue.cleanup() + elif hasattr(self.cache_task_queue, "manager"): + try: + llm_logger.info("Shutting down cache_task_queue manager...") + self.cache_task_queue.manager.shutdown() + except Exception as e: + llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}") + + if hasattr(self, "get_profile_block_num_signal"): + self.get_profile_block_num_signal.clear() + + self.worker_ready_signal.clear() + self.loaded_model_signal.clear() + + # Clean up other services + if hasattr(self, "dp_processed"): + for p in self.dp_processed: + llm_logger.info(f"Waiting for worker {p.pid} to exit") + p.join() + for p in self.dp_engine_worker_queue_server: + p.cleanup() + if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None: self.engine_worker_queue_server.cleanup() self.exist_task_signal.clear() @@ -1353,3 +1493,395 @@ class EngineService: self.recv_request_server.close() if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: self.recv_control_cmd_server.close() + + # 从 async_llm 移到 common_engine + def _worker_processes_ready(self): + """ + judge if all worker processes are ready + + """ + if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node: + return True + return False + + def _init_worker_signals(self): + """ + Initialize shared memory to indicate engine status + """ + # worker_ready_signal 用于worker进程感知engine是否启动完成 + worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) + self.worker_ready_signal = IPCSignal( + name="worker_ready_signal", + array=worker_ready_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + # launched_cache_manager_signal 用于感知engine是否启动了cache_manager + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": + launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) + self.launched_cache_manager_signal = IPCSignal( + name="launched_cache_manager_signal", + array=launched_cache_manager_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + # launched_expert_service_signal: Used to sense whether each expet_servic is started successfully + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + launched_expert_service_signal_data = np.zeros( + shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 + ) + self.launched_expert_service_signal = IPCSignal( + name="launched_expert_service_signal", + array=launched_expert_service_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + # loaded_model_signal: Used to detect whether each worker has completed model loading + loaded_model_signal_data = np.zeros([1], dtype=np.int32) + self.loaded_model_signal = IPCSignal( + name="loaded_model_signal", + array=loaded_model_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + if self.do_profile: + if paddle.is_compiled_with_custom_device("iluvatar_gpu"): + get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) + else: + get_profile_block_num = np.zeros([1], dtype=np.int32) + self.get_profile_block_num_signal = IPCSignal( + name="get_profile_block_num", + array=get_profile_block_num, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + + def _setting_environ_variables(self): + """ + 配置环境变量 + """ + variables = { + "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, + "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")), + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", + "FLAGS_use_append_attn": 1, + "NCCL_ALGO": "Ring", + "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), + "OMP_NUM_THREADS": 3, + } + # environment variables needed by Dy2St + variables.update( + { + "SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"), + "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), + "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), + "SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), + "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), + "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), + "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( + "FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1" + ), + "FLAGS_parameters_persistent_mode_in_dy2st": os.getenv( + "FLAGS_parameters_persistent_mode_in_dy2st", default="1" + ), + } + ) + + if self.cfg.scheduler_config.splitwise_role != "mixed": + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 + else: + variables["FLAGS_use_pd_disaggregation"] = 1 + # TODO dynamic load environment variable + if self.cfg.scheduler_config.splitwise_role == "prefill": + variables["FLAGS_fmt_write_cache_completed_signal"] = 1 + + if self.cfg.model_config.enable_mm: + variables["FLAGS_max_partition_size"] = 1024 + + command_prefix = "" + for k, v in variables.items(): + command_prefix += f"{k}={v} " + return command_prefix + + def _start_worker_service(self): + """ + start gpu worker service + + """ + log_dir = os.getenv("FD_LOG_DIR", default="log") + command_prefix = self._setting_environ_variables() + current_file_path = os.path.abspath(__file__) + current_dir_path = os.path.split(current_file_path)[0] + # TODO + uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u" + pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch" + pd_cmd = pd_cmd + f" --log_dir {log_dir}" + + worker_path = "../worker/worker_process.py" + py_script = os.path.join(current_dir_path, worker_path) + + ori_vocab_size = ( + len(self.data_processor.tokenizer.sp_model) + if hasattr(self.data_processor.tokenizer, "sp_model") + else len(self.data_processor.tokenizer.vocab) + ) + + think_end_id = self.data_processor.tokenizer.get_vocab().get("", -1) + if think_end_id > 0: + llm_logger.info(f"Get think_end_id {think_end_id} from vocab.") + else: + llm_logger.info("No token found in vocabulary, the model can not do reasoning.") + image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1) + line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1) + + ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port) + ips = None + if self.cfg.ips is not None: + ips = ",".join(self.cfg.ips) + arguments = ( + f" --devices {self.cfg.parallel_config.device_ids} {py_script}" + f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}" + f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" + f" --model {self.cfg.model_config.model!s}" + f" --device_ids {self.cfg.parallel_config.device_ids}" + f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" + f" --engine_worker_queue_port {ports}" + f" --pod_ip {self.cfg.master_ip}" + f" --block_size {self.cfg.cache_config.block_size}" + f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" + f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" + f" --pad_token_id {self.data_processor.pad_token_id}" + f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}" + f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}" + f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}" + f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" + f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" + f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}" + f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" + f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" + f" --ori_vocab_size {ori_vocab_size}" + f" --think_end_id {think_end_id}" + f" --image_patch_id {image_patch_id}" + f" --line_break_id {line_break_id}" + f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" + f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" + f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" + f" --load_strategy {self.cfg.load_config.load_strategy}" + f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" + f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" + f" --load_choices {self.cfg.load_config.load_choices}" + f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" + f" --ips {ips}" + f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" + f" --runner {self.cfg.model_config.runner}" + f" --convert {self.cfg.model_config.convert}" + f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" + f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" + f" --max_logprobs {self.cfg.model_config.max_logprobs}" + f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" + ) + if self.cfg.structured_outputs_config.logits_processors is not None: + arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}" + + worker_store_true_flag = { + "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel, + "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching, + "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, + "do_profile": self.do_profile, + "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, + "disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace, + "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, + "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, + "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, + "enable_logprob": self.cfg.model_config.enable_logprob, + "lm_head_fp32": self.cfg.model_config.lm_head_fp32, + } + for worker_flag, value in worker_store_true_flag.items(): + if value: + arguments = arguments + f" --{worker_flag}" + + worker_default_none_flag = { + "num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override, + } + for worker_flag, value in worker_default_none_flag.items(): + if value: + arguments = arguments + f" --{worker_flag} {value}" + + if self.cfg.nnode > 1: + pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" + pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" + llm_logger.info(f"Launch worker service command: {pd_cmd}") + p = subprocess.Popen( + pd_cmd, + stdout=subprocess.PIPE, + shell=True, + preexec_fn=os.setsid, + ) + return p + + def _stop_profile(self): + """ + Stop profiling of the model server and reset variables. + """ + self.do_profile = 0 + while self.get_profile_block_num_signal.value[0] == 0: + time.sleep(1) + num_gpu_blocks = self.get_profile_block_num_signal.value[0] + self.cfg.cache_config.reset(num_gpu_blocks) + self.resource_manager.reset_cache_config(self.cfg.cache_config) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + + def check_health(self, time_interval_threashold=30): + """ + Check the health of the model server by checking whether all workers are alive. + + """ + if self.worker_healthy_live_signal.value[0]: + elapsed_time = time.time() - self.worker_healthy_live_signal.value[0] + if elapsed_time > time_interval_threashold: + return False, "Worker Service Not Healthy" + + return True, "" + + def launch_components(self): + if self.cfg.scheduler_config.splitwise_role != "mixed": + # 单机逻辑 + self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) + self.splitwise_receive_thread.daemon = True + self.splitwise_receive_thread.start() + + role = self.cfg.scheduler_config.splitwise_role + host_ip = self.cfg.host_ip + disaggregate = self.cfg.disaggregate_info + request_queues_for_dp_ipc = None + result_queue_for_dp_ipc = None + if self.cfg.scheduler_config.name == "splitwise": + self.scheduler.start(role, host_ip, disaggregate) + elif self.cfg.scheduler_config.name == "dp": + request_queues_for_dp_ipc = [] + result_queue_for_dp_ipc = multiprocessing.Queue() + for i in range(self.cfg.parallel_config.data_parallel_size): + request_queues_for_dp_ipc.append(multiprocessing.Queue()) + self.scheduler.start( + self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, + ) + + if not envs.FD_ENABLE_MULTI_API_SERVER: + if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: + self.launched_expert_service_signal.value[0] = 1 + self.dp_processed = [] + self.dp_engine_worker_queue_server = [] + for i in range( + 1, + self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, + ): + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int(self.cfg.parallel_config.engine_worker_queue_port[i]), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" + + llm_logger.info(f"dp start queue service {address}") + self.dp_engine_worker_queue_server.append( + EngineWorkerQueue( + address=address, + is_server=True, + num_client=self.cfg.parallel_config.tensor_parallel_size, + local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, + ) + ) + from fastdeploy.engine.expert_service import ( + start_data_parallel_service, + ) + + self.dp_processed.append( + multiprocessing.Process( + target=start_data_parallel_service, + args=( + self.cfg, + i, + ), + ) + ) + llm_logger.info( + f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}" + + f" data parallel id {i}" + ) + self.dp_processed[-1].start() + while self.launched_expert_service_signal.value[i] == 0: + time.sleep(1) + + def check_worker_initialize_status(self): + """ + Check the initlialize status of workers by stdout logging + """ + + def detect_thread(): + for line in self.worker_proc.stdout: + line = line.decode("utf-8", errors="ignore") + if self.worker_init_status.get("finished", False): + break + if match := re.search( + r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)", + line, + ): + self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100 + elif (match := re.search(r"Start load layer (\d+)", line)) or ( + match := re.search(r"set state for layer (\d+)", line) + ): + progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers + self.worker_init_status["layer_loadding"] = progress + if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1: + self.worker_init_status["finished"] = True + + self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) + self.checking_worker_status_thread.start() + + # display weight loadding progress + with tqdm(total=100, desc="Loading Weights") as pbar: + progress = 0 + while progress < 100: + progress = int(self.worker_init_status.get("weight_loadding", 0) * 100) + if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready(): + progress = 100 + pbar.update(progress - pbar.n) + pbar.refresh() + time.sleep(0.5) + if self.worker_proc.poll() is not None: + return False + + # display layer loadding progress + with tqdm(total=100, desc="Loading Layers") as pbar: + progress = 0 + while progress < 100: + progress = int(self.worker_init_status.get("layer_loadding", 0) * 100) + if self._worker_processes_ready(): + progress = 100 + pbar.update(progress - pbar.n) + pbar.refresh() + time.sleep(0.5) + if self.worker_proc.poll() is not None: + return False + + self.worker_init_status["finished"] = True + try: + self.checking_worker_status_thread.join(timeout=1) + except Exception: + pass + return True diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index e64716fe81..c4080a4cc6 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -386,6 +386,7 @@ class CompletionOutput: draft_token_ids: list[int] = None text: Optional[str] = None reasoning_content: Optional[str] = None + reasoning_token_num: Optional[int] = 0 tool_calls: Optional[ToolCall] = None def to_dict(self): @@ -404,6 +405,7 @@ class CompletionOutput: "draft_token_ids": self.draft_token_ids, "text": self.text, "reasoning_content": self.reasoning_content, + "reasoning_token_num": self.reasoning_token_num, } @classmethod @@ -425,6 +427,7 @@ class CompletionOutput: f"decode_type={self.decode_type}, " f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " + f"reasoning_token_num={self.reasoning_token_num}, " f"logprobs={self.logprobs}, " f"top_logprobs={self.top_logprobs}, " f"draft_top_logprobs={self.draft_top_logprobs}, " diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 366244e524..7c2e3cbd17 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -58,7 +58,7 @@ class BaseDataProcessor(ABC): def set_value(req, key, value): value = getattr(self.generation_config, key, value) if isinstance(req, dict): - if key not in req: + if key not in req or req[key] is None: req[key] = value else: if req.get(key) is None: diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index b3d206c1c6..f24fe06746 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -191,7 +191,7 @@ class ZmqServerBase(ABC): return str(e), None def recv_result_handle(self): - while True: + while self.running: try: with self.response_token_lock: client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) diff --git a/tests/engine/test_async_llm.py b/tests/engine/test_async_llm.py index 19a2ad3ead..50235f16f3 100644 --- a/tests/engine/test_async_llm.py +++ b/tests/engine/test_async_llm.py @@ -21,14 +21,15 @@ import uuid import weakref from fastdeploy.engine.args_utils import EngineArgs -from fastdeploy.engine.async_llm import AsyncLLMEngine +from fastdeploy.engine.async_llm import AsyncLLM from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.utils import EngineError MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle" class TestAsyncLLMEngine(unittest.TestCase): - """Test case for AsyncLLMEngine functionality""" + """Test case for AsyncLLM functionality""" PROMPTS = [ "Hello, my name is", @@ -39,7 +40,7 @@ class TestAsyncLLMEngine(unittest.TestCase): @classmethod def setUpClass(cls): - """Set up AsyncLLMEngine for testing""" + """Set up AsyncLLM for testing""" try: # Use unique ports to avoid conflicts base_port = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) @@ -53,17 +54,24 @@ class TestAsyncLLMEngine(unittest.TestCase): cache_queue_port=cache_port, ) - cls.engine = AsyncLLMEngine.from_engine_args(engine_args) - success = cls.engine.start() + # Use base_port as async engine pid to align with ZMQ routing id + cls.engine = AsyncLLM.from_engine_args(engine_args, pid=base_port) + + cls.loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls.loop) + success = cls.loop.run_until_complete(cls.engine.start()) + + # Initialize connections after engine service is ready + cls.loop.run_until_complete(cls.engine.init_connections()) if not success: - raise RuntimeError("Failed to start AsyncLLMEngine") + raise RuntimeError("Failed to start AsyncLLM") # Use weak reference to avoid circular reference cls.engine_ref = weakref.ref(cls.engine) except Exception as e: - print(f"Setting up AsyncLLMEngine failed: {e}") + print(f"Setting up AsyncLLM failed: {e}") raise @classmethod @@ -75,6 +83,9 @@ class TestAsyncLLMEngine(unittest.TestCase): # Force stop the engine first cls.engine.running = False + # asyncio.run(cls.engine.shutdown()) + cls.loop.run_until_complete(cls.engine.shutdown()) + # Try sync cleanup first if hasattr(cls.engine, "_exit_sub_services"): try: @@ -104,49 +115,74 @@ class TestAsyncLLMEngine(unittest.TestCase): """Set up before each test method""" if hasattr(self, "engine") and self.engine: - # 清理可能残留的output_handler - if hasattr(self.engine, "output_handler") and self.engine.output_handler: - if not self.engine.output_handler.done(): - print("Cleaning up previous output_handler...") - self.engine.output_handler.cancel() - self.engine.output_handler = None - - # 清理输出处理器的队列 - if hasattr(self.engine, "output_processor") and self.engine.output_processor: - self.engine.output_processor.request_queues.clear() - print(f"Test setup completed: {self._testMethodName}") def tearDown(self): """Clean up after each test method""" if hasattr(self, "engine") and self.engine: - - if hasattr(self.engine, "output_handler") and self.engine.output_handler: - if not self.engine.output_handler.done(): - print("Cleaning up output_handler after test...") - self.engine.output_handler.cancel() - self.engine.output_handler = None - - if hasattr(self.engine, "output_processor") and self.engine.output_processor: - self.engine.output_processor.request_queues.clear() - print(f"Test cleanup completed: {self._testMethodName}") def run_async_test(self, coro): """Helper method to run async tests""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + try: - return loop.run_until_complete(coro) + return self.loop.run_until_complete(coro) finally: - loop.close() + pass def test_engine_initialization(self): """Test that the engine initializes correctly""" self.assertIsNotNone(self.engine) - self.assertTrue(self.engine.is_started) + # EngineServiceClient._running indicates underlying engine_service started + self.assertTrue(self.engine._running) self.assertTrue(self.engine.running) + def test_engine_service_start_exception_logs_and_reraises(self): + """EngineServiceClient.start should log and re-raise on internal exception""" + + async def _test(): + from unittest.mock import patch + + from fastdeploy.engine.async_llm import EngineServiceClient + + class DummyCfg: + pass + + client = EngineServiceClient(DummyCfg(), pid=12345) + + # Force _start_engine_process to raise so that start() enters exception block + with patch.object(client, "_start_engine_process", side_effect=RuntimeError("boom")): + with self.assertRaises(RuntimeError): + await client.start() + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_engine_service_start_process_failure(self): + """_start_engine_process should log and re-raise on process creation failure""" + + async def _test(): + from unittest.mock import patch + + from fastdeploy.engine.async_llm import EngineServiceClient + + class DummyCfg: + pass + + client = EngineServiceClient(DummyCfg(), pid=12345) + + # Patch multiprocessing.Process to raise so that exception block is hit + with patch("multiprocessing.Process", side_effect=RuntimeError("boom")): + with self.assertRaises(RuntimeError): + client._start_engine_process() + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + def test_single_prompt_generation(self): """Test generating response for a single prompt""" @@ -207,6 +243,41 @@ class TestAsyncLLMEngine(unittest.TestCase): results = self.run_async_test(_test()) self.assertEqual(len(results), 2) + def test_generation_with_multiple_choices(self): + """Test generating multiple choices with SamplingParams.n""" + + async def _test(): + # Use dict prompt to cover stream/include_stop_str_in_output flags + prompt = { + "prompt": "Hello, my name is", + "stream": True, + "include_stop_str_in_output": False, + "n": 2, + } + # Do not set n in SamplingParams so that prompt['n'] takes effect + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20) + + outputs = [] + generator = None + try: + generator = self.engine.generate(prompt, sampling_params) + async for output in generator: + outputs.append(output) + finally: + if generator is not None: + try: + await generator.aclose() + except Exception: + pass + + # Expect at least 2 finished outputs (one per choice) + finished_outputs = [o for o in outputs if getattr(o, "finished", False)] + self.assertGreaterEqual(len(finished_outputs), 2) + return outputs + + outputs = self.run_async_test(_test()) + self.assertGreater(len(outputs), 0) + async def _generate_single(self, prompt, sampling_params, request_id=None): """Helper method to generate response for a single prompt""" outputs = [] @@ -224,159 +295,74 @@ class TestAsyncLLMEngine(unittest.TestCase): pass return outputs - def test_async_request_queue_error_handling(self): - """Test AsyncRequestQueue error handling""" - - async def _test(): - from fastdeploy.engine.async_llm import AsyncRequestQueue - from fastdeploy.utils import EngineError - - # Test put_error and get error - queue = AsyncRequestQueue("test_request") - test_error = EngineError("Test error", error_code=500) - - await queue.put_error(test_error) - self.assertTrue(queue.finished) - - # Test get raises the error - with self.assertRaises(EngineError): - await queue.get() - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_async_request_queue_get_nowait(self): - """Test AsyncRequestQueue get_nowait functionality""" - - async def _test(): - from fastdeploy.engine.async_llm import AsyncRequestQueue - - queue = AsyncRequestQueue("test_request") - - # Test get_nowait when queue is empty - result = queue.get_nowait() - self.assertIsNone(result) - - # Test put and get_nowait with actual output - from unittest.mock import Mock - - mock_output = Mock() - mock_output.finished = False - await queue.put(mock_output) - - result = queue.get_nowait() - self.assertIsNotNone(result) - - # Test get_nowait with error in queue - test_error = Exception("Test error") - await queue.put_error(test_error) - - with self.assertRaises(Exception): - queue.get_nowait() - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_async_output_processor_abort_request(self): - """Test AsyncOutputProcessor abort_request functionality""" - - async def _test(): - from fastdeploy.engine.async_llm import ( - AsyncOutputProcessor, - AsyncRequestQueue, - ) - from fastdeploy.utils import EngineError - - processor = AsyncOutputProcessor() - request_id = "test_abort_request" - queue = AsyncRequestQueue(request_id) - - # Register request - await processor.register_request(request_id, queue) - self.assertIn(request_id, processor.request_queues) - - # Abort request - await processor.abort_request(request_id) - - # Verify request is removed and error is put in queue - self.assertNotIn(request_id, processor.request_queues) - - # Verify error was put in queue - with self.assertRaises(EngineError) as cm: - await queue.get() - self.assertEqual(cm.exception.error_code, 499) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_async_output_processor_propagate_error(self): - """Test AsyncOutputProcessor propagate_error functionality""" - - async def _test(): - from fastdeploy.engine.async_llm import ( - AsyncOutputProcessor, - AsyncRequestQueue, - ) - - processor = AsyncOutputProcessor() - - # Register multiple requests - queues = [] - for i in range(3): - request_id = f"test_request_{i}" - queue = AsyncRequestQueue(request_id) - await processor.register_request(request_id, queue) - queues.append(queue) - - # Propagate error to all queues - test_error = Exception("Test propagation error") - await processor.propagate_error(test_error) - - # Verify all queues are cleared - self.assertEqual(len(processor.request_queues), 0) - - # Verify all queues received the error - for queue in queues: - with self.assertRaises(Exception): - await queue.get() - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_process_single_output_error_handling(self): - """Test _process_single_output error handling""" + def test_process_output_error_handling(self): + """Test _process_output error handling""" async def _test(): from unittest.mock import Mock from fastdeploy.engine.async_llm import AsyncOutputProcessor - # Create processor with mock tokenizer that raises exception - mock_tokenizer = Mock() - mock_tokenizer.decode.side_effect = Exception("Decode error") - processor = AsyncOutputProcessor(mock_tokenizer) + # Create processor with mock data_processor that raises exception + mock_data_processor = Mock() + mock_data_processor.process_response_dict.side_effect = Exception("Decode error") + processor = AsyncOutputProcessor(mock_data_processor) - # Create mock output without text attribute - mock_output = Mock() - mock_output.outputs = Mock() - mock_output.outputs.token_ids = [1, 2, 3] - # Don't set text attribute to test the error handling - if hasattr(mock_output.outputs, "text"): - delattr(mock_output.outputs, "text") + # Create response dict without text field + response_dict = { + "request_id": "test", + "finished": True, + "outputs": { + "index": 0, + "send_idx": 0, + "token_ids": [1, 2, 3], + }, + "metrics": {"arrival_time": 0.0}, + } # Process the output - result = processor._process_single_output(mock_output) + result = processor._process_output(response_dict) # Verify text was set to empty string on error - self.assertEqual(result.outputs.text, "") + self.assertIn("outputs", result) + self.assertEqual(result["outputs"].get("text", ""), "") + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_process_output_processor_returns_none(self): + """Test _process_output when data_processor returns None""" + + async def _test(): + from unittest.mock import Mock + + from fastdeploy.engine.async_llm import AsyncOutputProcessor + + # Create processor with mock data_processor that returns None + mock_data_processor = Mock() + mock_data_processor.process_response_dict.return_value = None + processor = AsyncOutputProcessor(mock_data_processor) + + # Create response dict without text field + response_dict = { + "request_id": "test", + "finished": True, + "outputs": { + "index": 0, + "send_idx": 0, + "token_ids": [1, 2, 3], + }, + "metrics": {"arrival_time": 0.0}, + } + + # Process the output + result = processor._process_output(response_dict) + + # Verify text was set to empty string when processor returns None + self.assertIn("outputs", result) + self.assertEqual(result["outputs"].get("text", ""), "") return True @@ -384,7 +370,7 @@ class TestAsyncLLMEngine(unittest.TestCase): self.assertTrue(result) def test_engine_abort_request(self): - """Test AsyncLLMEngine abort_request functionality""" + """Test AsyncLLM abort_request functionality""" async def _test(): # Test calling abort_request directly without mocking @@ -398,29 +384,21 @@ class TestAsyncLLMEngine(unittest.TestCase): result = self.run_async_test(_test()) self.assertTrue(result) - def test_engine_abort_request_with_error(self): - """Test AsyncLLMEngine abort_request error handling""" + def test_engine_abort_request_with_cleanup_error(self): + """abort_request should handle cleanup_request exceptions gracefully""" async def _test(): - from unittest.mock import AsyncMock + from unittest.mock import AsyncMock, patch - # Temporarily patch the output_processor to simulate error - original_processor = self.engine.output_processor + mock_cm = AsyncMock() + mock_cm.cleanup_request.side_effect = Exception("cleanup failed") + mock_cm.running = True - try: - # Mock output_processor abort_request to raise error - mock_processor = AsyncMock() - mock_processor.abort_request.side_effect = Exception("Abort error") - self.engine.output_processor = mock_processor + with patch.object(self.engine, "connection_manager", mock_cm): + # Should not raise even if cleanup_request fails + await self.engine.abort_request("test_abort_error") - request_id = "test_abort_error" - # This should not raise an exception, just log the error - await self.engine.abort_request(request_id) - - return True - finally: - # Restore original processor - self.engine.output_processor = original_processor + return True result = self.run_async_test(_test()) self.assertTrue(result) @@ -443,734 +421,18 @@ class TestAsyncLLMEngine(unittest.TestCase): result = self.run_async_test(_test()) self.assertTrue(result) - def test_generate_with_generator_exit(self): - """Test generate handling GeneratorExit exception""" - - async def _test(): - # This test just verifies the code path exists - # We don't need to actually trigger GeneratorExit in the test - # since it's handled in the generate method - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_output_handler_loop_coverage(self): - """Test output handler loop related code paths""" - - async def _test(): - # Test the output handler start/stop mechanism - if hasattr(self.engine, "_start_output_handler"): - # This should not fail - self.engine._start_output_handler() - - # Verify output handler exists - self.assertIsNotNone(self.engine.output_handler) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_simple_error_scenarios(self): - """Test simple error scenarios without complex mocking""" - - async def _test(): - # Test abort_request with non-existent request - await self.engine.abort_request("non_existent_request") - - # Test various edge cases that don't require complex setup - from fastdeploy.engine.async_llm import AsyncRequestQueue - - queue = AsyncRequestQueue("test") - - # Test queue properties - self.assertEqual(queue.size, 0) - self.assertFalse(queue.finished) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_common_engine_scheduler_loop_thread_pool_error_handling(self): - """Test the actual scheduler loop thread pool error handling in common_engine.py""" - - async def _test(): - from concurrent.futures import ThreadPoolExecutor - from unittest.mock import Mock - - from fastdeploy.engine.args_utils import EngineArgs - from fastdeploy.engine.common_engine import EngineService - - try: - # Create a real EngineService instance - engine_args = EngineArgs( - model=MODEL_NAME, - max_model_len=512, - tensor_parallel_size=1, - engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 2, - cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 2, - max_num_seqs=4, # Reduce to avoid batch token error - max_num_batched_tokens=2048, # Set appropriately - ) - config = engine_args.create_engine_config() - engine_service = EngineService(config, start_queue=False) - - # Mock necessary components to make the scheduler loop runnable - engine_service.resource_manager = Mock() - engine_service.resource_manager.waiting = [] - engine_service.resource_manager.schedule.return_value = [] - - # Create a real ThreadPoolExecutor but override its submit method - real_pool = ThreadPoolExecutor(max_workers=1) - - # Track which error type to raise - error_type = {"shutdown": True} - - def mock_submit_with_error(*args, **kwargs): - if error_type["shutdown"]: - # First test: shutdown error (should trigger lines 713-715) - raise RuntimeError("cannot schedule new futures after shutdown") - else: - # Second test: non-shutdown error (should trigger line 717) - raise RuntimeError("some other pool error") - - # Replace the submit method - real_pool.submit = mock_submit_with_error - - # Mock the scheduler loop to simulate the exact conditions - loop_iterations = 0 - max_iterations = 2 - - def mock_scheduler_loop(): - nonlocal loop_iterations, engine_service - - while loop_iterations < max_iterations: - loop_iterations += 1 - - # Simulate the conditions that lead to get_request_pool.submit() call - # This mimics the logic in common_engine.py around line 711 - if len(engine_service.resource_manager.waiting) == 0: - try: - # This is line 711: get_request_pool.submit(_fetch_request) - real_pool.submit(lambda: None) # Mock _fetch_request - except RuntimeError as e: - # This is line 712-717: the exception handling we want to test - if "shutdown" in str(e): - # Lines 713-715: shutdown detection and break - print("Thread pool shutdown detected, exiting scheduler loop") - break - else: - # Line 717: re-raise non-shutdown errors - print(f"Re-raising non-shutdown error: {e}") - raise - - # Switch error type for second iteration - if loop_iterations == 1: - error_type["shutdown"] = False - - # Run the mock scheduler loop to trigger the error handling - try: - mock_scheduler_loop() - except RuntimeError as e: - # This should be the non-shutdown error that gets re-raised - self.assertNotIn("shutdown", str(e)) - self.assertIn("some other pool error", str(e)) - - # Clean up - real_pool.shutdown(wait=False) - del engine_service - - return True - - except Exception as e: - print(f"Common engine test exception: {e}") - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_process_outputs_edge_cases(self): - """Test AsyncOutputProcessor.process_outputs edge cases""" - - async def _test(): - from unittest.mock import Mock - - from fastdeploy.engine.async_llm import ( - AsyncOutputProcessor, - AsyncRequestQueue, - ) - - processor = AsyncOutputProcessor() - - # Test case 1: Empty outputs (covers line 115: return) - await processor.process_outputs({}) - await processor.process_outputs(None) - - # Test case 2: Request ID not in queues (covers line 121: continue) - unknown_outputs = {"unknown_request": [Mock()]} - await processor.process_outputs(unknown_outputs) - - # Test case 3: Non-list output (covers line 127: output_list = [output_list]) - request_id = "test_request" - queue = AsyncRequestQueue(request_id) - await processor.register_request(request_id, queue) - - # Create single output (not in list) - single_output = Mock() - single_output.finished = True - - # This should trigger the non-list conversion - outputs_dict = {request_id: single_output} # Single output, not list - await processor.process_outputs(outputs_dict) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_shutdown_exception_handling(self): - """Test shutdown method exception handling""" - - async def _test(): - import asyncio - from unittest.mock import AsyncMock, Mock, patch - - # Create a test engine to test shutdown - from fastdeploy.engine.args_utils import EngineArgs - from fastdeploy.engine.async_llm import AsyncLLMEngine - - engine_args = EngineArgs( - model=MODEL_NAME, - max_model_len=512, - tensor_parallel_size=1, - engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 4, - cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 4, - max_num_seqs=4, # Reduce to avoid batch token error - max_num_batched_tokens=2048, # Set appropriately - ) - - test_engine = AsyncLLMEngine.from_engine_args(engine_args) - - # Mock all signals to prevent cleanup errors - - test_engine.worker_ready_signal = Mock() - test_engine.worker_ready_signal.clear = Mock() - test_engine.loaded_model_signal = Mock() - test_engine.loaded_model_signal.clear = Mock() - test_engine.get_profile_block_num_signal = Mock() - test_engine.get_profile_block_num_signal.clear = Mock() - - try: - # Test shutdown with various exception scenarios - test_engine.running = True - - # Mock output_processor to test exception handling (lines 571-574) - mock_output_processor = AsyncMock() - mock_output_processor.propagate_error.side_effect = Exception("Propagate error failed") - test_engine.output_processor = mock_output_processor - - # Mock output_handler to test timeout and cancellation scenarios (lines 577-586) - mock_output_handler = AsyncMock() - mock_output_handler.done.return_value = False - mock_output_handler.cancel.return_value = None - - # Test timeout scenario (line 583: TimeoutError) - async def mock_wait_timeout(*args, **kwargs): - raise asyncio.TimeoutError() - - # Test general exception scenario (line 585: Exception) - async def mock_wait_exception(*args, **kwargs): - raise Exception("Handler error") - - test_engine.output_handler = mock_output_handler - - # Test the shutdown method - with patch("asyncio.wait_for", side_effect=mock_wait_timeout): - await test_engine.shutdown() - - # Test with general exception - test_engine.running = True - test_engine.output_handler = mock_output_handler - with patch("asyncio.wait_for", side_effect=mock_wait_exception): - await test_engine.shutdown() - - # Test engine_service stopping with exception (lines 591-597) - mock_engine_service = Mock() - mock_engine_service.running = True - test_engine.engine_service = mock_engine_service - test_engine._exit_sub_services = Mock(side_effect=Exception("Exit services failed")) - - test_engine.running = True - await test_engine.shutdown() - - finally: - # Clean up - del test_engine - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_worker_status_check_branches(self): - """Test worker status check""" - - async def _test(): - - import numpy as np - - # Don't test with the real engine to avoid hanging - # Instead, test the logic directly - # Mock the check_worker_initialize_status logic - def mock_check_worker_status(worker_ready_signal_value, worker_num_per_node): - # This simulates the logic in lines 609-611 - if np.sum(worker_ready_signal_value) == worker_num_per_node: - return True # Line 610 - return False # Line 611 - - # Test case 1: All workers ready (line 610: return True) - worker_signal_all_ready = np.array([1, 1, 1, 1]) # 4 workers, all ready - result = mock_check_worker_status(worker_signal_all_ready, 4) - self.assertTrue(result) - - # Test case 2: Not all workers ready (line 611: return False) - worker_signal_partial = np.array([1, 1, 0, 1]) # 4 workers, 1 not ready - result = mock_check_worker_status(worker_signal_partial, 4) - self.assertFalse(result) - - # Test case 3: No workers ready (line 611: return False) - worker_signal_none = np.array([0, 0, 0, 0]) # 4 workers, none ready - result = mock_check_worker_status(worker_signal_none, 4) - self.assertFalse(result) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_output_handler_loop_exceptions(self): - """Test output handler loop exception handling""" - - async def _test(): - import asyncio - from unittest.mock import AsyncMock, patch - - # Test the output handler loop exception paths - if hasattr(self.engine, "_start_output_handler"): - # Stop existing handler first - if hasattr(self.engine, "output_handler") and self.engine.output_handler: - self.engine.output_handler.cancel() - self.engine.output_handler = None - - # Mock engine_service to be None to test line 536-537 - original_engine_service = self.engine.engine_service - - try: - # Test engine_service None scenario - self.engine.engine_service = None - self.engine.running = True - - # Start the output handler - self.engine._start_output_handler() - - # Let it run briefly to hit the None check - await asyncio.sleep(0.01) - - # Stop the handler - if self.engine.output_handler: - self.engine.output_handler.cancel() - - # Test CancelledError handling (lines 550-551) - self.engine.running = True - self.engine.engine_service = original_engine_service - - # Mock scheduler to raise CancelledError - with patch.object( - original_engine_service.scheduler, "get_results", side_effect=asyncio.CancelledError() - ): - self.engine._start_output_handler() - await asyncio.sleep(0.01) - if self.engine.output_handler: - self.engine.output_handler.cancel() - - # Test general Exception handling (lines 552-554) - self.engine.running = True - with patch.object( - original_engine_service.scheduler, "get_results", side_effect=Exception("Test exception") - ): - # Mock propagate_error to avoid side effects - with patch.object(self.engine.output_processor, "propagate_error", new=AsyncMock()): - self.engine._start_output_handler() - await asyncio.sleep(0.01) - if self.engine.output_handler: - self.engine.output_handler.cancel() - - finally: - # Restore original engine_service - self.engine.engine_service = original_engine_service - self.engine.running = True - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_config_conditions_and_branches(self): - """Test various config conditions""" - - async def _test(): - from unittest.mock import Mock, patch - - from fastdeploy.engine.args_utils import EngineArgs - from fastdeploy.engine.async_llm import AsyncLLMEngine - - # Test splitwise_role conditions and cache manager start - try: - # Create engine with specific config to test branches - engine_args = EngineArgs( - model=MODEL_NAME, - max_model_len=512, - tensor_parallel_size=1, - engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 6, - cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 6, - num_gpu_blocks_override=50, # Set to avoid profiling - ) - - test_engine = AsyncLLMEngine.from_engine_args(engine_args) - - # Mock all signals to prevent cleanup errors - test_engine.worker_ready_signal = Mock() - test_engine.worker_ready_signal.clear = Mock() - test_engine.loaded_model_signal = Mock() - test_engine.loaded_model_signal.clear = Mock() - test_engine.get_profile_block_num_signal = Mock() - test_engine.get_profile_block_num_signal.clear = Mock() - - # Mock cfg to test different splitwise_role values - test_engine.cfg.scheduler_config.splitwise_role = "decode" # Not "mixed" - test_engine.cfg.parallel_config.device_ids = "0,1" - - # Mock cache service methods - test_engine.engine_service.start_cache_service = Mock(return_value=[]) - test_engine.launched_cache_manager_signal = Mock() - test_engine.launched_cache_manager_signal.value = [0] - - # This tests the tokenizer acquisition from input_processor and data_processor - mock_tokenizer = Mock() - - # Test input_processor tokenizer branch (line 231) - with patch.object(test_engine, "input_processor") as mock_input: - mock_input.tokenizer = mock_tokenizer - - # Simulate the tokenizer assignment logic - tokenizer = None - if hasattr(mock_input, "tokenizer"): - tokenizer = mock_input.tokenizer - self.assertEqual(tokenizer, mock_tokenizer) - - # This should trigger cache manager start (lines 267-268) - # Simulate the condition in start() method - if not test_engine.do_profile and test_engine.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = test_engine.cfg.parallel_config.device_ids.split(",") - test_engine.cache_manager_processes = test_engine.engine_service.start_cache_service( - device_ids, "test_suffix" - ) - - # Test enable_prefix_caching branch (lines 300-302) - test_engine.cfg.cache_config.enable_prefix_caching = True - if test_engine.do_profile == 0: # This is False due to num_gpu_blocks_override - pass # This would trigger the elif condition - elif test_engine.cfg.cache_config.enable_prefix_caching: - device_ids = test_engine.cfg.parallel_config.device_ids.split(",") - test_engine.cache_manager_processes = test_engine.engine_service.start_cache_service( - device_ids, "test_suffix" - ) - - # Test launched_cache_manager_signal setting (line 306) - if test_engine.cfg.scheduler_config.splitwise_role != "mixed": - test_engine.launched_cache_manager_signal.value[0] = 1 - - await test_engine.shutdown() - del test_engine - - except Exception as e: - print(f"Config test exception (expected): {e}") - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_worker_health_and_progress_tracking(self): - """Test worker health check and progress tracking""" - - async def _test(): - import time - from unittest.mock import Mock, patch - - # Test worker health check logic (lines 880-897) - if hasattr(self.engine, "engine_service") and hasattr( - self.engine.engine_service, "worker_healthy_live_signal" - ): - # Mock the worker health signal - mock_signal = Mock() - mock_signal.value = [time.time()] # Current time - - with patch.object(self.engine.engine_service, "worker_healthy_live_signal", mock_signal): - # Test health check with recent timestamp - if hasattr(self.engine, "_check_worker_health"): - try: - health_status, message = self.engine._check_worker_health(time_interval_threashold=10) - # Should be healthy with recent timestamp - except Exception: - pass # Method might not exist or have different signature - - # Test with old timestamp to trigger unhealthy condition - mock_signal.value = [time.time() - 20] # 20 seconds ago - try: - health_status, message = self.engine._check_worker_health(time_interval_threashold=10) - # Should be unhealthy with old timestamp - except Exception: - pass - - # Test splitwise mode functionality (lines 890-897) - if hasattr(self.engine, "engine_service"): - try: - # Test splitwise receive thread logic - if hasattr(self.engine.engine_service, "available_prefill_instances"): - # This would test line 890 - pass - - # Test split_mode_get_tasks - if hasattr(self.engine.engine_service, "split_mode_get_tasks"): - # This would test line 891 - pass - - # Test splitwise scheduler condition - if hasattr(self.engine.cfg.scheduler_config, "name"): - if self.engine.cfg.scheduler_config.name == "splitwise": - # This would test lines 892-896 - pass - - except Exception: - pass - - # Test worker initialization progress tracking (lines 950-1003) - if hasattr(self.engine, "worker_init_status"): - # Mock progress tracking - test_status = {} - - # Simulate weight loading progress (lines 951-955) - test_status["weight_loadding"] = 50.0 - - # Simulate layer loading progress (lines 960-965) - test_status["layer_loadding"] = 75 - - # Test progress update logic - progress = test_status.get("layer_loadding", 0) - if progress < 100: - # This simulates the progress checking loop - pass - - # Test worker process ready check (lines 970-975) - if hasattr(self.engine, "_worker_processes_ready"): - try: - self.engine._worker_processes_ready() - except Exception: - pass - - # Test worker process poll check (lines 980-985) - if hasattr(self.engine, "worker_proc") and self.engine.worker_proc: - try: - self.engine.worker_proc.poll() - except Exception: - pass - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_signal_initialization_and_cleanup(self): - """Test signal initialization and cleanup""" - - async def _test(): - - import numpy as np - - # Test expert service signal initialization (lines 640-643) - try: - # Test launched_expert_service_signal initialization - if hasattr(self.engine, "cfg") and hasattr(self.engine.cfg, "parallel_config"): - # This simulates the signal creation logic - np.zeros((1,), dtype=np.int32) - - # Test get_profile_block_num initialization - if hasattr(self.engine.cfg, "worker_num_per_node"): - np.zeros([self.engine.cfg.worker_num_per_node], dtype=np.int32) - - except Exception as e: - print(f"Signal init test exception (expected): {e}") - - # Test cleanup operations (lines 701-711) - try: - # Test zmq_server cleanup - if hasattr(self.engine, "zmq_server"): - # This would test line 705 - pass - - # Test dp_processed cleanup - if hasattr(self.engine, "dp_processed"): - # This would test lines 707-709 - for p in getattr(self.engine, "dp_processed", []): - if hasattr(p, "pid"): - # Simulate process cleanup - pass - - # Test dp_engine_worker_queue_server cleanup - if hasattr(self.engine, "dp_engine_worker_queue_server"): - # This would test lines 710-711 - for p in getattr(self.engine, "dp_engine_worker_queue_server", []): - if hasattr(p, "cleanup"): - # Simulate cleanup - pass - - except Exception as e: - print(f"Cleanup test exception (expected): {e}") - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_environment_flags_and_variables(self): - """Test environment flags and variables""" - - async def _test(): - import os - from unittest.mock import patch - - # Test V1_KVCACHE_SCHEDULER flag (line 744) - with patch.dict(os.environ, {"ENABLE_V1_KVCACHE_SCHEDULER": "1"}): - # Simulate the environment check - if os.getenv("ENABLE_V1_KVCACHE_SCHEDULER") == "1": - # This would trigger line 744 - pass - - # Test FLAGS settings (lines 745-753) - variables = {} - - # Test use_pd_disaggregation flags (lines 745-747) - variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 - variables["FLAGS_use_pd_disaggregation"] = 1 - - # Test splitwise_role == "prefill" condition (lines 749-750) - if hasattr(self.engine, "cfg") and hasattr(self.engine.cfg, "scheduler_config"): - if getattr(self.engine.cfg.scheduler_config, "splitwise_role", None) == "prefill": - variables["FLAGS_fmt_write_cache_completed_signal"] = 1 - - # Test max_partition_size setting (line 753) - variables["FLAGS_max_partition_size"] = 1024 - - # Test think_end_id logic (line 785) - if hasattr(self.engine, "data_processor") and hasattr(self.engine.data_processor, "tokenizer"): - try: - tokenizer = self.engine.data_processor.tokenizer - if hasattr(tokenizer, "vocab"): - # Simulate think_end_id extraction - pass # Mock value simulation - except Exception: - pass - - # Test multi-node IP configuration (line 794) - if hasattr(self.engine, "cfg") and hasattr(self.engine.cfg, "ips"): - try: - ips = ",".join(self.engine.cfg.ips) - f"some_command --ips {ips} --nnodes {len(self.engine.cfg.ips)}" - except Exception: - pass - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_additional_edge_cases(self): - """Test additional edge cases and error conditions""" - - async def _test(): - import time - - # Test thread joining with timeout (line 1003) - if hasattr(self.engine, "checking_worker_status_thread"): - try: - # Simulate thread join with timeout - if hasattr(self.engine.checking_worker_status_thread, "join"): - self.engine.checking_worker_status_thread.join(timeout=0.001) - except Exception: - pass - - # Test time.sleep calls (line 850) - # This is mainly for coverage of the sleep statement - time.sleep(0.001) # Minimal sleep for coverage - - # Test exception handling in sub service extraction (lines 688-689) - try: - # Simulate exception in service extraction - raise Exception("Test service extraction error") - except Exception as e: - # This covers the exception handling pattern - error_msg = f"Error extracting sub services: {e}" - self.assertIn("Error extracting sub services", error_msg) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_guided_input_validation(self): - """Test guided input validation functionality""" - - async def _test(): - from unittest.mock import Mock - - # Test _has_guided_input method (line 340) - if hasattr(self.engine, "_has_guided_input"): - # Create mock request with guided inputs - request = Mock() - request.guided_json = {"type": "object"} - request.guided_regex = None - request.guided_choice = None - request.structural_tag = None - request.guided_grammar = None - request.guided_json_object = None - - result = self.engine._has_guided_input(request) - self.assertTrue(result) - - # Test with no guided inputs - request.guided_json = None - result = self.engine._has_guided_input(request) - self.assertFalse(result) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - def test_request_validation_errors(self): """Test request validation error scenarios""" async def _test(): # Test input length validation (lines 438-443, 446-448) try: + prompts = [0, 1, 2] # Create sampling params with very high min_tokens to trigger error - sampling_params = SamplingParams(min_tokens=999999) + sampling_params = SamplingParams(min_tokens=999999, n=1) # This should trigger the min_tokens validation error - await self.engine.add_request("test_validation", "Short prompt", sampling_params) + await self.engine.add_request("test_validation", prompts, sampling_params) except Exception as e: # Expected to fail due to validation self.assertIn("min_dec_len", str(e).lower()) @@ -1178,8 +440,16 @@ class TestAsyncLLMEngine(unittest.TestCase): # Test max model len validation try: # Create a very long prompt to trigger max_model_len error - long_prompt = "A" * 10000 # Very long prompt - await self.engine.add_request("test_long", long_prompt) + long_prompts = {"prompt_token_ids": [1] * 3000, "prompt_token_ids_len": 3000} # 超过max_model_len + await self.engine.add_request("test_long", long_prompts) + except EngineError as e: + # 根据实际错误消息调整断言 + error_msg = str(e).lower() + self.assertTrue( + "exceeds the limit" in error_msg + or "input text is too long" in error_msg + or "input_ids_len" in error_msg + ) except Exception: # Expected to fail due to length validation pass @@ -1189,40 +459,6 @@ class TestAsyncLLMEngine(unittest.TestCase): result = self.run_async_test(_test()) self.assertTrue(result) - def test_generate_exception_handling(self): - """Test generate method exception handling scenarios""" - - async def _test(): - # Test GeneratorExit handling (lines 504-506) - try: - # Create a generator and simulate GeneratorExit - generator = self.engine.generate("Test prompt", SamplingParams(max_tokens=5)) - - # Get first output then simulate exit - await generator.__anext__() - - # Simulate GeneratorExit by calling generator.close() - await generator.aclose() - - except Exception: - # Expected behavior - pass - - # Test general exception handling (lines 507-510) - try: - # Use invalid prompt type to trigger exception - generator = self.engine.generate(None, SamplingParams(max_tokens=5)) - async for _ in generator: - pass - except Exception: - # Expected behavior - pass - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - def test_get_methods_coverage(self): """Test get_model_config and get_tokenizer methods""" @@ -1237,98 +473,267 @@ class TestAsyncLLMEngine(unittest.TestCase): # This should hit line 333: return self.data_processor.tokenizer self.assertIsNotNone(tokenizer) + # Test _has_guided_input method + from unittest.mock import Mock + + # Test with guided input + request_with_guided = Mock() + request_with_guided.guided_json = {"type": "object"} + request_with_guided.guided_regex = None + request_with_guided.guided_choice = None + request_with_guided.structural_tag = None + request_with_guided.guided_grammar = None + request_with_guided.guided_json_object = None + + result = self.engine._has_guided_input(request_with_guided) + self.assertTrue(result) + return True result = self.run_async_test(_test()) self.assertTrue(result) - def test_request_id_auto_generation(self): - """Test request ID generation when None is provided""" + def test_generate_engine_not_started(self): + """Test add_request and generate method when engine is not started""" async def _test(): - # Test line 377: request_id = str(uuid.uuid4()) - queue = await self.engine.add_request( - None, "Test prompt for UUID", SamplingParams(max_tokens=5) # This should trigger UUID generation + # Create a new engine instance without starting it + engine_args = EngineArgs( + model=MODEL_NAME, + max_model_len=8192, + tensor_parallel_size=1, + engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 2, + cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 2, ) - # The request should have been assigned a UUID - self.assertIsNotNone(queue.request_id) - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_prompt_format_branches(self): - """Test different prompt format branches""" - - async def _test(): - # Test dict prompt format (line 396) - dict_prompt = {"prompt": "Hello world dict", "some_param": "value"} + async_pid = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 2 + unstarted_engine = AsyncLLM.from_engine_args(engine_args, pid=async_pid) + # Don't call start() or init_connections() - engine is not fully initialized + # Test add_request method when engine is not fully initialized try: - queue = await self.engine.add_request("dict_test", dict_prompt, SamplingParams(max_tokens=5)) - self.assertIsNotNone(queue) - except Exception: - pass - - # Test list prompt format (line 391-394) - try: - # Use actual token IDs that might work - list_prompt = [1, 2, 3] # Simple token IDs - queue = await self.engine.add_request("list_test", list_prompt, SamplingParams(max_tokens=5)) - self.assertIsNotNone(queue) - except Exception: - # May fail but covers the branch - pass - - return True - - result = self.run_async_test(_test()) - self.assertTrue(result) - - def test_validation_error_branches(self): - """Test validation error scenarios to hit specific lines""" - - async def _test(): - from fastdeploy.utils import EngineError - - # Test min_tokens validation (lines 437-443) - try: - # This should trigger the validation error at line 438-443 - sampling_params = SamplingParams(min_tokens=50000) # Very high value - await self.engine.add_request("min_tokens_test", "Short", sampling_params) + sampling_params = SamplingParams(max_tokens=10) + await unstarted_engine.add_request("test_request", "Test prompt", sampling_params) + self.fail("Expected EngineError was not raised in add_request") except EngineError as e: - # Expected - this hits lines 438-443 + # Uninitialized engine should wrap error from add_request with error_code 400 self.assertEqual(e.error_code, 400) - self.assertIn("min_dec_len", str(e)) - except Exception: - pass + self.assertIn("async_llm add request failed", str(e)) + except Exception as e: + self.fail(f"Unexpected exception type in add_request: {type(e).__name__}: {e}") + + # Test generate method when engine is not fully initialized (ZMQ not connected) + try: + sampling_params = SamplingParams(max_tokens=10) + generator = unstarted_engine.generate("Test prompt", sampling_params) + async for _ in generator: + pass + self.fail("Expected EngineError was not raised in generate") + except EngineError as e: + # Generate should fail fast with initialization error + self.assertEqual(e.error_code, 500) + self.assertIn("init_connections", str(e)) + except Exception as e: + self.fail(f"Unexpected exception type in generate: {type(e).__name__}: {e}") return True result = self.run_async_test(_test()) self.assertTrue(result) - def test_engine_service_none_error(self): - """Test error when engine_service is None""" + def test_zmq_connection_initialization_failure(self): + """Test ZMQ connection initialization failure""" async def _test(): - from fastdeploy.utils import EngineError + from unittest.mock import Mock, patch - # Temporarily set engine_service to None to test line 374 - original_service = self.engine.engine_service - try: - self.engine.engine_service = None + # Create a new engine instance + engine_args = EngineArgs( + model=MODEL_NAME, + max_model_len=8192, + tensor_parallel_size=1, + engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 4, + cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 4, + ) - with self.assertRaises(EngineError) as cm: - await self.engine.add_request("test", "Hello") + async_pid = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 4 + test_engine = AsyncLLM.from_engine_args(engine_args, pid=async_pid) - self.assertEqual(cm.exception.error_code, 500) + # Test connection manager initialization failure + with ( + patch("fastdeploy.engine.async_llm.ZmqIpcClient") as mock_client_class, + patch("fastdeploy.engine.async_llm.DealerConnectionManager") as mock_manager_class, + ): - finally: - # Restore - self.engine.engine_service = original_service + # Mock successful client creation + mock_client = Mock() + mock_client_class.return_value = mock_client + + # Mock DealerConnectionManager to fail on initialize + mock_manager = Mock() + mock_manager.running = False + mock_manager.initialize.side_effect = Exception("Failed to initialize connection manager") + mock_manager_class.return_value = mock_manager + + try: + await test_engine.init_connections() + self.fail("Expected exception was not raised") + except Exception as e: + self.assertIn("Failed to initialize connection manager", str(e)) + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_add_request_exception_handling(self): + """Test add_request exception handling (lines 447-448 in async_llm.py)""" + + async def _test(): + from unittest.mock import patch + + # Mock data_processor to raise exception + with patch.object(self.engine, "data_processor") as mock_processor: + mock_processor.process_request_dict.side_effect = RuntimeError("Processing failed") + + try: + await self.engine.add_request("test_id", "test prompt", SamplingParams(max_tokens=10)) + self.fail("Expected EngineError was not raised") + except EngineError as e: + self.assertEqual(e.error_code, 400) + self.assertIn("async_llm add request failed", str(e)) + self.assertIn("Processing failed", str(e)) + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_generate_generator_exit_handled(self): + """Test generate handles GeneratorExit from response queue gracefully""" + + async def _test(): + from unittest.mock import AsyncMock, patch + + # Ensure engine has a valid request_client and connection_manager.running + self.assertIsNotNone(self.engine.request_client) + self.assertIsNotNone(self.engine.connection_manager) + + # Mock connection_manager to simulate GeneratorExit from response_queue.get() + mock_connection_manager = AsyncMock() + mock_queue = AsyncMock() + mock_queue.get.side_effect = GeneratorExit("Generator closed") + mock_connection_manager.get_connection.return_value = (AsyncMock(), mock_queue) + mock_connection_manager.running = True + + with patch.object(self.engine, "connection_manager", mock_connection_manager): + generator = self.engine.generate("test", SamplingParams(max_tokens=10)) + + # generate should swallow GeneratorExit and not propagate it to caller + try: + async for _ in generator: + pass + except GeneratorExit: + self.fail("GeneratorExit should be handled inside generate") + except Exception as e: + self.fail(f"Unexpected exception: {e}") + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_generate_cleanup_request_error_handled(self): + """generate should swallow cleanup_request errors in finally block""" + + async def _test(): + from unittest.mock import AsyncMock, patch + + from fastdeploy.engine.request import ( + CompletionOutput, + RequestMetrics, + RequestOutput, + ) + + # Build a minimal RequestOutput dict that generate() can consume + metrics = RequestMetrics(arrival_time=0.0) + completion = CompletionOutput(index=0, send_idx=0, token_ids=[], text="") + ro = RequestOutput(request_id="cmpl-test_0", outputs=completion, finished=True, metrics=metrics) + ro_dict = ro.to_dict() + + engine = self.engine + + # Mock connection_manager and response queue + mock_queue = AsyncMock() + mock_queue.get.return_value = [ro_dict] + mock_dealer = AsyncMock() + mock_cm = AsyncMock() + mock_cm.get_connection.return_value = (mock_dealer, mock_queue) + mock_cm.running = True + # Force cleanup_request to raise so we hit the except/pass branch + mock_cm.cleanup_request.side_effect = Exception("cleanup error") + + # Stub add_request to avoid touching real ZMQ or data_processor + async def fake_add_request(*args, **kwargs): + return None + + # Simple output processor that returns the dict unchanged + class DummyOutputProcessor: + def _process_output(self, response_dict, **kwargs): + return response_dict + + with ( + patch.object(engine, "connection_manager", mock_cm), + patch.object(engine, "add_request", side_effect=fake_add_request), + patch.object(engine, "request_client", object()), + patch.object(engine, "output_processor", DummyOutputProcessor()), + ): + outputs = [] + async for out in engine.generate("test", SamplingParams(max_tokens=5)): + outputs.append(out) + + # We should get exactly one finished output and no exception + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + + return True + + result = self.run_async_test(_test()) + self.assertTrue(result) + + def test_shutdown_exception_handling(self): + """Test shutdown method exception handling""" + + async def _test(): + from unittest.mock import Mock, patch + + # Create test engine + engine_args = EngineArgs( + model=MODEL_NAME, + max_model_len=8192, + tensor_parallel_size=1, + engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 6, + cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 6, + ) + async_pid = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 6 + test_engine = AsyncLLM.from_engine_args(engine_args, pid=async_pid) + + # Mock components that raise exceptions during shutdown + test_engine.connection_manager = Mock() + test_engine.connection_manager.close.side_effect = Exception("Connection manager close failed") + + test_engine.request_client = Mock() + test_engine.request_client.close.side_effect = Exception("Request client close failed") + + # Patch EngineServiceClient.shutdown to raise as well so we hit + # the exception handling path in AsyncLLM.shutdown (lines 566-567) + with patch("fastdeploy.engine.async_llm.EngineServiceClient.shutdown", side_effect=Exception("boom")): + # Test that shutdown handles all exceptions gracefully + try: + await test_engine.shutdown() + # Should not raise exception despite internal failures + except Exception as e: + self.fail(f"Shutdown should handle exceptions gracefully: {e}") return True diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py new file mode 100644 index 0000000000..a3ed0270e3 --- /dev/null +++ b/tests/engine/test_common_engine.py @@ -0,0 +1,849 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import time +import unittest +from unittest.mock import MagicMock, Mock, patch + +import numpy as np + +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.common_engine import EngineService + +MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle" + + +class TestCommonEngine(unittest.TestCase): + """Test case for EngineService functionality (lines 1215-1664)""" + + @classmethod + def setUpClass(cls): + """Set up EngineService for testing""" + try: + # Create engine args for testing + engine_args = EngineArgs( + model=MODEL_NAME, + max_model_len=8192, + tensor_parallel_size=1, + engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 10, + cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 10, + ) + + # Create and start the engine service + cls.cfg = engine_args.create_engine_config() + cls.engine = EngineService(cls.cfg, start_queue=True, use_async_llm=True) + + # Start the engine service + cls.engine.start() + + except Exception as e: + print(f"Setting up EngineService failed: {e}") + raise + + @classmethod + def tearDownClass(cls): + """Clean up after all tests""" + if hasattr(cls, "engine") and cls.engine is not None: + try: + cls.engine._exit_sub_services() + print("Engine cleanup completed") + except Exception as e: + print(f"Error during engine cleanup: {e}") + + def setUp(self): + """Set up before each test method""" + print(f"Starting test: {self._testMethodName}") + + def tearDown(self): + """Clean up after each test method""" + print(f"Completed test: {self._testMethodName}") + + def test_exit_sub_services(self): + """Test _exit_sub_services method (lines 1215-1291)""" + # Test that _exit_sub_services can be called without error + # Note: We won't actually call it since it would shut down the engine + # Instead we'll test that the method exists and has expected attributes + self.assertTrue(hasattr(self.engine, "_exit_sub_services")) + self.assertTrue(callable(getattr(self.engine, "_exit_sub_services"))) + + # Test that engine has expected attributes that would be cleaned up + if hasattr(self.engine, "worker_proc"): + self.assertIsNotNone(self.engine.worker_proc) + + # Verify running state + self.assertTrue(self.engine.running) + + def test_worker_processes_ready(self): + """Test _worker_processes_ready method (lines 1292-1299)""" + # Test with real engine that should have worker_ready_signal + if hasattr(self.engine, "worker_ready_signal"): + result = self.engine._worker_processes_ready() + # Result should be boolean + self.assertIsInstance(result, bool) + else: + self.skipTest("worker_ready_signal not available") + + def test_init_worker_signals(self): + """Test _init_worker_signals method (lines 1301-1361)""" + # Since engine is already started, signals should be initialized + self.assertTrue(hasattr(self.engine, "worker_ready_signal")) + self.assertTrue(hasattr(self.engine, "loaded_model_signal")) + + # Test that signals have expected properties + if hasattr(self.engine, "worker_ready_signal"): + self.assertIsNotNone(self.engine.worker_ready_signal) + + if hasattr(self.engine, "loaded_model_signal"): + self.assertIsNotNone(self.engine.loaded_model_signal) + + def test_setting_environ_variables(self): + """Test _setting_environ_variables method (lines 1362-1408)""" + result = self.engine._setting_environ_variables() + + # Check that result is a string and contains expected variables + self.assertIsInstance(result, str) + self.assertIn("ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY=0", result) + self.assertIn("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python", result) + self.assertIn("FLAGS_use_append_attn=1", result) + self.assertIn("NCCL_ALGO=Ring", result) + + def test_start_worker_service(self): + """Test _start_worker_service method (lines 1409-1517)""" + # Since engine is already started, we can test that worker process exists + if hasattr(self.engine, "worker_proc") and self.engine.worker_proc: + # Worker process should be running + self.assertIsNotNone(self.engine.worker_proc) + # Process should be alive (poll returns None if still running) + poll_result = self.engine.worker_proc.poll() + if poll_result is not None: + self.skipTest("Worker process is not running") + else: + self.skipTest("Worker process not available") + + def test_stop_profile(self): + """Test _stop_profile method (lines 1519-1532)""" + # Test method exists and is callable + self.assertTrue(hasattr(self.engine, "_stop_profile")) + self.assertTrue(callable(getattr(self.engine, "_stop_profile"))) + + # We won't actually call it as it modifies engine state + # Just verify the do_profile attribute exists + self.assertTrue(hasattr(self.engine, "do_profile")) + + def test_check_health(self): + """Test check_health method (lines 1533-1544)""" + if hasattr(self.engine, "worker_healthy_live_signal"): + is_healthy, message = self.engine.check_health(time_interval_threashold=30) + + # Should return tuple of (bool, str) + self.assertIsInstance(is_healthy, bool) + self.assertIsInstance(message, str) + else: + self.skipTest("worker_healthy_live_signal not available") + + def test_launch_components(self): + """Test launch_components method (lines 1545-1605)""" + # Method should exist and be callable + self.assertTrue(hasattr(self.engine, "launch_components")) + self.assertTrue(callable(getattr(self.engine, "launch_components"))) + + # Test that scheduler exists (should be created during start) + if hasattr(self.engine, "scheduler"): + self.assertIsNotNone(self.engine.scheduler) + + def test_check_worker_initialize_status(self): + """Test check_worker_initialize_status method (lines 1606-1663)""" + # Method should exist and be callable + self.assertTrue(hasattr(self.engine, "check_worker_initialize_status")) + self.assertTrue(callable(getattr(self.engine, "check_worker_initialize_status"))) + + # Test that worker_init_status exists + if hasattr(self.engine, "worker_init_status"): + self.assertIsInstance(self.engine.worker_init_status, dict) + + def test_engine_started_successfully(self): + """Test that engine started successfully and has expected state""" + # Verify engine is running + self.assertTrue(self.engine.running) + + # Verify data processor was created + if hasattr(self.engine, "data_processor"): + self.assertIsNotNone(self.engine.data_processor) + + # Verify IPC signal suffix is set + if hasattr(self.engine, "ipc_signal_suffix"): + self.assertIsNotNone(self.engine.ipc_signal_suffix) + + +if __name__ == "__main__": + unittest.main() + + +class TestCommonEngineAdditionalCoverage(unittest.TestCase): + """Additional unit tests focusing on branch coverage for common_engine.py + + These tests heavily mock subprocess/threading/IPC to avoid starting real workers + and to drive specific code paths that were previously uncovered. + """ + + def _make_cfg(self, **kwargs): + args = EngineArgs( + model=MODEL_NAME, + max_model_len=128, + tensor_parallel_size=1, + # give unique ports to avoid collision with other tests + engine_worker_queue_port=str(int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 20), + cache_queue_port=str(int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 20), + enable_prefix_caching=True, + **kwargs, + ) + # Keep batch tokens small to satisfy FDConfig checks: + # max_num_batched_tokens <= max_model_len * max_num_seqs + if getattr(args, "max_num_batched_tokens", None) is None: + args.max_num_batched_tokens = 128 + # Always enable chunked prefill in tests to avoid another strict check + args.enable_chunked_prefill = True + + # If DP > 1, we must provide enough engine_worker_queue_port for each dp index + dp = kwargs.get("data_parallel_size", args.data_parallel_size) + base = int(args.engine_worker_queue_port.split(",")[0]) + if dp and dp > 1: + ports = ",".join(str(base + i) for i in range(dp)) + args.engine_worker_queue_port = ports + + return args.create_engine_config(port_availability_check=False) + + def _stub_processor(self): + class _Tok: + def __init__(self): + self.vocab = {"": 42, "\n": 10, "<|IMAGE_PLACEHOLDER|>": 9} + + def get_vocab(self): + return self.vocab + + class _Proc: + def __init__(self): + self.tokenizer = _Tok() + self.eos_token_id_len = 1 + self.pad_token_id = 0 + + return _Proc() + + def test_start_prefill_branch_cache_manager_and_worker_dead(self): + """Cover lines 184-185, 194-197, 221, 226-227 in start().""" + # For prefill + local scheduler the core code now requires a router. + # Also, with the newer CacheConfig semantics we must ensure that + # prefill_kvcache_block_num (num_gpu_blocks_override * kv_cache_ratio) + # is >= max_block_num_per_seq; use 3 blocks so that with the default + # kv_cache_ratio=0.75 we still satisfy the assertion. + with patch("fastdeploy.engine.args_utils.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + cfg = self._make_cfg( + splitwise_role="prefill", + num_gpu_blocks_override=3, + router="0.0.0.0:30000", + ) + + # Patch EngineWorkerQueue before EngineService ctor to avoid real IPC + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + def get_server_port(self): + return 0 + + def cleanup(self): + pass + + def num_tasks(self): + return 0 + + def num_cache_infos(self): + return 0 + + def disaggregate_queue_empty(self): + return True + + def get_disaggregated_tasks(self): + return [] + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + # Patch heavy pieces + eng.create_data_processor = lambda: setattr(eng, "data_processor", self._stub_processor()) + eng._process_splitwise_task = lambda: None + eng._schedule_request_to_worker = lambda: None + eng._schedule_request_to_worker_v1 = lambda: None + + started_cache = {} + + def fake_start_cache(device_ids, suffix): + started_cache["called"] = True + # return a list to mimic processes + return [object()] + + eng.start_cache_service = fake_start_cache + + # Signals: make loaded_model_signal ready immediately; include launched_cache_manager_signal + class Sig: + def __init__(self, v=0): + self.value = np.array([v], dtype=np.int32) + + def clear(self): + pass + + def fake_init_signals(): + eng.worker_ready_signal = Sig(0) + eng.loaded_model_signal = Sig(1) # ready -> skip wait loop + eng.launched_cache_manager_signal = Sig(0) + + eng._init_worker_signals = fake_init_signals + + # Worker start stub and initialization status -> False to trigger error path + eng._start_worker_service = lambda: Mock(stdout=Mock(), poll=lambda: None) + eng.check_worker_initialize_status = lambda: False + + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + # Avoid starting token processor loop + eng.token_processor.run = lambda: None + ok = eng.start(async_llm_pid=12345) + + # start() returns False on failure + self.assertFalse(ok) + # cache manager started before workers (lines 184-185) + self.assertTrue(started_cache.get("called", False)) + # launched_cache_manager_signal set (line 221) + self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1) + # avoid atexit finalizer + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_start_mixed_branch_cache_after_load_and_zmq(self): + """Cover lines 215-217 and 231 in start().""" + cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=2) + + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + def get_server_port(self): + return 0 + + def cleanup(self): + pass + + def num_tasks(self): + return 0 + + def num_cache_infos(self): + return 0 + + def disaggregate_queue_empty(self): + return True + + def get_disaggregated_tasks(self): + return [] + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + eng.create_data_processor = lambda: setattr(eng, "data_processor", self._stub_processor()) + eng._process_splitwise_task = lambda: None + eng._schedule_request_to_worker = lambda: None + eng._schedule_request_to_worker_v1 = lambda: None + + started_cache = {} + + def fake_start_cache(device_ids, suffix): + started_cache["called"] = True + return [object()] + + eng.start_cache_service = fake_start_cache + + class Sig: + def __init__(self, v=0): + self.value = np.array([v], dtype=np.int32) + + def clear(self): + pass + + def fake_init_signals(): + eng.worker_ready_signal = Sig(0) + eng.loaded_model_signal = Sig(1) + eng.launched_cache_manager_signal = Sig(0) + + eng._init_worker_signals = fake_init_signals + + eng._start_worker_service = lambda: Mock(stdout=Mock(), poll=lambda: None) + eng.check_worker_initialize_status = lambda: True + + zmq_called = {} + eng.start_zmq_service = lambda pid: zmq_called.setdefault("pid", pid) + + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + eng.token_processor.run = lambda: None + eng.start(async_llm_pid=8888) + + self.assertTrue(started_cache.get("called", False)) # lines 215-217 + self.assertEqual(zmq_called.get("pid"), 8888) # line 231 + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_insert_zmq_task_error_logging(self): + """Cover lines 934-935 and 937 in _insert_zmq_task_to_scheduler.""" + cfg = self._make_cfg(splitwise_role="mixed") + + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + def get_server_port(self): + return 0 + + def cleanup(self): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=False) + eng.running = True + + class DummyRecv: + def __init__(self, msg): + self.msg = msg + + def receive_json_once(self, block): + return self.msg, None + + def close(self): + pass + + # Case 1: context terminated -> info branch + eng.recv_request_server = DummyRecv("Context was terminated") + with patch.object(eng, "llm_logger") as _: + eng._insert_zmq_task_to_scheduler() + + # Case 2: other error -> error branch + eng.recv_request_server = DummyRecv("Other Error") + with patch.object(eng, "llm_logger") as _: + eng._insert_zmq_task_to_scheduler() + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_exit_sub_services_cleanup_paths(self): + """Cover lines 1312-1340, 1350-1354 in _exit_sub_services.""" + cfg = self._make_cfg(splitwise_role="mixed") + + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + def get_server_port(self): + return 0 + + def cleanup(self): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + # attach stubs used by cleanup + class Sig: + def __init__(self): + self.value = np.array([0], dtype=np.int32) + + def clear(self): + pass + + eng.worker_ready_signal = Sig() + eng.loaded_model_signal = Sig() + eng.exist_task_signal = Sig() + eng.exist_swapped_task_signal = Sig() + eng.worker_healthy_live_signal = Sig() + eng.cache_ready_signal = Sig() + eng.swap_space_ready_signal = Sig() + eng.exist_prefill_task_signal = Sig() + eng.model_weights_status_signal = Sig() + eng.prefix_tree_status_signal = Sig() + eng.kv_cache_status_signal = Sig() + eng.send_response_server = Mock() + eng.recv_request_server = Mock() + eng.recv_control_cmd_server = Mock() + + # ensure cache manager control flags exist before first call + eng.resource_manager.cache_manager.shm_cache_task_flag_broadcast = Mock(clear=lambda: None) + eng.resource_manager.cache_manager.cache_ready_signal = Mock(clear=lambda: None) + eng.cache_manager_processes = [] + + # worker_proc kill raises -> cover 1312-1313 + eng.worker_proc = MagicMock(pid=1001) + with patch("fastdeploy.engine.common_engine.os.getpgid", side_effect=RuntimeError("boom")): + eng._exit_sub_services() + + # Prepare cache manager processes to hit both normal and exception branch + class DummyCacheMgr: + def __init__(self, pid, raise_on_kill=False): + self.pid = pid + self.raise_on_kill = raise_on_kill + + eng.cache_manager_processes = [DummyCacheMgr(2001, False), DummyCacheMgr(2002, True)] + eng.resource_manager.cache_manager.shm_cache_task_flag_broadcast = Mock(clear=lambda: None) + eng.resource_manager.cache_manager.cache_ready_signal = Mock(clear=lambda: None) + + def fake_getpgid(pid): + return pid + + def fake_killpg(pid, sig): + if pid == 2002: + raise RuntimeError("kill fail") + + # cache_task_queue with cleanup + eng.cache_task_queue = Mock() + eng.cache_task_queue.cleanup = Mock() + + eng.dp_processed = [Mock(pid=3001, join=lambda: None)] + eng.dp_engine_worker_queue_server = [Mock(cleanup=lambda: None)] + + with ( + patch("fastdeploy.engine.common_engine.os.getpgid", side_effect=fake_getpgid), + patch("fastdeploy.engine.common_engine.os.killpg", side_effect=fake_killpg), + ): + eng._exit_sub_services() + + # Now cover manager.shutdown warning path (no cleanup attribute) + class DummyMgr: + def __init__(self): + self.manager = Mock(shutdown=Mock(side_effect=RuntimeError("shutdown fail"))) + + eng.cache_task_queue = DummyMgr() + eng._exit_sub_services() + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_setting_environ_variables_v1_prefill_mm(self): + """Cover lines 1476-1485 in _setting_environ_variables.""" + # For prefill + local scheduler the core code now requires a router + # and ENABLE_V1_KVCACHE_SCHEDULER=0 when using the default IPC protocol. + with patch("fastdeploy.engine.args_utils.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + cfg = self._make_cfg(splitwise_role="prefill", router="0.0.0.0:30000") + cfg.model_config.enable_mm = True + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", True): + prefix = eng._setting_environ_variables() + self.assertIn("FLAGS_use_pd_disaggregation_per_chunk=1", prefix) + self.assertIn("FLAGS_fmt_write_cache_completed_signal=1", prefix) + self.assertIn("FLAGS_max_partition_size=1024", prefix) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_start_worker_service_cmd_build(self): + """Cover 1517, 1526, 1568, 1592, 1595 by building the worker command with mocks.""" + with patch("fastdeploy.config.get_host_ip", return_value="127.0.0.1"): + cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4, ips=["127.0.0.1", "127.0.0.2"]) + # Make model multi-modal so env var branch already covered above; here not required + cfg.structured_outputs_config.logits_processors = ["A", "B"] + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + eng.data_processor = self._stub_processor() + + captured = {"cmd": None} + + class DummyProc: + def __init__(self): + self.stdout = None + + def poll(self): + return None + + def fake_popen(cmd, stdout, shell, preexec_fn): + captured["cmd"] = cmd + return DummyProc() + + with patch("fastdeploy.engine.common_engine.subprocess.Popen", side_effect=fake_popen): + with patch("fastdeploy.engine.common_engine.llm_logger"): + p = eng._start_worker_service() + + self.assertIsNotNone(p) + self.assertIsInstance(captured["cmd"], str) + # logits processors added (1568) + self.assertIn("--logits-processors A B", captured["cmd"]) # type: ignore + # num_gpu_blocks_override added (1592) + self.assertIn("--num_gpu_blocks_override 4", captured["cmd"]) # type: ignore + # ips/nnodes added when nnode > 1 (1595) + self.assertIn("--nnodes 2", captured["cmd"]) # type: ignore + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_check_health_unhealthy(self): + """Cover line 1628: unhealthy worker.""" + cfg = self._make_cfg(splitwise_role="mixed") + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + class Sig: + def __init__(self, v): + self.value = np.array([v], dtype=np.int32) + + # set worker live time far past threshold + eng.worker_healthy_live_signal = Sig(int(time.time()) - 1000) + ok, msg = eng.check_health(time_interval_threashold=1) + self.assertFalse(ok) + self.assertIn("Not Healthy".lower(), msg.lower()) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_launch_components_expert_parallel(self): + """Cover 1635-1638, 1660-1676, 1684-1703 in launch_components().""" + # For prefill + local scheduler the core code now requires a router + # and ENABLE_V1_KVCACHE_SCHEDULER=0 when using the default IPC protocol. + with patch("fastdeploy.engine.args_utils.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + cfg = self._make_cfg( + splitwise_role="prefill", + # enable expert parallel and dp > 1 to go into the branch + data_parallel_size=2, + enable_expert_parallel=True, + router="0.0.0.0:30000", + ) + + # Provide EngineWorkerQueue stub for ctor + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + def get_server_port(self): + return 0 + + def cleanup(self): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=True, use_async_llm=True) + + # Init signals to create launched_expert_service_signal + with patch("fastdeploy.engine.common_engine.envs.FD_ENABLE_MULTI_API_SERVER", False): + eng.ipc_signal_suffix = cfg.parallel_config.engine_worker_queue_port[0] + eng._init_worker_signals() + + # Don't create real queues/processes + with ( + patch("fastdeploy.engine.common_engine.EngineWorkerQueue") as FakeQ, + patch("fastdeploy.engine.common_engine.multiprocessing.Process") as FakeP, + ): + # Fake queue instances with cleanup + FakeQ.return_value = Mock(cleanup=lambda: None) + + # When starting process, immediately mark the signal as 1 to break waiting loop + def start_side_effect(*args, **kwargs): + # set value for dp id 1 + eng.launched_expert_service_signal.value[1] = 1 + + proc_instance = Mock(start=start_side_effect) + FakeP.return_value = proc_instance + + # Avoid scheduler doing real work + eng.scheduler.start = lambda *a, **k: None + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + eng.launch_components() + + # Verify expert service branch executed + self.assertTrue(hasattr(eng, "dp_processed")) + self.assertGreaterEqual(len(eng.dp_processed), 1) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_check_worker_initialize_status_progress(self): + """Cover 1710-1762 by simulating stdout and ready signals.""" + cfg = self._make_cfg(splitwise_role="mixed") + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + # Fake worker process stdout content that matches regexes + lines = [ + b"Loading checkpoint shards: 1\n", + b"Start load layer 5\n", + ] + + class DummyProc: + def __init__(self, it): + self._it = iter(it) + + @property + def stdout(self): + return self._it + + def poll(self): + return None + + eng.worker_proc = DummyProc(lines) + eng.worker_init_status = {} + eng.cfg.model_config.num_hidden_layers = 8 + + # worker_ready_signal makes _worker_processes_ready() return True + class Sig: + def __init__(self): + self.value = np.array([1], dtype=np.int32) + + eng.worker_ready_signal = Sig() + + # Replace tqdm and sleep for fast execution + class DummyPbar: + def __init__(self): + self.n = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def update(self, delta=0, *args, **kwargs): + try: + self.n += int(delta) + except Exception: + self.n = 0 + + def refresh(self): + pass + + with patch("fastdeploy.engine.common_engine.tqdm", lambda *a, **k: DummyPbar()): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): + ok = eng.check_worker_initialize_status() + self.assertTrue(ok) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_worker_processes_ready_false(self): + """Cover line 1382 returning False.""" + cfg = self._make_cfg() + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + + class Sig: + def __init__(self): + # less than worker_num_per_node + self.value = np.array([0], dtype=np.int32) + + eng.worker_ready_signal = Sig() + self.assertFalse(eng._worker_processes_ready()) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_init_worker_signals_profile_iluvatar(self): + """Cover line 1434 by forcing iluvatar custom device and do_profile=True.""" + # do_profile=True when num_gpu_blocks_override is None + cfg = self._make_cfg(num_gpu_blocks_override=None) + + class DummyQ: + def __init__(self, *a, **k): + pass + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + eng.ipc_signal_suffix = cfg.parallel_config.engine_worker_queue_port[0] + with patch("fastdeploy.engine.common_engine.paddle.is_compiled_with_custom_device", return_value=True): + eng._init_worker_signals() + # signal should exist + self.assertTrue(hasattr(eng, "get_profile_block_num_signal")) + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass + + def test_launch_components_dp_mode(self): + """Cover 1648-1652 branch for DP scheduler mode.""" + # When ENABLE_V1_KVCACHE_SCHEDULER=1 the IPC cache-transfer protocol + # is no longer supported; force it to 0 here to avoid the + # NotImplementedError raised in EngineArgs.__post_init__ so we can + # still exercise the DP branch of launch_components. + with patch("fastdeploy.engine.args_utils.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + cfg = self._make_cfg( + splitwise_role="prefill", + data_parallel_size=2, + scheduler_name="dp", + ) + + class DummyQ: + def __init__(self, *a, **k): + self.available_prefill_instances = type("X", (), {"put": lambda *_: None})() + + with patch("fastdeploy.engine.common_engine.EngineWorkerQueue", DummyQ): + eng = EngineService(cfg, start_queue=False, use_async_llm=True) + # Patch scheduler.start so it doesn't do heavy work + eng.scheduler.start = Mock() + eng.launch_components() + eng.scheduler.start.assert_called() + if hasattr(eng, "_finalizer"): + try: + eng._finalizer.detach() + except Exception: + pass diff --git a/tests/inter_communicator/test_zmq_server.py b/tests/inter_communicator/test_zmq_server.py new file mode 100644 index 0000000000..98700f8cad --- /dev/null +++ b/tests/inter_communicator/test_zmq_server.py @@ -0,0 +1,38 @@ +""" +Simple tests for ZmqServerBase.recv_result_handle to cover the startup log line. +""" + +import unittest + +from fastdeploy.inter_communicator.zmq_server import ZmqServerBase + + +class _DummyServer(ZmqServerBase): + """Minimal concrete subclass to satisfy abstract methods. + + We do not create any real ZMQ sockets; we only need to call + recv_result_handle with running=False so the loop is skipped. + """ + + def __init__(self): + super().__init__() + self.socket = None + self.running = False # skip loop to just hit the startup log + + def _create_socket(self): # pragma: no cover - not needed in this test + return None + + def close(self): # pragma: no cover - not needed in this test + pass + + +class TestZmqServerRecvResultHandle(unittest.TestCase): + def test_recv_result_handle_startup_log(self): + """Just invoke recv_result_handle to execute the first log line (L123).""" + srv = _DummyServer() + # Should not raise; returns None after logging start/finish and skipping loop + self.assertIsNone(srv.recv_result_handle()) + + +if __name__ == "__main__": + unittest.main()