[BugFix] fix cache transfer tasks failure after cache cleared (#6202)

* [fix] fix cache transfer tasks failure after cache cleared

* [fix] fix submit_task

* [fix] fix cache manager hang when clearing prefix cache

* [fix] fix list_proxy has no clear method

* [fix] fix barrier

* [fix] add barrier0

* [fix] add cache_task_is_paused_signal

* [fix] fix condition

* [fix] fix cache transfer  sync and delay prefix cache tree clearing

* [fix] fix typo

* [chore] polish code

* [fix] revert only rank0 write kv_cache_status_signal

* [fix] fix thread pool and prefix cache manager hang

* [fix] add timeout for task_swapping_event

* [fix] tolerate prefix cache manager error while prefix tree is cleared

* [chore] add more log

* [fix] fix test_prefix_cache_manager

* [fix] fix prefix_cache_status_signal usage
This commit is contained in:
Yonghua Li
2026-02-08 15:33:56 +08:00
committed by GitHub
parent d6b3c722c1
commit 5ac5ecd0b0
7 changed files with 390 additions and 115 deletions
@@ -87,9 +87,16 @@ class EngineCacheQueue:
]
# Initialize barriers
self.barrier0_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier1_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier2_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.barrier3_init = [threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)]
self.pause_barrier_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.resume_barrier_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.swap_to_cpu_barrier1_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
@@ -135,9 +142,12 @@ class EngineCacheQueue:
callable=lambda idx: self.transfer_task_done_lock_init[idx],
proxytype=AcquirerProxy,
)
QueueManager.register("get_barrier0", callable=lambda idx: self.barrier0_init[idx])
QueueManager.register("get_barrier1", callable=lambda idx: self.barrier1_init[idx])
QueueManager.register("get_barrier2", callable=lambda idx: self.barrier2_init[idx])
QueueManager.register("get_barrier3", callable=lambda idx: self.barrier3_init[idx])
QueueManager.register("get_pause_barrier", callable=lambda idx: self.pause_barrier_init[idx])
QueueManager.register("get_resume_barrier", callable=lambda idx: self.resume_barrier_init[idx])
QueueManager.register(
"get_swap_to_cpu_barrier1",
callable=lambda idx: self.swap_to_cpu_barrier1_init[idx],
@@ -181,9 +191,12 @@ class EngineCacheQueue:
QueueManager.register("get_cache_sync_value")
QueueManager.register("get_transfer_task_lock")
QueueManager.register("get_transfer_task_done_lock")
QueueManager.register("get_barrier0")
QueueManager.register("get_barrier1")
QueueManager.register("get_barrier2")
QueueManager.register("get_barrier3")
QueueManager.register("get_pause_barrier")
QueueManager.register("get_resume_barrier")
QueueManager.register("get_swap_to_cpu_barrier1")
QueueManager.register("get_swap_to_cpu_barrier2")
QueueManager.register("get_swap_to_gpu_barrier1")
@@ -202,9 +215,12 @@ class EngineCacheQueue:
self.task_done_lock = self.manager.get_transfer_task_done_lock(self.local_data_parallel_id)
# Get barrier proxies
self.barrier0 = self.manager.get_barrier0(self.local_data_parallel_id)
self.barrier1 = self.manager.get_barrier1(self.local_data_parallel_id)
self.barrier2 = self.manager.get_barrier2(self.local_data_parallel_id)
self.barrier3 = self.manager.get_barrier3(self.local_data_parallel_id)
self.pause_barrier = self.manager.get_pause_barrier(self.local_data_parallel_id)
self.resume_barrier = self.manager.get_resume_barrier(self.local_data_parallel_id)
self.swap_to_cpu_barrier1 = self.manager.get_swap_to_cpu_barrier1(self.local_data_parallel_id)
self.swap_to_cpu_barrier2 = self.manager.get_swap_to_cpu_barrier2(self.local_data_parallel_id)
self.swap_to_gpu_barrier1 = self.manager.get_swap_to_gpu_barrier1(self.local_data_parallel_id)
@@ -281,6 +297,19 @@ class EngineCacheQueue:
self.task_lock.release()
return data, read_finish
def clear_transfer_task(self):
self.task_lock.acquire()
if 0 < self.task_sync_value.get() < self.total_num:
self.task_lock.release()
while 0 < self.task_sync_value.get() < self.total_num:
time.sleep(0.001)
self.task_lock.acquire()
self.task_sync_value.set(0)
while len(self.transfer_task_queue) > 0:
self.transfer_task_queue.pop(0)
logger.info("clear_transfer_task: done")
self.task_lock.release()
def put_transfer_done_signal(self, item):
"""
put swap result
@@ -311,3 +340,13 @@ class EngineCacheQueue:
except Exception as e:
logger.error(f"empty function meets error: {e}, {str(traceback.format_exc())}")
raise e
def result_queue_empty(self):
"""
check if result queue is empty
"""
try:
return len(self.tansfer_done_queue) == 0
except Exception as e:
logger.error(f"result_queue_empty function meets error: {e}, {str(traceback.format_exc())}")
raise e