[BugFix] Refine the preparation of cpu and storage cache (#5777)

* Refine the preparation of cpu and storage cache

* fix error

* fix error

* up

* fix

* up docs

* fix unittest

* remove debug info
This commit is contained in:
jc
2026-01-05 10:13:30 +08:00
committed by GitHub
parent 95257c1dbd
commit e911ac2ce7
10 changed files with 156 additions and 149 deletions
+20 -63
View File
@@ -16,7 +16,6 @@
import copy
import threading
import time
import traceback
from collections import deque
from collections.abc import Iterable
@@ -276,7 +275,7 @@ class ResourceManagerV1(ResourceManager):
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
preempted_req.num_cached_blocks = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
preempted_reqs.append(preempted_req)
@@ -650,7 +649,9 @@ class ResourceManagerV1(ResourceManager):
break
else: # need to prefill
llm_logger.debug(
f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}"
f"scheduler prefill task in running queue: {request.request_id}, "
f"request.need_prefill_tokens {request.need_prefill_tokens},"
f"request.num_computed_tokens {request.num_computed_tokens}"
)
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
@@ -703,7 +704,10 @@ class ResourceManagerV1(ResourceManager):
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if self.cache_manager.num_cpu_blocks > 0:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
@@ -721,14 +725,6 @@ class ResourceManagerV1(ResourceManager):
if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
if (
self.config.cache_config.enable_prefix_caching
and self.config.cache_config.kvcache_storage_backend
and num_new_tokens >= self.config.cache_config.block_size
):
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
@@ -754,7 +750,10 @@ class ResourceManagerV1(ResourceManager):
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if self.config.cache_config.enable_prefix_caching:
if self.cache_manager.num_cpu_blocks > 0:
if (
self.cache_manager.num_cpu_blocks > 0
or self.config.cache_config.kvcache_storage_backend
):
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
@@ -772,14 +771,6 @@ class ResourceManagerV1(ResourceManager):
if not request.get("skip_allocate", False):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
if (
self.config.cache_config.enable_prefix_caching
and self.config.cache_config.kvcache_storage_backend
and num_new_tokens >= self.config.cache_config.block_size
):
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
@@ -918,11 +909,10 @@ class ResourceManagerV1(ResourceManager):
def get_prefix_cached_blocks(self, request: Request):
"""
set prefix cached information for the given request
Match and fetch cache for a task.
"""
try:
cache_prepare_time = time.time()
(common_block_ids, matched_token_num, hit_info) = self.cache_manager.request_match_blocks(
(common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks(
request, self.config.cache_config.block_size
)
@@ -933,8 +923,11 @@ class ResourceManagerV1(ResourceManager):
)
request.num_cached_tokens = matched_token_num
request.metrics.gpu_cache_token_num = hit_info["gpu_match_token_num"]
request.metrics.cpu_cache_token_num = hit_info["cpu_match_token_num"]
request.metrics.gpu_cache_token_num = metrics["gpu_match_token_num"]
request.metrics.cpu_cache_token_num = metrics["cpu_match_token_num"]
request.metrics.storage_cache_token_num = metrics["storage_match_token_num"]
request.metrics.cpu_cache_prepare_time = metrics["cpu_cache_prepare_time"]
request.metrics.storage_cache_prepare_time = metrics["storage_cache_prepare_time"]
request.cache_info = [matched_block_num, no_cache_block_num]
request.block_tables = common_block_ids
request.skip_allocate = False
@@ -949,45 +942,11 @@ class ResourceManagerV1(ResourceManager):
request.skip_allocate = True
else:
request.num_computed_tokens = matched_token_num
request.metrics.gpu_cpu_cache_prepare_time = time.time() - cache_prepare_time
return True
except Exception as e:
llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...")
return False
def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list = []):
"""
Match and prefetch the cached blocks from the storage backend.
TODO: merge this function into get_prefix_cached_blocks
"""
try:
tic = time.time()
req_id = request.request_id
llm_logger.debug(f"get_storage_cached_blocks start process req {req_id}")
matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids)
llm_logger.debug(
f"matched {len(matched_block_ids)} blocks from storage for req_id:{req_id}, "
f"cost_time: {time.time() - tic:.6f}s"
)
matched_token_num = len(matched_block_ids) * self.config.cache_config.block_size
request.metrics.storage_cache_token_num = matched_token_num
request.num_computed_tokens += matched_token_num
if request.num_computed_tokens == request.need_prefill_tokens:
request.num_computed_tokens = request.num_computed_tokens - self.config.cache_config.block_size
request.metrics.storage_cache_prepare_time = time.time() - tic
request.cache_info[0] += len(matched_block_ids) # matched_block_num
request.cache_info[1] -= len(matched_block_ids) # no_cache_block_num
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
# TODO: main_process_metrics.prefix_storage_cache_token_num.inc(matched_token_num)
return matched_block_ids
except Exception as e:
llm_logger.error(
f"get_storage_cached_blocks process req {req_id}, error: {e}, {str(traceback.format_exc())} "
)
return []
def add_request(self, request: Request) -> None:
with self.lock:
self.apply_async_preprocess(request)
@@ -1043,8 +1002,6 @@ class ResourceManagerV1(ResourceManager):
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)
if self.config.cache_config.enable_prefix_caching:
self.get_storage_cached_blocks(request, extra_gpu_block_ids)
request.block_tables.extend(extra_gpu_block_ids)
allocated_position = self.get_available_position()
request.idx = allocated_position
@@ -1143,7 +1100,7 @@ class ResourceManagerV1(ResourceManager):
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
self.cache_manager.release_block_ids(request)
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.num_cached_blocks :])
else:
self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = []