Files
FastDeploy/fastdeploy/engine/common_engine.py
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

2709 lines
129 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import asyncio
import collections
import copy
import json
import multiprocessing
import os
import re
import signal
import subprocess
import sys
import threading
import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
CompletionOutput,
ControlRequest,
ControlResponse,
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.engine.sched.scheduler_metrics_logger import SchedulerMetricsLogger
from fastdeploy.eplb.utils import init_eplb_signals
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
EngineCacheQueue,
EngineWorkerQueue,
IPCSignal,
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import EngineError, console_logger, envs, get_logger, llm_logger
try:
TokenProcessor = load_token_processor_plugins()
llm_logger.info(f"TokenProcessor plugin {TokenProcessor} loaded")
except:
from fastdeploy.output.token_processor import TokenProcessor
def _read_latest_worker_traceback(log_dir: str) -> Optional[str]:
"""读取 workerlog.* 文件中的最新 traceback。"""
try:
candidates = sorted(Path(log_dir).glob("workerlog.*"), key=lambda path: path.stat().st_mtime, reverse=True)
except OSError:
return None
for path in candidates:
try:
content = path.read_text(encoding="utf-8", errors="ignore")
except OSError:
continue
marker = "Traceback (most recent call last):"
start = content.rfind(marker)
if start != -1:
return content[start:].strip()
return None
def _format_worker_launch_failure_message(log_dir: str) -> str:
"""格式化 worker 启动失败的错误消息,包含 traceback 信息。"""
message = "Failed to launch worker processes, check log/workerlog.* for more details."
traceback_text = _read_latest_worker_traceback(log_dir)
if traceback_text:
return f"{message}\n{traceback_text}"
return message
class EngineService:
"""
Base class containing common engine functionality
"""
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.
Args:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
self.use_async_llm = use_async_llm
if self.cfg.parallel_config.data_parallel_size > 1:
self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
)
else:
self.llm_logger = llm_logger
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_output_queues = {}
self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict)
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for tp_rank in range(tp_size):
# create worker control response queue
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
# create cache control response queue
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.llm_logger.info("Use V1 KVCache Scheduler")
self.resource_manager = ResourceManagerV1(
cfg.scheduler_config.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
else:
self.llm_logger.info("Use V0 KVCache Scheduler")
self.resource_manager = ResourceManager(
cfg.scheduler_config.max_num_seqs,
cfg,
cfg.parallel_config.tensor_parallel_size,
cfg.scheduler_config.splitwise_role,
cfg.parallel_config.local_data_parallel_id,
)
self.start_worker_queue_service(start_queue)
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port)
self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}")
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
self.token_processor = TokenProcessor(
cfg=cfg,
cached_generated_tokens=self.scheduler,
engine_worker_queue=self.engine_worker_queue,
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
self.scheduler_metrics_logger = SchedulerMetricsLogger(
enabled=True,
dp_rank=self.cfg.parallel_config.local_data_parallel_id,
)
self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger
self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger)
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
self.partial_chunked_tokens[idx] = (
(self.cfg.scheduler_config.max_num_batched_tokens // idx)
// self.cfg.cache_config.block_size
* self.cfg.cache_config.block_size
)
self.bos_client = None
self.mm_max_tokens_per_item = None
self.guided_decoding_checker = None
if self.cfg.structured_outputs_config.guided_decoding_backend != "off":
self.guided_decoding_checker = schema_checker(
self.cfg.structured_outputs_config.guided_decoding_backend,
disable_any_whitespace=self.cfg.structured_outputs_config.disable_any_whitespace,
)
self._init_worker_monitor_signals()
# Initialize RegisterManager
self._register_manager = RegisterManager(
cfg=self.cfg,
engine_worker_queue=self.engine_worker_queue,
get_is_paused=self._get_is_paused_safe,
)
if self.cfg.eplb_config.enable_eplb:
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
init_eplb_signals(cfg, current_suffix)
if self.use_async_llm:
# Add worker management attributes
self.worker_proc = None
self.do_profile = 1 if self.cfg.cache_config.num_gpu_blocks_override is None else 0
self.ipc_signal_suffix = None
self.cache_manager_processes = None
if envs.ENABLE_V1_KVCACHE_MANAGER:
from fastdeploy.cache_manager.v1.cache_utils import get_request_block_hasher
self._block_hasher = get_request_block_hasher(block_size=self.cfg.cache_config.block_size)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self, async_llm_pid=None):
self.running = True
console_logger.debug("Start engineService...")
if self.use_async_llm:
self.start_worker_service(async_llm_pid)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.insert_task_to_worker_thread = threading.Thread(
target=self._schedule_request_to_worker_v1, daemon=True
)
else:
self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True)
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role == "decode":
self._decode_process_splitwise_requests()
self._register_manager.start()
def start_worker_service(self, async_llm_pid=None):
# Initialize IPC signals for worker management
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
self._init_worker_signals()
# Create data processor if not exists
if not hasattr(self, "data_processor"):
self.create_data_processor()
# Launch components: scheduler, cache_manager, expert_service et.al.
self.launch_components()
# If block number is specified and model is deployed in splitwise mode, start cache manager first
if (
not self.do_profile
and self.cfg.scheduler_config.splitwise_role != "mixed"
and not envs.ENABLE_V1_KVCACHE_MANAGER
):
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Start worker processes
self.worker_proc = self._start_worker_service()
time.sleep(5)
self.worker_init_status = dict()
result_container = {}
def check_worker_initialize_status_func(res: dict):
res["worker_is_alive"] = True
if not self.check_worker_initialize_status():
self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR))
res["worker_is_alive"] = False
self.check_worker_initialize_status_func_thread = threading.Thread(
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
)
self.check_worker_initialize_status_func_thread.start()
# Wait model loading
while self.loaded_model_signal.value[0] == 0:
# Make sure worker process is alive
if not self.check_worker_initialize_status_func_thread.is_alive():
return False
time.sleep(1)
# If block number is not specified, let workers do profiling to determine the block number,
# and then start the cache manager
if self.do_profile:
self._stop_profile()
elif (
self.cfg.scheduler_config.splitwise_role == "mixed"
and self.cfg.cache_config.enable_prefix_caching
and not envs.ENABLE_V1_KVCACHE_MANAGER
):
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Worker launched
self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]:
self.llm_logger.error(_format_worker_launch_failure_message(envs.FD_LOG_DIR))
return False
# Start ZMQ service for communication with AsyncLLM
if async_llm_pid:
self.start_zmq_service(async_llm_pid)
def create_data_processor(self):
self.input_processor = InputPreprocessor(
self.cfg.model_config,
self.cfg.structured_outputs_config.reasoning_parser,
self.cfg.limit_mm_per_prompt,
self.cfg.mm_processor_kwargs,
self.cfg.tool_parser,
enable_mm_runtime=self.cfg.enable_mm_runtime,
)
self.data_processor = self.input_processor.create_processor()
self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item(
self.cfg.model_config.max_model_len
)
if self.mm_max_tokens_per_item is not None:
max_chunk_tokens = self.cfg.get_max_chunk_tokens(self.mm_max_tokens_per_item)
self.cfg.cache_config.postprocess(max_chunk_tokens, self.cfg.scheduler_config.max_num_seqs)
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
self.llm_logger.info(f"current_suffix: {current_suffix}")
exist_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
name="exist_task_signal",
array=exist_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
exist_swapped_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_swapped_task_signal = IPCSignal(
name="exist_swapped_task_signal",
array=exist_swapped_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
self.exist_prefill_task_signal = IPCSignal(
name="exist_prefill_task_signal",
array=exist_prefill_task_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
engine_forward_signal_data = np.zeros([1], dtype=np.int32)
self.engine_forward_signal = IPCSignal(
name="engine_forward_signal",
array=engine_forward_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(
shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32
)
self.worker_healthy_live_signal = IPCSignal(
name="worker_healthy_live_signal",
array=worker_healthy_live_recorded_time_array,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
cache_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.cache_ready_signal = IPCSignal(
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
swap_space_ready_signal_data = np.zeros(shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32)
self.swap_space_ready_signal = IPCSignal(
name="swap_space_ready_signal",
array=swap_space_ready_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
cache_transfer_inited_signal_data = np.zeros(
shape=[self.cfg.parallel_config.tensor_parallel_size], dtype=np.int32
)
self.cache_transfer_inited_signal = IPCSignal(
name="cache_transfer_inited_signal",
array=cache_transfer_inited_signal_data,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
model_weights_status = np.zeros([1], dtype=np.int32)
self.model_weights_status_signal = IPCSignal(
name="model_weights_status",
array=model_weights_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
prefix_tree_status = np.zeros([1], dtype=np.int32)
self.prefix_tree_status_signal = IPCSignal(
name="prefix_tree_status",
array=prefix_tree_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
kv_cache_status = np.zeros([1], dtype=np.int32)
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=current_suffix,
create=True,
)
def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
"""
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock"
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
if start_queue:
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
self.engine_worker_queue_server = EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
# Dynamically updates the port value if an anonymous port is used
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
self.cfg.parallel_config.local_engine_worker_queue_port = (
self.engine_worker_queue_server.get_server_port()
)
address = (
self.cfg.master_ip,
self.cfg.parallel_config.local_engine_worker_queue_port,
)
if not envs.ENABLE_V1_KVCACHE_MANAGER:
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info(
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
)
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks: List[Request], current_id=-1):
"""
Allocate resource and insert tasks to engine.
Used in v0_kvcache_scheduler.
"""
if not isinstance(tasks, list):
tasks = [tasks]
self.resource_manager.check_and_free_block_tables()
need_delete_tasks = []
for task in tasks:
rid = task.request_id.split("_")[0]
trace_carrier = task.trace_carrier
if trace_carrier:
tracing.trace_set_proc_propagate_context(rid, trace_carrier)
task.trace_carrier = tracing.trace_get_proc_propagate_context(rid)
if self.cfg.scheduler_config.splitwise_role == "prefill":
status, msg = self.split_connector.check_decode_allocated(task)
if status:
task.metrics.ask_decode_resource_finish_time = time.time()
else:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
for item in tasks:
trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, item.request_id, getattr(item, "user", ""))
available_batch = np.sum(self.resource_manager.stop_flags)
if len(tasks) > available_batch:
self.llm_logger.error(f"Inserting batch:{len(tasks)} exceeds the available batch:{available_batch}.")
self.llm_logger.error("The exceeded part will be ignored!")
tasks = tasks[:available_batch]
req_ids = [t.request_id for t in tasks]
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
if not tasks:
error_msg = f"The request required resources is exceed the limit, request id={req_ids}."
self.llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=500)
return False
self.token_processor.number_of_tasks += len(tasks)
is_decode = False
is_prefill = False
for i in range(len(tasks)):
if tasks[i].disaggregate_info is not None:
if self.cfg.scheduler_config.splitwise_role == "decode":
is_decode = True
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.split_connector.send_cache_info_to_messager(tasks, current_id)
elif self.cfg.scheduler_config.splitwise_role == "decode":
self.split_connector.send_cache_info_to_prefill(tasks)
if not is_decode:
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
if not getattr(task, "has_been_preempted_before", False):
task.metrics.inference_start_time = time.time()
tracing.trace_report_span(
tracing.TraceSpanName.SCHEDULE,
task.request_id.split("_")[0],
int(task.metrics.scheduler_recv_req_time * 1e9),
int(task.metrics.inference_start_time * 1e9),
thread_finish_flag=True,
)
trace_print(LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", ""))
trace_print(LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", ""))
else:
trace_print(
LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "")
)
if not is_prefill:
if not self.cfg.enable_mm_runtime:
self.update_requests_chunk_size(tasks)
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
return True
def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
"""
Decode insert prefilled requests into engine worker queue.
Used in v0_kvcache_scheduler.
Args:
request_outputs: a list of RequestOutput sent by prefill instance
"""
to_infer_reqs = []
for req_out in request_outputs:
solt_idx = self.resource_manager.req_dict[req_out.request_id]
del self.resource_manager.req_dict[req_out.request_id]
cur_req = self.resource_manager.tasks_list[solt_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not req_out.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.llm_logger.warning(f"{req_out.request_id} need not decode after first token")
continue
cur_req.prompt_token_ids[0] = req_out.outputs.token_ids[0]
cur_req.num_cached_tokens = req_out.num_cached_tokens
req_out.metrics.decode_recv_req_time = cur_req.metrics.decode_recv_req_time
req_out.metrics.decode_preallocate_req_time = cur_req.metrics.decode_preallocate_req_time
cur_req.metrics = req_out.metrics
cur_req.metrics.decode_inference_start_time = time.time()
if (
self.cfg.speculative_config.method == SpecMethod.MTP
and self.cfg.scheduler_config.splitwise_role == "decode"
):
cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids)
if req_out.error_code != 200:
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.scheduler.put_results([req_out])
self.llm_logger.warning(
f"{req_out.request_id} prefill failed with msg:{req_out.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[req_out.request_id] = 1
to_infer_reqs.append(cur_req)
if to_infer_reqs:
self.engine_worker_queue.put_tasks((to_infer_reqs, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put requests to engine worker queue, task:{to_infer_reqs}")
return True
def task_is_finished(self, index):
"""
judge if the task is finished
"""
assert index < len(self.resource_manager.stop_flags)
return self.resource_manager.stop_flags[index]
def all_tasks_finished(self):
"""
judge if all tasks are finished
"""
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
def update_requests_chunk_size(self, requests):
"""
update each request's chunk size info
"""
def update_tokens(idx, chunk_size, update_chunk=False):
nonlocal remain_batched_tokens, chunk_request_num
if update_chunk:
requests_chunk[idx][-1] += chunk_size
else:
requests_chunk[idx].append(chunk_size)
remain_batched_tokens -= chunk_size
current_request_size[idx] -= chunk_size
if current_request_size[idx] <= 0:
chunk_request_num -= 1
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
current_request_size = [request.prompt_token_ids_len for request in requests]
requests_chunk = [[] for _ in range(len(requests))]
chunk_request_num = len(current_request_size)
while chunk_request_num >= 1:
remain_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
for idx in range(len(current_request_size)):
if current_request_size[idx] <= 0:
continue
chunk_size = min(
current_request_size[idx],
self.partial_chunked_tokens[chunk_request_num],
)
update_tokens(idx, chunk_size)
while remain_batched_tokens >= self.cfg.cache_config.block_size:
# 当前 max_num_batched_tokens 还有剩余时,优先分配给较短的请求
waiting_requests = [input_lens for input_lens in current_request_size if input_lens > 0]
if len(waiting_requests) == 0:
break
available_tokens = (
remain_batched_tokens // self.cfg.cache_config.block_size * self.cfg.cache_config.block_size
)
append_idx = current_request_size.index(min(waiting_requests))
chunk_size = min(
current_request_size[append_idx],
self.partial_chunked_tokens[chunk_request_num],
available_tokens,
)
update_tokens(append_idx, chunk_size, update_chunk=True)
for idx in range(len(requests)):
requests[idx].set("prefill_chunk_info", requests_chunk[idx])
def update_mm_requests_chunk_size(self, requests):
"""
update each multimodal request's chunk size info
"""
if not self.cfg.cache_config.enable_chunked_prefill or len(requests) == 0:
return
for request in requests:
inputs = request.multimodal_inputs
# 兼容没有图片和视频的情况
if inputs["images"] is None:
inputs["image_type_ids"] = np.array([], dtype="int32")
inputs["grid_thw"] = np.array([], dtype="int64")
inputs["images"] = np.array([], dtype="uint8")
input_ids = paddle.to_tensor(inputs["input_ids"], dtype="int64")
image_type_ids = paddle.to_tensor(inputs["image_type_ids"], dtype="int32")
image_mask = input_ids == self.data_processor.image_patch_id
image_token_sum = paddle.full(shape=[len(input_ids) + 1], fill_value=0, dtype="int32")
image_token_sum[1:] = paddle.cumsum(image_mask.cast("int32"), dtype="int32")
grid_thw = []
for one in inputs["grid_thw"]:
if one[0] == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, one[1], one[2]]] * (one[0] // 2))
grid_thw = paddle.to_tensor(grid_thw, dtype="int64")
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse
chunk_image_num, chunk_seq_len = get_mm_split_fuse(
input_ids,
image_type_ids,
image_token_sum,
grid_thw,
self.data_processor.image_patch_id,
len(grid_thw),
0,
len(input_ids),
0,
self.partial_chunked_tokens[1],
2048,
)
grid_thw = grid_thw.numpy().reshape([-1, 3])
num_chunks = len(chunk_image_num)
chunks_info = []
input_ids_st, image_type_ids_st, grid_thw_st, patch_st = 0, 0, 0, 0
for idx in range(num_chunks):
chunk_input_ids = inputs["input_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunk_token_type_ids = inputs["token_type_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
actual_image_num = np.sum(grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx], 0])
chunk_image_type_ids = inputs["image_type_ids"][
image_type_ids_st : image_type_ids_st + actual_image_num
]
chunk_grid_thw = grid_thw[grid_thw_st : grid_thw_st + chunk_image_num[idx]]
chunk_patch_num = np.sum(np.prod(chunk_grid_thw, axis=1))
chunk_images = inputs["images"][patch_st : patch_st + chunk_patch_num]
chunk_position_ids = inputs["position_ids"][input_ids_st : input_ids_st + chunk_seq_len[idx]]
chunks_info.append(
{
"input_ids": chunk_input_ids,
"token_type_ids": chunk_token_type_ids,
"image_type_ids": (chunk_image_type_ids if chunk_image_type_ids.shape[0] else None),
"grid_thw": (chunk_grid_thw if chunk_grid_thw.shape[0] else None),
"images": (chunk_images if chunk_images.shape[0] else None),
"position_ids": chunk_position_ids,
}
)
input_ids_st += chunk_seq_len[idx]
image_type_ids_st += actual_image_num
grid_thw_st += chunk_image_num[idx]
patch_st += chunk_patch_num
request.set("prefill_chunk_info", chunks_info)
def _schedule_request_to_worker(self):
"""
Insert task to engine thread, monitor scheduler request queue.
if the engine has resource, insert task to engine
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
current_id = 0
while getattr(self, "running", True):
try:
if self.resource_manager.available_batch() == 0:
time.sleep(0.001)
continue
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
if (
self.cfg.scheduler_config.splitwise_role == "mixed"
or self.split_connector.has_splitwise_tasks()
):
time.sleep(0.005)
continue
if self.engine_worker_queue.num_cache_infos() > 0:
time.sleep(0.001)
continue
if len(self.split_connector.current_request_ids) > 0:
time.sleep(0.001)
continue
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
self.resource_manager.check_and_free_block_tables()
tasks = self.scheduler.get_requests(
available_blocks=self.resource_manager.available_block_num(),
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.scheduler_config.max_num_batched_tokens,
batch=num_prefill_batch,
)
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
continue
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
task.metrics.ask_decode_resource_start_time = time.time()
self.split_connector.send_splitwise_tasks(tasks, current_id)
insert_successful = self.insert_tasks(tasks, current_id)
if insert_successful:
current_id = current_id + 1
else:
continue
main_process_metrics.num_requests_waiting.dec(len(tasks))
main_process_metrics.num_requests_running.inc(len(tasks))
except Exception as e:
err_msg = f"Error happened while insert task to engine: {e}, {traceback.format_exc()!s}."
self.llm_logger.error(err_msg)
def _schedule_request_to_worker_v1(self):
"""
Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1).
"""
tracing.trace_set_thread_info("Scheduler Task to Work")
get_request_pool = ThreadPoolExecutor(max_workers=1)
is_fetching = False
def _fetch_request():
try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
nonlocal is_fetching
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
self.cfg.max_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len
available_blocks = self.cfg.cache_config.max_block_num_per_seq
tasks = self.scheduler.get_requests(
available_blocks=available_blocks,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
for task in tasks:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
# cache_manager_v1 set block_hasher to request
if hasattr(self, "_block_hasher"):
task.set_block_hasher(self._block_hasher)
if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
is_fetching = False
return
if tasks:
self.llm_logger.debug(
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:
# start async preprocess
self.resource_manager.apply_async_preprocess(task)
need_delete_tasks = []
if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.debug(
f"P has allocated resources and then ask D resource for request: {task.request_id}"
)
task.metrics.ask_decode_resource_start_time = time.time()
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.warning(
f"D failed to allocate resource for request {task.request_id}, try again."
)
time.sleep(0.05)
else:
task.metrics.ask_decode_resource_finish_time = time.time()
break
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.debug(
f"P has allocated resources and then ask D resource for req_id: {task.request_id}"
)
task.metrics.ask_decode_resource_start_time = time.time()
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
task.metrics.ask_decode_resource_finish_time = time.time()
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# to send cache info to cache messager
if tasks:
need_check_req_ids = [task.request_id for task in tasks]
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
finished_ids, delete_tasks_list = [], []
while need_check_req_ids:
finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req())
self.llm_logger.debug(
f"P has successfully sent cache infos to cache messager for requests: {finished_ids}"
)
if finished_ids:
for task in tasks:
result = self.resource_manager.waiting_async_process(task)
if result is None:
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=task.error_code,
error_msg=task.error_message,
)
]
)
need_check_req_ids.remove(task.request_id)
delete_tasks_list.append(task)
elif result is False:
if task.request_id in finished_ids:
need_check_req_ids.remove(task.request_id)
finished_ids.remove(task.request_id)
else:
time.sleep(0.001)
for tmp_task in delete_tasks_list:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
# Fetch requests and add them to the scheduling queue
if tasks:
for task in tasks:
task.metrics.add_req_to_resource_manager_time = time.time()
trace_print(
LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")
)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks)
self.llm_logger.info(
f"P add requests into running queue: {[task.request_id for task in tasks]}"
)
else:
for task in tasks:
self.resource_manager.add_request(task)
is_fetching = False
except Exception as e:
self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}")
is_fetching = False
while self.running:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
try:
if not is_fetching:
# Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool.
try:
is_fetching = True
get_request_pool.submit(_fetch_request)
except RuntimeError as e:
if "shutdown" in str(e):
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
break
else:
raise
if self.cfg.scheduler_config.splitwise_role != "mixed":
# Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished.
# Once the forward pass finishes, these accumulated requests can be scheduled in larger,
# more efficient batches.
if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0:
time.sleep(0.001)
continue
else:
# In mixed, todo: optimze cache swap, to decouple swap from scheduler
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
continue
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
# 2. Schedule requests
batch_request, error_tasks = self.resource_manager.schedule()
# 3. Send to engine
if len(batch_request) > 0:
if self.cfg.scheduler_config.splitwise_role == "decode":
for task in batch_request:
if task.task_type == RequestType.PREEMPTED:
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
self.llm_logger.error(msg)
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
self.resource_manager.get_real_bsz()
for task in batch_request:
if task.task_type == RequestType.PREFILL:
rid = task.request_id.split("_")[0]
if isinstance(task, Request) and task.has_been_preempted_before:
trace_print(
LoggingEventName.RESCHEDULED_INFERENCE_START,
task.request_id,
getattr(task, "user", ""),
)
else:
trace_carrier = task.trace_carrier
tracing.trace_set_proc_propagate_context(rid, trace_carrier)
trace_carrier = tracing.trace_get_proc_propagate_context(rid)
task.trace_carrier = trace_carrier
tracing.trace_report_span(
tracing.TraceSpanName.SCHEDULE,
rid,
int(task.metrics.scheduler_recv_req_time * 1e9),
int(time.time() * 1e9),
thread_finish_flag=True,
)
trace_print(
LoggingEventName.RESOURCE_ALLOCATE_END, task.request_id, getattr(task, "user", "")
)
trace_print(
LoggingEventName.REQUEST_SCHEDULE_END, task.request_id, getattr(task, "user", "")
)
trace_print(
LoggingEventName.INFERENCE_START, task.request_id, getattr(task, "user", "")
)
if isinstance(task, Request):
if self.cfg.scheduler_config.splitwise_role == "decode":
task.metrics.decode_inference_start_time = time.time()
elif not task.has_been_preempted_before:
task.metrics.inference_start_time = time.time()
self.engine_worker_queue.put_tasks((batch_request, self.resource_manager.real_bsz))
else:
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
# This helps EP workers barrier for syncing tasks not hang.
if self.cfg.parallel_config.enable_expert_parallel:
self.engine_worker_queue.put_tasks(
(batch_request, self.resource_manager.real_bsz)
) # Empty (as idle tasks for ep)
# 4. Response error tasks
if error_tasks:
for request_id, failed in error_tasks:
if failed is None:
self.llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
continue
self._send_error_response(request_id, failed)
if len(batch_request) <= 0 and not error_tasks:
time.sleep(0.005)
except RuntimeError as e:
raise e
except Exception as e:
err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc()))
self.llm_logger.error(err_msg)
def _get_scheduler_unhandled_request_num(self) -> int:
"""
Get scheduler-level pending request count when supported.
"""
get_unhandled = getattr(self.scheduler, "get_unhandled_request_num", None)
if not callable(get_unhandled):
return 0
try:
unhandled = int(get_unhandled())
except Exception as e:
self.llm_logger.debug(f"Failed to get scheduler unhandled request num: {e}")
return 0
return max(unhandled, 0)
def start_zmq_service(self, api_server_pid=None):
if api_server_pid is None:
return
self.api_server_pid = api_server_pid
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
self.internal_adapter = InternalAdapter(
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
)
# ROUTER mode: need to receive client handles
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
else:
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
if envs.ZMQ_SEND_BATCH_DATA:
# PUSH mode: batch send, no need to receive client handles
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PUSH)
# Mapping from request_id to worker_pid for routing batch responses
self.request_worker_map = {}
else:
# ROUTER mode: per-query send, need to receive client handles
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
self.recv_result_handle_thread = threading.Thread(
target=self.send_response_server.recv_result_handle, daemon=True
)
self.recv_result_handle_thread.start()
time.sleep(3)
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
self.insert_task_to_scheduler_thread.start()
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
self.receive_output_thread.start()
def _insert_zmq_task_to_scheduler(self):
tracing.trace_set_thread_info("Insert Task to Scheduler")
added_requests: Dict[str, int] = dict()
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if self.cfg.scheduler_config.splitwise_role == "decode":
return
while self.running:
try:
block = True if len(added_requests) == 0 else False
if not self.cfg.enable_mm_runtime:
err, data = self.recv_request_server.receive_json_once(block)
else:
err, data = self.recv_request_server.receive_pyobj_once(block)
if err is not None:
# The message "Context was terminated" is normal when closing a ZMQ context
if "Context was terminated" in str(err):
self.llm_logger.info(
"Engine stops inserting zmq task into scheduler due to ZMQ context termination (normal shutdown)."
)
else:
self.llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.recv_request_server = ZmqTcpServer(
port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL
)
else:
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
continue
# Extract zmq_worker_pid for per-worker PUSH routing.
# Only needed when ZMQ_SEND_BATCH_DATA=True AND not using internal adapter,
# because FD_ENABLE_INTERNAL_ADAPTER uses ROUTER (worker_pid is irrelevant).
worker_pid = None
if envs.ZMQ_SEND_BATCH_DATA and not envs.FD_ENABLE_INTERNAL_ADAPTER:
worker_pid = data["zmq_worker_pid"]
if ControlRequest.is_control_request(data):
try: # todo: run control request async, do not block request generation
if worker_pid is not None:
self.request_worker_map[data.get("request_id")] = worker_pid
control_req = ControlRequest.from_dict(data)
self.run_control_method(control_req)
except Exception as e:
self.llm_logger.error(
f"Failed to process control request {data.get('request_id')}: "
f"{e}, {traceback.format_exc()}"
)
continue
request, insert_task = data, []
results: List[Tuple[str, Optional[str]]] = list()
if data:
# Store worker_pid mapping for normal/abort requests
if worker_pid is not None:
req_id_for_map = data.get("request_id")
if req_id_for_map:
self.request_worker_map[req_id_for_map] = worker_pid
status_value = data.get("status", None)
if status_value is not None and status_value == RequestStatus.ABORT.value:
req_id = data["request_id"]
self.llm_logger.info(f"Receive abort request, req_id: {req_id}")
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager.add_abort_req_ids(req_id)
continue
err_msg = None
try:
request = Request.from_dict(data)
request.metrics.scheduler_recv_req_time = time.time()
main_process_metrics.requests_number.inc()
trace_carrier = data.get("trace_carrier")
if trace_carrier:
request_id = data["request_id"].split("_")[0]
tracing.trace_set_proc_propagate_context(request_id, trace_carrier)
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
self.llm_logger.debug(f"Receive request from api server: {request}")
if self.is_paused:
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
self._send_error_response(
request.request_id,
"Request is aborted since LLM Engine is paused.",
worker_pid=worker_pid,
)
continue
except Exception as e:
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e)
results.append((data["request_id"], err_msg))
if self.guided_decoding_checker is not None and err_msg is None:
request, err_msg = self.guided_decoding_checker.schema_format(request)
if err_msg is not None:
self.llm_logger.error(f"Receive request error: {err_msg}")
results.append((request.request_id, err_msg))
if err_msg is None:
insert_task.append(request)
response = self.scheduler.put_requests(insert_task)
results.extend(response)
if request:
if request.request_id not in added_requests:
added_requests[request.request_id] = 0
added_requests[request.request_id] += 1
for request_id, failed in results:
if request_id in added_requests:
added_requests[request_id] -= 1
if added_requests[request_id] == 0:
added_requests.pop(request_id)
if failed is None:
main_process_metrics.num_requests_waiting.inc(1)
continue
self._send_error_response(request_id, failed)
except Exception as e:
self.llm_logger.error(
f"Error happened while receiving new request from zmq, details={e}, "
f"traceback={traceback.format_exc()}"
)
def run_control_method(self, control_req: ControlRequest):
"""
Execute control method, process control request and return response.
This method is responsible for handling control requests, calling the corresponding
handler function based on the method name in the request. If the method doesn't exist
or is not callable, it returns an error response; otherwise executes the method and
returns a success response.
Args:
control_req (ControlRequest): Control request object containing request ID,
method name and parameters.
Returns:
None: No return value, sends ControlResponse through send_response_server.
"""
method = control_req.get_method()
request_id = control_req.request_id
# Look up worker_pid for routing control response
worker_pid = None
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
worker_pid = self.request_worker_map.pop(request_id, None)
try:
self.llm_logger.info(f"Start to run control method {method}: {request_id}")
handler_name = f"_control_{method}"
handler = getattr(self, handler_name, None)
if handler is None or not callable(handler):
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
self.llm_logger.error(str(error_result))
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
return
result = handler(control_req)
self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}")
succ_result = ControlResponse(request_id, 200, "Success", result)
data = [[succ_result]] if envs.ZMQ_SEND_BATCH_DATA else [succ_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
except Exception as e:
error_msg = f"Failed to run control method {method}: {request_id} {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
data = [[error_result]] if envs.ZMQ_SEND_BATCH_DATA else [error_result]
self.send_response_server.send_response(request_id, data, worker_pid=worker_pid)
def _control_pause(self, control_request: ControlRequest):
"""Pauses the LLM engine and aborts all running/inflight requests.
Args:
control_request: The control request containing pause command
Raises:
Exception: If pause is not supported in current configuration
Exception: If engine worker queue cleanup times out
Returns:
None
"""
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
if self.cfg.scheduler_config.name != "local":
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
self.llm_logger.info("Start to pause request generation.")
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Engine is already paused, no need to pause again.")
return
self.is_paused = True
self.llm_logger.info("Abort running requests.")
self.resource_manager.log_status()
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
timeout, count = 60, 0
while self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
count += 1
if count >= timeout * 1000:
break
if count >= timeout * 1000:
error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
running_reqs = self.resource_manager.preempted_all()
if len(running_reqs) > 0:
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
self.resource_manager.get_real_bsz()
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
# self.engine_worker_queue.clear_data()
self.token_processor.clear_data()
self.resource_manager.log_status()
# abort inflight requests to user
inflight_requests = self.scheduler.get_inflight_requests()
self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).")
for req in inflight_requests:
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
self.scheduler.reset()
if envs.ENABLE_V1_KVCACHE_MANAGER:
self.resource_manager.cache_manager.reset_cache()
else:
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(
self._wait_for_control_responses(
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
)
)
self.llm_logger.info("Successfully paused cache transfer.")
self.resource_manager.cache_manager.reset()
self.llm_logger.info("Successfully paused request generation.")
return None
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
"""Control function for resuming request generation.
This method resumes the paused request generation process by setting the pause flag
and notifying all waiting threads. It logs the start and end of the resume operation.
Args:
control_request: Control request object containing resume operation information
"""
self.llm_logger.info("Start to resume request generation.")
with self._pause_cond:
if not self.is_paused:
self.llm_logger.info("Engine is not paused, no need to resume.")
return None
self.is_paused = False
self._pause_cond.notify_all()
# resume cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to resume cache transfer.")
resume_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_resume_transfer", method="resume"
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
# Wait for cache_transfer responses
asyncio.run(
self._wait_for_control_responses(resume_transfer_request.request_id, 60, executors=["cache_transfer"])
)
self.llm_logger.info("Successfully resumed cache transfer.")
self.llm_logger.info("Successfully resumed request generation.")
return None
def _control_is_paused(self, control_request: ControlRequest) -> bool:
"""
Check if the LLM engine is in paused state.
Args:
control_request: Control request object.
Returns:
dict: Dictionary containing pause status information, {'is_paused': bool}
"""
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
with self._pause_cond:
return {"is_paused": self.is_paused}
def _get_is_paused_safe(self) -> bool:
"""Thread-safe getter for is_paused state, used by RegisterManager."""
with self._pause_cond:
return self.is_paused
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
"""Update model weights
Args:
control_request: Control request object containing parameters for weight updates
Returns:
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
Raises:
Exception: Raised when the engine is not in paused state
"""
self.llm_logger.info("Update Model Weights")
with self._pause_cond:
if self.is_paused is False:
error_msg = "Pause LLM Engine first before calling updating weights"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
responses = self._call_worker(control_request, 60)
if responses:
new_version = None
for resp in responses:
# Expect each worker response to be a dict-like object
if isinstance(resp, dict) and "version" in resp:
new_version = resp.get("version")
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
break
if new_version is not None:
self.cfg.model_config.version = new_version
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to update cache-transfer metadata after weight update.")
update_cache_request = ControlRequest(
request_id=f"{control_request.request_id}_update_weights",
method="update_weights",
args=copy.deepcopy(control_request.args),
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, update_cache_request))
asyncio.run(
self._wait_for_control_responses(update_cache_request.request_id, 60, executors=["cache_transfer"])
)
self.llm_logger.info("Successfully updated cache-transfer metadata after weight update.")
return responses
def _control_abort_requests(self, control_req: ControlRequest):
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
args = control_req.get_args()
abort_all = args.get("abort_all", False)
req_ids = args.get("req_ids", [])
matched_input_ids = set()
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
# Step 1: Determine target request list
if abort_all:
# all requests in running + waiting
target_req_ids = now_reqs
else:
# filter out requests that actually exist
target_req_ids = []
for rid in req_ids:
if rid in now_reqs:
target_req_ids.append(rid)
matched_input_ids.add(rid)
elif f"{rid}_0" in now_reqs:
target_req_ids.append(f"{rid}_0")
matched_input_ids.add(rid)
if not target_req_ids:
return {"aborted": [], "not_found": req_ids if not abort_all else []}
# Step 2: Collect partial results
aborted_info = []
results = []
for req_id in target_req_ids:
request = self.resource_manager.requests.get(req_id)
if request is None:
scheduled_req = self.scheduler.requests.get(req_id)
if scheduled_req is None:
continue
request = scheduled_req.raw
partial_token_ids = list(request.output_token_ids)
# Construct finished response with partial results
now = time.time()
abort_metrics = RequestMetrics(
arrival_time=request.metrics.arrival_time if request.metrics else now,
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
engine_recv_latest_token_time=now,
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
request_start_time=request.metrics.arrival_time if request.metrics else now,
)
result = RequestOutput(
request_id=req_id,
finished=True,
outputs=CompletionOutput(
index=0,
send_idx=len(partial_token_ids),
token_ids=[self.data_processor.eos_token_ids[0]],
),
metrics=abort_metrics,
error_code=200,
error_msg="Aborted",
)
results.append(result)
aborted_info.append(
{
"request_id": req_id,
"output_token_count": len(partial_token_ids),
}
)
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for req_id in target_req_ids:
self.resource_manager.add_abort_req_ids(req_id)
time.sleep(0.0001)
if self.cfg.scheduler_config.splitwise_role != "prefill":
self._wait_abort_complete(target_req_ids)
# Add results to scheduler, engine will have a thread calling get_results,
# then cleanup and call send_response to send to client.
# When client disconnects, send_response will automatically ignore
if self.cfg.scheduler_config.splitwise_role != "prefill":
try:
# self.send_response_server.send_response(req_id, [result])
self.scheduler.put_results(results)
except Exception:
pass # client may have disconnected
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
return {"aborted": aborted_info, "not_found": not_found}
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
"""
Wait for all abort requests to complete.
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
reset progress state if any, then continue monitoring
"""
target_set = set(target_req_ids)
prev_remaining_count = len(target_set)
last_progress_time = time.time()
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
while remaining:
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
if not remaining:
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
return
current_count = len(remaining)
if current_count < prev_remaining_count:
# progress made: recycle_abort_task was called
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
last_progress_time = time.time()
prev_remaining_count = current_count
if time.time() - last_progress_time > stall_timeout:
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
if stuck:
self.llm_logger.warning(
f"no abort progress for {stall_timeout}s, "
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
)
for req_id in list(stuck):
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
self.resource_manager.recycle_abort_task(req_id)
# reset progress state
last_progress_time = time.time()
prev_remaining_count = current_count - len(stuck)
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
time.sleep(0.005)
def _parse_tags(self, control_request: ControlRequest):
"""
Parse tags from control request.
"""
allowed_tags = ["weight", "kv_cache"]
tags = control_request.args.get("tags", None)
if tags is None:
tags = ",".join(allowed_tags)
control_request.args["tags"] = tags
self.llm_logger.info(
f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}"
)
elif isinstance(tags, list):
tags = ",".join(tags)
for tag in tags.split(","):
if tag not in allowed_tags:
raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")
return tags
def _control_sleep(self, control_request: ControlRequest):
"""
Offload gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for offloading memory
tags: list of tags to offload, supported values: ["weight", "cache"]
TODO: support different level of offloading, to provide options for release memory forever
or merely offloading to cpu memory for now.
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Make sure llm engine is paused.
self.llm_logger.warning(
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
"Please explicitly request to /pause the engine before /sleep."
)
self._control_pause(None)
# Determine which executors are needed for the sleep command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if envs.ENABLE_V1_KVCACHE_MANAGER:
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset_cache()
else:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
# Dispatch sleep request to executors
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))
def _control_wakeup(self, control_request: ControlRequest):
"""
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for reloading memory
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Determine which executors are needed for the wakeup command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
# Dispatch wakeup request to executors
self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors))
# Resume the engine after wakeup
self._control_resume(None)
return result
def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
"""
Dispatch control requests to workers, cache managers or engine itself.
Args:
control_request: ControlRequest
executors: List
"""
if "worker" in executors:
self.engine_worker_queue.put_tasks(([control_request], 1))
if "cache_transfer" in executors:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
return
async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
"""Wait for matching control responses from the selected executor queues.
This helper selects the control-response queues that belong to the requested
executors, then waits for all of them concurrently. Each queue gets a local
waiter that keeps reading until it sees the target request ID and stashes stale
responses into that queue's mailbox.
Args:
request_id: The control request ID that all returned responses must match.
timeout: Global timeout budget in seconds for the full multi-queue wait.
executors: Executor groups to wait for, for example `["worker"]` or
`["worker", "cache_transfer"]`. If `None`, waits for all control
response queues.
Returns:
A list of `response.result` values collected from all matched
`ControlResponse` objects. If no queue is selected, returns `None`.
Raises:
Exception: If the overall wait times out, or if any queue reports a non-200
control response or fails while waiting.
"""
def select_control_queues(executors: List[str] = None):
"""Select control response queues by executors."""
if executors is None:
return self._ctrl_output_queues
else:
queues = {}
for k, v in self._ctrl_output_queues.items():
if "w2e" in k and "worker" in executors:
queues[k] = v
elif "c2e" in k and "cache_transfer" in executors:
queues[k] = v
return queues
async def wait_one(queue_name: str, queue):
"""Wait until one queue returns a response for the current request_id."""
mailbox = self._ctrl_response_mailboxes[queue_name]
# Reuse a previously stashed response for this request before touching FMQ again.
cached_response = mailbox.pop(request_id, None)
if cached_response is not None:
self.llm_logger.info(f"Returning cached control response from {queue_name}.")
return cached_response
while True:
msg = await queue.get()
# Return if the response matches the control request
response: ControlResponse = msg.payload
if response.request_id == request_id:
self.llm_logger.info(f"Returning new control response from {queue_name}.")
return response
# Stash late responses from other control requests so they do not consume the
# current request's only read chance on this queue.
mailbox[response.request_id] = response
self.llm_logger.info(
f"Stashed old control response from {queue_name}. "
f"Expected request {request_id}, got request {response.request_id}"
)
# Select only the control response queues that belong to the requested executors.
queues = select_control_queues(executors)
if not queues:
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
return
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")
# Each queue gets its own waiter, which will stash stale responses until it finds the
# target request ID for this control request.
tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()}
done, pending = await asyncio.wait(tasks.values(), timeout=timeout)
if pending:
pending_names = [name for name, task in tasks.items() if task in pending]
done_names = [name for name, task in tasks.items() if task in done]
self.llm_logger.error(
f"Control request {request_id} execution timeout. "
f"Pending queues: {pending_names}, completed queues: {done_names}."
)
# Stop unfinished queue waiters so they do not outlive the control request.
for task in pending:
task.cancel()
await asyncio.gather(*pending, return_exceptions=True)
raise Exception(f"Control request {request_id} timed out after {timeout}s")
# Collect the results from all completed queues.
responses = []
for name, task in tasks.items():
try:
response = task.result()
except Exception as e:
self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}")
raise
if response.error_code != 200:
raise Exception(f"Error response from {name}: {response.error_message}")
responses.append(response.result)
return responses
def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))
def _send_error_response(self, request_id, error_msg, error_code: int = 500, worker_pid=None):
self.llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
)
error_result = RequestOutput(
request_id=request_id,
finished=True,
error_code=error_code,
error_msg=error_msg,
)
# Look up worker_pid from mapping if not provided
if worker_pid is None and envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
worker_pid = self.request_worker_map.pop(request_id, None)
# Since the request is not in scheduler
# Send result by zmq directly
if envs.FD_ENABLE_INTERNAL_ADAPTER:
self.send_response_server.send_response(None, [[error_result]])
elif envs.ZMQ_SEND_BATCH_DATA:
self.send_response_server.send_response(None, [[error_result]], worker_pid=worker_pid)
else:
self.send_response_server.send_response(request_id, [error_result])
def _decode_token(self, token_ids, req_id, is_end):
delta_text = ""
if envs.FD_ENABLE_RETURN_TEXT:
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
if delta_text != "":
prefix_offset = self.data_processor.decode_status[req_id][0]
read_offset = self.data_processor.decode_status[req_id][1]
token_ids = cum_tokens[prefix_offset:read_offset]
else:
token_ids = []
if is_end and delta_text == "" and len(cum_tokens) > 0:
read_offset = self.data_processor.decode_status[req_id][1]
token_ids = cum_tokens[read_offset:]
if is_end:
del self.data_processor.decode_status[req_id]
return delta_text, token_ids
def _zmq_send_generated_tokens(self):
"""
Receive output for zmq
"""
while self.running:
try:
results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.005)
continue
if envs.FD_ENABLE_INTERNAL_ADAPTER:
new_contents = []
for step_batch_results in results:
new_step_contents = []
for content in step_batch_results:
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids,
req_id=content.request_id,
is_end=content.finished,
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_step_contents.append(content)
elif content.finished:
new_step_contents.append(content)
else:
self.llm_logger.warning(
f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}"
)
else:
new_step_contents.append(content)
if new_step_contents:
new_contents.append(new_step_contents)
if new_contents:
self.send_response_server.send_response(None, new_contents)
else:
worker_batches = collections.defaultdict(list)
for request_id, contents in results.items():
new_contents = []
for content in contents:
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
delta_text, token_ids = self._decode_token(
token_ids=content.outputs.token_ids,
req_id=request_id,
is_end=content.finished,
)
else:
token_ids = content.outputs.token_ids
if len(token_ids):
content.outputs.token_ids = token_ids
content.outputs.text = delta_text
new_contents.append(content)
elif content.finished:
new_contents.append(content)
else:
self.llm_logger.warning(
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
)
else:
new_contents.append(content)
if new_contents:
if envs.ZMQ_SEND_BATCH_DATA:
wpid = self.request_worker_map.get(request_id)
worker_batches[wpid].append(new_contents)
is_finished = any(getattr(c, "finished", False) for c in new_contents)
if is_finished:
self.request_worker_map.pop(request_id, None)
else:
self.send_response_server.send_response(request_id, new_contents)
if envs.ZMQ_SEND_BATCH_DATA:
for wpid, batch_data in worker_batches.items():
if batch_data:
self.send_response_server.send_response(None, batch_data, worker_pid=wpid)
except Exception as e:
self.llm_logger.error(f"Unexpected error happend: {e}, {traceback.format_exc()!s}")
def _decode_process_splitwise_requests(self):
"""
Decode processes requests from engine worker queue, which are sent by prefill.
TODO: merge this function to the schedule function in resource manager
"""
allocate_resource_requests: list[Request] = []
prefilled_request_ouputs: list[RequestOutput] = []
def _fetch_requests():
if self.engine_worker_queue.disaggregate_queue_empty():
return
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
tasks = item[1]
if isinstance(tasks[0], Request):
self.llm_logger.debug(
f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}"
)
for task in tasks:
task.metrics.decode_recv_req_time = time.time()
allocate_resource_requests.extend(tasks)
elif isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(
f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}"
)
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
task.metrics.decode_recv_first_token_time = time.time()
prefilled_request_ouputs.extend(tasks)
def _process_allocate_resource_requests():
processed_indices = []
for idx, task in enumerate(allocate_resource_requests):
is_success = False
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
task.metrics.decode_preallocate_req_time = time.time()
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}")
processed_indices.append(idx)
is_success = True
else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.llm_logger.debug(f"D Resource available, processing task {task.request_id}")
self.insert_tasks([task])
task.metrics.decode_preallocate_req_time = time.time()
processed_indices.append(idx)
is_success = True
if not is_success:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
allocate_resource_requests.pop(idx)
def _process_prefilled_requests():
nonlocal prefilled_request_ouputs
ready_request_outputs = []
waiting_request_outputs = []
for req_output in prefilled_request_ouputs:
if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id):
# ensure the api_server and scheduler in decode have
# received the request sent by the client
waiting_request_outputs.append(req_output)
continue
req_output.finished = False
ready_request_outputs.append(req_output)
self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}")
prefilled_request_ouputs = waiting_request_outputs
if self.cfg.splitwise_version == "v1":
# decode return first token to client
self.scheduler.put_results(ready_request_outputs)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self._insert_prefilled_requests(ready_request_outputs)
else:
for req_output in ready_request_outputs:
request_id = req_output.request_id
if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids:
# first token is eos in Prefill, just recycle resource and continue
self.llm_logger.warning(f"{request_id} need not decode after first token")
self.resource_manager.pre_recycle_resource(request_id)
if request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[request_id]
req_output.finished = True
self.scheduler.put_results([req_output])
continue
if req_output.error_code != 200:
self.llm_logger.warning(
f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource."
)
self.resource_manager.pre_recycle_resource(request_id)
if request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[request_id]
self.scheduler.put_results([req_output])
continue
self.token_processor.tokens_counter[request_id] = 1
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
self.scheduler.put_results([req_output])
self.resource_manager.add_prefilled_request(req_output)
self.llm_logger.info(f"D has successfully added prefilled request, {request_id}")
def decode_loop():
while self.running:
try:
_fetch_requests()
_process_allocate_resource_requests()
_process_prefilled_requests()
time.sleep(0.001)
except Exception as e:
self.llm_logger.error(
f"Error in main loop of decode_process_splitwise_requests: " f"{e}, {traceback.format_exc()}"
)
time.sleep(0.01)
threading.Thread(target=decode_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix):
console_logger.debug("Start cache manager...")
return self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
device_ids=device_ids,
pod_ip=self.cfg.master_ip,
engine_worker_queue_port=self.cfg.parallel_config.local_engine_worker_queue_port,
ipc_suffix=ipc_signal_suffix,
create_cache_tensor=False,
)
def check_and_free_block_tables(self):
self.resource_manager.check_and_free_block_tables()
def clear_data(self):
try:
self.llm_logger.info("Clear Data: Start")
self.token_processor.clear_data()
self.engine_worker_queue.clear_data()
if hasattr(self, "cache_task_queue"):
self.cache_task_queue.clear_transfer_task()
self.send_response_server.req_dict.clear()
self.recv_request_server.req_dict.clear()
# Clean up worker_pid mapping (batch mode)
if envs.ZMQ_SEND_BATCH_DATA and hasattr(self, "request_worker_map"):
self.request_worker_map.clear()
self.llm_logger.info("Clear Data: Successfully")
return True
except Exception as e:
self.llm_logger.error(f"Clear data error: {e}")
return False
def _exit_sub_services(self):
"""
exit sub services
"""
self.llm_logger.info("Exit sub services.....")
self.running = False
if self.use_async_llm:
# Clean up worker processes first (before closing multiprocessing services)
if hasattr(self, "worker_proc") and self.worker_proc is not None:
self.llm_logger.info("Cleaning up worker processes...")
try:
pgid = os.getpgid(self.worker_proc.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
self.llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
# Clean up cache manager processes
if hasattr(self, "cache_manager_processes"):
self.llm_logger.info("Cleaning up cache manager processes...")
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
self.resource_manager.cache_manager.cache_ready_signal.clear()
for p in self.cache_manager_processes:
self.llm_logger.info(f"Killing cache manager process {p.pid}")
try:
pgid = os.getpgid(p.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception as e:
self.llm_logger.error(
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
)
if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None:
self.llm_logger.info("Cleaning up cache_task_queue...")
# Check if cleanup method exists
if hasattr(self.cache_task_queue, "cleanup"):
self.cache_task_queue.cleanup()
elif hasattr(self.cache_task_queue, "manager"):
try:
self.llm_logger.info("Shutting down cache_task_queue manager...")
self.cache_task_queue.manager.shutdown()
except Exception as e:
self.llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
if hasattr(self, "get_profile_block_num_signal"):
self.get_profile_block_num_signal.clear()
self.worker_ready_signal.clear()
self.loaded_model_signal.clear()
# Clean up other services
if hasattr(self, "dp_processed"):
for p in self.dp_processed:
self.llm_logger.info(f"Waiting for worker {p.pid} to exit")
p.join()
for p in self.dp_engine_worker_queue_server:
p.cleanup()
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
self.engine_worker_queue_server.cleanup()
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()
self.cache_ready_signal.clear()
self.swap_space_ready_signal.clear()
self.cache_transfer_inited_signal.clear()
self.exist_prefill_task_signal.clear()
self.model_weights_status_signal.clear()
self.prefix_tree_status_signal.clear()
self.kv_cache_status_signal.clear()
if hasattr(self, "send_response_server") and self.send_response_server is not None:
self.send_response_server.close()
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
self.recv_request_server.close()
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
self.recv_control_cmd_server.close()
# 从 async_llm 移到 common_engine
def _worker_processes_ready(self):
"""
judge if all worker processes are ready
"""
if np.sum(self.worker_ready_signal.value) == self.cfg.worker_num_per_node:
return True
return False
def _init_worker_signals(self):
"""
Initialize shared memory to indicate engine status
"""
# worker_ready_signal 用于worker进程感知engine是否启动完成
worker_ready_signal_data = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_ready_signal = IPCSignal(
name="worker_ready_signal",
array=worker_ready_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# launched_expert_service_signal: Used to sense whether each expert_service is started successfully
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
launched_expert_service_signal_data = np.zeros(
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
)
self.launched_expert_service_signal = IPCSignal(
name="launched_expert_service_signal",
array=launched_expert_service_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
# loaded_model_signal: Used to detect whether each worker has completed model loading
loaded_model_signal_data = np.zeros([1], dtype=np.int32)
self.loaded_model_signal = IPCSignal(
name="loaded_model_signal",
array=loaded_model_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
if self.do_profile:
if paddle.is_compiled_with_custom_device("iluvatar_gpu"):
get_profile_block_num = np.zeros([self.cfg.worker_num_per_node], dtype=np.int32)
else:
get_profile_block_num = np.zeros([1], dtype=np.int32)
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)
def _setting_environ_variables(self):
"""
配置环境变量
"""
variables = {
"ENABLE_FASTDEPLOY_LOAD_MODEL_CONCURRENCY": 0,
"LOAD_STATE_DICT_THREAD_NUM": len(self.cfg.parallel_config.device_ids.split(",")),
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": 3,
}
# environment variables needed by Dy2St
variables.update(
{
"SOT_LOG_LEVEL": os.getenv("SOT_LOG_LEVEL", default="0"),
"SOT_UNSAFE_CACHE_FASTPATH": os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"),
"SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"),
"SOT_SPECIALIZED_DIM_NUMBERS": os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"),
"SOT_ENABLE_COMPILE_TIME_LIMIT": os.getenv("SOT_ENABLE_COMPILE_TIME_LIMIT", default="0"),
"FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"),
"FLAGS_enable_async_fast_gc": os.getenv("FLAGS_enable_async_fast_gc", default="0"),
"FLAGS_pir_interpreter_record_stream_for_gc_cache": os.getenv(
"FLAGS_pir_interpreter_record_stream_for_gc_cache", default="1"
),
"FLAGS_parameters_persistent_mode_in_dy2st": os.getenv(
"FLAGS_parameters_persistent_mode_in_dy2st", default="1"
),
}
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
variables["FLAGS_use_pd_disaggregation_per_chunk"] = 1
else:
variables["FLAGS_use_pd_disaggregation"] = 1
# TODO dynamic load environment variable
if self.cfg.scheduler_config.splitwise_role == "prefill":
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
if self.cfg.enable_mm_runtime:
variables["FLAGS_max_partition_size"] = 1024
command_prefix = ""
for k, v in variables.items():
command_prefix += f"{k}={v} "
return command_prefix
def _start_worker_service(self):
"""
start gpu worker service
"""
log_dir = os.getenv("FD_LOG_DIR", default="log")
command_prefix = self._setting_environ_variables()
current_file_path = os.path.abspath(__file__)
current_dir_path = os.path.split(current_file_path)[0]
# TODO
uncache_worker_stdout = "" if os.getenv("UNCACHE_WORKER_STDOUT", "0") == "1" else "-u"
pd_cmd = f"{command_prefix} {sys.executable} {uncache_worker_stdout} -m paddle.distributed.launch"
pd_cmd = pd_cmd + f" --log_dir {log_dir}"
worker_path = "../worker/worker_process.py"
py_script = os.path.join(current_dir_path, worker_path)
ori_vocab_size = (
len(self.data_processor.tokenizer.sp_model)
if hasattr(self.data_processor.tokenizer, "sp_model")
else len(self.data_processor.tokenizer.vocab)
)
think_start_id = self.data_processor.tokenizer.get_vocab().get("<think>", -1)
if think_start_id >= 0:
self.llm_logger.info(f"Get think_start_id {think_start_id} from vocab.")
else:
self.llm_logger.info("No <think> token found in vocabulary, the model can not do reasoning.")
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
if think_end_id >= 0:
self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
else:
self.llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
if line_break_id < 0:
line_break_ids = self.data_processor.tokenizer.encode("\n", add_special_tokens=False)
if isinstance(line_break_ids, dict):
line_break_ids = line_break_ids.get("input_ids")
elif hasattr(line_break_ids, "input_ids"):
line_break_ids = line_break_ids.input_ids
if line_break_ids:
if isinstance(line_break_ids, (list, tuple)):
first = line_break_ids[0]
if isinstance(first, (list, tuple)):
line_break_id = int(first[0]) if first else -1
else:
line_break_id = int(first)
else:
line_break_id = int(line_break_ids)
if line_break_id >= 0:
self.llm_logger.info(f"Get line_break_id {line_break_id} from tokenizer.")
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
ips = None
if self.cfg.ips is not None:
ips = ",".join(self.cfg.ips)
arguments = (
f" --devices {self.cfg.parallel_config.device_ids} {py_script}"
f" --max_num_seqs {self.cfg.scheduler_config.max_num_seqs} --max_model_len {self.cfg.model_config.max_model_len}"
f" --gpu_memory_utilization {self.cfg.cache_config.gpu_memory_utilization}"
f" --model {self.cfg.model_config.model!s}"
f" --device_ids {self.cfg.parallel_config.device_ids}"
f" --tensor_parallel_size {self.cfg.parallel_config.tensor_parallel_size}"
f" --engine_worker_queue_port {ports}"
f" --pod_ip {self.cfg.master_ip}"
f" --block_size {self.cfg.cache_config.block_size}"
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
f" --pad_token_id {self.data_processor.pad_token_id}"
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
f" --chunked_moe_size {self.cfg.parallel_config.chunked_moe_size}"
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
f" --ori_vocab_size {ori_vocab_size}"
f" --think_start_id {think_start_id}"
f" --think_end_id {think_end_id}"
f" --image_patch_id {image_patch_id}"
f" --line_break_id {line_break_id}"
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --model_loader_extra_config '{json.dumps(self.cfg.load_config.model_loader_extra_config)}'"
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}"
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
f" --deploy_modality {self.cfg.deploy_modality.value}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
if self.mm_max_tokens_per_item is not None:
arguments += f" --mm_max_tokens_per_item '{json.dumps(self.mm_max_tokens_per_item)}'"
worker_store_true_flag = {
"enable_expert_parallel": self.cfg.parallel_config.enable_expert_parallel,
"enable_prefix_caching": self.cfg.cache_config.enable_prefix_caching,
"enable_chunked_prefill": self.cfg.cache_config.enable_chunked_prefill,
"do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
"moe_gate_fp32": self.cfg.model_config.moe_gate_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
arguments = arguments + f" --{worker_flag}"
worker_default_none_flag = {
"num_gpu_blocks_override": self.cfg.cache_config.num_gpu_blocks_override,
"kvcache_storage_backend": self.cfg.cache_config.kvcache_storage_backend,
}
for worker_flag, value in worker_default_none_flag.items():
if value:
arguments = arguments + f" --{worker_flag} {value}"
if self.cfg.nnode > 1:
pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}"
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
self.llm_logger.info(f"Launch worker service command: {pd_cmd}")
p = subprocess.Popen(
pd_cmd,
stdout=subprocess.PIPE,
shell=True,
preexec_fn=os.setsid,
)
return p
def _stop_profile(self):
"""
Stop profiling of the model server and reset variables.
"""
self.do_profile = 0
while self.get_profile_block_num_signal.value[0] == 0:
if hasattr(self, "worker_proc") and self.worker_proc is not None:
if self.worker_proc.poll() is not None:
raise RuntimeError("Worker process failed to start." "Please check log/workerlog.* for details.")
time.sleep(1)
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
if envs.ENABLE_V1_KVCACHE_MANAGER:
return
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
def check_health(self, time_interval_threashold=30):
"""
Check the health of the model server by checking whether all workers are alive.
"""
if self.worker_healthy_live_signal.value[0]:
elapsed_time = time.time() - self.worker_healthy_live_signal.value[0]
if elapsed_time > time_interval_threashold:
return False, "Worker Service Not Healthy"
return True, ""
def launch_components(self):
if self.cfg.scheduler_config.splitwise_role != "mixed":
# 单机逻辑
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
self.splitwise_receive_thread.daemon = True
self.splitwise_receive_thread.start()
role = self.cfg.scheduler_config.splitwise_role
host_ip = self.cfg.host_ip
if self.cfg.scheduler_config.name == "splitwise":
self.scheduler.start(role, host_ip, self.cfg.register_info)
elif self.cfg.scheduler_config.name == "dp":
self.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
)
if not envs.FD_ENABLE_MULTI_API_SERVER:
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
self.launched_expert_service_signal.value[0] = 1
self.dp_processed = []
self.dp_engine_worker_queue_server = []
for i in range(
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
self.llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
address=address,
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
)
from fastdeploy.engine.expert_service import (
start_data_parallel_service,
)
self.dp_processed.append(
multiprocessing.Process(
target=start_data_parallel_service,
args=(
self.cfg,
i,
),
)
)
self.llm_logger.info(
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
+ f" data parallel id {i}"
)
self.dp_processed[-1].start()
while self.launched_expert_service_signal.value[i] == 0:
time.sleep(1)
def check_worker_initialize_status(self):
"""
Check the initlialize status of workers by stdout logging
"""
def detect_thread():
for line in self.worker_proc.stdout:
line = line.decode("utf-8", errors="ignore")
if self.worker_init_status.get("finished", False):
break
if match := re.search(
r"Loading (?:fastsafetensors |safetensors )?checkpoint shards:\s*(\d+)",
line,
):
self.worker_init_status["weight_loadding"] = eval(match.group(1)) * 1.0 / 100
elif (match := re.search(r"Start load layer (\d+)", line)) or (
match := re.search(r"set state for layer (\d+)", line)
):
progress = eval(match.group(1)) * 1.0 / self.cfg.model_config.num_hidden_layers
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_hidden_layers - 1:
self.worker_init_status["finished"] = True
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start()
# display weight loadding progress
with tqdm(total=100, desc="Loading Weights") as pbar:
progress = 0
while progress < 100:
progress = int(self.worker_init_status.get("weight_loadding", 0) * 100)
if self.worker_init_status.get("layer_loadding", 0) > 0 or self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
pbar.refresh()
time.sleep(0.5)
if self.worker_proc.poll() is not None:
return False
# display layer loadding progress
with tqdm(total=100, desc="Loading Layers") as pbar:
progress = 0
while progress < 100:
progress = int(self.worker_init_status.get("layer_loadding", 0) * 100)
if self._worker_processes_ready():
progress = 100
pbar.update(progress - pbar.n)
pbar.refresh()
time.sleep(0.5)
if self.worker_proc.poll() is not None:
return False
self.worker_init_status["finished"] = True
try:
self.checking_worker_status_thread.join(timeout=1)
except Exception:
pass
return True