From d6b3c722c1ad7985bbc6290e7b8bebb23e1b04b5 Mon Sep 17 00:00:00 2001 From: jc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 6 Feb 2026 12:01:17 +0800 Subject: [PATCH] [KVCache] Storage cache supports c8 model (#6298) * Refine cache transfer manager * Storage cache supports c8 model --- custom_ops/gpu_ops/swap_cache_layout.cu | 26 +- .../cache_manager/cache_transfer_manager.py | 235 ++++++++++++++---- .../mooncake_store/mooncake_store.py | 40 ++- 3 files changed, 232 insertions(+), 69 deletions(-) diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu index edd31c9a07..62adccb2d0 100644 --- a/custom_ops/gpu_ops/swap_cache_layout.cu +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -25,25 +25,23 @@ void SwapCacheImpLayout( const std::vector& gpu_block_ids, const std::vector& cpu_block_ids, int mode) { - // mode is 0: gpu to cpu; 1: cpu to gpu - // cache layout: layer_num * [block_num, head_num, block_size, head_dim] - // buffer layout: [block_num, layer_num, head_num, block_size, head_dim] + /* + mode is 0: gpu to cpu; 1: cpu to gpu + + cache layout: layer_num * [block_num, head_num, block_size, head_dim] + scale layout: layer_num * [block_num, head_num, block_size] + cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] + scale buffer layout: [block_num, layer_num, head_num, block_size] + */ typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int64_t layer_number = cache_gpu_tensors.size(); - const int64_t num_heads = cache_shape[1]; - const int64_t block_size = cache_shape[2]; - const int64_t head_dim = cache_shape[3]; - const int64_t cache_block_stride = num_heads * block_size * head_dim; - -#ifdef SWAP_DEBUG - std::cout << "layer_number:" << layer_number << std::endl; - std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1] - << ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl; - std::cout << "cache_block_stride:" << cache_block_stride << std::endl; -#endif + int64_t cache_block_stride = 1; + for (int i = 1; i < cache_shape.size(); i++) { + cache_block_stride *= cache_shape[i]; + } auto stream = cache_gpu_tensors[0].stream(); const cudaMemcpyKind copy_kind = diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 5475d2d50f..670d525e40 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -173,7 +173,11 @@ class CacheTransferManager: # compute cache bytes self.cache_dtype = args.cache_dtype - self.cache_bytes = self._get_cache_bytes(self.cache_dtype) + self.cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype) + self.scale_item_bytes = self._get_cache_item_bytes(paddle.get_default_dtype()) + self.has_cache_scale = self.cache_dtype == "block_wise_fp8" + if self.has_cache_scale: + self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size] # extract other arg values self.model_id = os.path.basename(args.model_path.rstrip("/")) @@ -272,6 +276,14 @@ class CacheTransferManager: self.storage_backend_type = args.kvcache_storage_backend try: + # TODO: support cache scale for other backend + if self.has_cache_scale: + if self.storage_backend_type not in ["mooncake"]: + raise ValueError( + f"Unsupported storage backend ({self.storage_backend_type}) " + "when cache quantization is block_wise_fp8" + ) + if self.storage_backend_type is None: self.storage_backend = None elif self.storage_backend_type == "mooncake": @@ -288,7 +300,10 @@ class CacheTransferManager: shard_num=self.n_ranks, layer_num=self.num_layers + self.num_extra_layers, block_token_size=self.block_size, - bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes, + bytes_per_shard_layer_per_block=self.head_num + * self.block_size + * self.head_dim + * self.cache_item_bytes, device_id=self.device, dp_id=self.local_data_parallel_id, ) @@ -326,7 +341,9 @@ class CacheTransferManager: """ Initialize pinned memory buffer that can hold the cache for a longest request cache layout: layer_num * [block_num, head_num, block_size, head_dim] - buffer layout: [block_num, layer_num, head_num, block_size, head_dim] + scale layout: layer_num * [block_num, head_num, block_size] + cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] + scale buffer layout: [block_num, layer_num, head_num, block_size] """ layer_num = self.num_layers + self.num_extra_layers block_num = (args.max_model_len + self.block_size - 1) // self.block_size @@ -335,21 +352,38 @@ class CacheTransferManager: f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]" ) - self.storage_buffer_stride_bytes = ( - layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes + self.cache_buffer_stride_bytes = ( + layer_num * self.head_num * self.block_size * self.head_dim * self.cache_item_bytes ) - total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value + cache_buffer_total_bytes = block_num * self.cache_buffer_stride_bytes * 2 # key and value - logger.info(f"Creating cpu buffer cache for all layers: {total_bytes / 1024 ** 3:.2f}GB") - read_buffer = cuda_host_alloc(total_bytes) + logger.info(f"Creating cache cpu buffer for all layers: {cache_buffer_total_bytes / 1024 ** 3:.2f}GB") + read_buffer = cuda_host_alloc(cache_buffer_total_bytes) self.storage_key_read_buffer = read_buffer - self.storage_value_read_buffer = read_buffer + total_bytes // 2 - self.storage_backend.register_buffer(read_buffer, total_bytes) + self.storage_value_read_buffer = read_buffer + cache_buffer_total_bytes // 2 + self.storage_backend.register_buffer(read_buffer, cache_buffer_total_bytes) - write_buffer = cuda_host_alloc(total_bytes) + write_buffer = cuda_host_alloc(cache_buffer_total_bytes) self.storage_key_write_buffer = write_buffer - self.storage_value_write_buffer = write_buffer + total_bytes // 2 - self.storage_backend.register_buffer(write_buffer, total_bytes) + self.storage_value_write_buffer = write_buffer + cache_buffer_total_bytes // 2 + self.storage_backend.register_buffer(write_buffer, cache_buffer_total_bytes) + + if self.has_cache_scale: + self.scale_buffer_stride_bytes = layer_num * self.head_num * self.block_size * self.scale_item_bytes + scale_buffer_total_bytes = block_num * self.scale_buffer_stride_bytes * 2 + logger.info( + f"Creating scale cpu buffer cache for all layers: {scale_buffer_total_bytes / 1024 ** 3:.2f}GB" + ) + + read_buffer = cuda_host_alloc(scale_buffer_total_bytes) + self.storage_key_scale_read_buffer = read_buffer + self.storage_value_scale_read_buffer = read_buffer + scale_buffer_total_bytes // 2 + self.storage_backend.register_buffer(read_buffer, scale_buffer_total_bytes) + + write_buffer = cuda_host_alloc(scale_buffer_total_bytes) + self.storage_key_scale_write_buffer = write_buffer + self.storage_value_scale_write_buffer = write_buffer + scale_buffer_total_bytes // 2 + self.storage_backend.register_buffer(write_buffer, scale_buffer_total_bytes) def _init_gpu_cache(self, args): @@ -367,6 +401,7 @@ class CacheTransferManager: logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") set_device(self.device) for i in range(self.num_layers + self.num_extra_layers): + # NOTE: num_extra_layer_gpu_blocks is usually equal to num_gpu_blocks num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" @@ -464,9 +499,9 @@ class CacheTransferManager: value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3] else: value_cache_size = 0 - cache_bytes = self._get_cache_bytes(self.cache_dtype) - key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size - value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size + cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype) + key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size + value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size if args.cache_dtype == "block_wise_fp8": cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2] @@ -507,14 +542,16 @@ class CacheTransferManager: logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 - def _get_cache_bytes(self, cache_dtype): - if cache_dtype == "bfloat16": - cache_bytes = 2 + def _get_cache_item_bytes(self, cache_dtype): + if cache_dtype == "float32": + bytes = 4 + elif cache_dtype in ("bfloat16", "float16"): + bytes = 2 elif cache_dtype in ["uint8", "block_wise_fp8"]: - cache_bytes = 1 + bytes = 1 else: raise ValueError(f"Unsupported cache dtype: {cache_dtype}") - return cache_bytes + return bytes def _run_read_storage( self, @@ -523,6 +560,8 @@ class CacheTransferManager: start_read_block_idx: int, k_cache_keys: List[str], v_cache_keys: List[str], + k_scale_keys: List[str], + v_scale_keys: List[str], gpu_block_ids: List[int], cpu_block_ids: List[int], timeout: float, @@ -535,23 +574,45 @@ class CacheTransferManager: block_num = len(gpu_block_ids) keys = k_cache_keys + v_cache_keys k_cache_ptrs = [ - self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + self.storage_key_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids ] v_cache_ptrs = [ - self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + self.storage_value_read_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + target_locations = k_cache_ptrs + v_cache_ptrs + target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value + if k_scale_keys and v_scale_keys: + keys.extend(k_scale_keys + v_scale_keys) + k_scale_ptrs = [ + self.storage_key_scale_read_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids + ] + v_scale_ptrs = [ + self.storage_value_scale_read_buffer + i * self.scale_buffer_stride_bytes + for i in cpu_block_ids + ] + target_locations.extend(k_scale_ptrs + v_scale_ptrs) + target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2) + start_time = time.time() result = self.storage_backend.batch_get( - keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes + keys=keys, target_locations=target_locations, target_sizes=target_sizes ) read_cost_time = time.time() - start_time - k_result, v_result = result[:block_num], result[block_num:] - success_block_num = 0 - for k, v in zip(k_result, v_result): - if k > 0 and v > 0: + if k_scale_keys and v_scale_keys: + k_result, v_result = result[:block_num], result[block_num : 2 * block_num] + k_scale_result, v_scale_result = result[2 * block_num : 3 * block_num], result[3 * block_num :] + success_block_num = 0 + for k, v, k_scale, v_scale in zip(k_result, v_result, k_scale_result, v_scale_result): + if not (k > 0 and v > 0 and k_scale > 0 and v_scale > 0): + break + success_block_num += 1 + else: + k_result, v_result = result[:block_num], result[block_num : 2 * block_num] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if not (k > 0 and v > 0): + break success_block_num += 1 logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") valid_gpu_block_ids = gpu_block_ids[:success_block_num] @@ -577,6 +638,25 @@ class CacheTransferManager: self.device, mode, ) + if k_scale_keys and v_scale_keys: + swap_cache_layout( + self.gpu_cache_scales_k_tensors, + self.storage_key_scale_read_buffer, + self.cache_scale_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_scales_v_tensors, + self.storage_value_scale_read_buffer, + self.cache_scale_shape, + valid_gpu_block_ids, + valid_cpu_block_ids, + self.device, + mode, + ) swap_cost_time = time.time() - start_time logger.debug( f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s" @@ -607,14 +687,27 @@ class CacheTransferManager: def read_storage_task(self, task: ReadStorageTask): """Read cache from the storage backend to the GPU memory.""" + assert ( + self.storage_backend + ), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}" + try: gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys] v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys] + if not self.has_cache_scale: + k_scale_keys = None + v_scale_keys = None + else: + k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys] + v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys] + match_block_num = 0 if self.storage_backend_type in ("mooncake", "file"): - match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys) + match_block_num = self.storage_backend.query( + k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout + ) elif self.storage_backend_type == "attention_store": match_block_num = self.storage_backend.query( task.task_id, task.token_ids, task.start_read_block_idx, task.timeout @@ -623,6 +716,8 @@ class CacheTransferManager: k_cache_keys = k_cache_keys[:match_block_num] v_cache_keys = v_cache_keys[:match_block_num] + k_scale_keys = k_scale_keys[:match_block_num] if k_scale_keys else None + v_scale_keys = v_scale_keys[:match_block_num] if v_scale_keys else None gpu_block_ids = gpu_block_ids[:match_block_num] cpu_block_ids = cpu_block_ids[:match_block_num] valid_gpu_block_ids = [] @@ -635,6 +730,8 @@ class CacheTransferManager: task.start_read_block_idx, k_cache_keys, v_cache_keys, + k_scale_keys, + v_scale_keys, gpu_block_ids, cpu_block_ids, task.timeout, @@ -643,7 +740,9 @@ class CacheTransferManager: f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) except Exception as e: - logger.error(f"Failed to read cache for task {task.task_id}, error: {e}") + logger.error( + f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}" + ) valid_gpu_block_ids = [] finally: try: @@ -674,24 +773,20 @@ class CacheTransferManager: start_write_block_idx, k_cache_keys, v_cache_keys, + k_scale_keys, + v_scale_keys, gpu_block_ids, cpu_block_ids, timeout, ): try: if self.storage_backend_type in ("mooncake", "file"): - key_cache_size = [ - self.key_cache_shape[0], - self.key_cache_shape[1], - self.key_cache_shape[2], - self.key_cache_shape[3], - ] mode = 0 # gpu ==> cpu start_time = time.time() swap_cache_layout( self.gpu_cache_k_tensors, self.storage_key_write_buffer, - key_cache_size, + self.key_cache_shape, gpu_block_ids, cpu_block_ids, self.device, @@ -700,27 +795,57 @@ class CacheTransferManager: swap_cache_layout( self.gpu_cache_v_tensors, self.storage_value_write_buffer, - key_cache_size, + self.key_cache_shape, gpu_block_ids, cpu_block_ids, self.device, mode, ) + if k_scale_keys and v_scale_keys: + swap_cache_layout( + self.gpu_cache_scales_k_tensors, + self.storage_key_scale_write_buffer, + self.cache_scale_shape, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) + swap_cache_layout( + self.gpu_cache_scales_v_tensors, + self.storage_value_scale_write_buffer, + self.cache_scale_shape, + gpu_block_ids, + cpu_block_ids, + self.device, + mode, + ) swap_cost_time = time.time() - start_time block_num = len(gpu_block_ids) keys = k_cache_keys + v_cache_keys k_cache_ptrs = [ - self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + self.storage_key_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids ] v_cache_ptrs = [ - self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids + self.storage_value_write_buffer + i * self.cache_buffer_stride_bytes for i in cpu_block_ids ] - kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs - kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value + target_locations = k_cache_ptrs + v_cache_ptrs + target_sizes = [self.cache_buffer_stride_bytes] * block_num * 2 # key and value + if k_scale_keys and v_scale_keys: + keys.extend(k_scale_keys + v_scale_keys) + k_scale_ptrs = [ + self.storage_key_scale_write_buffer + i * self.scale_buffer_stride_bytes for i in cpu_block_ids + ] + v_scale_ptrs = [ + self.storage_value_scale_write_buffer + i * self.scale_buffer_stride_bytes + for i in cpu_block_ids + ] + target_locations.extend(k_scale_ptrs + v_scale_ptrs) + target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2) start_time = time.time() - self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) + self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes) write_cost_time = time.time() - start_time logger.debug( @@ -753,15 +878,27 @@ class CacheTransferManager: """ Write cache to the storage backend from the GPU memory. """ + assert ( + self.storage_backend + ), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}" + try: gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys] v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys] + if not self.has_cache_scale: + k_scale_keys = None + v_scale_keys = None + else: + k_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key_scale" for key in task.keys] + v_scale_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value_scale" for key in task.keys] match_block_num = 0 if self.storage_backend_type == ("mooncake", "file"): - match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout) + match_block_num = self.storage_backend.query( + k_cache_keys, v_cache_keys, k_scale_keys, v_scale_keys, task.timeout + ) elif self.storage_backend_type == "attention_store": match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout) logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}") @@ -773,6 +910,8 @@ class CacheTransferManager: try: k_cache_keys = k_cache_keys[match_block_num:] v_cache_keys = v_cache_keys[match_block_num:] + k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None + v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None gpu_block_ids = gpu_block_ids[match_block_num:] cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count @@ -782,6 +921,8 @@ class CacheTransferManager: match_block_num, k_cache_keys, v_cache_keys, + k_scale_keys, + v_scale_keys, gpu_block_ids, cpu_block_ids, task.timeout, @@ -790,7 +931,7 @@ class CacheTransferManager: f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) except Exception as e: - logger.error(f"Error in write back storage task: {e}") + logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") gpu_block_ids = [] finally: try: diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index ccb378ca94..949f990f38 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -237,19 +237,43 @@ class MooncakeStore(KVCacheStorage): logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms") return result - def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0): + def query( + self, + k_keys: List[str], + v_keys: List[str], + k_scale_keys: List[str] = None, + v_scale_keys: List[str] = None, + timeout: float = 1.0, + ): """ - Given the k_keys and v_keys, get the valid blocks number that + Given the k_keys, v_keys, k_scale_keys and v_scale_keys, get the valid blocks number that can be prefetched from storage backend. """ assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length." - result = self.exists(k_keys + v_keys) + + all_keys = k_keys + v_keys + has_scale = k_scale_keys is not None and v_scale_keys is not None + if has_scale: + assert ( + len(k_scale_keys) == len(v_scale_keys) == len(k_keys) == len(v_keys) + ), "k_scale_keys and v_scale_keys must have the same length as k_keys and v_keys." + all_keys.extend(k_scale_keys + v_scale_keys) + + result = self.exists(all_keys) # only consider the case when both key and value exist num = 0 - for k, v in zip(k_keys, v_keys): - if result[k] and result[v]: + if has_scale: + for k, v, k_scale, v_scale in zip(k_keys, v_keys, k_scale_keys, v_scale_keys): + if not (result[k] and result[v] and result[k_scale] and result[v_scale]): + break num += 1 + else: + for k, v in zip(k_keys, v_keys): + if not (result[k] and result[v]): + break + num += 1 + return num def delete(self, key, timeout=5) -> bool: @@ -287,12 +311,12 @@ class MooncakeStore(KVCacheStorage): success_num = result.count(0) if success_num == total_num: logger.debug( - f"Put all data into Mooncake Store successfully." + f"Put all data into Mooncake Store successfully. " f"success_num: {success_num}, cost_time: {cost_time:.6f}s" ) else: logger.error( - f"Some of the data was not put into Mooncake Store." + f"Some of the data was not put into Mooncake Store. " f"total_num: {total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s" ) if success_num > 0: @@ -322,7 +346,7 @@ class MooncakeStore(KVCacheStorage): ) else: logger.error( - f"Some of the data was not get from Mooncake Store." + f"Some of the data was not get from Mooncake Store. " f"total_num:{total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s" ) if success_num > 0: