[BugFix] fix cache transfer tasks failure after cache cleared (#6202)

* [fix] fix cache transfer tasks failure after cache cleared

* [fix] fix submit_task

* [fix] fix cache manager hang when clearing prefix cache

* [fix] fix list_proxy has no clear method

* [fix] fix barrier

* [fix] add barrier0

* [fix] add cache_task_is_paused_signal

* [fix] fix condition

* [fix] fix cache transfer  sync and delay prefix cache tree clearing

* [fix] fix typo

* [chore] polish code

* [fix] revert only rank0 write kv_cache_status_signal

* [fix] fix thread pool and prefix cache manager hang

* [fix] add timeout for task_swapping_event

* [fix] tolerate prefix cache manager error while prefix tree is cleared

* [chore] add more log

* [fix] fix test_prefix_cache_manager

* [fix] fix prefix_cache_status_signal usage
This commit is contained in:
Yonghua Li
2026-02-08 15:33:56 +08:00
committed by GitHub
parent d6b3c722c1
commit 5ac5ecd0b0
7 changed files with 390 additions and 115 deletions
@@ -241,6 +241,25 @@ class CacheTransferManager:
create=False,
)
# NOTE: `cache_task_is_paused_signal` indicates if do_data_transfer thread
# of the FIRST rank (rank#0) has received a pause signal
self.cache_task_is_paused_signal = IPCSignal(
name="cache_task_is_paused",
array=np.zeros([1], dtype=np.int32),
dtype=np.int32,
suffix=args.engine_worker_queue_port,
create=False,
)
# NOTE: `cache_task_inflight_signal` indicates if do_data_transfer thread
# of each rank has finished remaining tasks and finally paused
self.cache_task_inflight_signal = IPCSignal(
name="cache_task_inflight",
array=np.zeros([self.n_ranks], dtype=np.int32),
dtype=np.int32,
suffix=args.engine_worker_queue_port,
create=False,
)
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
array_size = min(max_chips_per_node, args.mp_num)
worker_healthy_live_array = np.zeros(shape=[array_size], dtype=np.int32)
@@ -262,6 +281,9 @@ class CacheTransferManager:
)
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
self.is_paused = False # transfer manager state
self.inflight = 0 # number of inflight transfer tasks
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_transfer_inited_signal = IPCSignal(
name="cache_transfer_inited_signal",
@@ -1022,6 +1044,16 @@ class CacheTransferManager:
return True, ""
def submit_task(self, thread_pool: concurrent.futures.ThreadPoolExecutor, task_fn, *args):
def inflight_task(fn, *args):
try:
return fn(*args)
finally:
self.inflight -= 1
thread_pool.submit(inflight_task, task_fn, *args)
def do_data_transfer(self):
"""
do data transfer task
@@ -1034,6 +1066,27 @@ class CacheTransferManager:
while True:
try:
if self.rank == 0:
self.cache_task_is_paused_signal.value[0] = 1 if self.is_paused else 0
if self.n_ranks > 1:
self.cache_task_queue.barrier0.wait()
if self.rank == 0:
self.cache_task_queue.barrier0.reset()
# Ensure all ranks synchronically do one of the following things:
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
if self.cache_task_is_paused_signal.value[0] == 1:
# wait for inflight tasks to finish first
while self.inflight != 0:
time.sleep(0.1)
# mark the current rank as not having inflight tasks
self.cache_task_inflight_signal.value[self.rank] = 0
time.sleep(1)
continue
else:
self.cache_task_inflight_signal.value[self.rank] = 1
if self.rank == 0:
if not self.cache_task_queue.empty():
self.cache_task_broadcast_signal.value[0] = 1
@@ -1041,7 +1094,9 @@ class CacheTransferManager:
self.cache_task_queue.barrier1.wait()
if self.rank == 0:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
self.inflight += 1
data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"do_data_transfer: {data}")
if read_finish:
@@ -1049,7 +1104,8 @@ class CacheTransferManager:
event_type, event_args = data[0], data[1:]
if event_type.value == CacheStatus.SWAP2CPU.value:
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_cpu_thread_pool.submit(
self.submit_task(
self.swap_to_cpu_thread_pool,
self._do_swap_to_cpu_task,
swap_node_ids,
gpu_block_id,
@@ -1059,7 +1115,8 @@ class CacheTransferManager:
)
elif event_type.value == CacheStatus.SWAP2GPU.value:
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_gpu_thread_pool.submit(
self.submit_task(
self.swap_to_gpu_thread_pool,
self._do_swap_to_gpu_task,
swap_node_ids,
gpu_block_id,
@@ -1069,13 +1126,15 @@ class CacheTransferManager:
)
elif event_type.value == CacheStatus.STORAGE2GPU.value:
read_storage_task = event_args[0]
self.read_storage_thread_pool.submit(
self.submit_task(
self.read_storage_thread_pool,
self.read_storage_task,
read_storage_task,
)
elif event_type.value == CacheStatus.GPU2STORAGE.value:
write_storage_task = event_args[0]
self.write_back_storage_thread_pool.submit(
self.submit_task(
self.write_back_storage_thread_pool,
self.write_back_storage_task,
write_storage_task,
)
@@ -1247,6 +1306,9 @@ class CacheTransferManager:
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
try:
# wait for inflight transfer tasks to finish and pause transfer manager
self.pause()
# clear cpu caches
logger.info("[RL] start clearing caches")
logger.debug("[RL] start clearing cpu caches")
@@ -1336,15 +1398,40 @@ class CacheTransferManager:
time.sleep(0.1)
logger.info("[RL] all ranks restored caches!")
# resume transfer
self.resume()
# set kv_cache_status_signal
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
self._log_memory("after restoring caches")
except Exception as e:
logger.error(f"[RL] failed to restore caches: {e}")
time.sleep(0.1)
def pause(self):
if self.n_ranks > 1:
self.cache_task_queue.pause_barrier.wait()
if self.rank == 0:
self.cache_task_queue.pause_barrier.reset()
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
self.is_paused = True
while np.sum(self.cache_task_inflight_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
def resume(self):
if self.n_ranks > 1:
self.cache_task_queue.resume_barrier.wait()
if self.rank == 0:
self.cache_task_queue.resume_barrier.reset()
self.is_paused = False
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
time.sleep(0.1)
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)