[Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)

* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
kevin
2026-04-21 14:39:00 +08:00
committed by GitHub
parent e4a4573080
commit 7707be8384
54 changed files with 14422 additions and 231 deletions
+22
View File
@@ -23,6 +23,12 @@ from fastdeploy.utils import llm_logger as logger
try:
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import (
swap_cache_per_layer, # 单层 KV cache 换入算子(同步)
)
from fastdeploy.model_executor.ops.gpu import (
swap_cache_per_layer_async, # 单层 KV cache 换入算子(异步,无强制 sync)
)
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
@@ -43,6 +49,12 @@ try:
raise RuntimeError("CUDA no need of get_peer_mem_addr!")
elif current_platform.is_maca():
from fastdeploy.model_executor.ops.gpu import (
swap_cache_per_layer, # 单层 KV cache 换入算子(同步)
)
from fastdeploy.model_executor.ops.gpu import (
swap_cache_per_layer_async, # 单层 KV cache 换入算子(异步,无强制 sync)
)
from fastdeploy.model_executor.ops.gpu import ( # get_output_kv_signal,; ipc_sent_key_value_cache_by_remote_ptr_block_sync,
cuda_host_alloc,
cuda_host_free,
@@ -89,6 +101,12 @@ try:
def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
raise RuntimeError("XPU No ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")
def swap_cache_per_layer(*args, **kwargs): # 单层 KV cache 换入算子(同步)
raise RuntimeError("XPU swap_cache_per_layer UNIMPLENENTED")
def swap_cache_per_layer_async(*args, **kwargs): # 单层 KV cache 换入算子(异步)
raise RuntimeError("XPU swap_cache_per_layer_async UNIMPLENENTED")
else:
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")
@@ -128,6 +146,8 @@ except Exception as e:
set_data_ipc = None
share_external_data_ = None
swap_cache_all_layers = None
swap_cache_per_layer = None # 单层 KV cache 换入算子(同步)
swap_cache_per_layer_async = None # 单层 KV cache 换入算子(异步)
unset_data_ipc = None
set_device = None
memory_allocated = None
@@ -146,6 +166,8 @@ __all__ = [
"set_data_ipc",
"share_external_data_",
"swap_cache_all_layers",
"swap_cache_per_layer", # 单层 KV cache 换入算子(同步)
"swap_cache_per_layer_async", # 单层 KV cache 换入算子(异步,无强制 sync)
"unset_data_ipc", # XPU是 None
"set_device",
"memory_allocated",
+71
View File
@@ -0,0 +1,71 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .base import KVCacheBase
from .cache_controller import CacheController
from .cache_manager import CacheManager
from .cache_utils import LayerDoneCounter, LayerSwapTimeoutError
from .metadata import (
AsyncTaskHandler,
BlockNode,
CacheBlockMetadata,
CacheStatus,
MatchResult,
PDTransferMetadata,
StorageConfig,
StorageMetadata,
StorageType,
TransferConfig,
TransferResult,
TransferStatus,
TransferTask,
TransferType,
)
from .storage import create_storage_connector, create_storage_scheduler
from .transfer import create_transfer_connector
from .transfer_manager import CacheTransferManager
__all__ = [
# Base classes
"KVCacheBase",
# Managers
"CacheManager",
"CacheController",
"CacheTransferManager",
# Exceptions
"LayerSwapTimeoutError",
# Utils
"LayerDoneCounter",
# Metadata
"CacheBlockMetadata",
"BlockNode",
"CacheStatus",
"TransferTask",
"TransferStatus",
"TransferConfig",
"TransferResult",
"AsyncTaskHandler",
"MatchResult",
"StorageMetadata",
"PDTransferMetadata",
"StorageConfig",
"StorageType",
"TransferType",
# Factory functions
"create_storage_scheduler",
"create_storage_connector",
"create_transfer_connector",
]
+80
View File
@@ -0,0 +1,80 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class KVCacheBase(ABC):
"""
Abstract base class for KV cache management.
This class defines the common interface for cache management operations.
Subclasses (CacheManager and CacheController) implement specific behaviors
based on their roles in the system.
CacheManager (Scheduler process):
- Manages DeviceBlockPool and HostBlockPool
- Handles block allocation and release
- Coordinates storage operations via StorageScheduler
CacheController (Worker process):
- Manages cache transfer operations
- Handles layer-by-layer transfer synchronization
- Coordinates cross-node transfer via TransferConnector
"""
def __init__(self, config: "FDConfig"):
"""
Initialize the KV cache base.
Args:
config: FDConfig instance containing all fastdeploy configuration
"""
self.config = config
# Extract configuration from FDConfig
self.model_config = config.model_config
self.cache_config = config.cache_config
self.quant_config = config.quant_config
self.parallel_config = config.parallel_config
self._initialized = False
@abstractmethod
def reset_cache(self) -> bool:
"""
Reset the cache state.
This method should be implemented by subclasses to reset their
specific cache state (e.g., clear block pools, reset transfer state).
Returns:
True if reset was successful, False otherwise
"""
pass
def is_initialized(self) -> bool:
"""
Check if the cache has been initialized.
Returns:
True if initialized, False otherwise
"""
return self._initialized
+251
View File
@@ -0,0 +1,251 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import traceback
from abc import ABC
from typing import Any, Dict, List, Optional
from fastdeploy.utils import get_logger
from .metadata import CacheBlockMetadata
logger = get_logger("block_pool", "cache_manager.log")
class BlockPool(ABC):
"""
Abstract base class for block pool management.
"""
def __init__(
self,
num_blocks: int,
block_size: int,
):
"""
Initialize the block pool.
Args:
num_blocks: Total number of blocks in the pool
block_size: Size of each block in bytes
"""
self.num_blocks = num_blocks
self.block_size = block_size
self._lock = threading.RLock()
# Track free and used blocks
self._free_blocks: List[int] = list(range(num_blocks))
self._used_blocks: set = set()
# Block metadata
self._metadata: Dict[int, CacheBlockMetadata] = {}
def allocate(self, num_blocks: int) -> Optional[List[int]]:
"""
Allocate blocks from the pool.
Args:
num_blocks: Number of blocks to allocate
Returns:
List of allocated block indices if successful, None if not enough blocks
"""
with self._lock:
if num_blocks == 0:
return []
if num_blocks > len(self._free_blocks):
logger.warning(
f"BlockPool.allocate failed: not enough blocks, "
f"requested={num_blocks}, available={len(self._free_blocks)}"
)
return None
allocated = self._free_blocks[-num_blocks:]
del self._free_blocks[-num_blocks:]
self._used_blocks.update(allocated)
return allocated
def release(self, block_indices: List[int]) -> None:
"""
Release blocks back to the pool.
Args:
block_indices: List of block indices to release
"""
with self._lock:
for idx in block_indices:
if idx in self._used_blocks:
self._used_blocks.remove(idx)
self._free_blocks.append(idx)
# Clear metadata
self._metadata.pop(idx, None)
else:
logger.error(
f"BlockPool.release: block_id={idx} NOT in used_blocks! "
f"request_blocks={block_indices}, "
f"is_in_free_blocks={idx in self._free_blocks}, "
f"is_valid_block_id={0 <= idx < self.num_blocks}"
)
logger.error(f"BlockPool.release callstack:\n{traceback.format_exc()}")
def get_metadata(self, block_idx: int) -> Optional[CacheBlockMetadata]:
"""
Get metadata for a block.
Args:
block_idx: Block index
Returns:
Block metadata or None if not found
"""
return self._metadata.get(block_idx)
def set_metadata(
self,
block_idx: int,
metadata: CacheBlockMetadata,
) -> None:
"""
Set metadata for a block.
Args:
block_idx: Block index
metadata: Block metadata to set
"""
self._metadata[block_idx] = metadata
def available_blocks(self) -> int:
"""Get number of available blocks."""
return len(self._free_blocks)
def used_blocks(self) -> int:
"""Get number of used blocks."""
return len(self._used_blocks)
def reset(self) -> None:
"""Reset the block pool."""
with self._lock:
self._free_blocks = list(range(self.num_blocks))
self._used_blocks.clear()
self._metadata.clear()
def resize(self, new_num_blocks: int) -> bool:
"""
Resize the block pool.
Supports both expansion and shrinking. Shrinking will fail if
there are more used blocks than the new size.
Args:
new_num_blocks: New total number of blocks
Returns:
True if resize was successful, False otherwise
"""
with self._lock:
current_used = len(self._used_blocks)
# Cannot shrink below currently used blocks
if new_num_blocks < current_used:
return False
old_num_blocks = self.num_blocks
self.num_blocks = new_num_blocks
if new_num_blocks > old_num_blocks:
# Expansion: add new free blocks
new_blocks = list(range(old_num_blocks, new_num_blocks))
self._free_blocks.extend(new_blocks)
elif new_num_blocks < old_num_blocks:
# Shrinking: remove free blocks beyond new size
blocks_to_keep = set(range(new_num_blocks))
self._free_blocks = [b for b in self._free_blocks if b in blocks_to_keep]
# Clean up metadata for removed blocks
for block_id in range(new_num_blocks, old_num_blocks):
self._metadata.pop(block_id, None)
return True
def get_stats(self) -> Dict[str, Any]:
"""Get pool statistics."""
return {
"num_blocks": self.num_blocks,
"block_size": self.block_size,
"available": len(self._free_blocks),
"used": len(self._used_blocks),
}
class DeviceBlockPool(BlockPool):
"""
GPU device memory block pool.
Manages KV cache blocks on GPU memory.
Does not track per-device blocks - device affinity is handled elsewhere.
"""
def __init__(
self,
num_blocks: int,
block_size: int,
):
"""
Initialize the device block pool.
Args:
num_blocks: Total number of blocks in the pool
block_size: Size of each block in bytes
"""
super().__init__(num_blocks, block_size)
def get_stats(self) -> Dict[str, Any]:
"""Get device pool statistics."""
stats = super().get_stats()
return stats
class HostBlockPool(BlockPool):
"""
CPU host memory block pool.
Manages KV cache blocks on CPU memory (pinned memory for fast GPU transfer).
"""
def __init__(
self,
num_blocks: int,
block_size: int,
use_pinned_memory: bool = True,
):
"""
Initialize the host block pool.
Args:
num_blocks: Total number of blocks
block_size: Size of each block in bytes
use_pinned_memory: Whether to use pinned (page-locked) memory
"""
super().__init__(num_blocks, block_size)
self.use_pinned_memory = use_pinned_memory
def get_stats(self) -> Dict[str, Any]:
"""Get host pool statistics."""
stats = super().get_stats()
stats["use_pinned_memory"] = self.use_pinned_memory
return stats
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+628
View File
@@ -0,0 +1,628 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import hashlib
import pickle
import threading
import time
from typing import Any, Callable, Dict, List, Optional, Sequence, Set
from paddleformers.utils.log import logger
class LayerDoneCounter:
"""
Independent synchronization primitive for tracking layer completion of a single transfer.
Used in compute-transfer overlap scenarios:
- Each LayerDoneCounter instance tracks layer completion for one transfer task.
- Uses CUDA Events for efficient waiting (no polling).
- Thread-safe.
Attributes:
_num_layers: Total number of layers.
_lock: Thread lock.
_completed_layers: Set of completed layer indices.
_callbacks: List of layer-completion callbacks.
_cuda_events: CUDA event per layer.
_layer_complete_times: Mapping of layer index to completion time.
_wait_count: Count of active waiters.
"""
def __init__(self, num_layers: int):
"""
Initialize the layer done counter.
Args:
num_layers: Total number of layers to track
"""
self._num_layers = num_layers
self._lock = threading.RLock()
self._completed_layers: Set[int] = set()
self._callbacks: List[Callable[[int], None]] = []
self._start_time: float = time.time()
# ============ CUDA Events for efficient waiting (no polling) ============
# Initialized to None; set by set_layer_event() after kernel submission to transfer stream.
# None means no event recorded yet for that layer (must fall back to polling).
self._cuda_events: List[Any] = [None] * num_layers
self._layer_complete_times: Dict[int, float] = {}
# ============ Reference count for active waiters (prevents premature cleanup) ============
self._wait_count: int = 0
def get_num_layers(self) -> int:
"""Get the total number of layers."""
return self._num_layers
# ============ Mark Methods (called by transfer thread) ============
def set_layer_event(self, layer_idx: int, cuda_event: Any) -> None:
"""
Set the CUDA event for a specific layer (used for cross-stream synchronization).
Called by transfer thread after submitting a layer's kernel to a non-default
stream (e.g., input_stream), so that wait_for_layer() can correctly synchronize
on the actual stream where the transfer runs.
Args:
layer_idx: Index of the layer
cuda_event: CUDA event recorded on the transfer stream after kernel submission
"""
with self._lock:
if 0 <= layer_idx < len(self._cuda_events):
self._cuda_events[layer_idx] = cuda_event
def mark_layer_done(self, layer_idx: int, cuda_event: Any = None) -> bool:
"""
Mark a layer as completed.
Args:
layer_idx: Index of the completed layer
cuda_event: Optional CUDA event to record completion
Returns:
True if this was the last layer, False otherwise
"""
with self._lock:
if layer_idx in self._completed_layers:
logger.warning(f"[mark_layer_done] layer {layer_idx} already marked done")
return len(self._completed_layers) >= self._num_layers
self._completed_layers.add(layer_idx)
self._layer_complete_times[layer_idx] = time.time()
# Record CUDA event if provided
if cuda_event is not None:
try:
cuda_event.record()
except Exception as e:
logger.warning(f"Failed to record CUDA event for layer {layer_idx}: {e}")
# Execute callbacks for this layer
for callback in self._callbacks:
try:
callback(layer_idx)
except Exception:
pass
return len(self._completed_layers) >= self._num_layers
def mark_all_done(self, cuda_event: Any = None) -> bool:
"""
Mark all layers as completed at once (used for D2H all-layers evict mode).
Args:
cuda_event: Optional CUDA event to record completion
Returns:
True (always returns True since all layers are marked done)
"""
with self._lock:
now = time.time()
self._completed_layers = set(range(self._num_layers))
self._layer_complete_times = {i: now for i in range(self._num_layers)}
# Record CUDA event if provided
if cuda_event is not None:
try:
cuda_event.record()
except Exception as e:
logger.warning(f"Failed to record CUDA event: {e}")
# Execute all callbacks (call with -1 to indicate all layers done)
for callback in self._callbacks:
try:
callback(-1)
except Exception:
pass
return True
# ============ Query Methods ============
def is_layer_done(self, layer_idx: int) -> bool:
"""
Check if a specific layer is completed.
Args:
layer_idx: Index of the layer to check
Returns:
True if the layer is completed, False otherwise
"""
with self._lock:
return layer_idx in self._completed_layers
def is_all_done(self) -> bool:
"""
Check if all layers are completed.
Returns:
True if all layers are completed, False otherwise
"""
with self._lock:
return len(self._completed_layers) >= self._num_layers
def get_completed_count(self) -> int:
"""
Get the number of completed layers.
Returns:
Number of completed layers
"""
with self._lock:
return len(self._completed_layers)
def get_pending_layers(self) -> List[int]:
"""
Get list of pending layer indices.
Returns:
List of pending layer indices
"""
with self._lock:
return [i for i in range(self._num_layers) if i not in self._completed_layers]
# ============ Wait Methods (called by forward thread) ============
def wait_for_layer(self, layer_idx: int, timeout: Optional[float] = None) -> bool:
"""
Wait for a specific layer to complete (CUDA Event synchronization).
Always synchronizes the CUDA event before returning to guarantee the GPU
transfer has actually completed, not just that the kernel was submitted.
The fast path that only checked is_layer_done() was unsafe because
mark_layer_done() is called immediately after kernel submission (async),
before the GPU has finished the transfer.
Args:
layer_idx: Index of the layer to wait for
timeout: Maximum wait time in seconds (default: 1s)
Returns:
True if layer completed
Raises:
LayerSwapTimeoutError: If timeout occurs before layer completes
"""
self._increment_wait_count()
try:
start_time = time.time()
timeout = timeout if timeout is not None else 1.0
while True:
# Always try CUDA event sync first: set_layer_event() is called before
# mark_layer_done(), so once is_layer_done() is True the event is present.
cuda_event = self._cuda_events[layer_idx] if layer_idx < len(self._cuda_events) else None
if cuda_event is not None:
try:
cuda_event.synchronize()
return True
except Exception as e:
logger.warning(f"CUDA event sync failed for layer {layer_idx}: {e}")
# Event sync failed; fall through to is_layer_done check
# No event yet (or sync failed): check software state as fallback
# (covers non-cupy scenarios where events are never set)
if self.is_layer_done(layer_idx):
return True
elapsed = time.time() - start_time
if elapsed >= timeout:
logger.error(f"[WaitForLayer] layer={layer_idx} TIMEOUT after {elapsed:.2f}s")
raise LayerSwapTimeoutError(f"Layer swap timeout: layer={layer_idx}, elapsed={elapsed:.2f}s")
time.sleep(0.001)
finally:
self._decrement_wait_count()
def wait_all(self, timeout: Optional[float] = None) -> bool:
"""
Wait for all layers to complete (used for D2H all-layers evict mode).
Always synchronizes _cuda_events[-1] (set by set_layer_event for the last layer)
before returning, for the same reason as wait_for_layer.
Args:
timeout: Maximum wait time in seconds (default: 300s)
Returns:
True if all layers completed
Raises:
LayerSwapTimeoutError: If timeout occurs
"""
self._increment_wait_count()
try:
start_time = time.time()
timeout = timeout if timeout is not None else 300.0
while True:
# _cuda_events[-1] is set by set_layer_event(num_layers-1, ...) before mark_all_done()
last_event = self._cuda_events[-1] if self._cuda_events else None
if last_event is not None:
try:
last_event.synchronize()
return True
except Exception as e:
logger.warning(f"CUDA event sync failed for wait_all: {e}")
# No event yet (or sync failed): check software state as fallback
if self.is_all_done():
return True
elapsed = time.time() - start_time
if elapsed >= timeout:
logger.error(f"[wait_all] TIMEOUT after {elapsed:.2f}s")
raise LayerSwapTimeoutError(f"wait_all timeout: elapsed={elapsed:.2f}s")
time.sleep(0.001)
finally:
self._decrement_wait_count()
# ============ Callback Methods ============
def register_callback(self, callback: Callable[[int], None]) -> None:
"""
Register a callback to be called when each layer completes.
Args:
callback: Function to call with layer index when completed
"""
with self._lock:
self._callbacks.append(callback)
# ============ Internal Helper Methods ============
def _increment_wait_count(self) -> None:
"""Increment the wait count."""
with self._lock:
self._wait_count += 1
def _decrement_wait_count(self) -> None:
"""Decrement the wait count."""
with self._lock:
if self._wait_count > 0:
self._wait_count -= 1
def _should_cleanup(self) -> bool:
"""Check if cleanup is safe (no active waiters and all done)."""
with self._lock:
return self._wait_count == 0 and self.is_all_done()
# ============ Time Tracking Methods ============
def get_layer_complete_time(self, layer_idx: int) -> Optional[float]:
"""
Get the completion time for a specific layer.
Args:
layer_idx: Index of the layer
Returns:
Completion time as Unix timestamp, or None if not completed
"""
with self._lock:
return self._layer_complete_times.get(layer_idx)
def get_layer_wait_time(self, layer_idx: int) -> Optional[float]:
"""
Get the time from transfer start to layer completion.
Args:
layer_idx: Index of the layer
Returns:
Time in seconds, or None if not completed
"""
with self._lock:
complete_time = self._layer_complete_times.get(layer_idx)
if complete_time is None:
return None
return complete_time - self._start_time
def get_all_layer_times(self) -> Dict[int, float]:
"""
Get completion times for all layers.
Returns:
Dictionary mapping layer_idx to completion time
"""
with self._lock:
return self._layer_complete_times.copy()
def get_elapsed_time(self) -> float:
"""
Get elapsed time since transfer start.
Returns:
Elapsed time in seconds
"""
return time.time() - self._start_time
def get_stats(self) -> Dict:
"""
Get current statistics.
Returns:
Dictionary with statistics
"""
with self._lock:
return {
"num_layers": self._num_layers,
"completed_layers": len(self._completed_layers),
"pending_layers": self._num_layers - len(self._completed_layers),
"wait_count": self._wait_count,
}
# ============ Cleanup Methods ============
def cleanup(self) -> None:
"""
Explicit cleanup method to release CUDA events.
Called when the transfer is complete and no more waiting is needed.
"""
with self._lock:
# Check if safe to cleanup
if self._wait_count > 0:
return
# Clear CUDA events
self._cuda_events.clear()
def __del__(self) -> None:
"""
Destructor to ensure CUDA events are released.
Note: This is a fallback. For explicit cleanup, call cleanup() method.
"""
try:
if self._cuda_events:
self._cuda_events.clear()
except Exception:
pass # Ignore errors during destruction
class LayerSwapTimeoutError(Exception):
"""Exception raised when layer swap operation times out."""
pass
# ============ Block Hash Computation ============
def hash_block_tokens(
token_ids: Sequence[int],
parent_block_hash: str | None = None,
extra_keys: Any = None,
) -> str:
"""
Compute hash value for a single block.
Reference: vLLM's hash_block_tokens implementation using chained hash:
hash = SHA256((parent_block_hash, token_ids_tuple, extra_keys))
Args:
token_ids: Token IDs of the current block.
parent_block_hash: Hash of the parent block (chained hash).
extra_keys: Additional keys (e.g., multimodal info, LoRA).
Returns:
Computed block hash as hex string.
"""
if parent_block_hash is None:
parent_block_hash = ""
value = (parent_block_hash, tuple(token_ids), extra_keys)
return hashlib.sha256(pickle.dumps(value)).hexdigest()
def get_block_hash_extra_keys(
request: Any,
start_idx: int,
end_idx: int,
mm_idx: int,
) -> tuple:
"""
Retrieve additional hash keys for a block based on multimodal information.
Mirrors the logic from prefix_cache_manager.PrefixCacheManager.get_block_hash_extra_keys.
For each block [start_idx, end_idx), scans the multimodal positions starting
from mm_idx and collects hashes of any multimodal items that overlap with the block.
Args:
request: Request object. Must expose a ``multimodal_inputs`` attribute which
is either None or a dict with keys:
- ``mm_positions``: list of objects with ``.offset`` and ``.length``
- ``mm_hashes``: list of hash strings, one per multimodal item
start_idx: Token index of the block start (inclusive).
end_idx: Token index of the block end (exclusive).
mm_idx: Index into mm_positions / mm_hashes to start scanning from
(avoids re-scanning already-processed items).
Returns:
(next_mm_idx, hash_keys):
next_mm_idx: updated mm_idx for the next block.
hash_keys : list of multimodal hash strings that fall within this block.
"""
hash_keys: List[str] = []
mm_inputs = getattr(request, "multimodal_inputs", None)
if (
mm_inputs is None
or "mm_positions" not in mm_inputs
or "mm_hashes" not in mm_inputs
or len(mm_inputs["mm_positions"]) == 0
):
return mm_idx, hash_keys
mm_positions = mm_inputs["mm_positions"]
mm_hashes = mm_inputs["mm_hashes"]
# Fast exit: last multimodal item ends before this block starts
if mm_positions[-1].offset + mm_positions[-1].length <= start_idx:
return mm_idx, hash_keys
for img_idx in range(mm_idx, len(mm_positions)):
image_offset = mm_positions[img_idx].offset
image_length = mm_positions[img_idx].length
if image_offset + image_length <= start_idx:
# Multimodal item ends before block starts skip
continue
elif image_offset >= end_idx:
# Multimodal item starts after block ends stop
return img_idx, hash_keys
elif image_offset + image_length > end_idx:
# Multimodal item spans beyond block end include hash, stop at this item
hash_keys.append(mm_hashes[img_idx])
return img_idx, hash_keys
else:
# Multimodal item is fully contained within the block
hash_keys.append(mm_hashes[img_idx])
return len(mm_positions) - 1, hash_keys
def get_request_block_hasher(
block_size: int,
) -> Callable[[Any], List[str]]:
"""
Factory function: returns a block hash calculator bound to block_size.
The returned function computes hashes for new complete blocks in a request.
Computation logic:
1. Get all token IDs (prompt + output)
2. Determine starting position based on existing block_hashes count
3. Compute hashes for new complete blocks (chained hash, with multimodal extra_keys)
Usage:
# Create hasher at service startup
block_hasher = get_request_block_hasher(block_size=64)
# Use in Request.prompt_hashes property
new_hashes = block_hasher(self)
self._prompt_hashes.extend(new_hashes)
Args:
block_size: Number of tokens per block.
Returns:
A function that takes a request and returns a list of newly computed
block hashes.
"""
def request_block_hasher(request: Any) -> List[str]:
"""
Compute hashes for uncomputed complete blocks in a request.
Args:
request: Request object with the following attributes:
- prompt_token_ids: Input token IDs.
- _prompt_hashes: List of existing block hashes (private attr).
- output_token_ids: Output token IDs (optional).
- multimodal_inputs (optional): Multimodal info dict with
``mm_positions`` and ``mm_hashes``.
Returns:
List of newly computed block hashes (only new complete blocks).
"""
# Get prompt token IDs
prompt_ids = request.prompt_token_ids
if hasattr(prompt_ids, "tolist"):
prompt_ids = prompt_ids.tolist()
if prompt_ids is None:
prompt_ids = []
# Get output token IDs
output_ids = getattr(request, "output_token_ids", [])
if hasattr(output_ids, "tolist"):
output_ids = output_ids.tolist()
if output_ids is None:
output_ids = []
# Combine all token IDs
all_token_ids = list(prompt_ids) + list(output_ids)
num_tokens = len(all_token_ids)
# Get existing block hashes
existing_hashes = getattr(request, "_prompt_hashes", [])
if existing_hashes is None:
existing_hashes = []
# Calculate starting position (skip already computed blocks)
start_token_idx = len(existing_hashes) * block_size
# Return empty if no new complete blocks
if start_token_idx + block_size > num_tokens:
return []
new_block_hashes: List[str] = []
prev_block_hash = existing_hashes[-1] if existing_hashes else None
# mm_idx tracks which multimodal item to scan from, avoiding redundant iteration
mm_idx = 0
# Compute hashes for new complete blocks
while True:
end_token_idx = start_token_idx + block_size
if end_token_idx > num_tokens:
break
# Get tokens for current block
block_tokens = all_token_ids[start_token_idx:end_token_idx]
# Collect multimodal extra_keys for this block
mm_idx, extra_keys = get_block_hash_extra_keys(
request=request,
start_idx=start_token_idx,
end_idx=end_token_idx,
mm_idx=mm_idx,
)
extra_keys_value = tuple(extra_keys) if extra_keys else None
# Compute hash (chained hash)
block_hash = hash_block_tokens(block_tokens, prev_block_hash, extra_keys_value)
new_block_hashes.append(block_hash)
# Update state
start_token_idx += block_size
prev_block_hash = block_hash
return new_block_hashes
return request_block_hasher
+590
View File
@@ -0,0 +1,590 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, List, Optional
class TransferStatus(Enum):
"""Status of a transfer task."""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class StorageType(Enum):
"""Supported storage backend types."""
MOONCAKE = "mooncake"
ATTNSTORE = "attnstore"
LOCAL = "local"
class TransferType(Enum):
"""Supported transfer mechanism types."""
RDMA = "rdma"
IPC = "ipc"
class CacheLevel(Enum):
"""Cache hierarchy levels for transfer operations."""
DEVICE = "device"
HOST = "host"
STORAGE = "storage"
class CacheStatus(Enum):
"""Cache status enum representing the current location and state of a BlockNode.
Attributes:
DEVICE: Block is in device (GPU) memory, ready for use. Can be matched.
HOST: Block is in host (CPU) memory, needs to be loaded to device. Can be matched.
SWAP_TO_HOST: Block is being evicted from device to host. Cannot be matched.
SWAP_TO_DEVICE: Block is being loaded from host to device.
LOADING_FROM_STORAGE: Block is being loaded from storage.
DELETING: Block is being deleted (removed from host or deleted when no host cache). Cannot be matched.
"""
DEVICE = auto()
HOST = auto()
SWAP_TO_HOST = auto()
SWAP_TO_DEVICE = auto()
DELETING = auto()
LOADING_FROM_STORAGE = auto()
@dataclass
class RadixTreeStats:
"""
Snapshot of RadixTree statistics.
Encapsulates all state counters for monitoring and statistics.
Returns as a snapshot to ensure consistent values across all fields.
Attributes:
node_count: Total number of nodes in the tree.
evictable_device_count: GPU nodes available for eviction (ref_count==0, status==DEVICE).
evictable_host_count: CPU nodes available for deletion (ref_count==0, status==HOST).
"""
node_count: int = 0
evictable_device_count: int = 0
evictable_host_count: int = 0
@property
def evictable_count(self) -> int:
"""Total evictable nodes count."""
return self.evictable_device_count + self.evictable_host_count
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"node_count": self.node_count,
"evictable_device_count": self.evictable_device_count,
"evictable_host_count": self.evictable_host_count,
"evictable_count": self.evictable_count,
}
@dataclass
class CacheBlockMetadata:
"""
Metadata for a cache block.
Attributes:
block_id: Unique identifier for the block
device_id: GPU device ID where the block resides
block_size: Size of the block in bytes
ref_count: Reference count for the block
is_pinned: Whether the block is pinned in memory
layer_indices: List of layer indices stored in this block
token_count: Number of tokens in this block
hash_value: Hash value for the block content
last_access_time: Last access timestamp
"""
block_id: int
device_id: int
block_size: int
ref_count: int = 0
is_pinned: bool = False
layer_indices: List[int] = field(default_factory=list)
token_count: int = 0
hash_value: Optional[str] = None
last_access_time: float = 0.0
@dataclass
class TransferTask:
"""
Represents a cache transfer task.
Attributes:
task_id: Unique identifier for the task
src_location: Source location (device/host/storage/remote)
dst_location: Destination location
block_indices: List of block indices to transfer
layer_indices: List of layer indices to transfer
status: Current status of the task
priority: Task priority (lower is higher priority)
created_time: Task creation timestamp
started_time: Task start timestamp
completed_time: Task completion timestamp
error_message: Error message if task failed
metadata: Additional task metadata
"""
task_id: str
src_location: str
dst_location: str
block_indices: List[int] = field(default_factory=list)
layer_indices: List[int] = field(default_factory=list)
status: TransferStatus = TransferStatus.PENDING
priority: int = 0
created_time: float = 0.0
started_time: Optional[float] = None
completed_time: Optional[float] = None
error_message: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class StorageConfig:
"""
Configuration for storage backend.
Attributes:
storage_type: Type of storage backend
storage_path: Base path for storage
max_size_bytes: Maximum storage size in bytes
enable_compression: Whether to enable compression
compression_algorithm: Compression algorithm to use
connection_timeout: Connection timeout in seconds
read_timeout: Read timeout in seconds
write_timeout: Write timeout in seconds
extra_config: Additional backend-specific configuration
"""
storage_type: StorageType = StorageType.MOONCAKE
storage_path: str = ""
max_size_bytes: int = 0
enable_compression: bool = False
compression_algorithm: str = "lz4"
connection_timeout: float = 30.0
read_timeout: float = 60.0
write_timeout: float = 60.0
extra_config: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TransferConfig:
"""
Configuration for transfer mechanism.
Attributes:
transfer_type: Type of transfer mechanism
enable_async: Whether to enable async transfer
max_concurrent_transfers: Maximum concurrent transfer tasks
buffer_size: Buffer size for transfer in bytes
enable_checksum: Whether to enable checksum verification
retry_count: Number of retries on failure
retry_delay: Delay between retries in seconds
extra_config: Additional transfer-specific configuration
"""
transfer_type: TransferType = TransferType.RDMA
enable_async: bool = True
max_concurrent_transfers: int = 4
buffer_size: int = 1024 * 1024 # 1MB
enable_checksum: bool = True
retry_count: int = 3
retry_delay: float = 1.0
extra_config: Dict[str, Any] = field(default_factory=dict)
@dataclass
class BlockNode:
"""
Node in the block management tree.
Represents a node in the radix tree or block allocation structure,
tracking block relationships and reference counts.
Attributes:
node_id: Globally unique identifier for this node (UUID)
block_id: Block identifier (may be reused across device/host)
parent: Parent BlockNode reference (None for root)
children: Dict mapping hash values to child BlockNodes (for radix tree)
children_ids: List of child block IDs
ref_count: Number of references to this block (defaults to 1 on creation)
token_count: Number of tokens stored in this block
hash_value: Hash value for prefix matching
cache_status: Current cache status (DEVICE/HOST/SWAP_TO_HOST/SWAP_TO_DEVICE)
last_access_time: Last access timestamp (defaults to current time on creation)
backuped: Whether this block has a backup on host memory
host_block_id: Host block ID where the backup is stored (if backuped=True)
"""
node_id: str = field(default_factory=lambda: str(uuid.uuid4()))
block_id: int = 0
parent: Optional["BlockNode"] = None
children: Dict[str, "BlockNode"] = field(default_factory=dict)
children_ids: List[int] = field(default_factory=list)
ref_count: int = 0
token_count: int = 0
hash_value: Optional[str] = None
cache_status: CacheStatus = CacheStatus.DEVICE
last_access_time: float = field(default_factory=time.time)
# Backup-related fields
backuped: bool = False # Whether a backup exists on host memory
host_block_id: Optional[int] = None # Host block ID where the backup is stored
hit_count: int = 1 # triggers backup when reaching the threshold
def __post_init__(self):
"""Initialize instance with current time if last_access_time not set."""
if self.last_access_time == 0.0:
self.last_access_time = time.time()
def add_child(self, child_id: int) -> None:
"""Add a child block ID."""
if child_id not in self.children_ids:
self.children_ids.append(child_id)
def remove_child(self, child_id: int) -> bool:
"""Remove a child block ID. Returns True if removed."""
if child_id in self.children_ids:
self.children_ids.remove(child_id)
return True
return False
def increment_ref(self) -> int:
"""Increment reference count and return new count."""
self.ref_count += 1
return self.ref_count
def decrement_ref(self) -> int:
"""Decrement reference count and return new count."""
if self.ref_count > 0:
self.ref_count -= 1
return self.ref_count
def touch(self) -> None:
"""
Update last_access_time to current time.
This method should be called whenever the block is accessed
to track access recency for eviction policies.
"""
self.last_access_time = time.time()
def update_access(self, delta_ref: int = 0) -> None:
"""
Update reference count and last_access_time.
Args:
delta_ref: Change in reference count (positive to increment, negative to decrement)
"""
if delta_ref > 0:
self.ref_count += delta_ref
elif delta_ref < 0:
self.ref_count = max(0, self.ref_count + delta_ref)
self.touch()
def is_leaf(self) -> bool:
"""Check if this is a leaf node (no children)."""
return len(self.children_ids) == 0 and len(self.children) == 0
def is_root(self) -> bool:
"""Check if this is a root node (no parent)."""
return self.parent is None
def is_on_device(self) -> bool:
"""Check if block is on device (GPU) memory."""
return self.cache_status == CacheStatus.DEVICE
def is_on_host(self) -> bool:
"""Check if block is on host (CPU) memory."""
return self.cache_status == CacheStatus.HOST
def is_swapping(self) -> bool:
"""Check if block is currently being swapped or deleted."""
return self.cache_status in (
CacheStatus.SWAP_TO_HOST,
CacheStatus.SWAP_TO_DEVICE,
CacheStatus.DELETING,
)
@dataclass
class MatchResult:
"""
Three-level cache prefix match result.
Contains matched nodes from Device, Host, and Storage levels.
Attributes:
storage_nodes: List of matched BlockNodes in Storage.
device_nodes: List of matched BlockNodes in Device.
host_nodes: List of matched BlockNodes in Host.
"""
device_nodes: List["BlockNode"] = field(default_factory=list)
host_nodes: List["BlockNode"] = field(default_factory=list)
storage_nodes: List["BlockNode"] = field(default_factory=list)
uncached_block_ids: List[int] = field(default_factory=list)
@property
def device_block_ids(self) -> List[int]:
"""Get list of matched device block IDs."""
return [node.block_id for node in self.device_nodes]
@property
def total_matched_blocks(self) -> int:
"""Get total number of matched device blocks."""
return self.matched_device_nums + self.matched_host_nums + self.matched_storage_nums
@property
def matched_device_nums(self) -> int:
"""Get total number of matched device blocks."""
return len(self.device_nodes)
@property
def matched_host_nums(self) -> int:
"""Get total number of matched host blocks."""
return len(self.host_nodes)
@property
def matched_storage_nums(self) -> int:
"""Get total number of matched storage hashes."""
return len(self.storage_nodes)
@dataclass
class StorageMetadata:
"""
Base metadata for storage transfer operations.
Encapsulates all information for storage load/evict operations.
Different storage implementations can extend this class with additional fields.
Attributes:
hash_values: List of hash values to transfer.
block_ids: Target/source host block IDs (pre-allocated by Scheduler).
direction: Transfer direction ("load" from storage, "evict" to storage).
storage_type: Storage type ("mooncake", "attnstore", "rdma", etc.).
endpoint: Storage service endpoint address.
timeout: Operation timeout in seconds.
layer_num: Number of layers to transfer (for layer-by-layer transfer).
extra_params: Storage-specific extra parameters.
"""
hash_values: List[str] = field(default_factory=list)
block_ids: List[int] = field(default_factory=list)
direction: str = "load"
storage_type: str = "mooncake"
endpoint: Optional[str] = None
timeout: float = 30.0
layer_num: int = 0
extra_params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class PDTransferMetadata:
"""
Base metadata for PD separation transfer operations.
Encapsulates all information for cross-node transfer in PD separation architecture.
Different transfer mechanisms (RDMA, IPC) can extend this class with additional fields.
Attributes:
source_node_id: Source node identifier (P node ID).
target_node_id: Target node identifier (D node ID).
block_ids: List of block IDs to transfer.
layer_num: Total number of model layers (for layer-by-layer transfer sync).
timeout: Operation timeout in seconds.
extra_params: Transfer-specific extra parameters.
"""
source_node_id: str = ""
target_node_id: str = ""
block_ids: List[int] = field(default_factory=list)
layer_num: int = 0
timeout: float = 30.0
extra_params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class CacheSwapMetadata:
"""
Metadata for cache transfer operations.
Encapsulates the mapping between source and destination block IDs
for Host↔Device, Storage→Host, and other transfer operations.
Attributes:
src_block_ids: Source block IDs (transfer origin).
dst_block_ids: Destination block IDs (transfer target).
src_type: Source cache level (CacheLevel.DEVICE/HOST/STORAGE).
dst_type: Destination cache level (CacheLevel.DEVICE/HOST/STORAGE).
hash_values: Corresponding hash values (used for storage-related operations).
success: Whether the transfer succeeded.
error_message: Error message if transfer failed.
async_handler: Async task handler for tracking the swap task execution state.
"""
src_block_ids: List[int] = field(default_factory=list)
dst_block_ids: List[int] = field(default_factory=list)
src_type: Optional[CacheLevel] = None
dst_type: Optional[CacheLevel] = None
hash_values: List[str] = field(default_factory=list)
success: bool = False
error_message: Optional[str] = None
async_handler: Optional["AsyncTaskHandler"] = None
def is_success(self) -> bool:
"""Return whether the transfer succeeded."""
return self.success
@property
def mapping(self) -> Dict[int, int]:
"""Get the src -> dst block ID mapping dict."""
if not self.success:
return {}
return dict(zip(self.src_block_ids, self.dst_block_ids))
@dataclass
class TransferResult:
"""
Cache transfer operation result.
Encapsulates the mapping between source and destination block IDs
for Host↔Device, Storage→Host, and other transfer operations.
Attributes:
src_block_ids: Source block IDs (transfer origin).
dst_block_ids: Destination block IDs (transfer target).
src_type: Source cache level (CacheLevel.DEVICE/HOST/STORAGE).
dst_type: Destination cache level (CacheLevel.DEVICE/HOST/STORAGE).
success: Whether the transfer succeeded.
error_message: Error message if transfer failed.
"""
src_block_ids: List[int] = field(default_factory=list)
dst_block_ids: List[int] = field(default_factory=list)
src_type: Optional[CacheLevel] = None
dst_type: Optional[CacheLevel] = None
success: bool = True
error_message: Optional[str] = None
@dataclass
class AsyncTaskHandler:
"""
Async task handler.
Used for submitting and tracking the state of async tasks.
External callers use this handler to check whether a task has completed.
Attributes:
task_id: Unique task identifier.
is_completed: Whether the task has completed.
result: Task result (available after completion).
error: Task error message (if failed).
"""
task_id: str = field(default_factory=lambda: str(uuid.uuid4()))
is_completed: bool = False
result: Optional[Any] = None
error: Optional[str] = None
_event: Any = field(default=None, repr=False)
def __post_init__(self):
"""Initialize event for synchronization."""
import threading
object.__setattr__(self, "_event", threading.Event())
def wait(self, timeout: Optional[float] = None) -> bool:
"""
Wait for the task to complete.
Args:
timeout: Maximum wait time in seconds. None means wait indefinitely.
Returns:
True if completed, False if timed out.
"""
return self._event.wait(timeout=timeout)
def cancel(self) -> bool:
"""
Cancel the task.
Returns:
True if successfully cancelled, False otherwise.
"""
if self.is_completed:
return False
self.error = "Task cancelled"
self.is_completed = True
self._event.set()
return True
def get_result(self) -> Any:
"""
Get the task result (blocking).
Returns:
Task result.
Raises:
RuntimeError: If the task failed or was cancelled.
"""
self._event.wait()
if self.error:
raise RuntimeError(self.error)
return self.result
def set_result(self, result: Any) -> None:
"""
Set the task result and mark as completed.
Args:
result: Task result.
"""
self.result = result
self.is_completed = True
self._event.set()
def set_error(self, error: str) -> None:
"""
Set the error message and mark as completed.
Args:
error: Error message.
"""
self.error = error
self.is_completed = True
self._event.set()
+697
View File
@@ -0,0 +1,697 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import heapq
import threading
from typing import Dict, List, Optional, Tuple
from fastdeploy.utils import get_logger
from .metadata import BlockNode, CacheStatus, RadixTreeStats
logger = get_logger("radix_tree", "cache_manager.log")
class RadixTree:
"""
Radix tree for efficient prefix matching in KV cache.
Used to find matching prefixes across different sequences,
enabling KV cache reuse for shared prefixes.
Uses separate min-heaps for DEVICE and HOST evictable nodes with true deletion,
ensuring heap contents are always consistent with the evictable set.
API Usage Guidelines
====================
1. Reference Count Management (CRITICAL)
-----------------------------------------
The reference count (ref_count) determines whether a node can be evicted.
A node is evictable ONLY when ref_count == 0.
IMPORTANT: You MUST pair increment_ref_nodes() and decrement_ref_nodes() calls:
- After insert(): nodes have ref_count >= 1, NOT evictable
- After decrement_ref_nodes(): ref_count decreases, may become evictable
- After increment_ref_nodes(): ref_count increases, removed from evictable set
WARNING: Unbalanced ref_count management can cause:
- Memory leaks: nodes never become evictable (ref_count > 0 forever)
- Premature eviction: nodes evicted while still in use (ref_count == 0)
Example:
nodes, wasted_ids = tree.insert(blocks) # ref_count = 1, wasted_ids may be non-empty if nodes were reused
if wasted_ids:
# Release wasted block_ids that were not used due to node reuse
release_blocks(wasted_ids)
# ... use the nodes ...
tree.decrement_ref_nodes(nodes) # ref_count = 0, now evictable
# Do NOT use nodes after decrement - they may be evicted!
2. Eviction Operation Order
---------------------------
The correct eviction order is:
DEVICE -> HOST -> Storage
Step 1: evict_device_to_host() - Move DEVICE nodes to HOST
- Input: num_blocks, host_block_ids (pre-allocated)
- Output: released device block_ids
- Nodes transition: DEVICE -> HOST (still in tree)
Step 2: evict_host_nodes() - Remove HOST nodes permanently
- Input: num_blocks
- Output: evicted host block_ids
- Nodes removed from tree completely
WARNING: Do NOT call evict_host_nodes() before evict_device_to_host() for
the same nodes - this will fail since nodes are still in DEVICE state.
3. Atomicity Guarantee
----------------------
All eviction methods provide atomic operation:
- Pre-check: verify enough evictable nodes exist
- If pre-check fails, return None immediately (no partial eviction)
- If success, all requested blocks are processed
Check return value:
- None: Not enough evictable blocks, operation failed
- Empty list: num_blocks == 0, nothing to do
- List of block_ids: Success
4. Thread Safety
----------------
All public methods are thread-safe using RLock.
However, be careful with the following pattern:
WARNING: Do NOT hold references to nodes across method calls:
# DANGEROUS - node may be evicted by another thread
nodes = tree.find_prefix(hashes)
# ... some operation without lock ...
tree.increment_ref_nodes(nodes) # nodes may already be evicted!
Instead, use the returned nodes immediately:
nodes = tree.find_prefix(hashes)
tree.increment_ref_nodes(nodes) # Safe: immediate operation
5. Node Lifecycle
-----------------
Node states and valid transitions:
[New] --insert()--> DEVICE (ref_count >= 1)
DEVICE --decrement_ref()--> DEVICE (ref_count == 0, evictable)
DEVICE --evict_device_to_host()--> HOST (ref_count == 0)
HOST --evict_host_nodes()--> [Deleted from tree]
HOST --swap_to_device()--> SWAP_TO_DEVICE
SWAP_TO_DEVICE --complete_swap_to_device()--> DEVICE
WARNING: Once a node's ref_count becomes 0, it can be evicted at any time.
Do NOT access or modify a node after decrementing its ref_count unless
you increment it first.
6. Common Pitfalls
------------------
a) Forgetting to decrement ref_count after use:
-> Memory leak, blocks never released
b) Decrementing ref_count multiple times:
-> ref_count becomes negative, undefined behavior
c) Using nodes after decrement_ref_nodes():
-> Nodes may be evicted, accessing invalid memory
d) Evicting nodes with ref_count > 0:
-> Not possible, eviction methods skip non-zero ref_count nodes
e) Calling find_prefix() on DELETING/SWAP_TO_HOST nodes:
-> These states are skipped, prefix match stops at these nodes
"""
def __init__(
self,
enable_host_cache: bool = False,
write_policy: str = "write_through",
):
"""
Initialize the radix tree.
Args:
enable_host_cache: If True, evict() moves nodes to HOST state
instead of removing them from tree.
write_policy: Write policy for backup to lower tier.
- "write_through": Every matched node triggers backup check
- "write_through_selective": Only nodes with hit_count >= threshold trigger backup
- "write_back": Backup only when evicted (not implemented yet)
"""
self._root = BlockNode()
self._lock = threading.RLock()
self._node_count = 1 # Root node
self._enable_host_cache = enable_host_cache
self._write_policy = write_policy
# Use dict for O(1) add/remove instead of heap's O(n) removal
# Format: {node_id: (last_access_time, node)}
self._evictable_device: Dict[str, Tuple[float, BlockNode]] = {}
self._evictable_host: Dict[str, Tuple[float, BlockNode]] = {}
def insert(
self,
blocks: List[Tuple[str, int]],
cache_status: CacheStatus = CacheStatus.DEVICE,
start_node: Optional[BlockNode] = None,
) -> Tuple[List[BlockNode], List[int]]:
"""
Insert a sequence of blocks into the tree.
Args:
blocks: List of (block_hash, block_id) tuples.
Each tuple represents a complete block.
cache_status: Initial cache status for new nodes.
Defaults to DEVICE.
start_node: Node to start insertion from. If None, starts from root.
Used for incremental insertion after prefix match.
Returns:
Tuple of (result_nodes, wasted_block_ids):
- result_nodes: List of inserted or updated BlockNode objects.
- wasted_block_ids: List of block_ids that were not used due to
node reuse (should be released by caller).
"""
result_nodes = []
wasted_block_ids = []
if not blocks:
return result_nodes, wasted_block_ids
with self._lock:
node = self._root if start_node is None else start_node
for i, (block_hash, block_id) in enumerate(blocks):
if block_hash not in node.children:
# Create new BlockNode with block_id, parent, and hash_value
new_node = BlockNode(
block_id=block_id,
parent=node,
hash_value=block_hash,
cache_status=cache_status,
)
node.children[block_hash] = new_node
self._node_count += 1
else:
# Node already exists for this hash - the new block_id is wasted
existing_node = node.children[block_hash]
if existing_node.block_id != block_id:
# Track the wasted block_id for caller to release
wasted_block_ids.append(block_id)
node = node.children[block_hash]
# Increment ref and update evictable status
node.increment_ref()
# If node in evictable, remove it from evictable dict
if node.cache_status == CacheStatus.DEVICE and node.node_id in self._evictable_device:
del self._evictable_device[node.node_id]
elif node.cache_status == CacheStatus.HOST and node.node_id in self._evictable_host:
del self._evictable_host[node.node_id]
result_nodes.append(node)
return result_nodes, wasted_block_ids
def find_prefix(
self,
block_hashes: List[str],
) -> List[BlockNode]:
"""
Find the longest matching prefix.
Args:
block_hashes: List of block hash values to match.
Returns:
List of matched BlockNode objects in order.
Empty list if no match found.
"""
matched_nodes = []
with self._lock:
node = self._root
for i, block_hash in enumerate(block_hashes):
if block_hash not in node.children:
break
node = node.children[block_hash]
if node.cache_status in (CacheStatus.DELETING, CacheStatus.SWAP_TO_HOST):
break
node.touch()
matched_nodes.append(node)
return matched_nodes
def increment_ref_nodes(self, nodes: List[BlockNode]) -> None:
"""
Increment reference count for a list of nodes.
Removes nodes from evictable set (no longer available for eviction).
Also updates last_access_time for each node.
Args:
nodes: List of BlockNode objects to increment ref_count.
"""
if not nodes:
return
with self._lock:
for node in nodes:
node.increment_ref()
node.hit_count += 1
node.touch()
self._remove_from_evictable(node)
def decrement_ref_nodes(self, nodes: List[BlockNode]) -> None:
"""
Decrement reference count for a list of nodes.
When ref_count becomes 0, the node is added to evictable heap
and becomes available for eviction. Also updates last_access_time.
Args:
nodes: List of BlockNode objects to decrement ref_count.
"""
if not nodes:
return
with self._lock:
for node in nodes:
old_ref = node.ref_count
node.decrement_ref()
node.touch()
# If ref_count goes from 1 to 0, add to evictable
if old_ref == 1 and node.ref_count == 0:
self._add_to_evictable(node)
def reset(self) -> None:
"""
Reset the tree to initial state.
Clears all nodes except root, evictable tracking, and node mappings.
"""
with self._lock:
self._root = BlockNode(block_id=0)
self._node_count = 1
self._evictable_device.clear()
self._evictable_host.clear()
def get_stats(self) -> RadixTreeStats:
"""
Get tree statistics snapshot.
Returns a snapshot of all tree statistics. Using a snapshot ensures
consistent values across all fields in a single call.
Returns:
RadixTreeStats containing all tree statistics.
"""
return RadixTreeStats(
node_count=self._node_count,
evictable_device_count=len(self._evictable_device),
evictable_host_count=len(self._evictable_host),
)
def node_count(self) -> int:
"""Get total number of nodes in the tree."""
return self._node_count
def evict_host_nodes(
self,
num_blocks: int,
) -> Optional[List[int]]:
"""
Evict HOST nodes from the tree.
Removes HOST nodes permanently and returns their block_ids.
Args:
num_blocks: Number of HOST blocks to evict
Returns:
List of evicted host block_ids, or None if not enough
evictable HOST blocks.
"""
if num_blocks == 0:
return []
with self._lock:
if len(self._evictable_host) < num_blocks:
return None
nodes = self._get_lru_nodes(self._evictable_host, num_blocks)
evicted_block_ids = []
for node in nodes:
self._remove_node_from_tree(node)
evicted_block_ids.append(node.block_id)
logger.debug(
f"evict_host_nodes: evicted={evicted_block_ids}, " f"remaining_host={len(self._evictable_host)}"
)
return evicted_block_ids
def _get_lru_nodes(
self,
evictable_dict: Dict[str, Tuple[float, BlockNode]],
num_blocks: int,
) -> List[BlockNode]:
"""
Get the coldest (LRU) nodes from an evictable dict.
Args:
evictable_dict: The evictable dict to get nodes from (_evictable_device or _evictable_host).
num_blocks: Number of nodes to get.
Returns:
List of BlockNode objects in LRU order (coldest first).
"""
if num_blocks <= 0 or not evictable_dict:
return []
smallest = heapq.nsmallest(
min(num_blocks, len(evictable_dict)), evictable_dict.items(), key=lambda item: item[1][0]
)
nodes = [node for _, (_, node) in smallest]
for node_id, _ in smallest:
del evictable_dict[node_id]
return nodes
def evict_device_nodes(
self,
num_blocks: int,
) -> Optional[List[int]]:
"""
Evict DEVICE nodes from the tree directly.
Removes DEVICE nodes permanently without moving to HOST.
This is used when host cache is disabled.
Args:
num_blocks: Number of DEVICE blocks to evict.
Returns:
List of evicted device block_ids, or None if not enough
evictable DEVICE blocks.
"""
if num_blocks == 0:
return []
with self._lock:
if len(self._evictable_device) < num_blocks:
return None
nodes = self._get_lru_nodes(self._evictable_device, num_blocks)
evicted_block_ids = []
for node in nodes:
self._remove_node_from_tree(node)
evicted_block_ids.append(node.block_id)
logger.debug(
f"evict_device_nodes: evicted={evicted_block_ids}, " f"remaining_device={len(self._evictable_device)}"
)
return evicted_block_ids
def evict_device_to_host(
self,
num_blocks: int,
host_block_ids: List[int],
) -> Optional[List[int]]:
"""
Evict DEVICE nodes to host memory.
Changes node status from DEVICE to HOST and updates block_id
to the provided host_block_ids.
Args:
num_blocks: Number of DEVICE blocks to evict
host_block_ids: Pre-allocated host block IDs to use
Returns:
List of released device block_ids, or None if not enough
evictable DEVICE blocks.
"""
if num_blocks == 0:
return []
if len(host_block_ids) < num_blocks:
return None
released_block_ids = []
with self._lock:
if len(self._evictable_device) < num_blocks:
return None
nodes = self._get_lru_nodes(self._evictable_device, num_blocks)
released_block_ids = []
for i, node in enumerate(nodes):
# Save the original device block_id
original_block_id = node.block_id
new_host_block_id = host_block_ids[i]
# Update status and block_id
node.cache_status = CacheStatus.HOST
node.block_id = new_host_block_id
node.touch()
# Add to host evictable dict
self._evictable_host[node.node_id] = (node.last_access_time, node)
released_block_ids.append(original_block_id)
logger.debug(
f"evict_device_to_host: released_device={released_block_ids} -> host={host_block_ids[:len(released_block_ids)]}, "
f"evictable_device={len(self._evictable_device)}, evictable_host={len(self._evictable_host)}"
)
return released_block_ids
def _add_to_evictable(self, node: BlockNode) -> None:
"""
Add a node to the appropriate evictable dict based on cache status.
"""
if node.cache_status == CacheStatus.DEVICE:
if node.node_id not in self._evictable_device:
self._evictable_device[node.node_id] = (node.last_access_time, node)
elif node.cache_status == CacheStatus.HOST:
if node.node_id not in self._evictable_host:
self._evictable_host[node.node_id] = (node.last_access_time, node)
def _remove_from_evictable(self, node: BlockNode) -> None:
"""
Remove a node from evictable tracking (O(1) deletion from dict).
"""
if node.cache_status == CacheStatus.DEVICE and node.node_id in self._evictable_device:
del self._evictable_device[node.node_id]
elif node.cache_status == CacheStatus.HOST and node.node_id in self._evictable_host:
del self._evictable_host[node.node_id]
def _remove_node_from_tree(self, node: BlockNode) -> None:
"""
Remove a single node from the tree permanently.
Args:
node: Node to remove
"""
if node.parent is None:
return # Cannot remove root
# Remove from parent's children
if node.hash_value and node.hash_value in node.parent.children:
del node.parent.children[node.hash_value]
self._node_count -= 1
def swap_to_device(
self,
nodes: List[BlockNode],
gpu_block_ids: List[int],
) -> List[int]:
"""
Swap CPU blocks to device.
Changes node status to SWAP_TO_DEVICE and updates block_id to GPU block ID.
This is used when loading host blocks back to device memory.
Args:
nodes: List of BlockNode objects on host to swap to device.
Caller guarantees all nodes are on HOST.
gpu_block_ids: Corresponding GPU block IDs
Returns:
List of original host block_ids
"""
if len(nodes) != len(gpu_block_ids):
return []
original_block_ids = []
with self._lock:
for node, gpu_block_id in zip(nodes, gpu_block_ids):
# Save the original host block_id
original_block_ids.append(node.block_id)
# Remove from evictable before changing status
self._remove_from_evictable(node)
# Update status to SWAP_TO_DEVICE and block_id to GPU block ID
node.cache_status = CacheStatus.DEVICE # Temporary status for test
node.block_id = gpu_block_id
node.touch()
return original_block_ids
def complete_swap_to_device(
self,
nodes: List[BlockNode],
) -> List[int]:
"""
Complete the swap to device operation.
Changes node status from SWAP_TO_DEVICE to DEVICE.
This should be called after the actual data transfer is complete.
Args:
nodes: List of BlockNode objects that were swapped to device
Returns:
List of GPU block_ids
"""
gpu_block_ids = []
with self._lock:
for node in nodes:
# Update status to DEVICE
node.cache_status = CacheStatus.DEVICE
node.touch()
gpu_block_ids.append(node.block_id)
return gpu_block_ids
def backup_blocks(
self,
nodes: List[BlockNode],
host_block_ids: List[int],
) -> List[int]:
"""
Mark blocks as backed up and record their host block IDs.
This method marks the given nodes as backuped and stores the
host block IDs. It does NOT perform the actual data transfer -
that should be done by the caller via cache_evict_metadata.
Args:
nodes: List of BlockNode objects to backup
host_block_ids: Corresponding host block IDs for the backup
Returns:
List of device block IDs that were marked as backuped
"""
if len(nodes) != len(host_block_ids):
return []
backed_up_ids = []
with self._lock:
for node, host_block_id in zip(nodes, host_block_ids):
node.backuped = True
node.host_block_id = host_block_id
backed_up_ids.append(node.block_id)
return backed_up_ids
def get_candidates_for_backup(self, threshold: int, pending_block_ids: list[int] = []) -> List[BlockNode]:
"""
Get nodes that are candidates for backup based on write_through_selective policy.
Returns evictable device nodes that:
1. Have hit_count >= threshold
2. Are not already backed up
Args:
threshold: Minimum hit_count required for backup candidacy.
pending_block_ids: List of block IDs already in the pending backup queue,
used to avoid duplicate scheduling.
Returns:
List of BlockNode objects that are candidates for backup,
sorted by LRU (coldest first).
"""
if self._write_policy != "write_through_selective":
return []
candidates = []
with self._lock:
for node_id, (_, node) in self._evictable_device.items():
if not node.backuped and node.hit_count >= threshold and node.block_id not in pending_block_ids:
candidates.append(node)
# Sort by LRU (oldest last_access_time first)
candidates.sort(key=lambda n: n.last_access_time)
return candidates
def evict_nodes_selective(
self,
num_blocks: int,
) -> List[int]:
"""
Evict device nodes with write_through_selective optimization.
First selects the coldest (LRU) nodes, then categorizes them:
- without_backup: Release directly (cold data, no transfer needed)
- with_backup: Update metadata to HOST (data already in host)
Args:
num_blocks: Number of blocks to evict
Returns:
List of released device block IDs
"""
if num_blocks <= 0:
return []
with self._lock:
if len(self._evictable_device) < num_blocks:
return []
# Get LRU nodes first (this pops them from _evictable_device)
nodes = self._get_lru_nodes(self._evictable_device, num_blocks)
released_device_ids = []
for node in nodes:
if node.backuped:
released_device_ids.append(node.block_id)
node.cache_status = CacheStatus.HOST
node.block_id = node.host_block_id
node.touch()
# Move to host evictable
self._evictable_host[node.node_id] = (node.last_access_time, node)
else:
self._remove_node_from_tree(node)
released_device_ids.append(node.block_id)
return released_device_ids
@@ -0,0 +1,232 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
if TYPE_CHECKING:
from fastdeploy.config import CacheConfig
from ..metadata import StorageType
from .base import StorageConnector, StorageScheduler
def create_storage_scheduler(
config: Any,
) -> Optional[StorageScheduler]:
"""
Create a StorageScheduler instance based on configuration.
This is a factory function that creates the appropriate StorageScheduler
based on the storage backend type specified in the configuration.
Args:
config: Configuration object, can be:
- CacheConfig: FastDeploy configuration object
- Dict: Dictionary with 'storage_type' and backend-specific settings
- StorageConfig: StorageConfig dataclass instance
Returns:
StorageScheduler instance if successful, None otherwise
Example:
# Using CacheConfig
scheduler = create_storage_scheduler(fd_config)
# Using dict config
config = {
'storage_type': 'mooncake',
'server_addr': 'localhost:8080',
'namespace': 'kv_cache',
}
scheduler = create_storage_scheduler(config)
"""
if config.kvcache_storage_backend is None:
return None
scheduler: Optional[StorageScheduler] = None
# Create scheduler based on storage type
if config.kvcache_storage_backend == "mooncake":
from .mooncake.connector import MooncakeStorageScheduler
scheduler = MooncakeStorageScheduler(config)
elif config.kvcache_storage_backend == "attention_store":
from .attnstore.connector import AttnStoreScheduler
scheduler = AttnStoreScheduler(config)
else:
raise ValueError(
f"Unsupported storage type: {config.kvcache_storage_backend}. "
f"Supported types: mooncake, attention_store, local"
)
# Attempt connection
if scheduler is not None:
if not scheduler.connect():
# Log warning but still return the scheduler
pass
return scheduler
def create_storage_connector(
config: Any,
) -> Optional[StorageConnector]:
"""
Create a StorageConnector instance based on configuration.
This is a factory function that creates the appropriate StorageConnector
based on the storage backend type specified in the configuration.
Args:
config: Configuration object, can be:
- CacheConfig: FastDeploy configuration object
- Dict: Dictionary with 'storage_type' and backend-specific settings
- StorageConfig: StorageConfig dataclass instance
Returns:
StorageConnector instance if successful, None otherwise
Example:
# Using CacheConfig
connector = create_storage_connector(fd_config)
# Using dict config
config = {
'storage_type': 'mooncake',
'server_addr': 'localhost:8080',
'buffer_size': 1024 * 1024,
}
connector = create_storage_connector(config)
"""
if config.kvcache_storage_backend is None:
return None
connector: Optional[StorageConnector] = None
# Create connector based on storage type
if config.kvcache_storage_backend == "mooncake":
from .mooncake.connector import MooncakeStorageConnector
connector = MooncakeStorageConnector(config)
elif config.kvcache_storage_backend == "attention_store":
from .attnstore.connector import AttnStoreConnector
connector = AttnStoreConnector(config)
else:
raise ValueError(
f"Unsupported storage type: {config.kvcache_storage_backend}. "
f"Supported types: mooncake, attention_store, local"
)
# Attempt connection
if connector is not None:
if not connector.connect():
# Log warning but still return the connector
pass
return connector
def _parse_storage_config(config: "CacheConfig") -> tuple:
"""
Parse storage configuration from various input types.
Args:
config: Configuration object (CacheConfig, Dict, or StorageConfig)
Returns:
Tuple of (storage_type, backend_config)
"""
storage_type = None
backend_config: Dict[str, Any] = {}
# Handle CacheConfig
if hasattr(config, "cache_config") and config.cache_config is not None:
cache_config = config.cache_config
# Get storage type from cache_config
if hasattr(cache_config, "kvcache_storage_backend"):
storage_backend = cache_config.kvcache_storage_backend
if storage_backend:
storage_type = _normalize_storage_type(storage_backend)
# Extract backend-specific configuration
if hasattr(cache_config, "kvcache_storage_config"):
backend_config = cache_config.kvcache_storage_config or {}
# Handle dict config
elif isinstance(config, dict):
if "storage_type" in config:
storage_type = _normalize_storage_type(config["storage_type"])
# Copy other keys as backend config
backend_config = {k: v for k, v in config.items() if k != "storage_type"}
elif "kvcache_storage_backend" in config:
storage_type = _normalize_storage_type(config["kvcache_storage_backend"])
backend_config = config.get("kvcache_storage_config", {})
# Handle StorageConfig dataclass
elif hasattr(config, "storage_type"):
storage_type = config.storage_type
backend_config = {
"storage_path": getattr(config, "storage_path", ""),
"max_size_bytes": getattr(config, "max_size_bytes", 0),
"enable_compression": getattr(config, "enable_compression", False),
"compression_algorithm": getattr(config, "compression_algorithm", "lz4"),
"connection_timeout": getattr(config, "connection_timeout", 30.0),
"read_timeout": getattr(config, "read_timeout", 60.0),
"write_timeout": getattr(config, "write_timeout", 60.0),
"extra_config": getattr(config, "extra_config", {}),
}
return storage_type, backend_config
def _normalize_storage_type(storage_type: Any) -> Optional[str]:
"""
Normalize storage type to lowercase string.
Args:
storage_type: Storage type (enum, string, etc.)
Returns:
Normalized storage type string
"""
if storage_type is None:
return None
# Handle enum
if isinstance(storage_type, StorageType):
return storage_type.value
# Handle string
if isinstance(storage_type, str):
return storage_type.lower()
# Handle other types
return str(storage_type).lower()
__all__ = [
"StorageScheduler",
"StorageConnector",
"create_storage_scheduler",
"create_storage_connector",
]
@@ -0,0 +1,22 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .connector import AttnStoreConnector, AttnStoreScheduler
__all__ = [
"AttnStoreScheduler",
"AttnStoreConnector",
]
@@ -0,0 +1,140 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Dict, List, Optional
from ..base import StorageConnector, StorageScheduler
class AttnStoreScheduler(StorageScheduler):
"""
AttnStore scheduler for Scheduler process.
Provides query operations for AttnStore system.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize AttnStore scheduler.
Args:
config: Configuration with keys:
- store_path: Base path for AttnStore
- cache_size: Cache size in bytes
"""
super().__init__(config)
def connect(self) -> bool:
"""Connect to AttnStore."""
try:
# Placeholder implementation
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from AttnStore."""
self._connected = False
def exists(self, key: str) -> bool:
"""Check if key exists in AttnStore."""
if not self._connected:
return False
# Placeholder implementation
return False
def query(self, keys: List[str]) -> Dict[str, bool]:
"""Query multiple keys for existence."""
if not self._connected:
return {k: False for k in keys}
# Placeholder implementation
return {k: False for k in keys}
def get_metadata(self, key: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a key."""
if not self._connected:
return None
# Placeholder implementation
return None
def list_keys(self, prefix: str = "") -> List[str]:
"""List keys with a given prefix."""
if not self._connected:
return []
# Placeholder implementation
return []
class AttnStoreConnector(StorageConnector):
"""
AttnStore connector for Worker process.
Provides data transfer operations for AttnStore system.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize AttnStore connector.
Args:
config: Configuration with keys:
- store_path: Base path for AttnStore
- transfer_threads: Number of transfer threads
"""
super().__init__(config)
def connect(self) -> bool:
"""Connect to AttnStore."""
try:
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from AttnStore."""
self._connected = False
def get(self, key: str, dst_buffer: Any) -> bool:
"""Get data from AttnStore."""
if not self._connected:
return False
# Placeholder implementation
return False
def set(self, key: str, src_buffer: Any, size: int) -> bool:
"""Set data in AttnStore."""
if not self._connected:
return False
# Placeholder implementation
return False
def delete(self, key: str) -> bool:
"""Delete data from AttnStore."""
if not self._connected:
return False
# Placeholder implementation
return False
def clear(self, prefix: str = "") -> int:
"""Clear data from AttnStore."""
if not self._connected:
return 0
# Placeholder implementation
return 0
+218
View File
@@ -0,0 +1,218 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class StorageScheduler(ABC):
"""
Abstract base class for storage scheduler operations.
Used by CacheManager (Scheduler process) to query storage
existence and metadata without performing actual data transfer.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the storage scheduler.
Args:
config: Storage configuration
"""
self.config = config or {}
self._lock = threading.RLock()
self._connected = False
@abstractmethod
def connect(self) -> bool:
"""
Connect to the storage backend.
Returns:
True if connection was successful
"""
pass
@abstractmethod
def disconnect(self) -> None:
"""Disconnect from the storage backend."""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Check if a key exists in storage.
Args:
key: Storage key to check
Returns:
True if key exists
"""
pass
@abstractmethod
def query(self, keys: List[str]) -> Dict[str, bool]:
"""
Query multiple keys for existence.
Args:
keys: List of keys to query
Returns:
Dictionary mapping keys to existence status
"""
pass
@abstractmethod
def get_metadata(self, key: str) -> Optional[Dict[str, Any]]:
"""
Get metadata for a key.
Args:
key: Storage key
Returns:
Metadata dictionary or None if not found
"""
pass
@abstractmethod
def list_keys(self, prefix: str = "") -> List[str]:
"""
List keys with a given prefix.
Args:
prefix: Key prefix to filter
Returns:
List of matching keys
"""
pass
def is_connected(self) -> bool:
"""Check if connected to storage."""
return self._connected
def get_stats(self) -> Dict[str, Any]:
"""Get storage statistics."""
return {
"connected": self._connected,
"config": self.config,
}
class StorageConnector(ABC):
"""
Abstract base class for storage connector operations.
Used by CacheController (Worker process) to perform actual
data transfer operations with the storage backend.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the storage connector.
Args:
config: Storage configuration
"""
self.config = config or {}
self._lock = threading.RLock()
self._connected = False
@abstractmethod
def connect(self) -> bool:
"""
Connect to the storage backend.
Returns:
True if connection was successful
"""
pass
@abstractmethod
def disconnect(self) -> None:
"""Disconnect from the storage backend."""
pass
@abstractmethod
def get(self, key: str, dst_buffer: Any) -> bool:
"""
Get data from storage.
Args:
key: Storage key
dst_buffer: Destination buffer to write data
Returns:
True if get was successful
"""
pass
@abstractmethod
def set(self, key: str, src_buffer: Any, size: int) -> bool:
"""
Set data in storage.
Args:
key: Storage key
src_buffer: Source buffer to read data from
size: Size of data in bytes
Returns:
True if set was successful
"""
pass
@abstractmethod
def delete(self, key: str) -> bool:
"""
Delete data from storage.
Args:
key: Storage key to delete
Returns:
True if deletion was successful
"""
pass
@abstractmethod
def clear(self, prefix: str = "") -> int:
"""
Clear data from storage.
Args:
prefix: Key prefix to clear (empty for all)
Returns:
Number of keys cleared
"""
pass
def is_connected(self) -> bool:
"""Check if connected to storage."""
return self._connected
def get_stats(self) -> Dict[str, Any]:
"""Get connector statistics."""
return {
"connected": self._connected,
"config": self.config,
}
@@ -0,0 +1,22 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .connector import MooncakeStorageConnector, MooncakeStorageScheduler
__all__ = [
"MooncakeStorageScheduler",
"MooncakeStorageConnector",
]
@@ -0,0 +1,168 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Dict, List, Optional
from ..base import StorageConnector, StorageScheduler
class MooncakeStorageScheduler(StorageScheduler):
"""
Mooncake storage scheduler for Scheduler process.
Provides query operations for Mooncake distributed storage.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize Mooncake storage scheduler.
Args:
config: Configuration with keys:
- server_addr: Mooncake server address
- namespace: Storage namespace
- timeout: Connection timeout
"""
super().__init__(config)
self._client = None
def connect(self) -> bool:
"""Connect to Mooncake storage."""
try:
# Initialize Mooncake client
# This would be implemented with actual Mooncake SDK
# import mooncake
# self._client = mooncake.Client(**self.config)
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from Mooncake storage."""
self._client = None
self._connected = False
def exists(self, key: str) -> bool:
"""Check if key exists in Mooncake storage."""
if not self._connected or self._client is None:
return False
# Placeholder implementation
# return self._client.exists(key)
return False
def query(self, keys: List[str]) -> Dict[str, bool]:
"""Query multiple keys for existence."""
if not self._connected or self._client is None:
return {k: False for k in keys}
# Placeholder implementation
# return self._client.batch_exists(keys)
return {k: False for k in keys}
def get_metadata(self, key: str) -> Optional[Dict[str, Any]]:
"""Get metadata for a key."""
if not self._connected or self._client is None:
return None
# Placeholder implementation
# return self._client.get_metadata(key)
return None
def list_keys(self, prefix: str = "") -> List[str]:
"""List keys with a given prefix."""
if not self._connected or self._client is None:
return []
# Placeholder implementation
# return self._client.list_keys(prefix)
return []
class MooncakeStorageConnector(StorageConnector):
"""
Mooncake storage connector for Worker process.
Provides data transfer operations for Mooncake distributed storage.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize Mooncake storage connector.
Args:
config: Configuration with keys:
- server_addr: Mooncake server address
- namespace: Storage namespace
- transfer_timeout: Transfer timeout
- buffer_size: Transfer buffer size
"""
super().__init__(config)
self._client = None
def connect(self) -> bool:
"""Connect to Mooncake storage."""
try:
# Initialize Mooncake client
# This would be implemented with actual Mooncake SDK
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from Mooncake storage."""
self._client = None
self._connected = False
def get(self, key: str, dst_buffer: Any) -> bool:
"""Get data from Mooncake storage."""
if not self._connected or self._client is None:
return False
# Placeholder implementation
# return self._client.get(key, dst_buffer)
return False
def set(self, key: str, src_buffer: Any, size: int) -> bool:
"""Set data in Mooncake storage."""
if not self._connected or self._client is None:
return False
# Placeholder implementation
# return self._client.set(key, src_buffer, size)
return False
def delete(self, key: str) -> bool:
"""Delete data from Mooncake storage."""
if not self._connected or self._client is None:
return False
# Placeholder implementation
# return self._client.delete(key)
return False
def clear(self, prefix: str = "") -> int:
"""Clear data from Mooncake storage."""
if not self._connected or self._client is None:
return 0
# Placeholder implementation
# return self._client.clear(prefix)
return 0
@@ -0,0 +1,176 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Dict, Optional
from .base import TransferConnector
def create_transfer_connector(
config: Any,
) -> Optional[TransferConnector]:
"""
Create a TransferConnector instance based on configuration.
This is a factory function that creates the appropriate TransferConnector
based on the transfer backend type specified in the configuration.
Args:
config: Configuration object, can be:
- CacheConfig: FastDeploy configuration object
- Dict: Dictionary with 'transfer_type' and backend-specific settings
Returns:
TransferConnector instance if successful, None otherwise
Example:
# Using CacheConfig
connector = create_transfer_connector(fd_config)
# Using dict config
config = {
'transfer_type': 'rdma',
'device': 'mlx5_0',
'port': 1,
}
connector = create_transfer_connector(config)
"""
transfer_type = _get_transfer_type(config)
if transfer_type is None:
return None
connector: Optional[TransferConnector] = None
# Create connector based on transfer type
if transfer_type == "rdma":
from .rdma.connector import RDMAConnector
connector = RDMAConnector(_get_backend_config(config))
elif transfer_type == "ipc":
from .ipc.connector import IPCConnector
connector = IPCConnector(_get_backend_config(config))
else:
raise ValueError(f"Unsupported transfer type: {transfer_type}. " f"Supported types: rdma, ipc")
# Attempt connection
if connector is not None:
if not connector.connect():
# Log warning but still return the connector
pass
return connector
def _get_transfer_type(config: Any) -> Optional[str]:
"""
Get transfer type from configuration.
Args:
config: Configuration object
Returns:
Transfer type string or None
"""
# Handle CacheConfig (from FDConfig)
if hasattr(config, "kvcache_transfer_backend"):
transfer_backend = config.kvcache_transfer_backend
if transfer_backend:
return _normalize_transfer_type(transfer_backend)
# Handle dict config
if isinstance(config, dict):
if "transfer_type" in config:
return _normalize_transfer_type(config["transfer_type"])
elif "kvcache_transfer_backend" in config:
return _normalize_transfer_type(config["kvcache_transfer_backend"])
# Handle object with cache_config attribute
if hasattr(config, "cache_config") and config.cache_config is not None:
cache_config = config.cache_config
if hasattr(cache_config, "kvcache_transfer_backend"):
transfer_backend = cache_config.kvcache_transfer_backend
if transfer_backend:
return _normalize_transfer_type(transfer_backend)
return None
def _get_backend_config(config: Any) -> Dict[str, Any]:
"""
Extract backend-specific configuration.
Args:
config: Configuration object
Returns:
Dictionary with backend configuration
"""
backend_config: Dict[str, Any] = {}
# Handle CacheConfig
if hasattr(config, "kvcache_transfer_config"):
backend_config = config.kvcache_transfer_config or {}
# Handle dict config
elif isinstance(config, dict):
if "transfer_config" in config:
backend_config = config["transfer_config"]
elif "kvcache_transfer_config" in config:
backend_config = config["kvcache_transfer_config"]
else:
# Copy all keys except transfer_type
backend_config = {
k: v for k, v in config.items() if k not in ("transfer_type", "kvcache_transfer_backend")
}
# Handle object with cache_config attribute
if hasattr(config, "cache_config") and config.cache_config is not None:
cache_config = config.cache_config
if hasattr(cache_config, "kvcache_transfer_config"):
backend_config = cache_config.kvcache_transfer_config or {}
return backend_config
def _normalize_transfer_type(transfer_type: Any) -> Optional[str]:
"""
Normalize transfer type to lowercase string.
Args:
transfer_type: Transfer type (enum, string, etc.)
Returns:
Normalized transfer type string
"""
if transfer_type is None:
return None
# Handle string
if isinstance(transfer_type, str):
return transfer_type.lower()
# Handle other types
return str(transfer_type).lower()
__all__ = [
"TransferConnector",
"create_transfer_connector",
]
@@ -0,0 +1,194 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
class TransferConnector(ABC):
"""
Abstract base class for transfer connector operations.
Used by CacheController (Worker process) to perform cross-node
and cross-process data transfer operations.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize the transfer connector.
Args:
config: Transfer configuration
"""
self.config = config or {}
self._lock = threading.RLock()
self._connected = False
@abstractmethod
def connect(self) -> bool:
"""
Connect to the transfer backend.
Returns:
True if connection was successful
"""
pass
@abstractmethod
def disconnect(self) -> None:
"""Disconnect from the transfer backend."""
pass
@abstractmethod
def send(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> bool:
"""
Send data to a remote destination.
Args:
dst_addr: Destination address
src_buffer: Source buffer to read data from
size: Size of data in bytes
dst_offset: Offset at destination
Returns:
True if send was successful
"""
pass
@abstractmethod
def recv(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> bool:
"""
Receive data from a remote source.
Args:
src_addr: Source address
dst_buffer: Destination buffer to write data
size: Size of data in bytes
src_offset: Offset at source
Returns:
True if receive was successful
"""
pass
@abstractmethod
def send_async(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> Any:
"""
Asynchronously send data to a remote destination.
Args:
dst_addr: Destination address
src_buffer: Source buffer to read data from
size: Size of data in bytes
dst_offset: Offset at destination
Returns:
Handle for tracking the async operation
"""
pass
@abstractmethod
def recv_async(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> Any:
"""
Asynchronously receive data from a remote source.
Args:
src_addr: Source address
dst_buffer: Destination buffer to write data
size: Size of data in bytes
src_offset: Offset at source
Returns:
Handle for tracking the async operation
"""
pass
@abstractmethod
def wait(self, handle: Any, timeout: float = -1) -> bool:
"""
Wait for an async operation to complete.
Args:
handle: Handle from send_async or recv_async
timeout: Timeout in seconds (-1 for infinite)
Returns:
True if operation completed successfully
"""
pass
@abstractmethod
def register_buffer(self, buffer: Any, addr: str) -> bool:
"""
Register a buffer for RDMA operations.
Args:
buffer: Buffer to register
addr: Address to associate with buffer
Returns:
True if registration was successful
"""
pass
@abstractmethod
def unregister_buffer(self, addr: str) -> bool:
"""
Unregister a buffer.
Args:
addr: Address of buffer to unregister
Returns:
True if unregistration was successful
"""
pass
def is_connected(self) -> bool:
"""Check if connected to transfer backend."""
return self._connected
def get_stats(self) -> Dict[str, Any]:
"""Get connector statistics."""
return {
"connected": self._connected,
"config": self.config,
}
@@ -0,0 +1,21 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .connector import IPCConnector
__all__ = [
"IPCConnector",
]
@@ -0,0 +1,201 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import mmap
import os
from typing import Any, Dict, Optional
from ..base import TransferConnector
class IPCConnector(TransferConnector):
"""
IPC connector for cross-process transfer on same node.
Uses shared memory for efficient data transfer between
processes on the same machine.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize IPC connector.
Args:
config: Configuration with keys:
- shm_path: Shared memory path prefix
- buffer_size: Default buffer size
- max_buffers: Maximum number of buffers
"""
super().__init__(config)
self._shm_buffers: Dict[str, mmap.mmap] = {}
self._shm_paths: Dict[str, str] = {}
def connect(self) -> bool:
"""Connect to IPC backend."""
try:
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from IPC backend."""
# Clean up shared memory
for name, shm in self._shm_buffers.items():
try:
shm.close()
except Exception:
pass
# Remove shared memory files
for name, path in self._shm_paths.items():
try:
os.unlink(path)
except Exception:
pass
self._shm_buffers.clear()
self._shm_paths.clear()
self._connected = False
def send(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> bool:
"""Send data via shared memory."""
if not self._connected:
return False
if dst_addr not in self._shm_buffers:
return False
try:
shm = self._shm_buffers[dst_addr]
shm.seek(dst_offset)
shm.write(src_buffer[:size])
return True
except Exception:
return False
def recv(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> bool:
"""Receive data via shared memory."""
if not self._connected:
return False
if src_addr not in self._shm_buffers:
return False
try:
shm = self._shm_buffers[src_addr]
shm.seek(src_offset)
data = shm.read(size)
dst_buffer[:size] = data
return True
except Exception:
return False
def send_async(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> Any:
"""Asynchronously send data via shared memory."""
# For shared memory, async is similar to sync
success = self.send(dst_addr, src_buffer, size, dst_offset)
return {"success": success, "addr": dst_addr}
def recv_async(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> Any:
"""Asynchronously receive data via shared memory."""
# For shared memory, async is similar to sync
success = self.recv(src_addr, dst_buffer, size, src_offset)
return {"success": success, "addr": src_addr}
def wait(self, handle: Any, timeout: float = -1) -> bool:
"""Wait for IPC operation completion."""
if handle is None:
return False
return handle.get("success", False)
def register_buffer(self, buffer: Any, addr: str) -> bool:
"""Register a shared memory buffer."""
if not self._connected:
return False
try:
# Create shared memory file
shm_path = f"/dev/shm/kv_cache_{addr}"
shm_fd = os.open(shm_path, os.O_CREAT | os.O_RDWR, 0o666)
# Size the file
buffer_size = len(buffer) if hasattr(buffer, "__len__") else self.config.get("buffer_size", 1024 * 1024)
os.ftruncate(shm_fd, buffer_size)
# Map the file
shm = mmap.mmap(shm_fd, buffer_size)
os.close(shm_fd)
self._shm_buffers[addr] = shm
self._shm_paths[addr] = shm_path
return True
except Exception:
return False
def unregister_buffer(self, addr: str) -> bool:
"""Unregister a shared memory buffer."""
if addr not in self._shm_buffers:
return False
try:
self._shm_buffers[addr].close()
del self._shm_buffers[addr]
if addr in self._shm_paths:
os.unlink(self._shm_paths[addr])
del self._shm_paths[addr]
return True
except Exception:
return False
def get_stats(self) -> Dict[str, Any]:
"""Get IPC connector statistics."""
stats = super().get_stats()
stats.update(
{
"registered_buffers": len(self._shm_buffers),
"buffer_addresses": list(self._shm_buffers.keys()),
}
)
return stats
@@ -0,0 +1,21 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .connector import RDMAConnector
__all__ = [
"RDMAConnector",
]
@@ -0,0 +1,173 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Any, Dict, Optional
from ..base import TransferConnector
class RDMAConnector(TransferConnector):
"""
RDMA connector for high-performance cross-node transfer.
Uses RDMA for zero-copy, low-latency data transfer between
nodes in PD separation deployments.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
Initialize RDMA connector.
Args:
config: Configuration with keys:
- device: RDMA device name
- port: RDMA port
- max_wr: Maximum work requests
- buffer_size: Buffer size for transfers
"""
super().__init__(config)
self._pd = None # Protection domain
self._cq = None # Completion queue
self._qp = None # Queue pair
self._mr = None # Memory region
self._buffers: Dict[str, Any] = {}
def connect(self) -> bool:
"""Connect to RDMA backend."""
try:
# Initialize RDMA resources
# This would be implemented with actual RDMA libraries
# import pyverbs
# self._pd = pyverbs.PD(...)
# self._cq = pyverbs.CQ(...)
# self._qp = pyverbs.QP(...)
self._connected = True
return True
except Exception:
self._connected = False
return False
def disconnect(self) -> None:
"""Disconnect from RDMA backend."""
self._buffers.clear()
self._mr = None
self._qp = None
self._cq = None
self._pd = None
self._connected = False
def send(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> bool:
"""Send data via RDMA write."""
if not self._connected:
return False
# Placeholder implementation
# This would use RDMA write operations
# self._qp.post_send(...)
# self._cq.poll()
return False
def recv(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> bool:
"""Receive data via RDMA read."""
if not self._connected:
return False
# Placeholder implementation
# This would use RDMA read operations
# self._qp.post_recv(...)
# self._cq.poll()
return False
def send_async(
self,
dst_addr: str,
src_buffer: Any,
size: int,
dst_offset: int = 0,
) -> Any:
"""Asynchronously send data via RDMA."""
if not self._connected:
return None
# Placeholder implementation
# Return a work request handle
return None
def recv_async(
self,
src_addr: str,
dst_buffer: Any,
size: int,
src_offset: int = 0,
) -> Any:
"""Asynchronously receive data via RDMA."""
if not self._connected:
return None
# Placeholder implementation
# Return a work request handle
return None
def wait(self, handle: Any, timeout: float = -1) -> bool:
"""Wait for RDMA operation completion."""
if not self._connected:
return False
# Placeholder implementation
# Poll completion queue for the work request
return False
def register_buffer(self, buffer: Any, addr: str) -> bool:
"""Register a buffer for RDMA operations."""
if not self._connected:
return False
try:
# Register memory region for RDMA
# self._mr = pyverbs.MR(self._pd, buffer, ...)
self._buffers[addr] = buffer
return True
except Exception:
return False
def unregister_buffer(self, addr: str) -> bool:
"""Unregister a buffer."""
if addr in self._buffers:
del self._buffers[addr]
return True
return False
def get_stats(self) -> Dict[str, Any]:
"""Get RDMA connector statistics."""
stats = super().get_stats()
stats.update(
{
"registered_buffers": len(self._buffers),
}
)
return stats
@@ -0,0 +1,666 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import paddle
from paddleformers.utils.log import logger
# Import cupy for independent CUDA stream management
try:
import cupy as cp
_HAS_CUPY = True
except ImportError:
_HAS_CUPY = False
logger.warning("cupy not available, falling back to synchronous transfers")
# Import ops for cache swap
from fastdeploy.cache_manager.ops import (
swap_cache_per_layer, # sync fallback (used when cupy not available)
)
from fastdeploy.cache_manager.ops import (
swap_cache_per_layer_async, # async per-layer op (no cudaStreamSynchronize)
)
from fastdeploy.cache_manager.ops import swap_cache_all_layers
from fastdeploy.cache_manager.v1.storage import create_storage_connector
from fastdeploy.cache_manager.v1.transfer import create_transfer_connector
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
class CacheTransferManager:
"""
KV Cache Transfer Manager.
H2D (load): layer-by-layer on _input_stream, overlaps with forward compute.
D2H (evict): all-layers on _output_stream, fire-and-forget.
Data organization:
1. Name-indexed storage (_cache_kvs_map, _host_cache_kvs_map): for building layer indices
2. Layer-indexed storage (_device_key_caches, etc.): passed to swap operators
Attributes:
config: FDConfig instance.
"""
def __init__(
self,
config: "FDConfig",
local_rank: int = 0,
device_id: int = 0,
):
"""
Initialize the transfer manager.
Args:
config: FDConfig instance.
local_rank: Local rank for tensor parallel.
device_id: Device ID.
"""
self.config = config
self.cache_config = config.cache_config
self.quant_config = config.quant_config
self._local_rank = local_rank
self._device_id = device_id
self._num_layers = config.model_config.num_hidden_layers
self._cache_dtype = config.cache_config.cache_dtype
self._num_host_blocks = self.cache_config.num_cpu_blocks or 0
self._lock = threading.RLock()
# ============ Async Transfer Streams (cupy-based) ============
# Two independent CUDA streams for fully async transfer
# _input_stream: H2D transfer (load to device, layer-by-layer)
# _output_stream: D2H transfer (evict to host, all-layers)
# They run in parallel without waiting for each other
# Using cupy to avoid affecting Paddle's internal stream state
if _HAS_CUPY and paddle.is_compiled_with_cuda():
self._cupy_device_id = cp.cuda.runtime.getDevice()
logger.info(
f"[TransferManager] Creating streams: local_rank={self._local_rank}, device_id={self._device_id}, "
f"cupy_device_id={self._cupy_device_id}"
)
with cp.cuda.Device(self._cupy_device_id):
self._input_stream = cp.cuda.Stream(non_blocking=False)
self._output_stream = cp.cuda.Stream(non_blocking=False)
logger.info(
f"[TransferManager] Using cupy streams: input={id(self._input_stream)}, output={id(self._output_stream)}"
)
else:
self._input_stream = None
self._output_stream = None
logger.warning("[TransferManager] cupy not available, async transfers disabled")
# ============ KV Cache Data Storage ============
# Name-indexed storage (used to build layer-indexed structures below)
self._cache_kvs_map: Dict[str, Any] = {}
self._host_cache_kvs_map: Dict[str, Any] = {}
# Layer-indexed lists (for all-layer transfers, compatible with swap_cache_all_layers operator)
# Device cache tensors per layer (GPU)
self._device_key_caches: List[Any] = [] # key cache per layer
self._device_value_caches: List[Any] = [] # value cache per layer
self._device_key_scales: List[Any] = [] # key scales (fp8)
self._device_value_scales: List[Any] = [] # value scales (fp8)
# Host cache pointers per layer (CPU pinned memory)
self._host_key_ptrs: List[int] = [] # key host pointers
self._host_value_ptrs: List[int] = [] # value host pointers
self._host_key_scales_ptrs: List[int] = [] # key scale pointers (fp8)
self._host_value_scales_ptrs: List[int] = [] # value scale pointers (fp8)
# ============ Connectors (for future use) ============
self._storage_connector = create_storage_connector(self.cache_config)
self._transfer_connector = create_transfer_connector(self.cache_config)
# ============ Cache Map Setters ============
@property
def cache_kvs_map(self) -> Dict[str, Any]:
return self._cache_kvs_map
def set_cache_kvs_map(self, cache_kvs_map: Dict[str, Any]) -> None:
"""
Share the KV cache tensor map from CacheController.
Args:
cache_kvs_map: Dictionary mapping cache names to tensors.
Format: {
"key_caches_{layer_id}_rank{rank}.device{device}": paddle.Tensor,
"value_caches_{layer_id}_rank{rank}.device{device}": paddle.Tensor,
"key_cache_scales_{layer_id}_rank{rank}.device{device}": paddle.Tensor, # fp8
"value_cache_scales_{layer_id}_rank{rank}.device{device}": paddle.Tensor, # fp8
...
}
"""
with self._lock:
self._cache_kvs_map = cache_kvs_map
self._build_device_layer_indices()
def _build_device_layer_indices(self) -> None:
"""Build layer-indexed Device cache lists from _cache_kvs_map."""
if not self._cache_kvs_map:
self._device_key_caches = []
self._device_value_caches = []
self._device_key_scales = []
self._device_value_scales = []
return
self._device_key_caches = []
self._device_value_caches = []
self._device_key_scales = []
self._device_value_scales = []
for layer_idx in range(self._num_layers):
key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
self._device_key_caches.append(self._cache_kvs_map.get(key_name))
self._device_value_caches.append(self._cache_kvs_map.get(val_name))
if self._is_fp8_quantization():
self._device_key_scales.append(self._cache_kvs_map.get(key_scale_name))
self._device_value_scales.append(self._cache_kvs_map.get(val_scale_name))
@property
def host_cache_kvs_map(self) -> Dict[str, Any]:
return self._host_cache_kvs_map
def set_host_cache_kvs_map(self, host_cache_kvs_map: Dict[str, Any]) -> None:
"""
Share the Host KV cache tensor map from CacheController.
Args:
host_cache_kvs_map: Dictionary mapping cache names to Host pointers (int).
Format: {
"key_caches_{layer_id}_rank{rank}.device{device}": pointer (int),
...
}
"""
with self._lock:
self._host_cache_kvs_map = host_cache_kvs_map
self._build_host_layer_indices()
def _build_host_layer_indices(self) -> None:
"""Build layer-indexed Host pointer lists from _host_cache_kvs_map."""
if self._num_host_blocks <= 0:
return
if not self._host_cache_kvs_map:
return
if self._num_layers == 0:
return
self._host_key_ptrs = []
self._host_value_ptrs = []
self._host_key_scales_ptrs = []
self._host_value_scales_ptrs = []
for layer_idx in range(self._num_layers):
key_name = f"key_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
val_name = f"value_caches_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
key_scale_name = f"key_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
val_scale_name = f"value_cache_scales_{layer_idx}_rank{self._local_rank}.device{self._device_id}"
self._host_key_ptrs.append(self._host_cache_kvs_map.get(key_name, 0))
self._host_value_ptrs.append(self._host_cache_kvs_map.get(val_name, 0))
if self._is_fp8_quantization():
self._host_key_scales_ptrs.append(self._host_cache_kvs_map.get(key_scale_name, 0))
self._host_value_scales_ptrs.append(self._host_cache_kvs_map.get(val_scale_name, 0))
# ============ Metadata Properties ============
def _get_kv_cache_quant_type(self) -> Optional[str]:
"""Get KV cache quantization type."""
if (
self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None
):
return self.quant_config.kv_cache_quant_type
return None
def _is_fp8_quantization(self, quant_type: Optional[str] = None) -> bool:
"""Check if using fp8 quantization."""
if quant_type is None:
quant_type = self._get_kv_cache_quant_type()
return quant_type == "block_wise_fp8"
@property
def num_layers(self) -> int:
return self._num_layers
@property
def local_rank(self) -> int:
return self._local_rank
@property
def device_id(self) -> int:
return self._device_id
@property
def cache_dtype(self) -> str:
return self._cache_dtype
@property
def has_cache_scale(self) -> bool:
"""Check if cache has scale tensors (fp8)."""
return self._is_fp8_quantization()
@property
def num_host_blocks(self) -> int:
return self._num_host_blocks
# ============ Layer Indexed Access ============
def get_device_key_cache(self, layer_idx: int) -> Optional[Any]:
"""Get Device key cache tensor for a specific layer."""
if 0 <= layer_idx < len(self._device_key_caches):
return self._device_key_caches[layer_idx]
return None
def get_device_value_cache(self, layer_idx: int) -> Optional[Any]:
"""Get Device value cache tensor for a specific layer."""
if 0 <= layer_idx < len(self._device_value_caches):
return self._device_value_caches[layer_idx]
return None
def get_host_key_ptr(self, layer_idx: int) -> int:
"""Get Host key cache pointer for a specific layer."""
if self._num_host_blocks <= 0:
return 0
if 0 <= layer_idx < len(self._host_key_ptrs):
return self._host_key_ptrs[layer_idx]
return 0
def get_host_value_ptr(self, layer_idx: int) -> int:
"""Get Host value cache pointer for a specific layer."""
if self._num_host_blocks <= 0:
return 0
if 0 <= layer_idx < len(self._host_value_ptrs):
return self._host_value_ptrs[layer_idx]
return 0
# ============ Internal Sync Fallbacks (used when cupy not available) ============
def _swap_all_layers(
self,
device_block_ids: List[int],
host_block_ids: List[int],
mode: int,
) -> bool:
"""
Synchronous all-layer transfer fallback (used when cupy streams unavailable).
Args:
device_block_ids: Device block IDs to swap.
host_block_ids: Host block IDs to swap.
mode: 0=DeviceHost (evict), 1=HostDevice (load).
"""
if self._num_host_blocks <= 0:
return False
try:
swap_cache_all_layers(
self._device_key_caches,
self._host_key_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_all_layers(
self._device_value_caches,
self._host_value_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs:
swap_cache_all_layers(
self._device_key_scales,
self._host_key_scales_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_all_layers(
self._device_value_scales,
self._host_value_scales_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
return True
except Exception:
import traceback
traceback.print_exc()
return False
def _swap_single_layer(
self,
layer_idx: int,
device_block_ids: List[int],
host_block_ids: List[int],
mode: int,
) -> bool:
"""
Synchronous single-layer transfer fallback (used when cupy streams unavailable).
Args:
layer_idx: Layer index to transfer.
device_block_ids: Device block IDs to swap.
host_block_ids: Host block IDs to swap.
mode: 0=DeviceHost (evict), 1=HostDevice (load).
"""
if self._num_host_blocks <= 0:
return False
if not device_block_ids or not host_block_ids:
return False
if len(device_block_ids) != len(host_block_ids):
return False
try:
key_cache = self.get_device_key_cache(layer_idx)
value_cache = self.get_device_value_cache(layer_idx)
if key_cache is None or value_cache is None:
return False
key_ptr = self.get_host_key_ptr(layer_idx)
value_ptr = self.get_host_value_ptr(layer_idx)
if key_ptr == 0 or value_ptr == 0:
return False
swap_cache_per_layer(
key_cache,
key_ptr,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_per_layer(
value_cache,
value_ptr,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
return True
except Exception:
import traceback
traceback.print_exc()
return False
# ============ Async Transfer Methods ============
def _swap_all_layers_async(
self,
device_block_ids: List[int],
host_block_ids: List[int],
mode: int,
) -> bool:
"""
Async all-layer transfer on dedicated stream.
D2H uses _output_stream (fire-and-forget).
H2D uses _input_stream (but H2D always goes through _swap_single_layer_async).
Falls back to _swap_all_layers if cupy not available.
Args:
device_block_ids: Device block IDs to swap.
host_block_ids: Host block IDs to swap.
mode: 0=DeviceHost (evict), 1=HostDevice (load).
"""
if self._num_host_blocks <= 0:
return False
if self._input_stream is None or self._output_stream is None:
return self._swap_all_layers(device_block_ids, host_block_ids, mode)
stream = self._output_stream if mode == 0 else self._input_stream
try:
logger.debug(
f"[TransferManager] _swap_all_layers_async: local_rank={self._local_rank}, device_id={self._device_id}, "
f"cupy_device_id={self._cupy_device_id}, stream_device={stream.device_id}, mode={mode}"
)
with cp.cuda.Device(self._cupy_device_id):
with stream:
swap_cache_all_layers(
self._device_key_caches,
self._host_key_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_all_layers(
self._device_value_caches,
self._host_value_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
if self._is_fp8_quantization() and self._device_key_scales and self._host_key_scales_ptrs:
swap_cache_all_layers(
self._device_key_scales,
self._host_key_scales_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_all_layers(
self._device_value_scales,
self._host_value_scales_ptrs,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
return True
except Exception:
import traceback
traceback.print_exc()
return False
def _swap_single_layer_async(
self,
layer_idx: int,
device_block_ids: List[int],
host_block_ids: List[int],
mode: int,
) -> bool:
"""
Async single-layer transfer on _input_stream (H2D) or _output_stream (D2H).
Falls back to _swap_single_layer if cupy not available.
Args:
layer_idx: Layer index to transfer.
device_block_ids: Device block IDs to swap.
host_block_ids: Host block IDs to swap.
mode: 0=DeviceHost (evict), 1=HostDevice (load).
"""
if self._num_host_blocks <= 0:
return False
if self._input_stream is None or self._output_stream is None:
return self._swap_single_layer(layer_idx, device_block_ids, host_block_ids, mode)
stream = self._output_stream if mode == 0 else self._input_stream
key_cache = self.get_device_key_cache(layer_idx)
value_cache = self.get_device_value_cache(layer_idx)
if key_cache is None or value_cache is None:
return False
key_ptr = self.get_host_key_ptr(layer_idx)
value_ptr = self.get_host_value_ptr(layer_idx)
if key_ptr == 0 or value_ptr == 0:
return False
try:
with cp.cuda.Device(self._cupy_device_id):
with stream:
swap_cache_per_layer_async(
key_cache,
key_ptr,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
swap_cache_per_layer_async(
value_cache,
value_ptr,
self._num_host_blocks,
device_block_ids,
host_block_ids,
self._device_id,
mode,
)
return True
except Exception:
import traceback
traceback.print_exc()
return False
# ============ Public Async API ============
def evict_to_host_async(
self,
device_block_ids: List[int],
host_block_ids: List[int],
) -> bool:
"""
Async evict all layers of KV Cache from Device to Host (D2H).
Runs on _output_stream, fire-and-forget.
Args:
device_block_ids: Device block IDs to evict.
host_block_ids: Host block IDs to receive.
"""
return self._swap_all_layers_async(device_block_ids, host_block_ids, mode=0)
def load_layers_to_device_async(
self,
layer_indices: List[int],
host_block_ids: List[int],
device_block_ids: List[int],
on_layer_complete: Optional[callable] = None,
) -> bool:
"""
Async load KV Cache from Host to Device layer-by-layer (H2D).
Each layer runs on _input_stream. Overlaps with forward compute:
the callback is invoked after each layer's kernel is submitted so
the forward thread can start using that layer's data once the event fires.
Args:
layer_indices: Layer indices to load.
host_block_ids: Host block IDs to load from.
device_block_ids: Device block IDs to receive.
on_layer_complete: Optional callback(layer_idx) after each layer is submitted.
"""
if self._num_host_blocks <= 0:
return False
all_success = True
for layer_idx in layer_indices:
success = self._swap_single_layer_async(layer_idx, device_block_ids, host_block_ids, mode=1)
if not success:
all_success = False
if on_layer_complete is not None:
try:
on_layer_complete(layer_idx)
except Exception:
pass
return all_success
# ============ Stream Utilities ============
def sync_input_stream(self):
"""Wait for all pending _input_stream (H2D) transfers to complete."""
if self._input_stream is not None:
self._input_stream.synchronize()
def sync_output_stream(self):
"""Wait for all pending _output_stream (D2H) transfers to complete."""
if self._output_stream is not None:
self._output_stream.synchronize()
def record_input_stream_event(self) -> Any:
"""
Record a CUDA event on _input_stream and return it.
Used by _on_layer_complete callback in CacheController so that
LayerDoneCounter.wait_for_layer() can synchronize on the actual
H2D transfer stream rather than Paddle's default stream.
Returns:
cupy.cuda.Event if cupy streams are available, else None.
"""
if not _HAS_CUPY or self._input_stream is None:
return None
try:
with cp.cuda.Device(self._cupy_device_id):
event = cp.cuda.Event()
with self._input_stream:
event.record()
return event
except Exception as e:
logger.warning(f"[TransferManager] Failed to record input_stream event: {e}")
return None
def get_stats(self) -> Dict[str, Any]:
"""Get transfer manager statistics."""
return {
"num_layers": self._num_layers,
"local_rank": self._local_rank,
"device_id": self._device_id,
"cache_dtype": self._cache_dtype,
"num_host_blocks": self._num_host_blocks,
"has_device_cache": len(self._device_key_caches) > 0,
"has_host_cache": len(self._host_key_ptrs) > 0,
"is_fp8": self._is_fp8_quantization(),
}
+30 -1
View File
@@ -1610,7 +1610,8 @@ class CacheConfig:
self.enable_output_caching = False
self.disable_chunked_mm_input = False
self.kvcache_storage_backend = None
self.write_policy = None
self.write_policy = "write_through_selective"
self.write_through_threshold = 2
self.num_cpu_blocks = None
self.use_mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
@@ -1618,6 +1619,10 @@ class CacheConfig:
if hasattr(self, key):
setattr(self, key, value)
# ENABLE_V1_KVCACHE_MANAGER=0 uses the old cache_transfer_manager subprocess which only supports write_through.
if not envs.ENABLE_V1_KVCACHE_MANAGER:
self.write_policy = "write_through"
self.cache_queue_port = parse_ports(self.cache_queue_port)
self.rdma_comm_ports = parse_ports(self.rdma_comm_ports)
self.pd_comm_port = parse_ports(self.pd_comm_port)
@@ -1673,6 +1678,15 @@ class CacheConfig:
if self.kv_cache_ratio > 1.0:
raise ValueError("KV cache ratio must be less than 1.0. Got " f"{self.kv_cache_ratio}.")
if envs.ENABLE_V1_KVCACHE_MANAGER:
allowed_write_policies = ["write_through_selective", "write_back", "write_through"]
else:
allowed_write_policies = ["write_through"]
if self.write_policy not in allowed_write_policies:
raise ValueError(
f"Invalid write_policy: {self.write_policy!r}. " f"Expected one of {allowed_write_policies}."
)
def postprocess(self, num_total_tokens, number_of_tasks):
"""
calculate block num
@@ -2143,6 +2157,21 @@ class FDConfig:
"Static Graph does not support to be started together with RL Training, and automatically switch to dynamic graph!"
)
# Layer-by-layer swap (H2D) is always incompatible with CUDA Graph prefill capture.
# Force only decode to use CUDA Graph when host cache is configured.
if (
self.cache_config is not None
and self.cache_config.num_cpu_blocks
and self.graph_opt_config.cudagraph_only_prefill
):
original_value = self.graph_opt_config.cudagraph_only_prefill
self.graph_opt_config.cudagraph_only_prefill = False
logger.warning(
f"[CacheConfig] Layer-by-layer swap-in is incompatible "
f"with CUDA Graph prefill capture. Forcing cudagraph_only_prefill=False "
f"(only decode will use CUDA Graph). Original cudagraph_only_prefill={original_value}"
)
if (
not current_platform.is_cuda()
and not current_platform.is_maca()
+14 -3
View File
@@ -250,9 +250,13 @@ class EngineArgs:
"""
The storage backend for kvcache storage. If set, it will use the kvcache storage backend.
"""
write_policy: str = "write_through"
write_policy: str = "write_through_selective"
"""
The policy of write cache to storage.
The policy of write cache to storage. Options: write_through (alias for write_through_selective with threshold=1), write_through_selective, write_back.
"""
write_through_threshold: int = 2
"""
The threshold of hit count for write_through_selective policy. Only effective when write_policy is write_through_selective.
"""
# System configuration parameters
@@ -1168,11 +1172,18 @@ class EngineArgs:
cache_group.add_argument(
"--write-policy",
type=str,
choices=["write_through"],
choices=["write_through", "write_through_selective", "write_back"],
default=EngineArgs.write_policy,
help="KVCache write policy",
)
cache_group.add_argument(
"--write-through-threshold",
type=int,
default=EngineArgs.write_through_threshold,
help="Hit count threshold for write_through_selective policy. Only effective when write_policy is write_through_selective.",
)
# Cluster system parameters group
system_group = parser.add_argument_group("System Configuration")
system_group.add_argument(
+67 -40
View File
@@ -236,6 +236,11 @@ class EngineService:
self.ipc_signal_suffix = None
self.cache_manager_processes = None
if envs.ENABLE_V1_KVCACHE_MANAGER:
from fastdeploy.cache_manager.v1.cache_utils import get_request_block_hasher
self._block_hasher = get_request_block_hasher(block_size=self.cfg.cache_config.block_size)
self._finalizer = weakref.finalize(self, self._exit_sub_services)
def start(self, async_llm_pid=None):
@@ -272,7 +277,11 @@ class EngineService:
self.launch_components()
# If block number is specified and model is deployed in splitwise mode, start cache manager first
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
if (
not self.do_profile
and self.cfg.scheduler_config.splitwise_role != "mixed"
and not envs.ENABLE_V1_KVCACHE_MANAGER
):
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -304,7 +313,11 @@ class EngineService:
# and then start the cache manager
if self.do_profile:
self._stop_profile()
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
elif (
self.cfg.scheduler_config.splitwise_role == "mixed"
and self.cfg.cache_config.enable_prefix_caching
and not envs.ENABLE_V1_KVCACHE_MANAGER
):
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -472,19 +485,20 @@ class EngineService:
self.cfg.parallel_config.local_engine_worker_queue_port,
)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info(
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
)
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
if not envs.ENABLE_V1_KVCACHE_MANAGER:
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.llm_logger.info(
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
)
self.cache_task_queue = EngineCacheQueue(
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
is_server=True,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=-1,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
self.engine_worker_queue = EngineWorkerQueue(
address=address,
@@ -900,6 +914,10 @@ class EngineService:
task.metrics.engine_get_req_time = time.time()
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
# cache_manager_v1 set block_hasher to request
if hasattr(self, "_block_hasher"):
task.set_block_hasher(self._block_hasher)
if self.cfg.scheduler_config.splitwise_role == "decode":
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
@@ -1064,12 +1082,12 @@ class EngineService:
if hasattr(self.resource_manager, "scheduler_unhandled_request_num"):
self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num()
# 2. Schedule requests
tasks, error_tasks = self.resource_manager.schedule()
batch_request, error_tasks = self.resource_manager.schedule()
# 3. Send to engine
if tasks:
if len(batch_request) > 0:
if self.cfg.scheduler_config.splitwise_role == "decode":
for task in tasks:
for task in batch_request:
if task.task_type == RequestType.PREEMPTED:
msg = f"{task.request_id} decode not enough blocks, need to be rescheduled."
self.llm_logger.error(msg)
@@ -1084,7 +1102,7 @@ class EngineService:
]
)
self.resource_manager.get_real_bsz()
for task in tasks:
for task in batch_request:
if task.task_type == RequestType.PREFILL:
rid = task.request_id.split("_")[0]
if isinstance(task, Request) and task.has_been_preempted_before:
@@ -1119,13 +1137,13 @@ class EngineService:
task.metrics.decode_inference_start_time = time.time()
elif not task.has_been_preempted_before:
task.metrics.inference_start_time = time.time()
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
self.engine_worker_queue.put_tasks((batch_request, self.resource_manager.real_bsz))
else:
# When there are no actual tasks to schedule, send an empty task batch to EP workers.
# This helps EP workers barrier for syncing tasks not hang.
if self.cfg.parallel_config.enable_expert_parallel:
self.engine_worker_queue.put_tasks(
([], self.resource_manager.real_bsz)
(batch_request, self.resource_manager.real_bsz)
) # Empty (as idle tasks for ep)
# 4. Response error tasks
@@ -1136,7 +1154,7 @@ class EngineService:
continue
self._send_error_response(request_id, failed)
if not tasks and not error_tasks:
if len(batch_request) <= 0 and not error_tasks:
time.sleep(0.005)
except RuntimeError as e:
@@ -1428,22 +1446,25 @@ class EngineService:
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
self.scheduler.reset()
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
)
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(
self._wait_for_control_responses(
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
if envs.ENABLE_V1_KVCACHE_MANAGER:
self.resource_manager.cache_manager.reset_cache()
else:
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(
request_id=f"{control_request.request_id}_pause_transfer", method="pause"
)
)
self.llm_logger.info("Successfully paused cache transfer.")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(
self._wait_for_control_responses(
f"{pause_transfer_request.request_id}", 60, executors=["cache_transfer"]
)
)
self.llm_logger.info("Successfully paused cache transfer.")
self.resource_manager.cache_manager.reset()
self.resource_manager.cache_manager.reset()
self.llm_logger.info("Successfully paused request generation.")
return None
@@ -1726,10 +1747,14 @@ class EngineService:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
if envs.ENABLE_V1_KVCACHE_MANAGER:
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset_cache()
else:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
# Dispatch sleep request to executors
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
@@ -2543,6 +2568,8 @@ class EngineService:
self.cfg.cache_config.reset(num_gpu_blocks)
self.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
if envs.ENABLE_V1_KVCACHE_MANAGER:
return
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
+2 -2
View File
@@ -186,7 +186,7 @@ class LLMEngine:
if not self._stop_profile():
return False
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
if not current_platform.is_intel_hpu():
if not current_platform.is_intel_hpu() and not envs.ENABLE_V1_KVCACHE_MANAGER:
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -799,7 +799,7 @@ class LLMEngine:
self.cfg.cache_config.reset(num_gpu_blocks)
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
if not current_platform.is_intel_hpu():
if not current_platform.is_intel_hpu() and not envs.ENABLE_V1_KVCACHE_MANAGER:
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
return True
+186 -7
View File
@@ -21,16 +21,20 @@ import time
import traceback
from dataclasses import asdict, dataclass, fields
from enum import Enum
from typing import Any, Dict, Generic, Optional
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional
from typing import TypeVar as TypingTypeVar
from typing import Union
if TYPE_CHECKING:
from fastdeploy.cache_manager.v1.metadata import MatchResult
import numpy as np
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing_extensions import TypeVar
from fastdeploy import envs
from fastdeploy.cache_manager.v1.metadata import CacheLevel, CacheSwapMetadata
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import (
@@ -134,6 +138,8 @@ class Request:
# from PoolingRequest
add_special_tokens: Optional[bool] = False,
zmq_worker_pid: Optional[int] = None,
# block hasher for dynamic hash computation
block_hasher: Optional[callable] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -147,11 +153,18 @@ class Request:
self.tools = tools
# model specific token ids: end of sentence token ids
self.eos_token_ids = eos_token_ids
self.num_cached_tokens = 0
self.num_cached_blocks = 0
self.disable_chat_template = disable_chat_template
self.disaggregate_info = disaggregate_info
# prefix caching related
self.num_cached_tokens = 0
self.num_cached_blocks = 0
self._prompt_hashes: list[str] = []
self._block_hasher = block_hasher
self._match_result: Optional[MatchResult] = None
self.cache_swap_metadata: list[CacheSwapMetadata] = []
self.cache_evict_metadata: list[CacheSwapMetadata] = []
# speculative method in disaggregate-mode
self.draft_token_ids = draft_token_ids
@@ -224,6 +237,38 @@ class Request:
self.add_special_tokens = add_special_tokens
self.zmq_worker_pid = zmq_worker_pid
@property
def prompt_hashes(self) -> list[str]:
"""
Dynamically get prompt_hashes, automatically computing new block hashes.
When accessing this property, it checks if there are new complete blocks
that need hash computation, and if so, computes and appends them.
"""
if self._block_hasher is not None:
new_hashes = self._block_hasher(self)
if new_hashes:
self._prompt_hashes.extend(new_hashes)
return self._prompt_hashes
@property
def match_result(self) -> Optional[MatchResult]:
return self._match_result
def set_block_hasher(self, block_hasher: callable):
"""Set the block hasher for dynamic hash computation."""
self._block_hasher = block_hasher
def pop_cache_swap_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_swap_metadata
self.cache_swap_metadata = []
return result
def pop_cache_evict_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_evict_metadata
self.cache_evict_metadata = []
return result
@classmethod
def _process_guided_json(cls, r: T):
guided_json_object = None
@@ -413,17 +458,30 @@ class Request:
Custom getstate method for pickle support.
Handles unpicklable attributes by filtering them from __dict__.
"""
# Create a filtered dictionary without problematic attributes
# Attributes that cannot or need not be pickled for cross-process transfer.
# _block_hasher: closure/callable, not picklable.
# _match_result: contains BlockNode tree with parent<->children circular
# references, which causes RecursionError during pickling.
# async_process_futures: asyncio futures, not picklable.
_SKIP_KEYS = {"_block_hasher", "_match_result"}
filtered_dict = {}
for key, value in self.__dict__.items():
# Skip attributes that are known to contain unpicklable objects
if key == "async_process_futures":
if key in _SKIP_KEYS:
continue
elif key == "async_process_futures":
filtered_dict[key] = []
else:
filtered_dict[key] = value
return filtered_dict
def __setstate__(self, state):
self.__dict__.update(state)
# Restore fields that were excluded from pickling with safe defaults.
if "_block_hasher" not in self.__dict__:
self._block_hasher = None
if "_match_result" not in self.__dict__:
self._match_result = None
def __eq__(self, other):
"""
EQ operator.
@@ -553,6 +611,127 @@ class Request:
return hasattr(self, key)
class BatchRequest:
def __init__(self):
self.requests: list[Request] = []
self.cache_swap_metadata: Optional[CacheSwapMetadata] = None
self.cache_evict_metadata: Optional[CacheSwapMetadata] = None
def add_request(self, request):
if hasattr(request, "cache_swap_metadata") and request.cache_swap_metadata:
self.append_swap_metadata(request.pop_cache_swap_metadata())
request.cache_swap_metadata = []
if hasattr(request, "cache_evict_metadata") and request.cache_evict_metadata:
self.append_evict_metadata(request.pop_cache_evict_metadata())
request.cache_evict_metadata = []
self.requests.append(request)
def append_swap_metadata(self, metadata: List[CacheSwapMetadata]):
for meta in metadata:
if self.cache_swap_metadata:
self.cache_swap_metadata.src_block_ids.extend(meta.src_block_ids)
self.cache_swap_metadata.dst_block_ids.extend(meta.dst_block_ids)
self.cache_swap_metadata.hash_values.extend(meta.hash_values)
else:
self.cache_swap_metadata = CacheSwapMetadata(
src_block_ids=meta.src_block_ids,
dst_block_ids=meta.dst_block_ids,
src_type=CacheLevel.HOST,
dst_type=CacheLevel.DEVICE,
hash_values=meta.hash_values,
)
def append_evict_metadata(self, metadata: List[CacheSwapMetadata]):
for meta in metadata:
if self.cache_evict_metadata:
self.cache_evict_metadata.src_block_ids.extend(meta.src_block_ids)
self.cache_evict_metadata.dst_block_ids.extend(meta.dst_block_ids)
self.cache_evict_metadata.hash_values.extend(meta.hash_values)
else:
self.cache_evict_metadata = CacheSwapMetadata(
src_block_ids=meta.src_block_ids,
dst_block_ids=meta.dst_block_ids,
src_type=CacheLevel.DEVICE,
dst_type=CacheLevel.HOST,
hash_values=meta.hash_values,
)
def __repr__(self):
requests_repr = repr(self.requests)
return f"BatchRequest(requests={requests_repr}, swap_metadata={self.cache_swap_metadata}, evict_metadata={self.cache_evict_metadata})"
def __getstate__(self):
state = self.__dict__.copy()
state["requests"] = [req.__getstate__() if hasattr(req, "__getstate__") else req for req in state["requests"]]
return state
def __setstate__(self, state):
self.__dict__.update(state)
restored_requests = []
for req_data in self.requests:
if isinstance(req_data, dict):
req = Request.__new__(Request)
req.__dict__.update(req_data)
restored_requests.append(req)
else:
restored_requests.append(req_data)
self.requests = restored_requests
def __iter__(self):
for req in self.requests:
yield req
def __getitem__(self, index):
return self.requests[index]
def __len__(self):
return len(self.requests)
def append(self, batch_request: "BatchRequest"):
self.requests.extend(batch_request.requests)
if batch_request.cache_swap_metadata:
self.append_swap_metadata([batch_request.cache_swap_metadata])
if batch_request.cache_evict_metadata:
self.append_evict_metadata([batch_request.cache_evict_metadata])
def extend(self, batch_requests: list["BatchRequest"]):
for br in batch_requests:
self.append(br)
@classmethod
def from_tasks(cls, tasks: list) -> tuple["BatchRequest", list, int]:
"""Classify tasks from the engine worker queue into inference requests and control requests.
Args:
tasks: List of (payload, real_bsz) tuples from task_queue.get_tasks().
payload is one of: BatchRequest, List[Request], or [ControlRequest].
Returns:
(batch_request, control_reqs, max_occupied_batch_index)
- batch_request: merged BatchRequest containing all inference requests
- control_reqs: list of ControlRequest objects
- max_occupied_batch_index: real_bsz of the last inference task batch
"""
batch_request = cls()
control_reqs = []
max_occupied_batch_index = 0
for payload, bsz in tasks:
if len(payload) > 0 and isinstance(payload[0], ControlRequest):
control_reqs.append(payload[0])
else:
max_occupied_batch_index = int(bsz)
if isinstance(payload, cls):
batch_request.append(payload)
else:
for req in payload:
batch_request.add_request(req)
return batch_request, control_reqs, max_occupied_batch_index
class ControlRequest:
"""A generic control request that supports method and args for control operations.
+12 -2
View File
@@ -20,7 +20,7 @@ import time
import numpy as np
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy import envs
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import llm_logger
@@ -53,7 +53,17 @@ class ResourceManager:
self.max_num_seqs = max_num_seqs
self.stop_flags = [True] * max_num_seqs # flag set to true if the slot has not been taken
self.enable_prefix_cache = config.cache_config.enable_prefix_caching
self.cache_manager = PrefixCacheManager(config, tensor_parallel_size, splitwise_role, local_data_parallel_id)
self.enable_cache_manager_v1 = envs.ENABLE_V1_KVCACHE_MANAGER
if self.enable_cache_manager_v1:
from fastdeploy.cache_manager.v1 import CacheManager
self.cache_manager = CacheManager(config)
else:
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
self.cache_manager = PrefixCacheManager(
config, tensor_parallel_size, splitwise_role, local_data_parallel_id
)
self.tasks_list = [None] * max_num_seqs # task slots
self.req_dict = dict()
# current batch status of the engine
+189 -111
View File
@@ -21,8 +21,8 @@ import traceback
from collections import deque
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Union
from dataclasses import dataclass, field
from typing import List, Union
import numpy as np
import paddle
@@ -32,8 +32,10 @@ from fastdeploy.cache_manager.multimodal_cache_manager import (
EncoderCacheManager,
ProcessorCacheManager,
)
from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.request import (
BatchRequest,
ImagePosition,
Request,
RequestOutput,
@@ -53,46 +55,61 @@ from fastdeploy.utils import download_from_bos, init_bos_client, llm_logger
@dataclass
class ScheduledDecodeTask:
class ScheduledTaskBase:
"""
Task for Scheduled.
"""
idx: int
request_id: str
task_type: RequestType = RequestType.DECODE
cache_swap_metadata: list[CacheSwapMetadata] = field(default_factory=list)
cache_evict_metadata: list[CacheSwapMetadata] = field(default_factory=list)
def pop_cache_swap_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_swap_metadata
self.cache_swap_metadata = []
return result
def pop_cache_evict_metadata(self) -> list[CacheSwapMetadata]:
result = self.cache_evict_metadata
self.cache_evict_metadata = []
return result
@dataclass
class ScheduledDecodeTask(ScheduledTaskBase):
"""
Task for allocating new blocks to decode.
"""
idx: int
request_id: str
block_tables: list[int]
task_type: RequestType = RequestType.DECODE
block_tables: list[int] = field(default_factory=list)
@dataclass
class ScheduledPreemptTask:
class ScheduledPreemptTask(ScheduledTaskBase):
"""
Task for terminating inference to recycle resource.
"""
idx: int
request_id: str
task_type: RequestType = RequestType.PREEMPTED
@dataclass
class ScheduledExtendBlocksTask:
class ScheduledExtendBlocksTask(ScheduledTaskBase):
"""
Task for allocating new blocks to extend.
"""
idx: int
request_id: str
extend_block_tables: list[int]
task_type: RequestType = RequestType.EXTEND
extend_block_tables: list[int] = field(default_factory=list)
@dataclass
class ScheduledAbortTask:
class ScheduledAbortTask(ScheduledTaskBase):
"""Task for allocating new blocks to skip."""
idx: int
request_id: str
task_type: RequestType = RequestType.ABORT
@@ -243,6 +260,7 @@ class ResourceManagerV1(ResourceManager):
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
else:
block_num = min(block_num, self.config.cache_config.max_block_num_per_seq)
return block_num
def _prepare_prefill_task(self, request, new_token_num):
@@ -252,13 +270,29 @@ class ResourceManagerV1(ResourceManager):
return request
def _prepare_decode_task(self, request):
return ScheduledDecodeTask(idx=request.idx, request_id=request.request_id, block_tables=request.block_tables)
return ScheduledDecodeTask(
idx=request.idx,
request_id=request.request_id,
block_tables=request.block_tables,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def _prepare_preempt_task(self, request):
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
return ScheduledPreemptTask(
idx=request.idx,
request_id=request.request_id,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def _prepare_abort_task(self, request):
return ScheduledAbortTask(idx=request.idx, request_id=request.request_id)
return ScheduledAbortTask(
idx=request.idx,
request_id=request.request_id,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock:
@@ -284,14 +318,14 @@ class ResourceManagerV1(ResourceManager):
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()
def _trigger_abort(self, request_id, scheduled_reqs):
def _trigger_abort(self, request_id, batch_request):
if request_id in self.requests:
abort_request = self.requests[request_id]
abort_request.status = RequestStatus.PREEMPTED
abort_request.num_computed_tokens = 0
self._free_blocks(abort_request) # 释放KV cache blocks
abort_request.cached_block_num = 0
scheduled_reqs.append(self._prepare_abort_task(abort_request))
batch_request.add_request(self._prepare_abort_task(abort_request))
self.to_be_aborted_req_id_set.add(request_id)
self.waiting_abort_req_id_set.remove(request_id)
@@ -347,7 +381,7 @@ class ResourceManagerV1(ResourceManager):
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, batch_request):
"""
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
"""
@@ -384,7 +418,7 @@ class ResourceManagerV1(ResourceManager):
)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
batch_request.add_request(self._prepare_preempt_task(preempted_req))
llm_logger.debug(
f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}"
@@ -723,18 +757,12 @@ class ResourceManagerV1(ResourceManager):
# Compatible with scenarios without images and videos.
return num_new_tokens
def exist_mm_prefill(self, scheduled_reqs):
for request in scheduled_reqs:
def exist_mm_prefill(self, batch_request):
for request in batch_request:
if request.task_type == RequestType.PREFILL and self._is_mm_request(request):
return True
return False
def exist_prefill(self, scheduled_reqs):
for request in scheduled_reqs:
if request.task_type == RequestType.PREFILL:
return True
return False
def add_abort_req_ids(self, req_ids):
with self.lock:
if isinstance(req_ids, list):
@@ -757,15 +785,14 @@ class ResourceManagerV1(ResourceManager):
Try to pull a batch of requests from the waiting queue and schedule them.
"""
def get_enough_request(request, scheduled_reqs):
def get_enough_request(request, batch_request):
return (
ErnieArchitectures.is_ernie5_arch(self.config.model_config.architectures)
and self._is_mm_request(request)
and self.exist_mm_prefill(scheduled_reqs)
and self.exist_mm_prefill(batch_request)
)
with self.lock:
scheduled_reqs: list[Request] = []
preempted_reqs: list[Request] = []
error_reqs: list[tuple[str, str]] = []
tokens_per_seq = (
@@ -780,6 +807,7 @@ class ResourceManagerV1(ResourceManager):
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
need_abort_requests = [] # users trigger abortion
batch_request = BatchRequest()
# First, schedule the RUNNING requests.
req_index = 0
@@ -801,7 +829,7 @@ class ResourceManagerV1(ResourceManager):
request.num_computed_tokens = request.num_total_tokens - 1
if request.request_id in self.waiting_abort_req_id_set:
self._trigger_abort(request.request_id, scheduled_reqs)
self._trigger_abort(request.request_id, batch_request)
req_index += 1
need_abort_requests.append(request)
continue
@@ -816,27 +844,23 @@ class ResourceManagerV1(ResourceManager):
f"schedule decoding task: {request} request.num_total_tokens {request.num_total_tokens} request.num_computed_tokens {request.num_computed_tokens}"
)
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
self._allocate_gpu_blocks(request, self.config.cache_config.enc_dec_block_num)
)
# Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request))
batch_request.add_request(self._prepare_decode_task(request))
else:
# Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(
request, self.config.cache_config.enc_dec_block_num, preempted_reqs, scheduled_reqs
request, self.config.cache_config.enc_dec_block_num, preempted_reqs, batch_request
)
if not can_schedule:
break
# Allocation for next decoding blocks
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(
self.config.cache_config.enc_dec_block_num, request.request_id
)
self._allocate_gpu_blocks(request, self.config.cache_config.enc_dec_block_num)
)
# Prepare decoding task
scheduled_reqs.append(self._prepare_decode_task(request))
batch_request.add_request(self._prepare_decode_task(request))
num_decoding_req_nums += 1
token_budget -= 1
if (
@@ -848,10 +872,8 @@ class ResourceManagerV1(ResourceManager):
def _allocate_decode_and_extend():
allocate_block_num = self.need_block_num_map[request.request_id].consume()
# Prepare decoding task
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
)
scheduled_reqs.append(self._prepare_decode_task(request))
request.block_tables.extend(self._allocate_gpu_blocks(request, allocate_block_num))
batch_request.add_request(self._prepare_decode_task(request))
# Prepare extend task
reuse_block_num = request.num_total_tokens // self.config.cache_config.block_size
@@ -863,14 +885,14 @@ class ResourceManagerV1(ResourceManager):
self.reuse_block_num_map[request.request_id] = reuse_block_num
request.extend_block_tables = request.block_tables[:reuse_block_num] # copy prompt cache
request.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(allocate_block_num, request.request_id)
)
scheduled_reqs.append(
request.extend_block_tables.extend(self._allocate_gpu_blocks(request, allocate_block_num))
batch_request.add_request(
ScheduledExtendBlocksTask(
idx=request.idx,
request_id=request.request_id,
extend_block_tables=request.extend_block_tables,
cache_swap_metadata=request.pop_cache_swap_metadata(),
cache_evict_metadata=request.pop_cache_evict_metadata(),
)
)
llm_logger.debug(f"extend blocks is {request.extend_block_tables}")
@@ -887,7 +909,7 @@ class ResourceManagerV1(ResourceManager):
request,
2 * self.need_block_num_map[request.request_id].watch(),
preempted_reqs,
scheduled_reqs,
batch_request,
)
if can_schedule:
@@ -908,7 +930,7 @@ class ResourceManagerV1(ResourceManager):
):
req_index += 1
continue
if get_enough_request(request, scheduled_reqs):
if get_enough_request(request, batch_request):
req_index += 1
continue
num_new_tokens = self._get_num_new_tokens(request, token_budget)
@@ -918,26 +940,23 @@ class ResourceManagerV1(ResourceManager):
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
request.block_tables.extend(self._allocate_gpu_blocks(request, num_new_block))
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
else: # Not enough blocks to allocate, trigger preemption
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs)
can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, batch_request)
if not can_schedule:
break
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id)
)
request.block_tables.extend(self._allocate_gpu_blocks(request, num_new_block))
# Prepare prefill task
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and self.config.scheduler_config.splitwise_role != "prefill"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
@@ -962,7 +981,7 @@ class ResourceManagerV1(ResourceManager):
break
request = self.waiting[0]
if get_enough_request(request, scheduled_reqs):
if get_enough_request(request, batch_request):
break
if request.status == RequestStatus.WAITING:
result = self.waiting_async_process(request)
@@ -979,15 +998,16 @@ class ResourceManagerV1(ResourceManager):
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
if not self.enable_cache_manager_v1:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
@@ -1013,24 +1033,27 @@ class ResourceManagerV1(ResourceManager):
self.waiting.popleft()
continue
num_new_block = self.get_new_block_nums(request, num_new_tokens)
llm_logger.debug(
f"request.request_id {request.request_id} num_new_block {num_new_block}, request.need_prefill_tokens {request.need_prefill_tokens}, request.num_computed_tokens {request.num_computed_tokens}, token_budget {token_budget}"
)
can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block(
num_new_block
)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if num_new_block > 0:
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
extra_gpu_block_ids = self._allocate_gpu_blocks(request, num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
@@ -1055,15 +1078,16 @@ class ResourceManagerV1(ResourceManager):
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
):
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
if not self.enable_cache_manager_v1:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
break
success = self.get_prefix_cached_blocks(request)
if not success:
self._free_blocks(request)
@@ -1088,18 +1112,17 @@ class ResourceManagerV1(ResourceManager):
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold):
if num_new_block > 0:
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
num_new_block, request.request_id
)
extra_gpu_block_ids = self._allocate_gpu_blocks(request, num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
batch_request.add_request(self._prepare_prefill_task(request, num_new_tokens))
token_budget -= num_new_tokens
request.num_computed_tokens += num_new_tokens
if (
self.config.cache_config.enable_prefix_caching
and self.config.scheduler_config.splitwise_role != "decode"
and not self.enable_cache_manager_v1
):
self.cache_manager.update_cache_blocks(
request, self.config.cache_config.block_size, request.num_computed_tokens
@@ -1116,8 +1139,8 @@ class ResourceManagerV1(ResourceManager):
# move waiting request to end of the deque
self.waiting.append(req)
if scheduled_reqs:
llm_logger.debug(f"schedued_reqs: {scheduled_reqs}")
if len(batch_request) > 0:
llm_logger.debug(f"schedued_reqs: {batch_request}")
self.current_reserve_output_block_num_float -= self.decay_output_block_num
self.current_reserve_output_block_num = max(
int(self.current_reserve_output_block_num_float),
@@ -1127,11 +1150,22 @@ class ResourceManagerV1(ResourceManager):
if self.current_reserve_output_block_num == 0:
self.can_relax_prefill_strategy = True
self._log_console_scheduler_metrics(scheduled_reqs)
self._log_console_scheduler_metrics(batch_request)
self.update_metrics()
return scheduled_reqs, error_reqs
# Issue pending backup tasks to batch_request
# This handles write_through_selective policy by attaching backup tasks
# to the batch request, which will be processed by the worker
if self.enable_cache_manager_v1 and len(batch_request) > 0:
evict_metadata = self.cache_manager.issue_pending_backup_to_batch_request()
if evict_metadata:
batch_request.append_evict_metadata([evict_metadata])
if self.enable_cache_manager_v1:
self.cache_manager.check_and_add_pending_backup()
return batch_request, error_reqs
def waiting_async_process(self, request: Request) -> None:
"""
@@ -1257,11 +1291,45 @@ class ResourceManagerV1(ResourceManager):
break
return self.real_bsz
def get_prefix_cached_blocks(self, request: Request):
def _allocate_gpu_blocks(self, request: Request, num_blocks: int) -> List[int]:
llm_logger.debug(f"[allocate_gpu_blocks] request_id={request.request_id}, num_blocks={num_blocks}")
if self.enable_cache_manager_v1:
return self.cache_manager.allocate_gpu_blocks(request, num_blocks)
else:
return self.cache_manager.allocate_gpu_blocks(num_blocks, request.request_id)
def _request_match_blocks(self, request: Request, skip_storage: bool = True):
"""
Match and fetch cache for a task.
Prefixed cache manager v1 will match blocks for request and return common_block_ids.
"""
try:
if self.enable_cache_manager_v1:
self.cache_manager.match_prefix(request, skip_storage)
match_result = request.match_result
if skip_storage:
common_block_ids = match_result.device_block_ids
matched_token_num = match_result.total_matched_blocks * self.config.cache_config.block_size
metrics = {
"gpu_match_token_num": match_result.matched_device_nums * self.config.cache_config.block_size,
"cpu_match_token_num": match_result.matched_host_nums * self.config.cache_config.block_size,
"storage_match_token_num": match_result.matched_storage_nums * self.config.cache_config.block_size,
"match_gpu_block_ids": common_block_ids,
"gpu_recv_block_ids": [],
"match_storage_block_ids": [],
"cpu_cache_prepare_time": 0,
"storage_cache_prepare_time": 0,
}
no_cache_block_num = (
request.need_prefill_tokens - matched_token_num + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size
request.cache_info = [len(common_block_ids), no_cache_block_num]
return (common_block_ids, matched_token_num, metrics)
else:
# Prefetch cache from storage
pass
else:
(common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks(
request, self.config.cache_config.block_size
)
@@ -1273,6 +1341,18 @@ class ResourceManagerV1(ResourceManager):
)
request.cache_info = [matched_block_num, no_cache_block_num]
return (common_block_ids, matched_token_num, metrics)
def get_prefix_cached_blocks(self, request: Request):
"""
Match and fetch cache for a task.
"""
try:
(common_block_ids, matched_token_num, metrics) = self._request_match_blocks(
request # skip_storage 使用默认值 True
)
request.block_tables = common_block_ids
request.num_cached_tokens = matched_token_num
if self.config.cache_config.disable_chunked_mm_input:
@@ -1375,9 +1455,7 @@ class ResourceManagerV1(ResourceManager):
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(
need_extra_prefill_blocks, request.request_id
)
extra_gpu_block_ids = self._allocate_gpu_blocks(request, need_extra_prefill_blocks)
request.block_tables.extend(extra_gpu_block_ids)
allocated_position = self.get_available_position()
request.idx = allocated_position
@@ -1397,9 +1475,7 @@ class ResourceManagerV1(ResourceManager):
else:
if self.cache_manager.can_allocate_gpu_blocks(need_prealloc_prefill_blocks):
request.block_tables.extend(
self.cache_manager.allocate_gpu_blocks(need_prealloc_prefill_blocks, request.request_id)
)
request.block_tables.extend(self._allocate_gpu_blocks(request, need_prealloc_prefill_blocks))
request.num_computed_tokens = 0
allocated_position = self.get_available_position()
request.idx = allocated_position
@@ -1432,9 +1508,7 @@ class ResourceManagerV1(ResourceManager):
if not self.cache_manager.can_allocate_gpu_blocks(total_need_blocks):
return False
request.block_tables = self.cache_manager.allocate_gpu_blocks(
need_prealloc_prefill_blocks, request.request_id
)
request.block_tables = self._allocate_gpu_blocks(request, need_prealloc_prefill_blocks)
request.num_computed_tokens = request.need_prefill_tokens
request.disaggregate_info["block_tables"] = request.block_tables
allocated_position = self.get_available_position()
@@ -1486,7 +1560,11 @@ class ResourceManagerV1(ResourceManager):
self.running.append(request)
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode":
if self.enable_cache_manager_v1:
self.cache_manager.request_finish(request)
elif (
self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode"
):
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(
request.block_tables[request.num_cached_blocks :], request.request_id
@@ -1600,7 +1678,7 @@ class ResourceManagerV1(ResourceManager):
f")"
)
def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | ScheduledDecodeTask]) -> None:
def _log_console_scheduler_metrics(self, batch_request: BatchRequest) -> None:
if not (
hasattr(self, "scheduler_metrics_logger")
and self.scheduler_metrics_logger is not None
@@ -1617,8 +1695,8 @@ class ResourceManagerV1(ResourceManager):
scheduler_queue_cnt = max(int(getattr(self, "scheduler_unhandled_request_num", 0) or 0), 0)
queue_cnt = len(self.waiting) + scheduler_queue_cnt
prefill_reqs = [r for r in scheduled_reqs if isinstance(r, Request) and r.task_type == RequestType.PREFILL]
has_decode = any(getattr(r, "task_type", None) == RequestType.DECODE for r in scheduled_reqs)
prefill_reqs = [r for r in batch_request if isinstance(r, Request) and r.task_type == RequestType.PREFILL]
has_decode = any(getattr(r, "task_type", None) == RequestType.DECODE for r in batch_request)
self.scheduler_metrics_logger.log_prefill_batch(
prefill_reqs=prefill_reqs,
+2
View File
@@ -269,6 +269,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))),
# Whether to enable FP8 quantization with pow2scale.
"FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))),
# enable kv cache manager v1
"ENABLE_V1_KVCACHE_MANAGER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_MANAGER", "0")),
}
+5 -1
View File
@@ -17,7 +17,7 @@
import logging
from dataclasses import dataclass, fields
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
import paddle
@@ -149,6 +149,10 @@ class ForwardMeta:
# Routing Replay table buffer
routing_replay_table: Optional[paddle.Tensor] = None
# ============ V1 KVCACHE Manager: Swap-in waiting info ============
# LayerDoneCounter for layer-by-layer swap waiting (set by submit_swap_tasks return value)
layer_done_counter: Optional[Any] = None
# chunked MoE related
moe_num_chunk: int = 1
max_moe_num_chunk: int = 1
@@ -272,6 +272,11 @@ class Attention(nn.Layer):
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
"""
# ============ V1 KVCACHE Manager: Layer-by-layer swap wait ============
# Wait for swap-in of current layer before using cache
if forward_meta.layer_done_counter is not None:
forward_meta.layer_done_counter.wait_for_layer(self.layer_id)
return forward_meta.attn_backend.forward(
q,
k,
+1
View File
@@ -1044,6 +1044,7 @@ class TokenProcessor:
envs.ENABLE_V1_KVCACHE_SCHEDULER
and self.cfg.cache_config.enable_prefix_caching
and self.cfg.cache_config.enable_output_caching
and not envs.ENABLE_V1_KVCACHE_MANAGER
):
self.resource_manager.cache_output_tokens(
task
+9 -2
View File
@@ -438,13 +438,20 @@ class MTPProposer(Proposer):
if self.forward_meta is not None:
del self.forward_meta.caches
def update_mtp_block_num(self, num_gpu_blocks) -> None:
def update_mtp_block_num(self, num_gpu_blocks, skip_cache_init: bool = False) -> None:
"""
Update MTP block num by theoretical calculation
Args:
num_gpu_blocks: Main model GPU block count.
skip_cache_init: When True, skip internal initialize_kv_cache call.
Set this when the caller (e.g. gpu_model_runner with enable_cache_manager_v1)
has already re-created MTP cache via cache_controller.
"""
# Reset block table and kv cache with global block num
self.main_model_num_gpu_blocks = num_gpu_blocks
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
if not skip_cache_init:
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
# Reset free list
free_list = list(
+73 -26
View File
@@ -29,7 +29,7 @@ from paddleformers.utils.log import logger
from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.request import ImagePosition, Request, RequestType
from fastdeploy.engine.request import BatchRequest, ImagePosition, Request, RequestType
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
@@ -91,6 +91,7 @@ else:
import zmq
from fastdeploy import envs
from fastdeploy.cache_manager.v1 import CacheController
from fastdeploy.engine.tasks import PoolingTask
from fastdeploy.input.image_processors.adaptive_processor import AdaptiveImageProcessor
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
@@ -272,6 +273,19 @@ class GPUModelRunner(ModelRunnerBase):
create=False,
)
# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
self.dsa_cache = envs.FD_ATTENTION_BACKEND == "DSA_ATTN"
self.enable_cache_manager_v1 = envs.ENABLE_V1_KVCACHE_MANAGER
if self.enable_cache_manager_v1:
self.cache_controller = CacheController(
fd_config,
self.local_rank,
self.device_id,
)
# for overlap
self._cached_model_output_data = None
self._cached_sampler_output = None
@@ -725,7 +739,7 @@ class GPUModelRunner(ModelRunnerBase):
)
return feature_positions
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
def insert_tasks_v1(self, req_dicts: BatchRequest, num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
@@ -742,6 +756,13 @@ class GPUModelRunner(ModelRunnerBase):
"position_ids_offset": [0],
"max_tokens_lst": [],
}
if self.enable_cache_manager_v1:
# submit_swap_tasks handles:
# 1. Waiting for pending evict handlers before submitting new evict
# 2. write_back policy: waiting for evict to complete before submitting swap-in
# 3. Adding handlers to pending lists appropriately
self.cache_controller.submit_swap_tasks(req_dicts.cache_evict_metadata, req_dicts.cache_swap_metadata)
for i in range(req_len):
request = req_dicts[i]
idx = self.share_inputs.get_index_by_batch_id(request.idx)
@@ -1423,10 +1444,35 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_meta.is_zero_size = self.forward_meta.ids_remove_padding.shape[0] == 0
self.forward_meta.exist_prefill = self.exist_prefill()
# ============ V1 KVCACHE Manager: Swap-in waiting config ============
if self.enable_cache_manager_v1:
self.forward_meta.layer_done_counter = self.cache_controller.swap_layer_done_counter
else:
self.forward_meta.layer_done_counter = None
def initialize_kv_cache(self, profile: bool = False) -> None:
"""
Initialize kv cache
"""
if self.enable_cache_manager_v1:
self.share_inputs["caches"] = self.cache_controller.initialize_kv_cache(
attn_backend=self.attn_backends[0],
num_gpu_blocks=self.num_gpu_blocks,
)
self.cache_kvs_map = self.cache_controller.get_kv_caches()
if self.spec_method == SpecMethod.MTP:
mtp_num_blocks = int(self.num_gpu_blocks * self.proposer.speculative_config.num_gpu_block_expand_ratio)
mtp_cache_list = self.cache_controller.initialize_mtp_kv_cache(
attn_backend=self.proposer.attn_backends[0],
num_gpu_blocks=mtp_num_blocks,
num_mtp_layers=self.proposer.model_config.num_hidden_layers,
layer_offset=self.proposer.num_main_model_layers,
)
self.proposer.num_gpu_blocks = mtp_num_blocks
self.proposer.cache_kvs_map = self.cache_controller.get_kv_caches()
self.proposer.model_inputs["caches"] = mtp_cache_list
return
# cache_kvs = {}
max_block_num = self.num_gpu_blocks
@@ -1434,13 +1480,6 @@ class GPUModelRunner(ModelRunnerBase):
cache_type = self.model_config.dtype
kv_cache_quant_type = None
# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
from fastdeploy import envs
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
self.dsa_cache = envs.FD_ATTENTION_BACKEND == "DSA_ATTN"
if (
self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
@@ -2245,15 +2284,16 @@ class GPUModelRunner(ModelRunnerBase):
return model_inputs, p_done_idxs, token_num_event
def _execute(self, model_inputs: Dict[str, paddle.Tensor]) -> None:
model_output = None
if model_inputs is not None and len(model_inputs) > 0:
model_output = self.model(
model_inputs,
self.forward_meta,
)
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
else:
model_output = None
return model_output
def _postprocess(
@@ -2639,7 +2679,8 @@ class GPUModelRunner(ModelRunnerBase):
self.num_gpu_blocks = self.cache_config.total_block_num
self.initialize_kv_cache(profile=True)
if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
if not self.enable_cache_manager_v1:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
# 1. Profile with multimodal encoder & encoder cache
@@ -2686,7 +2727,7 @@ class GPUModelRunner(ModelRunnerBase):
)
if self.spec_method == SpecMethod.MTP:
self.proposer.update_mtp_block_num(num_gpu_blocks)
self.proposer.update_mtp_block_num(num_gpu_blocks, skip_cache_init=self.enable_cache_manager_v1)
def cal_theortical_kvcache(self):
"""
@@ -2749,17 +2790,21 @@ class GPUModelRunner(ModelRunnerBase):
def clear_cache(self, profile=False):
"""Clear cached data from shared inputs and forward metadata"""
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if self.enable_cache_manager_v1:
self.cache_controller.free_gpu_cache()
else:
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not create_cache_tensor:
for name, tensor in self.cache_kvs_map.items():
unset_data_ipc(tensor, name, True, False)
self.cache_ready_signal.value[local_rank] = 0
if not create_cache_tensor:
for name, tensor in self.cache_kvs_map.items():
unset_data_ipc(tensor, name, True, False)
self.cache_ready_signal.value[local_rank] = 0
self.cache_kvs_map.clear()
self.share_inputs.pop("caches", None)
if self.forward_meta is not None:
@@ -2806,7 +2851,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs.reset_share_inputs()
if self.spec_method == SpecMethod.MTP:
self.proposer.model_inputs.reset_model_inputs()
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
if not self.enable_cache_manager_v1:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache()
# Recapture CUDAGraph
if self.use_cudagraph:
@@ -2843,7 +2889,7 @@ class GPUModelRunner(ModelRunnerBase):
if self.is_kvcache_sleeping:
logger.info("GPU model runner's kv cache is already sleeping, no need to sleep again!")
return
if self.spec_method == SpecMethod.MTP:
if self.spec_method == SpecMethod.MTP and not self.enable_cache_manager_v1:
self.proposer.clear_mtp_cache()
self.clear_cache()
self.is_kvcache_sleeping = True
@@ -2875,7 +2921,8 @@ class GPUModelRunner(ModelRunnerBase):
logger.info("GPU model runner's kv cache is not sleeping, no need to wakeup!")
return
if self.spec_method == SpecMethod.MTP:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
if not self.enable_cache_manager_v1:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks)
self.initialize_kv_cache()
self.is_kvcache_sleeping = False
+2 -2
View File
@@ -24,7 +24,7 @@ from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.engine.request import BatchRequest, Request
from fastdeploy.plugins.model_runner import load_model_runner_plugins
from fastdeploy.usage.usage_lib import report_usage_stats
from fastdeploy.utils import get_logger, set_random_seed
@@ -209,7 +209,7 @@ class GpuWorker(WorkerBase):
output = self.model_runner.execute_model(model_forward_batch, num_running_request)
return output
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
def preprocess_new_task(self, req_dicts: BatchRequest, num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
+22 -29
View File
@@ -49,7 +49,12 @@ from fastdeploy.config import (
SpeculativeConfig,
StructuredOutputsConfig,
)
from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
from fastdeploy.engine.request import (
BatchRequest,
ControlRequest,
ControlResponse,
RequestType,
)
from fastdeploy.eplb.async_expert_loader import (
MODEL_MAIN_NAME,
REARRANGE_EXPERT_MAGIC_NUM,
@@ -549,39 +554,27 @@ class PaddleDisWorkerProc:
if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill":
paddle.distributed.barrier(self.parallel_config.ep_group)
req_dicts, control_reqs = [], []
assert (
len(tasks) > 0
), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}"
# In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue.
# tasks[0] contains two part, ([req1, ...] ,real_bsz)
# tasks[0][0] is [req1, ...]
# if empty batch is delived, eval(tasks[0][0]) should be False ([]),
# if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below.
if tasks[0][0]:
for req_dict, bsz in tasks:
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
batch_request, control_reqs, max_occupied_batch_index = BatchRequest.from_tasks(tasks)
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
if self.parallel_config.use_ep:
self.cached_control_reqs.append(control_req)
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
if self.parallel_config.use_ep:
self.cached_control_reqs.append(control_req)
logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}")
else:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
if len(req_dicts) > 0:
if len(batch_request) > 0:
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(req_dicts)
scheduled_request_ids = [req.request_id for req in req_dicts]
num_prefill_requests = sum(1 for req in batch_request if req.task_type == RequestType.PREFILL)
num_scheduled_requests = len(batch_request)
scheduled_request_ids = [req.request_id for req in batch_request]
logger.info(
f"Rank: {self.local_rank}, num_prefill_requests: {num_prefill_requests}, "
f"max_occupied_batch_index: {max_occupied_batch_index}, "
@@ -590,7 +583,7 @@ class PaddleDisWorkerProc:
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index)
self.worker.preprocess_new_task(batch_request, max_occupied_batch_index)
else:
if self.scheduler_config.splitwise_role == "prefill":
if tp_size > 1: