Files
FastDeploy/custom_ops/gpu_ops/swap_cache_optimized.cu
T
kevin 7707be8384 [Feature][KVCache] Implement Cache Manager V1 with GPU + CPU Cache Support (1/n) (#7097)
* [Feature][KVCache] Support cache manager v1 architecture

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Update cache manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: update cache_manager and related modules

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: add node to evictable set in complete_swap_to_device

When a node transitions from SWAP_TO_DEVICE to DEVICE via
complete_swap_to_device, it was not being added to the
_evictable_device set. This caused nodes with ref_count=0 to
become "orphaned" - not appearing in any evictable set despite
having cache_status=DEVICE.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat: update cache manager v1 and related modules

- Add new cache_manager.py with cache management functionality
- Add radix_tree.py for prefix caching
- Update block_pool.py and metadata.py
- Update request.py and resource_manager_v1.py for scheduling
- Update gpu_model_runner.py for GPU model execution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* feat(cache): add cache controller v1 implementation

- Add CacheController class for cache management
- Update config.py with cache related configurations
- Refactor gpu_model_runner.py for improved cache handling

* feat(cache_manager): update cache manager v1

* fix(cache_manager): 修复 swap_cache H2D/D2H 方向的 block_ids 逻辑并清理 ForwardMeta

## Motivation

修复 swap_cache_optimized.cu 中 H2D 方向时 src/dst block_ids 使用错误的问题,
并清理 ForwardMeta 中已废弃的 cache_controller 字段。

## Modifications

- fix: swap_cache_optimized.cu 中根据 D2H 模板参数正确选取 src/dst block_ids,
  修复 H2D 方向 src/dst 倒置 bug(同时修复 SwapCachePerLayerImpl 和 SwapCacheAllLayersBatchImpl)
- refactor: cache_manager/v1/__init__.py 将 LayerSwapTimeoutError 导入从
  cache_controller 改为 cache_utils(正确来源)
- refactor: ForwardMeta 移除废弃的 cache_controller 字段
- refactor: gpu_model_runner.py 移除对应的 cache_controller 赋值语句
- test: 新增 tests/cache_manager/v1/test_swap_cache_ops.py 单元测试

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* feat(cache_manager): refactor cache manager v1 and optimize swap ops

## Motivation

对 cache manager v1 进行重构和优化,精简代码结构,提升可维护性。

## Modifications

- 重构 transfer_manager.py,大幅精简代码逻辑
- 优化 swap_cache_optimized.cu GPU 算子实现
- 调整 cache_manager.py、cache_controller.py 逻辑,修复 free_device_blocks 方法缺失问题
- 更新 block_pool.py、cache_utils.py、metadata.py、radix_tree.py
- 精简 gpu_model_runner.py、forward_meta.py、attention.py 中相关调用
- 更新对应单元测试(test_cache_controller、test_swap_cache_ops、test_transfer_manager)
- 调整 config.py 中相关配置项

* [KVCache][MTP] 支持 cache_manager_v1 下的 MTP KV Cache 初始化及多模态 hash

## Motivation

在 enable_cache_manager_v1 路径下,MTP(speculative decode)的 KV Cache 需要由
CacheController 统一管理,以复用 swap/transfer 能力,同时修复多模态场景下 block
hash 未携带 multimodal extra_keys 的问题。

## Modifications

- `cache_controller.py`
  - 新增 `initialize_mtp_kv_cache`:通过 CacheController 初始化 MTP KV Cache,
    并将其注册到 cache_kvs_map,使 transfer_manager 自动覆盖 MTP 层
  - `initialize_host_cache` 中的 num_layers 改为包含 MTP 额外 cache 层数,保证
    Host Cache 也为 MTP 分配足够空间
  - `_free_gpu_cache` 改名为 `free_gpu_cache`(对外可调用)

- `cache_utils.py`
  - 新增 `get_block_hash_extra_keys`:提取单个 block 内的多模态 hash 信息,
    对齐 PrefixCacheManager 的 multimodal extra_keys 逻辑
  - `get_request_block_hasher` 中在 hash_block_tokens 时携带 extra_keys,
    修复多模态场景 prefix cache 命中率不准的问题

- `spec_decode/mtp.py`
  - `update_mtp_block_num` 新增 `skip_cache_init` 参数,避免 v1 cache manager
    路径下重复初始化 MTP KV Cache

- `gpu_model_runner.py`
  - `initialize_kv_cache(v1)` 路径:在主模型 cache 初始化后,调用
    `cache_controller.initialize_mtp_kv_cache` 完成 MTP cache 创建
  - `clear_cache` / `wakeup` / `reset` 等路径:respect `enable_cache_manager_v1`
    标志,跳过重复的 proposer.initialize_kv_cache 调用

## Usage or Command

```bash
# 启动支持 MTP + cache_manager_v1 的推理服务(示例)
bash run.sh
```

* fix(cache_manager): multi-GPU fix, mm hash boundary fix, and remove batch ops

1. Fix CuPy stream/event creation for multi-GPU: wrap all stream operations
   with cp.cuda.Device(device_id) context to ensure streams/events are bound
   to the correct device, preventing cross-device errors in multi-GPU setups.

2. Remove cudaSetDevice from SwapCacheAllLayers (handled by cupy context now).

3. Remove swap_cache_all_layers_batch op: simplified the implementation by
   removing the batch upload variant; all-layer transfers now use the standard
   swap_cache_all_layers with cupy device context.

4. Fix mm hash boundary comparison in get_block_hash_extra_keys: change
   strict less-than (<) to less-than-or-equal (<=) so that multimodal items
   ending exactly at block start are correctly excluded.

5. Extract config fields to KVCacheBase: model_config, cache_config,
   quant_config, parallel_config are now set in the base class __init__ to
   avoid duplication in CacheController and CacheManager subclasses.

6. Translate metadata.py docstrings from Chinese to English for broader
   contributor accessibility.

7. Add test_cache_utils.py: comprehensive unit tests for
   get_block_hash_extra_keys covering all boundary and overlap scenarios.

8. Expand test suite: test_request.py cache fields tests, test_radix_tree.py
   backup candidate tests, test_transfer_manager.py and test_cache_manager.py
   multi-GPU and concurrent operation tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix List import and move write_policy normalization to CacheManager

## Motivation

修复两处问题:
1. `fastdeploy/engine/request.py` 中 `List` 未导入导致 pre-commit F821 报错
2. `write_policy` 归一化逻辑(`write_through` → `write_through_selective`)不应放在 `FDConfig`,移至 `CacheManager.__init__` 中,使其只影响 Cache Manager V1 的内部逻辑

## Modifications

- `fastdeploy/engine/request.py`: 在 `typing` 导入中补充 `List`,删除重复的 `CacheSwapMetadata` TYPE_CHECKING 导入,修复 F821/F811
- `fastdeploy/config.py`: 删除 `write_policy` 归一化逻辑
- `fastdeploy/cache_manager/v1/cache_manager.py`: 将归一化逻辑移入 `CacheManager.__init__`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix pre-commit code style issues

## Motivation

修复 CI pre-commit 代码风格检查失败问题。

## Modifications

- `fastdeploy/engine/common_engine.py`: black 格式化
- `fastdeploy/worker/worker_process.py`: black 格式化 + isort 修复
- `fastdeploy/cache_manager/v1/storage/__init__.py`: isort 修复
- `fastdeploy/worker/gpu_worker.py`: isort 修复

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] update cache_manager_v1 modules

## Motivation

更新 Cache Manager V1 相关模块,完善版权信息、改进模块结构与可维护性。

## Modifications

- `fastdeploy/cache_manager/v1/` 系列模块:补充版权 header,优化代码结构
- `fastdeploy/config.py`:配置项更新
- `fastdeploy/engine/sched/resource_manager_v1.py`:调度相关更新

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add BatchRequest.from_tasks and refactor worker task parsing

## Motivation

将 worker_process 中重复的 task 解析逻辑收敛到 BatchRequest,减少代码冗余,提升可维护性。

## Modifications

- `fastdeploy/engine/request.py`:新增 `BatchRequest.from_tasks()` 类方法,统一将 task_queue 任务分类为推理请求和控制请求
- `fastdeploy/worker/worker_process.py`:使用 `BatchRequest.from_tasks()` 替代内联解析逻辑,并修复重复的 control_reqs 处理块

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [Feature][KVCache] add NUMA affinity for host cache and skip swap cache tests

## Motivation

优化 Host cache 内存分配的 NUMA 亲和性,减少跨 NUMA 访问延迟;
同时跳过 swap cache ops 测试(当前环境不支持)。

## Modifications

- `fastdeploy/cache_manager/v1/cache_controller.py`:
  - 新增 `_get_numa_node_for_gpu()` 方法,通过 nvidia-smi 或 sysfs 获取 GPU 对应的 NUMA 节点
  - 新增 `_bind_to_closest_numa_node()` 方法,绑定当前线程到 GPU 最近的 NUMA 节点
  - 在 `initialize_host_cache()` 中调用 NUMA 绑定,优化 H2D 传输性能
