[Cherry-Pick] [KVCache] launch cache transfer processes only if hierarchical cache or kv cache storage is enabled (#5871) (#5859)

* [fix] temporarily forbid cpu cache in update/clear api

* [fix] stop launching cache transfer manager unless hierarchical cache is enabled
This commit is contained in:
Yonghua Li
2026-01-06 11:05:45 +08:00
committed by GitHub
parent 0f008b8bd1
commit f3ebd64446
2 changed files with 89 additions and 89 deletions
@@ -254,37 +254,38 @@ class PrefixCacheManager:
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --engine_pid {pid_suffix}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
if self.cache_config.enable_hierarchical_cache:
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
+ f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}"
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
+ f" --mp_num {tensor_parallel_size}"
+ f" --cache_dtype {cache_config.cache_dtype}"
+ f" --key_cache_shape {key_cache_shape}"
+ val_cache_arg_str
+ f" --cache_queue_port {cache_config.cache_queue_port}"
+ f" --enable_splitwise {int(self.enable_splitwise)}"
+ f" --pod_ip {pod_ip}"
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
+ f" --engine_pid {pid_suffix}"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ f" --protocol {cache_config.cache_transfer_protocol}"
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
logger.info("PrefixCacheManager is waiting for kv cache to be initialized.")
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
@@ -294,13 +295,14 @@ class PrefixCacheManager:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)
if cache_manager_processes:
exit_code = cache_manager_processes[-1].poll()
if exit_code is None:
logger.info("Launch cache transfer manager successful")
else:
logger.info(
"Launch cache transfer manager failed, see launch_cache_transfer_manager.log for more information"
)
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
+49 -51
View File
@@ -33,7 +33,6 @@ from fastdeploy.eplb.utils import RedundantExpertWorkload
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
RearrangeExpertStatus,
@@ -548,6 +547,28 @@ class EngineClient:
2 : worker update finish and notify client
"""
with self.clear_update_lock:
if self.fd_config.cache_config.enable_hierarchical_cache:
return False, "hierarchical cache updating is not supported"
# if self.enable_prefix_caching or self.enable_splitwise:
# # kv_cache_status_signal: CLEARED -> UPDATING -> NORMAL
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED:
# self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
# api_server_logger.info(f"Start to update kv cache {self.kv_cache_status_signal.value[0]}")
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.NORMAL:
# api_server_logger.info(f"..updating kv cache {self.kv_cache_status_signal.value[0]}")
# time.sleep(1)
if self.enable_prefix_caching:
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
@@ -556,34 +577,13 @@ class EngineClient:
return False, "worker is clearing model weight, cannot update now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.UPDATING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"start update model weight {self.model_weights_status_signal.value}")
all_updated = False
while timeout >= 0 and not all_updated:
api_server_logger.info(
f"Updating model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
prefix_updated = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_updated = weight_updated and cache_updated and prefix_updated
else:
all_updated = weight_updated and cache_updated
else:
all_updated = weight_updated
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
time.sleep(1)
return True, ""
def clear_load_weight(self, timeout=300):
@@ -594,6 +594,27 @@ class EngineClient:
"""
with self.clear_update_lock:
if self.fd_config.cache_config.enable_hierarchical_cache:
return False, "hierarchical cache clearing is not supported"
# if self.enable_prefix_caching or self.enable_splitwise:
# # kv_cache_status_signal: NORMAL -> CLEARING -> CLEARED
# if self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL:
# self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
# api_server_logger.info(f"Start to clear kv cache {self.kv_cache_status_signal.value[0]}")
# while self.kv_cache_status_signal.value[0] != KVCacheStatus.CLEARED:
# api_server_logger.info(f"..clearing kv cache {self.kv_cache_status_signal.value[0]}")
# time.sleep(1)
if self.enable_prefix_caching:
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
@@ -602,36 +623,13 @@ class EngineClient:
return False, "worker is updating model weight, cannot clear now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise:
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARING
if self.enable_prefix_caching:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"start clear model weight {self.model_weights_status_signal.value}")
all_cleared = False
while timeout >= 0 and not all_cleared:
api_server_logger.info(
f"Clearing model weights.. "
f"model_weights_status: {self.model_weights_status_signal.value[0]}, "
f"prefix_tree_status: {self.prefix_tree_status_signal.value[0]}, "
f"kv_cache_status: {self.kv_cache_status_signal.value[0]} "
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
prefix_cleared = self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED
if self.enable_prefix_caching or self.enable_splitwise:
if self.enable_prefix_caching:
all_cleared = weight_cleared and cache_cleared and prefix_cleared
else:
all_cleared = weight_cleared and cache_cleared
else:
all_cleared = weight_cleared
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Clear model weight timeout"
time.sleep(1)
return True, ""
def check_model_weight_status(self):