mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[KVCache] Storage cache supports c8 model (#6298)
* Refine cache transfer manager * Storage cache supports c8 model
This commit is contained in:
@@ -25,25 +25,23 @@ void SwapCacheImpLayout(
|
||||
const std::vector<int64_t>& gpu_block_ids,
|
||||
const std::vector<int64_t>& cpu_block_ids,
|
||||
int mode) {
|
||||
// mode is 0: gpu to cpu; 1: cpu to gpu
|
||||
// cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
// buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
/*
|
||||
mode is 0: gpu to cpu; 1: cpu to gpu
|
||||
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
scale layout: layer_num * [block_num, head_num, block_size]
|
||||
cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
scale buffer layout: [block_num, layer_num, head_num, block_size]
|
||||
*/
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int64_t layer_number = cache_gpu_tensors.size();
|
||||
const int64_t num_heads = cache_shape[1];
|
||||
const int64_t block_size = cache_shape[2];
|
||||
const int64_t head_dim = cache_shape[3];
|
||||
const int64_t cache_block_stride = num_heads * block_size * head_dim;
|
||||
|
||||
#ifdef SWAP_DEBUG
|
||||
std::cout << "layer_number:" << layer_number << std::endl;
|
||||
std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1]
|
||||
<< ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl;
|
||||
std::cout << "cache_block_stride:" << cache_block_stride << std::endl;
|
||||
#endif
|
||||
int64_t cache_block_stride = 1;
|
||||
for (int i = 1; i < cache_shape.size(); i++) {
|
||||
cache_block_stride *= cache_shape[i];
|
||||
}
|
||||
|
||||
auto stream = cache_gpu_tensors[0].stream();
|
||||
const cudaMemcpyKind copy_kind =
|
||||
|
||||
@@ -173,7 +173,11 @@ class CacheTransferManager:
|
||||
|
||||
# compute cache bytes
|
||||
self.cache_dtype = args.cache_dtype
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
self.cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
|
||||
self.scale_item_bytes = self._get_cache_item_bytes(paddle.get_default_dtype())
|
||||
self.has_cache_scale = self.cache_dtype == "block_wise_fp8"
|
||||
if self.has_cache_scale:
|
||||
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size]
|
||||
|
||||
# extract other arg values
|
||||
self.model_id = os.path.basename(args.model_path.rstrip("/"))
|
||||
@@ -272,6 +276,14 @@ class CacheTransferManager:
|
||||
self.storage_backend_type = args.kvcache_storage_backend
|
||||
|
||||
try:
|
||||
# TODO: support cache scale for other backend
|
||||
if self.has_cache_scale:
|
||||
if self.storage_backend_type not in ["mooncake"]:
|
||||
raise ValueError(
|
||||
f"Unsupported storage backend ({self.storage_backend_type}) "
|
||||
"when cache quantization is block_wise_fp8"
|
||||
)
|
||||
|
||||
if self.storage_backend_type is None:
|
||||
self.storage_backend = None
|
||||
elif self.storage_backend_type == "mooncake":
|
||||
@@ -288,7 +300,10 @@ class CacheTransferManager:
|
||||
shard_num=self.n_ranks,
|
||||
layer_num=self.num_layers + self.num_extra_layers,
|
||||
block_token_size=self.block_size,
|
||||
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||
bytes_per_shard_layer_per_block=self.head_num
|
||||
* self.block_size
|
||||
* self.head_dim
|
||||
* self.cache_item_bytes,
|
||||
device_id=self.device,
|
||||
dp_id=self.local_data_parallel_id,
|
||||
)
|
||||
@@ -326,7 +341,9 @@ class CacheTransferManager:
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
scale layout: layer_num * [block_num, head_num, block_size]
|
||||
cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
scale buffer layout: [block_num, layer_num, head_num, block_size]
|
||||
"""
|
||||
layer_num = self.num_layers + self.num_extra_layers
|
||||
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
|
||||
@@ -335,21 +352,38 @@ class CacheTransferManager:
|
||||
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
|
||||
)
|
||||
|
||||
self.storage_buffer_stride_bytes = (
|
||||
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
|
||||
self.cache_buffer_stride_bytes = (
|
||||
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_item_bytes
|
||||
)
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
cache_buffer_total_bytes = block_num * self.cache_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for all layers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
read_buffer = cuda_host_alloc(total_bytes)
|
||||
logger.info(f"Creating cache cpu buffer for all layers: {cache_buffer_total_bytes / 1024 ** 3:.2f}GB")
|
||||
read_buffer = cuda_host_alloc(cache_buffer_total_bytes)
|
||||
self.storage_key_read_buffer = read_buffer
|
||||
self.storage_value_read_buffer = read_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(read_buffer, total_bytes)
|
||||
self.storage_value_read_buffer = read_buffer + cache_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(read_buffer, cache_buffer_total_bytes)
|
||||
|
||||
write_buffer = cuda_host_alloc(total_bytes)
|
||||
write_buffer = cuda_host_alloc(cache_buffer_total_bytes)
|
||||
self.storage_key_write_buffer = write_buffer
|
||||
self.storage_value_write_buffer = write_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, total_bytes)
|
||||
self.storage_value_write_buffer = write_buffer + cache_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, cache_buffer_total_bytes)
|
||||
|
||||
if self.has_cache_scale:
|
||||
self.scale_buffer_stride_bytes = layer_num * self.head_num * self.block_size * self.scale_item_bytes
|
||||
scale_buffer_total_bytes = block_num * self.scale_buffer_stride_bytes * 2
|
||||
logger.info(
|
||||
f"Creating scale cpu buffer cache for all layers: {scale_buffer_total_bytes / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
|
||||
read_buffer = cuda_host_alloc(scale_buffer_total_bytes)
|
||||
self.storage_key_scale_read_buffer = read_buffer
|
||||
self.storage_value_scale_read_buffer = read_buffer + scale_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(read_buffer, scale_buffer_total_bytes)
|
||||
|
||||
write_buffer = cuda_host_alloc(scale_buffer_total_bytes)
|
||||
self.storage_key_scale_write_buffer = write_buffer
|
||||
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
@@ -367,6 +401,7 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
@@ -464,9 +499,9 @@ class CacheTransferManager:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
|
||||
cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
|
||||
@@ -507,14 +542,16 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _get_cache_bytes(self, cache_dtype):
|
||||
if cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
def _get_cache_item_bytes(self, cache_dtype):
|
||||
if cache_dtype == "float32":
|
||||
bytes = 4
|
||||
elif cache_dtype in ("bfloat16", "float16"):
|
||||
bytes = 2
|
||||
elif cache_dtype in ["uint8", "block_wise_fp8"]:
|
||||
cache_bytes = 1
|
||||
bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return cache_bytes
|
||||
return bytes
|
||||
|
||||
def _run_read_storage(
|
||||
self,
|
||||
@@ -523,6 +560,8 @@ class CacheTransferManager:
|
||||
start_read_block_idx: int,
|
||||
k_cache_keys: List[str],
|
||||
v_cache_keys: List[str],
|
||||
k_scale_keys: List[str],
|
||||
v_scale_keys: List[str],
|
||||
gpu_block_ids: List[int],
|
||||
cpu_block_ids: List[int],
|
||||
timeout: float,
|
||||
@@ -535,23 +574,45 @@ class CacheTransferManager:
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
self.storage_key_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
self.storage_value_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
target_locations = k_cache_ptrs + v_cache_ptrs
|
||||
target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
if k_scale_keys and v_scale_keys:
|
||||
keys.extend(k_scale_keys + v_scale_keys)
|
||||
k_scale_ptrs = [
|
||||
self.storage_key_scale_read_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_scale_ptrs = [
|
||||
self.storage_value_scale_read_buffer + i * self.scale_buffer_stride_bytes
|
||||
for i in cpu_block_ids
|
||||
]
|
||||
target_locations.extend(k_scale_ptrs + v_scale_ptrs)
|
||||
target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2)
|
||||
|
||||
start_time = time.time()
|
||||
result = self.storage_backend.batch_get(
|
||||
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
|
||||
keys=keys, target_locations=target_locations, target_sizes=target_sizes
|
||||
)
|
||||
read_cost_time = time.time() - start_time
|
||||
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
if k_scale_keys and v_scale_keys:
|
||||
k_result, v_result = result[:block_num], result[block_num : 2 * block_num]
|
||||
k_scale_result, v_scale_result = result[2 * block_num : 3 * block_num], result[3 * block_num :]
|
||||
success_block_num = 0
|
||||
for k, v, k_scale, v_scale in zip(k_result, v_result, k_scale_result, v_scale_result):
|
||||
if not (k > 0 and v > 0 and k_scale > 0 and v_scale > 0):
|
||||
break
|
||||
success_block_num += 1
|
||||
else:
|
||||
k_result, v_result = result[:block_num], result[block_num : 2 * block_num]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if not (k > 0 and v > 0):
|
||||
break
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
@@ -577,6 +638,25 @@ class CacheTransferManager:
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
if k_scale_keys and v_scale_keys:
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_scales_k_tensors,
|
||||
self.storage_key_scale_read_buffer,
|
||||
self.cache_scale_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_scales_v_tensors,
|
||||
self.storage_value_scale_read_buffer,
|
||||
self.cache_scale_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||
@@ -607,14 +687,27 @@ class CacheTransferManager:
|
||||
|
||||
def read_storage_task(self, task: ReadStorageTask):
|
||||
"""Read cache from the storage backend to the GPU memory."""
|
||||
assert (
|
||||
self.storage_backend
|
||||
), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}"
|
||||
|
||||
try:
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
|
||||
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
|
||||
if not self.has_cache_scale:
|
||||
k_scale_keys = None
|
||||
v_scale_keys = None
|
||||
else:
|
||||
k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys]
|
||||
v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys]
|
||||
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type in ("mooncake", "file"):
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
|
||||
match_block_num = self.storage_backend.query(
|
||||
k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout
|
||||
)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(
|
||||
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
|
||||
@@ -623,6 +716,8 @@ class CacheTransferManager:
|
||||
|
||||
k_cache_keys = k_cache_keys[:match_block_num]
|
||||
v_cache_keys = v_cache_keys[:match_block_num]
|
||||
k_scale_keys = k_scale_keys[:match_block_num] if k_scale_keys else None
|
||||
v_scale_keys = v_scale_keys[:match_block_num] if v_scale_keys else None
|
||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||
cpu_block_ids = cpu_block_ids[:match_block_num]
|
||||
valid_gpu_block_ids = []
|
||||
@@ -635,6 +730,8 @@ class CacheTransferManager:
|
||||
task.start_read_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
k_scale_keys,
|
||||
v_scale_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
@@ -643,7 +740,9 @@ class CacheTransferManager:
|
||||
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
|
||||
logger.error(
|
||||
f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}"
|
||||
)
|
||||
valid_gpu_block_ids = []
|
||||
finally:
|
||||
try:
|
||||
@@ -674,24 +773,20 @@ class CacheTransferManager:
|
||||
start_write_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
k_scale_keys,
|
||||
v_scale_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
timeout,
|
||||
):
|
||||
try:
|
||||
if self.storage_backend_type in ("mooncake", "file"):
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
mode = 0 # gpu ==> cpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
self.key_cache_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
@@ -700,27 +795,57 @@ class CacheTransferManager:
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
self.key_cache_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
if k_scale_keys and v_scale_keys:
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_scales_k_tensors,
|
||||
self.storage_key_scale_write_buffer,
|
||||
self.cache_scale_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_scales_v_tensors,
|
||||
self.storage_value_scale_write_buffer,
|
||||
self.cache_scale_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
self.storage_key_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
self.storage_value_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
target_locations = k_cache_ptrs + v_cache_ptrs
|
||||
target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
if k_scale_keys and v_scale_keys:
|
||||
keys.extend(k_scale_keys + v_scale_keys)
|
||||
k_scale_ptrs = [
|
||||
self.storage_key_scale_write_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_scale_ptrs = [
|
||||
self.storage_value_scale_write_buffer + i * self.scale_buffer_stride_bytes
|
||||
for i in cpu_block_ids
|
||||
]
|
||||
target_locations.extend(k_scale_ptrs + v_scale_ptrs)
|
||||
target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2)
|
||||
|
||||
start_time = time.time()
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes)
|
||||
write_cost_time = time.time() - start_time
|
||||
|
||||
logger.debug(
|
||||
@@ -753,15 +878,27 @@ class CacheTransferManager:
|
||||
"""
|
||||
Write cache to the storage backend from the GPU memory.
|
||||
"""
|
||||
assert (
|
||||
self.storage_backend
|
||||
), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}"
|
||||
|
||||
try:
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
|
||||
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
|
||||
if not self.has_cache_scale:
|
||||
k_scale_keys = None
|
||||
v_scale_keys = None
|
||||
else:
|
||||
k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys]
|
||||
v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys]
|
||||
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == ("mooncake", "file"):
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
|
||||
match_block_num = self.storage_backend.query(
|
||||
k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout
|
||||
)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
|
||||
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
|
||||
@@ -773,6 +910,8 @@ class CacheTransferManager:
|
||||
try:
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None
|
||||
v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
cpu_block_ids = cpu_block_ids[match_block_num:]
|
||||
# TODO: support timeout with actual block count
|
||||
@@ -782,6 +921,8 @@ class CacheTransferManager:
|
||||
match_block_num,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
k_scale_keys,
|
||||
v_scale_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
@@ -790,7 +931,7 @@ class CacheTransferManager:
|
||||
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in write back storage task: {e}")
|
||||
logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}")
|
||||
gpu_block_ids = []
|
||||
finally:
|
||||
try:
|
||||
|
||||
@@ -237,19 +237,43 @@ class MooncakeStore(KVCacheStorage):
|
||||
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
|
||||
return result
|
||||
|
||||
def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0):
|
||||
def query(
|
||||
self,
|
||||
k_keys: List[str],
|
||||
v_keys: List[str],
|
||||
k_scale_keys: List[str] = None,
|
||||
v_scale_keys: List[str] = None,
|
||||
timeout: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
Given the k_keys, v_keys, k_scale_keys and v_scale_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.exists(k_keys + v_keys)
|
||||
|
||||
all_keys = k_keys + v_keys
|
||||
has_scale = k_scale_keys is not None and v_scale_keys is not None
|
||||
if has_scale:
|
||||
assert (
|
||||
len(k_scale_keys) == len(v_scale_keys) == len(k_keys) == len(v_keys)
|
||||
), "k_scale_keys and v_scale_keys must have the same length as k_keys and v_keys."
|
||||
all_keys.extend(k_scale_keys + v_scale_keys)
|
||||
|
||||
result = self.exists(all_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
if has_scale:
|
||||
for k, v, k_scale, v_scale in zip(k_keys, v_keys, k_scale_keys, v_scale_keys):
|
||||
if not (result[k] and result[v] and result[k_scale] and result[v_scale]):
|
||||
break
|
||||
num += 1
|
||||
else:
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if not (result[k] and result[v]):
|
||||
break
|
||||
num += 1
|
||||
|
||||
return num
|
||||
|
||||
def delete(self, key, timeout=5) -> bool:
|
||||
@@ -287,12 +311,12 @@ class MooncakeStore(KVCacheStorage):
|
||||
success_num = result.count(0)
|
||||
if success_num == total_num:
|
||||
logger.debug(
|
||||
f"Put all data into Mooncake Store successfully."
|
||||
f"Put all data into Mooncake Store successfully. "
|
||||
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Some of the data was not put into Mooncake Store."
|
||||
f"Some of the data was not put into Mooncake Store. "
|
||||
f"total_num: {total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
if success_num > 0:
|
||||
@@ -322,7 +346,7 @@ class MooncakeStore(KVCacheStorage):
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Some of the data was not get from Mooncake Store."
|
||||
f"Some of the data was not get from Mooncake Store. "
|
||||
f"total_num:{total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
if success_num > 0:
|
||||
|
||||
Reference in New Issue
Block a user