mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Cherry-Pick] [KVCache] launch cache transfer processes only if hierarchical cache or kv cache storage is enabled (#5871) (#5859)
* [fix] temporarily forbid cpu cache in update/clear api * [fix] stop launching cache transfer manager unless hierarchical cache is enabled
This commit is contained in:
@@ -254,37 +254,38 @@ class PrefixCacheManager:
|
||||
val_shape_str = str(val_cache_shape)
|
||||
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth "
|
||||
+ visible_devices
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
if self.cache_config.enable_hierarchical_cache:
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
"FLAGS_allocator_strategy=auto_growth "
|
||||
+ visible_devices
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
|
||||
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
@@ -294,13 +295,14 @@ class PrefixCacheManager:
|
||||
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info(
|
||||
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
|
||||
)
|
||||
if cache_manager_processes:
|
||||
exit_code = cache_manager_processes[-1].poll()
|
||||
if exit_code is None:
|
||||
logger.info("Launch cache transfer manager successful")
|
||||
else:
|
||||
logger.info(
|
||||
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
|
||||
)
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
|
||||
@@ -33,7 +33,6 @@ from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
IPCSignal,
|
||||
KVCacheStatus,
|
||||
ModelWeightsStatus,
|
||||
PrefixTreeStatus,
|
||||
RearrangeExpertStatus,
|
||||
@@ -548,6 +547,28 @@ class EngineClient:
|
||||
2 : worker update finish and notify client
|
||||
"""
|
||||
with self.clear_update_lock:
|
||||
if self.fd_config.cache_config.enable_hierarchical_cache:
|
||||
return False, "hierarchical cache updating is not supported"
|
||||
|
||||
# if self.enable_prefix_caching or self.enable_splitwise:
|
||||
# # kv_cache_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED:
|
||||
# self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
|
||||
# api_server_logger.info(f"Start to update kv cache {self.kv_cache_status_signal.value[0]}")
|
||||
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.NORMAL:
|
||||
# api_server_logger.info(f"..updating kv cache {self.kv_cache_status_signal.value[0]}")
|
||||
# time.sleep(1)
|
||||
|
||||
if self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
|
||||
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
|
||||
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
|
||||
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||
@@ -556,34 +577,13 @@ class EngineClient:
|
||||
return False, "worker is clearing model weight, cannot update now"
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
|
||||
if self.enable_prefix_caching:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
|
||||
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
|
||||
all_updated = False
|
||||
while timeout >= 0 and not all_updated:
|
||||
api_server_logger.info(
|
||||
f"Updating model weights.. "
|
||||
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
|
||||
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
|
||||
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
|
||||
)
|
||||
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
|
||||
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
|
||||
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
if self.enable_prefix_caching:
|
||||
all_updated = weight_updated and cache_updated and prefix_updated
|
||||
else:
|
||||
all_updated = weight_updated and cache_updated
|
||||
else:
|
||||
all_updated = weight_updated
|
||||
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
|
||||
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
|
||||
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
return False, "Update model weight timeout"
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
|
||||
def clear_load_weight(self, timeout=300):
|
||||
@@ -594,6 +594,27 @@ class EngineClient:
|
||||
"""
|
||||
|
||||
with self.clear_update_lock:
|
||||
if self.fd_config.cache_config.enable_hierarchical_cache:
|
||||
return False, "hierarchical cache clearing is not supported"
|
||||
# if self.enable_prefix_caching or self.enable_splitwise:
|
||||
# # kv_cache_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL:
|
||||
# self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
|
||||
# api_server_logger.info(f"Start to clear kv cache {self.kv_cache_status_signal.value[0]}")
|
||||
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.CLEARED:
|
||||
# api_server_logger.info(f"..clearing kv cache {self.kv_cache_status_signal.value[0]}")
|
||||
# time.sleep(1)
|
||||
|
||||
if self.enable_prefix_caching:
|
||||
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
|
||||
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
|
||||
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
|
||||
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||
return True, ""
|
||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||
@@ -602,36 +623,13 @@ class EngineClient:
|
||||
return False, "worker is updating model weight, cannot clear now"
|
||||
|
||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
|
||||
if self.enable_prefix_caching:
|
||||
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
|
||||
|
||||
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
|
||||
all_cleared = False
|
||||
while timeout >= 0 and not all_cleared:
|
||||
api_server_logger.info(
|
||||
f"Clearing model weights.. "
|
||||
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
|
||||
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
|
||||
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
|
||||
)
|
||||
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
|
||||
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
|
||||
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
|
||||
if self.enable_prefix_caching or self.enable_splitwise:
|
||||
if self.enable_prefix_caching:
|
||||
all_cleared = weight_cleared and cache_cleared and prefix_cleared
|
||||
else:
|
||||
all_cleared = weight_cleared and cache_cleared
|
||||
else:
|
||||
all_cleared = weight_cleared
|
||||
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
|
||||
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
|
||||
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
|
||||
if timeout < 0:
|
||||
return False, "Clear model weight timeout"
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
|
||||
def check_model_weight_status(self):
|
||||
|
||||
Reference in New Issue
Block a user