[Feature] Support KV Cache Storage (#5571)

* Support Mooncake Store

* up

* up

* add op

* fix conflict

* fix error

* up for comments

* avoid thread lock

* up

* fix unittest

* fix unittest

* remove debug info

* consider tp_size > 1

* add default rdma_nics

* add utils

* up

* fix error

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
Juncai
2025-12-25 16:30:35 +08:00
committed by GitHub
parent be3be4913a
commit 412867fd99
27 changed files with 1672 additions and 195 deletions
@@ -22,6 +22,7 @@ import queue
import threading
import time
import traceback
from typing import List
import numpy as np
import paddle
@@ -36,8 +37,10 @@ from fastdeploy.cache_manager.ops import (
set_device,
share_external_data_,
swap_cache_all_layers,
swap_cache_layout,
unset_data_ipc,
)
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform
@@ -55,8 +58,9 @@ def parse_args():
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
@@ -101,6 +105,20 @@ def parse_args():
help="speculative config",
)
parser.add_argument("--create_cache_tensor", action="store_true")
parser.add_argument(
"--kvcache_storage_backend",
type=str,
default=None,
choices=["mooncake", "none"],
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
)
parser.add_argument(
"--write_policy",
type=str,
choices=["write_through"],
default="write_through",
help="KVCache write policy",
)
args = parser.parse_args()
return args
@@ -135,6 +153,9 @@ class CacheTransferManager:
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
@@ -194,8 +215,54 @@ class CacheTransferManager:
suffix=args.engine_worker_queue_port,
create=False,
)
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
self.storage_backend = None
elif args.kvcache_storage_backend == "mooncake":
logger.info("Start initialize mooncake store...")
self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully")
else:
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_storage_buffer(self, args):
"""
Initialize pinned memory buffer that can hold the cache for a longest request
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
"""
layer_num = args.num_layers + self.num_extra_layers
head_num = self.key_cache_shape[1]
block_size = self.key_cache_shape[2]
head_dim = self.key_cache_shape[3]
block_num = (args.max_model_len + block_size - 1) // block_size
logger.info(
f"Creating cache buffer for storage with shape: "
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
)
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
read_buffer = cuda_host_alloc(total_bytes)
self.storage_key_read_buffer = read_buffer
self.storage_value_read_buffer = read_buffer + total_bytes // 2
self.storage_backend.register_buffer(read_buffer, total_bytes)
write_buffer = cuda_host_alloc(total_bytes)
self.storage_key_write_buffer = write_buffer
self.storage_value_write_buffer = write_buffer + total_bytes // 2
self.storage_backend.register_buffer(write_buffer, total_bytes)
def _init_gpu_cache(self, args):
try:
@@ -319,12 +386,7 @@ class CacheTransferManager:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
if args.cache_dtype == "bfloat16":
cache_bytes = 2
elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8":
cache_bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
cache_bytes = self._get_cache_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
if args.cache_dtype == "block_wise_fp8":
@@ -367,6 +429,222 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _get_cache_bytes(self, cache_dtype):
if cache_dtype == "bfloat16":
cache_bytes = 2
elif cache_dtype in ["uint8", "block_wise_fp8"]:
cache_bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
return cache_bytes
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
"""
Given the k_keys and v_keys, get the valid blocks number that
can be prefetched from storage backend.
"""
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.storage_backend.exists(k_keys + v_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
num += 1
return num
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_read_storage, key_hash_keys: {k_cache_keys}, "
f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}"
)
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
for k, v in zip(k_result, v_result):
if k > 0 and v > 0:
success_block_num += 1
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
mode = 1 # cpu ==> gpu
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_read_buffer,
self.key_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_read_buffer,
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
return valid_gpu_block_ids
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
f"error:{e}, {traceback.format_exc()}"
)
raise
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
"""Read cache from the storage backend to the GPU memory."""
try:
logger.debug(
f"read_storage_task, task id: {task_id}, hash_keys: {keys}, "
f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
k_cache_keys = k_cache_keys[:match_block_num]
v_cache_keys = v_cache_keys[:match_block_num]
gpu_block_ids = gpu_block_ids[:match_block_num]
cpu_block_ids = [i for i in range(match_block_num)]
valid_gpu_block_ids = []
if match_block_num > 0:
# TODO: support timeout with actual block count
try:
valid_gpu_block_ids = self._run_read_storage(
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
)
logger.info(
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
)
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
valid_gpu_block_ids = []
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
logger.info(
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
f"valid block num {len(valid_gpu_block_ids)}"
)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
)
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
f"gpu_block_ids: {gpu_block_ids}"
)
key_cache_size = [
self.key_cache_shape[0],
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
mode = 0 # gpu ==> cpu
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
f"error:{e}, {traceback.format_exc()}"
)
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
"""
Write cache to the storage backend from the GPU memory.
"""
try:
logger.debug(
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
f"task_id: {task_id}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
if len(k_cache_keys) == 0:
logger.info(f"No uncached keys found for task {task_id}")
gpu_block_ids = []
else:
try:
# TODO: support timeout with actual block count
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
except Exception as e:
logger.error(f"Error in write back storage task: {e}")
gpu_block_ids = []
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait()
if self.rank == 0: # 只有当rank为0时执行同步操作
self.cache_task_queue.swap_to_storage_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
f"error:{e}, {traceback.format_exc()}"
)
def _do_swap_to_cpu_task(
self,
swap_node_ids,
@@ -459,14 +737,9 @@ class CacheTransferManager:
logger.debug(f"transfer data: get_transfer_task {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
(
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
) = data
event_type, transfer_task_id = data[0], data[1]
if event_type.value == CacheStatus.SWAP2CPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
self.swap_to_cpu_thread_pool.submit(
self._do_swap_to_cpu_task,
swap_node_ids,
@@ -475,7 +748,8 @@ class CacheTransferManager:
event_type,
transfer_task_id,
)
else:
elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
self.swap_to_gpu_thread_pool.submit(
self._do_swap_to_gpu_task,
swap_node_ids,
@@ -484,6 +758,24 @@ class CacheTransferManager:
event_type,
transfer_task_id,
)
elif event_type.value == CacheStatus.STORAGE2GPU.value:
hash_keys, gpu_block_ids, timeout = data[2:]
self.read_storage_thread_pool.submit(
self.read_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
)
elif event_type.value == CacheStatus.GPU2STORAGE.value:
hash_keys, gpu_block_ids, timeout = data[2:]
self.write_back_storage_thread_pool.submit(
self.write_back_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
@@ -635,11 +927,11 @@ class CacheTransferManager:
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
)
return (
event_type,
transfer_task_id,
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
)
def clear_or_update_caches(self, args):
@@ -738,9 +1030,7 @@ def main():
"""
启动cache manager
"""
cache_manager = CacheTransferManager(args)
cache_manager.do_data_transfer()
@@ -749,5 +1039,10 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
logger.info(f"args: {vars(args)}")
set_device(args.device_id)
main()
try:
main()
except Exception as e:
logger.error(f"cache_transfer_manager failed with error: {e}, traceback: {traceback.format_exc()}")
raise