[BugFix] move cache creation back to cache transfer process and adapt clear/update (#6144)

* [fix] move cache creation back to cache transfer process

* [fix] fix clear cache

* [chore] change some log level

* [fix] fix clear cache

* [fix] fix clear cache for blockwisefp8 and mtp

* [fix] fix c8

* [fix] fix clear_mtp_cache args

* [chore] update cache_transfer_manager

* [fix] fix update mtp cache
This commit is contained in:
Yonghua Li
2026-01-24 21:59:13 +08:00
committed by GitHub
parent 976203cf60
commit 833d00e2d7
5 changed files with 158 additions and 71 deletions
@@ -300,16 +300,6 @@ class CacheTransferManager:
def _init_gpu_cache(self, args):
try:
assert not args.create_cache_tensor
except:
logger.warn(
f"In current implementation, cache transfer manager do not create cache tensors at all, "
f"meaning create_cache_tensor should be False, while we got {args.create_cache_tensor}. "
f"Cache tensor creation will occur in: 1) model runner in case of mixed deployment; "
f"or 2) cache messager in case of disaggregation deployment. "
f"Please check the codes and make sure they work correctly."
)
if not args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
@@ -1039,14 +1029,15 @@ class CacheTransferManager:
# TODO XPU support RL
if unset_data_ipc is None:
return
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.")
while True:
# handle cache clearing/restoring
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
try:
logger.info(f"Start clearing caches {self.cache_ready_signal.value}")
# clear cpu caches
logger.info("[RL] start clearing caches")
logger.debug("[RL] start clearing cpu caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
@@ -1054,62 +1045,102 @@ class CacheTransferManager:
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
if self.cache_dtype == "block_wise_fp8":
self.k_scales_ptrs.clear()
self.v_scales_ptrs.clear()
gc.collect()
logger.debug("[RL] successfully cleared cpu caches")
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
logger.debug("[RL] all ranks cleared cpu caches")
else:
logger.debug("[RL] skip clearing cpu caches")
# clear gpu caches
set_device(self.device)
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
logger.debug("[RL] start clearing gpu caches")
if args.create_cache_tensor:
logger.info("[RL] waiting for gpu runner to unlink cuda ipc")
while self.cache_ready_signal.value[self.rank] != 0:
time.sleep(0.1)
logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc")
paddle.set_device(f"gpu:{self.device}")
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_scales_k_tensors.clear()
self.gpu_cache_scales_v_tensors.clear()
paddle.device.cuda.empty_cache()
logger.debug("[RL] successfully cleared gpu caches")
else:
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
logger.debug("[RL] successfully unlinked gpu caches cuda ipc")
self.cache_ready_signal.value[self.rank] = 0
# reset cache_ready_signal
self.cache_ready_signal.value[self.rank] = 0
logger.info(f"Finish clearing caches {self.cache_ready_signal.value}")
# wait for all ranks caches to be cleared
if np.sum(self.cache_ready_signal.value) != 0:
while np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] all ranks cleared caches!")
# reset kv_cache_status_signal
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
logger.info(f"All ranks finish clearing caches {self.cache_ready_signal.value}")
self._log_memory("after clearing caches")
except Exception as e:
logger.error(f"Failed to clear caches: {e}")
logger.error(f"[RL] failed to clear caches: {e}")
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
try:
logger.info(f"Start restoring caches {self.cache_ready_signal.value}")
# restore cpu cache
logger.info("[RL] start restoring caches")
logger.debug("[RL] start restoring cpu caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
logger.debug("[RL] successfully restored cpu caches")
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.debug("[RL] all ranks restored cpu caches")
else:
logger.debug("[RL] skip restoring cpu caches")
# restore gpu cache and set cache_ready_signal
logger.debug("[RL] start restoring gpu caches")
self._init_gpu_cache(args)
logger.info(f"Finish restoring caches {self.cache_ready_signal.value}")
logger.debug("[RL] successfully restored gpu caches")
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.info("[RL] all ranks restored caches!")
# set kv_cache_status_signal
logger.info(f"All ranks finish restoring caches {self.cache_ready_signal.value}")
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
self._log_memory("after restoring caches")
except Exception as e:
logger.error(f"Failed to restore caches: {e}")
logger.error(f"[RL] failed to restore caches: {e}")
time.sleep(0.1)
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)
max_reserved = paddle.device.cuda.max_memory_reserved() / (1024**3)
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.warning(
f"GPU memory usage {context}:"
f"max_allocated: {max_alloc:.2f}GB "
f"max_reserved: {max_reserved:.2f}GB "
f"current_allocated: {curr_alloc:.2f}GB "
f"current_reserved: {curr_reserved:.2f}GB"
)
def main():
"""