Files
FastDeploy/fastdeploy/cache_manager/v1/cache_controller.py
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

1090 lines
42 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import ctypes
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import paddle
from paddleformers.utils.log import logger
if TYPE_CHECKING:
from fastdeploy.config import FDConfig
# Import ops for CPU cache allocation
from fastdeploy.cache_manager.ops import cuda_host_alloc, cuda_host_free
from .base import KVCacheBase
from .cache_utils import LayerDoneCounter
from .metadata import (
AsyncTaskHandler,
CacheLevel,
CacheSwapMetadata,
PDTransferMetadata,
StorageMetadata,
TransferResult,
)
from .transfer_manager import CacheTransferManager
class CacheController(KVCacheBase):
"""
Cache Controller for Worker process.
Inherits KVCacheBase, handles transfer tasks by block index only, does NOT manage BlockPool.
BlockPool is managed by CacheManager. CacheController only executes transfers
based on block IDs provided by Scheduler.
All transfer methods are async - they submit tasks and return immediately,
returning an AsyncTaskHandler for the caller to track completion.
Three-level cache hierarchy:
Level 1: Device (GPU) - Fastest access, directly used for inference
Level 2: Host (CPU) - Medium speed, needs to be loaded to Device
Level 3: Storage - Slowest, needs to be fetched to Host first
Attributes:
transfer_manager: CacheTransferManager instance.
layer_counter: LayerDoneCounter instance.
num_layers: Total number of model layers.
"""
def __init__(self, config: "FDConfig", local_rank: int, device_id: int):
"""
Initialize the Cache Controller.
Args:
config: FDConfig instance containing all fastdeploy configuration
"""
super().__init__(config)
self._num_layers = self.model_config.num_hidden_layers
self._local_rank = local_rank
self._device_id = device_id
# cache_kvs_map: stores created kv cache tensors by name
self.cache_kvs_map: Dict[str, Any] = {}
# host_cache_kvs_map: stores Host (pinned memory) kv cache tensors by name for swap space
self.host_cache_kvs_map: Dict[str, Any] = {}
# Thread safety
self._lock = threading.RLock()
# Thread pool executor for async operations
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="cache_transfer")
# Initialize transfer manager
self._transfer_manager = CacheTransferManager(config, local_rank, device_id)
# Note: LayerDoneCounter is no longer a singleton
# Each submit_swap_tasks call creates a new LayerDoneCounter instance
self._layer_done_counter = None
# Pending evict LayerDoneCounters for write_back mode ordering
self._pending_evict_counters: List["LayerDoneCounter"] = []
self._initialized = True
# NUMA binding flag
self._numa_bound = False
@property
def write_policy(self) -> Optional[str]:
"""Get the write policy for cache operations."""
if self.cache_config and hasattr(self.cache_config, "write_policy"):
return self.cache_config.write_policy
return None
def _should_wait_for_swap_out(self) -> bool:
"""
Determine if swap-out operations should wait synchronously.
Returns:
True if write_policy is 'write_back', otherwise False.
"""
return self.write_policy == "write_back"
def submit_swap_tasks(
self,
evict_metadata: Optional["CacheSwapMetadata"],
swap_in_metadata: Optional["CacheSwapMetadata"],
) -> Optional["LayerDoneCounter"]:
"""
Submit evict and swap-in tasks with proper synchronization.
Logic:
1. Before submitting evict, wait for existing pending evict counters to complete
2. write_back: Wait for evict to complete before submitting swap-in
3. Other policies: Submit both evict and swap-in immediately
Args:
evict_metadata: CacheSwapMetadata for device-to-host eviction (can be None)
swap_in_metadata: CacheSwapMetadata for host-to-device swap-in (can be None)
Returns:
LayerDoneCounter for swap-in task, or None if no swap-in metadata provided.
"""
# Step 1: Wait for existing pending evict counters before submitting new evict
self._wait_for_pending_evict_counters()
# Step 2: Submit evict task if provided
# Note: evict returns LayerDoneCounter but we don't wait on it layer-by-layer
# (except in write_back mode where we wait synchronously via wait_all)
if evict_metadata is not None:
evict_counter = self.evict_device_to_host(evict_metadata)
self._pending_evict_counters.append(evict_counter)
# Step 3: For write_back, wait for evict to complete before submitting swap-in
if self._should_wait_for_swap_out():
self._wait_for_pending_evict_counters()
# Step 4: Submit swap-in task if provided
# Returns LayerDoneCounter for tracking layer completion
if swap_in_metadata is not None:
self._layer_done_counter = self.load_host_to_device(swap_in_metadata)
return self._layer_done_counter
return None
def _wait_for_pending_evict_counters(self) -> None:
"""
Wait for all pending evict counters to complete.
This is called before submitting new evict tasks to ensure proper ordering.
Uses LayerDoneCounter.wait_all() for efficient waiting.
"""
if not self._pending_evict_counters:
return
evict_wait_start = time.time()
evict_length = len(self._pending_evict_counters)
for counter in self._pending_evict_counters:
counter.wait_all()
self._pending_evict_counters.clear()
evict_wait_ms = (time.time() - evict_wait_start) * 1000
if evict_wait_ms > 0.1:
logger.info(f"cache evict wait time: {evict_wait_ms:.2f}ms, {evict_length} pending evictions")
# ============ Properties ============
@property
def transfer_manager(self) -> CacheTransferManager:
"""Get the transfer manager."""
return self._transfer_manager
@property
def swap_layer_done_counter(self) -> Optional["LayerDoneCounter"]:
"""Get the layer done counter for layer swap."""
return self._layer_done_counter
# ============ Helper Methods ============
def _get_kv_cache_quant_type(self) -> Optional[str]:
"""Get KV cache quantization type."""
if (
self.quant_config
and hasattr(self.quant_config, "kv_cache_quant_type")
and self.quant_config.kv_cache_quant_type is not None
):
return self.quant_config.kv_cache_quant_type
return None
def _is_fp8_quantization(self, quant_type: Optional[str] = None) -> bool:
"""Check if using fp8 quantization."""
if quant_type is None:
quant_type = self._get_kv_cache_quant_type()
return quant_type == "block_wise_fp8"
def _get_cache_names(self, layer_idx: int) -> Dict[str, str]:
"""
Generate cache names for a layer.
Args:
layer_idx: Layer index.
Returns:
Dictionary with cache names: {
"key": "key_caches_{layer}_rank{rank}.device{device}",
"value": "value_caches_{layer}_rank{rank}.device{device}",
"key_scale": "key_cache_scales_{layer}_rank{rank}.device{device}",
"value_scale": "value_cache_scales_{layer}_rank{rank}.device{device}",
}
"""
local_rank = self._local_rank % self.parallel_config.tensor_parallel_size
return {
"key": f"key_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value": f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"key_scale": f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value_scale": f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
}
# ============ KV Cache Management ============
def get_kv_caches(self) -> Optional[Dict[str, Any]]:
"""
Get the current KV Cache tensor dictionary.
Returns:
KV Cache tensor dictionary, None if not initialized.
"""
with self._lock:
return self.cache_kvs_map
def initialize_kv_cache(
self,
attn_backend: Any,
num_gpu_blocks: int,
) -> List[Any]:
"""
Initialize KV Cache tensors.
Create KV Cache tensors on GPU for storing attention Key and Value.
Args:
attn_backend: Attention backend instance for getting kv cache shape.
num_gpu_blocks: Maximum number of blocks on GPU.
Returns:
cache_kvs_list: KV Cache tensor list in [key_cache_layer0, value_cache_layer0, ...] order.
"""
# Get kv cache quantization type
kv_cache_quant_type = self._get_kv_cache_quant_type()
# Get kv cache shape
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
# Get scale shape for block_wise_fp8 quantization
kv_cache_scale_shape = None
if self._is_fp8_quantization(kv_cache_quant_type):
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
logger.info(f"Initializing kv cache for all layers. num_layers={self._num_layers}")
cache_kvs_list = []
for i in range(self._num_layers):
# Generate cache names
cache_names = self._get_cache_names(i)
logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
# Create key cache and value cache
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
# Create scale caches for block_wise_fp8 quantization
if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
paddle.device.cuda.empty_cache()
logger.info("kv cache is initialized!")
# Share cache_kvs_map with transfer manager for data transfer operations
self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map)
# Initialize host cache
self.initialize_host_cache(attn_backend)
return cache_kvs_list
def initialize_mtp_kv_cache(
self,
attn_backend: Any,
num_gpu_blocks: int,
num_mtp_layers: int,
layer_offset: int,
) -> List[Any]:
"""
Initialize MTP (speculative decode) KV Cache tensors.
MTP cache layers use indices [layer_offset, layer_offset + num_mtp_layers),
so they share the same cache_kvs_map namespace as the main model cache but
with non-overlapping layer indices. All subsequent transfer operations
via CacheController automatically cover MTP layers as well because they
live in the same cache_kvs_map.
Args:
attn_backend: MTP attention backend instance (proposer.attn_backends[0]).
num_gpu_blocks: Number of GPU blocks for MTP (already expanded by ratio).
num_mtp_layers: Number of MTP model layers (proposer.model_config.num_hidden_layers).
layer_offset: Starting layer index, equals main model num_hidden_layers.
Returns:
cache_kvs_list: KV Cache tensor list in [key_layer0, val_layer0, ...] order.
"""
kv_cache_quant_type = self._get_kv_cache_quant_type()
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
kv_cache_scale_shape = None
if self._is_fp8_quantization(kv_cache_quant_type):
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
logger.info(
f"[CacheController] Initializing MTP kv cache for {num_mtp_layers} layers "
f"(layer_offset={layer_offset}, num_gpu_blocks={num_gpu_blocks})."
)
cache_kvs_list = []
for i in range(layer_offset, layer_offset + num_mtp_layers):
cache_names = self._get_cache_names(i)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["key"]] = key_cache
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=self.model_config.dtype)
self.cache_kvs_map[cache_names["value"]] = val_cache
cache_kvs_list.extend([key_cache, val_cache])
if self._is_fp8_quantization(kv_cache_quant_type) and kv_cache_scale_shape:
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype()
)
self.cache_kvs_map[cache_names["key_scale"]] = key_cache_scales
self.cache_kvs_map[cache_names["value_scale"]] = val_cache_scales
cache_kvs_list.extend([key_cache_scales, val_cache_scales])
paddle.device.cuda.empty_cache()
logger.info("[CacheController] MTP kv cache initialized!")
# Refresh transfer manager so it sees the full map (main + MTP layers)
self._transfer_manager.set_cache_kvs_map(self.cache_kvs_map)
return cache_kvs_list
def _get_numa_node_for_gpu(self, device_id: int) -> int:
"""
Get the NUMA node closest to the specified GPU device.
Tries multiple methods in order:
1. nvidia-smi topo -C -i <gpu_id> (fastest and most reliable)
2. /sys/class/nvidia-gpu/ (direct sysfs)
3. /sys/bus/pci/devices/ (fallback)
Args:
device_id: CUDA device ID.
Returns:
NUMA node index, or -1 if cannot be determined.
"""
try:
# Method 1: Use nvidia-smi topo -C -i (fastest, SGLang-style)
# This directly outputs the NUMA ID for the specific GPU
try:
import subprocess
result = subprocess.run(
["nvidia-smi", "topo", "-C", "-i", str(device_id)], capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
output_line = result.stdout.strip()
prefix = "NUMA IDs of closest CPU:"
if output_line.startswith(prefix):
numa_str = output_line[len(prefix) :].strip()
# Handle comma-separated or range values (e.g., "0" or "0,1" or "0-1")
if numa_str:
# Take the first NUMA node if multiple are listed
first_numa = numa_str.split(",")[0].split("-")[0].strip()
if first_numa.isdigit():
return int(first_numa)
except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
logger.debug(f"[CacheController] nvidia-smi topo -C method failed: {e}")
# Method 2: Try to read from /sys filesystem
sys_path = f"/sys/class/nvidia-gpu/nvidia{device_id}/device/numa_node"
if os.path.exists(sys_path):
with open(sys_path, "r") as f:
return int(f.read().strip())
# Method 3: Fallback - check all NVIDIA PCI devices
import glob
numa_paths = glob.glob("/sys/bus/pci/devices/*/numa_node")
for path in numa_paths:
vendor_path = path.replace("numa_node", "vendor")
if os.path.exists(vendor_path):
with open(vendor_path, "r") as f:
vendor = f.read().strip()
if vendor == "0x10de": # NVIDIA vendor ID
with open(path, "r") as f:
return int(f.read().strip())
return -1
except Exception as e:
logger.debug(f"[CacheController] Failed to get NUMA node for GPU {device_id}: {e}")
return -1
def _bind_to_closest_numa_node(self) -> bool:
"""
Bind current thread and memory allocation to the NUMA node closest to the GPU.
This should be called before allocating host memory to ensure the memory
is allocated on the NUMA node local to the GPU, reducing cross-NUMA access
latency during H2D transfers.
Returns:
True if binding was successful, False otherwise.
"""
if self._numa_bound:
return True
try:
# Load libnuma
try:
libnuma = ctypes.CDLL("libnuma.so.1")
except OSError:
try:
libnuma = ctypes.CDLL("libnuma.so")
except OSError:
logger.warning("[CacheController] libnuma not found, NUMA binding skipped")
return False
# Check if NUMA is available
if libnuma.numa_available() < 0:
logger.warning("[CacheController] NUMA is not available on this system")
return False
# Get NUMA node for current GPU
numa_node = self._get_numa_node_for_gpu(self._device_id)
if numa_node < 0:
logger.warning(f"[CacheController] Could not determine NUMA node for GPU {self._device_id}")
return False
# Bind current thread to specific NUMA node
# numa_run_on_node binds the current thread to run on the specified node
result = libnuma.numa_run_on_node(numa_node)
if result < 0:
logger.warning(f"[CacheController] numa_run_on_node({numa_node}) failed")
return False
# Set memory allocation preference to the specified NUMA node
# This affects subsequent memory allocations (including cudaHostAlloc)
libnuma.numa_set_preferred(numa_node)
self._numa_bound = True
logger.info(
f"[CacheController] NUMA binding successful: " f"GPU {self._device_id} bound to NUMA node {numa_node}"
)
return True
except Exception as e:
logger.warning(f"[CacheController] NUMA binding failed: {e}")
return False
def initialize_host_cache(
self,
attn_backend: Any,
) -> Dict[str, Any]:
"""
Initialize Host (Pinned Memory) KV Cache.
Use cuda_host_alloc to allocate pinned memory for fast Host-Device data transfer.
Called during initialization to create Host-side swap space.
Args:
attn_backend: Attention backend instance for getting kv cache shape.
Returns:
host_cache_kvs_map: Host KV Cache pointer dictionary, indexed by name.
"""
num_host_blocks = self.cache_config.num_cpu_blocks
if num_host_blocks == 0:
logger.info("[CacheController] No swap space (Host cache) specified, skipping initialization.")
return
if len(self.host_cache_kvs_map) > 0:
return
# Step 0: Bind to closest NUMA node before allocating host memory
# This ensures subsequent cuda_host_alloc allocations are on the local NUMA node
if not self._numa_bound:
self._bind_to_closest_numa_node()
# Get kv cache quantization type
kv_cache_quant_type = self._get_kv_cache_quant_type()
# Get kv cache shape (pass num_host_blocks as max_num_blocks for host cache)
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
max_num_blocks=num_host_blocks, kv_cache_quant_type=kv_cache_quant_type
)
# Calculate cache sizes (elements per block per layer)
key_cache_size = key_cache_shape[1] * key_cache_shape[2] * key_cache_shape[3]
if value_cache_shape:
value_cache_size = value_cache_shape[1] * value_cache_shape[2] * value_cache_shape[3]
else:
value_cache_size = 0
# Get cache dtype and bytes per element
cache_dtype = self.cache_config.cache_dtype
cache_item_bytes = self.cache_config.get_cache_bytes(cache_dtype)
# Calculate total bytes to allocate
key_need_to_allocate_bytes = num_host_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = num_host_blocks * cache_item_bytes * value_cache_size
# Calculate scale sizes for block_wise_fp8 quantization
scales_key_need_to_allocate_bytes = 0
scales_value_need_to_allocate_bytes = 0
cache_scale_shape = None
if self._is_fp8_quantization(kv_cache_quant_type):
cache_scales_size = key_cache_shape[1] * key_cache_shape[2]
# Scale tensor uses default dtype (float32)
scale_bytes = 4 # float32
scales_key_need_to_allocate_bytes = num_host_blocks * scale_bytes * cache_scales_size
scales_value_need_to_allocate_bytes = num_host_blocks * scale_bytes * cache_scales_size
cache_scale_shape = [num_host_blocks, key_cache_shape[1], key_cache_shape[2]]
num_layers = self._num_layers + self.config.speculative_config.num_extra_cache_layer
per_layer_size_gb = (key_need_to_allocate_bytes + value_need_to_allocate_bytes) / (1024**3)
actual_alloc_gb = per_layer_size_gb * num_layers
logger.info(
f"[CacheController] Host swap space allocated: {actual_alloc_gb:.2f}GB "
f"({per_layer_size_gb:.2f}GB per layer x {num_layers} layers), "
f"num_host_blocks: {num_host_blocks}"
)
logger.info(f"[CacheController] Initializing swap space (Host cache) for {num_layers} layers.")
# Allocate Host cache for each layer
for i in range(num_layers):
# Generate cache names
cache_names = self._get_cache_names(i)
logger.info(
f"[CacheController] Creating Host cache for layer {i}: "
f"key={(key_need_to_allocate_bytes / 1024 ** 3):.2f}GB, "
f"value={(value_need_to_allocate_bytes / 1024 ** 3):.2f}GB"
)
# Allocate key cache using cuda_host_alloc (pinned memory)
self.host_cache_kvs_map[cache_names["key"]] = cuda_host_alloc(key_need_to_allocate_bytes)
# Allocate scale cache for block_wise_fp8 quantization
if self._is_fp8_quantization(kv_cache_quant_type):
self.host_cache_kvs_map[cache_names["key_scale"]] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
# Allocate value cache if needed
if value_need_to_allocate_bytes > 0:
self.host_cache_kvs_map[cache_names["value"]] = cuda_host_alloc(value_need_to_allocate_bytes)
if self._is_fp8_quantization(kv_cache_quant_type):
self.host_cache_kvs_map[cache_names["value_scale"]] = cuda_host_alloc(
scales_value_need_to_allocate_bytes
)
logger.info(f"[CacheController] Swap space (Host cache) is ready for {num_layers} layers!")
# Store shapes for later use
self._host_key_cache_shape = [num_host_blocks] + list(key_cache_shape[1:])
self._host_value_cache_shape = [num_host_blocks] + list(value_cache_shape[1:]) if value_cache_shape else None
self._host_cache_scale_shape = cache_scale_shape
self._num_host_blocks = num_host_blocks
# Share host_cache_kvs_map with transfer manager
self._transfer_manager.set_host_cache_kvs_map(self.host_cache_kvs_map)
def get_host_cache_kvs_map(self) -> Dict[str, Any]:
"""
Get the Host KV Cache pointer dictionary.
Returns:
Host KV Cache pointer dictionary, empty dict if not initialized.
"""
return self.host_cache_kvs_map
# ============ Worker Methods ============
def _submit_swap_task(
self,
meta: CacheSwapMetadata,
src_location: CacheLevel,
dst_location: CacheLevel,
transfer_fn_all: callable,
transfer_fn_layer: callable,
force_all_layers: bool = False,
) -> LayerDoneCounter:
"""
Submit a single swap transfer task (internal method).
Creates a LayerDoneCounter for tracking layer completion.
The counter is returned to the caller for later waiting.
H2D (load) always uses layer-by-layer mode for compute-transfer overlap.
D2H (evict) always uses all-layers mode via _output_stream (fire-and-forget).
Args:
meta: CacheSwapMetadata containing src_block_ids and dst_block_ids.
src_location: Source cache level (CacheLevel.HOST or CacheLevel.DEVICE).
dst_location: Destination cache level (CacheLevel.DEVICE or CacheLevel.HOST).
transfer_fn_all: All-layer transfer function, signature (src_ids, dst_ids) -> bool.
transfer_fn_layer: Layer-by-layer transfer function, signature (layer_indices, on_layer_complete, src_ids, dst_ids) -> bool.
force_all_layers: If True, always use all-layers mode (used for D2H evict).
Returns:
LayerDoneCounter instance for tracking layer completion.
"""
# Create LayerDoneCounter for this transfer (independent sync primitive)
layer_counter = LayerDoneCounter(self._num_layers)
src_block_ids = meta.src_block_ids
dst_block_ids = meta.dst_block_ids
if not src_block_ids or not dst_block_ids:
logger.info(f"[SwapTask] skip: empty block_ids src={src_block_ids}, dst={dst_block_ids}")
meta.success = False
meta.error_message = "Empty block IDs in CacheSwapMetadata"
return layer_counter
layers_to_transfer = list(range(self._num_layers))
def _on_layer_complete(layer_idx: int) -> None:
"""Callback called after each layer's H2D kernel is submitted to input_stream.
Records a CUDA event on input_stream so that wait_for_layer() can
synchronize on the actual transfer stream (cross-stream dependency).
"""
# Record event on _input_stream so wait_for_layer() waits for the real H2D transfer.
# Must use input_stream (not Paddle default stream) to capture the correct dependency.
stream_event = self._transfer_manager.record_input_stream_event()
if stream_event is not None:
layer_counter.set_layer_event(layer_idx, stream_event)
# Mark layer done (adds to _completed_layers, unblocks polling fallback)
layer_counter.mark_layer_done(layer_idx)
def _do_transfer():
try:
start_time = time.time()
if force_all_layers:
success = transfer_fn_all(src_block_ids, dst_block_ids)
elapsed = time.time() - start_time
if success:
# For H2D transfers: record event on _input_stream so that
# wait_all() synchronizes on the actual transfer stream, not
# Paddle's default stream. set_layer_event must be called
# before mark_all_done() so wait_all()'s loop finds the event.
if dst_location == CacheLevel.DEVICE:
stream_event = self._transfer_manager.record_input_stream_event()
if stream_event is not None:
layer_counter.set_layer_event(self._num_layers - 1, stream_event)
# Mark all layers done at once
layer_counter.mark_all_done()
result = TransferResult(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_type=src_location,
dst_type=dst_location,
success=success,
error_message=(
None if success else f"All-layer {src_location.value}{dst_location.value} transfer failed"
),
)
logger.debug(
f"[SwapTask] all_layers {src_location.value}->{dst_location.value} "
f"{'success' if success else 'FAILED'} "
f"src={src_block_ids} dst={dst_block_ids} elapsed={elapsed*1000:.3f}ms"
)
else:
success = transfer_fn_layer(
layers_to_transfer,
_on_layer_complete,
src_block_ids,
dst_block_ids,
)
elapsed = time.time() - start_time
result = TransferResult(
src_block_ids=src_block_ids,
dst_block_ids=dst_block_ids,
src_type=src_location,
dst_type=dst_location,
success=success,
error_message=(
None
if success
else f"Layer-by-layer {src_location.value}{dst_location.value} transfer failed"
),
)
logger.debug(
f"[SwapTask] layer_by_layer {src_location.value}->{dst_location.value} "
f"{'success' if success else 'FAILED'} "
f"src={src_block_ids} dst={dst_block_ids} elapsed={elapsed*1000:.3f}ms"
)
# Update metadata with result
meta.success = result.success
meta.error_message = result.error_message
except Exception as e:
import traceback
traceback.print_exc()
logger.error(
f"[SwapTask] {src_location.value}->{dst_location.value} "
f"EXCEPTION: {e}\n{traceback.format_exc()}"
)
meta.success = False
meta.error_message = str(e)
finally:
# Cleanup CUDA events when transfer is complete
layer_counter.cleanup()
self._executor.submit(_do_transfer)
return layer_counter
def load_host_to_device(
self,
swap_metadata: CacheSwapMetadata,
) -> LayerDoneCounter:
"""
Load host cache to device (async).
Creates an async transfer task and returns LayerDoneCounter
for tracking layer completion.
Args:
swap_metadata: CacheSwapMetadata containing:
- src_block_ids: Source host block IDs
- dst_block_ids: Destination device block IDs
Returns:
LayerDoneCounter for tracking layer completion.
"""
layer_counter = self._submit_swap_task(
meta=swap_metadata,
src_location=CacheLevel.HOST,
dst_location=CacheLevel.DEVICE,
transfer_fn_all=None,
transfer_fn_layer=lambda layer_indices, on_layer_complete, src_ids, dst_ids: self._transfer_manager.load_layers_to_device_async(
layer_indices=layer_indices,
host_block_ids=src_ids,
device_block_ids=dst_ids,
on_layer_complete=on_layer_complete,
),
)
return layer_counter
def evict_device_to_host(
self,
swap_metadata: CacheSwapMetadata,
) -> LayerDoneCounter:
"""
Evict device cache to host (async).
Creates an async transfer task and returns LayerDoneCounter
for tracking layer completion.
Args:
swap_metadata: CacheSwapMetadata containing:
- src_block_ids: Source device block IDs
- dst_block_ids: Destination host block IDs
Returns:
LayerDoneCounter for tracking layer completion.
"""
layer_counter = self._submit_swap_task(
meta=swap_metadata,
src_location=CacheLevel.DEVICE,
dst_location=CacheLevel.HOST,
transfer_fn_all=lambda src_ids, dst_ids: self._transfer_manager.evict_to_host_async(src_ids, dst_ids),
transfer_fn_layer=None,
force_all_layers=True, # Eviction always uses output_stream for all-layers async transfer
)
return layer_counter
def prefetch_from_storage(
self,
metadata: StorageMetadata,
) -> AsyncTaskHandler:
"""
Prefetch storage cache to host (async).
When Scheduler matches cache in storage, Worker uses this method
to pull data from storage to host.
Args:
metadata: Storage transfer metadata, containing:
- hash_values: Hash values to fetch
- block_ids: Destination host block IDs (pre-allocated by Scheduler)
- Other storage-specific parameters
Returns:
AsyncTaskHandler for tracking the async transfer task.
"""
handler = AsyncTaskHandler()
# TODO: Implement storage prefetch logic
handler.set_error("Storage prefetch not implemented yet")
return handler
def backup_device_to_storage(
self,
device_block_ids: List[int],
metadata: StorageMetadata,
) -> AsyncTaskHandler:
"""
Backup device cache to storage (async).
Backup KV cache from device memory to external storage
for reuse by subsequent requests.
Args:
device_block_ids: Device block IDs to backup.
metadata: Storage transfer metadata.
Returns:
AsyncTaskHandler for tracking the async transfer task.
"""
handler = AsyncTaskHandler()
# TODO: Implement storage backup logic
handler.set_error("Storage backup not implemented yet")
return handler
def backup_host_to_storage(
self,
host_block_ids: List[int],
metadata: StorageMetadata,
) -> AsyncTaskHandler:
"""
Backup host cache to storage (async).
Backup KV cache from host memory to external storage.
Args:
host_block_ids: Host block IDs to backup.
metadata: Storage transfer metadata.
Returns:
AsyncTaskHandler for tracking the async transfer task.
"""
handler = AsyncTaskHandler()
# TODO: Implement storage backup logic
handler.set_error("Storage backup not implemented yet")
return handler
def send_to_node(
self,
metadata: PDTransferMetadata,
) -> AsyncTaskHandler:
"""
Send cache to another node (PD separation, async).
In PD separation architecture, P node uses this method
to send KV cache to D node.
Args:
metadata: PD transfer metadata, containing:
- target_node_id: Target node identifier
- block_ids: Block IDs to transfer
- Other transfer-specific parameters
Returns:
AsyncTaskHandler for tracking the async transfer task.
"""
handler = AsyncTaskHandler()
# TODO: Implement PD separation transfer logic
handler.set_error("PD transfer not implemented yet")
return handler
def wait_for_transfer_from_node(
self,
metadata: PDTransferMetadata,
) -> AsyncTaskHandler:
"""
Wait for cache transfer from another node (PD separation, async).
In PD separation architecture, D node uses this method
to wait for P node to send KV cache.
Args:
metadata: PD transfer metadata, containing:
- source_node_id: Source node identifier
- block_ids: Block IDs to receive
- Other transfer-specific parameters
Returns:
AsyncTaskHandler for tracking the async transfer task.
"""
handler = AsyncTaskHandler()
# TODO: Implement PD separation transfer wait logic
handler.set_error("PD transfer not implemented yet")
return handler
# ============ Public Interface Implementation ============
def reset_cache(self) -> bool:
"""
Reset cache state (clear content only, do NOT free storage).
This method only clears the transfer state:
- Clears pending evict counters
It does NOT free any storage (GPU memory, CPU pinned memory, or storage).
Use free_cache() to release storage resources.
Returns:
True if successful, False otherwise.
"""
try:
with self._lock:
# Clear pending evict counters
self._pending_evict_counters.clear()
return True
except Exception:
return False
def free_cache(self, clear_storage: bool = False) -> bool:
"""
Free all cache storage (GPU memory + CPU pinned memory + storage).
This releases all underlying storage resources, not just clears content.
Use this when shutting down or wanting to fully release cache resources.
Returns:
True if successful, False otherwise.
"""
try:
# First reset transfer state
self.reset_cache()
# Free GPU cache
self.free_gpu_cache()
# Free CPU cache (pinned memory)
self._free_host_cache()
# Clear storage
if clear_storage:
self._clear_storage()
return True
except Exception:
return False
def free_gpu_cache(self) -> None:
"""Free GPU cache tensors stored in cache_kvs_map."""
if not hasattr(self, "cache_kvs_map") or not self.cache_kvs_map:
return
logger.info(f"[CacheController] Freeing GPU cache memory, {len(self.cache_kvs_map)} tensors.")
self.cache_kvs_map.clear()
paddle.device.cuda.empty_cache()
logger.info("[CacheController] GPU cache memory released.")
def _clear_storage(self) -> None:
"""Clear storage connector cache."""
storage_connector = getattr(self._transfer_manager, "_storage_connector", None)
if not storage_connector:
return
try:
if hasattr(storage_connector, "clear") and callable(storage_connector.clear):
count = storage_connector.clear()
logger.info(f"[CacheController] Cleared {count} entries from storage.")
elif hasattr(storage_connector, "disconnect") and callable(storage_connector.disconnect):
storage_connector.disconnect()
logger.info("[CacheController] Storage connector disconnected.")
except Exception as e:
logger.warning(f"[CacheController] Failed to clear storage: {e}")
# ============ Statistics Methods ============
def get_stats(self) -> Dict[str, Any]:
"""Get controller statistics."""
with self._lock:
return {
"initialized": self._initialized,
"num_layers": self._num_layers,
"pending_evict_counters": len(self._pending_evict_counters),
"transfer_manager": self._transfer_manager.get_stats(),
}
def start(self) -> None:
"""Start the transfer manager."""
self._transfer_manager.start()
def stop(self) -> None:
"""Stop the transfer manager and shutdown thread pool."""
self._transfer_manager.stop()
# Shutdown thread pool executor
self._executor.shutdown(wait=False)
def __del__(self) -> None:
"""Destructor to release pinned host memory."""
try:
self._free_host_cache()
except Exception:
pass
def _free_host_cache(self) -> None:
"""Free pinned host memory allocated for swap space."""
if not hasattr(self, "host_cache_kvs_map"):
return
if not self.host_cache_kvs_map:
return
logger.info(f"[CacheController] Freeing host cache memory, {len(self.host_cache_kvs_map)} tensors.")
for name, ptr in list(self.host_cache_kvs_map.items()):
if ptr != 0:
try:
cuda_host_free(ptr)
except Exception as e:
logger.warning(f"[CacheController] Failed to free host cache {name}: {e}")
self.host_cache_kvs_map.clear()
logger.info("[CacheController] Host cache memory released.")