mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix cache transfer tasks failure after cache cleared (#6202)
* [fix] fix cache transfer tasks failure after cache cleared * [fix] fix submit_task * [fix] fix cache manager hang when clearing prefix cache * [fix] fix list_proxy has no clear method * [fix] fix barrier * [fix] add barrier0 * [fix] add cache_task_is_paused_signal * [fix] fix condition * [fix] fix cache transfer sync and delay prefix cache tree clearing * [fix] fix typo * [chore] polish code * [fix] revert only rank0 write kv_cache_status_signal * [fix] fix thread pool and prefix cache manager hang * [fix] add timeout for task_swapping_event * [fix] tolerate prefix cache manager error while prefix tree is cleared * [chore] add more log * [fix] fix test_prefix_cache_manager * [fix] fix prefix_cache_status_signal usage
This commit is contained in:
@@ -120,6 +120,7 @@ class PrefixCacheManager:
|
||||
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.free_cpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.gpu_free_task_future = None
|
||||
self.cpu_free_future = None
|
||||
self.cache_status_lock = Lock()
|
||||
|
||||
logger.info(
|
||||
@@ -192,6 +193,22 @@ class PrefixCacheManager:
|
||||
create=True,
|
||||
)
|
||||
|
||||
self.cache_task_is_paused_signal = IPCSignal(
|
||||
name="cache_task_is_paused",
|
||||
array=np.zeros([1], dtype=np.int32),
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=True,
|
||||
)
|
||||
|
||||
self.cache_task_inflight_signal = IPCSignal(
|
||||
name="cache_task_inflight",
|
||||
array=np.zeros([tensor_parallel_size], dtype=np.int32),
|
||||
dtype=np.int32,
|
||||
suffix=engine_worker_queue_port,
|
||||
create=True,
|
||||
)
|
||||
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=(pod_ip, cache_config.local_cache_queue_port),
|
||||
authkey=b"cache_queue_service",
|
||||
@@ -547,7 +564,12 @@ class PrefixCacheManager:
|
||||
"""
|
||||
sync swap task
|
||||
"""
|
||||
self.task_swapping_event[transfer_task_id].wait()
|
||||
while True:
|
||||
flag = self.task_swapping_event[transfer_task_id].wait(timeout=0.1)
|
||||
if flag or self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
if not flag:
|
||||
logger.info(f"swap task timeout because prefix tree status is not normal: {transfer_task_id}")
|
||||
break
|
||||
del self.task_swapping_event[transfer_task_id]
|
||||
|
||||
def _check_validity(self, req_id, match_gpu_blocks_num, expected_block_num):
|
||||
@@ -674,8 +696,13 @@ class PrefixCacheManager:
|
||||
self.req_to_radix_tree_info[req_id] = [leaf_node, can_cache_computed_tokens]
|
||||
task.num_cached_blocks = can_cache_computed_tokens // block_size
|
||||
except Exception as e:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"update_cache_blocks: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def is_chunked_mm_input(self, mm_inputs, matched_token_num):
|
||||
"""
|
||||
@@ -857,8 +884,13 @@ class PrefixCacheManager:
|
||||
task.num_cached_blocks = len(common_block_ids)
|
||||
return common_block_ids, match_token_num, metrics
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"request_match_blocks: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
def request_block_ids(self, task, block_size, dec_token_num, *args):
|
||||
"""
|
||||
@@ -959,8 +991,13 @@ class PrefixCacheManager:
|
||||
)
|
||||
return common_block_ids, unique_block_ids, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"request_block_ids: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"request_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def release_block_ids_async(self, task):
|
||||
"""
|
||||
@@ -1015,8 +1052,13 @@ class PrefixCacheManager:
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"release_block_ids: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def write_cache_to_storage(self, request: Request):
|
||||
"""
|
||||
@@ -1140,8 +1182,13 @@ class PrefixCacheManager:
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"free_nodes_directly: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"free_nodes_directly: error: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
def _handle_free_gpu_node_without_cpu(self, node):
|
||||
"""
|
||||
@@ -1312,15 +1359,17 @@ class PrefixCacheManager:
|
||||
|
||||
# swap cache to cpu
|
||||
if hash_value_gpu_block_ids_map:
|
||||
cpu_free_future = None
|
||||
self.cpu_free_future = None
|
||||
if total_gpu_free_count > len(self.cpu_free_block_list):
|
||||
cpu_free_count = total_gpu_free_count
|
||||
if cpu_free_count < need_block_num:
|
||||
cpu_free_count = need_block_num
|
||||
cpu_free_future = self.free_cpu_executor_pool.submit(self.free_cpu_block_ids, cpu_free_count)
|
||||
self.cpu_free_future = self.free_cpu_executor_pool.submit(
|
||||
self.free_cpu_block_ids, cpu_free_count
|
||||
)
|
||||
self.gpu_free_task_future = self.free_gpu_executor_pool.submit(
|
||||
self._evict_cache_async,
|
||||
cpu_free_future,
|
||||
self.cpu_free_future,
|
||||
total_gpu_free_count,
|
||||
hash_value_gpu_block_ids_map,
|
||||
hash_value_block_ids_map,
|
||||
@@ -1331,8 +1380,13 @@ class PrefixCacheManager:
|
||||
else:
|
||||
self.gpu_free_task_future = None
|
||||
except Exception as e:
|
||||
logger.error(f"free_block_ids_async: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"free_block_ids_async: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"free_block_ids_async: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def free_cpu_block_ids(self, need_block_num):
|
||||
"""
|
||||
@@ -2004,25 +2058,40 @@ class PrefixCacheManager:
|
||||
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
if self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
logger.warning(
|
||||
f"recv_data_transfer_result: an error occured while prefix tree status is not normal, ignore it. {e}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"recv_data_transfer_result: {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the RadixTree.
|
||||
"""
|
||||
logger.info(f"wait for cache_task_inflight_signal to reset {self.cache_task_inflight_signal.value}")
|
||||
while np.sum(self.cache_task_inflight_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
|
||||
if len(self.node_map) == 0:
|
||||
return
|
||||
logger.info("wait for recv_data_transfer_result done")
|
||||
while not self.cache_task_queue.result_queue_empty():
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("Resetting the RadixTree!")
|
||||
logger.info(f"Resetting the RadixTree! node_map len {len(self.node_map)}")
|
||||
|
||||
# wait for swap tasks to finish
|
||||
logger.info("waiting for cpu_free_future to finish")
|
||||
if self.cpu_free_future is not None:
|
||||
self.cpu_free_future.result()
|
||||
self.cpu_free_future = None
|
||||
logger.info("reset cpu_free_future")
|
||||
|
||||
logger.info("waiting for gpu_free_task_future to finish")
|
||||
if self.gpu_free_task_future is not None:
|
||||
self.gpu_free_task_future.result()
|
||||
self.gpu_free_task_future = None
|
||||
for event in list(self.task_swapping_event.values()):
|
||||
event.wait()
|
||||
self.gpu_free_task_future = None
|
||||
logger.info("reset gpu_free_task_future")
|
||||
|
||||
self.task_swapping_event.clear()
|
||||
|
||||
# clear node map
|
||||
|
||||
Reference in New Issue
Block a user