[BugFix] Refine the preparation of cpu and storage cache (#5777)

* Refine the preparation of cpu and storage cache

* fix error

* fix error

* up

* fix

* up docs

* fix unittest

* remove debug info
This commit is contained in:
jc
2026-01-05 10:13:30 +08:00
committed by GitHub
parent 95257c1dbd
commit e911ac2ce7
10 changed files with 156 additions and 149 deletions
@@ -112,7 +112,7 @@ class PrefixCacheManager:
self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list)
self.cache_info = {} # {request_id: (last_match_node, num_cached_tokens)}
self.req_to_radix_tree_info = {} # {request_id: (last_match_node, num_cached_tokens_in_raidx_tree)}
self.executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
@@ -634,7 +634,7 @@ class PrefixCacheManager:
"""
try:
req_id = task.request_id
last_node, num_cached_tokens = self.cache_info[req_id]
last_node, num_cached_tokens = self.req_to_radix_tree_info[req_id]
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
@@ -653,8 +653,8 @@ class PrefixCacheManager:
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = [leaf_node, can_cache_computed_tokens]
task.cached_block_num = can_cache_computed_tokens // block_size
self.req_to_radix_tree_info[req_id] = [leaf_node, can_cache_computed_tokens]
task.num_cached_blocks = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
@@ -674,14 +674,14 @@ class PrefixCacheManager:
break
return False, 0
def request_match_blocks(self, task, block_size, *args):
def request_match_blocks(self, task: Request, block_size, *args):
"""
get match blocks info for a task.
Match and fetch cache for a task.
This is a synchronous interface. If CPU-to-GPU data transfer occurs,
it will block until synchronization completes.
Callers requiring asynchronous behavior should invoke this via a thread pool.
Note: This function may allocate GPU blocks for matched CPU Cache
Note: This function may allocate GPU blocks for matched CPU Cache and Storage Cache
Parameters:
- task: Task dictionary
@@ -689,15 +689,17 @@ class PrefixCacheManager:
Returns:
- common_block_ids: List of matched shared blocks
- unique_block_ids: List of exclusively allocated blocks
- match_token_num: Number of matched tokens
- metrics: Dictionary of metrics
"""
with self.request_release_lock:
try:
hit_info = {
"gpu_cache_blocks": 0,
"cpu_cache_blocks": 0,
metrics = {
"gpu_match_token_num": 0,
"cpu_match_token_num": 0,
"storage_match_token_num": 0,
"cpu_cache_prepare_time": 0,
"storage_cache_prepare_time": 0,
}
self.metrics.req_count += 1
if isinstance(task.prompt_token_ids, np.ndarray):
@@ -706,7 +708,8 @@ class PrefixCacheManager:
prompt_token_ids = task.prompt_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to process req {req_id}")
input_token_num = len(prompt_token_ids + task.output_token_ids)
input_token_ids = prompt_token_ids + task.output_token_ids
input_token_num = len(input_token_ids)
common_block_ids = []
# 1. match block
(
@@ -721,7 +724,7 @@ class PrefixCacheManager:
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
# 2. prepare cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
# 2. prepare cpu cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
gpu_recv_block_ids = []
match_cpu_blocks_num = len(match_cpu_block_ids)
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
@@ -731,6 +734,7 @@ class PrefixCacheManager:
)
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
if len(gpu_recv_block_ids) > 0:
start_time = time.time()
self._prepare_cpu_cache(
req_id=req_id,
swap_node_ids=swap_node_ids,
@@ -738,81 +742,94 @@ class PrefixCacheManager:
match_cpu_block_ids=match_cpu_block_ids,
cpu_recv_block_ids=[],
)
cost_time = time.time() - start_time
metrics["cpu_cache_prepare_time"] = cost_time
else:
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
)
# 3. update metrics
matched_token_num = gpu_match_token_num + cpu_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
if matched_token_num > 0:
# 3. match and prefetch cache from storage
match_token_num = gpu_match_token_num + cpu_match_token_num
no_match_token_num = input_token_num - match_token_num
no_match_block_num = (no_match_token_num + block_size - 1) // block_size
gpu_recv_storage_block_ids = []
storage_match_token_num = 0
match_storage_block_ids = []
if self.kvcache_storage_backend and no_match_token_num >= block_size:
if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num):
raise Exception(
"request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache"
)
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {no_match_block_num} block to receive storage cache"
)
gpu_recv_storage_block_ids = self.allocate_gpu_blocks(no_match_block_num)
prefix_block_key = [] if match_block_node.hash_value is None else [match_block_node.hash_value]
cur_token_idx = match_token_num
no_match_block_keys = []
while cur_token_idx <= input_token_num - block_size:
cur_block_token_ids = input_token_ids[cur_token_idx : cur_token_idx + block_size]
cur_block_key = get_hash_str(cur_block_token_ids, prefix_block_key)
no_match_block_keys.append(cur_block_key)
cur_token_idx += block_size
prefix_block_key = [cur_block_key]
logger.info(
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
)
start_time = time.time()
storage_matched_block_ids = self.issue_prefetch_storage_task(
req_id, no_match_block_keys, gpu_recv_storage_block_ids
)
storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time
metrics["storage_cache_prepare_time"] = cost_time
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, "
f"matched block num: {storage_matched_block_num}, cost_time:{cost_time:.6f}s"
)
match_storage_block_ids = gpu_recv_storage_block_ids[:storage_matched_block_num]
self.recycle_gpu_blocks(gpu_recv_storage_block_ids[storage_matched_block_num:])
# 4. update metrics
match_token_num = gpu_match_token_num + cpu_match_token_num + storage_match_token_num
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids + match_storage_block_ids
if match_token_num > 0:
self.metrics.hit_req_count += 1
self.metrics.calculate_hit_metrics(
req_id,
cpu_match_token_num,
gpu_match_token_num,
storage_match_token_num,
input_token_num,
)
hit_info["gpu_cache_blocks"] = len(match_gpu_block_ids)
hit_info["cpu_cache_blocks"] = len(match_cpu_block_ids)
hit_info["gpu_match_token_num"] = gpu_match_token_num
hit_info["cpu_match_token_num"] = cpu_match_token_num
metrics["gpu_match_token_num"] = gpu_match_token_num
metrics["cpu_match_token_num"] = cpu_match_token_num
metrics["storage_match_token_num"] = storage_match_token_num
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
logger.debug(f"request_match_blocks: req_id {req_id}, matched_block_ids_num {len(common_block_ids)}")
logger.debug(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = [match_block_node, len(common_block_ids) * block_size]
task.cached_block_num = len(common_block_ids)
return common_block_ids, matched_token_num, hit_info
# record request cache info in radix tree, note that the block ids for receiving storage cache
# are recorded into radix tree in update_cache_blocks
self.req_to_radix_tree_info[req_id] = [match_block_node, gpu_match_token_num + cpu_match_token_num]
task.num_cached_blocks = len(common_block_ids)
return common_block_ids, match_token_num, metrics
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
def request_match_storage_blocks(self, request, extra_gpu_block_ids):
"""
Match and fetch the cached blocks from the storage backend for the given request.
# TODO: merge this function into request_match_blocks
args:
request: The request to be processed
extra_gpu_block_ids: A list of GPU block IDs to be used for fetching the cache
returns:
matched_block_ids: A list of block IDs that prefetched cache from storage
"""
if self.kvcache_storage_backend is None:
return []
req_id = request.request_id
input_ids = request.prompt_token_ids
block_size = self.cache_config.block_size
prefix_block_key = []
num_cached_tokens = 0
if req_id in self.cache_info:
last_node, num_cached_tokens = self.cache_info[req_id]
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
block_keys = []
current_tokens = num_cached_tokens
while current_tokens <= len(input_ids) - block_size:
cur_block_key = get_hash_str(input_ids[current_tokens : current_tokens + block_size], prefix_block_key)
block_keys.append(cur_block_key)
current_tokens += block_size
prefix_block_key = [cur_block_key]
logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}")
matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids)
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}"
)
return matched_block_ids
def request_block_ids(self, task, block_size, dec_token_num, *args):
"""
Allocate blocks for a task.
@@ -898,6 +915,7 @@ class PrefixCacheManager:
req_id,
cpu_match_token_num,
gpu_match_token_num,
0,
input_token_num,
)
hit_info["gpu_cache_blocks"] = gpu_match_token_num // block_size
@@ -946,8 +964,8 @@ class PrefixCacheManager:
keys.append(node.hash_value)
node = node.parent
if req_id in self.cache_info:
del self.cache_info[req_id]
if req_id in self.req_to_radix_tree_info:
del self.req_to_radix_tree_info[req_id]
logger.info(f"release_block_ids: req_id {req_id} leaf_node {leaf_node}")
@@ -1187,6 +1205,7 @@ class PrefixCacheManager:
while True:
if len(self.gpu_lru_leaf_heap) == 0:
logger.info("free_block_ids_async: no more gpu leaf node available.")
break
if total_gpu_free_count >= need_block_num:
break
@@ -1239,6 +1258,9 @@ class PrefixCacheManager:
):
heapq.heappush(self.gpu_lru_leaf_heap, node)
self.gpu_lru_leaf_set.add(node)
logger.info(
f"free_block_ids_async: need_block_num {need_block_num}, free_block_num {total_gpu_free_count}."
)
# swap cache to cpu
if hash_value_gpu_block_ids_map:
@@ -1628,7 +1650,7 @@ class PrefixCacheManager:
match_token_num = match_token_num + block_size
current_match_node = child
# record request cache info
self.cache_info[req_id] = [child, match_token_num]
self.req_to_radix_tree_info[req_id] = [child, match_token_num]
else:
break
@@ -1952,7 +1974,7 @@ class PrefixCacheManager:
self.req_leaf_map.clear()
self.leaf_req_map.clear()
self.unfilled_req_block_map.clear()
self.cache_info.clear()
self.req_to_radix_tree_info.clear()
# reset gpu cache data structure
self.gpu_lru_leaf_heap.clear()