""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from __future__ import annotations import asyncio import collections import copy import json import multiprocessing import os import re import signal import subprocess import sys import threading import time import traceback import weakref from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import paddle import zmq from tqdm import tqdm import fastdeploy.metrics.trace as tracing from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( CompletionOutput, ControlRequest, ControlResponse, Request, RequestMetrics, RequestOutput, RequestStatus, RequestType, ) from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.engine.sched.scheduler_metrics_logger import SchedulerMetricsLogger from fastdeploy.eplb.utils import init_eplb_signals from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import ( EngineCacheQueue, EngineWorkerQueue, IPCLock, IPCSignal, ZmqIpcServer, ZmqTcpServer, ) from fastdeploy.inter_communicator.fmq import FMQ from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.spec_decode import SpecMethod from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.trace.constants import LoggingEventName from fastdeploy.trace.trace_logger import print as trace_print from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger try: TokenProcessor = load_token_processor_plugins() llm_logger.info(f"TokenProcessor plugin {TokenProcessor} loaded") except: from fastdeploy.output.token_processor import TokenProcessor def _read_latest_worker_traceback(log_dir: str) -> Optional[str]: """读取 workerlog.* 文件中的最新 traceback。""" try: candidates = sorted(Path(log_dir).glob("workerlog.*"), key=lambda path: path.stat().st_mtime, reverse=True) except OSError: return None for path in candidates: try: content = path.read_text(encoding="utf-8", errors="ignore") except OSError: continue marker = "Traceback (most recent call last):" start = content.rfind(marker) if start != -1: return content[start:].strip() return None def _format_worker_launch_failure_message(log_dir: str) -> str: """格式化 worker 启动失败的错误消息,包含 traceback 信息。""" message = "Failed to launch worker processes, check log/workerlog.* for more details." traceback_text = _read_latest_worker_traceback(log_dir) if traceback_text: return f"{message}\n{traceback_text}" return message class EngineService: """ Base class containing common engine functionality """ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): """ Initializes the LLMEngine with the provided configuration. Args: cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg self.use_async_llm = use_async_llm if self.cfg.parallel_config.data_parallel_size > 1: self.llm_logger = get_logger( "fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log" ) else: self.llm_logger = llm_logger self.is_paused = False # pause request generation self._pause_cond = threading.Condition() self._ctrl_output_queues = {} self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) tp_size = cfg.parallel_config.tensor_parallel_size dp_index = cfg.parallel_config.local_data_parallel_id for tp_rank in range(tp_size): # create worker control response queue engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}" self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)") self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") # create cache control response queue if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}" self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)") self._ctrl_output_queues[name] = FMQ().queue(name, "consumer") self.scheduler = cfg.scheduler_config.scheduler() self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.llm_logger.info("Use V1 KVCache Scheduler") self.resource_manager = ResourceManagerV1( cfg.scheduler_config.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.scheduler_config.splitwise_role, cfg.parallel_config.local_data_parallel_id, ) else: self.llm_logger.info("Use V0 KVCache Scheduler") self.resource_manager = ResourceManager( cfg.scheduler_config.max_num_seqs, cfg, cfg.parallel_config.tensor_parallel_size, cfg.scheduler_config.splitwise_role, cfg.parallel_config.local_data_parallel_id, ) self.start_worker_queue_service(start_queue) os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port) self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}") self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager) self.token_processor = TokenProcessor( cfg=cfg, cached_generated_tokens=self.scheduler, engine_worker_queue=self.engine_worker_queue, split_connector=self.split_connector, ) self.token_processor.set_resource_manager(self.resource_manager) self.scheduler_metrics_logger = SchedulerMetricsLogger( enabled=True, dp_rank=self.cfg.parallel_config.local_data_parallel_id, ) self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger) self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( (self.cfg.scheduler_config.max_num_batched_tokens // idx) // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) self.bos_client = None self.mm_max_tokens_per_item = None self.guided_decoding_checker = None if self.cfg.structured_outputs_config.guided_decoding_backend != "off": self.guided_decoding_checker = schema_checker( self.cfg.structured_outputs_config.guided_decoding_backend, disable_any_whitespace=self.cfg.structured_outputs_config.disable_any_whitespace, ) self._init_worker_monitor_signals() # Pass the GPU KV cache lock to cache_manager for mutual exclusion # between the CPU transfer process and the worker process. self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock # Initialize RegisterManager self._register_manager = RegisterManager( cfg=self.cfg, engine_worker_queue=self.engine_worker_queue, get_is_paused=self._get_is_paused_safe, ) if self.cfg.eplb_config.enable_eplb: current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port init_eplb_signals(cfg, current_suffix) if self.use_async_llm: # Add worker management attributes self.worker_proc = None self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0 self.ipc_signal_suffix = None self.cache_manager_processes = None self._finalizer = weakref.finalize(self, self._exit_sub_services) def start(self, async_llm_pid=None): self.running = True console_logger.debug("Start engineService...") if self.use_async_llm: self.start_worker_service(async_llm_pid) if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.insert_task_to_worker_thread = threading.Thread( target=self._schedule_request_to_worker_v1, daemon=True ) else: self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) self.insert_task_to_worker_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role == "decode": self._decode_process_splitwise_requests() self._register_manager.start() def start_worker_service(self, async_llm_pid=None): # Initialize IPC signals for worker management self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0] self._init_worker_signals() # Create data processor if not exists if not hasattr(self, "data_processor"): self.create_data_processor() # Launch components: scheduler, cache_manager, expert_service et.al. self.launch_components() # If block number is specified and model is deployed in splitwise mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) # Start worker processes self.worker_proc = self._start_worker_service() time.sleep(5) self.worker_init_status = dict() result_container = {} def check_worker_initialize_status_func(res: dict): res["worker_is_alive"] = True if not self.check_worker_initialize_status(): self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR)) res["worker_is_alive"] = False self.check_worker_initialize_status_func_thread = threading.Thread( target=check_worker_initialize_status_func, args=(result_container,), daemon=True ) self.check_worker_initialize_status_func_thread.start() # Wait model loading while self.loaded_model_signal.value[0] == 0: # Make sure worker process is alive if not self.check_worker_initialize_status_func_thread.is_alive(): return False time.sleep(1) # If block number is not specified, let workers do profiling to determine the block number, # and then start the cache manager if self.do_profile: self._stop_profile() elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching: device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) # Worker launched self.check_worker_initialize_status_func_thread.join() if not result_container["worker_is_alive"]: self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR)) return False # Start ZMQ service for communication with AsyncLLM if async_llm_pid: self.start_zmq_service(async_llm_pid) def create_data_processor(self): self.input_processor = InputPreprocessor( self.cfg.model_config, self.cfg.structured_outputs_config.reasoning_parser, self.cfg.limit_mm_per_prompt, self.cfg.mm_processor_kwargs, self.cfg.tool_parser, ) self.data_processor = self.input_processor.create_processor() self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item( self.cfg.model_config.max_model_len ) if self.mm_max_tokens_per_item is not None: max_chunk_tokens = self.cfg.get_max_chunk_tokens(self.mm_max_tokens_per_item) self.cfg.cache_config.postprocess(max_chunk_tokens, self.cfg.scheduler_config.max_num_seqs) def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理 current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port self.llm_logger.info(f"current_suffix: {current_suffix}") exist_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_task_signal = IPCSignal( name="exist_task_signal", array=exist_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_swapped_task_signal = IPCSignal( name="exist_swapped_task_signal", array=exist_swapped_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # exist_prefill_task_signal 用于各worker进程感知是否进行prefill exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32) self.exist_prefill_task_signal = IPCSignal( name="exist_prefill_task_signal", array=exist_prefill_task_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) engine_forward_signal_data = np.zeros([1], dtype=np.int32) self.engine_forward_signal = IPCSignal( name="engine_forward_signal", array=engine_forward_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros( shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 ) self.worker_healthy_live_signal = IPCSignal( name="worker_healthy_live_signal", array=worker_healthy_live_recorded_time_array, dtype=np.int32, suffix=current_suffix, create=True, ) cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32) self.cache_ready_signal = IPCSignal( name="cache_ready_signal", array=cache_ready_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32) self.swap_space_ready_signal = IPCSignal( name="swap_space_ready_signal", array=swap_space_ready_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) cache_transfer_inited_signal_data = np.zeros( shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32 ) self.cache_transfer_inited_signal = IPCSignal( name="cache_transfer_inited_signal", array=cache_transfer_inited_signal_data, dtype=np.int32, suffix=current_suffix, create=True, ) model_weights_status = np.zeros([1], dtype=np.int32) self.model_weights_status_signal = IPCSignal( name="model_weights_status", array=model_weights_status, dtype=np.int32, suffix=current_suffix, create=True, ) prefix_tree_status = np.zeros([1], dtype=np.int32) self.prefix_tree_status_signal = IPCSignal( name="prefix_tree_status", array=prefix_tree_status, dtype=np.int32, suffix=current_suffix, create=True, ) kv_cache_status = np.zeros([1], dtype=np.int32) self.kv_cache_status_signal = IPCSignal( name="kv_cache_status", array=kv_cache_status, dtype=np.int32, suffix=current_suffix, create=True, ) # gpu_cache_lock: file-based lock for mutual exclusion between worker # and CPU transfer when accessing GPU KV cache. self.gpu_cache_lock = IPCLock( name="gpu_cache_lock", suffix=current_suffix, create=True, ) def start_worker_queue_service(self, start_queue): """ start queue service for engine worker communication """ if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port) else: address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock" if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": if start_queue: self.llm_logger.info(f"Starting engine worker queue server service at {address}") self.engine_worker_queue_server = EngineWorkerQueue( address=address, is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) # Dynamically updates the port value if an anonymous port is used if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: self.cfg.parallel_config.local_engine_worker_queue_port = ( self.engine_worker_queue_server.get_server_port() ) address = ( self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port, ) if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.llm_logger.info( f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}" ) self.cache_task_queue = EngineCacheQueue( address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port), authkey=b"cache_queue_service", is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port() self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id, ) def insert_tasks(self, tasks: List[Request], current_id=-1): """ Allocate resource and insert tasks to engine. Used in v0_kvcache_scheduler. """ if not isinstance(tasks, list): tasks = [tasks] self.resource_manager.check_and_free_block_tables() need_delete_tasks = [] for task in tasks: rid = task.request_id.split("_")[0] trace_carrier = task.trace_carrier if trace_carrier: tracing.trace_set_proc_propagate_context(rid, trace_carrier) task.trace_carrier = tracing.trace_get_proc_propagate_context(rid) if self.cfg.scheduler_config.splitwise_role == "prefill": status, msg = self.split_connector.check_decode_allocated(task) if status: task.metrics.ask_decode_resource_finish_time = time.time() else: self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, error_msg=msg, ) ] ) need_delete_tasks.append(task) continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) for item in tasks: trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, item.request_id, getattr(item, "user", "")) available_batch = np.sum(self.resource_manager.stop_flags) if len(tasks) > available_batch: self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.") self.llm_logger.error("The exceeded part will be ignored!") tasks = tasks[:available_batch] req_ids = [t.request_id for t in tasks] tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks) if not tasks: error_msg = f"The request required resources is exceed the limit, request id={req_ids}." self.llm_logger.error(error_msg) raise EngineError(error_msg, error_code=500) return False self.token_processor.number_of_tasks += len(tasks) is_decode = False is_prefill = False for i in range(len(tasks)): if tasks[i].disaggregate_info is not None: if self.cfg.scheduler_config.splitwise_role == "decode": is_decode = True else: is_prefill = True self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len if self.cfg.scheduler_config.splitwise_role == "prefill": self.split_connector.send_cache_info_to_messager(tasks, current_id) elif self.cfg.scheduler_config.splitwise_role == "decode": self.split_connector.send_cache_info_to_prefill(tasks) if not is_decode: self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") for task in tasks: if not getattr(task, "has_been_preempted_before", False): task.metrics.inference_start_time = time.time() tracing.trace_report_span( tracing.TraceSpanName.SCHEDULE, task.request_id.split("_")[0], int(task.metrics.scheduler_recv_req_time * 1e9), int(task.metrics.inference_start_time * 1e9), thread_finish_flag=True, ) trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "")) trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "")) else: trace_print( LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "") ) if not is_prefill: if not self.cfg.model_config.enable_mm: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) return True def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]): """ Decode insert prefilled requests into engine worker queue. Used in v0_kvcache_scheduler. Args: request_outputs: a list of RequestOutput sent by prefill instance """ to_infer_reqs = [] for req_out in request_outputs: solt_idx = self.resource_manager.req_dict[req_out.request_id] del self.resource_manager.req_dict[req_out.request_id] cur_req = self.resource_manager.tasks_list[solt_idx] if envs.FD_ENABLE_INTERNAL_ADAPTER: if not req_out.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue self.resource_manager.stop_flags[solt_idx] = True self.resource_manager.tasks_list[solt_idx] = None self.resource_manager._recycle_block_tables(cur_req) if req_out.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[req_out.request_id] self.llm_logger.warning(f"{req_out.request_id} need not decode after first token") continue cur_req.prompt_token_ids[0] = req_out.outputs.token_ids[0] cur_req.num_cached_tokens = req_out.num_cached_tokens req_out.metrics.decode_recv_req_time = cur_req.metrics.decode_recv_req_time req_out.metrics.decode_preallocate_req_time = cur_req.metrics.decode_preallocate_req_time cur_req.metrics = req_out.metrics cur_req.metrics.decode_inference_start_time = time.time() if ( self.cfg.speculative_config.method == SpecMethod.MTP and self.cfg.scheduler_config.splitwise_role == "decode" ): cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids) if req_out.error_code != 200: self.resource_manager.stop_flags[solt_idx] = True self.resource_manager.tasks_list[solt_idx] = None self.resource_manager._recycle_block_tables(cur_req) if req_out.request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[req_out.request_id] self.scheduler.put_results([req_out]) self.llm_logger.warning( f"{req_out.request_id} prefill failed with msg:{req_out.error_msg}, recycle resource." ) continue self.token_processor.tokens_counter[req_out.request_id] = 1 to_infer_reqs.append(cur_req) if to_infer_reqs: self.engine_worker_queue.put_tasks((to_infer_reqs, self.resource_manager.real_bsz)) self.llm_logger.debug(f"put requests to engine worker queue, task:{to_infer_reqs}") return True def task_is_finished(self, index): """ judge if the task is finished """ assert index < len(self.resource_manager.stop_flags) return self.resource_manager.stop_flags[index] def all_tasks_finished(self): """ judge if all tasks are finished """ return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags) def update_requests_chunk_size(self, requests): """ update each request's chunk size info """ def update_tokens(idx, chunk_size, update_chunk=False): nonlocal remain_batched_tokens, chunk_request_num if update_chunk: requests_chunk[idx][-1] += chunk_size else: requests_chunk[idx].append(chunk_size) remain_batched_tokens -= chunk_size current_request_size[idx] -= chunk_size if current_request_size[idx] <= 0: chunk_request_num -= 1 if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: return current_request_size = [request.prompt_token_ids_len for request in requests] requests_chunk = [[] for _ in range(len(requests))] chunk_request_num = len(current_request_size) while chunk_request_num >= 1: remain_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens for idx in range(len(current_request_size)): if current_request_size[idx] <= 0: continue chunk_size = min( current_request_size[idx], self.partial_chunked_tokens[chunk_request_num], ) update_tokens(idx, chunk_size) while remain_batched_tokens >= self.cfg.cache_config.block_size: # 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求 waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0] if len(waiting_requests) == 0: break available_tokens = ( remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size ) append_idx = current_request_size.index(min(waiting_requests)) chunk_size = min( current_request_size[append_idx], self.partial_chunked_tokens[chunk_request_num], available_tokens, ) update_tokens(append_idx, chunk_size, update_chunk=True) for idx in range(len(requests)): requests[idx].set("prefill_chunk_info", requests_chunk[idx]) def update_mm_requests_chunk_size(self, requests): """ update each multimodal request's chunk size info """ if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0: return for request in requests: inputs = request.multimodal_inputs # 兼容没有图片和视频的情况 if inputs["images"] is None: inputs["image_type_ids"] = np.array([], dtype="int32") inputs["grid_thw"] = np.array([], dtype="int64") inputs["images"] = np.array([], dtype="uint8") input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64") image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32") image_mask = input_ids == self.data_processor.image_patch_id image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32") image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"), dtype="int32") grid_thw = [] for one in inputs["grid_thw"]: if one[0] == 1: grid_thw.append(one) else: grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2)) grid_thw = paddle.to_tensor(grid_thw, dtype="int64") from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse chunk_image_num, chunk_seq_len = get_mm_split_fuse( input_ids, image_type_ids, image_token_sum, grid_thw, self.data_processor.image_patch_id, len(grid_thw), 0, len(input_ids), 0, self.partial_chunked_tokens[1], 2048, ) grid_thw = grid_thw.numpy().reshape([-1, 3]) num_chunks = len(chunk_image_num) chunks_info = [] input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0 for idx in range(num_chunks): chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0]) chunk_image_type_ids = inputs["image_type_ids"][ image_type_ids_st : image_type_ids_st + actual_image_num ] chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]] chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1)) chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num] chunk_position_ids = inputs["position_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]] chunks_info.append( { "input_ids": chunk_input_ids, "token_type_ids": chunk_token_type_ids, "image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None), "grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None), "images": (chunk_images if chunk_images.shape[0] else None), "position_ids": chunk_position_ids, } ) input_ids_st += chunk_seq_len[idx] image_type_ids_st += actual_image_num grid_thw_st += chunk_image_num[idx] patch_st += chunk_patch_num request.set("prefill_chunk_info", chunks_info) def _schedule_request_to_worker(self): """ Insert task to engine thread, monitor scheduler request queue. if the engine has resource, insert task to engine """ tracing.trace_set_thread_info("Scheduler Task to Work") current_id = 0 while getattr(self, "running", True): try: if self.resource_manager.available_batch() == 0: time.sleep(0.001) continue if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0: if ( self.cfg.scheduler_config.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks() ): time.sleep(0.005) continue if self.engine_worker_queue.num_cache_infos() > 0: time.sleep(0.001) continue if len(self.split_connector.current_request_ids) > 0: time.sleep(0.001) continue num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) self.resource_manager.check_and_free_block_tables() tasks = self.scheduler.get_requests( available_blocks=self.resource_manager.available_block_num(), block_size=self.cfg.cache_config.block_size, reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num, max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens, batch=num_prefill_batch, ) for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) if len(tasks) == 0: time.sleep(0.001) continue if self.cfg.scheduler_config.splitwise_role == "decode": # TODO: refine scheduler to remove this limitation # Decode will process and schedule the request sent by prefill to engine, # so the same request sent by the decode api server will be ignored continue self.llm_logger.debug(f"get tasks from scheduler: {tasks}") if self.cfg.scheduler_config.splitwise_role != "mixed": for task in tasks: task.metrics.ask_decode_resource_start_time = time.time() self.split_connector.send_splitwise_tasks(tasks, current_id) insert_successful = self.insert_tasks(tasks, current_id) if insert_successful: current_id = current_id + 1 else: continue main_process_metrics.num_requests_waiting.dec(len(tasks)) main_process_metrics.num_requests_running.inc(len(tasks)) except Exception as e: err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}." self.llm_logger.error(err_msg) def _schedule_request_to_worker_v1(self): """ Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ tracing.trace_set_thread_info("Scheduler Task to Work") get_request_pool = ThreadPoolExecutor(max_workers=1) is_fetching = False def _fetch_request(): try: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) nonlocal is_fetching num_prefill_batch = min( int(self.resource_manager.available_batch()), self.cfg.max_prefill_batch, ) if self.cfg.scheduler_config.splitwise_role != "mixed": max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens else: max_num_batched_tokens = self.cfg.model_config.max_model_len available_blocks = self.cfg.cache_config.max_block_num_per_seq tasks = self.scheduler.get_requests( available_blocks=available_blocks, block_size=self.cfg.cache_config.block_size, reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num max_num_batched_tokens=max_num_batched_tokens, batch=num_prefill_batch, ) for task in tasks: task.metrics.engine_get_req_time = time.time() trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) if self.cfg.scheduler_config.splitwise_role == "decode": # TODO: refine scheduler to remove this limitation # Decode will process and schedule the request sent by prefill to engine, # so the same request sent by the decode api server will be ignored is_fetching = False return if tasks: self.llm_logger.debug( f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" ) if self.cfg.scheduler_config.splitwise_role == "prefill": for task in tasks: # start async preprocess self.resource_manager.apply_async_preprocess(task) need_delete_tasks = [] if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: for task in tasks: # assure can allocate block ids in P while not self.resource_manager.preallocate_resource_in_p(task): time.sleep(0.005) self.llm_logger.debug( f"P has allocated resources and then ask D resource for request: {task.request_id}" ) task.metrics.ask_decode_resource_start_time = time.time() while True: self.split_connector.send_splitwise_tasks([task], task.idx) status, msg = self.split_connector.check_decode_allocated(task) if not status: self.llm_logger.warning( f"D failed to allocate resource for request {task.request_id}, try again." ) time.sleep(0.05) else: task.metrics.ask_decode_resource_finish_time = time.time() break self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") else: for task in tasks: # assure can allocate block ids in P while not self.resource_manager.preallocate_resource_in_p(task): time.sleep(0.005) self.llm_logger.debug( f"P has allocated resources and then ask D resource for req_id: {task.request_id}" ) task.metrics.ask_decode_resource_start_time = time.time() self.split_connector.send_splitwise_tasks([task], task.idx) for task in tasks: # assure fetch block ids from D status, msg = self.split_connector.check_decode_allocated(task) task.metrics.ask_decode_resource_finish_time = time.time() if not status: self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, error_msg=msg, ) ] ) need_delete_tasks.append(task) continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) # release resource in P self.resource_manager.pre_recycle_resource(tmp_task.request_id) # to send cache info to cache messager if tasks: need_check_req_ids = [task.request_id for task in tasks] self.split_connector.send_cache_info_to_messager(tasks, 0) # ensure cache tasks has sent to cache_messager need_check_req_ids = [task.request_id for task in tasks] finished_ids, delete_tasks_list = [], [] while need_check_req_ids: finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) self.llm_logger.debug( f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" ) if finished_ids: for task in tasks: result = self.resource_manager.waiting_async_process(task) if result is None: self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=task.error_code, error_msg=task.error_message, ) ] ) need_check_req_ids.remove(task.request_id) delete_tasks_list.append(task) elif result is False: if task.request_id in finished_ids: need_check_req_ids.remove(task.request_id) finished_ids.remove(task.request_id) else: time.sleep(0.001) for tmp_task in delete_tasks_list: tasks.remove(tmp_task) # release resource in P self.resource_manager.pre_recycle_resource(tmp_task.request_id) # Fetch requests and add them to the scheduling queue if tasks: for task in tasks: task.metrics.add_req_to_resource_manager_time = time.time() trace_print( LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") ) if self.cfg.scheduler_config.splitwise_role == "prefill": self.resource_manager.add_request_in_p(tasks) self.llm_logger.info( f"P add requests into running queue: {[task.request_id for task in tasks]}" ) else: for task in tasks: self.resource_manager.add_request(task) is_fetching = False except Exception as e: self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") is_fetching = False while self.running: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) try: if not is_fetching: # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. try: is_fetching = True get_request_pool.submit(_fetch_request) except RuntimeError as e: if "shutdown" in str(e): self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") break else: raise if self.cfg.scheduler_config.splitwise_role != "mixed": # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. # Once the forward pass finishes, these accumulated requests can be scheduled in larger, # more efficient batches. if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0: time.sleep(0.001) continue else: # In mixed, todo: optimze cache swap, to decouple swap from scheduler if self.engine_worker_queue.exist_tasks(): time.sleep(0.001) continue if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() # 2. Schedule requests tasks, error_tasks = self.resource_manager.schedule() # 3. Send to engine if tasks: if self.cfg.scheduler_config.splitwise_role == "decode": for task in tasks: if task.task_type == RequestType.PREEMPTED: msg = f"{task.request_id} decode not enough blocks, need to be rescheduled." self.llm_logger.error(msg) self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, error_code=500, error_msg=msg, ) ] ) self.resource_manager.get_real_bsz() for task in tasks: if task.task_type == RequestType.PREFILL: rid = task.request_id.split("_")[0] if isinstance(task, Request) and task.has_been_preempted_before: trace_print( LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", ""), ) else: trace_carrier = task.trace_carrier tracing.trace_set_proc_propagate_context(rid, trace_carrier) trace_carrier = tracing.trace_get_proc_propagate_context(rid) task.trace_carrier = trace_carrier tracing.trace_report_span( tracing.TraceSpanName.SCHEDULE, rid, int(task.metrics.scheduler_recv_req_time * 1e9), int(time.time() * 1e9), thread_finish_flag=True, ) trace_print( LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", "") ) trace_print( LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "") ) trace_print( LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "") ) if isinstance(task, Request): if self.cfg.scheduler_config.splitwise_role == "decode": task.metrics.decode_inference_start_time = time.time() elif not task.has_been_preempted_before: task.metrics.inference_start_time = time.time() self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) else: # When there are no actual tasks to schedule, send an empty task batch to EP workers. # This helps EP workers barrier for syncing tasks not hang. if self.cfg.parallel_config.enable_expert_parallel: self.engine_worker_queue.put_tasks( ([], self.resource_manager.real_bsz) ) # Empty (as idle tasks for ep) # 4. Response error tasks if error_tasks: for request_id, failed in error_tasks: if failed is None: self.llm_logger.warning(f"Request {request_id} has no error, skip sending error response.") continue self._send_error_response(request_id, failed) if not tasks and not error_tasks: time.sleep(0.005) except RuntimeError as e: raise e except Exception as e: err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) def _get_scheduler_unhandled_request_num(self) -> int: """ Get scheduler-level pending request count when supported. """ get_unhandled = getattr(self.scheduler, "get_unhandled_request_num", None) if not callable(get_unhandled): return 0 try: unhandled = int(get_unhandled()) except Exception as e: self.llm_logger.debug(f"Failed to get scheduler unhandled request num: {e}") return 0 return max(unhandled, 0) def start_zmq_service(self, api_server_pid=None): if api_server_pid is None: return self.api_server_pid = api_server_pid if envs.FD_ENABLE_INTERNAL_ADAPTER: self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL) self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER) self.internal_adapter = InternalAdapter( cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id ) # ROUTER mode: need to receive client handles self.recv_result_handle_thread = threading.Thread( target=self.send_response_server.recv_result_handle, daemon=True ) self.recv_result_handle_thread.start() else: self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL) if envs.ZMQ_SEND_BATCH_DATA: # PUSH mode: batch send, no need to receive client handles self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PUSH) # Mapping from request_id to worker_pid for routing batch responses self.request_worker_map = {} else: # ROUTER mode: per-query send, need to receive client handles self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER) self.recv_result_handle_thread = threading.Thread( target=self.send_response_server.recv_result_handle, daemon=True ) self.recv_result_handle_thread.start() time.sleep(3) self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True) self.insert_task_to_scheduler_thread.start() self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True) self.receive_output_thread.start() def _insert_zmq_task_to_scheduler(self): tracing.trace_set_thread_info("Insert Task to Scheduler") added_requests: Dict[str, int] = dict() if envs.FD_ENABLE_INTERNAL_ADAPTER: if self.cfg.scheduler_config.splitwise_role == "decode": return while self.running: try: block = True if len(added_requests) == 0 else False if not self.cfg.model_config.enable_mm: err, data = self.recv_request_server.receive_json_once(block) else: err, data = self.recv_request_server.receive_pyobj_once(block) if err is not None: # The message "Context was terminated" is normal when closing a ZMQ context if "Context was terminated" in str(err): self.llm_logger.info( "Engine stops inserting zmq task into scheduler due to ZMQ context termination (normal shutdown)." ) else: self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}") if envs.FD_ENABLE_INTERNAL_ADAPTER: self.recv_request_server = ZmqTcpServer( port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL ) else: self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL) continue # Extract zmq_worker_pid for per-worker PUSH routing. # Only needed when ZMQ_SEND_BATCH_DATA=True AND not using internal adapter, # because FD_ENABLE_INTERNAL_ADAPTER uses ROUTER (worker_pid is irrelevant). worker_pid = None if envs.ZMQ_SEND_BATCH_DATA and not envs.FD_ENABLE_INTERNAL_ADAPTER: worker_pid = data["zmq_worker_pid"] if ControlRequest.is_control_request(data): try: # todo: run control request async, do not block request generation if worker_pid is not None: self.request_worker_map[data.get("request_id")] = worker_pid control_req = ControlRequest.from_dict(data) self.run_control_method(control_req) except Exception as e: self.llm_logger.error( f"Failed to process control request {data.get('request_id')}: " f"{e}, {traceback.format_exc()}" ) continue request, insert_task = data, [] results: List[Tuple[str, Optional[str]]] = list() if data: # Store worker_pid mapping for normal/abort requests if worker_pid is not None: req_id_for_map = data.get("request_id") if req_id_for_map: self.request_worker_map[req_id_for_map] = worker_pid status_value = data.get("status", None) if status_value is not None and status_value == RequestStatus.ABORT.value: req_id = data["request_id"] self.llm_logger.info(f"Receive abort request, req_id: {req_id}") if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.add_abort_req_ids(req_id) continue err_msg = None try: request = Request.from_dict(data) request.metrics.scheduler_recv_req_time = time.time() main_process_metrics.requests_number.inc() trace_carrier = data.get("trace_carrier") if trace_carrier: request_id = data["request_id"].split("_")[0] tracing.trace_set_proc_propagate_context(request_id, trace_carrier) trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) self.llm_logger.debug(f"Receive request from api server: {request}") if self.is_paused: self.llm_logger.warning(f"Engine is paused, drop request: {request}") self._send_error_response( request.request_id, "Request is aborted since LLM Engine is paused.", worker_pid=worker_pid, ) continue except Exception as e: self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") err_msg = str(e) results.append((data["request_id"], err_msg)) if self.guided_decoding_checker is not None and err_msg is None: request, err_msg = self.guided_decoding_checker.schema_format(request) if err_msg is not None: self.llm_logger.error(f"Receive request error: {err_msg}") results.append((request.request_id, err_msg)) if err_msg is None: insert_task.append(request) response = self.scheduler.put_requests(insert_task) results.extend(response) if request: if request.request_id not in added_requests: added_requests[request.request_id] = 0 added_requests[request.request_id] += 1 for request_id, failed in results: if request_id in added_requests: added_requests[request_id] -= 1 if added_requests[request_id] == 0: added_requests.pop(request_id) if failed is None: main_process_metrics.num_requests_waiting.inc(1) continue self._send_error_response(request_id, failed) except Exception as e: self.llm_logger.error( f"Error happened while receiving new request from zmq, details={e}, " f"traceback={traceback.format_exc()}" ) def run_control_method(self, control_req: ControlRequest): """ Execute control method, process control request and return response. This method is responsible for handling control requests, calling the corresponding handler function based on the method name in the request. If the method doesn't exist or is not callable, it returns an error response; otherwise executes the method and returns a success response. Args: control_req (ControlRequest): Control request object containing request ID, method name and parameters. Returns: None: No return value, sends ControlResponse through send_response_server. """ method = control_req.get_method() request_id = control_req.request_id # Look up worker_pid for routing control response worker_pid = None if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"): worker_pid = self.request_worker_map.pop(request_id, None) try: self.llm_logger.info(f"Start to run control method {method}: {request_id}") handler_name = f"_control_{method}" handler = getattr(self, handler_name, None) if handler is None or not callable(handler): error_result = ControlResponse(request_id, 400, f"unknown control method:{method}") self.llm_logger.error(str(error_result)) data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result] self.send_response_server.send_response(request_id, data, worker_pid=worker_pid) return result = handler(control_req) self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}") succ_result = ControlResponse(request_id, 200, "Success", result) data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result] self.send_response_server.send_response(request_id, data, worker_pid=worker_pid) except Exception as e: error_msg = f"Failed to run control method {method}: {request_id} {str(e)}" self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}") error_result = ControlResponse(request_id, 500, error_msg) data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result] self.send_response_server.send_response(request_id, data, worker_pid=worker_pid) def _control_pause(self, control_request: ControlRequest): """Pauses the LLM engine and aborts all running/inflight requests. Args: control_request: The control request containing pause command Raises: Exception: If pause is not supported in current configuration Exception: If engine worker queue cleanup times out Returns: None """ if not envs.ENABLE_V1_KVCACHE_SCHEDULER: raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER") if self.cfg.scheduler_config.name != "local": raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}") self.llm_logger.info("Start to pause request generation.") with self._pause_cond: if self.is_paused: self.llm_logger.info("Engine is already paused, no need to pause again.") return self.is_paused = True self.llm_logger.info("Abort running requests.") self.resource_manager.log_status() # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue timeout, count = 60, 0 while self.engine_worker_queue.exist_tasks(): time.sleep(0.001) count += 1 if count >= timeout * 1000: break if count >= timeout * 1000: error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" self.llm_logger.error(error_msg) raise Exception(error_msg) running_reqs = self.resource_manager.preempted_all() if len(running_reqs) > 0: self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") self.resource_manager.get_real_bsz() self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) # self.engine_worker_queue.clear_data() self.token_processor.clear_data() self.resource_manager.log_status() # abort inflight requests to user inflight_requests = self.scheduler.get_inflight_requests() self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") for req in inflight_requests: self._send_error_response(req.request_id, "Request is aborted since engine is paused.") self.scheduler.reset() # pause cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to pause cache transfer.") pause_transfer_request = ControlRequest( request_id=f"{control_request.request_id}_pause_transfer", method="pause" ) self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request)) # Wait for cache_transfer responses asyncio.run( self._wait_for_control_responses( f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"] ) ) self.llm_logger.info("Successfully paused cache transfer.") self.resource_manager.cache_manager.reset() self.llm_logger.info("Successfully paused request generation.") return None def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. This method resumes the paused request generation process by setting the pause flag and notifying all waiting threads. It logs the start and end of the resume operation. Args: control_request: Control request object containing resume operation information """ self.llm_logger.info("Start to resume request generation.") with self._pause_cond: if not self.is_paused: self.llm_logger.info("Engine is not paused, no need to resume.") return None self.is_paused = False self._pause_cond.notify_all() # resume cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to resume cache transfer.") resume_transfer_request = ControlRequest( request_id=f"{control_request.request_id}_resume_transfer", method="resume" ) self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request)) # Wait for cache_transfer responses asyncio.run( self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"]) ) self.llm_logger.info("Successfully resumed cache transfer.") self.llm_logger.info("Successfully resumed request generation.") return None def _control_is_paused(self, control_request: ControlRequest) -> bool: """ Check if the LLM engine is in paused state. Args: control_request: Control request object. Returns: dict: Dictionary containing pause status information, {'is_paused': bool} """ self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}") with self._pause_cond: return {"is_paused": self.is_paused} def _get_is_paused_safe(self) -> bool: """Thread-safe getter for is_paused state, used by RegisterManager.""" with self._pause_cond: return self.is_paused def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]: """Update model weights Args: control_request: Control request object containing parameters for weight updates Returns: Optional[dict]: Returns the result dictionary if update succeeds, None otherwise Raises: Exception: Raised when the engine is not in paused state """ self.llm_logger.info("Update Model Weights") with self._pause_cond: if self.is_paused is False: error_msg = "Pause LLM Engine first before calling updating weights" self.llm_logger.error(error_msg) raise Exception(error_msg) responses = self._call_worker(control_request, 60) if responses: new_version = None for resp in responses: # Expect each worker response to be a dict-like object if isinstance(resp, dict) and "version" in resp: new_version = resp.get("version") self.llm_logger.info(f"Update Weights Version in Config: {new_version}") break if new_version is not None: self.cfg.model_config.version = new_version if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.llm_logger.info("Start to update cache-transfer metadata after weight update.") update_cache_request = ControlRequest( request_id=f"{control_request.request_id}_update_weights", method="update_weights", args=copy.deepcopy(control_request.args), ) self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request)) asyncio.run( self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"]) ) self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.") return responses def _control_abort_requests(self, control_req: ControlRequest): if not envs.ENABLE_V1_KVCACHE_SCHEDULER: raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") args = control_req.get_args() abort_all = args.get("abort_all", False) req_ids = args.get("req_ids", []) matched_input_ids = set() now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) # Step 1: Determine target request list if abort_all: # all requests in running + waiting target_req_ids = now_reqs else: # filter out requests that actually exist target_req_ids = [] for rid in req_ids: if rid in now_reqs: target_req_ids.append(rid) matched_input_ids.add(rid) elif f"{rid}_0" in now_reqs: target_req_ids.append(f"{rid}_0") matched_input_ids.add(rid) if not target_req_ids: return {"aborted": [], "not_found": req_ids if not abort_all else []} # Step 2: Collect partial results aborted_info = [] results = [] for req_id in target_req_ids: request = self.resource_manager.requests.get(req_id) if request is None: scheduled_req = self.scheduler.requests.get(req_id) if scheduled_req is None: continue request = scheduled_req.raw partial_token_ids = list(request.output_token_ids) # Construct finished response with partial results now = time.time() abort_metrics = RequestMetrics( arrival_time=request.metrics.arrival_time if request.metrics else now, inference_start_time=request.metrics.inference_start_time if request.metrics else now, engine_recv_latest_token_time=now, engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, request_start_time=request.metrics.arrival_time if request.metrics else now, ) result = RequestOutput( request_id=req_id, finished=True, outputs=CompletionOutput( index=0, send_idx=len(partial_token_ids), token_ids=[self.data_processor.eos_token_ids[0]], ), metrics=abort_metrics, error_code=200, error_msg="Aborted", ) results.append(result) aborted_info.append( { "request_id": req_id, "output_token_count": len(partial_token_ids), } ) # Step 3: Execute abort — add all requests to waiting_abort_req_id_set if envs.ENABLE_V1_KVCACHE_SCHEDULER: for req_id in target_req_ids: self.resource_manager.add_abort_req_ids(req_id) time.sleep(0.0001) if self.cfg.scheduler_config.splitwise_role != "prefill": self._wait_abort_complete(target_req_ids) # Add results to scheduler, engine will have a thread calling get_results, # then cleanup and call send_response to send to client. # When client disconnects, send_response will automatically ignore if self.cfg.scheduler_config.splitwise_role != "prefill": try: # self.send_response_server.send_response(req_id, [result]) self.scheduler.put_results(results) except Exception: pass # client may have disconnected not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] return {"aborted": aborted_info, "not_found": not_found} def _wait_abort_complete(self, target_req_ids, stall_timeout=1): """ Wait for all abort requests to complete. - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, reset progress state if any, then continue monitoring """ target_set = set(target_req_ids) prev_remaining_count = len(target_set) last_progress_time = time.time() remaining = target_set & self.resource_manager.get_reqs_in_aborting() while remaining: remaining = target_set & self.resource_manager.get_reqs_in_aborting() if not remaining: self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") return current_count = len(remaining) if current_count < prev_remaining_count: # progress made: recycle_abort_task was called self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") last_progress_time = time.time() prev_remaining_count = current_count if time.time() - last_progress_time > stall_timeout: # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) stuck = remaining & self.resource_manager.to_be_aborted_req_id_set if stuck: self.llm_logger.warning( f"no abort progress for {stall_timeout}s, " f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" ) for req_id in list(stuck): self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") self.resource_manager.recycle_abort_task(req_id) # reset progress state last_progress_time = time.time() prev_remaining_count = current_count - len(stuck) # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow time.sleep(0.005) def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. """ allowed_tags = ["weight", "kv_cache"] tags = control_request.args.get("tags", None) if tags is None: tags = ",".join(allowed_tags) control_request.args["tags"] = tags self.llm_logger.info( f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}" ) elif isinstance(tags, list): tags = ",".join(tags) for tag in tags.split(","): if tag not in allowed_tags: raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}") return tags def _control_sleep(self, control_request: ControlRequest): """ Offload gpu memory occupation for certain parts, e.g. weight, cache. Args: control_request: Control request object containing parameters for offloading memory tags: list of tags to offload, supported values: ["weight", "cache"] TODO: support different level of offloading, to provide options for release memory forever or merely offloading to cpu memory for now. """ # Args check tags = self._parse_tags(control_request) control_request.args["tags"] = tags # Make sure llm engine is paused. self.llm_logger.warning( "Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. " "Please explicitly request to /pause the engine before /sleep." ) self._control_pause(None) # Determine which executors are needed for the sleep command executors = set() if "weight" in tags: executors.add("worker") if "kv_cache" in tags: executors.add("worker") if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: executors.add("cache_transfer") if self.cfg.cache_config.enable_prefix_caching: self.resource_manager.cache_manager.reset() # Dispatch sleep request to executors self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}") self._dispatch_control_request(control_request, executors) return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors)) def _control_wakeup(self, control_request: ControlRequest): """ Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache. Args: control_request: Control request object containing parameters for reloading memory tags: list of tags to reload, supported values: ["weight", "kv_cache"] """ # Args check tags = self._parse_tags(control_request) control_request.args["tags"] = tags # Determine which executors are needed for the wakeup command executors = set() if "weight" in tags: executors.add("worker") if "kv_cache" in tags: executors.add("worker") if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: executors.add("cache_transfer") # Dispatch wakeup request to executors self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}") self._dispatch_control_request(control_request, executors) result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors)) # Resume the engine after wakeup self._control_resume(None) return result def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]): """ Dispatch control requests to workers, cache managers or engine itself. Args: control_request: ControlRequest executors: List """ if "worker" in executors: self.engine_worker_queue.put_tasks(([control_request], 1)) if "cache_transfer" in executors: if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request)) return async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None): """Wait for matching control responses from the selected executor queues. This helper selects the control-response queues that belong to the requested executors, then waits for all of them concurrently. Each queue gets a local waiter that keeps reading until it sees the target request ID and stashes stale responses into that queue's mailbox. Args: request_id: The control request ID that all returned responses must match. timeout: Global timeout budget in seconds for the full multi-queue wait. executors: Executor groups to wait for, for example `["worker"]` or `["worker", "cache_transfer"]`. If `None`, waits for all control response queues. Returns: A list of `response.result` values collected from all matched `ControlResponse` objects. If no queue is selected, returns `None`. Raises: Exception: If the overall wait times out, or if any queue reports a non-200 control response or fails while waiting. """ def select_control_queues(executors: List[str] = None): """Select control response queues by executors.""" if executors is None: return self._ctrl_output_queues else: queues = {} for k, v in self._ctrl_output_queues.items(): if "w2e" in k and "worker" in executors: queues[k] = v elif "c2e" in k and "cache_transfer" in executors: queues[k] = v return queues async def wait_one(queue_name: str, queue): """Wait until one queue returns a response for the current request_id.""" mailbox = self._ctrl_response_mailboxes[queue_name] # Reuse a previously stashed response for this request before touching FMQ again. cached_response = mailbox.pop(request_id, None) if cached_response is not None: self.llm_logger.info(f"Returning cached control response from {queue_name}.") return cached_response while True: msg = await queue.get() # Return if the response matches the control request response: ControlResponse = msg.payload if response.request_id == request_id: self.llm_logger.info(f"Returning new control response from {queue_name}.") return response # Stash late responses from other control requests so they do not consume the # current request's only read chance on this queue. mailbox[response.request_id] = response self.llm_logger.info( f"Stashed old control response from {queue_name}. " f"Expected request {request_id}, got request {response.request_id}" ) # Select only the control response queues that belong to the requested executors. queues = select_control_queues(executors) if not queues: self.llm_logger.info(f"No queues to wait for, executors: {executors}") return self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}") # Each queue gets its own waiter, which will stash stale responses until it finds the # target request ID for this control request. tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()} done, pending = await asyncio.wait(tasks.values(), timeout=timeout) if pending: pending_names = [name for name, task in tasks.items() if task in pending] done_names = [name for name, task in tasks.items() if task in done] self.llm_logger.error( f"Control request {request_id} execution timeout. " f"Pending queues: {pending_names}, completed queues: {done_names}." ) # Stop unfinished queue waiters so they do not outlive the control request. for task in pending: task.cancel() await asyncio.gather(*pending, return_exceptions=True) raise Exception(f"Control request {request_id} timed out after {timeout}s") # Collect the results from all completed queues. responses = [] for name, task in tasks.items(): try: response = task.result() except Exception as e: self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}") raise if response.error_code != 200: raise Exception(f"Error response from {name}: {response.error_message}") responses.append(response.result) return responses def _call_worker(self, control_request: ControlRequest, timeout: int): request_id = control_request.request_id self.engine_worker_queue.put_tasks(([control_request], 1)) # Use a single asyncio.run() to concurrently wait for all worker responses. return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"])) def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None): self.llm_logger.error( f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}" ) error_result = RequestOutput( request_id=request_id, finished=True, error_code=error_code, error_msg=error_msg, ) # Look up worker_pid from mapping if not provided if worker_pid is None and envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"): worker_pid = self.request_worker_map.pop(request_id, None) # Since the request is not in scheduler # Send result by zmq directly if envs.FD_ENABLE_INTERNAL_ADAPTER: self.send_response_server.send_response(None, [[error_result]]) elif envs.ZMQ_SEND_BATCH_DATA: self.send_response_server.send_response(None, [[error_result]], worker_pid=worker_pid) else: self.send_response_server.send_response(request_id, [error_result]) def _decode_token(self, token_ids, req_id, is_end): delta_text = "" if envs.FD_ENABLE_RETURN_TEXT: delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id) if delta_text != "": prefix_offset = self.data_processor.decode_status[req_id][0] read_offset = self.data_processor.decode_status[req_id][1] token_ids = cum_tokens[prefix_offset:read_offset] else: token_ids = [] if is_end: del self.data_processor.decode_status[req_id] return delta_text, token_ids def _zmq_send_generated_tokens(self): """ Receive output for zmq """ while self.running: try: results = self.scheduler.get_results() if len(results) == 0: time.sleep(0.005) continue if envs.FD_ENABLE_INTERNAL_ADAPTER: new_contents = [] for step_batch_results in results: new_step_contents = [] for content in step_batch_results: if isinstance(content, RequestOutput) and content.outputs is not None: decode_type = content.outputs.decode_type delta_text = "" if decode_type == 0: delta_text, token_ids = self._decode_token( token_ids=content.outputs.token_ids, req_id=content.request_id, is_end=content.finished, ) else: token_ids = content.outputs.token_ids if len(token_ids): content.outputs.token_ids = token_ids content.outputs.text = delta_text new_step_contents.append(content) elif content.finished: new_step_contents.append(content) else: self.llm_logger.warning( f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}" ) else: new_step_contents.append(content) if new_step_contents: new_contents.append(new_step_contents) if new_contents: self.send_response_server.send_response(None, new_contents) else: worker_batches = collections.defaultdict(list) for request_id, contents in results.items(): new_contents = [] for content in contents: if isinstance(content, RequestOutput) and content.outputs is not None: decode_type = content.outputs.decode_type delta_text = "" if decode_type == 0: delta_text, token_ids = self._decode_token( token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished, ) else: token_ids = content.outputs.token_ids if len(token_ids): content.outputs.token_ids = token_ids content.outputs.text = delta_text new_contents.append(content) elif content.finished: new_contents.append(content) else: self.llm_logger.warning( f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}" ) else: new_contents.append(content) if new_contents: if envs.ZMQ_SEND_BATCH_DATA: wpid = self.request_worker_map.get(request_id) worker_batches[wpid].append(new_contents) is_finished = any(getattr(c, "finished", False) for c in new_contents) if is_finished: self.request_worker_map.pop(request_id, None) else: self.send_response_server.send_response(request_id, new_contents) if envs.ZMQ_SEND_BATCH_DATA: for wpid, batch_data in worker_batches.items(): if batch_data: self.send_response_server.send_response(None, batch_data, worker_pid=wpid) except Exception as e: self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") def _decode_process_splitwise_requests(self): """ Decode processes requests from engine worker queue, which are sent by prefill. TODO: merge this function to the schedule function in resource manager """ allocate_resource_requests: list[Request] = [] prefilled_request_ouputs: list[RequestOutput] = [] def _fetch_requests(): if self.engine_worker_queue.disaggregate_queue_empty(): return items = self.engine_worker_queue.get_disaggregated_tasks() for item in items: tasks = item[1] if isinstance(tasks[0], Request): self.llm_logger.debug( f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}" ) for task in tasks: task.metrics.decode_recv_req_time = time.time() allocate_resource_requests.extend(tasks) elif isinstance(tasks[0], RequestOutput): self.llm_logger.debug( f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}" ) if not isinstance(tasks, list): tasks = [tasks] for task in tasks: task.finished = False task.metrics.decode_recv_first_token_time = time.time() prefilled_request_ouputs.extend(tasks) def _process_allocate_resource_requests(): processed_indices = [] for idx, task in enumerate(allocate_resource_requests): is_success = False if envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.resource_manager.preallocate_resource_in_d(task): task.metrics.decode_preallocate_req_time = time.time() self.llm_logger.info(f"Resource available, processing task {task.request_id}") self.split_connector.send_cache_info_to_prefill([task]) self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}") processed_indices.append(idx) is_success = True else: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.llm_logger.debug(f"D Resource available, processing task {task.request_id}") self.insert_tasks([task]) task.metrics.decode_preallocate_req_time = time.time() processed_indices.append(idx) is_success = True if not is_success: if not self.enable_decode_cache_task: task.error_msg = "Not enough resources" self.split_connector.send_cache_info_to_prefill([task]) self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}") processed_indices.append(idx) else: self.llm_logger.debug(f"Still waiting for resources {task.request_id}") break for idx in sorted(processed_indices, reverse=True): allocate_resource_requests.pop(idx) def _process_prefilled_requests(): nonlocal prefilled_request_ouputs ready_request_outputs = [] waiting_request_outputs = [] for req_output in prefilled_request_ouputs: if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id): # ensure the api_server and scheduler in decode have # received the request sent by the client waiting_request_outputs.append(req_output) continue req_output.finished = False ready_request_outputs.append(req_output) self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}") prefilled_request_ouputs = waiting_request_outputs if self.cfg.splitwise_version == "v1": # decode return first token to client self.scheduler.put_results(ready_request_outputs) if not envs.ENABLE_V1_KVCACHE_SCHEDULER: self._insert_prefilled_requests(ready_request_outputs) else: for req_output in ready_request_outputs: request_id = req_output.request_id if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue self.llm_logger.warning(f"{request_id} need not decode after first token") self.resource_manager.pre_recycle_resource(request_id) if request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[request_id] req_output.finished = True self.scheduler.put_results([req_output]) continue if req_output.error_code != 200: self.llm_logger.warning( f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource." ) self.resource_manager.pre_recycle_resource(request_id) if request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[request_id] self.scheduler.put_results([req_output]) continue self.token_processor.tokens_counter[request_id] = 1 if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance self.scheduler.put_results([req_output]) self.resource_manager.add_prefilled_request(req_output) self.llm_logger.info(f"D has successfully added prefilled request, {request_id}") def decode_loop(): while self.running: try: _fetch_requests() _process_allocate_resource_requests() _process_prefilled_requests() time.sleep(0.001) except Exception as e: self.llm_logger.error( f"Error in main loop of decode_process_splitwise_requests: " f"{e}, {traceback.format_exc()}" ) time.sleep(0.01) threading.Thread(target=decode_loop, daemon=True).start() def start_cache_service(self, device_ids, ipc_signal_suffix): console_logger.debug("Start cache manager...") return self.resource_manager.cache_manager.launch_cache_manager( cache_config=self.cfg.cache_config, tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size, device_ids=device_ids, pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.parallel_config.local_engine_worker_queue_port, ipc_suffix=ipc_signal_suffix, create_cache_tensor=False, ) def check_and_free_block_tables(self): self.resource_manager.check_and_free_block_tables() def clear_data(self): try: self.llm_logger.info("Clear Data: Start") self.token_processor.clear_data() self.engine_worker_queue.clear_data() if hasattr(self, "cache_task_queue"): self.cache_task_queue.clear_transfer_task() self.send_response_server.req_dict.clear() self.recv_request_server.req_dict.clear() # Clean up worker_pid mapping (batch mode) if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"): self.request_worker_map.clear() self.llm_logger.info("Clear Data: Successfully") return True except Exception as e: self.llm_logger.error(f"Clear data error: {e}") return False def _exit_sub_services(self): """ exit sub services """ self.llm_logger.info("Exit sub services.....") self.running = False if self.use_async_llm: # Clean up worker processes first (before closing multiprocessing services) if hasattr(self, "worker_proc") and self.worker_proc is not None: self.llm_logger.info("Cleaning up worker processes...") try: pgid = os.getpgid(self.worker_proc.pid) os.killpg(pgid, signal.SIGTERM) except Exception as e: self.llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}") # Clean up cache manager processes if hasattr(self, "cache_manager_processes"): self.llm_logger.info("Cleaning up cache manager processes...") self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear() self.resource_manager.cache_manager.cache_ready_signal.clear() for p in self.cache_manager_processes: self.llm_logger.info(f"Killing cache manager process {p.pid}") try: pgid = os.getpgid(p.pid) os.killpg(pgid, signal.SIGTERM) except Exception as e: self.llm_logger.error( f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}" ) if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None: self.llm_logger.info("Cleaning up cache_task_queue...") # Check if cleanup method exists if hasattr(self.cache_task_queue, "cleanup"): self.cache_task_queue.cleanup() elif hasattr(self.cache_task_queue, "manager"): try: self.llm_logger.info("Shutting down cache_task_queue manager...") self.cache_task_queue.manager.shutdown() except Exception as e: self.llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}") if hasattr(self, "get_profile_block_num_signal"): self.get_profile_block_num_signal.clear() self.worker_ready_signal.clear() self.loaded_model_signal.clear() # Clean up other services if hasattr(self, "dp_processed"): for p in self.dp_processed: self.llm_logger.info(f"Waiting for worker {p.pid} to exit") p.join() for p in self.dp_engine_worker_queue_server: p.cleanup() if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None: self.engine_worker_queue_server.cleanup() self.exist_task_signal.clear() self.exist_swapped_task_signal.clear() self.worker_healthy_live_signal.clear() self.cache_ready_signal.clear() self.swap_space_ready_signal.clear() self.cache_transfer_inited_signal.clear() self.exist_prefill_task_signal.clear() self.model_weights_status_signal.clear() self.prefix_tree_status_signal.clear() self.kv_cache_status_signal.clear() if hasattr(self, "send_response_server") and self.send_response_server is not None: self.send_response_server.close() if hasattr(self, "recv_request_server") and self.recv_request_server is not None: self.recv_request_server.close() if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None: self.recv_control_cmd_server.close() # 从 async_llm 移到 common_engine def _worker_processes_ready(self): """ judge if all worker processes are ready """ if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node: return True return False def _init_worker_signals(self): """ Initialize shared memory to indicate engine status """ # worker_ready_signal 用于worker进程感知engine是否启动完成 worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) self.worker_ready_signal = IPCSignal( name="worker_ready_signal", array=worker_ready_signal_data, dtype=np.int32, suffix=self.ipc_signal_suffix, create=True, ) # launched_cache_manager_signal 用于感知engine是否启动了cache_manager if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( name="launched_cache_manager_signal", array=launched_cache_manager_signal_data, dtype=np.int32, suffix=self.ipc_signal_suffix, create=True, ) # launched_expert_service_signal: Used to sense whether each expert_service is started successfully if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: launched_expert_service_signal_data = np.zeros( shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 ) self.launched_expert_service_signal = IPCSignal( name="launched_expert_service_signal", array=launched_expert_service_signal_data, dtype=np.int32, suffix=self.ipc_signal_suffix, create=True, ) # loaded_model_signal: Used to detect whether each worker has completed model loading loaded_model_signal_data = np.zeros([1], dtype=np.int32) self.loaded_model_signal = IPCSignal( name="loaded_model_signal", array=loaded_model_signal_data, dtype=np.int32, suffix=self.ipc_signal_suffix, create=True, ) if self.do_profile: if paddle.is_compiled_with_custom_device("iluvatar_gpu"): get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32) else: get_profile_block_num = np.zeros([1], dtype=np.int32) self.get_profile_block_num_signal = IPCSignal( name="get_profile_block_num", array=get_profile_block_num, dtype=np.int32, suffix=self.ipc_signal_suffix, create=True, ) def _setting_environ_variables(self): """ 配置环境变量 """ variables = { "ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0, "LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")), "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", "FLAGS_use_append_attn": 1, "NCCL_ALGO": "Ring", "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), "OMP_NUM_THREADS": 3, } # environment variables needed by Dy2St variables.update( { "SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"), "SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), "SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), "SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"), "FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv( "FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1" ), "FLAGS_parameters_persistent_mode_in_dy2st": os.getenv( "FLAGS_parameters_persistent_mode_in_dy2st", default="1" ), } ) if self.cfg.scheduler_config.splitwise_role != "mixed": if envs.ENABLE_V1_KVCACHE_SCHEDULER: variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1 else: variables["FLAGS_use_pd_disaggregation"] = 1 # TODO dynamic load environment variable if self.cfg.scheduler_config.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 if self.cfg.model_config.enable_mm: variables["FLAGS_max_partition_size"] = 1024 command_prefix = "" for k, v in variables.items(): command_prefix += f"{k}={v} " return command_prefix def _start_worker_service(self): """ start gpu worker service """ log_dir = os.getenv("FD_LOG_DIR", default="log") command_prefix = self._setting_environ_variables() current_file_path = os.path.abspath(__file__) current_dir_path = os.path.split(current_file_path)[0] # TODO uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u" pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch" pd_cmd = pd_cmd + f" --log_dir {log_dir}" worker_path = "../worker/worker_process.py" py_script = os.path.join(current_dir_path, worker_path) ori_vocab_size = ( len(self.data_processor.tokenizer.sp_model) if hasattr(self.data_processor.tokenizer, "sp_model") else len(self.data_processor.tokenizer.vocab) ) think_start_id = self.data_processor.tokenizer.get_vocab().get("", -1) if think_start_id >= 0: self.llm_logger.info(f"Get think_start_id {think_start_id} from vocab.") else: self.llm_logger.info("No token found in vocabulary, the model can not do reasoning.") think_end_id = self.data_processor.tokenizer.get_vocab().get("", -1) if think_end_id >= 0: self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.") else: self.llm_logger.info("No token found in vocabulary, the model can not do reasoning.") image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1) line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1) if line_break_id < 0: line_break_ids = self.data_processor.tokenizer.encode("\n", add_special_tokens=False) if isinstance(line_break_ids, dict): line_break_ids = line_break_ids.get("input_ids") elif hasattr(line_break_ids, "input_ids"): line_break_ids = line_break_ids.input_ids if line_break_ids: if isinstance(line_break_ids, (list, tuple)): first = line_break_ids[0] if isinstance(first, (list, tuple)): line_break_id = int(first[0]) if first else -1 else: line_break_id = int(first) else: line_break_id = int(line_break_ids) if line_break_id >= 0: self.llm_logger.info(f"Get line_break_id {line_break_id} from tokenizer.") ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port)) ips = None if self.cfg.ips is not None: ips = ",".join(self.cfg.ips) arguments = ( f" --devices {self.cfg.parallel_config.device_ids} {py_script}" f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}" f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}" f" --model {self.cfg.model_config.model!s}" f" --device_ids {self.cfg.parallel_config.device_ids}" f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}" f" --engine_worker_queue_port {ports}" f" --pod_ip {self.cfg.master_ip}" f" --block_size {self.cfg.cache_config.block_size}" f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}" f" --eos_tokens_lens {self.data_processor.eos_token_id_len}" f" --pad_token_id {self.data_processor.pad_token_id}" f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}" f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}" f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}" f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}" f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}" f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}" f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}" f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'" f" --ori_vocab_size {ori_vocab_size}" f" --think_start_id {think_start_id}" f" --think_end_id {think_end_id}" f" --image_patch_id {image_patch_id}" f" --line_break_id {line_break_id}" f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" f" --load_strategy {self.cfg.load_config.load_strategy}" f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" f" --load_choices {self.cfg.load_config.load_choices}" f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" f" --ips {ips}" f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" f" --runner {self.cfg.model_config.runner}" f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" f" --max_logprobs {self.cfg.model_config.max_logprobs}" f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'" f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}" f" --deploy_modality {self.cfg.deploy_modality.value}" ) if self.cfg.structured_outputs_config.logits_processors is not None: arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}" if self.mm_max_tokens_per_item is not None: arguments += f" --mm_max_tokens_per_item '{json.dumps(self.mm_max_tokens_per_item)}'" worker_store_true_flag = { "enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel, "enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching, "enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill, "do_profile": self.do_profile, "dynamic_load_weight": self.cfg.load_config.dynamic_load_weight, "disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace, "disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce, "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, "enable_logprob": self.cfg.model_config.enable_logprob, "lm_head_fp32": self.cfg.model_config.lm_head_fp32, "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, } for worker_flag, value in worker_store_true_flag.items(): if value: arguments = arguments + f" --{worker_flag}" worker_default_none_flag = { "num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override, "kvcache_storage_backend": self.cfg.cache_config.kvcache_storage_backend, } for worker_flag, value in worker_default_none_flag.items(): if value: arguments = arguments + f" --{worker_flag} {value}" if self.cfg.nnode > 1: pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}" pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log" self.llm_logger.info(f"Launch worker service command: {pd_cmd}") p = subprocess.Popen( pd_cmd, stdout=subprocess.PIPE, shell=True, preexec_fn=os.setsid, ) return p def _stop_profile(self): """ Stop profiling of the model server and reset variables. """ self.do_profile = 0 while self.get_profile_block_num_signal.value[0] == 0: if hasattr(self, "worker_proc") and self.worker_proc is not None: if self.worker_proc.poll() is not None: raise RuntimeError("Worker process failed to start." "Please check log/workerlog.* for details.") time.sleep(1) num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.resource_manager.reset_cache_config(self.cfg.cache_config) if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. """ if self.worker_healthy_live_signal.value[0]: elapsed_time = time.time() - self.worker_healthy_live_signal.value[0] if elapsed_time > time_interval_threashold: return False, "Worker Service Not Healthy" return True, "" def launch_components(self): if self.cfg.scheduler_config.splitwise_role != "mixed": # 单机逻辑 self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=()) self.splitwise_receive_thread.daemon = True self.splitwise_receive_thread.start() role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip if self.cfg.scheduler_config.name == "splitwise": self.scheduler.start(role, host_ip, self.cfg.register_info) elif self.cfg.scheduler_config.name == "dp": self.scheduler.start( self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node, ) if not envs.FD_ENABLE_MULTI_API_SERVER: if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: self.launched_expert_service_signal.value[0] = 1 self.dp_processed = [] self.dp_engine_worker_queue_server = [] for i in range( 1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, ): if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: address = ( self.cfg.master_ip, int(self.cfg.parallel_config.engine_worker_queue_port[i]), ) else: address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" self.llm_logger.info(f"dp start queue service {address}") self.dp_engine_worker_queue_server.append( EngineWorkerQueue( address=address, is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) ) from fastdeploy.engine.expert_service import ( start_data_parallel_service, ) self.dp_processed.append( multiprocessing.Process( target=start_data_parallel_service, args=( self.cfg, i, ), ) ) self.llm_logger.info( f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}" + f" data parallel id {i}" ) self.dp_processed[-1].start() while self.launched_expert_service_signal.value[i] == 0: time.sleep(1) def check_worker_initialize_status(self): """ Check the initlialize status of workers by stdout logging """ def detect_thread(): for line in self.worker_proc.stdout: line = line.decode("utf-8", errors="ignore") if self.worker_init_status.get("finished", False): break if match := re.search( r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)", line, ): self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100 elif (match := re.search(r"Start load layer (\d+)", line)) or ( match := re.search(r"set state for layer (\d+)", line) ): progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers self.worker_init_status["layer_loadding"] = progress if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1: self.worker_init_status["finished"] = True self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread.start() # display weight loadding progress with tqdm(total=100, desc="Loading Weights") as pbar: progress = 0 while progress < 100: progress = int(self.worker_init_status.get("weight_loadding", 0) * 100) if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready(): progress = 100 pbar.update(progress - pbar.n) pbar.refresh() time.sleep(0.5) if self.worker_proc.poll() is not None: return False # display layer loadding progress with tqdm(total=100, desc="Loading Layers") as pbar: progress = 0 while progress < 100: progress = int(self.worker_init_status.get("layer_loadding", 0) * 100) if self._worker_processes_ready(): progress = 100 pbar.update(progress - pbar.n) pbar.refresh() time.sleep(0.5) if self.worker_proc.poll() is not None: return False self.worker_init_status["finished"] = True try: self.checking_worker_status_thread.join(timeout=1) except Exception: pass return True