mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
* [Feature] support v1 update/clear api for RL * [fix] fix stale control responses when control method timed out * [chore] remove unused code * [chore] optimize tags and key_prefix * [test] fix ci * [chore] fix code style * [fix] fix ep control * [fix] fix ep control for engine cache queue
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
@@ -35,7 +36,6 @@ from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTa
|
||||
from fastdeploy.cache_manager.ops import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
memory_allocated,
|
||||
set_data_ipc,
|
||||
set_device,
|
||||
share_external_data_,
|
||||
@@ -49,7 +49,9 @@ from fastdeploy.cache_manager.transfer_factory import (
|
||||
MooncakeStore,
|
||||
)
|
||||
from fastdeploy.config import CacheConfig, SpeculativeConfig
|
||||
from fastdeploy.engine.request import ControlRequest, ControlResponse
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import console_logger, get_logger
|
||||
|
||||
@@ -96,7 +98,6 @@ def parse_args():
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
@@ -184,15 +185,17 @@ class CacheTransferManager:
|
||||
self.key_prefix = ""
|
||||
|
||||
# extract other arg values
|
||||
self.model_path = args.model_path
|
||||
self.model_id = os.path.basename(args.model_path.rstrip("/"))
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = args.rank
|
||||
self.device = args.device_id
|
||||
self.num_layers = args.num_layers
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.create_cache_tensor = args.create_cache_tensor
|
||||
self.local_data_parallel_id = args.local_data_parallel_id
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
self.cache_queue_port = args.cache_queue_port
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
@@ -200,8 +203,10 @@ class CacheTransferManager:
|
||||
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.control_task_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.ctrl_output_queue = None
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
@@ -209,7 +214,7 @@ class CacheTransferManager:
|
||||
is_server=False,
|
||||
num_client=args.mp_num,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
local_data_parallel_id=0,
|
||||
)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
@@ -231,11 +236,12 @@ class CacheTransferManager:
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
self._init_gpu_cache(args)
|
||||
self._init_gpu_cache()
|
||||
if self.num_cpu_blocks > 0:
|
||||
self._init_cpu_cache(args)
|
||||
self._init_cpu_cache()
|
||||
if self.storage_backend_type is not None:
|
||||
self._init_storage(args)
|
||||
self._init_control()
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
@@ -287,7 +293,9 @@ class CacheTransferManager:
|
||||
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
|
||||
|
||||
self.is_paused = False # transfer manager state
|
||||
self.is_sleeping = False
|
||||
self.inflight = 0 # number of inflight transfer tasks
|
||||
self.inflight_tasks = {}
|
||||
|
||||
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_transfer_inited_signal = IPCSignal(
|
||||
@@ -299,6 +307,15 @@ class CacheTransferManager:
|
||||
)
|
||||
self.cache_transfer_inited_signal.value[self.rank] = 1
|
||||
|
||||
def _init_control(self):
|
||||
dp_rank = self.local_data_parallel_id
|
||||
tp_rank = self.rank
|
||||
tp_size = self.n_ranks
|
||||
cq_port = self.cache_queue_port
|
||||
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_rank}_{cq_port}"
|
||||
self.ctrl_output_queue = FMQ().queue(name, "producer")
|
||||
logger.info(f"Init control output queue: {name} (producer)")
|
||||
|
||||
def _init_storage(self, args):
|
||||
try:
|
||||
# TODO: support cache scale for other backend
|
||||
@@ -354,14 +371,19 @@ class CacheTransferManager:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
self.key_prefix = ""
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
if os.path.exists(version_file_path):
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
|
||||
self._update_key_prefix()
|
||||
|
||||
logger.info("Initialize cache storage successfully")
|
||||
|
||||
def _update_key_prefix(self):
|
||||
# use key_prefix to distinguish cache for different version of weight in rl
|
||||
version_file_path = os.path.join(self.model_path, "version.yaml")
|
||||
if os.path.exists(version_file_path):
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
|
||||
else:
|
||||
logger.error(f"version.yaml not found at {version_file_path}")
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
@@ -410,20 +432,20 @@ class CacheTransferManager:
|
||||
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
def _init_gpu_cache(self):
|
||||
|
||||
if not args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners or messagers to create kv cache.")
|
||||
if not self.create_cache_tensor:
|
||||
logger.info("Waiting for runners or messagers to create kv cache.")
|
||||
while self.cache_ready_signal.value[self.rank] != 1:
|
||||
time.sleep(0.1)
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
|
||||
logger.info("OK! Stop waiting.")
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
cache_type = "uint8"
|
||||
else:
|
||||
cache_type = args.cache_dtype
|
||||
cache_type = self.cache_dtype
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
logger.info("Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
|
||||
@@ -446,14 +468,12 @@ class CacheTransferManager:
|
||||
self.value_cache_shape[2],
|
||||
self.value_cache_shape[3],
|
||||
]
|
||||
if args.create_cache_tensor:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
if self.create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
|
||||
fill_value=0,
|
||||
@@ -464,7 +484,7 @@ class CacheTransferManager:
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.full(
|
||||
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
|
||||
fill_value=0,
|
||||
@@ -472,13 +492,11 @@ class CacheTransferManager:
|
||||
)
|
||||
set_data_ipc(value_cache_scales, value_cache_scales_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_cache_scales = share_external_data_(
|
||||
key_cache_scales,
|
||||
@@ -488,7 +506,7 @@ class CacheTransferManager:
|
||||
)
|
||||
if self.value_cache_shape:
|
||||
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
value_cache_scales = share_external_data_(
|
||||
value_cache_scales,
|
||||
@@ -499,48 +517,65 @@ class CacheTransferManager:
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales
|
||||
self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name])
|
||||
if args.value_cache_shape:
|
||||
if self.value_cache_shape:
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales
|
||||
self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||
if self.create_cache_tensor:
|
||||
self.cache_ready_signal.value[self.rank] = 1
|
||||
while np.sum(self.cache_ready_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
|
||||
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
|
||||
logger.info("GPU KV cache is initialized")
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
def _clear_gpu_cache(self):
|
||||
if self.create_cache_tensor:
|
||||
logger.debug("Waiting for gpu runner to unlink cuda ipc")
|
||||
while self.cache_ready_signal.value[self.rank] != 0:
|
||||
time.sleep(0.1)
|
||||
logger.debug("Stop waiting! gpu runner has unlinked cuda ipc")
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
if hasattr(self, "gpu_cache_scales_k_tensors"):
|
||||
self.gpu_cache_scales_k_tensors.clear()
|
||||
if hasattr(self, "gpu_cache_scales_v_tensors"):
|
||||
self.gpu_cache_scales_v_tensors.clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
else:
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
logger.debug("Successfully unlinked gpu caches cuda ipc")
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks cleared gpu caches")
|
||||
|
||||
def _init_cpu_cache(self):
|
||||
if self.num_cpu_blocks == 0:
|
||||
return
|
||||
paddle.set_device("cpu")
|
||||
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
|
||||
if args.value_cache_shape:
|
||||
if self.value_cache_shape:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
key_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
logger.info("Initializing swap space (cpu cache) for all layers.")
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
|
||||
scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
if args.num_cpu_blocks == 0:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
return
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
|
||||
paddle.set_device("cpu")
|
||||
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
|
||||
self.k_dst_ptrs = []
|
||||
self.v_dst_ptrs = []
|
||||
self.k_scales_ptrs = []
|
||||
@@ -551,22 +586,43 @@ class CacheTransferManager:
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||
value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}"
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
f"..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes)
|
||||
self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name])
|
||||
if value_need_to_allocate_bytes > 0:
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes)
|
||||
self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
logger.info("Swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
while np.sum(self.swap_space_ready_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks init cpu caches")
|
||||
|
||||
def _clear_cpu_cache(self):
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
self.cpu_cache_kvs.clear()
|
||||
self.k_dst_ptrs.clear()
|
||||
self.v_dst_ptrs.clear()
|
||||
if hasattr(self, "k_scales_ptrs"):
|
||||
self.k_scales_ptrs.clear()
|
||||
if hasattr(self, "v_scales_ptrs"):
|
||||
self.v_scales_ptrs.clear()
|
||||
gc.collect()
|
||||
self.swap_space_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("All ranks cleared cpu caches")
|
||||
|
||||
def _run_read_storage(
|
||||
self,
|
||||
task_id: str,
|
||||
@@ -1024,6 +1080,79 @@ class CacheTransferManager:
|
||||
logger.debug(f"_do_swap_to_gpu_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"_do_swap_to_gpu_task: put_transfer_done_signal for transfer_task_id {transfer_task_id}")
|
||||
|
||||
def _handle_pause(self):
|
||||
if self.is_paused:
|
||||
logger.info("💡 Cache transfer manager is already paused, no need to pause again!")
|
||||
else:
|
||||
self.pause()
|
||||
logger.info("✅ Successfully paused transfer")
|
||||
return True
|
||||
|
||||
def _handle_resume(self):
|
||||
if not self.is_paused:
|
||||
logger.info("💡 Cache transfer manager is not paused, no need to resume!")
|
||||
else:
|
||||
self.resume()
|
||||
if self.storage_backend_type is not None:
|
||||
self._update_key_prefix()
|
||||
logger.info("✅ Successfully resumed transfer")
|
||||
return True
|
||||
|
||||
def _handle_sleep(self):
|
||||
if self.is_sleeping:
|
||||
logger.info("💡 Cache transfer manager is already sleeping, no need to sleep again!")
|
||||
else:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._clear_cpu_cache()
|
||||
self._clear_gpu_cache()
|
||||
self.is_sleeping = True
|
||||
logger.info("✅ Successfully fell asleep (offloaded caches)")
|
||||
return True
|
||||
|
||||
def _handle_wakeup(self):
|
||||
if not self.is_sleeping:
|
||||
logger.info("💡 Cache transfer manager is not sleeping, no need to wakeup!")
|
||||
else:
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache()
|
||||
self._init_gpu_cache()
|
||||
self.is_sleeping = False
|
||||
logger.info("✅ Successfully wakeup (reload caches)")
|
||||
return True
|
||||
|
||||
def control_task(self, task: ControlRequest):
|
||||
method = task.get_method()
|
||||
tags = task.args.get("tags", {})
|
||||
logger.info(f"Received control task: {method}, tags: {tags}")
|
||||
|
||||
handlers = {
|
||||
"pause": self._handle_pause,
|
||||
"resume": self._handle_resume,
|
||||
"sleep": self._handle_sleep,
|
||||
"wakeup": self._handle_wakeup,
|
||||
}
|
||||
|
||||
handler = handlers.get(method)
|
||||
error_code = 200
|
||||
error_message = "Success"
|
||||
|
||||
if handler:
|
||||
try:
|
||||
handler()
|
||||
except Exception as e:
|
||||
error_code = 500
|
||||
error_message = f"Failed to execute {method}: {str(e)}"
|
||||
logger.error(f"Error in control_task: {traceback.format_exc()}")
|
||||
else:
|
||||
error_code = 400
|
||||
error_message = f"Unknown control method: {method}"
|
||||
logger.warning(error_message)
|
||||
|
||||
self.cache_task_queue.barrier.wait()
|
||||
resp = ControlResponse(task.request_id, error_code, error_message)
|
||||
asyncio.run(self.ctrl_output_queue.put(resp))
|
||||
logger.info(f"Put response into output queue {self.ctrl_output_queue.name}: {resp}")
|
||||
|
||||
def check_work_status(self, time_interval_threashold=envs.FD_CACHE_PROC_EXIT_TIMEOUT):
|
||||
"""
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
@@ -1043,8 +1172,13 @@ class CacheTransferManager:
|
||||
return fn(*args)
|
||||
finally:
|
||||
self.inflight -= 1
|
||||
logger.debug(f"submit_task: {fn.__name__} finished, args: {args}, current inflight: {self.inflight}")
|
||||
|
||||
self.inflight += 1
|
||||
thread_pool.submit(inflight_task, task_fn, *args)
|
||||
logger.debug(
|
||||
f"submit_task: {task_fn.__name__} submitted to thread pool, args: {args}, current inflight: {self.inflight}"
|
||||
)
|
||||
|
||||
def do_data_transfer(self):
|
||||
"""
|
||||
@@ -1055,6 +1189,7 @@ class CacheTransferManager:
|
||||
max_errors = (
|
||||
envs.FD_CACHE_PROC_ERROR_COUNT
|
||||
) # After this many consecutive errors, check if the worker process exists.
|
||||
is_paused = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
@@ -1066,18 +1201,18 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier0.reset()
|
||||
|
||||
# Ensure all ranks synchronically do one of the following things:
|
||||
# (1) If rank#0 is paused, wait for a short time and check out rank#0 status again;
|
||||
# (2) otherwise, all ranks are allowed to pull tasks from cache task queue
|
||||
# (1) If rank#0 is paused, wait for inflight tasks to finish first, then only process control tasks;
|
||||
# (2) otherwise, all ranks are allowed to pull all tasks from cache task queue
|
||||
if self.cache_task_is_paused_signal.value[0] == 1:
|
||||
# wait for inflight tasks to finish first
|
||||
while self.inflight != 0:
|
||||
time.sleep(0.1)
|
||||
# mark the current rank as not having inflight tasks
|
||||
self.cache_task_inflight_signal.value[self.rank] = 0
|
||||
time.sleep(1)
|
||||
continue
|
||||
is_paused = True
|
||||
else:
|
||||
self.cache_task_inflight_signal.value[self.rank] = 1
|
||||
is_paused = False
|
||||
|
||||
if self.rank == 0:
|
||||
if not self.cache_task_queue.empty():
|
||||
@@ -1088,12 +1223,16 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
self.inflight += 1
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||
logger.debug(f"do_data_transfer: {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
event_type, event_args = data[0], data[1:]
|
||||
|
||||
# control task is the only task allowed to execute when loop is paused
|
||||
if is_paused and event_type.value != CacheStatus.CTRL.value:
|
||||
continue
|
||||
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.submit_task(
|
||||
@@ -1130,6 +1269,13 @@ class CacheTransferManager:
|
||||
self.write_back_storage_task,
|
||||
write_storage_task,
|
||||
)
|
||||
elif event_type.value == CacheStatus.CTRL.value:
|
||||
control_task = event_args[0]
|
||||
self.control_task_thread_pool.submit(
|
||||
self.control_task,
|
||||
control_task,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier2.wait()
|
||||
@@ -1292,115 +1438,46 @@ class CacheTransferManager:
|
||||
# TODO XPU support RL
|
||||
if unset_data_ipc is None:
|
||||
return
|
||||
logger.info("[RL] Launch a thread to clear/restore kv cache when model weights are cleared/updated.")
|
||||
logger.info(
|
||||
"check_cache_status: Launch a thread to clear/restore kv cache when model weights are cleared/updated."
|
||||
)
|
||||
while True:
|
||||
# handle cache clearing/restoring
|
||||
if self.kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports clearing cache."
|
||||
try:
|
||||
# wait for inflight transfer tasks to finish and pause transfer manager
|
||||
# pause transfer
|
||||
self.pause()
|
||||
|
||||
# clear cpu caches
|
||||
logger.info("[RL] start clearing caches")
|
||||
logger.debug("[RL] start clearing cpu caches")
|
||||
# clear caches
|
||||
logger.info("check_cache_status: start clearing caches")
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
paddle.set_device("cpu")
|
||||
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
|
||||
cuda_host_free(ptrs)
|
||||
self.cpu_cache_kvs.clear()
|
||||
self.k_dst_ptrs.clear()
|
||||
self.v_dst_ptrs.clear()
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.k_scales_ptrs.clear()
|
||||
self.v_scales_ptrs.clear()
|
||||
gc.collect()
|
||||
logger.debug("[RL] successfully cleared cpu caches")
|
||||
# reset swap_space_ready_signal
|
||||
self.swap_space_ready_signal.value[self.rank] = 0
|
||||
while np.sum(self.swap_space_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.debug("[RL] all ranks cleared cpu caches")
|
||||
else:
|
||||
logger.debug("[RL] skip clearing cpu caches")
|
||||
|
||||
# clear gpu caches
|
||||
logger.debug("[RL] start clearing gpu caches")
|
||||
if args.create_cache_tensor:
|
||||
logger.info("[RL] waiting for gpu runner to unlink cuda ipc")
|
||||
while self.cache_ready_signal.value[self.rank] != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] stop waiting! gpu runner has unlinked cuda ipc")
|
||||
paddle.set_device(f"gpu:{self.device}")
|
||||
paddle.set_flags({"FLAGS_selected_gpus": f"{self.device}"})
|
||||
self.gpu_cache_kvs.clear()
|
||||
self.gpu_cache_k_tensors.clear()
|
||||
self.gpu_cache_v_tensors.clear()
|
||||
if self.cache_dtype == "block_wise_fp8":
|
||||
self.gpu_cache_scales_k_tensors.clear()
|
||||
self.gpu_cache_scales_v_tensors.clear()
|
||||
paddle.device.cuda.empty_cache()
|
||||
logger.debug("[RL] successfully cleared gpu caches")
|
||||
else:
|
||||
for name, tensor in self.gpu_cache_kvs.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
logger.debug("[RL] successfully unlinked gpu caches cuda ipc")
|
||||
self.cache_ready_signal.value[self.rank] = 0
|
||||
|
||||
while np.sum(self.cache_ready_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] all ranks cleared caches!")
|
||||
|
||||
# reset kv_cache_status_signal
|
||||
self._clear_cpu_cache()
|
||||
self._clear_gpu_cache()
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
|
||||
|
||||
self._log_memory("after clearing caches")
|
||||
|
||||
self._log_memory("check_cache_status: after clearing caches")
|
||||
except Exception as e:
|
||||
logger.error(f"[RL] failed to clear caches: {e}")
|
||||
logger.error(f"check_cache_status: failed to clear caches: {e}")
|
||||
|
||||
elif self.kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
|
||||
assert args.splitwise_role == "mixed", "Only mixed mode supports updating cache."
|
||||
try:
|
||||
# restore cpu cache
|
||||
logger.info("[RL] start restoring caches")
|
||||
logger.debug("[RL] start restoring cpu caches")
|
||||
logger.info("check_cache_status: start restoring caches")
|
||||
if self.num_cpu_blocks > 0 and envs.FD_ENABLE_SWAP_SPACE_CLEARING:
|
||||
self._init_cpu_cache(args)
|
||||
logger.debug("[RL] successfully restored cpu caches")
|
||||
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
logger.debug("[RL] all ranks restored cpu caches")
|
||||
else:
|
||||
logger.debug("[RL] skip restoring cpu caches")
|
||||
|
||||
# restore gpu cache and set cache_ready_signal
|
||||
logger.debug("[RL] start restoring gpu caches")
|
||||
self._init_gpu_cache(args)
|
||||
logger.debug("[RL] successfully restored gpu caches")
|
||||
self._init_cpu_cache()
|
||||
self._init_gpu_cache()
|
||||
|
||||
# update key prefix for kv cache backend
|
||||
if self.storage_backend_type is not None:
|
||||
# use key_prefix to distinguish cache for different version of weight in rl
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
|
||||
|
||||
# wait for all ranks caches to be ready
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] all ranks restored caches!")
|
||||
self._update_key_prefix()
|
||||
|
||||
# resume transfer
|
||||
self.resume()
|
||||
|
||||
# set kv_cache_status_signal
|
||||
self.kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
|
||||
|
||||
self._log_memory("after restoring caches")
|
||||
self._log_memory("check_cache_status: after restoring caches")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RL] failed to restore caches: {e}")
|
||||
logger.error(f"check_cache_status: failed to restore caches: {e}")
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -1409,11 +1486,11 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.pause_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.cache_task_queue.pause_barrier.reset()
|
||||
logger.info("[RL] 🟠 wait for inflight transfer tasks to finish")
|
||||
logger.info("pause: 🟠 wait for inflight transfer tasks to finish")
|
||||
self.is_paused = True
|
||||
while np.sum(self.cache_task_inflight_signal.value) != 0:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🔴 pause transfer manager and stop do transfer tasks")
|
||||
logger.info("pause: 🔴 pause transfer manager and stop do transfer tasks")
|
||||
|
||||
def resume(self):
|
||||
if self.n_ranks > 1:
|
||||
@@ -1423,7 +1500,7 @@ class CacheTransferManager:
|
||||
self.is_paused = False
|
||||
while np.sum(self.cache_task_inflight_signal.value) != self.n_ranks:
|
||||
time.sleep(0.1)
|
||||
logger.info("[RL] 🟢 resume transfer manager and start to do transfer tasks")
|
||||
logger.info("resume: 🟢 resume transfer manager and start to do transfer tasks")
|
||||
|
||||
def _log_memory(self, context: str):
|
||||
"""Log current GPU memory usage."""
|
||||
@@ -1432,8 +1509,8 @@ class CacheTransferManager:
|
||||
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
|
||||
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
|
||||
|
||||
logger.warning(
|
||||
f"GPU memory usage {context}:"
|
||||
logger.info(
|
||||
f"{context}: "
|
||||
f"max_allocated: {max_alloc:.2f}GB "
|
||||
f"max_reserved: {max_reserved:.2f}GB "
|
||||
f"current_allocated: {curr_alloc:.2f}GB "
|
||||
|
||||
Reference in New Issue
Block a user