- `tests/cache_manager/v1/test_swap_cache_ops.py`:跳过所有测试类(`TestSwapCacheAllLayersCorrectness`、`TestSwapCacheAllLayersPerformance`、`TestSwapCacheRandomBlockIndices`)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix unittest failures for cache_manager_v1

三个单测因接口变更或 Mock 方式问题导致失败,需修复。

- tests/distributed/chunked_moe.py:`setup_model_runner` 使用 `__new__` 跳过 `__init__`,补加 `enable_cache_manager_v1 = False`,修复 `AttributeError`
- tests/engine/test_resource_manager.py:`PrefixCacheManager` 为局部导入,`patch` 路径改为定义位置 `fastdeploy.cache_manager.prefix_cache_manager.PrefixCacheManager`
- tests/v1/test_resource_manager_v1.py:`_trigger_preempt` 第四参数已由 `list` 改为 `BatchRequest`,更新测试传参和断言

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] remove debug logging code

## Modifications

- fastdeploy/engine/request.py:删除调试用 logger 及 prompt_hashes 中的 debug 日志
- fastdeploy/worker/worker_process.py:删除 __main__ 中的调试 import 和 print 语句

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix cupy device id caching and pickle for _match_result

## Motivation

修复两个 bug:
1. `transfer_manager.py` 中每次调用 `cp.cuda.runtime.getDevice()` 存在隐患,应在初始化时缓存为实例变量,保证后续操作使用一致的设备 ID。
2. `request.py` 的 `__getstate__` 未跳过 `_match_result`,该字段包含 BlockNode 树的父子循环引用,pickle 时会触发 `RecursionError`;同时补充 `__setstate__` 确保 unpickle 后字段恢复为安全默认值。

## Modifications

- `transfer_manager.py`:初始化时调用 `cp.cuda.runtime.getDevice()` 并缓存到 `self._cupy_device_id`,后续 `with cp.cuda.Device(...)` 和日志均使用该缓存值。
- `request.py`:
  - `__getstate__` 中将 `_match_result` 加入跳过集合 `_SKIP_KEYS`,避免循环引用导致 pickle 失败。
  - 新增 `__setstate__`,unpickle 后将 `_block_hasher` 和 `_match_result` 恢复为 `None`。

## Usage or Command

* fix(test): fix unit test errors for _trigger_preempt and wakeup with MTP

- Use BatchRequest instead of list in test_trigger_preempt_records_tasks
- Add missing enable_cache_manager_v1 attr in TestSleepWakeupBehavior._make_runner

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] fix gpu_free_block_list returning wrong block IDs

## Motivation

`gpu_free_block_list` 的兼容 property 中误用了 `list(range(N))`,
将 `available_blocks()` 的返回值当作整数传给 `range()`,
导致返回 `[0, 1, ..., N-1]` 的假列表,而非真实的空闲 block ID。

## Modifications

- `cache_manager/v1/cache_manager.py`:将 `list(range(self._device_pool.available_blocks()))` 改为 `list(self._device_pool.available_blocks())`

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] 修复 gpu_free_block_list 返回 int 导致 TypeError

## Motivation

