mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix cache transfer tasks failure after cache cleared (#6202)
* [fix] fix cache transfer tasks failure after cache cleared * [fix] fix submit_task * [fix] fix cache manager hang when clearing prefix cache * [fix] fix list_proxy has no clear method * [fix] fix barrier * [fix] add barrier0 * [fix] add cache_task_is_paused_signal * [fix] fix condition * [fix] fix cache transfer sync and delay prefix cache tree clearing * [fix] fix typo * [chore] polish code * [fix] revert only rank0 write kv_cache_status_signal * [fix] fix thread pool and prefix cache manager hang * [fix] add timeout for task_swapping_event * [fix] tolerate prefix cache manager error while prefix tree is cleared * [chore] add more log * [fix] fix test_prefix_cache_manager * [fix] fix prefix_cache_status_signal usage
This commit is contained in:
@@ -566,57 +566,85 @@ class EngineClient:
|
||||
2 : worker update finish and notify client
|
||||
"""
|
||||
with self.clear_update_lock:
|
||||
if self.enable_prefix_caching:
|
||||
|
||||
skip_action = False
|
||||
return_code = None
|
||||
return_body = {}
|
||||
|
||||
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
skip_action = True
|
||||
return_code = 200
|
||||
return_body = {**self.data_parallel_info, "msg": "model weight is updated"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
skip_action = True
|
||||
return_code = 400
|
||||
return_body = {**self.data_parallel_info, "msg": "worker is updating model weight already"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
skip_action = True
|
||||
return_code = 403
|
||||
return_body = {**self.data_parallel_info, "msg": "worker is clearing model weight, cannot update now"}
|
||||
|
||||
if not skip_action:
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||
api_server_logger.info(
|
||||
f"[RL] >>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] >>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"[RL] ... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] ... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
|
||||
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
|
||||
if weight_updated and (not self.enable_cache_transfer or cache_updated):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return_code = 404
|
||||
return_body = {**self.data_parallel_info, "msg": "update model weight timeout"}
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] <<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] <<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] !!! skip updating model weight for the following reason: {return_body.get('msg')}"
|
||||
)
|
||||
|
||||
if timeout >= 0 and self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
|
||||
api_server_logger.info(
|
||||
f">>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
f"[RL] >>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f"[RL] ... prefix tree status: {self.prefix_tree_status_signal.value[0]}"
|
||||
)
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return 404, {**self.data_parallel_info, "msg": "update prefix tree timeout"}
|
||||
api_server_logger.info(
|
||||
f"<<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
return_code = 404
|
||||
return_body = {**self.data_parallel_info, "msg": "update prefix tree timeout"}
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] <<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
|
||||
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
return 200, {**self.data_parallel_info, "msg": "model weight is updated"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
return 400, {**self.data_parallel_info, "msg": "worker is updating model weight already"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
return 403, {**self.data_parallel_info, "msg": "worker is clearing model weight, cannot update now"}
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||
api_server_logger.info(
|
||||
f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
|
||||
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
|
||||
if weight_updated and (not self.enable_cache_transfer or cache_updated):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return 404, {**self.data_parallel_info, "msg": "update model weight timeout"}
|
||||
api_server_logger.info(
|
||||
f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
return 200, {**self.data_parallel_info, "msg": "update model weight successfully"}
|
||||
if return_code:
|
||||
if return_code == 404:
|
||||
api_server_logger.error("[RL] ??? updating model weight time out")
|
||||
return return_code, return_body
|
||||
else:
|
||||
return 200, {**self.data_parallel_info, "msg": "update model weight successfully"}
|
||||
|
||||
def clear_load_weight(self, timeout=300):
|
||||
"""
|
||||
@@ -626,57 +654,84 @@ class EngineClient:
|
||||
"""
|
||||
|
||||
with self.clear_update_lock:
|
||||
if self.enable_prefix_caching:
|
||||
|
||||
skip_action = False
|
||||
return_code = None
|
||||
return_body = {}
|
||||
|
||||
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||
skip_action = True
|
||||
return_code = 200
|
||||
return_body = {**self.data_parallel_info, "msg": "model weight is cleared"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
skip_action = True
|
||||
return_code = 400
|
||||
return_body = {**self.data_parallel_info, "msg": "worker is clearing model weight already"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
skip_action = True
|
||||
return_code = 403
|
||||
return_body = {**self.data_parallel_info, "msg": "worker is updating model weight, cannot clear now"}
|
||||
|
||||
if not skip_action:
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||
api_server_logger.info(
|
||||
f"[RL] >>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] >>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"[RL] ... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] ... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
|
||||
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
|
||||
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return_code = 404
|
||||
return_body = {**self.data_parallel_info, "msg": "clear model weight timeout"}
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] <<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f"[RL] <<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] !!! skip clearing model weight for the following reason: {return_body.get('msg')}"
|
||||
)
|
||||
|
||||
if timeout >= 0 and self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
|
||||
api_server_logger.info(
|
||||
f">>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
f"[RL] >>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
|
||||
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
|
||||
api_server_logger.info(
|
||||
f"[RL] ... prefix tree status: {self.prefix_tree_status_signal.value[0]}"
|
||||
)
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return 404, {**self.data_parallel_info, "msg": "clear prefix tree timeout"}
|
||||
api_server_logger.info(
|
||||
f"<<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
|
||||
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||
return 200, {**self.data_parallel_info, "msg": "model weight is cleared"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
return 400, {**self.data_parallel_info, "msg": "worker is clearing model weight already"}
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
return 403, {**self.data_parallel_info, "msg": "worker is updating model weight, cannot clear now"}
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||
api_server_logger.info(
|
||||
f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
while timeout >= 0:
|
||||
api_server_logger.info(
|
||||
f"... weight status: {self.model_weights_status_signal.value[0]}"
|
||||
if not self.enable_cache_transfer
|
||||
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
|
||||
)
|
||||
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
|
||||
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
|
||||
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
|
||||
break
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return 404, {**self.data_parallel_info, "msg": "clear model weight timeout"}
|
||||
api_server_logger.info(
|
||||
f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]})"
|
||||
if not self.enable_cache_transfer
|
||||
else f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
|
||||
)
|
||||
return 200, {**self.data_parallel_info, "msg": "clear model weight successfully"}
|
||||
return_code = 404
|
||||
return_body = {**self.data_parallel_info, "msg": "clear prefix tree timeout"}
|
||||
else:
|
||||
api_server_logger.info(
|
||||
f"[RL] <<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
|
||||
)
|
||||
if return_code:
|
||||
if return_code == 404:
|
||||
api_server_logger.error("[RL] ??? clearing model weight time out")
|
||||
return return_code, return_body
|
||||
else:
|
||||
return 200, {**self.data_parallel_info, "msg": "clear model weight successfully"}
|
||||
|
||||
def check_model_weight_status(self):
|
||||
return self.model_weights_status_signal.value[0] < 0
|
||||
|
||||
Reference in New Issue
Block a user