[BugFix] fix cache transfer tasks failure after cache cleared (#6202)

* [fix] fix cache transfer tasks failure after cache cleared

* [fix] fix submit_task

* [fix] fix cache manager hang when clearing prefix cache

* [fix] fix list_proxy has no clear method

* [fix] fix barrier

* [fix] add barrier0

* [fix] add cache_task_is_paused_signal

* [fix] fix condition

* [fix] fix cache transfer  sync and delay prefix cache tree clearing

* [fix] fix typo

* [chore] polish code

* [fix] revert only rank0 write kv_cache_status_signal

* [fix] fix thread pool and prefix cache manager hang

* [fix] add timeout for task_swapping_event

* [fix] tolerate prefix cache manager error while prefix tree is cleared

* [chore] add more log

* [fix] fix test_prefix_cache_manager

* [fix] fix prefix_cache_status_signal usage
This commit is contained in:
Yonghua Li
2026-02-08 15:33:56 +08:00
committed by GitHub
parent d6b3c722c1
commit 5ac5ecd0b0
7 changed files with 390 additions and 115 deletions
+138 -83
View File
@@ -566,57 +566,85 @@ class EngineClient:
2 : worker update finish and notify client
"""
with self.clear_update_lock:
if self.enable_prefix_caching:
skip_action = False
return_code = None
return_body = {}
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
skip_action = True
return_code = 200
return_body = {**self.data_parallel_info, "msg": "model weight is updated"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
skip_action = True
return_code = 400
return_body = {**self.data_parallel_info, "msg": "worker is updating model weight already"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
skip_action = True
return_code = 403
return_body = {**self.data_parallel_info, "msg": "worker is clearing model weight, cannot update now"}
if not skip_action:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
api_server_logger.info(
f"[RL] >>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f"[RL] >>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"[RL] ... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"[RL] ... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
if weight_updated and (not self.enable_cache_transfer or cache_updated):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return_code = 404
return_body = {**self.data_parallel_info, "msg": "update model weight timeout"}
else:
api_server_logger.info(
f"[RL] <<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f"[RL] <<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
else:
api_server_logger.info(
f"[RL] !!! skip updating model weight for the following reason: {return_body.get('msg')}"
)
if timeout >= 0 and self.enable_prefix_caching:
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(
f">>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
f"[RL] >>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
api_server_logger.info(
f"[RL] ... prefix tree status: {self.prefix_tree_status_signal.value[0]}"
)
time.sleep(1)
timeout -= 1
if timeout < 0:
return 404, {**self.data_parallel_info, "msg": "update prefix tree timeout"}
api_server_logger.info(
f"<<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
return_code = 404
return_body = {**self.data_parallel_info, "msg": "update prefix tree timeout"}
else:
api_server_logger.info(
f"[RL] <<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return 200, {**self.data_parallel_info, "msg": "model weight is updated"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return 400, {**self.data_parallel_info, "msg": "worker is updating model weight already"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return 403, {**self.data_parallel_info, "msg": "worker is clearing model weight, cannot update now"}
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
api_server_logger.info(
f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
if weight_updated and (not self.enable_cache_transfer or cache_updated):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return 404, {**self.data_parallel_info, "msg": "update model weight timeout"}
api_server_logger.info(
f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
return 200, {**self.data_parallel_info, "msg": "update model weight successfully"}
if return_code:
if return_code == 404:
api_server_logger.error("[RL] ??? updating model weight time out")
return return_code, return_body
else:
return 200, {**self.data_parallel_info, "msg": "update model weight successfully"}
def clear_load_weight(self, timeout=300):
"""
@@ -626,57 +654,84 @@ class EngineClient:
"""
with self.clear_update_lock:
if self.enable_prefix_caching:
skip_action = False
return_code = None
return_body = {}
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
skip_action = True
return_code = 200
return_body = {**self.data_parallel_info, "msg": "model weight is cleared"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
skip_action = True
return_code = 400
return_body = {**self.data_parallel_info, "msg": "worker is clearing model weight already"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
skip_action = True
return_code = 403
return_body = {**self.data_parallel_info, "msg": "worker is updating model weight, cannot clear now"}
if not skip_action:
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
api_server_logger.info(
f"[RL] >>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"[RL] >>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"[RL] ... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"[RL] ... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return_code = 404
return_body = {**self.data_parallel_info, "msg": "clear model weight timeout"}
else:
api_server_logger.info(
f"[RL] <<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f"[RL] <<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
else:
api_server_logger.info(
f"[RL] !!! skip clearing model weight for the following reason: {return_body.get('msg')}"
)
if timeout >= 0 and self.enable_prefix_caching:
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(
f">>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
f"[RL] >>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
api_server_logger.info(
f"[RL] ... prefix tree status: {self.prefix_tree_status_signal.value[0]}"
)
time.sleep(1)
timeout -= 1
if timeout < 0:
return 404, {**self.data_parallel_info, "msg": "clear prefix tree timeout"}
api_server_logger.info(
f"<<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return 200, {**self.data_parallel_info, "msg": "model weight is cleared"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return 400, {**self.data_parallel_info, "msg": "worker is clearing model weight already"}
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return 403, {**self.data_parallel_info, "msg": "worker is updating model weight, cannot clear now"}
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
api_server_logger.info(
f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return 404, {**self.data_parallel_info, "msg": "clear model weight timeout"}
api_server_logger.info(
f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
return 200, {**self.data_parallel_info, "msg": "clear model weight successfully"}
return_code = 404
return_body = {**self.data_parallel_info, "msg": "clear prefix tree timeout"}
else:
api_server_logger.info(
f"[RL] <<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
if return_code:
if return_code == 404:
api_server_logger.error("[RL] ??? clearing model weight time out")
return return_code, return_body
else:
return 200, {**self.data_parallel_info, "msg": "clear model weight successfully"}
def check_model_weight_status(self):
return self.model_weights_status_signal.value[0] < 0