gpu_free_block_list 属性中调用 BlockPool.available_blocks(),
该方法返回 int(空闲块数量),用 list() 包装 int 会触发
TypeError: 'int' object is not iterable。

## Modifications

将 list(self._device_pool.available_blocks()) 改为
list(self._device_pool._free_blocks),直接返回空闲块索引列表。

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][CacheManager] 适配 V1 CacheManager 的 pause/sleep/free_cache 操作

## Motivation

V1 CacheManager 引入了新的 reset_cache() 接口,pause 和 sleep 操作需要适配,
同时 free_cache 需要支持可选的 clear_storage 参数。

## Modifications

- cache_controller.py: free_cache 新增 clear_storage 参数(默认 False),
  仅当 clear_storage=True 时才调用 _clear_storage(),避免不必要的 storage 清空
- common_engine.py: pause 和 sleep 操作中,当 ENABLE_V1_KVCACHE_MANAGER 时
  使用 cache_manager.reset_cache() 替代旧的 reset() 和 pause_transfer 逻辑
- gpu_model_runner.py: sleep 时仅在非 V1 cache manager 下执行 MTP cache 清除

## Usage or Command

# 启动服务(V1 CacheManager)
python -m fastdeploy.entrypoints.openai.api_server \
  --enable-v1-kvcache-manager \
  ...

* [BugFix][KVCache] fix missing enable_cache_manager_v1 in test mocks and remove unused select_blocks_for_backup

- Remove unused `select_blocks_for_backup` method from radix_tree.py
- Fix `match_prefix` default param `skip_storage=True` and log order in cache_manager.py
- Sync test_gpu_model_runner.py with upstream/develop (add TestInsertTasksV1SplitwiseSuffix)
- Add `enable_cache_manager_v1=False` to all mock runners to fix AttributeError in CI

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [BugFix][KVCache] simplify _free_blocks in ResourceManagerV1 for non-v1 path

Remove redundant prefix_caching branch in else path; always call
recycle_gpu_blocks with full block_tables for non-cache-manager-v1 case.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [KVCache][Optimization][BugFix] fix and optimize block_pool, cache_manager, transfer_manager, request

## Motivation

修复 cache_manager v1 中若干代码质量问题,提升性能并消除潜在的类型不一致 Bug。

## Modifications

1. **block_pool.py**:`BlockPool.allocate` 将逐个 pop 循环替换为切片 + 批量 set.update,消除 Python 循环开销,O(n) → O(k)(C 层批量操作)
2. **cache_manager.py**:`match_prefix` 在 prefix caching 关闭时提前 return 前写入空 `MatchResult()`,避免调用方解引用 `_match_result=None` 崩溃
3. **transfer_manager.py**:`_build_device_layer_indices` 在 `_cache_kvs_map` 为空时也重置四个层索引列表,防止残留旧 tensor 被 swap 算子使用
4. **request.py**:`BatchRequest.append_swap_metadata` / `append_evict_metadata` 构造 `CacheSwapMetadata` 时将 `src_type`/`dst_type` 从字符串改为 `CacheLevel` 枚举,与字段类型声明一致;补充 `CacheLevel` 导入;`match_result` 属性返回类型标注修正为 `Optional[MatchResult]`
5. **resource_manager_v1.py**:`_allocate_gpu_blocks` 日志从 `INFO` 降级为 `DEBUG`,消除高频调度路径的日志噪音
6. **tests/engine/test_request.py**:同步更新 `src_type`/`dst_type` 断言为 `CacheLevel` 枚举值,补充 `CacheLevel` 导入

## Usage or Command

单元测试:
```bash
source .venv/py310/bin/activate
cd baidu/FastDeploy
python -m pytest tests/cache_manager/v1/test_cache_manager.py -v
python -m pytest tests/cache_manager/v1/test_transfer_manager.py -v
python -m pytest tests/engine/test_request.py -v
```

* [BugFix][KVCache] Fix BlockPool.allocate returns all blocks when num_blocks=0

## Motivation

当 `allocate(num_blocks=0)` 被调用时,Python 负索引陷阱导致严重错误:
`-0 == 0`,所以 `self._free_blocks[-0:]` 等价于 `self._free_blocks[0:]`,
会返回并清空整个空闲块列表,而非返回空列表。

## Modifications

在 `BlockPool.allocate` 中增加对 `num_blocks == 0` 的提前判断,直接返回 `[]`,
避免触发 Python 负索引陷阱。

## Usage or Command

```bash
# 运行相关单元测试验证修复
python -m pytest tests/cache_manager/v1/test_cache_manager.py -vv -s
```

* [KVCache][Test] add unit tests for cache_manager v1 modules

## Motivation

补全 cache_manager/v1 各模块的单测覆盖,确保核心方法有完整的测试保障。

## Modifications

新增/补充以下测试文件,全部 326 个用例通过:

- tests/cache_manager/v1/test_block_pool.py(新建)
  覆盖 BlockPool.get_metadata/set_metadata/resize、DeviceBlockPool/HostBlockPool
- tests/cache_manager/v1/test_metadata.py(新建)
  覆盖 BlockNode、RadixTreeStats、MatchResult、CacheSwapMetadata、AsyncTaskHandler
- tests/cache_manager/v1/test_cache_utils.py(补充)
  新增 hash_block_tokens、get_request_block_hasher、LayerDoneCounter 时间追踪及内部辅助方法
- tests/cache_manager/v1/test_radix_tree.py(补充)
  新增 TestCompleteSwapToDevice 专项测试类(6 个用例)
- tests/cache_manager/v1/test_cache_manager.py(补充)
  新增 offload_to_host、load_from_host、pending backup 系列、prepare_prefetch_metadata
- tests/cache_manager/v1/test_transfer_manager.py(补充)
  新增 _swap_single_layer 校验路径、sync_input/output_stream、record_input_stream_event

## Usage or Command

```bash
# 运行所有新增单测
source .venv/py310/bin/activate
python -m pytest tests/cache_manager/v1/test_block_pool.py \
  tests/cache_manager/v1/test_metadata.py \
  tests/cache_manager/v1/test_cache_utils.py \
  tests/cache_manager/v1/test_radix_tree.py \
  tests/cache_manager/v1/test_cache_manager.py \
  tests/cache_manager/v1/test_transfer_manager.py -v
# 期望结果:326 passed
```

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2026-04-21 14:39:00 +08:00

