mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support KV Cache Storage (#5571)
* Support Mooncake Store * up * up * add op * fix conflict * fix error * up for comments * avoid thread lock * up * fix unittest * fix unittest * remove debug info * consider tp_size > 1 * add default rdma_nics * add utils * up * fix error --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -36,8 +37,10 @@ from fastdeploy.cache_manager.ops import (
|
||||
set_device,
|
||||
share_external_data_,
|
||||
swap_cache_all_layers,
|
||||
swap_cache_layout,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -55,8 +58,9 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument(
|
||||
@@ -101,6 +105,20 @@ def parse_args():
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--create_cache_tensor", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kvcache_storage_backend",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mooncake", "none"],
|
||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write_policy",
|
||||
type=str,
|
||||
choices=["write_through"],
|
||||
default="write_through",
|
||||
help="KVCache write policy",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -135,6 +153,9 @@ class CacheTransferManager:
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.n_ranks = args.mp_num
|
||||
@@ -194,8 +215,54 @@ class CacheTransferManager:
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
|
||||
self.storage_backend = None
|
||||
elif args.kvcache_storage_backend == "mooncake":
|
||||
logger.info("Start initialize mooncake store...")
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
"""
|
||||
layer_num = args.num_layers + self.num_extra_layers
|
||||
head_num = self.key_cache_shape[1]
|
||||
block_size = self.key_cache_shape[2]
|
||||
head_dim = self.key_cache_shape[3]
|
||||
block_num = (args.max_model_len + block_size - 1) // block_size
|
||||
logger.info(
|
||||
f"Creating cache buffer for storage with shape: "
|
||||
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
|
||||
)
|
||||
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
read_buffer = cuda_host_alloc(total_bytes)
|
||||
self.storage_key_read_buffer = read_buffer
|
||||
self.storage_value_read_buffer = read_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(read_buffer, total_bytes)
|
||||
|
||||
write_buffer = cuda_host_alloc(total_bytes)
|
||||
self.storage_key_write_buffer = write_buffer
|
||||
self.storage_value_write_buffer = write_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, total_bytes)
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
try:
|
||||
@@ -319,12 +386,7 @@ class CacheTransferManager:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
if args.cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8":
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
|
||||
cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
@@ -367,6 +429,222 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _get_cache_bytes(self, cache_dtype):
|
||||
if cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif cache_dtype in ["uint8", "block_wise_fp8"]:
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return cache_bytes
|
||||
|
||||
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.storage_backend.exists(k_keys + v_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_read_storage, key_hash_keys: {k_cache_keys}, "
|
||||
f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
|
||||
mode = 1 # cpu ==> gpu
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
|
||||
return valid_gpu_block_ids
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
"""Read cache from the storage backend to the GPU memory."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"read_storage_task, task id: {task_id}, hash_keys: {keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}"
|
||||
)
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
|
||||
|
||||
k_cache_keys = k_cache_keys[:match_block_num]
|
||||
v_cache_keys = v_cache_keys[:match_block_num]
|
||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||
cpu_block_ids = [i for i in range(match_block_num)]
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
if match_block_num > 0:
|
||||
# TODO: support timeout with actual block count
|
||||
try:
|
||||
valid_gpu_block_ids = self._run_read_storage(
|
||||
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
|
||||
)
|
||||
logger.info(
|
||||
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
|
||||
f"valid block num {len(valid_gpu_block_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
|
||||
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
mode = 0 # gpu ==> cpu
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
"""
|
||||
Write cache to the storage backend from the GPU memory.
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
|
||||
f"task_id: {task_id}, timeout: {timeout}"
|
||||
)
|
||||
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
|
||||
if len(k_cache_keys) == 0:
|
||||
logger.info(f"No uncached keys found for task {task_id}")
|
||||
gpu_block_ids = []
|
||||
else:
|
||||
try:
|
||||
# TODO: support timeout with actual block count
|
||||
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in write back storage task: {e}")
|
||||
gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
|
||||
self.cache_task_queue.swap_to_storage_barrier.wait()
|
||||
if self.rank == 0: # 只有当rank为0时执行同步操作
|
||||
self.cache_task_queue.swap_to_storage_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
||||
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
self,
|
||||
swap_node_ids,
|
||||
@@ -459,14 +737,9 @@ class CacheTransferManager:
|
||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
(
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
) = data
|
||||
event_type, transfer_task_id = data[0], data[1]
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
@@ -475,7 +748,8 @@ class CacheTransferManager:
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
else:
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
@@ -484,6 +758,24 @@ class CacheTransferManager:
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
self.read_storage_thread_pool.submit(
|
||||
self.read_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
self.write_back_storage_thread_pool.submit(
|
||||
self.write_back_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier2.wait()
|
||||
@@ -635,11 +927,11 @@ class CacheTransferManager:
|
||||
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
|
||||
)
|
||||
return (
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
|
||||
def clear_or_update_caches(self, args):
|
||||
@@ -738,9 +1030,7 @@ def main():
|
||||
"""
|
||||
启动cache manager
|
||||
"""
|
||||
|
||||
cache_manager = CacheTransferManager(args)
|
||||
|
||||
cache_manager.do_data_transfer()
|
||||
|
||||
|
||||
@@ -749,5 +1039,10 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
||||
logger.info(f"args: {vars(args)}")
|
||||
set_device(args.device_id)
|
||||
main()
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
logger.error(f"cache_transfer_manager failed with error: {e}, traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user