mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-25 01:55:45 +08:00
[Feature] [KVCache] support attention_store kv cache backend (#5823)
* [feat] support attention_store kv cache backend * [fix] fix codestyle * [chore] optimize log * [fix] fix write storage task * [fix] fix read storage * [fix] fix code conflict after merge develop * [fix] fix cache bytes and read task token ids * [chore] add model for cache transfer manager * [chore] add some log * [chore] remove launched_cache_manager_signal * [fix] fix write_back_storage_task match_block_num condition * [fix] fix swap_cost_time * [ci] fix ci * Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||
from fastdeploy.cache_manager.ops import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
@@ -40,7 +41,7 @@ from fastdeploy.cache_manager.ops import (
|
||||
swap_cache_layout,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
|
||||
from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -58,6 +59,7 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--model_id", type=str, default="default", help="model id")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||
@@ -109,7 +111,7 @@ def parse_args():
|
||||
"--kvcache_storage_backend",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mooncake", "none"],
|
||||
choices=["mooncake", "attention_store", "none"],
|
||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -133,8 +135,6 @@ class CacheTransferManager:
|
||||
"""
|
||||
初始化CacheTransferManager
|
||||
"""
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
@@ -142,11 +142,31 @@ class CacheTransferManager:
|
||||
self.gpu_cache_scales_k_tensors = []
|
||||
self.gpu_cache_scales_v_tensors = []
|
||||
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
|
||||
# parse kv cache shape
|
||||
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
||||
self.value_cache_shape = []
|
||||
if args.value_cache_shape:
|
||||
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
|
||||
|
||||
# extract kv cache shape into fields
|
||||
self.num_gpu_blocks = self.key_cache_shape[0]
|
||||
self.head_num = self.key_cache_shape[1]
|
||||
self.block_size = self.key_cache_shape[2]
|
||||
self.head_dim = self.key_cache_shape[3]
|
||||
|
||||
# compute cache bytes
|
||||
self.cache_dtype = args.cache_dtype
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
|
||||
# extract other arg values
|
||||
self.model_id = args.model_id
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = args.rank
|
||||
self.device = args.device_id
|
||||
self.num_layers = args.num_layers
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.local_data_parallel_id = args.local_data_parallel_id
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
@@ -158,18 +178,13 @@ class CacheTransferManager:
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.cache_dtype = args.cache_dtype
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=args.mp_num,
|
||||
client_id=rank,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
)
|
||||
|
||||
@@ -223,8 +238,22 @@ class CacheTransferManager:
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
elif args.kvcache_storage_backend == "attention_store":
|
||||
logger.info("Start initialize attention store...")
|
||||
self.storage_backend = AttentionStore(
|
||||
namespace=self.model_id,
|
||||
shard_id=self.rank,
|
||||
shard_num=self.n_ranks,
|
||||
layer_num=self.num_layers + self.num_extra_layers,
|
||||
block_token_size=self.block_size,
|
||||
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||
device_id=self.device,
|
||||
dp_id=self.local_data_parallel_id,
|
||||
)
|
||||
logger.info("Initialized attention store successfully!")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||
self.storage_backend_type = args.kvcache_storage_backend
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
@@ -246,18 +275,16 @@ class CacheTransferManager:
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
"""
|
||||
layer_num = args.num_layers + self.num_extra_layers
|
||||
head_num = self.key_cache_shape[1]
|
||||
block_size = self.key_cache_shape[2]
|
||||
head_dim = self.key_cache_shape[3]
|
||||
block_num = (args.max_model_len + block_size - 1) // block_size
|
||||
layer_num = self.num_layers + self.num_extra_layers
|
||||
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
|
||||
logger.info(
|
||||
f"Creating cache buffer for storage with shape: "
|
||||
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
|
||||
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
|
||||
)
|
||||
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
|
||||
self.storage_buffer_stride_bytes = (
|
||||
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
|
||||
)
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
@@ -296,8 +323,8 @@ class CacheTransferManager:
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
||||
@@ -415,7 +442,7 @@ class CacheTransferManager:
|
||||
self.v_dst_ptrs = []
|
||||
self.k_scales_ptrs = []
|
||||
self.v_scales_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||
@@ -446,228 +473,283 @@ class CacheTransferManager:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return cache_bytes
|
||||
|
||||
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
|
||||
def _run_read_storage(
|
||||
self,
|
||||
task_id: str,
|
||||
token_ids: List[int],
|
||||
start_read_block_idx: int,
|
||||
k_cache_keys: List[str],
|
||||
v_cache_keys: List[str],
|
||||
gpu_block_ids: List[int],
|
||||
cpu_block_ids: List[int],
|
||||
timeout: float,
|
||||
):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.storage_backend.exists(k_keys + v_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, "
|
||||
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, "
|
||||
f"cpu_block_ids_num: {len(cpu_block_ids)}"
|
||||
)
|
||||
if self.storage_backend_type == "mooncake":
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
result = self.storage_backend.batch_get(
|
||||
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
|
||||
)
|
||||
read_cost_time = time.time() - start_time
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
read_cost_time = time.time() - start_time
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
mode = 1 # cpu ==> gpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||
)
|
||||
|
||||
mode = 1 # cpu ==> gpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||
)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
key_cache = []
|
||||
val_cache = []
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
|
||||
start_time = time.time()
|
||||
read_block_num = self.storage_backend.read(
|
||||
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout
|
||||
)
|
||||
read_cost_time = time.time() - start_time
|
||||
valid_gpu_block_ids = gpu_block_ids[:read_block_num]
|
||||
logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s")
|
||||
|
||||
return valid_gpu_block_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
def read_storage_task(self, task: ReadStorageTask):
|
||||
"""Read cache from the storage backend to the GPU memory."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, "
|
||||
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}"
|
||||
)
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(
|
||||
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
|
||||
)
|
||||
logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}")
|
||||
|
||||
k_cache_keys = k_cache_keys[:match_block_num]
|
||||
v_cache_keys = v_cache_keys[:match_block_num]
|
||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||
cpu_block_ids = [i for i in range(match_block_num)]
|
||||
cpu_block_ids = cpu_block_ids[:match_block_num]
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
if match_block_num > 0:
|
||||
# TODO: support timeout with actual block count
|
||||
try:
|
||||
valid_gpu_block_ids = self._run_read_storage(
|
||||
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
|
||||
task.task_id,
|
||||
task.token_ids[: match_block_num * self.block_size],
|
||||
task.start_read_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
)
|
||||
logger.info(
|
||||
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
|
||||
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
|
||||
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
|
||||
result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids)
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
|
||||
f"valid block num {len(valid_gpu_block_ids)}"
|
||||
)
|
||||
logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
|
||||
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in read_storage_task: "
|
||||
f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
def _run_write_back_storage(
|
||||
self,
|
||||
task_id,
|
||||
token_ids,
|
||||
start_write_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
timeout,
|
||||
):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
if self.storage_backend_type == "mooncake":
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
mode = 0 # gpu ==> cpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
|
||||
mode = 0 # gpu ==> cpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
write_cost_time = time.time() - start_time
|
||||
start_time = time.time()
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
write_cost_time = time.time() - start_time
|
||||
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
||||
)
|
||||
return block_num
|
||||
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
key_cache = []
|
||||
val_cache = []
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
|
||||
start_time = time.time()
|
||||
write_block_num = self.storage_backend.write(
|
||||
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout
|
||||
)
|
||||
write_cost_time = time.time() - start_time
|
||||
logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s")
|
||||
return write_block_num
|
||||
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
return 0
|
||||
|
||||
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
def write_back_storage_task(self, task: WriteStorageTask):
|
||||
"""
|
||||
Write cache to the storage backend from the GPU memory.
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
|
||||
f"task_id: {task_id}, timeout: {timeout}"
|
||||
)
|
||||
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
|
||||
if len(k_cache_keys) == 0:
|
||||
logger.info(f"No uncached keys found for task {task_id}")
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
|
||||
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
|
||||
|
||||
if match_block_num >= len(k_cache_keys):
|
||||
logger.info(f"No uncached keys found for task {task.task_id}")
|
||||
gpu_block_ids = []
|
||||
else:
|
||||
try:
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
cpu_block_ids = cpu_block_ids[match_block_num:]
|
||||
# TODO: support timeout with actual block count
|
||||
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
|
||||
write_block_num = self._run_write_back_storage(
|
||||
task.task_id,
|
||||
task.token_ids,
|
||||
match_block_num,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in write back storage task: {e}")
|
||||
gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
|
||||
result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
|
||||
self.cache_task_queue.swap_to_storage_barrier.wait()
|
||||
if self.rank == 0: # 只有当rank为0时执行同步操作
|
||||
self.cache_task_queue.swap_to_storage_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
||||
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
@@ -759,12 +841,12 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
||||
logger.debug(f"do_data_transfer: {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
event_type, transfer_task_id = data[0], data[1]
|
||||
event_type, event_args = data[0], data[1:]
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
@@ -774,7 +856,7 @@ class CacheTransferManager:
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
@@ -784,22 +866,16 @@ class CacheTransferManager:
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
read_storage_task = event_args[0]
|
||||
self.read_storage_thread_pool.submit(
|
||||
self.read_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
read_storage_task,
|
||||
)
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
write_storage_task = event_args[0]
|
||||
self.write_back_storage_thread_pool.submit(
|
||||
self.write_back_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
write_storage_task,
|
||||
)
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
@@ -1047,7 +1123,11 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
||||
if args.mp_num > 1:
|
||||
logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log")
|
||||
else:
|
||||
logger = get_logger("cache_transfer", "cache_transfer.log")
|
||||
|
||||
logger.info(f"args: {vars(args)}")
|
||||
set_device(args.device_id)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user