401 lines
17 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @file swap_cache_optimized.cu
* @brief Optimized KV cache swap operators using warp-level parallelism.
*
* This file implements high-performance operators for KV cache transfer
* between GPU and CPU pinned memory:
*
* swap_cache_per_layer: Single-layer transfer (sync, backward compatible)
* swap_cache_per_layer_async: Single-layer transfer (async, no cudaStreamSync)
*
* Key optimizations vs original:
* 1. Consecutive block fast path: detects consecutive block ID runs and uses
* cudaMemcpyAsync instead of warp kernel (avoids kernel launch overhead).
* 2. Async variant: swap_cache_per_layer_async omits cudaStreamSynchronize,
* enabling true async pipelining when called on a dedicated cupy stream.
* 3. Warp-level PTX: non-temporal load/store for non-consecutive blocks to
* avoid L2 cache pollution.
*/
#include "cuda_multiprocess.h"
#include "helper.h"
#include "paddle/extension.h"
#include <cstdint>
#include <vector>
// ============================================================================
// Device Functions: Warp-Level Parallel Transfer
// ============================================================================
/**
* @brief Warp-level parallel data transfer function.
*
* Uses PTX inline assembly for optimized memory access:
* - ld.global.nc.b64: Non-cacheable load (avoids L2 cache pollution)
* - st.global.cg.b64: Cache-globing store (optimizes write performance)
*
* @param lane_id Thread lane ID within the warp (0-WARP_SIZE-1)
* @param src_addr Source memory address
* @param dst_addr Destination memory address
* @param item_size_bytes Size of the item in bytes (must be 8-byte aligned)
*/
__device__ __forceinline__ void transfer_item_warp(int32_t lane_id,
const void* src_addr,
void* dst_addr,
int64_t item_size_bytes) {
const uint64_t* __restrict__ src = static_cast<const uint64_t*>(src_addr);
uint64_t* __restrict__ dst = static_cast<uint64_t*>(dst_addr);
const int total_chunks = item_size_bytes / sizeof(uint64_t);
#pragma unroll
for (int j = lane_id; j < total_chunks; j += WARP_SIZE) {
uint64_t tmp;
#ifdef PADDLE_WITH_HIP
// ROCm/HIP path using built-in nontemporal operations
tmp = __builtin_nontemporal_load(src + j);
__builtin_nontemporal_store(tmp, dst + j);
#else
// NVIDIA CUDA path using PTX inline assembly
asm volatile("ld.global.nc.b64 %0,[%1];"
: "=l"(tmp)
: "l"(src + j)
: "memory");
asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp)
: "memory");
#endif
}
}
// ============================================================================
// Kernels
// ============================================================================
/**
* @brief CUDA kernel for single-layer KV cache transfer (non-consecutive path).
*
* Each warp processes one block using warp-level parallel PTX loads/stores.
* Used only when block IDs are non-consecutive; consecutive runs are handled
* by cudaMemcpyAsync in the host-side fast path.
*
* @tparam D2H true = Device->Host (evict), false = Host->Device (load)
*/
template <bool D2H>
__global__ void swap_cache_per_layer_kernel(
const void* __restrict__ src_ptr,
void* __restrict__ dst_ptr,
const int64_t* __restrict__ src_block_ids,
const int64_t* __restrict__ dst_block_ids,
int64_t num_blocks,
int64_t item_size_bytes) {
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t lane_id = tid % WARP_SIZE;
int32_t warp_id = tid / WARP_SIZE;
if (warp_id >= num_blocks) return;
int64_t src_block_id = src_block_ids[warp_id];
int64_t dst_block_id = dst_block_ids[warp_id];
const char* src_now =
static_cast<const char*>(src_ptr) + src_block_id * item_size_bytes;
char* dst_now = static_cast<char*>(dst_ptr) + dst_block_id * item_size_bytes;
transfer_item_warp(lane_id, src_now, dst_now, item_size_bytes);
}
// ============================================================================
// Helper: Consecutive Block Fast Path
// ============================================================================
/**
* @brief Transfer a single layer using consecutive-block detection.
*
* Scans src/dst block ID pairs for consecutive runs. For each run, issues
* a single cudaMemcpyAsync (like swap_cache_all_layers). Non-consecutive
* blocks are batched and handled by the warp kernel.
*
* @tparam D2H true = Device->Host, false = Host->Device
* @param src_ptr Source base pointer (GPU or CPU depending on D2H)
* @param dst_ptr Destination base pointer
* @param src_block_ids Host vector of source block IDs
* @param dst_block_ids Host vector of destination block IDs
* @param num_blocks Number of blocks to transfer
* @param item_size_bytes Bytes per block
* @param stream CUDA stream
*/
template <bool D2H>
void TransferSingleLayerWithFastPath(const void* src_ptr,
void* dst_ptr,
const std::vector<int64_t>& src_block_ids,
const std::vector<int64_t>& dst_block_ids,
int64_t num_blocks,
int64_t item_size_bytes,
cudaStream_t stream) {
// --- Pass 1: handle consecutive runs with cudaMemcpyAsync ---
// Collect indices of non-consecutive blocks for the kernel fallback.
std::vector<int64_t> nc_src, nc_dst;
const cudaMemcpyKind kind =
D2H ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
int64_t run_start = 0;
for (int64_t i = 1; i <= num_blocks; ++i) {
bool end_of_run = (i == num_blocks) ||
(src_block_ids[i] != src_block_ids[i - 1] + 1) ||
(dst_block_ids[i] != dst_block_ids[i - 1] + 1);
if (!end_of_run) continue;
int64_t run_len = i - run_start;
if (run_len > 1) {
// Consecutive run: merge into a single cudaMemcpyAsync
const char* src_run = static_cast<const char*>(src_ptr) +
src_block_ids[run_start] * item_size_bytes;
char* dst_run = static_cast<char*>(dst_ptr) +
dst_block_ids[run_start] * item_size_bytes;
checkCudaErrors(cudaMemcpyAsync(
dst_run, src_run, run_len * item_size_bytes, kind, stream));
} else {
// Single non-consecutive block: defer to warp kernel
nc_src.push_back(src_block_ids[run_start]);
nc_dst.push_back(dst_block_ids[run_start]);
}
run_start = i;
}
// --- Pass 2: warp kernel for remaining non-consecutive blocks ---
if (!nc_src.empty()) {
int64_t nc_count = static_cast<int64_t>(nc_src.size());
int64_t *d_src, *d_dst;
checkCudaErrors(
cudaMallocAsync(&d_src, nc_count * sizeof(int64_t), stream));
checkCudaErrors(
cudaMallocAsync(&d_dst, nc_count * sizeof(int64_t), stream));
checkCudaErrors(cudaMemcpyAsync(d_src,
nc_src.data(),
nc_count * sizeof(int64_t),
cudaMemcpyHostToDevice,
stream));
checkCudaErrors(cudaMemcpyAsync(d_dst,
nc_dst.data(),
nc_count * sizeof(int64_t),
cudaMemcpyHostToDevice,
stream));
constexpr int kWarpsPerBlock = 4;
const int threads_per_block = kWarpsPerBlock * WARP_SIZE;
const int grid =
(static_cast<int>(nc_count) + kWarpsPerBlock - 1) / kWarpsPerBlock;
swap_cache_per_layer_kernel<D2H><<<grid, threads_per_block, 0, stream>>>(
src_ptr, dst_ptr, d_src, d_dst, nc_count, item_size_bytes);
checkCudaErrors(cudaFreeAsync(d_src, stream));
checkCudaErrors(cudaFreeAsync(d_dst, stream));
}
}
// ============================================================================
// Implementation: Single Layer
// ============================================================================
/**
* @brief Core implementation for single-layer KV cache transfer.
*
* @param do_sync If true, calls cudaStreamSynchronize at end (sync op).
* Set to false for the async variant.
*/
template <paddle::DataType D, bool D2H>
void SwapCachePerLayerImpl(const paddle::Tensor& cache_gpu,
int64_t cache_cpu_ptr,
int64_t max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
cudaStream_t stream,
bool do_sync) {
typedef typename PDTraits<D>::DataType DataType_;
typedef typename PDTraits<D>::data_t data_t;
auto cache_shape = cache_gpu.shape();
const int64_t max_block_num_gpu = cache_shape[0];
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape.size() == 4 ? cache_shape[3] : 1;
const int64_t item_size_bytes =
num_heads * block_size * head_dim * sizeof(DataType_);
const int64_t num_blocks = swap_block_ids_gpu.size();
if (num_blocks == 0) return;
// Validate block IDs
for (size_t i = 0; i < swap_block_ids_gpu.size(); ++i) {
if (swap_block_ids_gpu[i] < 0 ||
swap_block_ids_gpu[i] >= max_block_num_gpu) {
PD_THROW("Invalid swap_block_ids_gpu at index " + std::to_string(i) +
": " + std::to_string(swap_block_ids_gpu[i]) +
" out of range [0, " + std::to_string(max_block_num_gpu) + ")");
}
if (swap_block_ids_cpu[i] < 0 ||
swap_block_ids_cpu[i] >= max_block_num_cpu) {
PD_THROW("Invalid swap_block_ids_cpu at index " + std::to_string(i) +
": " + std::to_string(swap_block_ids_cpu[i]) +
" out of range [0, " + std::to_string(max_block_num_cpu) + ")");
}
}
// D2H: src=GPU, dst=CPU; H2D: src=CPU, dst=GPU
const auto& src_block_ids = D2H ? swap_block_ids_gpu : swap_block_ids_cpu;
const auto& dst_block_ids = D2H ? swap_block_ids_cpu : swap_block_ids_gpu;
const void* src_ptr;
void* dst_ptr;
if (D2H) {
src_ptr = cache_gpu.data<data_t>();
dst_ptr = reinterpret_cast<void*>(cache_cpu_ptr);
} else {
src_ptr = reinterpret_cast<const void*>(cache_cpu_ptr);
dst_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
}
TransferSingleLayerWithFastPath<D2H>(src_ptr,
dst_ptr,
src_block_ids,
dst_block_ids,
num_blocks,
item_size_bytes,
stream);
if (do_sync) {
checkCudaErrors(cudaStreamSynchronize(stream));
}
}
// ============================================================================
// Operator Registration
// ============================================================================
// Operator Entry Points
// ============================================================================
// Helper macro to dispatch dtype and direction for SwapCachePerLayerImpl
#define DISPATCH_PER_LAYER(DTYPE, MODE, DO_SYNC, ...) \
switch (DTYPE) { \
case paddle::DataType::BFLOAT16: \
if ((MODE) == 0) \
SwapCachePerLayerImpl<paddle::DataType::BFLOAT16, true>(__VA_ARGS__, \
DO_SYNC); \
else \
SwapCachePerLayerImpl<paddle::DataType::BFLOAT16, false>(__VA_ARGS__, \
DO_SYNC); \
break; \
case paddle::DataType::FLOAT16: \
if ((MODE) == 0) \
SwapCachePerLayerImpl<paddle::DataType::FLOAT16, true>(__VA_ARGS__, \
DO_SYNC); \
else \
SwapCachePerLayerImpl<paddle::DataType::FLOAT16, false>(__VA_ARGS__, \
DO_SYNC); \
break; \
case paddle::DataType::UINT8: \
if ((MODE) == 0) \
SwapCachePerLayerImpl<paddle::DataType::UINT8, true>(__VA_ARGS__, \
DO_SYNC); \
else \
SwapCachePerLayerImpl<paddle::DataType::UINT8, false>(__VA_ARGS__, \
DO_SYNC); \
break; \
default: \
PD_THROW("Unsupported data type for swap_cache_per_layer."); \
}
/**
* @brief Single-layer KV cache swap (synchronous, backward compatible).
*/
void SwapCachePerLayer(const paddle::Tensor& cache_gpu,
int64_t cache_cpu_ptr,
int64_t max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int rank,
int mode) {
auto stream = cache_gpu.stream();
DISPATCH_PER_LAYER(cache_gpu.dtype(),
mode,
/*do_sync=*/true,
cache_gpu,
cache_cpu_ptr,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
stream);
}
/**
* @brief Single-layer KV cache swap (async, no cudaStreamSynchronize).
*
* Designed for use inside a cupy stream context. Completion is tracked
* by the caller via CUDA events (record_input_stream_event).
*/
void SwapCachePerLayerAsync(const paddle::Tensor& cache_gpu,
int64_t cache_cpu_ptr,
int64_t max_block_num_cpu,
const std::vector<int64_t>& swap_block_ids_gpu,
const std::vector<int64_t>& swap_block_ids_cpu,
int rank,
int mode) {
auto stream = cache_gpu.stream();
DISPATCH_PER_LAYER(cache_gpu.dtype(),
mode,
/*do_sync=*/false,
cache_gpu,
cache_cpu_ptr,
max_block_num_cpu,
swap_block_ids_gpu,
swap_block_ids_cpu,
stream);
}
// ============================================================================
// Operator Registration
// ============================================================================
PD_BUILD_STATIC_OP(swap_cache_per_layer)
.Inputs({"cache_gpu"})
.Attrs({
"cache_cpu_ptr: int64_t",
"max_block_num_cpu: int64_t",
"swap_block_ids_gpu: std::vector<int64_t>",
"swap_block_ids_cpu: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.Outputs({"cache_dst_out"})
.SetInplaceMap({{"cache_gpu", "cache_dst_out"}})
.SetKernelFn(PD_KERNEL(SwapCachePerLayer));
PD_BUILD_STATIC_OP(swap_cache_per_layer_async)
.Inputs({"cache_gpu"})
.Attrs({
"cache_cpu_ptr: int64_t",
"max_block_num_cpu: int64_t",
"swap_block_ids_gpu: std::vector<int64_t>",
"swap_block_ids_cpu: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.Outputs({"cache_dst_out"})
.SetInplaceMap({{"cache_gpu", "cache_dst_out"}})
.SetKernelFn(PD_KERNEL(SwapCachePerLayerAsync));