[Feature] [KVCache] support attention_store kv cache backend (#5823)

* [feat] support attention_store kv cache backend

* [fix] fix codestyle

* [chore] optimize log

* [fix] fix write storage task

* [fix] fix read storage

* [fix] fix code conflict after merge develop

* [fix] fix cache bytes and read task token ids

* [chore] add model for cache transfer manager

* [chore] add some log

* [chore] remove launched_cache_manager_signal

* [fix] fix write_back_storage_task match_block_num condition

* [fix] fix swap_cost_time

* [ci] fix ci

* Update fastdeploy/engine/sched/resource_manager_v1.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/cache_manager/cache_transfer_manager.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2026-01-22 21:01:23 +08:00
committed by GitHub
parent 3cd0ffe36c
commit 8d27a523e7
17 changed files with 599 additions and 226 deletions
+272 -192
View File
@@ -29,6 +29,7 @@ import paddle
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.ops import (
cuda_host_alloc,
cuda_host_free,
@@ -40,7 +41,7 @@ from fastdeploy.cache_manager.ops import (
swap_cache_layout,
unset_data_ipc,
)
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform
@@ -58,6 +59,7 @@ def parse_args():
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--model_id", type=str, default="default", help="model id")
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
@@ -109,7 +111,7 @@ def parse_args():
"--kvcache_storage_backend",
type=str,
default=None,
choices=["mooncake", "none"],
choices=["mooncake", "attention_store", "none"],
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
)
parser.add_argument(
@@ -133,8 +135,6 @@ class CacheTransferManager:
"""
初始化CacheTransferManager
"""
device = args.device_id
rank = args.rank
self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = []
@@ -142,11 +142,31 @@ class CacheTransferManager:
self.gpu_cache_scales_k_tensors = []
self.gpu_cache_scales_v_tensors = []
self.speculative_config = SpeculativeConfig(args.speculative_config)
# parse kv cache shape
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
self.value_cache_shape = []
if args.value_cache_shape:
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
# extract kv cache shape into fields
self.num_gpu_blocks = self.key_cache_shape[0]
self.head_num = self.key_cache_shape[1]
self.block_size = self.key_cache_shape[2]
self.head_dim = self.key_cache_shape[3]
# compute cache bytes
self.cache_dtype = args.cache_dtype
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
# extract other arg values
self.model_id = args.model_id
self.n_ranks = args.mp_num
self.rank = args.rank
self.device = args.device_id
self.num_layers = args.num_layers
self.ipc_suffix = args.ipc_suffix
self.local_data_parallel_id = args.local_data_parallel_id
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
paddle.set_default_dtype(args.default_dtype)
@@ -158,18 +178,13 @@ class CacheTransferManager:
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
self.ipc_suffix = args.ipc_suffix
self.cache_dtype = args.cache_dtype
address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue(
address=address,
is_server=False,
num_client=args.mp_num,
client_id=rank,
client_id=self.rank,
local_data_parallel_id=args.local_data_parallel_id,
)
@@ -223,8 +238,22 @@ class CacheTransferManager:
self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully")
elif args.kvcache_storage_backend == "attention_store":
logger.info("Start initialize attention store...")
self.storage_backend = AttentionStore(
namespace=self.model_id,
shard_id=self.rank,
shard_num=self.n_ranks,
layer_num=self.num_layers + self.num_extra_layers,
block_token_size=self.block_size,
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
)
logger.info("Initialized attention store successfully!")
else:
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
self.storage_backend_type = args.kvcache_storage_backend
if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}")
@@ -246,18 +275,16 @@ class CacheTransferManager:
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
"""
layer_num = args.num_layers + self.num_extra_layers
head_num = self.key_cache_shape[1]
block_size = self.key_cache_shape[2]
head_dim = self.key_cache_shape[3]
block_num = (args.max_model_len + block_size - 1) // block_size
layer_num = self.num_layers + self.num_extra_layers
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
logger.info(
f"Creating cache buffer for storage with shape: "
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
)
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
self.storage_buffer_stride_bytes = (
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
)
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
@@ -296,8 +323,8 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
set_device(self.device)
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
for i in range(self.num_layers + self.num_extra_layers):
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
@@ -415,7 +442,7 @@ class CacheTransferManager:
self.v_dst_ptrs = []
self.k_scales_ptrs = []
self.v_scales_ptrs = []
for i in range(args.num_layers + self.num_extra_layers):
for i in range(self.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}"
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
@@ -446,228 +473,283 @@ class CacheTransferManager:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
return cache_bytes
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
def _run_read_storage(
self,
task_id: str,
token_ids: List[int],
start_read_block_idx: int,
k_cache_keys: List[str],
v_cache_keys: List[str],
gpu_block_ids: List[int],
cpu_block_ids: List[int],
timeout: float,
):
"""
Given the k_keys and v_keys, get the valid blocks number that
can be prefetched from storage backend.
Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU.
"""
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.storage_backend.exists(k_keys + v_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
num += 1
return num
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, "
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, "
f"cpu_block_ids_num: {len(cpu_block_ids)}"
)
if self.storage_backend_type == "mooncake":
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
result = self.storage_backend.batch_get(
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
)
read_cost_time = time.time() - start_time
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
read_cost_time = time.time() - start_time
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
for k, v in zip(k_result, v_result):
if k > 0 and v > 0:
success_block_num += 1
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
for k, v in zip(k_result, v_result):
if k > 0 and v > 0:
success_block_num += 1
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
mode = 1 # cpu ==> gpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_read_buffer,
self.key_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_read_buffer,
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
)
mode = 1 # cpu ==> gpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_read_buffer,
self.key_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_read_buffer,
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
)
elif self.storage_backend_type == "attention_store":
key_cache = []
val_cache = []
for i in range(self.num_layers + self.num_extra_layers):
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
start_time = time.time()
read_block_num = self.storage_backend.read(
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout
)
read_cost_time = time.time() - start_time
valid_gpu_block_ids = gpu_block_ids[:read_block_num]
logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s")
return valid_gpu_block_ids
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
f"error:{e}, {traceback.format_exc()}"
f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
)
raise
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
def read_storage_task(self, task: ReadStorageTask):
"""Read cache from the storage backend to the GPU memory."""
try:
logger.debug(
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, "
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
match_block_num = 0
if self.storage_backend_type == "mooncake":
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
)
logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}")
k_cache_keys = k_cache_keys[:match_block_num]
v_cache_keys = v_cache_keys[:match_block_num]
gpu_block_ids = gpu_block_ids[:match_block_num]
cpu_block_ids = [i for i in range(match_block_num)]
cpu_block_ids = cpu_block_ids[:match_block_num]
valid_gpu_block_ids = []
if match_block_num > 0:
# TODO: support timeout with actual block count
try:
valid_gpu_block_ids = self._run_read_storage(
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
task.task_id,
task.token_ids[: match_block_num * self.block_size],
task.start_read_block_idx,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
)
logger.info(
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
)
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
valid_gpu_block_ids = []
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids)
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
logger.info(
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
f"valid block num {len(valid_gpu_block_ids)}"
)
logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}")
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
f"An error occurred in read_storage_task: "
f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}"
)
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
def _run_write_back_storage(
self,
task_id,
token_ids,
start_write_block_idx,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
timeout,
):
try:
logger.debug(
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
f"gpu_block_ids: {gpu_block_ids}"
)
key_cache_size = [
self.key_cache_shape[0],
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
if self.storage_backend_type == "mooncake":
key_cache_size = [
self.key_cache_shape[0],
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
mode = 0 # gpu ==> cpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
mode = 0 # gpu ==> cpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
write_cost_time = time.time() - start_time
start_time = time.time()
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
write_cost_time = time.time() - start_time
logger.debug(
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
)
return block_num
elif self.storage_backend_type == "attention_store":
key_cache = []
val_cache = []
for i in range(self.num_layers + self.num_extra_layers):
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
start_time = time.time()
write_block_num = self.storage_backend.write(
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout
)
write_cost_time = time.time() - start_time
logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s")
return write_block_num
logger.debug(
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
f"error:{e}, {traceback.format_exc()}"
f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
)
return 0
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
def write_back_storage_task(self, task: WriteStorageTask):
"""
Write cache to the storage backend from the GPU memory.
"""
try:
logger.debug(
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
f"task_id: {task_id}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
if len(k_cache_keys) == 0:
logger.info(f"No uncached keys found for task {task_id}")
match_block_num = 0
if self.storage_backend_type == "mooncake":
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
if match_block_num >= len(k_cache_keys):
logger.info(f"No uncached keys found for task {task.task_id}")
gpu_block_ids = []
else:
try:
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = cpu_block_ids[match_block_num:]
# TODO: support timeout with actual block count
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
write_block_num = self._run_write_back_storage(
task.task_id,
task.token_ids,
match_block_num,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
)
logger.info(
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
)
except Exception as e:
logger.error(f"Error in write back storage task: {e}")
gpu_block_ids = []
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait()
if self.rank == 0: # 只有当rank为0时执行同步操作
self.cache_task_queue.swap_to_storage_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
f"error:{e}, {traceback.format_exc()}"
f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}"
)
def _do_swap_to_cpu_task(
@@ -759,12 +841,12 @@ class CacheTransferManager:
self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1:
data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"transfer data: get_transfer_task {data}")
logger.debug(f"do_data_transfer: {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
event_type, transfer_task_id = data[0], data[1]
event_type, event_args = data[0], data[1:]
if event_type.value == CacheStatus.SWAP2CPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_cpu_thread_pool.submit(
self._do_swap_to_cpu_task,
swap_node_ids,
@@ -774,7 +856,7 @@ class CacheTransferManager:
transfer_task_id,
)
elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_gpu_thread_pool.submit(
self._do_swap_to_gpu_task,
swap_node_ids,
@@ -784,22 +866,16 @@ class CacheTransferManager:
transfer_task_id,
)
elif event_type.value == CacheStatus.STORAGE2GPU.value:
hash_keys, gpu_block_ids, timeout = data[2:]
read_storage_task = event_args[0]
self.read_storage_thread_pool.submit(
self.read_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
read_storage_task,
)
elif event_type.value == CacheStatus.GPU2STORAGE.value:
hash_keys, gpu_block_ids, timeout = data[2:]
write_storage_task = event_args[0]
self.write_back_storage_thread_pool.submit(
self.write_back_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
write_storage_task,
)
else:
if self.n_ranks > 1:
@@ -1047,7 +1123,11 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
if args.mp_num > 1:
logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log")
else:
logger = get_logger("cache_transfer", "cache_transfer.log")
logger.info(f"args: {vars(args)}")
set_device(args.device_id)
try: