mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix cache transfer tasks failure after cache cleared (#6202)
* [fix] fix cache transfer tasks failure after cache cleared * [fix] fix submit_task * [fix] fix cache manager hang when clearing prefix cache * [fix] fix list_proxy has no clear method * [fix] fix barrier * [fix] add barrier0 * [fix] add cache_task_is_paused_signal * [fix] fix condition * [fix] fix cache transfer sync and delay prefix cache tree clearing * [fix] fix typo * [chore] polish code * [fix] revert only rank0 write kv_cache_status_signal * [fix] fix thread pool and prefix cache manager hang * [fix] add timeout for task_swapping_event * [fix] tolerate prefix cache manager error while prefix tree is cleared * [chore] add more log * [fix] fix test_prefix_cache_manager * [fix] fix prefix_cache_status_signal usage
This commit is contained in:
@@ -241,6 +241,25 @@ class CacheTransferManager:
|
||||
create=False,
|
||||
)
|
||||
|
||||
# NOTE: `cache_task_is_paused_signal` indicates if do_data_transfer thread
|
||||
# of the FIRST rank (rank#0) has received a pause signal
|
||||
self.cache_task_is_paused_signal = IPCSignal(
|
||||
name="cache_task_is_paused",
|
||||
array=np.zeros([1], dtype=np.int32),
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
# NOTE: `cache_task_inflight_signal` indicates if do_data_transfer thread
|
||||
# of each rank has finished remaining tasks and finally paused
|
||||
self.cache_task_inflight_signal = IPCSignal(
|
||||
name="cache_task_inflight",
|
||||
array=np.zeros([self.n_ranks], dtype=np.int32),
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
array_size = min(max_chips_per_node, args.mp_num)
|
||||
worker_healthy_live_array = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
@@ -262,6 +281,9 @@ class CacheTransferManager:
|
||||
)
|
||||
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
|
||||
|
||||
self.is_paused = False # transfer manager state
|
||||
self.inflight = 0 # number of inflight transfer tasks
|
||||
|
||||
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_transfer_inited_signal = IPCSignal(
|
||||
name="cache_transfer_inited_signal",
|
||||
@@ -1022,6 +1044,16 @@ class CacheTransferManager:
|
||||
|
||||
return True, ""
|
||||
|
||||
def submit_task(self, thread_pool: concurrent.futures.ThreadPoolExecutor, task_fn, *args):
|
||||
|
||||
def inflight_task(fn, *args):
|
||||
try:
|
||||
return fn(*args)
|
||||
finally:
|
||||
self.inflight -= 1
|
||||
|
||||
thread_pool.submit(inflight_task, task_fn, *args)
|
||||
|
||||
def do_data_transfer(self):
|
||||
"""
|
||||
do data transfer task
|
||||
@@ -1034,6 +1066,27 @@ class CacheTransferManager:
|
||||
|
||||
while True:
|
||||
try:
|
||||
if self.rank == 0:
|
||||
self.cache_task_is_paused_signal.value[0] = 1 if self.is_paused else 0
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier0.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.barrier0.reset()
|
||||
|
||||
# Ensure all ranks synchronically do one of the following things:
|
||||
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
|
||||
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
|
||||
if self.cache_task_is_paused_signal.value[0] == 1:
|
||||
# wait for inflight tasks to finish first
|
||||
while self.inflight != 0:
|
||||
time.sleep(0.1)
|
||||
# mark the current rank as not having inflight tasks
|
||||
self.cache_task_inflight_signal.value[self.rank] = 0
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
self.cache_task_inflight_signal.value[self.rank] = 1
|
||||
|
||||
if self.rank == 0:
|
||||
if not self.cache_task_queue.empty():
|
||||
self.cache_task_broadcast_signal.value[0] = 1
|
||||
@@ -1041,7 +1094,9 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier1.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
self.inflight += 1
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||
logger.debug(f"do_data_transfer: {data}")
|
||||
if read_finish:
|
||||
@@ -1049,7 +1104,8 @@ class CacheTransferManager:
|
||||
event_type, event_args = data[0], data[1:]
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self.submit_task(
|
||||
self.swap_to_cpu_thread_pool,
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
@@ -1059,7 +1115,8 @@ class CacheTransferManager:
|
||||
)
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self.submit_task(
|
||||
self.swap_to_gpu_thread_pool,
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
@@ -1069,13 +1126,15 @@ class CacheTransferManager:
|
||||
)
|
||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
read_storage_task = event_args[0]
|
||||
self.read_storage_thread_pool.submit(
|
||||
self.submit_task(
|
||||
self.read_storage_thread_pool,
|
||||
self.read_storage_task,
|
||||
read_storage_task,
|
||||
)
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
write_storage_task = event_args[0]
|
||||
self.write_back_storage_thread_pool.submit(
|
||||
self.submit_task(
|
||||
self.write_back_storage_thread_pool,
|
||||
self.write_back_storage_task,
|
||||
write_storage_task,
|
||||
)
|
||||
@@ -1247,6 +1306,9 @@ class CacheTransferManager:
|
||||
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
|
||||
try:
|
||||
# wait for inflight transfer tasks to finish and pause transfer manager
|
||||
self.pause()
|
||||
|
||||
# clear cpu caches
|
||||
logger.info("[RL] start clearing caches")
|
||||
logger.debug("[RL] start clearing cpu caches")
|
||||
@@ -1336,15 +1398,40 @@ class CacheTransferManager:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] all ranks restored caches!")
|
||||
|
||||
# resume transfer
|
||||
self.resume()
|
||||
|
||||
# set kv_cache_status_signal
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
self._log_memory("after restoring caches")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RL] failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
def pause(self):
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.pause_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.pause_barrier.reset()
|
||||
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
|
||||
self.is_paused = True
|
||||
while np.sum(self.cache_task_inflight_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
|
||||
|
||||
def resume(self):
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.resume_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.resume_barrier.reset()
|
||||
self.is_paused = False
|
||||
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
|
||||
|
||||
def _log_memory(self, context: str):
|
||||
"""Log current GPU memory usage."""
|
||||
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)
|
||||
|
||||
Reference in New Issue
Block a user