[Cherry-Pick] [Feature] support v1 update/clear api for RL (#6761) (#6974)

* [Feature] support v1 update/clear api for RL

* [fix] fix stale control responses when control method timed out

* [chore] remove unused code

* [chore] optimize tags and key_prefix

* [test] fix ci

* [chore] fix code style

* [fix] fix ep control

* [fix] fix ep control for engine cache queue
This commit is contained in:
Yonghua Li
2026-03-25 19:18:35 +08:00
committed by GitHub
parent 49c2310854
commit 35034f91fa
25 changed files with 1665 additions and 328 deletions
+231 -154
View File
@@ -15,6 +15,7 @@
"""
import argparse
import asyncio
import concurrent.futures
import gc
import json
@@ -35,7 +36,6 @@ from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTa
from fastdeploy.cache_manager.ops import (
cuda_host_alloc,
cuda_host_free,
memory_allocated,
set_data_ipc,
set_device,
share_external_data_,
@@ -49,7 +49,9 @@ from fastdeploy.cache_manager.transfer_factory import (
MooncakeStore,
)
from fastdeploy.config import CacheConfig, SpeculativeConfig
from fastdeploy.engine.request import ControlRequest, ControlResponse
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.platforms import current_platform
from fastdeploy.utils import console_logger, get_logger
@@ -96,7 +98,6 @@ def parse_args():
help="engine worker queue port",
)
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
@@ -184,15 +185,17 @@ class CacheTransferManager:
self.key_prefix = ""
# extract other arg values
self.model_path = args.model_path
self.model_id = os.path.basename(args.model_path.rstrip("/"))
self.n_ranks = args.mp_num
self.rank = args.rank
self.device = args.device_id
self.num_layers = args.num_layers
self.ipc_suffix = args.ipc_suffix
self.create_cache_tensor = args.create_cache_tensor
self.local_data_parallel_id = args.local_data_parallel_id
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.cache_queue_port = args.cache_queue_port
paddle.set_default_dtype(args.default_dtype)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -200,8 +203,10 @@ class CacheTransferManager:
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.control_task_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.ctrl_output_queue = None
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
@@ -209,7 +214,7 @@ class CacheTransferManager:
is_server=False,
num_client=args.mp_num,
client_id=self.rank,
local_data_parallel_id=args.local_data_parallel_id,
local_data_parallel_id=0,
)
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
@@ -231,11 +236,12 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
self._init_gpu_cache(args)
self._init_gpu_cache()
if self.num_cpu_blocks > 0:
self._init_cpu_cache(args)
self._init_cpu_cache()
if self.storage_backend_type is not None:
self._init_storage(args)
self._init_control()
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
@@ -287,7 +293,9 @@ class CacheTransferManager:
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
self.is_paused = False # transfer manager state
self.is_sleeping = False
self.inflight = 0 # number of inflight transfer tasks
self.inflight_tasks = {}
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_transfer_inited_signal = IPCSignal(
@@ -299,6 +307,15 @@ class CacheTransferManager:
)
self.cache_transfer_inited_signal.value[self.rank] = 1
def _init_control(self):
dp_rank = self.local_data_parallel_id
tp_rank = self.rank
tp_size = self.n_ranks
cq_port = self.cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_rank}_{cq_port}"
self.ctrl_output_queue = FMQ().queue(name, "producer")
logger.info(f"Init control output queue: {name} (producer)")
def _init_storage(self, args):
try:
# TODO: support cache scale for other backend
@@ -354,14 +371,19 @@ class CacheTransferManager:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
self.key_prefix = ""
version_file_path = os.path.join(args.model_path, "version.yaml")
if os.path.exists(version_file_path):
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
self._update_key_prefix()
logger.info("Initialize cache storage successfully")
def _update_key_prefix(self):
# use key_prefix to distinguish cache for different version of weight in rl
version_file_path = os.path.join(self.model_path, "version.yaml")
if os.path.exists(version_file_path):
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
else:
logger.error(f"version.yaml not found at {version_file_path}")
def _init_storage_buffer(self, args):
"""
Initialize pinned memory buffer that can hold the cache for a longest request
@@ -410,20 +432,20 @@ class CacheTransferManager:
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
def _init_gpu_cache(self, args):
def _init_gpu_cache(self):
if not args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.")
if not self.create_cache_tensor:
logger.info("Waiting for runners or messagers to create kv cache.")
while self.cache_ready_signal.value[self.rank] != 1:
time.sleep(0.1)
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
logger.info("OK! Stop waiting.")
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
cache_type = "uint8"
else:
cache_type = args.cache_dtype
cache_type = self.cache_dtype
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
logger.info("Initializing kv cache for all layers.")
set_device(self.device)
for i in range(self.num_layers + self.num_extra_layers):
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
@@ -446,14 +468,12 @@ class CacheTransferManager:
self.value_cache_shape[2],
self.value_cache_shape[3],
]
if args.create_cache_tensor:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
if self.create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_name)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
key_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
fill_value=0,
@@ -464,7 +484,7 @@ class CacheTransferManager:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(val_cache, val_name)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
fill_value=0,
@@ -472,13 +492,11 @@ class CacheTransferManager:
)
set_data_ipc(value_cache_scales, value_cache_scales_name)
else:
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_cache_scales = share_external_data_(
key_cache_scales,
@@ -488,7 +506,7 @@ class CacheTransferManager:
)
if self.value_cache_shape:
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_cache_scales = share_external_data_(
value_cache_scales,
@@ -499,48 +517,65 @@ class CacheTransferManager:
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales
self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name])
if args.value_cache_shape:
if self.value_cache_shape:
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales
self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name])
if args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
if self.create_cache_tensor:
self.cache_ready_signal.value[self.rank] = 1
while np.sum(self.cache_ready_signal.value) != self.n_ranks:
time.sleep(0.1)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
logger.info("GPU KV cache is initialized")
def _init_cpu_cache(self, args):
def _clear_gpu_cache(self):
if self.create_cache_tensor:
logger.debug("Waiting for gpu runner to unlink cuda ipc")
while self.cache_ready_signal.value[self.rank] != 0:
time.sleep(0.1)
logger.debug("Stop waiting! gpu runner has unlinked cuda ipc")
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
if hasattr(self, "gpu_cache_scales_k_tensors"):
self.gpu_cache_scales_k_tensors.clear()
if hasattr(self, "gpu_cache_scales_v_tensors"):
self.gpu_cache_scales_v_tensors.clear()
paddle.device.cuda.empty_cache()
else:
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
logger.debug("Successfully unlinked gpu caches cuda ipc")
self.cache_ready_signal.value[self.rank] = 0
while np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("All ranks cleared gpu caches")
def _init_cpu_cache(self):
if self.num_cpu_blocks == 0:
return
paddle.set_device("cpu")
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
if args.value_cache_shape:
if self.value_cache_shape:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
if args.cache_dtype == "block_wise_fp8":
key_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size
logger.info("Initializing swap space (cpu cache) for all layers.")
if self.cache_dtype == "block_wise_fp8":
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
if args.num_cpu_blocks == 0:
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
return
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
paddle.set_device("cpu")
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
self.k_dst_ptrs = []
self.v_dst_ptrs = []
self.k_scales_ptrs = []
@@ -551,22 +586,43 @@ class CacheTransferManager:
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}"
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
f"..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
if value_need_to_allocate_bytes > 0:
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
if args.cache_dtype == "block_wise_fp8":
if self.cache_dtype == "block_wise_fp8":
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
logger.info("Swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
while np.sum(self.swap_space_ready_signal.value) != self.n_ranks:
time.sleep(0.1)
logger.info("All ranks init cpu caches")
def _clear_cpu_cache(self):
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
if hasattr(self, "k_scales_ptrs"):
self.k_scales_ptrs.clear()
if hasattr(self, "v_scales_ptrs"):
self.v_scales_ptrs.clear()
gc.collect()
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("All ranks cleared cpu caches")
def _run_read_storage(
self,
task_id: str,
@@ -1024,6 +1080,79 @@ class CacheTransferManager:
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
def _handle_pause(self):
if self.is_paused:
logger.info("💡 Cache transfer manager is already paused, no need to pause again!")
else:
self.pause()
logger.info("✅ Successfully paused transfer")
return True
def _handle_resume(self):
if not self.is_paused:
logger.info("💡 Cache transfer manager is not paused, no need to resume!")
else:
self.resume()
if self.storage_backend_type is not None:
self._update_key_prefix()
logger.info("✅ Successfully resumed transfer")
return True
def _handle_sleep(self):
if self.is_sleeping:
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
else:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._clear_cpu_cache()
self._clear_gpu_cache()
self.is_sleeping = True
logger.info("✅ Successfully fell asleep (offloaded caches)")
return True
def _handle_wakeup(self):
if not self.is_sleeping:
logger.info("💡 Cache transfer manager is not sleeping, no need to wakeup!")
else:
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache()
self._init_gpu_cache()
self.is_sleeping = False
logger.info("✅ Successfully wakeup (reload caches)")
return True
def control_task(self, task: ControlRequest):
method = task.get_method()
tags = task.args.get("tags", {})
logger.info(f"Received control task: {method}, tags: {tags}")
handlers = {
"pause": self._handle_pause,
"resume": self._handle_resume,
"sleep": self._handle_sleep,
"wakeup": self._handle_wakeup,
}
handler = handlers.get(method)
error_code = 200
error_message = "Success"
if handler:
try:
handler()
except Exception as e:
error_code = 500
error_message = f"Failed to execute {method}: {str(e)}"
logger.error(f"Error in control_task: {traceback.format_exc()}")
else:
error_code = 400
error_message = f"Unknown control method: {method}"
logger.warning(error_message)
self.cache_task_queue.barrier.wait()
resp = ControlResponse(task.request_id, error_code, error_message)
asyncio.run(self.ctrl_output_queue.put(resp))
logger.info(f"Put response into output queue {self.ctrl_output_queue.name}: {resp}")
def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIMEOUT):
"""
Check the health of the model server by checking whether all workers are alive.
@@ -1043,8 +1172,13 @@ class CacheTransferManager:
return fn(*args)
finally:
self.inflight -= 1
logger.debug(f"submit_task: {fn.__name__} finished, args: {args}, current inflight: {self.inflight}")
self.inflight += 1
thread_pool.submit(inflight_task, task_fn, *args)
logger.debug(
f"submit_task: {task_fn.__name__} submitted to thread pool, args: {args}, current inflight: {self.inflight}"
)
def do_data_transfer(self):
"""
@@ -1055,6 +1189,7 @@ class CacheTransferManager:
max_errors = (
envs.FD_CACHE_PROC_ERROR_COUNT
) # After this many consecutive errors, check if the worker process exists.
is_paused = False
while True:
try:
@@ -1066,18 +1201,18 @@ class CacheTransferManager:
self.cache_task_queue.barrier0.reset()
# Ensure all ranks synchronically do one of the following things:
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
# (1) If rank#0 is paused, wait for inflight tasks to finish first, then only process control tasks;
# (2) otherwise, all ranks are allowed to pull all tasks from cache task queue
if self.cache_task_is_paused_signal.value[0] == 1:
# wait for inflight tasks to finish first
while self.inflight != 0:
time.sleep(0.1)
# mark the current rank as not having inflight tasks
self.cache_task_inflight_signal.value[self.rank] = 0
time.sleep(1)
continue
is_paused = True
else:
self.cache_task_inflight_signal.value[self.rank] = 1
is_paused = False
if self.rank == 0:
if not self.cache_task_queue.empty():
@@ -1088,12 +1223,16 @@ class CacheTransferManager:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
self.inflight += 1
data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"do_data_transfer: {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
event_type, event_args = data[0], data[1:]
# control task is the only task allowed to execute when loop is paused
if is_paused and event_type.value != CacheStatus.CTRL.value:
continue
if event_type.value == CacheStatus.SWAP2CPU.value:
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.submit_task(
@@ -1130,6 +1269,13 @@ class CacheTransferManager:
self.write_back_storage_task,
write_storage_task,
)
elif event_type.value == CacheStatus.CTRL.value:
control_task = event_args[0]
self.control_task_thread_pool.submit(
self.control_task,
control_task,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
@@ -1292,115 +1438,46 @@ class CacheTransferManager:
# TODO XPU support RL
if unset_data_ipc is None:
return
logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.")
logger.info(
"check_cache_status: Launch a thread to clear/restore kv cache when model weights are cleared/updated."
)
while True:
# handle cache clearing/restoring
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
try:
# wait for inflight transfer tasks to finish and pause transfer manager
# pause transfer
self.pause()
# clear cpu caches
logger.info("[RL] start clearing caches")
logger.debug("[RL] start clearing cpu caches")
# clear caches
logger.info("check_cache_status: start clearing caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
paddle.set_device("cpu")
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
cuda_host_free(ptrs)
self.cpu_cache_kvs.clear()
self.k_dst_ptrs.clear()
self.v_dst_ptrs.clear()
if self.cache_dtype == "block_wise_fp8":
self.k_scales_ptrs.clear()
self.v_scales_ptrs.clear()
gc.collect()
logger.debug("[RL] successfully cleared cpu caches")
# reset swap_space_ready_signal
self.swap_space_ready_signal.value[self.rank] = 0
while np.sum(self.swap_space_ready_signal.value) != 0:
time.sleep(0.1)
logger.debug("[RL] all ranks cleared cpu caches")
else:
logger.debug("[RL] skip clearing cpu caches")
# clear gpu caches
logger.debug("[RL] start clearing gpu caches")
if args.create_cache_tensor:
logger.info("[RL] waiting for gpu runner to unlink cuda ipc")
while self.cache_ready_signal.value[self.rank] != 0:
time.sleep(0.1)
logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc")
paddle.set_device(f"gpu:{self.device}")
paddle.set_flags({"FLAGS_selected_gpus": f"{self.device}"})
self.gpu_cache_kvs.clear()
self.gpu_cache_k_tensors.clear()
self.gpu_cache_v_tensors.clear()
if self.cache_dtype == "block_wise_fp8":
self.gpu_cache_scales_k_tensors.clear()
self.gpu_cache_scales_v_tensors.clear()
paddle.device.cuda.empty_cache()
logger.debug("[RL] successfully cleared gpu caches")
else:
for name, tensor in self.gpu_cache_kvs.items():
unset_data_ipc(tensor, name, True, False)
logger.debug("[RL] successfully unlinked gpu caches cuda ipc")
self.cache_ready_signal.value[self.rank] = 0
while np.sum(self.cache_ready_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] all ranks cleared caches!")
# reset kv_cache_status_signal
self._clear_cpu_cache()
self._clear_gpu_cache()
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
self._log_memory("after clearing caches")
self._log_memory("check_cache_status: after clearing caches")
except Exception as e:
logger.error(f"[RL] failed to clear caches: {e}")
logger.error(f"check_cache_status: failed to clear caches: {e}")
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
try:
# restore cpu cache
logger.info("[RL] start restoring caches")
logger.debug("[RL] start restoring cpu caches")
logger.info("check_cache_status: start restoring caches")
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
self._init_cpu_cache(args)
logger.debug("[RL] successfully restored cpu caches")
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.debug("[RL] all ranks restored cpu caches")
else:
logger.debug("[RL] skip restoring cpu caches")
# restore gpu cache and set cache_ready_signal
logger.debug("[RL] start restoring gpu caches")
self._init_gpu_cache(args)
logger.debug("[RL] successfully restored gpu caches")
self._init_cpu_cache()
self._init_gpu_cache()
# update key prefix for kv cache backend
if self.storage_backend_type is not None:
# use key_prefix to distinguish cache for different version of weight in rl
version_file_path = os.path.join(args.model_path, "version.yaml")
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)
logger.info("[RL] all ranks restored caches!")
self._update_key_prefix()
# resume transfer
self.resume()
# set kv_cache_status_signal
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
self._log_memory("after restoring caches")
self._log_memory("check_cache_status: after restoring caches")
except Exception as e:
logger.error(f"[RL] failed to restore caches: {e}")
logger.error(f"check_cache_status: failed to restore caches: {e}")
time.sleep(0.1)
@@ -1409,11 +1486,11 @@ class CacheTransferManager:
self.cache_task_queue.pause_barrier.wait()
if self.rank == 0:
self.cache_task_queue.pause_barrier.reset()
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
logger.info("pause: 🟠 wait for inflight transfer tasks to finish")
self.is_paused = True
while np.sum(self.cache_task_inflight_signal.value) != 0:
time.sleep(0.1)
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
logger.info("pause: 🔴 pause transfer manager and stop do transfer tasks")
def resume(self):
if self.n_ranks > 1:
@@ -1423,7 +1500,7 @@ class CacheTransferManager:
self.is_paused = False
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
time.sleep(0.1)
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
logger.info("resume: 🟢 resume transfer manager and start to do transfer tasks")
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
@@ -1432,8 +1509,8 @@ class CacheTransferManager:
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.warning(
f"GPU memory usage {context}:"
logger.info(
f"{context}: "
f"max_allocated: {max_alloc:.2f}GB "
f"max_reserved: {max_reserved:.2f}GB "
f"current_allocated: {curr_alloc:.2f}GB "