[Feature] Support KV Cache Storage (#5571)

* Support Mooncake Store

* up

* up

* add op

* fix conflict

* fix error

* up for comments

* avoid thread lock

* up

* fix unittest

* fix unittest

* remove debug info

* consider tp_size > 1

* add default rdma_nics

* add utils

* up

* fix error

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
Juncai
2025-12-25 16:30:35 +08:00
committed by GitHub
parent be3be4913a
commit 412867fd99
27 changed files with 1672 additions and 195 deletions
+209 -73
View File
@@ -14,10 +14,8 @@
# limitations under the License.
"""
import hashlib
import heapq
import os
import pickle
import subprocess
import sys
import threading
@@ -34,9 +32,10 @@ from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.cache_manager.ops import get_all_visible_devices
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
from fastdeploy.utils import get_hash_str, get_logger
logger = get_logger("prefix_cache_manager", "cache_manager.log")
@@ -65,6 +64,7 @@ class PrefixCacheManager:
self.enable_splitwise = 0
self.splitwise_role = splitwise_role
self.config = config
self.tensor_parallel_size = tensor_parallel_size
self.cache_config = config.cache_config
self.speculative_config = config.speculative_config
self.local_data_parallel_id = local_data_parallel_id
@@ -89,6 +89,13 @@ class PrefixCacheManager:
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# prams for cache storage
self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend
self.write_policy = self.cache_config.write_policy
self.task_write_back_event = {}
self.task_prefetch_event = {}
self.storage_prefetch_block_ids = {}
# gpu cache data structure
self.gpu_lru_leaf_heap = []
self.gpu_lru_leaf_set = set()
@@ -105,7 +112,7 @@ class PrefixCacheManager:
self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list)
self.cache_info = {}
self.cache_info = {} # {request_id: (last_match_node, num_cached_tokens)}
self.executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
@@ -253,6 +260,10 @@ class PrefixCacheManager:
else:
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
if cache_config.kvcache_storage_backend:
kvcache_storage_backend_str = cache_config.kvcache_storage_backend
else:
kvcache_storage_backend_str = "none"
for i in range(tensor_parallel_size):
launch_cmd = (
@@ -281,7 +292,10 @@ class PrefixCacheManager:
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
+ f" --write_policy {cache_config.write_policy}"
+ f" --max_model_len {self.config.model_config.max_model_len}"
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -290,7 +304,7 @@ class PrefixCacheManager:
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
if self.num_cpu_blocks > 0:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
@@ -303,7 +317,7 @@ class PrefixCacheManager:
)
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
threading.Thread(target=self.recv_data_transfer_result).start()
if cache_config.enable_prefix_caching:
@@ -505,13 +519,7 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task(
(
swap_node_ids,
gpu_block_ids,
cpu_block_ids,
event_type,
transfer_task_id,
)
(event_type, transfer_task_id, swap_node_ids, gpu_block_ids, cpu_block_ids)
)
if is_sync:
self.sync_swap_task(transfer_task_id)
@@ -629,6 +637,10 @@ class PrefixCacheManager:
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
logger.debug(
f"update_cache_blocks: req_id {req_id}, num_cached_tokens {num_cached_tokens}, "
f"can_cache_computed_tokens {can_cache_computed_tokens}"
)
with self.request_release_lock:
leaf_node = self.mm_build_path(
@@ -640,7 +652,7 @@ class PrefixCacheManager:
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
self.cache_info[req_id] = [leaf_node, can_cache_computed_tokens]
task.cached_block_num = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
@@ -692,7 +704,7 @@ class PrefixCacheManager:
else:
prompt_token_ids = task.prompt_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
logger.info(f"request_match_blocks: start to process req {req_id}")
input_token_num = len(prompt_token_ids + task.output_token_ids)
common_block_ids = []
# 1. match block
@@ -708,12 +720,14 @@ class PrefixCacheManager:
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
# 2. prepare cache
# allocate gpu cache for matched cpu blocks
# 2. prepare cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
gpu_recv_block_ids = []
match_cpu_blocks_num = len(match_cpu_block_ids)
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
if match_cpu_blocks_num > 0:
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
)
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(
@@ -746,20 +760,58 @@ class PrefixCacheManager:
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
)
logger.info(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
self.cache_info[req_id] = [match_block_node, len(common_block_ids) * block_size]
task.cached_block_num = len(common_block_ids)
return common_block_ids, matched_token_num, hit_info
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
def request_match_storage_blocks(self, request, extra_gpu_block_ids):
"""
Match and fetch the cached blocks from the storage backend for the given request.
# TODO: merge this function into request_match_blocks
args:
request: The request to be processed
extra_gpu_block_ids: A list of GPU block IDs to be used for fetching the cache
returns:
matched_block_ids: A list of block IDs that prefetched cache from storage
"""
if self.kvcache_storage_backend is None:
return []
req_id = request.request_id
input_ids = request.prompt_token_ids
block_size = self.cache_config.block_size
prefix_block_key = []
num_cached_tokens = 0
if req_id in self.cache_info:
last_node, num_cached_tokens = self.cache_info[req_id]
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
block_keys = []
current_tokens = num_cached_tokens
while current_tokens <= len(input_ids) - block_size:
cur_block_key = get_hash_str(input_ids[current_tokens : current_tokens + block_size], prefix_block_key)
block_keys.append(cur_block_key)
current_tokens += block_size
prefix_block_key = [cur_block_key]
logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}")
matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids)
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}"
)
return matched_block_ids
def request_block_ids(self, task, block_size, dec_token_num, *args):
"""
Allocate blocks for a task.
@@ -806,10 +858,7 @@ class PrefixCacheManager:
current_time = time.time()
self._update_matched_node_info(req_id, match_block_node, current_time)
# 2. prepare cache
(
gpu_recv_block_ids,
gpu_extra_block_ids,
) = self._prepare_cache(
(gpu_recv_block_ids, gpu_extra_block_ids) = self._prepare_cache(
req_id,
input_ids,
block_size,
@@ -829,7 +878,6 @@ class PrefixCacheManager:
gpu_build_path_block_ids = []
gpu_build_path_block_ids = gpu_extra_block_ids
leaf_node = self.build_path(
req_id,
current_time,
@@ -883,6 +931,7 @@ class PrefixCacheManager:
with self.request_release_lock:
try:
req_id = task.request_id
keys = []
leaf_node = self.req_leaf_map.pop(req_id)
if leaf_node in self.leaf_req_map:
self.leaf_req_map[leaf_node].remove(req_id)
@@ -893,6 +942,7 @@ class PrefixCacheManager:
if req_id in node.req_id_set:
node.req_id_set.remove(req_id)
node.decrement_shared_count()
keys.append(node.hash_value)
node = node.parent
if req_id in self.cache_info:
@@ -919,6 +969,78 @@ class PrefixCacheManager:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def write_cache_to_storage(self, request: Request):
"""
For finished request, write cache to storage.
NOTE: this function does not modify the global params
"""
if self.kvcache_storage_backend is None:
return
req_id = request.request_id
keys = []
node = self.req_leaf_map[req_id]
while node != self.radix_tree_root:
keys.append(node.hash_value)
node = node.parent
keys = list(reversed(keys))
if not keys:
return
gpu_block_ids = request.block_tables[: len(keys)]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
tic = time.time()
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
cost_time = time.time() - tic
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
if self.kvcache_storage_backend is None:
return
if len(hash_keys) != len(gpu_block_ids):
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
logger.error(err_msg)
raise ValueError(err_msg)
self.task_write_back_event[req_id] = Event()
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
if is_sync:
self.wait_write_storage_task(req_id)
def wait_write_storage_task(self, req_id):
"""
Sync write back task
"""
if req_id in self.task_write_back_event:
self.task_write_back_event[req_id].wait()
del self.task_write_back_event[req_id]
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
"""
Prefetch cache from storage task
"""
storage_block_ids = []
self.task_prefetch_event[req_id] = Event()
# issue task to cache_transfer_manager
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
if is_sync:
storage_block_ids = self.wait_prefetch_storage_task(req_id)
return storage_block_ids
def wait_prefetch_storage_task(self, req_id):
"""
Wait for prefetch cache from storage task to finish
"""
if req_id not in self.task_prefetch_event:
return None
self.task_prefetch_event[req_id].wait()
storage_block_ids = self.storage_prefetch_block_ids[req_id]
del self.task_prefetch_event[req_id]
del self.storage_prefetch_block_ids[req_id]
return storage_block_ids
def free_nodes_directly(self, node):
with self.request_release_lock:
try:
@@ -1069,10 +1191,7 @@ class PrefixCacheManager:
break
node = heapq.heappop(self.gpu_lru_leaf_heap)
self.gpu_lru_leaf_set.remove(node)
if (
not self.cache_config.enable_hierarchical_cache
or self.cache_config.num_cpu_blocks < need_block_num
):
if self.cache_config.num_cpu_blocks < need_block_num:
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
total_gpu_free_count += 1
@@ -1195,12 +1314,6 @@ class PrefixCacheManager:
)
return total_cpu_free_count
def cal_block_hash(self, block):
"""
calculate hash value of a block
"""
return hash(tuple(block))
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
"""
Retrieves additional hash keys for block identification.
@@ -1260,16 +1373,6 @@ class PrefixCacheManager:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return len(mm_inputs["mm_positions"]) - 1, hash_keys
def hash_block_features(self, input_ids, extra_keys: list = []):
"""
calculate hash value of a block with additional keys
Args:
input_ids: Input token IDs
extra_keys: Additional keys for block identification
"""
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
def _revert_match_blocks(
self,
request,
@@ -1363,6 +1466,7 @@ class PrefixCacheManager:
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
prefix_block_key = []
with self.cache_status_lock:
while match_token_num < total_token_num:
@@ -1376,7 +1480,10 @@ class PrefixCacheManager:
end_idx=match_token_num + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(token_block, extra_keys)
prefix_block_key.extend(extra_keys)
hash_value = get_hash_str(token_block, prefix_block_key)
prefix_block_key = [hash_value]
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
@@ -1476,6 +1583,7 @@ class PrefixCacheManager:
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
prefix_block_key = []
with self.cache_status_lock:
while match_token_num < total_token_num:
@@ -1483,7 +1591,8 @@ class PrefixCacheManager:
token_num = len(token_block)
if token_num != block_size:
break
hash_value = self.cal_block_hash(token_block)
hash_value = get_hash_str(token_block, prefix_block_key)
prefix_block_key = [hash_value]
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
@@ -1515,6 +1624,8 @@ class PrefixCacheManager:
swap_node_ids.append(child.node_id)
match_token_num = match_token_num + block_size
current_match_node = child
# record request cache info
self.cache_info[req_id] = [child, match_token_num]
else:
break
@@ -1577,8 +1688,10 @@ class PrefixCacheManager:
has_unfilled_block = False
current_time = time.time()
input_hash_value = self.hash_block_features(input_ids)
input_hash_value = get_hash_str(input_ids)
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
current_block = input_ids[i : i + block_size]
current_block_size = len(current_block) # 最后一个block可能没填满
@@ -1591,7 +1704,9 @@ class PrefixCacheManager:
end_idx=i + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(current_block, extra_keys)
prefix_block_key.extend(extra_keys)
hash_value = get_hash_str(current_block, prefix_block_key)
prefix_block_key = [hash_value]
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
@@ -1651,7 +1766,7 @@ class PrefixCacheManager:
gpu_block_ids = gpu_block_ids.copy()
node = last_node
reverved_dec_block_ids = []
input_hash_value = self.cal_block_hash(input_ids)
input_hash_value = get_hash_str(input_ids)
token_num = len(left_input_ids)
if token_num == 0:
@@ -1663,6 +1778,7 @@ class PrefixCacheManager:
unique_node_ids = []
new_last_node = last_node
has_unfilled_block = False
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
for i in range(0, token_num, block_size):
current_block = left_input_ids[i : i + block_size]
@@ -1670,7 +1786,8 @@ class PrefixCacheManager:
if current_block_size != block_size:
has_unfilled_block = True
else:
hash_value = self.cal_block_hash(current_block)
hash_value = get_hash_str(current_block, prefix_block_key)
prefix_block_key = [hash_value]
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
@@ -1764,28 +1881,47 @@ class PrefixCacheManager:
if data is None:
time.sleep(0.001)
continue
(
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
) = data
length = len(task_gpu_block_id)
for i in range(length):
self._handle_swap_result(
swap_node_ids[i],
task_gpu_block_id[i],
task_cpu_block_id[i],
event_type = data[0]
if event_type.value == CacheStatus.STORAGE2GPU.value:
logger.info(f"recv_data_transfer_result: {data}")
task_id, hash_keys, block_ids = data[1:]
if task_id not in self.storage_prefetch_block_ids:
self.storage_prefetch_block_ids[task_id] = []
saved_block_ids = self.storage_prefetch_block_ids[task_id]
saved_block_ids.append(block_ids)
if len(saved_block_ids) == self.tensor_parallel_size:
self.storage_prefetch_block_ids[task_id] = min(saved_block_ids, key=len)
if task_id in self.task_prefetch_event:
self.task_prefetch_event[task_id].set()
elif event_type.value == CacheStatus.GPU2STORAGE.value:
logger.info(f"recv_data_transfer_result: {data}")
task_id, hash_keys, block_ids = data[1:]
if task_id in self.task_write_back_event:
self.task_write_back_event[task_id].set()
else:
(
event_type,
transfer_task_id,
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
) = data
length = len(task_gpu_block_id)
for i in range(length):
self._handle_swap_result(
swap_node_ids[i],
task_gpu_block_id[i],
task_cpu_block_id[i],
event_type,
)
if transfer_task_id in self.task_swapping_event:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
if transfer_task_id in self.task_swapping_event:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e