[BugFix] fix cache transfer tasks failure after cache cleared (#6202)

* [fix] fix cache transfer tasks failure after cache cleared

* [fix] fix submit_task

* [fix] fix cache manager hang when clearing prefix cache

* [fix] fix list_proxy has no clear method

* [fix] fix barrier

* [fix] add barrier0

* [fix] add cache_task_is_paused_signal

* [fix] fix condition

* [fix] fix cache transfer  sync and delay prefix cache tree clearing

* [fix] fix typo

* [chore] polish code

* [fix] revert only rank0 write kv_cache_status_signal

* [fix] fix thread pool and prefix cache manager hang

* [fix] add timeout for task_swapping_event

* [fix] tolerate prefix cache manager error while prefix tree is cleared

* [chore] add more log

* [fix] fix test_prefix_cache_manager

* [fix] fix prefix_cache_status_signal usage
This commit is contained in:
Yonghua Li
2026-02-08 15:33:56 +08:00
committed by GitHub
parent d6b3c722c1
commit 5ac5ecd0b0
7 changed files with 390 additions and 115 deletions
@@ -120,6 +120,7 @@ class PrefixCacheManager:
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_cpu_executor_pool = ThreadPoolExecutor(max_workers=1)
self.gpu_free_task_future = None
self.cpu_free_future = None
self.cache_status_lock = Lock()
logger.info(
@@ -192,6 +193,22 @@ class PrefixCacheManager:
create=True,
)
self.cache_task_is_paused_signal = IPCSignal(
name="cache_task_is_paused",
array=np.zeros([1], dtype=np.int32),
dtype=np.int32,
suffix=engine_worker_queue_port,
create=True,
)
self.cache_task_inflight_signal = IPCSignal(
name="cache_task_inflight",
array=np.zeros([tensor_parallel_size], dtype=np.int32),
dtype=np.int32,
suffix=engine_worker_queue_port,
create=True,
)
self.cache_task_queue = EngineCacheQueue(
address=(pod_ip, cache_config.local_cache_queue_port),
authkey=b"cache_queue_service",
@@ -547,7 +564,12 @@ class PrefixCacheManager:
"""
sync swap task
"""
self.task_swapping_event[transfer_task_id].wait()
while True:
flag = self.task_swapping_event[transfer_task_id].wait(timeout=0.1)
if flag or self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
if not flag:
logger.info(f"swap task timeout because prefix tree status is not normal: {transfer_task_id}")
break
del self.task_swapping_event[transfer_task_id]
def _check_validity(self, req_id, match_gpu_blocks_num, expected_block_num):
@@ -674,8 +696,13 @@ class PrefixCacheManager:
self.req_to_radix_tree_info[req_id] = [leaf_node, can_cache_computed_tokens]
task.num_cached_blocks = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"update_cache_blocks: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def is_chunked_mm_input(self, mm_inputs, matched_token_num):
"""
@@ -857,8 +884,13 @@ class PrefixCacheManager:
task.num_cached_blocks = len(common_block_ids)
return common_block_ids, match_token_num, metrics
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"request_match_blocks: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
def request_block_ids(self, task, block_size, dec_token_num, *args):
"""
@@ -959,8 +991,13 @@ class PrefixCacheManager:
)
return common_block_ids, unique_block_ids, hit_info
except Exception as e:
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"request_block_ids: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def release_block_ids_async(self, task):
"""
@@ -1015,8 +1052,13 @@ class PrefixCacheManager:
)
return
except Exception as e:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"release_block_ids: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def write_cache_to_storage(self, request: Request):
"""
@@ -1140,8 +1182,13 @@ class PrefixCacheManager:
else:
break
except Exception as e:
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"free_nodes_directly: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
raise e
def _handle_free_gpu_node_without_cpu(self, node):
"""
@@ -1312,15 +1359,17 @@ class PrefixCacheManager:
# swap cache to cpu
if hash_value_gpu_block_ids_map:
cpu_free_future = None
self.cpu_free_future = None
if total_gpu_free_count > len(self.cpu_free_block_list):
cpu_free_count = total_gpu_free_count
if cpu_free_count < need_block_num:
cpu_free_count = need_block_num
cpu_free_future = self.free_cpu_executor_pool.submit(self.free_cpu_block_ids, cpu_free_count)
self.cpu_free_future = self.free_cpu_executor_pool.submit(
self.free_cpu_block_ids, cpu_free_count
)
self.gpu_free_task_future = self.free_gpu_executor_pool.submit(
self._evict_cache_async,
cpu_free_future,
self.cpu_free_future,
total_gpu_free_count,
hash_value_gpu_block_ids_map,
hash_value_block_ids_map,
@@ -1331,8 +1380,13 @@ class PrefixCacheManager:
else:
self.gpu_free_task_future = None
except Exception as e:
logger.error(f"free_block_ids_async: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"free_block_ids_async: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"free_block_ids_async: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def free_cpu_block_ids(self, need_block_num):
"""
@@ -2004,25 +2058,40 @@ class PrefixCacheManager:
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
logger.warning(
f"recv_data_transfer_result: an error occured while prefix tree status is not normal, ignore it. {e}"
)
else:
logger.error(f"recv_data_transfer_result: {str(traceback.format_exc())}")
raise e
def reset(self):
"""
Reset the RadixTree.
"""
logger.info(f"wait for cache_task_inflight_signal to reset {self.cache_task_inflight_signal.value}")
while np.sum(self.cache_task_inflight_signal.value) != 0:
time.sleep(0.1)
if len(self.node_map) == 0:
return
logger.info("wait for recv_data_transfer_result done")
while not self.cache_task_queue.result_queue_empty():
time.sleep(0.1)
logger.info("Resetting the RadixTree!")
logger.info(f"Resetting the RadixTree! node_map len {len(self.node_map)}")
# wait for swap tasks to finish
logger.info("waiting for cpu_free_future to finish")
if self.cpu_free_future is not None:
self.cpu_free_future.result()
self.cpu_free_future = None
logger.info("reset cpu_free_future")
logger.info("waiting for gpu_free_task_future to finish")
if self.gpu_free_task_future is not None:
self.gpu_free_task_future.result()
self.gpu_free_task_future = None
for event in list(self.task_swapping_event.values()):
event.wait()
self.gpu_free_task_future = None
logger.info("reset gpu_free_task_future")
self.task_swapping_event.clear()
# clear node map