mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
7707be8384
* [Feature][KVCache] Support cache manager v1 architecture Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Update cache manager and related modules Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: update cache_manager and related modules Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: add node to evictable set in complete_swap_to_device When a node transitions from SWAP_TO_DEVICE to DEVICE via complete_swap_to_device, it was not being added to the _evictable_device set. This caused nodes with ref_count=0 to become "orphaned" - not appearing in any evictable set despite having cache_status=DEVICE. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: update cache manager v1 and related modules - Add new cache_manager.py with cache management functionality - Add radix_tree.py for prefix caching - Update block_pool.py and metadata.py - Update request.py and resource_manager_v1.py for scheduling - Update gpu_model_runner.py for GPU model execution Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(cache): add cache controller v1 implementation - Add CacheController class for cache management - Update config.py with cache related configurations - Refactor gpu_model_runner.py for improved cache handling * feat(cache_manager): update cache manager v1 * fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta ## Motivation 修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题, 并清理 ForwardMeta 中已废弃的 cache_controller 字段。 ## Modifications - fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids, 修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl) - refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从 cache_controller 改为 cache_utils(正确来源) - refactor: ForwardMeta 移除废弃的 cache_controller 字段 - refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句 - test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * feat(cache_manager): refactor cache manager v1 and optimize swap ops ## Motivation 对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。 ## Modifications - 重构 transfer_manager.py,大幅精简代码逻辑 - 优化 swap_cache_optimized.cu GPU 算子实现 - 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题 - 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py - 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用 - 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager) - 调整 config.py 中相关配置项 * [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash ## Motivation 在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由 CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block hash 未携带 multimodal extra_keys 的问题。 ## Modifications - `cache_controller.py` - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache, 并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层 - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证 Host Cache 也为 MTP 分配足够空间 - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用) - `cache_utils.py` - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息, 对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑 - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys, 修复多模态场景 prefix cache 命中率不准的问题 - `spec_decode/mtp.py` - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager 路径下重复初始化 MTP KV Cache - `gpu_model_runner.py` - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用 `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建 - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1` 标志,跳过重复的 proposer.initialize_kv_cache 调用 ## Usage or Command ```bash # 启动支持 MTP + cache_manager_v1 的推理服务(示例) bash run.sh ``` * fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops 1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations with cp.cuda.Device(device_id) context to ensure streams/events are bound to the correct device, preventing cross-device errors in multi-GPU setups. 2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now). 3. Remove swap_cache_all_layers_batch op: simplified the implementation by removing the batch upload variant; all-layer transfers now use the standard swap_cache_all_layers with cupy device context. 4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change strict less-than (<) to less-than-or-equal (<=) so that multimodal items ending exactly at block start are correctly excluded. 5. Extract config fields to KVCacheBase: model_config, cache_config, quant_config, parallel_config are now set in the base class __init__ to avoid duplication in CacheController and CacheManager subclasses. 6. Translate metadata.py docstrings from Chinese to English for broader contributor accessibility. 7. Add test_cache_utils.py: comprehensive unit tests for get_block_hash_extra_keys covering all boundary and overlap scenarios. 8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py backup candidate tests, test_transfer_manager.py and test_cache_manager.py multi-GPU and concurrent operation tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager ## Motivation 修复两处问题: 1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错 2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑 ## Modifications - `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811 - `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑 - `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix pre-commit code style issues ## Motivation 修复 CI pre-commit 代码风格检查失败问题。 ## Modifications - `fastdeploy/engine/common_engine.py`: black 格式化 - `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复 - `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复 - `fastdeploy/worker/gpu_worker.py`: isort 修复 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] update cache_manager_v1 modules ## Motivation 更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。 ## Modifications - `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构 - `fastdeploy/config.py`:配置项更新 - `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing ## Motivation 将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。 ## Modifications - `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求 - `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests ## Motivation 优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟; 同时跳过 swap cache ops 测试(当前环境不支持)。 ## Modifications - `fastdeploy/cache_manager/v1/cache_controller.py`: - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点 - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点 - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能 - `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix unittest failures for cache_manager_v1 三个单测因接口变更或 Mock 方式问题导致失败,需修复。 - tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError` - tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager` - tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] remove debug logging code ## Modifications - fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志 - fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix cupy device id caching and pickle for _match_result ## Motivation 修复两个 bug: 1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。 2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。 ## Modifications - `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。 - `request.py`: - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。 - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。 ## Usage or Command * fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP - Use BatchRequest instead of list in test_trigger_preempt_records_tasks - Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs ## Motivation `gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`, 将 `available_blocks()` 的返回值当作整数传给 `range()`, 导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。 ## Modifications - `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())` Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError ## Motivation gpu_free_block_list 属性中调用 BlockPool.available_blocks(), 该方法返回 int(空闲块数量),用 list() 包装 int 会触发 TypeError: 'int' object is not iterable。 ## Modifications 将 list(self._device_pool.available_blocks()) 改为 list(self._device_pool._free_blocks),直接返回空闲块索引列表。 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作 ## Motivation V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配, 同时 free_cache 需要支持可选的 clear_storage 参数。 ## Modifications - cache_controller.py: free_cache 新增 clear_storage 参数(默认 False), 仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空 - common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时 使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑 - gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除 ## Usage or Command # 启动服务(V1 CacheManager) python -m fastdeploy.entrypoints.openai.api_server \ --enable-v1-kvcache-manager \ ... * [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup - Remove unused `select_blocks_for_backup` method from radix_tree.py - Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py - Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix) - Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path Remove redundant prefix_caching branch in else path; always call recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request ## Motivation 修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。 ## Modifications 1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作) 2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃 3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用 4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]` 5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音 6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入 ## Usage or Command 单元测试: ```bash source .venv/py310/bin/activate cd baidu/FastDeploy python -m pytest tests/cache_manager/v1/test_cache_manager.py -v python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v python -m pytest tests/engine/test_request.py -v ``` * [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0 ## Motivation 当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误: `-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`, 会返回并清空整个空闲块列表,而非返回空列表。 ## Modifications 在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`, 避免触发 Python 负索引陷阱。 ## Usage or Command ```bash # 运行相关单元测试验证修复 python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s ``` * [KVCache][Test] add unit tests for cache_manager v1 modules ## Motivation 补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。 ## Modifications 新增/补充以下测试文件,全部 326 个用例通过: - tests/cache_manager/v1/test_block_pool.py(新建) 覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool - tests/cache_manager/v1/test_metadata.py(新建) 覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler - tests/cache_manager/v1/test_cache_utils.py(补充) 新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法 - tests/cache_manager/v1/test_radix_tree.py(补充) 新增 TestCompleteSwapToDevice 专项测试类(6 个用例) - tests/cache_manager/v1/test_cache_manager.py(补充) 新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata - tests/cache_manager/v1/test_transfer_manager.py(补充) 新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event ## Usage or Command ```bash # 运行所有新增单测 source .venv/py310/bin/activate python -m pytest tests/cache_manager/v1/test_block_pool.py \ tests/cache_manager/v1/test_metadata.py \ tests/cache_manager/v1/test_cache_utils.py \ tests/cache_manager/v1/test_radix_tree.py \ tests/cache_manager/v1/test_cache_manager.py \ tests/cache_manager/v1/test_transfer_manager.py -v # 期望结果:326 passed ``` --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2709 lines
129 KiB
Python
2709 lines
129 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import collections
|
|
import copy
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import weakref
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import zmq
|
|
from tqdm import tqdm
|
|
|
|
import fastdeploy.metrics.trace as tracing
|
|
from fastdeploy.cache_manager.cache_data import CacheStatus
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.engine.register_manager import RegisterManager
|
|
from fastdeploy.engine.request import (
|
|
CompletionOutput,
|
|
ControlRequest,
|
|
ControlResponse,
|
|
Request,
|
|
RequestMetrics,
|
|
RequestOutput,
|
|
RequestStatus,
|
|
RequestType,
|
|
)
|
|
from fastdeploy.engine.resource_manager import ResourceManager
|
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
|
from fastdeploy.engine.sched.scheduler_metrics_logger import SchedulerMetricsLogger
|
|
from fastdeploy.eplb.utils import init_eplb_signals
|
|
from fastdeploy.input.preprocess import InputPreprocessor
|
|
from fastdeploy.inter_communicator import (
|
|
EngineCacheQueue,
|
|
EngineWorkerQueue,
|
|
IPCSignal,
|
|
ZmqIpcServer,
|
|
ZmqTcpServer,
|
|
)
|
|
from fastdeploy.inter_communicator.fmq import FMQ
|
|
from fastdeploy.metrics.metrics import main_process_metrics
|
|
from fastdeploy.model_executor.guided_decoding import schema_checker
|
|
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
|
from fastdeploy.spec_decode import SpecMethod
|
|
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
|
from fastdeploy.trace.constants import LoggingEventName
|
|
from fastdeploy.trace.trace_logger import print as trace_print
|
|
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
|
|
|
|
try:
|
|
TokenProcessor = load_token_processor_plugins()
|
|
llm_logger.info(f"TokenProcessor plugin {TokenProcessor} loaded")
|
|
except:
|
|
from fastdeploy.output.token_processor import TokenProcessor
|
|
|
|
|
|
def _read_latest_worker_traceback(log_dir: str) -> Optional[str]:
|
|
"""读取 workerlog.* 文件中的最新 traceback。"""
|
|
|
|
try:
|
|
candidates = sorted(Path(log_dir).glob("workerlog.*"), key=lambda path: path.stat().st_mtime, reverse=True)
|
|
except OSError:
|
|
return None
|
|
|
|
for path in candidates:
|
|
try:
|
|
content = path.read_text(encoding="utf-8", errors="ignore")
|
|
except OSError:
|
|
continue
|
|
|
|
marker = "Traceback (most recent call last):"
|
|
start = content.rfind(marker)
|
|
if start != -1:
|
|
return content[start:].strip()
|
|
|
|
return None
|
|
|
|
|
|
def _format_worker_launch_failure_message(log_dir: str) -> str:
|
|
"""格式化 worker 启动失败的错误消息,包含 traceback 信息。"""
|
|
message = "Failed to launch worker processes, check log/workerlog.* for more details."
|
|
traceback_text = _read_latest_worker_traceback(log_dir)
|
|
if traceback_text:
|
|
return f"{message}\n{traceback_text}"
|
|
return message
|
|
|
|
|
|
class EngineService:
|
|
"""
|
|
Base class containing common engine functionality
|
|
"""
|
|
|
|
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
|
|
"""
|
|
Initializes the LLMEngine with the provided configuration.
|
|
|
|
Args:
|
|
cfg (Config): Config object containing all the configuration parameters.
|
|
"""
|
|
self.cfg = cfg
|
|
self.use_async_llm = use_async_llm
|
|
|
|
if self.cfg.parallel_config.data_parallel_size > 1:
|
|
self.llm_logger = get_logger(
|
|
"fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
|
|
)
|
|
else:
|
|
self.llm_logger = llm_logger
|
|
|
|
self.is_paused = False # pause request generation
|
|
self._pause_cond = threading.Condition()
|
|
|
|
self._ctrl_output_queues = {}
|
|
self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict)
|
|
tp_size = cfg.parallel_config.tensor_parallel_size
|
|
dp_index = cfg.parallel_config.local_data_parallel_id
|
|
for tp_rank in range(tp_size):
|
|
# create worker control response queue
|
|
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
|
|
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
|
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
|
|
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
|
|
|
|
# create cache control response queue
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
|
|
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
|
|
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
|
|
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
|
|
|
|
self.scheduler = cfg.scheduler_config.scheduler()
|
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
|
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
self.llm_logger.info("Use V1 KVCache Scheduler")
|
|
self.resource_manager = ResourceManagerV1(
|
|
cfg.scheduler_config.max_num_seqs,
|
|
cfg,
|
|
cfg.parallel_config.tensor_parallel_size,
|
|
cfg.scheduler_config.splitwise_role,
|
|
cfg.parallel_config.local_data_parallel_id,
|
|
)
|
|
else:
|
|
self.llm_logger.info("Use V0 KVCache Scheduler")
|
|
self.resource_manager = ResourceManager(
|
|
cfg.scheduler_config.max_num_seqs,
|
|
cfg,
|
|
cfg.parallel_config.tensor_parallel_size,
|
|
cfg.scheduler_config.splitwise_role,
|
|
cfg.parallel_config.local_data_parallel_id,
|
|
)
|
|
|
|
self.start_worker_queue_service(start_queue)
|
|
|
|
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port)
|
|
self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}")
|
|
|
|
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
|
|
self.token_processor = TokenProcessor(
|
|
cfg=cfg,
|
|
cached_generated_tokens=self.scheduler,
|
|
engine_worker_queue=self.engine_worker_queue,
|
|
split_connector=self.split_connector,
|
|
)
|
|
self.token_processor.set_resource_manager(self.resource_manager)
|
|
|
|
self.scheduler_metrics_logger = SchedulerMetricsLogger(
|
|
enabled=True,
|
|
dp_rank=self.cfg.parallel_config.local_data_parallel_id,
|
|
)
|
|
self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger
|
|
self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger)
|
|
|
|
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
|
|
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
|
|
self.partial_chunked_tokens[idx] = (
|
|
(self.cfg.scheduler_config.max_num_batched_tokens // idx)
|
|
// self.cfg.cache_config.block_size
|
|
* self.cfg.cache_config.block_size
|
|
)
|
|
|
|
self.bos_client = None
|
|
self.mm_max_tokens_per_item = None
|
|
self.guided_decoding_checker = None
|
|
if self.cfg.structured_outputs_config.guided_decoding_backend != "off":
|
|
self.guided_decoding_checker = schema_checker(
|
|
self.cfg.structured_outputs_config.guided_decoding_backend,
|
|
disable_any_whitespace=self.cfg.structured_outputs_config.disable_any_whitespace,
|
|
)
|
|
self._init_worker_monitor_signals()
|
|
|
|
# Initialize RegisterManager
|
|
self._register_manager = RegisterManager(
|
|
cfg=self.cfg,
|
|
engine_worker_queue=self.engine_worker_queue,
|
|
get_is_paused=self._get_is_paused_safe,
|
|
)
|
|
|
|
if self.cfg.eplb_config.enable_eplb:
|
|
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
|
init_eplb_signals(cfg, current_suffix)
|
|
|
|
if self.use_async_llm:
|
|
# Add worker management attributes
|
|
self.worker_proc = None
|
|
self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0
|
|
self.ipc_signal_suffix = None
|
|
self.cache_manager_processes = None
|
|
|
|
if envs.ENABLE_V1_KVCACHE_MANAGER:
|
|
from fastdeploy.cache_manager.v1.cache_utils import get_request_block_hasher
|
|
|
|
self._block_hasher = get_request_block_hasher(block_size=self.cfg.cache_config.block_size)
|
|
|
|
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
|
|
|
def start(self, async_llm_pid=None):
|
|
self.running = True
|
|
console_logger.debug("Start engineService...")
|
|
|
|
if self.use_async_llm:
|
|
self.start_worker_service(async_llm_pid)
|
|
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
self.insert_task_to_worker_thread = threading.Thread(
|
|
target=self._schedule_request_to_worker_v1, daemon=True
|
|
)
|
|
else:
|
|
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
|
|
self.insert_task_to_worker_thread.start()
|
|
self.token_processor.tasks_queue = self.engine_worker_queue
|
|
self.token_processor.run()
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
self._decode_process_splitwise_requests()
|
|
|
|
self._register_manager.start()
|
|
|
|
def start_worker_service(self, async_llm_pid=None):
|
|
# Initialize IPC signals for worker management
|
|
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
|
self._init_worker_signals()
|
|
|
|
# Create data processor if not exists
|
|
if not hasattr(self, "data_processor"):
|
|
self.create_data_processor()
|
|
|
|
# Launch components: scheduler, cache_manager, expert_service et.al.
|
|
self.launch_components()
|
|
|
|
# If block number is specified and model is deployed in splitwise mode, start cache manager first
|
|
if (
|
|
not self.do_profile
|
|
and self.cfg.scheduler_config.splitwise_role != "mixed"
|
|
and not envs.ENABLE_V1_KVCACHE_MANAGER
|
|
):
|
|
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
|
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
|
|
|
# Start worker processes
|
|
self.worker_proc = self._start_worker_service()
|
|
time.sleep(5)
|
|
self.worker_init_status = dict()
|
|
result_container = {}
|
|
|
|
def check_worker_initialize_status_func(res: dict):
|
|
res["worker_is_alive"] = True
|
|
if not self.check_worker_initialize_status():
|
|
self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR))
|
|
res["worker_is_alive"] = False
|
|
|
|
self.check_worker_initialize_status_func_thread = threading.Thread(
|
|
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
|
|
)
|
|
self.check_worker_initialize_status_func_thread.start()
|
|
|
|
# Wait model loading
|
|
while self.loaded_model_signal.value[0] == 0:
|
|
# Make sure worker process is alive
|
|
if not self.check_worker_initialize_status_func_thread.is_alive():
|
|
return False
|
|
time.sleep(1)
|
|
|
|
# If block number is not specified, let workers do profiling to determine the block number,
|
|
# and then start the cache manager
|
|
if self.do_profile:
|
|
self._stop_profile()
|
|
elif (
|
|
self.cfg.scheduler_config.splitwise_role == "mixed"
|
|
and self.cfg.cache_config.enable_prefix_caching
|
|
and not envs.ENABLE_V1_KVCACHE_MANAGER
|
|
):
|
|
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
|
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
|
|
|
# Worker launched
|
|
self.check_worker_initialize_status_func_thread.join()
|
|
if not result_container["worker_is_alive"]:
|
|
self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR))
|
|
return False
|
|
|
|
# Start ZMQ service for communication with AsyncLLM
|
|
if async_llm_pid:
|
|
self.start_zmq_service(async_llm_pid)
|
|
|
|
def create_data_processor(self):
|
|
self.input_processor = InputPreprocessor(
|
|
self.cfg.model_config,
|
|
self.cfg.structured_outputs_config.reasoning_parser,
|
|
self.cfg.limit_mm_per_prompt,
|
|
self.cfg.mm_processor_kwargs,
|
|
self.cfg.tool_parser,
|
|
enable_mm_runtime=self.cfg.enable_mm_runtime,
|
|
)
|
|
self.data_processor = self.input_processor.create_processor()
|
|
self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item(
|
|
self.cfg.model_config.max_model_len
|
|
)
|
|
if self.mm_max_tokens_per_item is not None:
|
|
max_chunk_tokens = self.cfg.get_max_chunk_tokens(self.mm_max_tokens_per_item)
|
|
self.cfg.cache_config.postprocess(max_chunk_tokens, self.cfg.scheduler_config.max_num_seqs)
|
|
|
|
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
|
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
|
self.llm_logger.info(f"current_suffix: {current_suffix}")
|
|
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.exist_task_signal = IPCSignal(
|
|
name="exist_task_signal",
|
|
array=exist_task_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
|
|
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.exist_swapped_task_signal = IPCSignal(
|
|
name="exist_swapped_task_signal",
|
|
array=exist_swapped_task_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
|
|
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.exist_prefill_task_signal = IPCSignal(
|
|
name="exist_prefill_task_signal",
|
|
array=exist_prefill_task_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.engine_forward_signal = IPCSignal(
|
|
name="engine_forward_signal",
|
|
array=engine_forward_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
|
|
worker_healthy_live_recorded_time_array = np.zeros(
|
|
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
|
|
)
|
|
self.worker_healthy_live_signal = IPCSignal(
|
|
name="worker_healthy_live_signal",
|
|
array=worker_healthy_live_recorded_time_array,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
|
|
self.cache_ready_signal = IPCSignal(
|
|
name="cache_ready_signal",
|
|
array=cache_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
|
|
self.swap_space_ready_signal = IPCSignal(
|
|
name="swap_space_ready_signal",
|
|
array=swap_space_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
cache_transfer_inited_signal_data = np.zeros(
|
|
shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32
|
|
)
|
|
self.cache_transfer_inited_signal = IPCSignal(
|
|
name="cache_transfer_inited_signal",
|
|
array=cache_transfer_inited_signal_data,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
model_weights_status = np.zeros([1], dtype=np.int32)
|
|
self.model_weights_status_signal = IPCSignal(
|
|
name="model_weights_status",
|
|
array=model_weights_status,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
prefix_tree_status = np.zeros([1], dtype=np.int32)
|
|
self.prefix_tree_status_signal = IPCSignal(
|
|
name="prefix_tree_status",
|
|
array=prefix_tree_status,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
kv_cache_status = np.zeros([1], dtype=np.int32)
|
|
self.kv_cache_status_signal = IPCSignal(
|
|
name="kv_cache_status",
|
|
array=kv_cache_status,
|
|
dtype=np.int32,
|
|
suffix=current_suffix,
|
|
create=True,
|
|
)
|
|
|
|
def start_worker_queue_service(self, start_queue):
|
|
"""
|
|
start queue service for engine worker communication
|
|
"""
|
|
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
|
address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port)
|
|
else:
|
|
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock"
|
|
|
|
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
|
if start_queue:
|
|
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
|
|
self.engine_worker_queue_server = EngineWorkerQueue(
|
|
address=address,
|
|
is_server=True,
|
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
|
)
|
|
# Dynamically updates the port value if an anonymous port is used
|
|
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
|
self.cfg.parallel_config.local_engine_worker_queue_port = (
|
|
self.engine_worker_queue_server.get_server_port()
|
|
)
|
|
address = (
|
|
self.cfg.master_ip,
|
|
self.cfg.parallel_config.local_engine_worker_queue_port,
|
|
)
|
|
|
|
if not envs.ENABLE_V1_KVCACHE_MANAGER:
|
|
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
self.llm_logger.info(
|
|
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
|
|
)
|
|
self.cache_task_queue = EngineCacheQueue(
|
|
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
|
|
authkey=b"cache_queue_service",
|
|
is_server=True,
|
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
|
client_id=-1,
|
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
|
)
|
|
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
|
|
|
|
self.engine_worker_queue = EngineWorkerQueue(
|
|
address=address,
|
|
is_server=False,
|
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
|
client_id=0,
|
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
|
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
|
|
)
|
|
|
|
def insert_tasks(self, tasks: List[Request], current_id=-1):
|
|
"""
|
|
Allocate resource and insert tasks to engine.
|
|
Used in v0_kvcache_scheduler.
|
|
"""
|
|
if not isinstance(tasks, list):
|
|
tasks = [tasks]
|
|
|
|
self.resource_manager.check_and_free_block_tables()
|
|
|
|
need_delete_tasks = []
|
|
for task in tasks:
|
|
rid = task.request_id.split("_")[0]
|
|
trace_carrier = task.trace_carrier
|
|
if trace_carrier:
|
|
tracing.trace_set_proc_propagate_context(rid, trace_carrier)
|
|
task.trace_carrier = tracing.trace_get_proc_propagate_context(rid)
|
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
|
status, msg = self.split_connector.check_decode_allocated(task)
|
|
if status:
|
|
task.metrics.ask_decode_resource_finish_time = time.time()
|
|
else:
|
|
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
|
self.scheduler.put_results(
|
|
[
|
|
RequestOutput(
|
|
request_id=task.request_id,
|
|
finished=True,
|
|
error_code=500,
|
|
error_msg=msg,
|
|
)
|
|
]
|
|
)
|
|
need_delete_tasks.append(task)
|
|
continue
|
|
for tmp_task in need_delete_tasks:
|
|
tasks.remove(tmp_task)
|
|
|
|
for item in tasks:
|
|
trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, item.request_id, getattr(item, "user", ""))
|
|
|
|
available_batch = np.sum(self.resource_manager.stop_flags)
|
|
if len(tasks) > available_batch:
|
|
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
|
|
self.llm_logger.error("The exceeded part will be ignored!")
|
|
tasks = tasks[:available_batch]
|
|
|
|
req_ids = [t.request_id for t in tasks]
|
|
|
|
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
|
|
|
|
if not tasks:
|
|
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
|
|
self.llm_logger.error(error_msg)
|
|
raise EngineError(error_msg, error_code=500)
|
|
return False
|
|
|
|
self.token_processor.number_of_tasks += len(tasks)
|
|
|
|
is_decode = False
|
|
is_prefill = False
|
|
for i in range(len(tasks)):
|
|
if tasks[i].disaggregate_info is not None:
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
is_decode = True
|
|
else:
|
|
is_prefill = True
|
|
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
|
|
|
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
|
self.split_connector.send_cache_info_to_messager(tasks, current_id)
|
|
elif self.cfg.scheduler_config.splitwise_role == "decode":
|
|
self.split_connector.send_cache_info_to_prefill(tasks)
|
|
|
|
if not is_decode:
|
|
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
|
|
for task in tasks:
|
|
if not getattr(task, "has_been_preempted_before", False):
|
|
task.metrics.inference_start_time = time.time()
|
|
tracing.trace_report_span(
|
|
tracing.TraceSpanName.SCHEDULE,
|
|
task.request_id.split("_")[0],
|
|
int(task.metrics.scheduler_recv_req_time * 1e9),
|
|
int(task.metrics.inference_start_time * 1e9),
|
|
thread_finish_flag=True,
|
|
)
|
|
trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", ""))
|
|
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
|
|
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
|
|
else:
|
|
trace_print(
|
|
LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "")
|
|
)
|
|
if not is_prefill:
|
|
if not self.cfg.enable_mm_runtime:
|
|
self.update_requests_chunk_size(tasks)
|
|
else:
|
|
self.update_mm_requests_chunk_size(tasks)
|
|
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
|
|
return True
|
|
|
|
def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
|
|
"""
|
|
Decode insert prefilled requests into engine worker queue.
|
|
Used in v0_kvcache_scheduler.
|
|
Args:
|
|
request_outputs: a list of RequestOutput sent by prefill instance
|
|
"""
|
|
to_infer_reqs = []
|
|
for req_out in request_outputs:
|
|
solt_idx = self.resource_manager.req_dict[req_out.request_id]
|
|
del self.resource_manager.req_dict[req_out.request_id]
|
|
cur_req = self.resource_manager.tasks_list[solt_idx]
|
|
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
if not req_out.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
|
|
self.resource_manager.stop_flags[solt_idx] = True
|
|
self.resource_manager.tasks_list[solt_idx] = None
|
|
self.resource_manager._recycle_block_tables(cur_req)
|
|
if req_out.request_id in self.token_processor.tokens_counter:
|
|
del self.token_processor.tokens_counter[req_out.request_id]
|
|
self.llm_logger.warning(f"{req_out.request_id} need not decode after first token")
|
|
continue
|
|
|
|
cur_req.prompt_token_ids[0] = req_out.outputs.token_ids[0]
|
|
cur_req.num_cached_tokens = req_out.num_cached_tokens
|
|
req_out.metrics.decode_recv_req_time = cur_req.metrics.decode_recv_req_time
|
|
req_out.metrics.decode_preallocate_req_time = cur_req.metrics.decode_preallocate_req_time
|
|
cur_req.metrics = req_out.metrics
|
|
cur_req.metrics.decode_inference_start_time = time.time()
|
|
if (
|
|
self.cfg.speculative_config.method == SpecMethod.MTP
|
|
and self.cfg.scheduler_config.splitwise_role == "decode"
|
|
):
|
|
cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids)
|
|
|
|
if req_out.error_code != 200:
|
|
self.resource_manager.stop_flags[solt_idx] = True
|
|
self.resource_manager.tasks_list[solt_idx] = None
|
|
self.resource_manager._recycle_block_tables(cur_req)
|
|
if req_out.request_id in self.token_processor.tokens_counter:
|
|
del self.token_processor.tokens_counter[req_out.request_id]
|
|
self.scheduler.put_results([req_out])
|
|
self.llm_logger.warning(
|
|
f"{req_out.request_id} prefill failed with msg:{req_out.error_msg}, recycle resource."
|
|
)
|
|
continue
|
|
|
|
self.token_processor.tokens_counter[req_out.request_id] = 1
|
|
to_infer_reqs.append(cur_req)
|
|
|
|
if to_infer_reqs:
|
|
self.engine_worker_queue.put_tasks((to_infer_reqs, self.resource_manager.real_bsz))
|
|
self.llm_logger.debug(f"put requests to engine worker queue, task:{to_infer_reqs}")
|
|
return True
|
|
|
|
def task_is_finished(self, index):
|
|
"""
|
|
judge if the task is finished
|
|
"""
|
|
assert index < len(self.resource_manager.stop_flags)
|
|
return self.resource_manager.stop_flags[index]
|
|
|
|
def all_tasks_finished(self):
|
|
"""
|
|
judge if all tasks are finished
|
|
"""
|
|
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
|
|
|
|
def update_requests_chunk_size(self, requests):
|
|
"""
|
|
update each request's chunk size info
|
|
"""
|
|
|
|
def update_tokens(idx, chunk_size, update_chunk=False):
|
|
nonlocal remain_batched_tokens, chunk_request_num
|
|
if update_chunk:
|
|
requests_chunk[idx][-1] += chunk_size
|
|
else:
|
|
requests_chunk[idx].append(chunk_size)
|
|
remain_batched_tokens -= chunk_size
|
|
current_request_size[idx] -= chunk_size
|
|
if current_request_size[idx] <= 0:
|
|
chunk_request_num -= 1
|
|
|
|
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
|
|
return
|
|
|
|
current_request_size = [request.prompt_token_ids_len for request in requests]
|
|
requests_chunk = [[] for _ in range(len(requests))]
|
|
chunk_request_num = len(current_request_size)
|
|
while chunk_request_num >= 1:
|
|
remain_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
|
|
for idx in range(len(current_request_size)):
|
|
if current_request_size[idx] <= 0:
|
|
continue
|
|
chunk_size = min(
|
|
current_request_size[idx],
|
|
self.partial_chunked_tokens[chunk_request_num],
|
|
)
|
|
update_tokens(idx, chunk_size)
|
|
|
|
while remain_batched_tokens >= self.cfg.cache_config.block_size:
|
|
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
|
|
waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
|
|
if len(waiting_requests) == 0:
|
|
break
|
|
|
|
available_tokens = (
|
|
remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
|
|
)
|
|
append_idx = current_request_size.index(min(waiting_requests))
|
|
chunk_size = min(
|
|
current_request_size[append_idx],
|
|
self.partial_chunked_tokens[chunk_request_num],
|
|
available_tokens,
|
|
)
|
|
update_tokens(append_idx, chunk_size, update_chunk=True)
|
|
|
|
for idx in range(len(requests)):
|
|
requests[idx].set("prefill_chunk_info", requests_chunk[idx])
|
|
|
|
def update_mm_requests_chunk_size(self, requests):
|
|
"""
|
|
update each multimodal request's chunk size info
|
|
"""
|
|
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
|
|
return
|
|
|
|
for request in requests:
|
|
inputs = request.multimodal_inputs
|
|
# 兼容没有图片和视频的情况
|
|
if inputs["images"] is None:
|
|
inputs["image_type_ids"] = np.array([], dtype="int32")
|
|
inputs["grid_thw"] = np.array([], dtype="int64")
|
|
inputs["images"] = np.array([], dtype="uint8")
|
|
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
|
|
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
|
|
image_mask = input_ids == self.data_processor.image_patch_id
|
|
image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
|
|
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"), dtype="int32")
|
|
grid_thw = []
|
|
for one in inputs["grid_thw"]:
|
|
if one[0] == 1:
|
|
grid_thw.append(one)
|
|
else:
|
|
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
|
|
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
|
|
|
|
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
|
|
|
|
chunk_image_num, chunk_seq_len = get_mm_split_fuse(
|
|
input_ids,
|
|
image_type_ids,
|
|
image_token_sum,
|
|
grid_thw,
|
|
self.data_processor.image_patch_id,
|
|
len(grid_thw),
|
|
0,
|
|
len(input_ids),
|
|
0,
|
|
self.partial_chunked_tokens[1],
|
|
2048,
|
|
)
|
|
|
|
grid_thw = grid_thw.numpy().reshape([-1, 3])
|
|
num_chunks = len(chunk_image_num)
|
|
chunks_info = []
|
|
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
|
|
for idx in range(num_chunks):
|
|
chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
|
|
chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
|
|
actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
|
|
chunk_image_type_ids = inputs["image_type_ids"][
|
|
image_type_ids_st : image_type_ids_st + actual_image_num
|
|
]
|
|
chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
|
|
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
|
|
chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
|
|
chunk_position_ids = inputs["position_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
|
|
|
|
chunks_info.append(
|
|
{
|
|
"input_ids": chunk_input_ids,
|
|
"token_type_ids": chunk_token_type_ids,
|
|
"image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
|
|
"grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
|
|
"images": (chunk_images if chunk_images.shape[0] else None),
|
|
"position_ids": chunk_position_ids,
|
|
}
|
|
)
|
|
|
|
input_ids_st += chunk_seq_len[idx]
|
|
image_type_ids_st += actual_image_num
|
|
grid_thw_st += chunk_image_num[idx]
|
|
patch_st += chunk_patch_num
|
|
request.set("prefill_chunk_info", chunks_info)
|
|
|
|
def _schedule_request_to_worker(self):
|
|
"""
|
|
Insert task to engine thread, monitor scheduler request queue.
|
|
if the engine has resource, insert task to engine
|
|
"""
|
|
tracing.trace_set_thread_info("Scheduler Task to Work")
|
|
current_id = 0
|
|
while getattr(self, "running", True):
|
|
try:
|
|
if self.resource_manager.available_batch() == 0:
|
|
time.sleep(0.001)
|
|
continue
|
|
if self.engine_worker_queue.exist_tasks():
|
|
time.sleep(0.001)
|
|
continue
|
|
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
|
|
if (
|
|
self.cfg.scheduler_config.splitwise_role == "mixed"
|
|
or self.split_connector.has_splitwise_tasks()
|
|
):
|
|
time.sleep(0.005)
|
|
continue
|
|
if self.engine_worker_queue.num_cache_infos() > 0:
|
|
time.sleep(0.001)
|
|
continue
|
|
if len(self.split_connector.current_request_ids) > 0:
|
|
time.sleep(0.001)
|
|
continue
|
|
|
|
num_prefill_batch = min(
|
|
int(self.resource_manager.available_batch()),
|
|
self.cfg.max_prefill_batch,
|
|
)
|
|
|
|
self.resource_manager.check_and_free_block_tables()
|
|
tasks = self.scheduler.get_requests(
|
|
available_blocks=self.resource_manager.available_block_num(),
|
|
block_size=self.cfg.cache_config.block_size,
|
|
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
|
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
|
|
batch=num_prefill_batch,
|
|
)
|
|
for task in tasks:
|
|
task.metrics.engine_get_req_time = time.time()
|
|
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
|
if len(tasks) == 0:
|
|
time.sleep(0.001)
|
|
continue
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
# TODO: refine scheduler to remove this limitation
|
|
# Decode will process and schedule the request sent by prefill to engine,
|
|
# so the same request sent by the decode api server will be ignored
|
|
continue
|
|
|
|
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
|
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
for task in tasks:
|
|
task.metrics.ask_decode_resource_start_time = time.time()
|
|
self.split_connector.send_splitwise_tasks(tasks, current_id)
|
|
|
|
insert_successful = self.insert_tasks(tasks, current_id)
|
|
if insert_successful:
|
|
current_id = current_id + 1
|
|
else:
|
|
continue
|
|
|
|
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
|
main_process_metrics.num_requests_running.inc(len(tasks))
|
|
except Exception as e:
|
|
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
|
|
self.llm_logger.error(err_msg)
|
|
|
|
def _schedule_request_to_worker_v1(self):
|
|
"""
|
|
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
|
|
"""
|
|
tracing.trace_set_thread_info("Scheduler Task to Work")
|
|
get_request_pool = ThreadPoolExecutor(max_workers=1)
|
|
is_fetching = False
|
|
|
|
def _fetch_request():
|
|
try:
|
|
with self._pause_cond:
|
|
self._pause_cond.wait_for(lambda: not self.is_paused)
|
|
nonlocal is_fetching
|
|
num_prefill_batch = min(
|
|
int(self.resource_manager.available_batch()),
|
|
self.cfg.max_prefill_batch,
|
|
)
|
|
|
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
|
|
else:
|
|
max_num_batched_tokens = self.cfg.model_config.max_model_len
|
|
|
|
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
|
tasks = self.scheduler.get_requests(
|
|
available_blocks=available_blocks,
|
|
block_size=self.cfg.cache_config.block_size,
|
|
reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num
|
|
max_num_batched_tokens=max_num_batched_tokens,
|
|
batch=num_prefill_batch,
|
|
)
|
|
for task in tasks:
|
|
task.metrics.engine_get_req_time = time.time()
|
|
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
|
|
|
|
# cache_manager_v1 set block_hasher to request
|
|
if hasattr(self, "_block_hasher"):
|
|
task.set_block_hasher(self._block_hasher)
|
|
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
# TODO: refine scheduler to remove this limitation
|
|
# Decode will process and schedule the request sent by prefill to engine,
|
|
# so the same request sent by the decode api server will be ignored
|
|
is_fetching = False
|
|
return
|
|
|
|
if tasks:
|
|
self.llm_logger.debug(
|
|
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
|
|
)
|
|
|
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
|
for task in tasks:
|
|
# start async preprocess
|
|
self.resource_manager.apply_async_preprocess(task)
|
|
need_delete_tasks = []
|
|
if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES:
|
|
for task in tasks:
|
|
# assure can allocate block ids in P
|
|
while not self.resource_manager.preallocate_resource_in_p(task):
|
|
time.sleep(0.005)
|
|
self.llm_logger.debug(
|
|
f"P has allocated resources and then ask D resource for request: {task.request_id}"
|
|
)
|
|
task.metrics.ask_decode_resource_start_time = time.time()
|
|
while True:
|
|
self.split_connector.send_splitwise_tasks([task], task.idx)
|
|
status, msg = self.split_connector.check_decode_allocated(task)
|
|
if not status:
|
|
self.llm_logger.warning(
|
|
f"D failed to allocate resource for request {task.request_id}, try again."
|
|
)
|
|
time.sleep(0.05)
|
|
else:
|
|
task.metrics.ask_decode_resource_finish_time = time.time()
|
|
break
|
|
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
|
|
else:
|
|
for task in tasks:
|
|
# assure can allocate block ids in P
|
|
while not self.resource_manager.preallocate_resource_in_p(task):
|
|
time.sleep(0.005)
|
|
|
|
self.llm_logger.debug(
|
|
f"P has allocated resources and then ask D resource for req_id: {task.request_id}"
|
|
)
|
|
task.metrics.ask_decode_resource_start_time = time.time()
|
|
self.split_connector.send_splitwise_tasks([task], task.idx)
|
|
|
|
for task in tasks:
|
|
# assure fetch block ids from D
|
|
status, msg = self.split_connector.check_decode_allocated(task)
|
|
task.metrics.ask_decode_resource_finish_time = time.time()
|
|
if not status:
|
|
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
|
self.scheduler.put_results(
|
|
[
|
|
RequestOutput(
|
|
request_id=task.request_id,
|
|
finished=True,
|
|
error_code=500,
|
|
error_msg=msg,
|
|
)
|
|
]
|
|
)
|
|
need_delete_tasks.append(task)
|
|
continue
|
|
for tmp_task in need_delete_tasks:
|
|
tasks.remove(tmp_task)
|
|
# release resource in P
|
|
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
|
|
|
|
# to send cache info to cache messager
|
|
if tasks:
|
|
need_check_req_ids = [task.request_id for task in tasks]
|
|
self.split_connector.send_cache_info_to_messager(tasks, 0)
|
|
# ensure cache tasks has sent to cache_messager
|
|
need_check_req_ids = [task.request_id for task in tasks]
|
|
finished_ids, delete_tasks_list = [], []
|
|
while need_check_req_ids:
|
|
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
|
|
self.llm_logger.debug(
|
|
f"P has successfully sent cache infos to cache messager for requests: {finished_ids}"
|
|
)
|
|
if finished_ids:
|
|
for task in tasks:
|
|
result = self.resource_manager.waiting_async_process(task)
|
|
if result is None:
|
|
self.scheduler.put_results(
|
|
[
|
|
RequestOutput(
|
|
request_id=task.request_id,
|
|
finished=True,
|
|
error_code=task.error_code,
|
|
error_msg=task.error_message,
|
|
)
|
|
]
|
|
)
|
|
need_check_req_ids.remove(task.request_id)
|
|
delete_tasks_list.append(task)
|
|
elif result is False:
|
|
if task.request_id in finished_ids:
|
|
need_check_req_ids.remove(task.request_id)
|
|
finished_ids.remove(task.request_id)
|
|
else:
|
|
time.sleep(0.001)
|
|
|
|
for tmp_task in delete_tasks_list:
|
|
tasks.remove(tmp_task)
|
|
# release resource in P
|
|
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
|
|
|
|
# Fetch requests and add them to the scheduling queue
|
|
if tasks:
|
|
for task in tasks:
|
|
task.metrics.add_req_to_resource_manager_time = time.time()
|
|
trace_print(
|
|
LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")
|
|
)
|
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
|
self.resource_manager.add_request_in_p(tasks)
|
|
self.llm_logger.info(
|
|
f"P add requests into running queue: {[task.request_id for task in tasks]}"
|
|
)
|
|
else:
|
|
for task in tasks:
|
|
self.resource_manager.add_request(task)
|
|
is_fetching = False
|
|
except Exception as e:
|
|
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
|
|
is_fetching = False
|
|
|
|
while self.running:
|
|
with self._pause_cond:
|
|
self._pause_cond.wait_for(lambda: not self.is_paused)
|
|
try:
|
|
if not is_fetching:
|
|
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
|
|
try:
|
|
is_fetching = True
|
|
get_request_pool.submit(_fetch_request)
|
|
except RuntimeError as e:
|
|
if "shutdown" in str(e):
|
|
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
|
|
break
|
|
else:
|
|
raise
|
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
|
|
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
|
|
# more efficient batches.
|
|
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
|
|
time.sleep(0.001)
|
|
continue
|
|
else:
|
|
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
|
|
if self.engine_worker_queue.exist_tasks():
|
|
time.sleep(0.001)
|
|
continue
|
|
|
|
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
|
|
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
|
|
# 2. Schedule requests
|
|
batch_request, error_tasks = self.resource_manager.schedule()
|
|
|
|
# 3. Send to engine
|
|
if len(batch_request) > 0:
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
for task in batch_request:
|
|
if task.task_type == RequestType.PREEMPTED:
|
|
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
|
|
self.llm_logger.error(msg)
|
|
self.scheduler.put_results(
|
|
[
|
|
RequestOutput(
|
|
request_id=task.request_id,
|
|
finished=True,
|
|
error_code=500,
|
|
error_msg=msg,
|
|
)
|
|
]
|
|
)
|
|
self.resource_manager.get_real_bsz()
|
|
for task in batch_request:
|
|
if task.task_type == RequestType.PREFILL:
|
|
rid = task.request_id.split("_")[0]
|
|
if isinstance(task, Request) and task.has_been_preempted_before:
|
|
trace_print(
|
|
LoggingEventName.RESCHEDULED_INFERENCE_START,
|
|
task.request_id,
|
|
getattr(task, "user", ""),
|
|
)
|
|
else:
|
|
trace_carrier = task.trace_carrier
|
|
tracing.trace_set_proc_propagate_context(rid, trace_carrier)
|
|
trace_carrier = tracing.trace_get_proc_propagate_context(rid)
|
|
task.trace_carrier = trace_carrier
|
|
tracing.trace_report_span(
|
|
tracing.TraceSpanName.SCHEDULE,
|
|
rid,
|
|
int(task.metrics.scheduler_recv_req_time * 1e9),
|
|
int(time.time() * 1e9),
|
|
thread_finish_flag=True,
|
|
)
|
|
trace_print(
|
|
LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", "")
|
|
)
|
|
trace_print(
|
|
LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "")
|
|
)
|
|
trace_print(
|
|
LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "")
|
|
)
|
|
if isinstance(task, Request):
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
task.metrics.decode_inference_start_time = time.time()
|
|
elif not task.has_been_preempted_before:
|
|
task.metrics.inference_start_time = time.time()
|
|
self.engine_worker_queue.put_tasks((batch_request, self.resource_manager.real_bsz))
|
|
else:
|
|
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
|
|
# This helps EP workers barrier for syncing tasks not hang.
|
|
if self.cfg.parallel_config.enable_expert_parallel:
|
|
self.engine_worker_queue.put_tasks(
|
|
(batch_request, self.resource_manager.real_bsz)
|
|
) # Empty (as idle tasks for ep)
|
|
|
|
# 4. Response error tasks
|
|
if error_tasks:
|
|
for request_id, failed in error_tasks:
|
|
if failed is None:
|
|
self.llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
|
|
continue
|
|
self._send_error_response(request_id, failed)
|
|
|
|
if len(batch_request) <= 0 and not error_tasks:
|
|
time.sleep(0.005)
|
|
|
|
except RuntimeError as e:
|
|
raise e
|
|
except Exception as e:
|
|
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
|
|
self.llm_logger.error(err_msg)
|
|
|
|
def _get_scheduler_unhandled_request_num(self) -> int:
|
|
"""
|
|
Get scheduler-level pending request count when supported.
|
|
"""
|
|
get_unhandled = getattr(self.scheduler, "get_unhandled_request_num", None)
|
|
if not callable(get_unhandled):
|
|
return 0
|
|
try:
|
|
unhandled = int(get_unhandled())
|
|
except Exception as e:
|
|
self.llm_logger.debug(f"Failed to get scheduler unhandled request num: {e}")
|
|
return 0
|
|
return max(unhandled, 0)
|
|
|
|
def start_zmq_service(self, api_server_pid=None):
|
|
if api_server_pid is None:
|
|
return
|
|
self.api_server_pid = api_server_pid
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
|
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
|
self.internal_adapter = InternalAdapter(
|
|
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
|
)
|
|
# ROUTER mode: need to receive client handles
|
|
self.recv_result_handle_thread = threading.Thread(
|
|
target=self.send_response_server.recv_result_handle, daemon=True
|
|
)
|
|
self.recv_result_handle_thread.start()
|
|
else:
|
|
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
# PUSH mode: batch send, no need to receive client handles
|
|
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PUSH)
|
|
# Mapping from request_id to worker_pid for routing batch responses
|
|
self.request_worker_map = {}
|
|
else:
|
|
# ROUTER mode: per-query send, need to receive client handles
|
|
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
|
self.recv_result_handle_thread = threading.Thread(
|
|
target=self.send_response_server.recv_result_handle, daemon=True
|
|
)
|
|
self.recv_result_handle_thread.start()
|
|
time.sleep(3)
|
|
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
|
self.insert_task_to_scheduler_thread.start()
|
|
|
|
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
|
|
self.receive_output_thread.start()
|
|
|
|
def _insert_zmq_task_to_scheduler(self):
|
|
tracing.trace_set_thread_info("Insert Task to Scheduler")
|
|
added_requests: Dict[str, int] = dict()
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
|
return
|
|
|
|
while self.running:
|
|
try:
|
|
block = True if len(added_requests) == 0 else False
|
|
if not self.cfg.enable_mm_runtime:
|
|
err, data = self.recv_request_server.receive_json_once(block)
|
|
else:
|
|
err, data = self.recv_request_server.receive_pyobj_once(block)
|
|
if err is not None:
|
|
# The message "Context was terminated" is normal when closing a ZMQ context
|
|
if "Context was terminated" in str(err):
|
|
self.llm_logger.info(
|
|
"Engine stops inserting zmq task into scheduler due to ZMQ context termination (normal shutdown)."
|
|
)
|
|
else:
|
|
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
self.recv_request_server = ZmqTcpServer(
|
|
port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL
|
|
)
|
|
else:
|
|
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
|
continue
|
|
|
|
# Extract zmq_worker_pid for per-worker PUSH routing.
|
|
# Only needed when ZMQ_SEND_BATCH_DATA=True AND not using internal adapter,
|
|
# because FD_ENABLE_INTERNAL_ADAPTER uses ROUTER (worker_pid is irrelevant).
|
|
worker_pid = None
|
|
if envs.ZMQ_SEND_BATCH_DATA and not envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
worker_pid = data["zmq_worker_pid"]
|
|
|
|
if ControlRequest.is_control_request(data):
|
|
try: # todo: run control request async, do not block request generation
|
|
if worker_pid is not None:
|
|
self.request_worker_map[data.get("request_id")] = worker_pid
|
|
control_req = ControlRequest.from_dict(data)
|
|
self.run_control_method(control_req)
|
|
except Exception as e:
|
|
self.llm_logger.error(
|
|
f"Failed to process control request {data.get('request_id')}: "
|
|
f"{e}, {traceback.format_exc()}"
|
|
)
|
|
continue
|
|
|
|
request, insert_task = data, []
|
|
results: List[Tuple[str, Optional[str]]] = list()
|
|
if data:
|
|
# Store worker_pid mapping for normal/abort requests
|
|
if worker_pid is not None:
|
|
req_id_for_map = data.get("request_id")
|
|
if req_id_for_map:
|
|
self.request_worker_map[req_id_for_map] = worker_pid
|
|
status_value = data.get("status", None)
|
|
if status_value is not None and status_value == RequestStatus.ABORT.value:
|
|
req_id = data["request_id"]
|
|
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
self.resource_manager.add_abort_req_ids(req_id)
|
|
continue
|
|
err_msg = None
|
|
try:
|
|
request = Request.from_dict(data)
|
|
|
|
request.metrics.scheduler_recv_req_time = time.time()
|
|
main_process_metrics.requests_number.inc()
|
|
trace_carrier = data.get("trace_carrier")
|
|
if trace_carrier:
|
|
request_id = data["request_id"].split("_")[0]
|
|
tracing.trace_set_proc_propagate_context(request_id, trace_carrier)
|
|
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
|
|
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
|
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
|
self.llm_logger.debug(f"Receive request from api server: {request}")
|
|
|
|
if self.is_paused:
|
|
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
|
|
self._send_error_response(
|
|
request.request_id,
|
|
"Request is aborted since LLM Engine is paused.",
|
|
worker_pid=worker_pid,
|
|
)
|
|
continue
|
|
except Exception as e:
|
|
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
|
err_msg = str(e)
|
|
results.append((data["request_id"], err_msg))
|
|
|
|
if self.guided_decoding_checker is not None and err_msg is None:
|
|
request, err_msg = self.guided_decoding_checker.schema_format(request)
|
|
if err_msg is not None:
|
|
self.llm_logger.error(f"Receive request error: {err_msg}")
|
|
results.append((request.request_id, err_msg))
|
|
|
|
if err_msg is None:
|
|
insert_task.append(request)
|
|
|
|
response = self.scheduler.put_requests(insert_task)
|
|
results.extend(response)
|
|
|
|
if request:
|
|
if request.request_id not in added_requests:
|
|
added_requests[request.request_id] = 0
|
|
added_requests[request.request_id] += 1
|
|
|
|
for request_id, failed in results:
|
|
if request_id in added_requests:
|
|
added_requests[request_id] -= 1
|
|
if added_requests[request_id] == 0:
|
|
added_requests.pop(request_id)
|
|
|
|
if failed is None:
|
|
main_process_metrics.num_requests_waiting.inc(1)
|
|
continue
|
|
|
|
self._send_error_response(request_id, failed)
|
|
except Exception as e:
|
|
self.llm_logger.error(
|
|
f"Error happened while receiving new request from zmq, details={e}, "
|
|
f"traceback={traceback.format_exc()}"
|
|
)
|
|
|
|
def run_control_method(self, control_req: ControlRequest):
|
|
"""
|
|
Execute control method, process control request and return response.
|
|
|
|
This method is responsible for handling control requests, calling the corresponding
|
|
handler function based on the method name in the request. If the method doesn't exist
|
|
or is not callable, it returns an error response; otherwise executes the method and
|
|
returns a success response.
|
|
|
|
Args:
|
|
control_req (ControlRequest): Control request object containing request ID,
|
|
method name and parameters.
|
|
|
|
Returns:
|
|
None: No return value, sends ControlResponse through send_response_server.
|
|
"""
|
|
method = control_req.get_method()
|
|
request_id = control_req.request_id
|
|
|
|
# Look up worker_pid for routing control response
|
|
worker_pid = None
|
|
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
|
worker_pid = self.request_worker_map.pop(request_id, None)
|
|
|
|
try:
|
|
self.llm_logger.info(f"Start to run control method {method}: {request_id}")
|
|
|
|
handler_name = f"_control_{method}"
|
|
handler = getattr(self, handler_name, None)
|
|
if handler is None or not callable(handler):
|
|
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
|
|
self.llm_logger.error(str(error_result))
|
|
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
|
|
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
|
return
|
|
|
|
result = handler(control_req)
|
|
self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}")
|
|
succ_result = ControlResponse(request_id, 200, "Success", result)
|
|
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
|
|
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Failed to run control method {method}: {request_id} {str(e)}"
|
|
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
|
error_result = ControlResponse(request_id, 500, error_msg)
|
|
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
|
|
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
|
|
|
|
def _control_pause(self, control_request: ControlRequest):
|
|
"""Pauses the LLM engine and aborts all running/inflight requests.
|
|
Args:
|
|
control_request: The control request containing pause command
|
|
|
|
Raises:
|
|
Exception: If pause is not supported in current configuration
|
|
Exception: If engine worker queue cleanup times out
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
|
if self.cfg.scheduler_config.name != "local":
|
|
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
|
|
|
|
self.llm_logger.info("Start to pause request generation.")
|
|
|
|
with self._pause_cond:
|
|
if self.is_paused:
|
|
self.llm_logger.info("Engine is already paused, no need to pause again.")
|
|
return
|
|
self.is_paused = True
|
|
|
|
self.llm_logger.info("Abort running requests.")
|
|
|
|
self.resource_manager.log_status()
|
|
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
|
|
timeout, count = 60, 0
|
|
while self.engine_worker_queue.exist_tasks():
|
|
time.sleep(0.001)
|
|
count += 1
|
|
if count >= timeout * 1000:
|
|
break
|
|
if count >= timeout * 1000:
|
|
error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!"
|
|
self.llm_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
running_reqs = self.resource_manager.preempted_all()
|
|
if len(running_reqs) > 0:
|
|
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
|
|
self.resource_manager.get_real_bsz()
|
|
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
|
|
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
|
|
# self.engine_worker_queue.clear_data()
|
|
self.token_processor.clear_data()
|
|
self.resource_manager.log_status()
|
|
|
|
# abort inflight requests to user
|
|
inflight_requests = self.scheduler.get_inflight_requests()
|
|
self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).")
|
|
for req in inflight_requests:
|
|
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
|
|
self.scheduler.reset()
|
|
|
|
if envs.ENABLE_V1_KVCACHE_MANAGER:
|
|
self.resource_manager.cache_manager.reset_cache()
|
|
else:
|
|
# pause cache transfer
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
self.llm_logger.info("Start to pause cache transfer.")
|
|
pause_transfer_request = ControlRequest(
|
|
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
|
|
)
|
|
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
|
|
# Wait for cache_transfer responses
|
|
asyncio.run(
|
|
self._wait_for_control_responses(
|
|
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
|
|
)
|
|
)
|
|
self.llm_logger.info("Successfully paused cache transfer.")
|
|
|
|
self.resource_manager.cache_manager.reset()
|
|
self.llm_logger.info("Successfully paused request generation.")
|
|
return None
|
|
|
|
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
|
|
"""Control function for resuming request generation.
|
|
|
|
This method resumes the paused request generation process by setting the pause flag
|
|
and notifying all waiting threads. It logs the start and end of the resume operation.
|
|
|
|
Args:
|
|
control_request: Control request object containing resume operation information
|
|
"""
|
|
self.llm_logger.info("Start to resume request generation.")
|
|
with self._pause_cond:
|
|
if not self.is_paused:
|
|
self.llm_logger.info("Engine is not paused, no need to resume.")
|
|
return None
|
|
self.is_paused = False
|
|
self._pause_cond.notify_all()
|
|
|
|
# resume cache transfer
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
self.llm_logger.info("Start to resume cache transfer.")
|
|
resume_transfer_request = ControlRequest(
|
|
request_id=f"{control_request.request_id}_resume_transfer", method="resume"
|
|
)
|
|
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
|
|
# Wait for cache_transfer responses
|
|
asyncio.run(
|
|
self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"])
|
|
)
|
|
self.llm_logger.info("Successfully resumed cache transfer.")
|
|
|
|
self.llm_logger.info("Successfully resumed request generation.")
|
|
return None
|
|
|
|
def _control_is_paused(self, control_request: ControlRequest) -> bool:
|
|
"""
|
|
Check if the LLM engine is in paused state.
|
|
|
|
Args:
|
|
control_request: Control request object.
|
|
|
|
Returns:
|
|
dict: Dictionary containing pause status information, {'is_paused': bool}
|
|
"""
|
|
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
|
|
with self._pause_cond:
|
|
return {"is_paused": self.is_paused}
|
|
|
|
def _get_is_paused_safe(self) -> bool:
|
|
"""Thread-safe getter for is_paused state, used by RegisterManager."""
|
|
with self._pause_cond:
|
|
return self.is_paused
|
|
|
|
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
|
"""Update model weights
|
|
Args:
|
|
control_request: Control request object containing parameters for weight updates
|
|
|
|
Returns:
|
|
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
|
|
|
|
Raises:
|
|
Exception: Raised when the engine is not in paused state
|
|
"""
|
|
self.llm_logger.info("Update Model Weights")
|
|
with self._pause_cond:
|
|
if self.is_paused is False:
|
|
error_msg = "Pause LLM Engine first before calling updating weights"
|
|
self.llm_logger.error(error_msg)
|
|
raise Exception(error_msg)
|
|
responses = self._call_worker(control_request, 60)
|
|
|
|
if responses:
|
|
new_version = None
|
|
for resp in responses:
|
|
# Expect each worker response to be a dict-like object
|
|
if isinstance(resp, dict) and "version" in resp:
|
|
new_version = resp.get("version")
|
|
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
|
|
break
|
|
if new_version is not None:
|
|
self.cfg.model_config.version = new_version
|
|
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
self.llm_logger.info("Start to update cache-transfer metadata after weight update.")
|
|
update_cache_request = ControlRequest(
|
|
request_id=f"{control_request.request_id}_update_weights",
|
|
method="update_weights",
|
|
args=copy.deepcopy(control_request.args),
|
|
)
|
|
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request))
|
|
asyncio.run(
|
|
self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"])
|
|
)
|
|
self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.")
|
|
|
|
return responses
|
|
|
|
def _control_abort_requests(self, control_req: ControlRequest):
|
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
|
args = control_req.get_args()
|
|
abort_all = args.get("abort_all", False)
|
|
req_ids = args.get("req_ids", [])
|
|
matched_input_ids = set()
|
|
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
|
|
|
|
# Step 1: Determine target request list
|
|
if abort_all:
|
|
# all requests in running + waiting
|
|
target_req_ids = now_reqs
|
|
else:
|
|
# filter out requests that actually exist
|
|
target_req_ids = []
|
|
for rid in req_ids:
|
|
if rid in now_reqs:
|
|
target_req_ids.append(rid)
|
|
matched_input_ids.add(rid)
|
|
elif f"{rid}_0" in now_reqs:
|
|
target_req_ids.append(f"{rid}_0")
|
|
matched_input_ids.add(rid)
|
|
|
|
if not target_req_ids:
|
|
return {"aborted": [], "not_found": req_ids if not abort_all else []}
|
|
|
|
# Step 2: Collect partial results
|
|
aborted_info = []
|
|
results = []
|
|
for req_id in target_req_ids:
|
|
request = self.resource_manager.requests.get(req_id)
|
|
if request is None:
|
|
scheduled_req = self.scheduler.requests.get(req_id)
|
|
if scheduled_req is None:
|
|
continue
|
|
request = scheduled_req.raw
|
|
|
|
partial_token_ids = list(request.output_token_ids)
|
|
|
|
# Construct finished response with partial results
|
|
now = time.time()
|
|
abort_metrics = RequestMetrics(
|
|
arrival_time=request.metrics.arrival_time if request.metrics else now,
|
|
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
|
|
engine_recv_latest_token_time=now,
|
|
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
|
|
request_start_time=request.metrics.arrival_time if request.metrics else now,
|
|
)
|
|
result = RequestOutput(
|
|
request_id=req_id,
|
|
finished=True,
|
|
outputs=CompletionOutput(
|
|
index=0,
|
|
send_idx=len(partial_token_ids),
|
|
token_ids=[self.data_processor.eos_token_ids[0]],
|
|
),
|
|
metrics=abort_metrics,
|
|
error_code=200,
|
|
error_msg="Aborted",
|
|
)
|
|
results.append(result)
|
|
aborted_info.append(
|
|
{
|
|
"request_id": req_id,
|
|
"output_token_count": len(partial_token_ids),
|
|
}
|
|
)
|
|
|
|
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
for req_id in target_req_ids:
|
|
self.resource_manager.add_abort_req_ids(req_id)
|
|
time.sleep(0.0001)
|
|
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
|
self._wait_abort_complete(target_req_ids)
|
|
|
|
# Add results to scheduler, engine will have a thread calling get_results,
|
|
# then cleanup and call send_response to send to client.
|
|
# When client disconnects, send_response will automatically ignore
|
|
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
|
try:
|
|
# self.send_response_server.send_response(req_id, [result])
|
|
self.scheduler.put_results(results)
|
|
except Exception:
|
|
pass # client may have disconnected
|
|
|
|
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
|
|
|
|
return {"aborted": aborted_info, "not_found": not_found}
|
|
|
|
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
|
|
"""
|
|
Wait for all abort requests to complete.
|
|
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
|
|
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
|
|
reset progress state if any, then continue monitoring
|
|
"""
|
|
target_set = set(target_req_ids)
|
|
prev_remaining_count = len(target_set)
|
|
last_progress_time = time.time()
|
|
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
|
while remaining:
|
|
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
|
if not remaining:
|
|
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
|
|
return
|
|
|
|
current_count = len(remaining)
|
|
if current_count < prev_remaining_count:
|
|
# progress made: recycle_abort_task was called
|
|
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
|
|
last_progress_time = time.time()
|
|
prev_remaining_count = current_count
|
|
|
|
if time.time() - last_progress_time > stall_timeout:
|
|
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
|
|
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
|
|
if stuck:
|
|
self.llm_logger.warning(
|
|
f"no abort progress for {stall_timeout}s, "
|
|
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
|
|
)
|
|
for req_id in list(stuck):
|
|
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
|
|
self.resource_manager.recycle_abort_task(req_id)
|
|
# reset progress state
|
|
last_progress_time = time.time()
|
|
prev_remaining_count = current_count - len(stuck)
|
|
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
|
|
|
|
time.sleep(0.005)
|
|
|
|
def _parse_tags(self, control_request: ControlRequest):
|
|
"""
|
|
Parse tags from control request.
|
|
"""
|
|
allowed_tags = ["weight", "kv_cache"]
|
|
tags = control_request.args.get("tags", None)
|
|
if tags is None:
|
|
tags = ",".join(allowed_tags)
|
|
control_request.args["tags"] = tags
|
|
self.llm_logger.info(
|
|
f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}"
|
|
)
|
|
elif isinstance(tags, list):
|
|
tags = ",".join(tags)
|
|
|
|
for tag in tags.split(","):
|
|
if tag not in allowed_tags:
|
|
raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")
|
|
|
|
return tags
|
|
|
|
def _control_sleep(self, control_request: ControlRequest):
|
|
"""
|
|
Offload gpu memory occupation for certain parts, e.g. weight, cache.
|
|
|
|
Args:
|
|
control_request: Control request object containing parameters for offloading memory
|
|
tags: list of tags to offload, supported values: ["weight", "cache"]
|
|
|
|
TODO: support different level of offloading, to provide options for release memory forever
|
|
or merely offloading to cpu memory for now.
|
|
"""
|
|
# Args check
|
|
tags = self._parse_tags(control_request)
|
|
control_request.args["tags"] = tags
|
|
|
|
# Make sure llm engine is paused.
|
|
self.llm_logger.warning(
|
|
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
|
|
"Please explicitly request to /pause the engine before /sleep."
|
|
)
|
|
self._control_pause(None)
|
|
|
|
# Determine which executors are needed for the sleep command
|
|
executors = set()
|
|
if "weight" in tags:
|
|
executors.add("worker")
|
|
if "kv_cache" in tags:
|
|
executors.add("worker")
|
|
if envs.ENABLE_V1_KVCACHE_MANAGER:
|
|
if self.cfg.cache_config.enable_prefix_caching:
|
|
self.resource_manager.cache_manager.reset_cache()
|
|
else:
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
executors.add("cache_transfer")
|
|
if self.cfg.cache_config.enable_prefix_caching:
|
|
self.resource_manager.cache_manager.reset()
|
|
|
|
# Dispatch sleep request to executors
|
|
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
|
|
self._dispatch_control_request(control_request, executors)
|
|
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))
|
|
|
|
def _control_wakeup(self, control_request: ControlRequest):
|
|
"""
|
|
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.
|
|
|
|
Args:
|
|
control_request: Control request object containing parameters for reloading memory
|
|
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
|
|
"""
|
|
# Args check
|
|
tags = self._parse_tags(control_request)
|
|
control_request.args["tags"] = tags
|
|
|
|
# Determine which executors are needed for the wakeup command
|
|
executors = set()
|
|
if "weight" in tags:
|
|
executors.add("worker")
|
|
if "kv_cache" in tags:
|
|
executors.add("worker")
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
executors.add("cache_transfer")
|
|
|
|
# Dispatch wakeup request to executors
|
|
self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}")
|
|
self._dispatch_control_request(control_request, executors)
|
|
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors))
|
|
|
|
# Resume the engine after wakeup
|
|
self._control_resume(None)
|
|
|
|
return result
|
|
|
|
def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
|
|
"""
|
|
Dispatch control requests to workers, cache managers or engine itself.
|
|
|
|
Args:
|
|
control_request: ControlRequest
|
|
executors: List
|
|
"""
|
|
if "worker" in executors:
|
|
self.engine_worker_queue.put_tasks(([control_request], 1))
|
|
if "cache_transfer" in executors:
|
|
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
|
|
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
|
|
return
|
|
|
|
async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
|
|
"""Wait for matching control responses from the selected executor queues.
|
|
|
|
This helper selects the control-response queues that belong to the requested
|
|
executors, then waits for all of them concurrently. Each queue gets a local
|
|
waiter that keeps reading until it sees the target request ID and stashes stale
|
|
responses into that queue's mailbox.
|
|
|
|
Args:
|
|
request_id: The control request ID that all returned responses must match.
|
|
timeout: Global timeout budget in seconds for the full multi-queue wait.
|
|
executors: Executor groups to wait for, for example `["worker"]` or
|
|
`["worker", "cache_transfer"]`. If `None`, waits for all control
|
|
response queues.
|
|
|
|
Returns:
|
|
A list of `response.result` values collected from all matched
|
|
`ControlResponse` objects. If no queue is selected, returns `None`.
|
|
|
|
Raises:
|
|
Exception: If the overall wait times out, or if any queue reports a non-200
|
|
control response or fails while waiting.
|
|
"""
|
|
|
|
def select_control_queues(executors: List[str] = None):
|
|
"""Select control response queues by executors."""
|
|
if executors is None:
|
|
return self._ctrl_output_queues
|
|
else:
|
|
queues = {}
|
|
for k, v in self._ctrl_output_queues.items():
|
|
if "w2e" in k and "worker" in executors:
|
|
queues[k] = v
|
|
elif "c2e" in k and "cache_transfer" in executors:
|
|
queues[k] = v
|
|
return queues
|
|
|
|
async def wait_one(queue_name: str, queue):
|
|
"""Wait until one queue returns a response for the current request_id."""
|
|
mailbox = self._ctrl_response_mailboxes[queue_name]
|
|
# Reuse a previously stashed response for this request before touching FMQ again.
|
|
cached_response = mailbox.pop(request_id, None)
|
|
if cached_response is not None:
|
|
self.llm_logger.info(f"Returning cached control response from {queue_name}.")
|
|
return cached_response
|
|
|
|
while True:
|
|
msg = await queue.get()
|
|
|
|
# Return if the response matches the control request
|
|
response: ControlResponse = msg.payload
|
|
if response.request_id == request_id:
|
|
self.llm_logger.info(f"Returning new control response from {queue_name}.")
|
|
return response
|
|
|
|
# Stash late responses from other control requests so they do not consume the
|
|
# current request's only read chance on this queue.
|
|
mailbox[response.request_id] = response
|
|
self.llm_logger.info(
|
|
f"Stashed old control response from {queue_name}. "
|
|
f"Expected request {request_id}, got request {response.request_id}"
|
|
)
|
|
|
|
# Select only the control response queues that belong to the requested executors.
|
|
queues = select_control_queues(executors)
|
|
if not queues:
|
|
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
|
|
return
|
|
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")
|
|
|
|
# Each queue gets its own waiter, which will stash stale responses until it finds the
|
|
# target request ID for this control request.
|
|
tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()}
|
|
done, pending = await asyncio.wait(tasks.values(), timeout=timeout)
|
|
if pending:
|
|
pending_names = [name for name, task in tasks.items() if task in pending]
|
|
done_names = [name for name, task in tasks.items() if task in done]
|
|
self.llm_logger.error(
|
|
f"Control request {request_id} execution timeout. "
|
|
f"Pending queues: {pending_names}, completed queues: {done_names}."
|
|
)
|
|
# Stop unfinished queue waiters so they do not outlive the control request.
|
|
for task in pending:
|
|
task.cancel()
|
|
await asyncio.gather(*pending, return_exceptions=True)
|
|
raise Exception(f"Control request {request_id} timed out after {timeout}s")
|
|
|
|
# Collect the results from all completed queues.
|
|
responses = []
|
|
for name, task in tasks.items():
|
|
try:
|
|
response = task.result()
|
|
except Exception as e:
|
|
self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}")
|
|
raise
|
|
|
|
if response.error_code != 200:
|
|
raise Exception(f"Error response from {name}: {response.error_message}")
|
|
responses.append(response.result)
|
|
|
|
return responses
|
|
|
|
def _call_worker(self, control_request: ControlRequest, timeout: int):
|
|
request_id = control_request.request_id
|
|
self.engine_worker_queue.put_tasks(([control_request], 1))
|
|
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
|
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))
|
|
|
|
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
|
|
self.llm_logger.error(
|
|
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
|
)
|
|
error_result = RequestOutput(
|
|
request_id=request_id,
|
|
finished=True,
|
|
error_code=error_code,
|
|
error_msg=error_msg,
|
|
)
|
|
# Look up worker_pid from mapping if not provided
|
|
if worker_pid is None and envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
|
worker_pid = self.request_worker_map.pop(request_id, None)
|
|
# Since the request is not in scheduler
|
|
# Send result by zmq directly
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
self.send_response_server.send_response(None, [[error_result]])
|
|
elif envs.ZMQ_SEND_BATCH_DATA:
|
|
self.send_response_server.send_response(None, [[error_result]], worker_pid=worker_pid)
|
|
else:
|
|
self.send_response_server.send_response(request_id, [error_result])
|
|
|
|
def _decode_token(self, token_ids, req_id, is_end):
|
|
delta_text = ""
|
|
if envs.FD_ENABLE_RETURN_TEXT:
|
|
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
|
|
if delta_text != "":
|
|
prefix_offset = self.data_processor.decode_status[req_id][0]
|
|
read_offset = self.data_processor.decode_status[req_id][1]
|
|
token_ids = cum_tokens[prefix_offset:read_offset]
|
|
else:
|
|
token_ids = []
|
|
|
|
if is_end and delta_text == "" and len(cum_tokens) > 0:
|
|
read_offset = self.data_processor.decode_status[req_id][1]
|
|
token_ids = cum_tokens[read_offset:]
|
|
|
|
if is_end:
|
|
del self.data_processor.decode_status[req_id]
|
|
return delta_text, token_ids
|
|
|
|
def _zmq_send_generated_tokens(self):
|
|
"""
|
|
Receive output for zmq
|
|
"""
|
|
while self.running:
|
|
try:
|
|
results = self.scheduler.get_results()
|
|
if len(results) == 0:
|
|
time.sleep(0.005)
|
|
continue
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
|
new_contents = []
|
|
for step_batch_results in results:
|
|
new_step_contents = []
|
|
for content in step_batch_results:
|
|
if isinstance(content, RequestOutput) and content.outputs is not None:
|
|
decode_type = content.outputs.decode_type
|
|
delta_text = ""
|
|
if decode_type == 0:
|
|
delta_text, token_ids = self._decode_token(
|
|
token_ids=content.outputs.token_ids,
|
|
req_id=content.request_id,
|
|
is_end=content.finished,
|
|
)
|
|
else:
|
|
token_ids = content.outputs.token_ids
|
|
if len(token_ids):
|
|
content.outputs.token_ids = token_ids
|
|
content.outputs.text = delta_text
|
|
new_step_contents.append(content)
|
|
elif content.finished:
|
|
new_step_contents.append(content)
|
|
else:
|
|
self.llm_logger.warning(
|
|
f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}"
|
|
)
|
|
else:
|
|
new_step_contents.append(content)
|
|
if new_step_contents:
|
|
new_contents.append(new_step_contents)
|
|
if new_contents:
|
|
self.send_response_server.send_response(None, new_contents)
|
|
|
|
else:
|
|
worker_batches = collections.defaultdict(list)
|
|
for request_id, contents in results.items():
|
|
new_contents = []
|
|
for content in contents:
|
|
if isinstance(content, RequestOutput) and content.outputs is not None:
|
|
decode_type = content.outputs.decode_type
|
|
delta_text = ""
|
|
if decode_type == 0:
|
|
delta_text, token_ids = self._decode_token(
|
|
token_ids=content.outputs.token_ids,
|
|
req_id=request_id,
|
|
is_end=content.finished,
|
|
)
|
|
else:
|
|
token_ids = content.outputs.token_ids
|
|
if len(token_ids):
|
|
content.outputs.token_ids = token_ids
|
|
content.outputs.text = delta_text
|
|
new_contents.append(content)
|
|
elif content.finished:
|
|
new_contents.append(content)
|
|
else:
|
|
self.llm_logger.warning(
|
|
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
|
)
|
|
else:
|
|
new_contents.append(content)
|
|
if new_contents:
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
wpid = self.request_worker_map.get(request_id)
|
|
worker_batches[wpid].append(new_contents)
|
|
is_finished = any(getattr(c, "finished", False) for c in new_contents)
|
|
if is_finished:
|
|
self.request_worker_map.pop(request_id, None)
|
|
else:
|
|
self.send_response_server.send_response(request_id, new_contents)
|
|
if envs.ZMQ_SEND_BATCH_DATA:
|
|
for wpid, batch_data in worker_batches.items():
|
|
if batch_data:
|
|
self.send_response_server.send_response(None, batch_data, worker_pid=wpid)
|
|
except Exception as e:
|
|
self.llm_logger.error(f"Unexpected error happend: {e}, {traceback.format_exc()!s}")
|
|
|
|
def _decode_process_splitwise_requests(self):
|
|
"""
|
|
Decode processes requests from engine worker queue, which are sent by prefill.
|
|
TODO: merge this function to the schedule function in resource manager
|
|
"""
|
|
allocate_resource_requests: list[Request] = []
|
|
prefilled_request_ouputs: list[RequestOutput] = []
|
|
|
|
def _fetch_requests():
|
|
if self.engine_worker_queue.disaggregate_queue_empty():
|
|
return
|
|
|
|
items = self.engine_worker_queue.get_disaggregated_tasks()
|
|
for item in items:
|
|
tasks = item[1]
|
|
if isinstance(tasks[0], Request):
|
|
self.llm_logger.debug(
|
|
f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}"
|
|
)
|
|
for task in tasks:
|
|
task.metrics.decode_recv_req_time = time.time()
|
|
allocate_resource_requests.extend(tasks)
|
|
elif isinstance(tasks[0], RequestOutput):
|
|
self.llm_logger.debug(
|
|
f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}"
|
|
)
|
|
if not isinstance(tasks, list):
|
|
tasks = [tasks]
|
|
for task in tasks:
|
|
task.finished = False
|
|
task.metrics.decode_recv_first_token_time = time.time()
|
|
prefilled_request_ouputs.extend(tasks)
|
|
|
|
def _process_allocate_resource_requests():
|
|
processed_indices = []
|
|
for idx, task in enumerate(allocate_resource_requests):
|
|
is_success = False
|
|
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
if self.resource_manager.preallocate_resource_in_d(task):
|
|
task.metrics.decode_preallocate_req_time = time.time()
|
|
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
|
|
self.split_connector.send_cache_info_to_prefill([task])
|
|
self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}")
|
|
processed_indices.append(idx)
|
|
is_success = True
|
|
else:
|
|
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
|
|
self.llm_logger.debug(f"D Resource available, processing task {task.request_id}")
|
|
self.insert_tasks([task])
|
|
task.metrics.decode_preallocate_req_time = time.time()
|
|
processed_indices.append(idx)
|
|
is_success = True
|
|
|
|
if not is_success:
|
|
if not self.enable_decode_cache_task:
|
|
task.error_msg = "Not enough resources"
|
|
self.split_connector.send_cache_info_to_prefill([task])
|
|
self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}")
|
|
processed_indices.append(idx)
|
|
else:
|
|
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
|
|
break
|
|
|
|
for idx in sorted(processed_indices, reverse=True):
|
|
allocate_resource_requests.pop(idx)
|
|
|
|
def _process_prefilled_requests():
|
|
nonlocal prefilled_request_ouputs
|
|
ready_request_outputs = []
|
|
waiting_request_outputs = []
|
|
|
|
for req_output in prefilled_request_ouputs:
|
|
if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id):
|
|
# ensure the api_server and scheduler in decode have
|
|
# received the request sent by the client
|
|
waiting_request_outputs.append(req_output)
|
|
continue
|
|
req_output.finished = False
|
|
ready_request_outputs.append(req_output)
|
|
self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}")
|
|
|
|
prefilled_request_ouputs = waiting_request_outputs
|
|
if self.cfg.splitwise_version == "v1":
|
|
# decode return first token to client
|
|
self.scheduler.put_results(ready_request_outputs)
|
|
|
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
self._insert_prefilled_requests(ready_request_outputs)
|
|
else:
|
|
for req_output in ready_request_outputs:
|
|
request_id = req_output.request_id
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids:
|
|
# first token is eos in Prefill, just recycle resource and continue
|
|
self.llm_logger.warning(f"{request_id} need not decode after first token")
|
|
self.resource_manager.pre_recycle_resource(request_id)
|
|
if request_id in self.token_processor.tokens_counter:
|
|
del self.token_processor.tokens_counter[request_id]
|
|
req_output.finished = True
|
|
self.scheduler.put_results([req_output])
|
|
continue
|
|
if req_output.error_code != 200:
|
|
self.llm_logger.warning(
|
|
f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource."
|
|
)
|
|
self.resource_manager.pre_recycle_resource(request_id)
|
|
if request_id in self.token_processor.tokens_counter:
|
|
del self.token_processor.tokens_counter[request_id]
|
|
self.scheduler.put_results([req_output])
|
|
continue
|
|
self.token_processor.tokens_counter[request_id] = 1
|
|
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
|
|
self.scheduler.put_results([req_output])
|
|
self.resource_manager.add_prefilled_request(req_output)
|
|
self.llm_logger.info(f"D has successfully added prefilled request, {request_id}")
|
|
|
|
def decode_loop():
|
|
while self.running:
|
|
try:
|
|
_fetch_requests()
|
|
_process_allocate_resource_requests()
|
|
_process_prefilled_requests()
|
|
time.sleep(0.001)
|
|
except Exception as e:
|
|
self.llm_logger.error(
|
|
f"Error in main loop of decode_process_splitwise_requests: " f"{e}, {traceback.format_exc()}"
|
|
)
|
|
time.sleep(0.01)
|
|
|
|
threading.Thread(target=decode_loop, daemon=True).start()
|
|
|
|
def start_cache_service(self, device_ids, ipc_signal_suffix):
|
|
console_logger.debug("Start cache manager...")
|
|
return self.resource_manager.cache_manager.launch_cache_manager(
|
|
cache_config=self.cfg.cache_config,
|
|
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
|
device_ids=device_ids,
|
|
pod_ip=self.cfg.master_ip,
|
|
engine_worker_queue_port=self.cfg.parallel_config.local_engine_worker_queue_port,
|
|
ipc_suffix=ipc_signal_suffix,
|
|
create_cache_tensor=False,
|
|
)
|
|
|
|
def check_and_free_block_tables(self):
|
|
self.resource_manager.check_and_free_block_tables()
|
|
|
|
def clear_data(self):
|
|
try:
|
|
self.llm_logger.info("Clear Data: Start")
|
|
self.token_processor.clear_data()
|
|
self.engine_worker_queue.clear_data()
|
|
if hasattr(self, "cache_task_queue"):
|
|
self.cache_task_queue.clear_transfer_task()
|
|
self.send_response_server.req_dict.clear()
|
|
self.recv_request_server.req_dict.clear()
|
|
# Clean up worker_pid mapping (batch mode)
|
|
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
|
|
self.request_worker_map.clear()
|
|
self.llm_logger.info("Clear Data: Successfully")
|
|
return True
|
|
except Exception as e:
|
|
self.llm_logger.error(f"Clear data error: {e}")
|
|
return False
|
|
|
|
def _exit_sub_services(self):
|
|
"""
|
|
exit sub services
|
|
"""
|
|
self.llm_logger.info("Exit sub services.....")
|
|
self.running = False
|
|
|
|
if self.use_async_llm:
|
|
# Clean up worker processes first (before closing multiprocessing services)
|
|
if hasattr(self, "worker_proc") and self.worker_proc is not None:
|
|
self.llm_logger.info("Cleaning up worker processes...")
|
|
try:
|
|
pgid = os.getpgid(self.worker_proc.pid)
|
|
os.killpg(pgid, signal.SIGTERM)
|
|
except Exception as e:
|
|
self.llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
|
|
|
|
# Clean up cache manager processes
|
|
if hasattr(self, "cache_manager_processes"):
|
|
self.llm_logger.info("Cleaning up cache manager processes...")
|
|
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
|
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
|
for p in self.cache_manager_processes:
|
|
self.llm_logger.info(f"Killing cache manager process {p.pid}")
|
|
try:
|
|
pgid = os.getpgid(p.pid)
|
|
os.killpg(pgid, signal.SIGTERM)
|
|
except Exception as e:
|
|
self.llm_logger.error(
|
|
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
|
|
)
|
|
|
|
if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None:
|
|
self.llm_logger.info("Cleaning up cache_task_queue...")
|
|
# Check if cleanup method exists
|
|
if hasattr(self.cache_task_queue, "cleanup"):
|
|
self.cache_task_queue.cleanup()
|
|
elif hasattr(self.cache_task_queue, "manager"):
|
|
try:
|
|
self.llm_logger.info("Shutting down cache_task_queue manager...")
|
|
self.cache_task_queue.manager.shutdown()
|
|
except Exception as e:
|
|
self.llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
|
|
|
|
if hasattr(self, "get_profile_block_num_signal"):
|
|
self.get_profile_block_num_signal.clear()
|
|
|
|
self.worker_ready_signal.clear()
|
|
self.loaded_model_signal.clear()
|
|
|
|
# Clean up other services
|
|
if hasattr(self, "dp_processed"):
|
|
for p in self.dp_processed:
|
|
self.llm_logger.info(f"Waiting for worker {p.pid} to exit")
|
|
p.join()
|
|
for p in self.dp_engine_worker_queue_server:
|
|
p.cleanup()
|
|
|
|
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
|
|
self.engine_worker_queue_server.cleanup()
|
|
self.exist_task_signal.clear()
|
|
self.exist_swapped_task_signal.clear()
|
|
self.worker_healthy_live_signal.clear()
|
|
self.cache_ready_signal.clear()
|
|
self.swap_space_ready_signal.clear()
|
|
self.cache_transfer_inited_signal.clear()
|
|
self.exist_prefill_task_signal.clear()
|
|
self.model_weights_status_signal.clear()
|
|
self.prefix_tree_status_signal.clear()
|
|
self.kv_cache_status_signal.clear()
|
|
if hasattr(self, "send_response_server") and self.send_response_server is not None:
|
|
self.send_response_server.close()
|
|
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
|
|
self.recv_request_server.close()
|
|
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
|
|
self.recv_control_cmd_server.close()
|
|
|
|
# 从 async_llm 移到 common_engine
|
|
def _worker_processes_ready(self):
|
|
"""
|
|
judge if all worker processes are ready
|
|
|
|
"""
|
|
if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node:
|
|
return True
|
|
return False
|
|
|
|
def _init_worker_signals(self):
|
|
"""
|
|
Initialize shared memory to indicate engine status
|
|
"""
|
|
# worker_ready_signal 用于worker进程感知engine是否启动完成
|
|
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
|
|
self.worker_ready_signal = IPCSignal(
|
|
name="worker_ready_signal",
|
|
array=worker_ready_signal_data,
|
|
dtype=np.int32,
|
|
suffix=self.ipc_signal_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
|
|
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.launched_cache_manager_signal = IPCSignal(
|
|
name="launched_cache_manager_signal",
|
|
array=launched_cache_manager_signal_data,
|
|
dtype=np.int32,
|
|
suffix=self.ipc_signal_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# launched_expert_service_signal: Used to sense whether each expert_service is started successfully
|
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
|
launched_expert_service_signal_data = np.zeros(
|
|
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
|
|
)
|
|
self.launched_expert_service_signal = IPCSignal(
|
|
name="launched_expert_service_signal",
|
|
array=launched_expert_service_signal_data,
|
|
dtype=np.int32,
|
|
suffix=self.ipc_signal_suffix,
|
|
create=True,
|
|
)
|
|
|
|
# loaded_model_signal: Used to detect whether each worker has completed model loading
|
|
loaded_model_signal_data = np.zeros([1], dtype=np.int32)
|
|
self.loaded_model_signal = IPCSignal(
|
|
name="loaded_model_signal",
|
|
array=loaded_model_signal_data,
|
|
dtype=np.int32,
|
|
suffix=self.ipc_signal_suffix,
|
|
create=True,
|
|
)
|
|
|
|
if self.do_profile:
|
|
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
|
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
|
|
else:
|
|
get_profile_block_num = np.zeros([1], dtype=np.int32)
|
|
self.get_profile_block_num_signal = IPCSignal(
|
|
name="get_profile_block_num",
|
|
array=get_profile_block_num,
|
|
dtype=np.int32,
|
|
suffix=self.ipc_signal_suffix,
|
|
create=True,
|
|
)
|
|
|
|
def _setting_environ_variables(self):
|
|
"""
|
|
配置环境变量
|
|
"""
|
|
variables = {
|
|
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
|
|
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")),
|
|
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
|
|
"FLAGS_use_append_attn": 1,
|
|
"NCCL_ALGO": "Ring",
|
|
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
|
|
"OMP_NUM_THREADS": 3,
|
|
}
|
|
# environment variables needed by Dy2St
|
|
variables.update(
|
|
{
|
|
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
|
|
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
|
|
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
|
|
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
|
|
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
|
|
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
|
|
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
|
|
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
|
|
"FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
|
|
),
|
|
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
|
|
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
|
|
),
|
|
}
|
|
)
|
|
|
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
|
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
|
|
else:
|
|
variables["FLAGS_use_pd_disaggregation"] = 1
|
|
# TODO dynamic load environment variable
|
|
if self.cfg.scheduler_config.splitwise_role == "prefill":
|
|
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
|
|
|
|
if self.cfg.enable_mm_runtime:
|
|
variables["FLAGS_max_partition_size"] = 1024
|
|
|
|
command_prefix = ""
|
|
for k, v in variables.items():
|
|
command_prefix += f"{k}={v} "
|
|
return command_prefix
|
|
|
|
def _start_worker_service(self):
|
|
"""
|
|
start gpu worker service
|
|
|
|
"""
|
|
log_dir = os.getenv("FD_LOG_DIR", default="log")
|
|
command_prefix = self._setting_environ_variables()
|
|
current_file_path = os.path.abspath(__file__)
|
|
current_dir_path = os.path.split(current_file_path)[0]
|
|
# TODO
|
|
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u"
|
|
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch"
|
|
pd_cmd = pd_cmd + f" --log_dir {log_dir}"
|
|
|
|
worker_path = "../worker/worker_process.py"
|
|
py_script = os.path.join(current_dir_path, worker_path)
|
|
|
|
ori_vocab_size = (
|
|
len(self.data_processor.tokenizer.sp_model)
|
|
if hasattr(self.data_processor.tokenizer, "sp_model")
|
|
else len(self.data_processor.tokenizer.vocab)
|
|
)
|
|
|
|
think_start_id = self.data_processor.tokenizer.get_vocab().get("<think>", -1)
|
|
if think_start_id >= 0:
|
|
self.llm_logger.info(f"Get think_start_id {think_start_id} from vocab.")
|
|
else:
|
|
self.llm_logger.info("No <think> token found in vocabulary, the model can not do reasoning.")
|
|
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
|
|
if think_end_id >= 0:
|
|
self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
|
|
else:
|
|
self.llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
|
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
|
|
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
|
|
if line_break_id < 0:
|
|
line_break_ids = self.data_processor.tokenizer.encode("\n", add_special_tokens=False)
|
|
if isinstance(line_break_ids, dict):
|
|
line_break_ids = line_break_ids.get("input_ids")
|
|
elif hasattr(line_break_ids, "input_ids"):
|
|
line_break_ids = line_break_ids.input_ids
|
|
if line_break_ids:
|
|
if isinstance(line_break_ids, (list, tuple)):
|
|
first = line_break_ids[0]
|
|
if isinstance(first, (list, tuple)):
|
|
line_break_id = int(first[0]) if first else -1
|
|
else:
|
|
line_break_id = int(first)
|
|
else:
|
|
line_break_id = int(line_break_ids)
|
|
if line_break_id >= 0:
|
|
self.llm_logger.info(f"Get line_break_id {line_break_id} from tokenizer.")
|
|
|
|
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
|
|
ips = None
|
|
if self.cfg.ips is not None:
|
|
ips = ",".join(self.cfg.ips)
|
|
arguments = (
|
|
f" --devices {self.cfg.parallel_config.device_ids} {py_script}"
|
|
f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}"
|
|
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
|
|
f" --model {self.cfg.model_config.model!s}"
|
|
f" --device_ids {self.cfg.parallel_config.device_ids}"
|
|
f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}"
|
|
f" --engine_worker_queue_port {ports}"
|
|
f" --pod_ip {self.cfg.master_ip}"
|
|
f" --block_size {self.cfg.cache_config.block_size}"
|
|
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
|
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
|
|
f" --pad_token_id {self.data_processor.pad_token_id}"
|
|
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
|
|
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
|
|
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
|
|
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
|
|
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
|
|
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
|
|
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
|
|
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
|
|
f" --ori_vocab_size {ori_vocab_size}"
|
|
f" --think_start_id {think_start_id}"
|
|
f" --think_end_id {think_end_id}"
|
|
f" --image_patch_id {image_patch_id}"
|
|
f" --line_break_id {line_break_id}"
|
|
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
|
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
|
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
|
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
|
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
|
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
|
|
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
|
|
f" --ips {ips}"
|
|
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
|
f" --runner {self.cfg.model_config.runner}"
|
|
f" --convert {self.cfg.model_config.convert}"
|
|
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
|
|
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
|
|
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
|
|
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
|
|
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
|
|
f" --deploy_modality {self.cfg.deploy_modality.value}"
|
|
)
|
|
if self.cfg.structured_outputs_config.logits_processors is not None:
|
|
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
|
|
if self.mm_max_tokens_per_item is not None:
|
|
arguments += f" --mm_max_tokens_per_item '{json.dumps(self.mm_max_tokens_per_item)}'"
|
|
|
|
worker_store_true_flag = {
|
|
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
|
|
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
|
|
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
|
|
"do_profile": self.do_profile,
|
|
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
|
|
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
|
|
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
|
|
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
|
|
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
|
|
"enable_logprob": self.cfg.model_config.enable_logprob,
|
|
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
|
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
|
|
"enable_entropy": self.cfg.model_config.enable_entropy,
|
|
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
|
|
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
|
|
}
|
|
for worker_flag, value in worker_store_true_flag.items():
|
|
if value:
|
|
arguments = arguments + f" --{worker_flag}"
|
|
|
|
worker_default_none_flag = {
|
|
"num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override,
|
|
"kvcache_storage_backend": self.cfg.cache_config.kvcache_storage_backend,
|
|
}
|
|
for worker_flag, value in worker_default_none_flag.items():
|
|
if value:
|
|
arguments = arguments + f" --{worker_flag} {value}"
|
|
|
|
if self.cfg.nnode > 1:
|
|
pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}"
|
|
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
|
self.llm_logger.info(f"Launch worker service command: {pd_cmd}")
|
|
p = subprocess.Popen(
|
|
pd_cmd,
|
|
stdout=subprocess.PIPE,
|
|
shell=True,
|
|
preexec_fn=os.setsid,
|
|
)
|
|
return p
|
|
|
|
def _stop_profile(self):
|
|
"""
|
|
Stop profiling of the model server and reset variables.
|
|
"""
|
|
self.do_profile = 0
|
|
while self.get_profile_block_num_signal.value[0] == 0:
|
|
if hasattr(self, "worker_proc") and self.worker_proc is not None:
|
|
if self.worker_proc.poll() is not None:
|
|
raise RuntimeError("Worker process failed to start." "Please check log/workerlog.* for details.")
|
|
time.sleep(1)
|
|
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
|
|
self.cfg.cache_config.reset(num_gpu_blocks)
|
|
self.resource_manager.reset_cache_config(self.cfg.cache_config)
|
|
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
if envs.ENABLE_V1_KVCACHE_MANAGER:
|
|
return
|
|
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
|
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
|
|
|
def check_health(self, time_interval_threashold=30):
|
|
"""
|
|
Check the health of the model server by checking whether all workers are alive.
|
|
|
|
"""
|
|
if self.worker_healthy_live_signal.value[0]:
|
|
elapsed_time = time.time() - self.worker_healthy_live_signal.value[0]
|
|
if elapsed_time > time_interval_threashold:
|
|
return False, "Worker Service Not Healthy"
|
|
|
|
return True, ""
|
|
|
|
def launch_components(self):
|
|
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
# 单机逻辑
|
|
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
|
|
self.splitwise_receive_thread.daemon = True
|
|
self.splitwise_receive_thread.start()
|
|
|
|
role = self.cfg.scheduler_config.splitwise_role
|
|
host_ip = self.cfg.host_ip
|
|
if self.cfg.scheduler_config.name == "splitwise":
|
|
self.scheduler.start(role, host_ip, self.cfg.register_info)
|
|
elif self.cfg.scheduler_config.name == "dp":
|
|
self.scheduler.start(
|
|
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
|
|
)
|
|
|
|
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
|
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
|
self.launched_expert_service_signal.value[0] = 1
|
|
self.dp_processed = []
|
|
self.dp_engine_worker_queue_server = []
|
|
for i in range(
|
|
1,
|
|
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
|
|
):
|
|
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
|
address = (
|
|
self.cfg.master_ip,
|
|
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
|
|
)
|
|
else:
|
|
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
|
|
|
|
self.llm_logger.info(f"dp start queue service {address}")
|
|
self.dp_engine_worker_queue_server.append(
|
|
EngineWorkerQueue(
|
|
address=address,
|
|
is_server=True,
|
|
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
|
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
|
)
|
|
)
|
|
from fastdeploy.engine.expert_service import (
|
|
start_data_parallel_service,
|
|
)
|
|
|
|
self.dp_processed.append(
|
|
multiprocessing.Process(
|
|
target=start_data_parallel_service,
|
|
args=(
|
|
self.cfg,
|
|
i,
|
|
),
|
|
)
|
|
)
|
|
self.llm_logger.info(
|
|
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
|
|
+ f" data parallel id {i}"
|
|
)
|
|
self.dp_processed[-1].start()
|
|
while self.launched_expert_service_signal.value[i] == 0:
|
|
time.sleep(1)
|
|
|
|
def check_worker_initialize_status(self):
|
|
"""
|
|
Check the initlialize status of workers by stdout logging
|
|
"""
|
|
|
|
def detect_thread():
|
|
for line in self.worker_proc.stdout:
|
|
line = line.decode("utf-8", errors="ignore")
|
|
if self.worker_init_status.get("finished", False):
|
|
break
|
|
if match := re.search(
|
|
r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)",
|
|
line,
|
|
):
|
|
self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100
|
|
elif (match := re.search(r"Start load layer (\d+)", line)) or (
|
|
match := re.search(r"set state for layer (\d+)", line)
|
|
):
|
|
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers
|
|
self.worker_init_status["layer_loadding"] = progress
|
|
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1:
|
|
self.worker_init_status["finished"] = True
|
|
|
|
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
|
|
self.checking_worker_status_thread.start()
|
|
|
|
# display weight loadding progress
|
|
with tqdm(total=100, desc="Loading Weights") as pbar:
|
|
progress = 0
|
|
while progress < 100:
|
|
progress = int(self.worker_init_status.get("weight_loadding", 0) * 100)
|
|
if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready():
|
|
progress = 100
|
|
pbar.update(progress - pbar.n)
|
|
pbar.refresh()
|
|
time.sleep(0.5)
|
|
if self.worker_proc.poll() is not None:
|
|
return False
|
|
|
|
# display layer loadding progress
|
|
with tqdm(total=100, desc="Loading Layers") as pbar:
|
|
progress = 0
|
|
while progress < 100:
|
|
progress = int(self.worker_init_status.get("layer_loadding", 0) * 100)
|
|
if self._worker_processes_ready():
|
|
progress = 100
|
|
pbar.update(progress - pbar.n)
|
|
pbar.refresh()
|
|
time.sleep(0.5)
|
|
if self.worker_proc.poll() is not None:
|
|
return False
|
|
|
|
self.worker_init_status["finished"] = True
|
|
try:
|
|
self.checking_worker_status_thread.join(timeout=1)
|
|
except Exception:
|
|
pass
|
|
return True
|