Files
FastDeploy/fastdeploy/engine/sched/resource_manager_v1.py
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

1733 lines
83 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import threading
import time
import traceback
from collections import deque
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import List, Union
import numpy as np
import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
)
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.request import (
BatchRequest,
ImagePosition,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.input.utils import IDS_TYPE_FLAG
from fastdeploy.inter_communicator import IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.multimodal.hasher import MultimodalHasher
from fastdeploy.platforms import current_platform
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.trace.constants import LoggingEventName
from fastdeploy.trace.trace_logger import print as trace_print
from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger
@dataclass
class ScheduledTaskBase:
"""
Task for Scheduled.
"""
idx: int
request_id: str
task_type: RequestType = RequestType.DECODE
cache_swap_metadata: list[CacheSwapMetadata] = field(default_factory=list)
cache_evict_metadata: list[CacheSwapMetadata] = field(default_factory=list)
def pop_cache_swap_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_swap_metadata
self.cache_swap_metadata = []
return result
def pop_cache_evict_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_evict_metadata
self.cache_evict_metadata = []
return result
@dataclass
class ScheduledDecodeTask(ScheduledTaskBase):
"""
Task for allocating new blocks to decode.
"""
block_tables: list[int] = field(default_factory=list)
@dataclass
class ScheduledPreemptTask(ScheduledTaskBase):
"""
Task for terminating inference to recycle resource.
"""
task_type: RequestType = RequestType.PREEMPTED
@dataclass
class ScheduledExtendBlocksTask(ScheduledTaskBase):
"""
Task for allocating new blocks to extend.
"""
task_type: RequestType = RequestType.EXTEND
extend_block_tables: list[int] = field(default_factory=list)
@dataclass
class ScheduledAbortTask(ScheduledTaskBase):
"""Task for allocating new blocks to skip."""
task_type: RequestType = RequestType.ABORT
class SignalConsumer:
"""
A class that consumes a signal value up to a specified limit.
This class maintains an internal signal value and allows controlled consumption
of that signal. The signal can be watched at any time, but can only be consumed
a limited number of times before being reset to zero.
"""
def __init__(self, signal, consume_limit):
"""
Initialize the SignalConsumer with a signal value and consumption limit.
Args:
signal: The initial signal value to be consumed.
consume_limit (int): The maximum number of times the signal can be consumed
before being reset to 0. Must be a positive integer.
Raises:
AssertionError: If consume_limit is not greater than 0.
"""
assert consume_limit > 0
self._signal = signal
self._consume_limit = consume_limit
def watch(self):
"""
Get the current signal value without consuming it.
This method allows reading the signal value any number of times without
affecting the consumption limit or the signal value itself.
Returns:
The current signal value.
"""
return self._signal
def consume(self):
"""
Consume the signal value, decrementing the consumption limit.
This method returns the current signal value and decrements the consumption
counter. When the consumption limit reaches zero, the signal is automatically
reset to 0. The consumption happens in a finally block to ensure the limit is
decremented even if an exception occurs while processing the signal.
Returns:
The current signal value before consumption.
Note:
After the consumption limit is reached, this method will continue to
return 0 on subsequent calls.
"""
try:
return self._signal
finally:
if self._consume_limit > 0:
self._consume_limit -= 1
if self._consume_limit == 0:
self._signal = 0
class ResourceManagerV1(ResourceManager):
"""
Resource manager for scheduler v1.
In scheduler v1, all gpu blocks are managed by PrefixCacheManager.
Tasks sent to worker are divided into 3 types, PREFILL、DECODE and PREEMPTED.
For prefill task, the worker infer with one step and then stopped for this query if not all prompt tokens are computed.
For decode task, the work continues to decode until allocated blocks are exhausted.
For preempted task, the work reset all inputs to terminate the inference.
"""
def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id=0):
super(ResourceManagerV1, self).__init__(
max_num_seqs, config, tensor_parallel_size, splitwise_role, local_data_parallel_id
)
# req_id -> Request
self.config = config
self.requests: dict[str, Request] = {}
# Priority queues for requests.
self.waiting: deque[Request] = deque()
self.running: list[Request] = []
self.preallocated_reqs: dict[str, Request] = {}
self.enable_max_prefill = envs.FD_ENABLE_MAX_PREFILL
self.finish_execution_pool = ThreadPoolExecutor(max_workers=1)
self.lock = threading.Lock()
self.to_be_rescheduled_request_id_set = set()
main_process_metrics.max_batch_size.set(max_num_seqs)
self.using_extend_tables_req_id = set()
self.reuse_block_num_map = dict()
self.abort_req_ids_set = set()
self.waiting_abort_req_id_set = set()
self.to_be_aborted_req_id_set = set()
# need block nums
need_block_num_data = np.zeros([max_num_seqs], dtype=np.int32)
self.need_block_num_signal = IPCSignal(
name="need_block_num_signal",
array=need_block_num_data,
dtype=np.int32,
suffix=self.config.parallel_config.local_engine_worker_queue_port,
create=True,
)
self.need_block_num_map = dict()
self.encoder_cache = None
if config.enable_mm_runtime and config.cache_config.max_encoder_cache > 0:
self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache)
self.processor_cache = None
if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0:
max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024)
self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes)
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
self.init_reserve_output_block_num = (
envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.decay_output_block_num = (
envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # float
self.min_reserve_output_block_num = (
envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
) # int
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = True
# Scheduler-side requests that have not been moved into resource manager waiting queue yet.
self.scheduler_unhandled_request_num = 0
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
def get_new_block_nums(self, request: Request, num_new_tokens: int):
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables)
if self.config.speculative_config.method is not None:
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
else:
block_num = min(block_num, self.config.cache_config.max_block_num_per_seq)
return block_num
def _prepare_prefill_task(self, request, new_token_num):
request.prefill_start_index = request.num_computed_tokens
request.prefill_end_index = request.num_computed_tokens + new_token_num
request.task_type = RequestType.PREFILL
return request
def _prepare_decode_task(self, request):
return ScheduledDecodeTask(
idx=request.idx,
request_id=request.request_id,
block_tables=request.block_tables,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(
idx=request.idx,
request_id=request.request_id,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def _prepare_abort_task(self, request):
return ScheduledAbortTask(
idx=request.idx,
request_id=request.request_id,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
llm_logger.debug(f"reschedule {request_id} into waiting queue")
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id]
request.has_been_preempted_before = True
request.metrics.preempted_count += 1
if process_func is not None:
process_func(request)
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def recycle_abort_task(self, request_id):
with self.lock:
if request_id in self.to_be_aborted_req_id_set and request_id in self.requests:
request = self.requests[request_id]
self.tasks_list[request.idx] = None # 清空slot
self.stop_flags[request.idx] = True # 设置停止标志
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()
def _trigger_abort(self, request_id, batch_request):
if request_id in self.requests:
abort_request = self.requests[request_id]
abort_request.status = RequestStatus.PREEMPTED
abort_request.num_computed_tokens = 0
self._free_blocks(abort_request) # 释放KV cache blocks
abort_request.cached_block_num = 0
batch_request.add_request(self._prepare_abort_task(abort_request))
self.to_be_aborted_req_id_set.add(request_id)
self.waiting_abort_req_id_set.remove(request_id)
def _info_each_block(self):
"""
print each req block
"""
for req in self.running:
llm_logger.debug(
f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables"
)
def _can_preempt(self):
"""
cannot preempt request which use extend block
"""
for req in self.running:
if not req.use_extend_tables:
return True
return False
def preempted_all(self):
with self.lock:
preempted_reqs = []
for i in range(len(self.running)):
req = self.running.pop()
# txt2image: req.use_extend_tables is True, req can not be preempted. txt2image is not used in RL.
if req.use_extend_tables:
self.running.insert(0, req)
continue
req.status = RequestStatus.PREEMPTED
req.num_computed_tokens = 0
self._free_blocks(req)
req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(req.request_id)
trace_print(LoggingEventName.PREEMPTED, req.request_id, getattr(req, "user", ""))
preempted_reqs.append(self._prepare_preempt_task(req))
return preempted_reqs
def wait_worker_inflight_requests_finish(self, timeout=60):
count = 0
while count < timeout * 1000:
# wait ongoing running and rescheduled requests finished in worker
running_reqs_count = len(self.to_be_rescheduled_request_id_set) + len(self.running)
if running_reqs_count == 0:
break
count += 1
time.sleep(0.001)
if count >= timeout * 1000:
llm_logger.info(
f"wait_inflight_requests_finish timeout after {timeout} seconds, "
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, batch_request):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
can_schedule = False
while self._can_preempt():
if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks):
preempted_req = self.running.pop()
if preempted_req.use_extend_tables:
self.running.insert(0, preempted_req)
continue
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
self.tasks_list[preempted_req.idx] = None
self.stop_flags[preempted_req.idx] = True
if preempted_req.request_id in self.requests:
del self.requests[preempted_req.request_id]
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
if self.config.cache_config.kvcache_storage_backend:
self.cache_manager.write_cache_to_storage_decode(preempted_req)
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
if envs.FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST:
if self.config.cache_config.kvcache_storage_backend:
self.cache_manager.write_cache_to_storage(preempted_req)
self._free_blocks(preempted_req)
preempted_req.num_cached_blocks = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
trace_print(
LoggingEventName.PREEMPTED, preempted_req.request_id, getattr(preempted_req, "user", "")
)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
preempted_reqs.append(preempted_req)
batch_request.add_request(self._prepare_preempt_task(preempted_req))
llm_logger.debug(
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
)
llm_logger.debug(self.info())
self._info_each_block()
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
self.current_reserve_output_block_num = self.init_reserve_output_block_num
self.current_reserve_output_block_num_float = self.init_reserve_output_block_num
self.can_relax_prefill_strategy = False
return can_schedule
def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block):
if self.can_relax_prefill_strategy:
can_schedule_block_num_threshold = num_chunk_new_block
else:
can_schedule_block_num_threshold = (
num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num
)
if self.config.speculative_config.method is not None:
can_schedule_block_num_threshold = min(
can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq
)
return can_schedule_block_num_threshold
def _update_mm_hashes(self, request):
if request.multimodal_inputs is None:
return
inputs = request.multimodal_inputs
if (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
and len(inputs["grid_thw"]) != 0
):
grid_thw = []
new_mm_positions, new_mm_hashes = [], []
image_st = 0
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
new_mm_positions.append(inputs["mm_positions"][idx])
new_mm_hashes.append(inputs["mm_hashes"][idx])
image_st += h * w
else:
grid_thw.extend([[2, h, w]] * (t // 2))
token_st = inputs["mm_positions"][idx].offset
for _ in range(t // 2):
mm_num_token = inputs["mm_num_token_func"](grid_thw=[2, h, w])
new_mm_positions.append(ImagePosition(token_st, mm_num_token))
# videos are split into patches every 2 frames, need to rehash
new_mm_hashes.append(
MultimodalHasher.hash_features(inputs["images"][image_st : image_st + 2 * h * w])
)
image_st += 2 * h * w
token_st += mm_num_token
inputs["mm_positions"] = new_mm_positions
inputs["mm_hashes"] = new_mm_hashes
elif inputs.get("mm_positions", None) is None or inputs.get("mm_hashes", None) is None:
inputs["mm_positions"] = []
inputs["mm_hashes"] = []
def _is_mm_request(self, request):
inputs = request.multimodal_inputs
if inputs is None or len(inputs) == 0:
return False
if (
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
):
return True
elif (
inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
):
return True
return False
def revert_chunked_mm_input(self, mm_inputs, matched_token_num):
"""
revert mm_inputs that is chunked
"""
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
return matched_token_num
position_idx = len(mm_inputs["mm_positions"]) - 1
while matched_token_num > 0 and position_idx >= 0:
position = mm_inputs["mm_positions"][position_idx]
if position.offset < matched_token_num < position.offset + position.length:
matched_token_num = (
position.offset // self.config.cache_config.block_size
) * self.config.cache_config.block_size
position_idx -= 1
elif matched_token_num <= position.offset:
position_idx -= 1
elif matched_token_num >= position.offset + position.length:
break
else:
llm_logger.error(
f"revert_chunked_mm_input error, matched_token_num:{matched_token_num} position:{position}, {mm_inputs['mm_positions']}"
)
break
return matched_token_num
def _get_num_new_tokens(self, request, token_budget):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
assert num_new_tokens > 0, (
f"Request {request.request_id} has no remaining tokens: "
f"need_prefill={request.need_prefill_tokens}, computed={request.num_computed_tokens}"
)
num_new_tokens = min(num_new_tokens, token_budget)
# Deterministic mode: align chunk boundaries to split_kv_size
# This ensures batch-invariant attention by making each chunk
# a multiple of the split-KV block size (default 16)
if envs.FD_DETERMINISTIC_MODE:
split_kv_size = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE
current_pos = request.num_computed_tokens
remaining_tokens = request.need_prefill_tokens - current_pos
# Case 1: Final chunk - no alignment needed
if remaining_tokens < split_kv_size:
aligned_end = current_pos + remaining_tokens
else:
# Case 2: Need to align to split_kv_size boundary
# Calculate next boundary position
next_boundary = ((current_pos + split_kv_size - 1) // split_kv_size) * split_kv_size
tokens_to_boundary = next_boundary - current_pos
# Not enough budget to reach the next boundary: defer to next iteration
if token_budget < tokens_to_boundary:
return 0
# Align to as many full boundaries as budget allows
aligned_end = ((current_pos + token_budget) // split_kv_size) * split_kv_size
num_new_tokens = aligned_end - current_pos
# Don't exceed the original budget or remaining tokens
num_new_tokens = min(
num_new_tokens, token_budget, request.need_prefill_tokens - request.num_computed_tokens
)
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens > token_budget
and token_budget > self.config.cache_config.block_size
):
num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size
request.with_image = False
if not self.config.enable_mm_runtime:
return num_new_tokens
inputs = request.multimodal_inputs
if (
inputs is not None
and inputs.get("patch_idx", None) is not None
and inputs.get("patch_map", None) is not None
):
pre_end_idx = request.num_computed_tokens
new_end_idx = pre_end_idx + num_new_tokens
prompt_token_ids_len = len(request.prompt_token_ids)
if not inputs.get("tts", False):
assert prompt_token_ids_len == len(inputs["patch_idx"]), (
prompt_token_ids_len,
len(inputs["patch_idx"]),
)
def _compute_audio_prefix_count(end_idx, end_patch_idx):
audio_prefix_count = 0
pre_patch_end_idx = 0
for patch_idx in range(end_patch_idx + 1):
patch_map = inputs["patch_map"][patch_idx]
modal_id = patch_map["modal_id"]
if modal_id == IDS_TYPE_FLAG["audio"]:
if patch_idx != end_patch_idx:
audio_prefix_count += patch_map["end_idx"] - pre_patch_end_idx
else:
audio_prefix_count += end_idx - pre_patch_end_idx
pre_patch_end_idx = patch_map["end_idx"]
return audio_prefix_count
# start
if pre_end_idx >= prompt_token_ids_len:
start_patch_idx = inputs["patch_idx"][-1]
else:
start_patch_idx = inputs["patch_idx"][pre_end_idx]
if (
pre_end_idx > 0
and request.prompt_token_ids[pre_end_idx]
in [
inputs["image_patch_id"],
inputs["video_patch_id"],
inputs["audio_patch_id"],
]
and request.prompt_token_ids[pre_end_idx] != request.prompt_token_ids[pre_end_idx - 1]
):
# It just hit the starting position of the image / video / audio
start_patch_idx -= 1
start_patch_map = inputs["patch_map"][start_patch_idx]
request.image_start = start_patch_map["image_num"]
request.video_start = start_patch_map["video_num"]
request.audio_start = _compute_audio_prefix_count(pre_end_idx, start_patch_idx)
# end
if new_end_idx >= prompt_token_ids_len:
end_patch_idx = inputs["patch_idx"][-1]
else:
end_patch_idx = inputs["patch_idx"][new_end_idx]
if request.prompt_token_ids[new_end_idx] in [
inputs["image_end_id"],
inputs["video_end_id"],
inputs["audio_end_id"],
]:
end_patch_idx -= 1
end_patch_map = inputs["patch_map"][end_patch_idx]
end_modal_id = end_patch_map["modal_id"]
if end_modal_id == IDS_TYPE_FLAG["image"]:
new_end_idx = end_patch_map["end_idx"] # 当前模态结束位置
if end_modal_id == IDS_TYPE_FLAG["video"] and "can_split_idx_list" in inputs:
can_split_idx_list = inputs["can_split_idx_list"]
for i in range(len(can_split_idx_list)):
if can_split_idx_list[i] >= new_end_idx:
new_end_idx = can_split_idx_list[i]
break
num_new_tokens = new_end_idx - pre_end_idx
request.image_end = end_patch_map["image_num"]
request.video_end = end_patch_map["video_num"]
request.audio_end = _compute_audio_prefix_count(new_end_idx, end_patch_idx)
elif (
inputs is not None
and inputs.get("images", None) is not None
and inputs.get("image_patch_id", None) is not None
and inputs.get("grid_thw", None) is not None
):
input_ids_lst = request.prompt_token_ids + request.output_token_ids
input_ids = paddle.to_tensor(input_ids_lst, dtype="int64")
image_patch_id = inputs["image_patch_id"]
if request.multimodal_img_boundaries is None:
grid_thw = []
for idx, one in enumerate(inputs["grid_thw"]):
t, h, w = one[0], one[1], one[2]
if t == 1:
grid_thw.append(one)
else:
grid_thw.extend([[2, h, w]] * (t // 2))
if current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import get_img_boundaries
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_img_boundaries,
)
else:
from fastdeploy.model_executor.ops.gpu import get_img_boundaries
mm_num_token = inputs["mm_num_token_func"](grid_thw=grid_thw)
mm_num_token = paddle.to_tensor(mm_num_token, dtype="int64")
request.multimodal_img_boundaries = get_img_boundaries(
task_input_ids=input_ids, mm_num_token=mm_num_token, image_patch_id=image_patch_id
).numpy()
grid_thw = np.array(grid_thw).reshape([-1, 3])
inputs["grid_thw"] = grid_thw
grid_thw = inputs["grid_thw"]
img_boundaries_idx = request.multimodal_img_boundaries[0]
img_num_per_boundary = request.multimodal_img_boundaries[1]
ori_prompt_len = img_boundaries_idx[-1].item()
pre_end_idx = request.num_computed_tokens
new_end_idx = pre_end_idx + num_new_tokens
if new_end_idx < ori_prompt_len and input_ids[new_end_idx - 1] == image_patch_id:
boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item()
if boundary_idx == len(img_boundaries_idx):
new_end_idx = ori_prompt_len
else:
new_end_idx = img_boundaries_idx[boundary_idx].item()
elif new_end_idx >= ori_prompt_len and paddle.sum(input_ids[pre_end_idx:new_end_idx] == image_patch_id):
new_end_idx = ori_prompt_len
num_new_tokens = new_end_idx - pre_end_idx
image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id
request.with_image = image_mask.any()
if request.with_image:
pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item()
if pre_boundary_idx == len(img_boundaries_idx):
request.num_image_start = img_num_per_boundary[-1]
else:
pre_boundary_idx = (
pre_boundary_idx
if pre_end_idx == img_boundaries_idx[pre_boundary_idx]
else pre_boundary_idx - 1
)
request.num_image_start = img_num_per_boundary[pre_boundary_idx]
new_boundary_idx = np.searchsorted(img_boundaries_idx, new_end_idx, side="left").item()
if new_boundary_idx == len(img_boundaries_idx):
request.num_image_end = img_num_per_boundary[-1]
else:
new_boundary_idx = (
new_boundary_idx
if new_end_idx == img_boundaries_idx[new_boundary_idx]
else new_boundary_idx - 1
)
request.num_image_end = img_num_per_boundary[new_boundary_idx]
request.image_type_ids_start = np.sum(grid_thw[: request.num_image_start, 0])
request.image_type_ids_end = np.sum(grid_thw[: request.num_image_end, 0])
request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1))
request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1))
if self.encoder_cache:
cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end]
cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end]
request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions)
# Compatible with scenarios without images and videos.
return num_new_tokens
def exist_mm_prefill(self, batch_request):
for request in batch_request:
if request.task_type == RequestType.PREFILL and self._is_mm_request(request):
return True
return False
def add_abort_req_ids(self, req_ids):
with self.lock:
if isinstance(req_ids, list):
self.waiting_abort_req_id_set.update(req_ids)
else:
self.waiting_abort_req_id_set.add(req_ids)
def cache_output_tokens(self, request):
if (
self.config.cache_config.enable_prefix_caching
and self.config.cache_config.enable_output_caching
and self.config.scheduler_config.splitwise_role != "decode"
):
with self.lock:
if request.num_computed_tokens >= request.need_prefill_tokens: # request is decoding
self.cache_manager.cache_output_blocks(request, self.config.cache_config.block_size)
def schedule(self):
"""
Try to pull a batch of requests from the waiting queue and schedule them.
"""
def get_enough_request(request, batch_request):
return (
ErnieArchitectures.is_ernie5_arch(self.config.model_config.architectures)
and self._is_mm_request(request)
and self.exist_mm_prefill(batch_request)
)
with self.lock:
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
tokens_per_seq = (
(self.config.speculative_config.num_speculative_tokens + 1)
if self.config.speculative_config is not None
else 1
)
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs * tokens_per_seq
)
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
need_abort_requests = [] # users trigger abortion
batch_request = BatchRequest()
# First, schedule the RUNNING requests.
req_index = 0
num_decoding_req_nums = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
need_block_num = self.need_block_num_signal.value[request.idx]
if need_block_num != 0:
self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1)
self.need_block_num_signal.value[request.idx] = 0
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
if (
self.config.scheduler_config.splitwise_role == "prefill"
): # do not need to schedule for decoding
req_index += 1
continue
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
request.num_computed_tokens = request.num_total_tokens - 1
if request.request_id in self.waiting_abort_req_id_set:
self._trigger_abort(request.request_id, batch_request)
req_index += 1
need_abort_requests.append(request)
continue
if (
self.allocated_slots(request) - request.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
# Allocation for next decoding blocks
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
llm_logger.debug(
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
)
request.block_tables.extend(
self._allocate_gpu_blocks(request, self.config.cache_config.enc_dec_block_num)
)
# Prepare decoding task
batch_request.add_request(self._prepare_decode_task(request))
else:
# Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(
request, self.config.cache_config.enc_dec_block_num, preempted_reqs, batch_request
)
if not can_schedule:
break
# Allocation for next decoding blocks
request.block_tables.extend(
self._allocate_gpu_blocks(request, self.config.cache_config.enc_dec_block_num)
)
# Prepare decoding task
batch_request.add_request(self._prepare_decode_task(request))
num_decoding_req_nums += 1
token_budget -= 1
if (
request.use_extend_tables
and request.request_id not in self.using_extend_tables_req_id
and self.need_block_num_map[request.request_id].watch() > 0
):
def _allocate_decode_and_extend():
allocate_block_num = self.need_block_num_map[request.request_id].consume()
# Prepare decoding task
request.block_tables.extend(self._allocate_gpu_blocks(request, allocate_block_num))
batch_request.add_request(self._prepare_decode_task(request))
# Prepare extend task
reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size
llm_logger.info(
f"req {request.request_id} at batch id {request.idx} with reuse_block_num {reuse_block_num} is going to enable extend tables,"
f"need_block_num {allocate_block_num}"
)
self.using_extend_tables_req_id.add(request.request_id)
self.reuse_block_num_map[request.request_id] = reuse_block_num
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
request.extend_block_tables.extend(self._allocate_gpu_blocks(request, allocate_block_num))
batch_request.add_request(
ScheduledExtendBlocksTask(
idx=request.idx,
request_id=request.request_id,
extend_block_tables=request.extend_block_tables,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
)
llm_logger.debug(f"extend blocks is {request.extend_block_tables}")
if self.cache_manager.can_allocate_gpu_blocks(
2 * self.need_block_num_map[request.request_id].watch()
):
_allocate_decode_and_extend()
else:
llm_logger.info(
f"{request.idx} using extend block need {2 * self.need_block_num_map[request.request_id].watch()} blocks but got not enough blocks, ready to preempt"
)
can_schedule = self._trigger_preempt(
request,
2 * self.need_block_num_map[request.request_id].watch(),
preempted_reqs,
batch_request,
)
if can_schedule:
_allocate_decode_and_extend()
else:
break
else: # need to prefill
llm_logger.debug(
f"scheduler prefill task in running queue: {request.request_id}, "
f"request.need_prefill_tokens {request.need_prefill_tokens},"
f"request.num_computed_tokens {request.num_computed_tokens}"
)
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens
>= self.config.cache_config.block_size
and token_budget < self.config.cache_config.block_size
):
req_index += 1
continue
if get_enough_request(request, batch_request):
req_index += 1
continue
num_new_tokens = self._get_num_new_tokens(request, token_budget)
if num_new_tokens == 0:
req_index += 1
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(self._allocate_gpu_blocks(request, num_new_block))
# Prepare prefill task
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, batch_request)
if not can_schedule:
break
request.block_tables.extend(self._allocate_gpu_blocks(request, num_new_block))
# Prepare prefill task
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and self.config.scheduler_config.splitwise_role != "prefill"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
req_index += 1
# remove requests to be aborted from running list
for request in need_abort_requests:
self.running.remove(request)
# Second, schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []
while self.waiting and token_budget > 0:
if (
len(self.running)
+ len(self.to_be_rescheduled_request_id_set)
+ len(self.to_be_aborted_req_id_set)
+ sum([req.status == RequestStatus.PREEMPTED for req in self.waiting])
>= self.max_num_seqs
):
break
request = self.waiting[0]
if get_enough_request(request, batch_request):
break
if request.status == RequestStatus.WAITING:
result = self.waiting_async_process(request)
if result is None:
error_reqs.append((request.request_id, request.error_message))
self.waiting.popleft()
continue
elif result is True:
# skip current request, try next request
skip_requests.append(request)
self.waiting.popleft()
continue
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if not self.enable_cache_manager_v1:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
break
if (
current_platform.is_intel_hpu()
and request.need_prefill_tokens - request.num_computed_tokens
>= self.config.cache_config.block_size
and token_budget < self.config.cache_config.block_size
):
continue
# Allocate blocks for the tokens that does not hit cache
if envs.FD_DISABLE_CHUNKED_PREFILL:
# Disable chunk prefill
if token_budget < request.need_prefill_tokens:
break
num_new_tokens = self._get_num_new_tokens(request, token_budget)
if num_new_tokens == 0:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
skip_requests.append(request)
self.waiting.popleft()
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
llm_logger.debug(
f"request.request_id {request.request_id} num_new_block {num_new_block}, request.need_prefill_tokens {request.need_prefill_tokens}, request.num_computed_tokens {request.num_computed_tokens}, token_budget {token_budget}"
)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if num_new_block > 0:
extra_gpu_block_ids = self._allocate_gpu_blocks(request, num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
if self.config.scheduler_config.splitwise_role == "mixed":
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[allocated_position] = request
self.stop_flags[allocated_position] = False
self.req_dict[request.request_id] = allocated_position
llm_logger.debug(f"req_id:{request.request_id} allocate pos end")
else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
break
elif request.status == RequestStatus.PREEMPTED:
request.need_prefill_tokens = (
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
):
if not self.enable_cache_manager_v1:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
break
# Allocate blocks for the tokens that does not hit cache
if envs.FD_DISABLE_CHUNKED_PREFILL:
# Disable chunk prefill
if token_budget < request.need_prefill_tokens:
break
num_new_tokens = self._get_num_new_tokens(request, token_budget)
if num_new_tokens == 0:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
skip_requests.append(request)
self.waiting.popleft()
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if num_new_block > 0:
extra_gpu_block_ids = self._allocate_gpu_blocks(request, num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
)
request.status = RequestStatus.RUNNING
else:
if self.config.cache_config.enable_prefix_caching:
self._free_blocks(request)
break
else:
llm_logger.info(f"Unknown request status type:{request.status}, req_id:{request.request_id}")
for req in skip_requests:
# move waiting request to end of the deque
self.waiting.append(req)
if len(batch_request) > 0:
llm_logger.debug(f"schedued_reqs: {batch_request}")
self.current_reserve_output_block_num_float -= self.decay_output_block_num
self.current_reserve_output_block_num = max(
int(self.current_reserve_output_block_num_float),
self.min_reserve_output_block_num,
0,
)
if self.current_reserve_output_block_num == 0:
self.can_relax_prefill_strategy = True
self._log_console_scheduler_metrics(batch_request)
self.update_metrics()
# Issue pending backup tasks to batch_request
# This handles write_through_selective policy by attaching backup tasks
# to the batch request, which will be processed by the worker
if self.enable_cache_manager_v1 and len(batch_request) > 0:
evict_metadata = self.cache_manager.issue_pending_backup_to_batch_request()
if evict_metadata:
batch_request.append_evict_metadata([evict_metadata])
if self.enable_cache_manager_v1:
self.cache_manager.check_and_add_pending_backup()
return batch_request, error_reqs
def waiting_async_process(self, request: Request) -> None:
"""
Check if async preprocessing is complete for a request.
Args:
request: The request to check
Returns:
None: If an error occurred during preprocessing
True: If preprocessing is still in progress (request should be skipped)
False: If preprocessing is complete (request can be scheduled)
"""
for future in request.async_process_futures:
if future.done():
if request.get("error_message") is not None:
return None
else:
return True
request.async_process_futures = []
return False
def apply_async_preprocess(self, request: Request) -> None:
request.async_process_futures.append(self.async_preprocess_pool.submit(self._download_features, request))
def _has_features_info(self, task):
inputs = task.multimodal_inputs
if inputs is None or len(inputs) == 0:
return False
if (
(inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0)
or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0)
or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0)
):
return True
return False
def _download_features(self, request: Request) -> None:
"""
download multimodal features from bos
Note:
1. this function will be add features for request.multimodal_inputs
2. this function maybe update request.error_message and request.error_code
Args:
request (Request): request object
"""
def download_bos_features(bos_client, features_urls):
result_list = []
for status, feature in download_from_bos(self.bos_client, features_urls, retry=1):
if status:
start_download_time = time.time()
if isinstance(feature, np.ndarray):
feature_info = f"type=np.ndarray, shape={feature.shape}, dtype={feature.dtype}"
elif isinstance(feature, list):
feature_info = f"type=list, len={len(feature)}"
else:
feature_info = f"type={type(feature).__name__}"
elapsed_time = round((time.time() - start_download_time) * 1000, 2)
llm_logger.info(
f"request {request.request_id} async download feature success: {feature_info}, "
f"elapsed time: {elapsed_time} ms"
)
result_list.append(feature)
else:
error_msg = f"request {request.request_id} download features error: {feature}"
llm_logger.error(error_msg)
return error_msg
return result_list
if not self._has_features_info(request):
return None
if self.bos_client is None:
try:
self.bos_client = init_bos_client()
except Exception as e:
error_msg = f"request {request.request_id} init bos client error: {str(e)}"
llm_logger.error(error_msg)
request.error_message = error_msg
request.error_code = 540
return None
inputs = request.multimodal_inputs
if inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["video_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["video_features"] = result
if inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["image_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["image_features"] = result
if inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0:
result = download_bos_features(self.bos_client, inputs["audio_feature_urls"])
if isinstance(result, str): # download error
request.error_message = result
request.error_code = 530
return None
inputs["audio_features"] = result
def get_reqs_in_aborting(self):
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set
def get_available_position(self) -> int:
position = 0
while position < self.max_num_seqs:
if self.stop_flags[position] is True:
return position
position += 1
raise RuntimeError("No available position is available for new request")
def get_real_bsz(self) -> int:
for i in range(self.max_num_seqs - 1, -1, -1):
if not self.stop_flags[i]:
self.real_bsz = i + 1
break
return self.real_bsz
def _allocate_gpu_blocks(self, request: Request, num_blocks: int) -> List[int]:
llm_logger.debug(f"[allocate_gpu_blocks] request_id={request.request_id}, num_blocks={num_blocks}")
if self.enable_cache_manager_v1:
return self.cache_manager.allocate_gpu_blocks(request, num_blocks)
else:
return self.cache_manager.allocate_gpu_blocks(num_blocks, request.request_id)
def _request_match_blocks(self, request: Request, skip_storage: bool = True):
"""
Prefixed cache manager v1 will match blocks for request and return common_block_ids.
"""
if self.enable_cache_manager_v1:
self.cache_manager.match_prefix(request, skip_storage)
match_result = request.match_result
if skip_storage:
common_block_ids = match_result.device_block_ids
matched_token_num = match_result.total_matched_blocks * self.config.cache_config.block_size
metrics = {
"gpu_match_token_num": match_result.matched_device_nums * self.config.cache_config.block_size,
"cpu_match_token_num": match_result.matched_host_nums * self.config.cache_config.block_size,
"storage_match_token_num": match_result.matched_storage_nums * self.config.cache_config.block_size,
"match_gpu_block_ids": common_block_ids,
"gpu_recv_block_ids": [],
"match_storage_block_ids": [],
"cpu_cache_prepare_time": 0,
"storage_cache_prepare_time": 0,
}
no_cache_block_num = (
request.need_prefill_tokens - matched_token_num + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size
request.cache_info = [len(common_block_ids), no_cache_block_num]
return (common_block_ids, matched_token_num, metrics)
else:
# Prefetch cache from storage
pass
else:
(common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks(
request, self.config.cache_config.block_size
)
matched_block_num = len(common_block_ids)
no_cache_block_num = self.cache_manager.get_required_block_num(
request.need_prefill_tokens - matched_token_num,
self.config.cache_config.block_size,
)
request.cache_info = [matched_block_num, no_cache_block_num]
return (common_block_ids, matched_token_num, metrics)
def get_prefix_cached_blocks(self, request: Request):
"""
Match and fetch cache for a task.
"""
try:
(common_block_ids, matched_token_num, metrics) = self._request_match_blocks(
request # skip_storage 使用默认值 True
)
request.block_tables = common_block_ids
request.num_cached_tokens = matched_token_num
if self.config.cache_config.disable_chunked_mm_input:
if matched_token_num == request.need_prefill_tokens:
matched_token_num = matched_token_num - self.config.cache_config.block_size
request.num_computed_tokens = self.revert_chunked_mm_input(
request.multimodal_inputs, matched_token_num
)
else:
if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
else:
request.num_computed_tokens = matched_token_num
if request.num_cached_tokens != request.num_computed_tokens:
revert_tokens_num = request.num_cached_tokens - request.num_computed_tokens
llm_logger.info(
f"request {request.request_id} num_cached_tokens: {request.num_cached_tokens}, revert_tokens_num: {revert_tokens_num}"
)
revert_block_idx = len(common_block_ids) - revert_tokens_num // self.config.cache_config.block_size - 1
for block_idx in range(len(common_block_ids) - 1, revert_block_idx, -1):
if common_block_ids[block_idx] in metrics["match_gpu_block_ids"]:
metrics["gpu_match_token_num"] -= self.config.cache_config.block_size
elif common_block_ids[block_idx] in metrics["gpu_recv_block_ids"]:
metrics["cpu_match_token_num"] -= self.config.cache_config.block_size
elif common_block_ids[block_idx] in metrics["match_storage_block_ids"]:
metrics["storage_match_token_num"] -= self.config.cache_config.block_size
if not request.has_been_preempted_before:
# NOTE: Do not log or report metrics for cache hit rate when request is being rescheduled
request.metrics.gpu_cache_token_num = metrics["gpu_match_token_num"]
request.metrics.cpu_cache_token_num = metrics["cpu_match_token_num"]
request.metrics.storage_cache_token_num = metrics["storage_match_token_num"]
request.metrics.cpu_cache_prepare_time = metrics["cpu_cache_prepare_time"]
request.metrics.storage_cache_prepare_time = metrics["storage_cache_prepare_time"]
main_process_metrics.prefix_cache_token_num.inc(request.num_computed_tokens)
main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.cpu_cache_token_num)
return True
except Exception as e:
llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...")
return False
def add_request(self, request: Request) -> None:
with self.lock:
self.apply_async_preprocess(request)
llm_logger.debug(f"self.waiting append request:{request.request_id},req.type:{request.status}")
self.waiting.append(request)
self.requests[request.request_id] = request
def pre_recycle_resource(self, request_id: str):
"""
Recycle resource in P or D before finished due to unexpected error.
"""
with self.lock:
if request_id not in self.requests:
return
req = self.requests[request_id]
self.tasks_list[req.idx] = None
self.stop_flags[req.idx] = True
self._free_blocks(req)
del self.requests[request_id]
if request_id in self.req_dict:
del self.req_dict[request_id]
def add_request_in_p(self, requests: list[Request]):
with self.lock:
for request in requests:
self.running.append(request)
def preallocate_resource_in_p(self, request: Request):
"""
In P/D aggregated deployment, preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "prefill", "Only P instance can call this method"
with self.lock:
if self.available_batch() == 0:
return False
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend:
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
return False
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
return False
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
extra_gpu_block_ids = self._allocate_gpu_blocks(request, need_extra_prefill_blocks)
request.block_tables.extend(extra_gpu_block_ids)
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.need_prefill_tokens
)
return True
else:
self._free_blocks(request)
return False
else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(self._allocate_gpu_blocks(request, need_prealloc_prefill_blocks))
request.num_computed_tokens = 0
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
return False
def preallocate_resource_in_d(self, request: Request):
"""
In P/D aggregated deployment, D should preallocate resource for P.
If can allocate, allocate resources and return True
If can not, return False
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
request.need_prefill_tokens = len(request.prompt_token_ids)
need_prealloc_prefill_blocks = (
request.need_prefill_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num
with self.lock:
if len(self.waiting) > 0:
return False
if self.available_batch() == 0:
return False
total_need_blocks = self._get_can_schedule_prefill_threshold_block(need_prealloc_prefill_blocks)
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
return False
request.block_tables = self._allocate_gpu_blocks(request, need_prealloc_prefill_blocks)
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
return True
def has_resource_for_prefilled_req(self, request_id: str):
"""
Check whether there are enough slot and gpu resource for the prefilled request,
of which the cache is saved in cpu buffer.
"""
with self.lock:
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
assert request_id in self.preallocated_reqs, "request_id must be in preallocate"
need_blocks_num = len(self.preallocated_reqs[request_id].disaggregate_info["block_tables"])
return self.available_batch() > 0 and self.cache_manager.can_allocate_gpu_blocks(need_blocks_num)
def add_prefilled_request(self, request_output: RequestOutput):
"""
In P/D aggregated deployment, D should continue to decode after receiving first token and cache from P.
NOTE: GPU resources should be checked in advance to ensure they are sufficient for the prefilled request.
"""
with self.lock:
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
if request_output.request_id not in self.requests:
llm_logger.error(f"Request {request_output.request_id} not found in requests")
return
request = self.requests[request_output.request_id]
# update request and insert to running
request.output_token_ids.append(request_output.outputs.token_ids[0])
request.num_cached_tokens = request_output.num_cached_tokens
if (
self.config.speculative_config.method == SpecMethod.MTP
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time
request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time
request.metrics = copy.deepcopy(request_output.metrics)
request.metrics.decode_inference_start_time = time.time()
request.metrics.update_decoder_start_time()
self.running.append(request)
def _free_blocks(self, request: Request):
if self.enable_cache_manager_v1:
self.cache_manager.request_finish(request)
elif (
self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode"
):
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(
request.block_tables[request.num_cached_blocks :], request.request_id
)
else:
self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id)
request.block_tables = []
if request.request_id in self.using_extend_tables_req_id:
reuse_block_num = self.reuse_block_num_map[request.request_id]
self.using_extend_tables_req_id.remove(request.request_id)
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[reuse_block_num:], request.request_id)
llm_logger.info(
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[reuse_block_num:]}"
)
request.extend_block_tables = []
del self.reuse_block_num_map[request.request_id]
del self.need_block_num_map[request.request_id]
def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
return self.finish_execution_pool.submit(self.finish_requests, request_ids)
def finish_requests(self, request_ids: Union[str, Iterable[str]]):
llm_logger.info(f"recycle resources for requests: {request_ids}")
self.update_metrics(verbose=True)
try:
if isinstance(request_ids, str):
request_ids = (request_ids,)
else:
request_ids = set(request_ids)
need_postprocess_reqs = []
with self.lock:
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
llm_logger.error(f"invalid request id: {req_id} self.requests: {self.requests}")
continue
if request in self.waiting:
llm_logger.error(f"request {request.request_id} scheduled into waiting list, after finished")
continue
if request in self.running:
llm_logger.info(f"finish running request: {req_id}")
self.running.remove(request)
request.status = RequestStatus.FINISHED
need_postprocess_reqs.append(request)
if request.request_id in self.to_be_rescheduled_request_id_set:
# finished after preempted, blocks have been recycled.
llm_logger.info(f"finish preempeted request: {req_id}")
self.to_be_rescheduled_request_id_set.remove(request.request_id)
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
# Do not block the main thread here
# Write cache to storage if kvcache_storage_backend is enabled
for req in need_postprocess_reqs:
if self.config.scheduler_config.splitwise_role == "decode":
# D instance uses simplified write method (does not rely on Radix Tree)
self.cache_manager.write_cache_to_storage_decode(req)
else:
# P instance / Mixed instance uses standard write method (relies on Radix Tree)
self.cache_manager.write_cache_to_storage(req)
with self.lock:
for req in need_postprocess_reqs:
try:
self._free_blocks(req)
llm_logger.debug(f"req_id:{req.request_id} free pos:{req.idx}")
except Exception as e:
llm_logger.warning(f"release block failed {req.request_id}: {e}")
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
finally:
self.update_metrics(verbose=True)
def clear_data(self):
self.waiting: deque[Request] = deque()
self.to_be_rescheduled_request_id_set = set()
self.update_metrics(verbose=True)
def update_metrics(self, verbose=False):
# Update metrics
num_tasks = sum([1 if task else 0 for task in self.tasks_list])
blocks_used_by_tasks = set()
for task in self.tasks_list:
if task is not None:
blocks_used_by_tasks.update(task.block_tables)
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks))
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
main_process_metrics.num_requests_running.set(len(self.running))
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))
if verbose:
llm_logger.info(f"update metrics: running={len(self.running)}, waiting={num_tasks - len(self.running)}")
def log_status(self):
llm_logger.info(
f"ResourceManagerV1( "
f"waiting={len(self.waiting)}, "
f"running={len(self.running)}, "
f"preempted={len(self.to_be_rescheduled_request_id_set)}, "
f"tasks_list={self.tasks_list}, "
f"stop_flags={self.stop_flags}, "
f"req_dict={self.req_dict}, "
f"requests={self.requests}, "
f")"
)
def _log_console_scheduler_metrics(self, batch_request: BatchRequest) -> None:
if not (
hasattr(self, "scheduler_metrics_logger")
and self.scheduler_metrics_logger is not None
and envs.FD_CONSOLE_SCHEDULER_METRICS
):
return
total_blocks = self.total_block_number()
free_blocks = self.available_block_num()
used_blocks = max(total_blocks - free_blocks, 0)
tokens_used = used_blocks * self.config.cache_config.block_size
token_usage = used_blocks / total_blocks if total_blocks > 0 else 0.0
running_cnt = len(self.running)
scheduler_queue_cnt = max(int(getattr(self, "scheduler_unhandled_request_num", 0) or 0), 0)
queue_cnt = len(self.waiting) + scheduler_queue_cnt
prefill_reqs = [r for r in batch_request if isinstance(r, Request) and r.task_type == RequestType.PREFILL]
has_decode = any(getattr(r, "task_type", None) == RequestType.DECODE for r in batch_request)
self.scheduler_metrics_logger.log_prefill_batch(
prefill_reqs=prefill_reqs,
running_cnt=running_cnt,
queue_cnt=queue_cnt,
tokens_used=tokens_used,
token_usage=token_usage,
)
if has_decode:
has_prefill = len(prefill_reqs) > 0
graph_opt_cfg = self.config.graph_opt_config
use_cudagraph_cfg = bool(getattr(graph_opt_cfg, "use_cudagraph", False))
graph_opt_level = int(getattr(graph_opt_cfg, "graph_opt_level", 0) or 0)
full_cuda_graph = bool(getattr(graph_opt_cfg, "full_cuda_graph", True))
cudagraph_only_prefill = bool(getattr(graph_opt_cfg, "cudagraph_only_prefill", False))
use_decode_cudagraph = (
has_decode
and use_cudagraph_cfg
and (
# Reference PR https://github.com/PaddlePaddle/FastDeploy/pull/6196
# Static split graph mode: Prefill+Mixed and Decode can use CUDAGraph.
(graph_opt_level > 0 and not full_cuda_graph)
# Dynamic / static-full modes: decode-only can use CUDAGraph.
or (not has_prefill and not cudagraph_only_prefill)
)
)
self.scheduler_metrics_logger.log_decode_batch(
running_cnt=running_cnt,
queue_cnt=queue_cnt,
tokens_used=tokens_used,
token_usage=token_usage,
use_cudagraph=use_decode_cudagraph,
)