[KVCache] Storage cache supports c8 model (#6298)

* Refine cache transfer manager
* Storage cache supports c8 model
This commit is contained in:
jc
2026-02-06 12:01:17 +08:00
committed by GitHub
parent 72fe94cb13
commit d6b3c722c1
3 changed files with 232 additions and 69 deletions
+12 -14
View File
@@ -25,25 +25,23 @@ void SwapCacheImpLayout(
const std::vector<int64_t>& gpu_block_ids,
const std::vector<int64_t>& cpu_block_ids,
int mode) {
// mode is 0: gpu to cpu; 1: cpu to gpu
// cache layout: layer_num * [block_num, head_num, block_size, head_dim]
// buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
/*
mode is 0: gpu to cpu; 1: cpu to gpu
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
scale layout: layer_num * [block_num, head_num, block_size]
cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
scale buffer layout: [block_num, layer_num, head_num, block_size]
*/
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int64_t layer_number = cache_gpu_tensors.size();
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape[3];
const int64_t cache_block_stride = num_heads * block_size * head_dim;
#ifdef SWAP_DEBUG
std::cout << "layer_number:" << layer_number << std::endl;
std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1]
<< ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl;
std::cout << "cache_block_stride:" << cache_block_stride << std::endl;
#endif
int64_t cache_block_stride = 1;
for (int i = 1; i < cache_shape.size(); i++) {
cache_block_stride *= cache_shape[i];
}
auto stream = cache_gpu_tensors[0].stream();
const cudaMemcpyKind copy_kind =
@@ -173,7 +173,11 @@ class CacheTransferManager:
# compute cache bytes
self.cache_dtype = args.cache_dtype
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
self.cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
self.scale_item_bytes = self._get_cache_item_bytes(paddle.get_default_dtype())
self.has_cache_scale = self.cache_dtype == "block_wise_fp8"
if self.has_cache_scale:
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size]
# extract other arg values
self.model_id = os.path.basename(args.model_path.rstrip("/"))
@@ -272,6 +276,14 @@ class CacheTransferManager:
self.storage_backend_type = args.kvcache_storage_backend
try:
# TODO: support cache scale for other backend
if self.has_cache_scale:
if self.storage_backend_type not in ["mooncake"]:
raise ValueError(
f"Unsupported storage backend ({self.storage_backend_type}) "
"when cache quantization is block_wise_fp8"
)
if self.storage_backend_type is None:
self.storage_backend = None
elif self.storage_backend_type == "mooncake":
@@ -288,7 +300,10 @@ class CacheTransferManager:
shard_num=self.n_ranks,
layer_num=self.num_layers + self.num_extra_layers,
block_token_size=self.block_size,
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
bytes_per_shard_layer_per_block=self.head_num
* self.block_size
* self.head_dim
* self.cache_item_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
)
@@ -326,7 +341,9 @@ class CacheTransferManager:
"""
Initialize pinned memory buffer that can hold the cache for a longest request
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
scale layout: layer_num * [block_num, head_num, block_size]
cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
scale buffer layout: [block_num, layer_num, head_num, block_size]
"""
layer_num = self.num_layers + self.num_extra_layers
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
@@ -335,21 +352,38 @@ class CacheTransferManager:
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
)
self.storage_buffer_stride_bytes = (
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
self.cache_buffer_stride_bytes = (
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_item_bytes
)
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
cache_buffer_total_bytes = block_num * self.cache_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for all layers: {total_bytes / 1024 ** 3:.2f}GB")
read_buffer = cuda_host_alloc(total_bytes)
logger.info(f"Creating cache cpu buffer for all layers: {cache_buffer_total_bytes / 1024 ** 3:.2f}GB")
read_buffer = cuda_host_alloc(cache_buffer_total_bytes)
self.storage_key_read_buffer = read_buffer
self.storage_value_read_buffer = read_buffer + total_bytes // 2
self.storage_backend.register_buffer(read_buffer, total_bytes)
self.storage_value_read_buffer = read_buffer + cache_buffer_total_bytes // 2
self.storage_backend.register_buffer(read_buffer, cache_buffer_total_bytes)
write_buffer = cuda_host_alloc(total_bytes)
write_buffer = cuda_host_alloc(cache_buffer_total_bytes)
self.storage_key_write_buffer = write_buffer
self.storage_value_write_buffer = write_buffer + total_bytes // 2
self.storage_backend.register_buffer(write_buffer, total_bytes)
self.storage_value_write_buffer = write_buffer + cache_buffer_total_bytes // 2
self.storage_backend.register_buffer(write_buffer, cache_buffer_total_bytes)
if self.has_cache_scale:
self.scale_buffer_stride_bytes = layer_num * self.head_num * self.block_size * self.scale_item_bytes
scale_buffer_total_bytes = block_num * self.scale_buffer_stride_bytes * 2
logger.info(
f"Creating scale cpu buffer cache for all layers: {scale_buffer_total_bytes / 1024 ** 3:.2f}GB"
)
read_buffer = cuda_host_alloc(scale_buffer_total_bytes)
self.storage_key_scale_read_buffer = read_buffer
self.storage_value_scale_read_buffer = read_buffer + scale_buffer_total_bytes // 2
self.storage_backend.register_buffer(read_buffer, scale_buffer_total_bytes)
write_buffer = cuda_host_alloc(scale_buffer_total_bytes)
self.storage_key_scale_write_buffer = write_buffer
self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2
self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes)
def _init_gpu_cache(self, args):
@@ -367,6 +401,7 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
set_device(self.device)
for i in range(self.num_layers + self.num_extra_layers):
# NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
@@ -464,9 +499,9 @@ class CacheTransferManager:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
cache_bytes = self._get_cache_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
if args.cache_dtype == "block_wise_fp8":
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
@@ -507,14 +542,16 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _get_cache_bytes(self, cache_dtype):
if cache_dtype == "bfloat16":
cache_bytes = 2
def _get_cache_item_bytes(self, cache_dtype):
if cache_dtype == "float32":
bytes = 4
elif cache_dtype in ("bfloat16", "float16"):
bytes = 2
elif cache_dtype in ["uint8", "block_wise_fp8"]:
cache_bytes = 1
bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
return cache_bytes
return bytes
def _run_read_storage(
self,
@@ -523,6 +560,8 @@ class CacheTransferManager:
start_read_block_idx: int,
k_cache_keys: List[str],
v_cache_keys: List[str],
k_scale_keys: List[str],
v_scale_keys: List[str],
gpu_block_ids: List[int],
cpu_block_ids: List[int],
timeout: float,
@@ -535,23 +574,45 @@ class CacheTransferManager:
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
self.storage_key_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
self.storage_value_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
target_locations = k_cache_ptrs + v_cache_ptrs
target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value
if k_scale_keys and v_scale_keys:
keys.extend(k_scale_keys + v_scale_keys)
k_scale_ptrs = [
self.storage_key_scale_read_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids
]
v_scale_ptrs = [
self.storage_value_scale_read_buffer + i * self.scale_buffer_stride_bytes
for i in cpu_block_ids
]
target_locations.extend(k_scale_ptrs + v_scale_ptrs)
target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2)
start_time = time.time()
result = self.storage_backend.batch_get(
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
keys=keys, target_locations=target_locations, target_sizes=target_sizes
)
read_cost_time = time.time() - start_time
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
for k, v in zip(k_result, v_result):
if k > 0 and v > 0:
if k_scale_keys and v_scale_keys:
k_result, v_result = result[:block_num], result[block_num : 2 * block_num]
k_scale_result, v_scale_result = result[2 * block_num : 3 * block_num], result[3 * block_num :]
success_block_num = 0
for k, v, k_scale, v_scale in zip(k_result, v_result, k_scale_result, v_scale_result):
if not (k > 0 and v > 0 and k_scale > 0 and v_scale > 0):
break
success_block_num += 1
else:
k_result, v_result = result[:block_num], result[block_num : 2 * block_num]
success_block_num = 0
for k, v in zip(k_result, v_result):
if not (k > 0 and v > 0):
break
success_block_num += 1
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
@@ -577,6 +638,25 @@ class CacheTransferManager:
self.device,
mode,
)
if k_scale_keys and v_scale_keys:
swap_cache_layout(
self.gpu_cache_scales_k_tensors,
self.storage_key_scale_read_buffer,
self.cache_scale_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_scales_v_tensors,
self.storage_value_scale_read_buffer,
self.cache_scale_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
@@ -607,14 +687,27 @@ class CacheTransferManager:
def read_storage_task(self, task: ReadStorageTask):
"""Read cache from the storage backend to the GPU memory."""
assert (
self.storage_backend
), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}"
try:
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
if not self.has_cache_scale:
k_scale_keys = None
v_scale_keys = None
else:
k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys]
v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys]
match_block_num = 0
if self.storage_backend_type in ("mooncake", "file"):
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
match_block_num = self.storage_backend.query(
k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout
)
elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
@@ -623,6 +716,8 @@ class CacheTransferManager:
k_cache_keys = k_cache_keys[:match_block_num]
v_cache_keys = v_cache_keys[:match_block_num]
k_scale_keys = k_scale_keys[:match_block_num] if k_scale_keys else None
v_scale_keys = v_scale_keys[:match_block_num] if v_scale_keys else None
gpu_block_ids = gpu_block_ids[:match_block_num]
cpu_block_ids = cpu_block_ids[:match_block_num]
valid_gpu_block_ids = []
@@ -635,6 +730,8 @@ class CacheTransferManager:
task.start_read_block_idx,
k_cache_keys,
v_cache_keys,
k_scale_keys,
v_scale_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
@@ -643,7 +740,9 @@ class CacheTransferManager:
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
)
except Exception as e:
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
logger.error(
f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}"
)
valid_gpu_block_ids = []
finally:
try:
@@ -674,24 +773,20 @@ class CacheTransferManager:
start_write_block_idx,
k_cache_keys,
v_cache_keys,
k_scale_keys,
v_scale_keys,
gpu_block_ids,
cpu_block_ids,
timeout,
):
try:
if self.storage_backend_type in ("mooncake", "file"):
key_cache_size = [
self.key_cache_shape[0],
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
mode = 0 # gpu ==> cpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
self.key_cache_shape,
gpu_block_ids,
cpu_block_ids,
self.device,
@@ -700,27 +795,57 @@ class CacheTransferManager:
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
self.key_cache_shape,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
if k_scale_keys and v_scale_keys:
swap_cache_layout(
self.gpu_cache_scales_k_tensors,
self.storage_key_scale_write_buffer,
self.cache_scale_shape,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_scales_v_tensors,
self.storage_value_scale_write_buffer,
self.cache_scale_shape,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
self.storage_key_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
self.storage_value_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
target_locations = k_cache_ptrs + v_cache_ptrs
target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value
if k_scale_keys and v_scale_keys:
keys.extend(k_scale_keys + v_scale_keys)
k_scale_ptrs = [
self.storage_key_scale_write_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids
]
v_scale_ptrs = [
self.storage_value_scale_write_buffer + i * self.scale_buffer_stride_bytes
for i in cpu_block_ids
]
target_locations.extend(k_scale_ptrs + v_scale_ptrs)
target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2)
start_time = time.time()
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes)
write_cost_time = time.time() - start_time
logger.debug(
@@ -753,15 +878,27 @@ class CacheTransferManager:
"""
Write cache to the storage backend from the GPU memory.
"""
assert (
self.storage_backend
), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}"
try:
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
if not self.has_cache_scale:
k_scale_keys = None
v_scale_keys = None
else:
k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys]
v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys]
match_block_num = 0
if self.storage_backend_type == ("mooncake", "file"):
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
match_block_num = self.storage_backend.query(
k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout
)
elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
@@ -773,6 +910,8 @@ class CacheTransferManager:
try:
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None
v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = cpu_block_ids[match_block_num:]
# TODO: support timeout with actual block count
@@ -782,6 +921,8 @@ class CacheTransferManager:
match_block_num,
k_cache_keys,
v_cache_keys,
k_scale_keys,
v_scale_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
@@ -790,7 +931,7 @@ class CacheTransferManager:
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
)
except Exception as e:
logger.error(f"Error in write back storage task: {e}")
logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}")
gpu_block_ids = []
finally:
try:
@@ -237,19 +237,43 @@ class MooncakeStore(KVCacheStorage):
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
return result
def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0):
def query(
self,
k_keys: List[str],
v_keys: List[str],
k_scale_keys: List[str] = None,
v_scale_keys: List[str] = None,
timeout: float = 1.0,
):
"""
Given the k_keys and v_keys, get the valid blocks number that
Given the k_keys, v_keys, k_scale_keys and v_scale_keys, get the valid blocks number that
can be prefetched from storage backend.
"""
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.exists(k_keys + v_keys)
all_keys = k_keys + v_keys
has_scale = k_scale_keys is not None and v_scale_keys is not None
if has_scale:
assert (
len(k_scale_keys) == len(v_scale_keys) == len(k_keys) == len(v_keys)
), "k_scale_keys and v_scale_keys must have the same length as k_keys and v_keys."
all_keys.extend(k_scale_keys + v_scale_keys)
result = self.exists(all_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
if has_scale:
for k, v, k_scale, v_scale in zip(k_keys, v_keys, k_scale_keys, v_scale_keys):
if not (result[k] and result[v] and result[k_scale] and result[v_scale]):
break
num += 1
else:
for k, v in zip(k_keys, v_keys):
if not (result[k] and result[v]):
break
num += 1
return num
def delete(self, key, timeout=5) -> bool:
@@ -287,12 +311,12 @@ class MooncakeStore(KVCacheStorage):
success_num = result.count(0)
if success_num == total_num:
logger.debug(
f"Put all data into Mooncake Store successfully."
f"Put all data into Mooncake Store successfully. "
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
else:
logger.error(
f"Some of the data was not put into Mooncake Store."
f"Some of the data was not put into Mooncake Store. "
f"total_num: {total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
if success_num > 0:
@@ -322,7 +346,7 @@ class MooncakeStore(KVCacheStorage):
)
else:
logger.error(
f"Some of the data was not get from Mooncake Store."
f"Some of the data was not get from Mooncake Store. "
f"total_num:{total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
if success_num > 0: