[BugFix] fix cache transfer manager updating/clearing (#5930)

* [fix] fix cache transfer manager updating/clearing

* [fix] fix code style

* [fix] fix config

* [fix] fix engine client

* [fix] let worker update kv cache status signal

* [fix] update worker process

* [fix] fix clear/update for case if comm group is shutdown

* [fix] update dynamic weight manager

* [fix] fix port

* [fix] add num_cpu_blocks arg for async_llm, and remove unnecessary waiting
This commit is contained in:
Yonghua Li
2026-01-13 21:09:29 +08:00
committed by GitHub
parent 6da06abc17
commit 456637002d
8 changed files with 165 additions and 74 deletions
@@ -178,7 +178,7 @@ class CacheTransferManager:
name="cache_ready_signal",
array=cache_ready_signal_data,
dtype=np.int32,
suffix=self.ipc_suffix,
suffix=args.engine_worker_queue_port,
create=False,
)
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
@@ -186,7 +186,7 @@ class CacheTransferManager:
name="swap_space_ready_signal",
array=swap_space_ready_data,
dtype=np.int32,
suffix=self.ipc_suffix,
suffix=args.engine_worker_queue_port,
create=False,
)
@@ -201,7 +201,7 @@ class CacheTransferManager:
name="cache_task_broadcast_signal",
array=cache_task_broadcast_data,
dtype=np.int32,
suffix=args.ipc_suffix,
suffix=args.engine_worker_queue_port,
create=False,
)
@@ -230,7 +230,15 @@ class CacheTransferManager:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
# Initialize update/clear signals for RL
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=np.zeros([1], dtype=np.int32),
dtype=np.int32,
suffix=args.engine_worker_queue_port,
create=False,
)
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
def _init_storage_buffer(self, args):
"""
@@ -951,29 +959,19 @@ class CacheTransferManager:
task_cpu_block_id,
)
def clear_or_update_caches(self, args):
def check_cache_status(self, args):
# TODO XPU support RL
if unset_data_ipc is None:
return
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
kv_cache_status = np.zeros([1], dtype=np.int32)
kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
array=kv_cache_status,
dtype=np.int32,
suffix=self.ipc_suffix,
create=False,
)
while True:
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
# handle cache clearing/restoring
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
try:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
)
logger.info(f"Start clearing caches {self.cache_ready_signal.value}")
# clear cpu caches
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
@@ -996,49 +994,43 @@ class CacheTransferManager:
# reset cache_ready_signal
self.cache_ready_signal.value[self.rank] = 0
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
)
logger.info(f"Finish clearing caches {self.cache_ready_signal.value}")
# wait for all ranks caches to be cleared
if np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
# reset kv_cache_status_signal
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
logger.info("All ranks finish clearing caches")
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
logger.info(f"All ranks finish clearing caches {self.cache_ready_signal.value}")
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
logger.error(f"Failed to clear caches: {e}")
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
try:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
)
logger.info(f"Start restoring caches {self.cache_ready_signal.value}")
# restore cpu cache
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
# restore gpu cache and set cache_ready_signal
self._init_gpu_cache(args)
logger.info(
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
)
logger.info(f"Finish restoring caches {self.cache_ready_signal.value}")
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
# set kv_cache_status_signal
logger.info("All ranks finish restoring caches")
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
logger.info(f"All ranks finish restoring caches {self.cache_ready_signal.value}")
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
logger.error(f"Failed to restore caches: {e}")
time.sleep(0.1)
+7 -4
View File
@@ -1335,6 +1335,7 @@ class CacheConfig:
self.disable_chunked_mm_input = False
self.kvcache_storage_backend = None
self.write_policy = None
self.num_cpu_blocks = None
for key, value in args.items():
if hasattr(self, key):
@@ -1380,10 +1381,12 @@ class CacheConfig:
* byte_size
)
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
if self.num_cpu_blocks is None:
if self.swap_space is None:
self.num_cpu_blocks = 0
else:
self.num_cpu_blocks = int(self.swap_space * 1024**3 / self.bytes_per_block)
self._verify_args()
def metrics_info(self):
+1
View File
@@ -1707,6 +1707,7 @@ class EngineService:
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
+1
View File
@@ -574,6 +574,7 @@ class LLMEngine:
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
f" --eplb_config '{self.cfg.eplb_config.to_json_string()}'"
f" --routing_replay_config '{self.cfg.routing_replay_config.to_json_string()}'"
f" --num_cpu_blocks {self.cfg.cache_config.num_cpu_blocks}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
+66 -18
View File
@@ -34,6 +34,7 @@ from fastdeploy.eplb.utils import RedundantExpertWorkload
from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import (
IPCSignal,
KVCacheStatus,
ModelWeightsStatus,
PrefixTreeStatus,
RearrangeExpertStatus,
@@ -79,6 +80,9 @@ class EngineClient:
)
self.max_model_len = self.fd_config.model_config.max_model_len
self.enable_prefix_caching = self.fd_config.cache_config.enable_prefix_caching
self.enable_cache_transfer = (
self.fd_config.cache_config.swap_space or self.fd_config.cache_config.kvcache_storage_backend
)
self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed"
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
@@ -528,17 +532,22 @@ class EngineClient:
2 : worker update finish and notify client
"""
with self.clear_update_lock:
if self.fd_config.cache_config.swap_space:
return False, "hierarchical cache updating is not supported"
if self.enable_prefix_caching:
# prefix_tree_status_signal: CLEARED -> UPDATING -> NORMAL
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.CLEARED:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.UPDATING
api_server_logger.info(f"Start to update prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"..updating prefix tree {self.prefix_tree_status_signal.value[0]}")
api_server_logger.info(
f">>> start updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL:
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update prefix tree timeout"
api_server_logger.info(
f"<<< finish updating prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
# model_weights_status_signal: CLEARED -> UPDATING -> NORMAL
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
@@ -549,13 +558,30 @@ class EngineClient:
return False, "worker is clearing model weight, cannot update now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
api_server_logger.info(f"Start to update model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.NORMAL:
api_server_logger.info(f"..updating model weights {self.model_weights_status_signal.value[0]}")
api_server_logger.info(
f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]})"
if not self.enable_cache_transfer
else f">>> start updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_updated = self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL
cache_updated = self.kv_cache_status_signal.value[0] == KVCacheStatus.NORMAL
if weight_updated and (not self.enable_cache_transfer or cache_updated):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Update model weight timeout"
api_server_logger.info(
f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"<<< finish updating model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
return True, ""
def clear_load_weight(self, timeout=300):
@@ -566,17 +592,22 @@ class EngineClient:
"""
with self.clear_update_lock:
if self.fd_config.cache_config.swap_space:
return False, "hierarchical cache clearing is not supported"
if self.enable_prefix_caching:
# prefix_tree_status_signal: NORMAL -> CLEARING -> CLEARED
if self.prefix_tree_status_signal.value[0] == PrefixTreeStatus.NORMAL:
self.prefix_tree_status_signal.value[0] = PrefixTreeStatus.CLEARING
api_server_logger.info(f"Start to clear prefix tree {self.prefix_tree_status_signal.value[0]}")
while self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"..clearing prefix tree {self.prefix_tree_status_signal.value[0]}")
api_server_logger.info(
f">>> start clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
while timeout >= 0 and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.CLEARED:
api_server_logger.info(f"... prefix tree status: {self.prefix_tree_status_signal.value[0]}")
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Clear prefix tree timeout"
api_server_logger.info(
f"<<< finish clearing prefix tree (status: {self.prefix_tree_status_signal.value[0]})"
)
# model_weights_status_signal: NORMAL -> CLEARING -> CLEARED
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
@@ -587,13 +618,30 @@ class EngineClient:
return False, "worker is updating model weight, cannot clear now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
api_server_logger.info(f"Start to clear model weight {self.model_weights_status_signal.value[0]}")
while timeout >= 0 and self.model_weights_status_signal.value[0] != ModelWeightsStatus.CLEARED:
api_server_logger.info(f"..clearing model weights {self.model_weights_status_signal.value[0]}")
api_server_logger.info(
f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f">>> start clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
while timeout >= 0:
api_server_logger.info(
f"... weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"... weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]}"
)
weight_cleared = self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED
cache_cleared = self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARED
if weight_cleared and (not self.enable_cache_transfer or cache_cleared):
break
time.sleep(1)
timeout -= 1
if timeout < 0:
return False, "Clear model weight timeout"
api_server_logger.info(
f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]}"
if not self.enable_cache_transfer
else f"<<< finish clearing model weight (weight status: {self.model_weights_status_signal.value[0]} cache status: {self.kv_cache_status_signal.value[0]})"
)
return True, ""
def check_model_weight_status(self):
+17 -7
View File
@@ -24,7 +24,7 @@ import paddle
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import ModelWeightsStatus
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
class DynamicWeightManager:
@@ -258,18 +258,25 @@ class DynamicWeightManager:
value[self.rank] = status
@staticmethod
def check_model_weights_status(model_weights_status, model_runner, pid, block):
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
"""
check model weights status
A function to handle the state of model weights, check the model weights state,
and perform corresponding operations as needed.
- model_weights_status (`IPCSignal`): The signal indicating the status of model weights.
- kv_cache_status (`IPCSignal`): The signal indicating the status of key-value cache.
- model_runner (`ModelRunnerBase`): The model runner instance.
- block (`bool`): Block mode keeps the worker process blocked in the status-check loop,
avoiding communication operations in the worker event loop.
"""
# logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL and (
block or model_weights_status.value[0] != ModelWeightsStatus.CLEARED
):
# 如果为 block 模式,那么循环不会退出,直到权重更新、通信组重建
# 如果为非 block 模式,那么循环在权重更新或清理后均会退出
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...")
if kv_cache_status:
kv_cache_status.value[0] = KVCacheStatus.UPDATING
model_runner.clear_requests()
model_runner.update_parameters(pid)
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
@@ -277,9 +284,12 @@ class DynamicWeightManager:
logger.info("finished loading new checkpoint")
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
logger.info("infer engine stopped! start to clear checkpoint...")
if kv_cache_status:
kv_cache_status.value[0] = KVCacheStatus.CLEARING
model_runner.clear_requests()
model_runner.clear_parameters(pid)
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
time.sleep(0.01)
logger.info("finished clearing checkpoint")
time.sleep(0.01)
else:
time.sleep(0.01)
+44 -9
View File
@@ -230,6 +230,16 @@ class PaddleDisWorkerProc:
create=False,
)
# init kv_cache_status
kv_cache_status_data = np.zeros(shape=[1], dtype=np.int32)
self.kv_cache_status = IPCSignal(
name="kv_cache_status",
array=kv_cache_status_data,
dtype=np.int32,
suffix=self.parallel_config.local_engine_worker_queue_port,
create=False,
)
# init exist_task_signal
workers_exist_task = np.zeros([1], dtype=np.int32)
self.exist_task_signal = IPCSignal(
@@ -421,8 +431,7 @@ class PaddleDisWorkerProc:
self._run_eplb(tp_rank)
if self.fd_config.load_config.dynamic_load_weight:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
@@ -455,8 +464,10 @@ class PaddleDisWorkerProc:
)
self.model_weights_status.value[0] = self.model_weights_signal[0]
self.kv_cache_status.value[0] = self.model_weights_signal[0]
DynamicWeightManager.check_model_weights_status(
self.model_weights_status,
self.kv_cache_status if self.fd_config.cache_config.num_cpu_blocks > 0 else None,
# model_weights_signal
self.worker.model_runner,
self.parallel_config.local_engine_worker_queue_port,
@@ -464,14 +475,31 @@ class PaddleDisWorkerProc:
)
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
self.task_queue.clear_data()
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
# 只有不关闭通信组时,清理权重后需要额外等待(否则信号量会同步混乱)
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
while self.model_weights_status.value[0] == ModelWeightsStatus.CLEARED:
time.sleep(0.01)
continue
if self.model_weights_signal[0] == ModelWeightsStatus.UPDATING:
logger.info(
f"Rank: {self.local_rank} has updated parameters. {self.model_weights_status.value[0]}"
)
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
elif self.model_weights_signal[0] == ModelWeightsStatus.CLEARING:
logger.info(
f"Rank: {self.local_rank} has cleared parameters. {self.model_weights_status.value[0]}"
)
# 如果清理权重后不关闭通信组,那么将推理进程统一阻塞在下面的循环中,否则信号量可能同步混乱;直到下次权重更新时唤醒
if not self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle:
if self.ranks > 1: # 所有 Rank 同时入睡,监听下次的更新信号
paddle.distributed.barrier()
while self.model_weights_signal[0] != ModelWeightsStatus.UPDATING:
self.model_weights_signal[0] = self.model_weights_status.value[0]
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=None
)
time.sleep(1)
self.model_weights_status.value[0] = (
ModelWeightsStatus.UPDATING
) # 所有 Rank 已同步唤醒,启动权重更新流程
continue
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.")
@@ -906,6 +934,13 @@ def parse_args():
help="Enable output of token-level entropy.",
)
parser.add_argument(
"--num_cpu_blocks",
type=int,
default=0,
help="Number of cpu blocks.",
)
args = parser.parse_args()
return args
+1
View File
@@ -38,6 +38,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
--gpu-memory-utilization 0.9 \
--model "$MODEL_PATH" \
--no-shutdown-comm-group-if-worker-idle \
--swap-space 10 \
--load-strategy ipc_snapshot \
--dynamic-load-weight &