[Feature] [KVCache] support attention_store kv cache backend (#5823)

* [feat] support attention_store kv cache backend

* [fix] fix codestyle

* [chore] optimize log

* [fix] fix write storage task

* [fix] fix read storage

* [fix] fix code conflict after merge develop

* [fix] fix cache bytes and read task token ids

* [chore] add model for cache transfer manager

* [chore] add some log

* [chore] remove launched_cache_manager_signal

* [fix] fix write_back_storage_task match_block_num condition

* [fix] fix swap_cost_time

* [ci] fix ci

* Update fastdeploy/engine/sched/resource_manager_v1.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/cache_manager/cache_transfer_manager.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2026-01-22 21:01:23 +08:00
committed by GitHub
parent 3cd0ffe36c
commit 8d27a523e7
17 changed files with 599 additions and 226 deletions
+1 -1
View File
@@ -32,7 +32,7 @@
| KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 | | KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 |
| KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 | | KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 |
| KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | | KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 | | KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 | 个 |
| KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 | | KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 |
| KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 | | KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 |
| KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 | | KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 |
+4 -1
View File
@@ -1051,7 +1051,10 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log") if args.mp_num > 1:
logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log")
else:
logger = get_logger("cache_messager", "cache_messager.log")
logger.info("create cache messager...") logger.info("create cache messager...")
logger.info(f"{args}") logger.info(f"{args}")
+37
View File
@@ -0,0 +1,37 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import List
@dataclass(frozen=True, kw_only=True)
class CacheTask:
task_id: str
keys: List[str]
token_ids: List[int]
gpu_block_ids: List[int]
@dataclass(frozen=True, kw_only=True)
class ReadStorageTask(CacheTask):
start_read_block_idx: int
timeout: float = 30.0
@dataclass(frozen=True, kw_only=True)
class WriteStorageTask(CacheTask):
timeout: float = 30.0
+272 -192
View File
@@ -29,6 +29,7 @@ import paddle
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.ops import ( from fastdeploy.cache_manager.ops import (
cuda_host_alloc, cuda_host_alloc,
cuda_host_free, cuda_host_free,
@@ -40,7 +41,7 @@ from fastdeploy.cache_manager.ops import (
swap_cache_layout, swap_cache_layout,
unset_data_ipc, unset_data_ipc,
) )
from fastdeploy.cache_manager.transfer_factory import MooncakeStore from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore
from fastdeploy.config import SpeculativeConfig from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -58,6 +59,7 @@ def parse_args():
default="mixed", default="mixed",
help="splitwise role, can be decode, prefill or mixed", help="splitwise role, can be decode, prefill or mixed",
) )
parser.add_argument("--model_id", type=str, default="default", help="model id")
parser.add_argument("--rank", type=int, default=0, help="local tp rank") parser.add_argument("--rank", type=int, default=0, help="local tp rank")
parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length") parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
@@ -109,7 +111,7 @@ def parse_args():
"--kvcache_storage_backend", "--kvcache_storage_backend",
type=str, type=str,
default=None, default=None,
choices=["mooncake", "none"], choices=["mooncake", "attention_store", "none"],
help="The storage backend for kvcache storage. If not set, storage backend is disabled.", help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
) )
parser.add_argument( parser.add_argument(
@@ -133,8 +135,6 @@ class CacheTransferManager:
""" """
初始化CacheTransferManager 初始化CacheTransferManager
""" """
device = args.device_id
rank = args.rank
self.gpu_cache_kvs = {} self.gpu_cache_kvs = {}
self.cpu_cache_kvs = {} self.cpu_cache_kvs = {}
self.gpu_cache_k_tensors = [] self.gpu_cache_k_tensors = []
@@ -142,11 +142,31 @@ class CacheTransferManager:
self.gpu_cache_scales_k_tensors = [] self.gpu_cache_scales_k_tensors = []
self.gpu_cache_scales_v_tensors = [] self.gpu_cache_scales_v_tensors = []
self.speculative_config = SpeculativeConfig(args.speculative_config) self.speculative_config = SpeculativeConfig(args.speculative_config)
# parse kv cache shape
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
self.value_cache_shape = [] self.value_cache_shape = []
if args.value_cache_shape: if args.value_cache_shape:
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")] self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
# extract kv cache shape into fields
self.num_gpu_blocks = self.key_cache_shape[0] self.num_gpu_blocks = self.key_cache_shape[0]
self.head_num = self.key_cache_shape[1]
self.block_size = self.key_cache_shape[2]
self.head_dim = self.key_cache_shape[3]
# compute cache bytes
self.cache_dtype = args.cache_dtype
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
# extract other arg values
self.model_id = args.model_id
self.n_ranks = args.mp_num
self.rank = args.rank
self.device = args.device_id
self.num_layers = args.num_layers
self.ipc_suffix = args.ipc_suffix
self.local_data_parallel_id = args.local_data_parallel_id
self.num_extra_layers = self.speculative_config.num_extra_cache_layer self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
paddle.set_default_dtype(args.default_dtype) paddle.set_default_dtype(args.default_dtype)
@@ -158,18 +178,13 @@ class CacheTransferManager:
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务 self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
self.rank = rank
self.device = device
self.ipc_suffix = args.ipc_suffix
self.cache_dtype = args.cache_dtype
address = (args.pod_ip, args.cache_queue_port) address = (args.pod_ip, args.cache_queue_port)
self.cache_task_queue = EngineCacheQueue( self.cache_task_queue = EngineCacheQueue(
address=address, address=address,
is_server=False, is_server=False,
num_client=args.mp_num, num_client=args.mp_num,
client_id=rank, client_id=self.rank,
local_data_parallel_id=args.local_data_parallel_id, local_data_parallel_id=args.local_data_parallel_id,
) )
@@ -223,8 +238,22 @@ class CacheTransferManager:
self.storage_backend = MooncakeStore(tp_rank=self.rank) self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args) self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully") logger.info("Initialized mooncake store successfully")
elif args.kvcache_storage_backend == "attention_store":
logger.info("Start initialize attention store...")
self.storage_backend = AttentionStore(
namespace=self.model_id,
shard_id=self.rank,
shard_num=self.n_ranks,
layer_num=self.num_layers + self.num_extra_layers,
block_token_size=self.block_size,
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
)
logger.info("Initialized attention store successfully!")
else: else:
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}") raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
self.storage_backend_type = args.kvcache_storage_backend
if args.write_policy not in ["write_through"]: if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}") raise ValueError(f"Invalid write policy: {args.write_policy}")
@@ -246,18 +275,16 @@ class CacheTransferManager:
cache layout: layer_num * [block_num, head_num, block_size, head_dim] cache layout: layer_num * [block_num, head_num, block_size, head_dim]
buffer layout: [block_num, layer_num, head_num, block_size, head_dim] buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
""" """
layer_num = args.num_layers + self.num_extra_layers layer_num = self.num_layers + self.num_extra_layers
head_num = self.key_cache_shape[1] block_num = (args.max_model_len + self.block_size - 1) // self.block_size
block_size = self.key_cache_shape[2]
head_dim = self.key_cache_shape[3]
block_num = (args.max_model_len + block_size - 1) // block_size
logger.info( logger.info(
f"Creating cache buffer for storage with shape: " f"Creating cache buffer for storage with shape: "
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]" f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
) )
self.cache_bytes = self._get_cache_bytes(self.cache_dtype) self.storage_buffer_stride_bytes = (
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
)
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB") logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
@@ -296,8 +323,8 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
set_device(self.device) set_device(self.device)
for i in range(args.num_layers + self.num_extra_layers): for i in range(self.num_layers + self.num_extra_layers):
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
@@ -415,7 +442,7 @@ class CacheTransferManager:
self.v_dst_ptrs = [] self.v_dst_ptrs = []
self.k_scales_ptrs = [] self.k_scales_ptrs = []
self.v_scales_ptrs = [] self.v_scales_ptrs = []
for i in range(args.num_layers + self.num_extra_layers): for i in range(self.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}" key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}" val_name = f"value_caches_{i}_rank{self.rank}"
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}" key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
@@ -446,228 +473,283 @@ class CacheTransferManager:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}") raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
return cache_bytes return cache_bytes
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]): def _run_read_storage(
self,
task_id: str,
token_ids: List[int],
start_read_block_idx: int,
k_cache_keys: List[str],
v_cache_keys: List[str],
gpu_block_ids: List[int],
cpu_block_ids: List[int],
timeout: float,
):
""" """
Given the k_keys and v_keys, get the valid blocks number that Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU.
can be prefetched from storage backend.
""" """
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.storage_backend.exists(k_keys + v_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
num += 1
return num
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try: try:
logger.debug( if self.storage_backend_type == "mooncake":
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, " block_num = len(gpu_block_ids)
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, " keys = k_cache_keys + v_cache_keys
f"cpu_block_ids_num: {len(cpu_block_ids)}" k_cache_ptrs = [
) self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
result = self.storage_backend.batch_get(
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
)
read_cost_time = time.time() - start_time
block_num = len(gpu_block_ids) k_result, v_result = result[:block_num], result[block_num:]
keys = k_cache_keys + v_cache_keys success_block_num = 0
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids] for k, v in zip(k_result, v_result):
v_cache_ptrs = [ if k > 0 and v > 0:
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids success_block_num += 1
] logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs valid_gpu_block_ids = gpu_block_ids[:success_block_num]
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value valid_cpu_block_ids = cpu_block_ids[:success_block_num]
start_time = time.time()
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
read_cost_time = time.time() - start_time
k_result, v_result = result[:block_num], result[block_num:] mode = 1 # cpu ==> gpu
success_block_num = 0 start_time = time.time()
for k, v in zip(k_result, v_result): swap_cache_layout(
if k > 0 and v > 0: self.gpu_cache_k_tensors,
success_block_num += 1 self.storage_key_read_buffer,
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}") self.key_cache_shape,
valid_gpu_block_ids = gpu_block_ids[:success_block_num] valid_gpu_block_ids,
valid_cpu_block_ids = cpu_block_ids[:success_block_num] valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_read_buffer,
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
)
mode = 1 # cpu ==> gpu elif self.storage_backend_type == "attention_store":
start_time = time.time() key_cache = []
swap_cache_layout( val_cache = []
self.gpu_cache_k_tensors, for i in range(self.num_layers + self.num_extra_layers):
self.storage_key_read_buffer, key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
self.key_cache_shape, val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
valid_gpu_block_ids,
valid_cpu_block_ids, start_time = time.time()
self.device, read_block_num = self.storage_backend.read(
mode, task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout
) )
swap_cache_layout( read_cost_time = time.time() - start_time
self.gpu_cache_v_tensors, valid_gpu_block_ids = gpu_block_ids[:read_block_num]
self.storage_value_read_buffer, logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s")
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
)
return valid_gpu_block_ids return valid_gpu_block_ids
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: " f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
f"error:{e}, {traceback.format_exc()}"
) )
raise raise
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): def read_storage_task(self, task: ReadStorageTask):
"""Read cache from the storage backend to the GPU memory.""" """Read cache from the storage backend to the GPU memory."""
try: try:
logger.debug( gpu_block_ids = task.gpu_block_ids.copy()
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, " cpu_block_ids = [i for i in range(len(gpu_block_ids))]
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}" k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
) v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys] match_block_num = 0
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys] if self.storage_backend_type == "mooncake":
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys) match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}") elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
)
logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}")
k_cache_keys = k_cache_keys[:match_block_num] k_cache_keys = k_cache_keys[:match_block_num]
v_cache_keys = v_cache_keys[:match_block_num] v_cache_keys = v_cache_keys[:match_block_num]
gpu_block_ids = gpu_block_ids[:match_block_num] gpu_block_ids = gpu_block_ids[:match_block_num]
cpu_block_ids = [i for i in range(match_block_num)] cpu_block_ids = cpu_block_ids[:match_block_num]
valid_gpu_block_ids = [] valid_gpu_block_ids = []
if match_block_num > 0: if match_block_num > 0:
# TODO: support timeout with actual block count # TODO: support timeout with actual block count
try: try:
valid_gpu_block_ids = self._run_read_storage( valid_gpu_block_ids = self._run_read_storage(
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids task.task_id,
task.token_ids[: match_block_num * self.block_size],
task.start_read_block_idx,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
) )
logger.info( logger.info(
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}." f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
) )
except Exception as e: except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}") logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
valid_gpu_block_ids = [] valid_gpu_block_ids = []
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids) result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids)
self.cache_task_queue.swap_storage_to_gpu_barrier.wait() self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
self.cache_task_queue.swap_storage_to_gpu_barrier.reset() self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result) self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(f"read_storage_task: put_transfer_done_signal {result}") logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}")
logger.info(
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
f"valid block num {len(valid_gpu_block_ids)}"
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: " f"An error occurred in read_storage_task: "
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}" f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}"
) )
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids): def _run_write_back_storage(
self,
task_id,
token_ids,
start_write_block_idx,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
timeout,
):
try: try:
logger.debug( if self.storage_backend_type == "mooncake":
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, " key_cache_size = [
f"gpu_block_ids: {gpu_block_ids}" self.key_cache_shape[0],
) self.key_cache_shape[1],
key_cache_size = [ self.key_cache_shape[2],
self.key_cache_shape[0], self.key_cache_shape[3],
self.key_cache_shape[1], ]
self.key_cache_shape[2], mode = 0 # gpu ==> cpu
self.key_cache_shape[3], start_time = time.time()
] swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
mode = 0 # gpu ==> cpu block_num = len(gpu_block_ids)
start_time = time.time() keys = k_cache_keys + v_cache_keys
swap_cache_layout( k_cache_ptrs = [
self.gpu_cache_k_tensors, self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
self.storage_key_write_buffer, ]
key_cache_size, v_cache_ptrs = [
gpu_block_ids, self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
cpu_block_ids, ]
self.device, kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
mode, kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cost_time = time.time() - start_time
block_num = len(gpu_block_ids) start_time = time.time()
keys = k_cache_keys + v_cache_keys self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
k_cache_ptrs = [ write_cost_time = time.time() - start_time
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
] logger.debug(
v_cache_ptrs = [ f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids )
] return block_num
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value elif self.storage_backend_type == "attention_store":
start_time = time.time() key_cache = []
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes) val_cache = []
write_cost_time = time.time() - start_time for i in range(self.num_layers + self.num_extra_layers):
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
start_time = time.time()
write_block_num = self.storage_backend.write(
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout
)
write_cost_time = time.time() - start_time
logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s")
return write_block_num
logger.debug(
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: " f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
f"error:{e}, {traceback.format_exc()}"
) )
return 0
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1): def write_back_storage_task(self, task: WriteStorageTask):
""" """
Write cache to the storage backend from the GPU memory. Write cache to the storage backend from the GPU memory.
""" """
try: try:
logger.debug( gpu_block_ids = task.gpu_block_ids.copy()
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
f"task_id: {task_id}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = [i for i in range(len(gpu_block_ids))] cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
if len(k_cache_keys) == 0: match_block_num = 0
logger.info(f"No uncached keys found for task {task_id}") if self.storage_backend_type == "mooncake":
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
elif self.storage_backend_type == "attention_store":
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
if match_block_num >= len(k_cache_keys):
logger.info(f"No uncached keys found for task {task.task_id}")
gpu_block_ids = [] gpu_block_ids = []
else: else:
try: try:
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = cpu_block_ids[match_block_num:]
# TODO: support timeout with actual block count # TODO: support timeout with actual block count
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids) write_block_num = self._run_write_back_storage(
task.task_id,
task.token_ids,
match_block_num,
k_cache_keys,
v_cache_keys,
gpu_block_ids,
cpu_block_ids,
task.timeout,
)
logger.info(
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
)
except Exception as e: except Exception as e:
logger.error(f"Error in write back storage task: {e}") logger.error(f"Error in write back storage task: {e}")
gpu_block_ids = [] gpu_block_ids = []
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids) result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait() self.cache_task_queue.swap_to_storage_barrier.wait()
if self.rank == 0: # 只有当rank为0时执行同步操作 if self.rank == 0: # 只有当rank为0时执行同步操作
self.cache_task_queue.swap_to_storage_barrier.reset() self.cache_task_queue.swap_to_storage_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
except Exception as e: except Exception as e:
logger.error( logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: " f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}"
f"error:{e}, {traceback.format_exc()}"
) )
def _do_swap_to_cpu_task( def _do_swap_to_cpu_task(
@@ -759,12 +841,12 @@ class CacheTransferManager:
self.cache_task_queue.barrier1.reset() self.cache_task_queue.barrier1.reset()
if self.cache_task_broadcast_signal.value[0] == 1: if self.cache_task_broadcast_signal.value[0] == 1:
data, read_finish = self.cache_task_queue.get_transfer_task() data, read_finish = self.cache_task_queue.get_transfer_task()
logger.debug(f"transfer data: get_transfer_task {data}") logger.debug(f"do_data_transfer: {data}")
if read_finish: if read_finish:
self.cache_task_broadcast_signal.value[0] = 0 self.cache_task_broadcast_signal.value[0] = 0
event_type, transfer_task_id = data[0], data[1] event_type, event_args = data[0], data[1:]
if event_type.value == CacheStatus.SWAP2CPU.value: if event_type.value == CacheStatus.SWAP2CPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:] transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_cpu_thread_pool.submit( self.swap_to_cpu_thread_pool.submit(
self._do_swap_to_cpu_task, self._do_swap_to_cpu_task,
swap_node_ids, swap_node_ids,
@@ -774,7 +856,7 @@ class CacheTransferManager:
transfer_task_id, transfer_task_id,
) )
elif event_type.value == CacheStatus.SWAP2GPU.value: elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:] transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
self.swap_to_gpu_thread_pool.submit( self.swap_to_gpu_thread_pool.submit(
self._do_swap_to_gpu_task, self._do_swap_to_gpu_task,
swap_node_ids, swap_node_ids,
@@ -784,22 +866,16 @@ class CacheTransferManager:
transfer_task_id, transfer_task_id,
) )
elif event_type.value == CacheStatus.STORAGE2GPU.value: elif event_type.value == CacheStatus.STORAGE2GPU.value:
hash_keys, gpu_block_ids, timeout = data[2:] read_storage_task = event_args[0]
self.read_storage_thread_pool.submit( self.read_storage_thread_pool.submit(
self.read_storage_task, self.read_storage_task,
transfer_task_id, read_storage_task,
hash_keys,
gpu_block_ids,
timeout,
) )
elif event_type.value == CacheStatus.GPU2STORAGE.value: elif event_type.value == CacheStatus.GPU2STORAGE.value:
hash_keys, gpu_block_ids, timeout = data[2:] write_storage_task = event_args[0]
self.write_back_storage_thread_pool.submit( self.write_back_storage_thread_pool.submit(
self.write_back_storage_task, self.write_back_storage_task,
transfer_task_id, write_storage_task,
hash_keys,
gpu_block_ids,
timeout,
) )
else: else:
if self.n_ranks > 1: if self.n_ranks > 1:
@@ -1047,7 +1123,11 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log") if args.mp_num > 1:
logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log")
else:
logger = get_logger("cache_transfer", "cache_transfer.log")
logger.info(f"args: {vars(args)}") logger.info(f"args: {vars(args)}")
set_device(args.device_id) set_device(args.device_id)
try: try:
@@ -31,7 +31,9 @@ import numpy as np
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.ops import get_all_visible_devices from fastdeploy.cache_manager.ops import get_all_visible_devices
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
@@ -47,7 +49,7 @@ class PrefixCacheManager:
def __init__( def __init__(
self, self,
config, config: FDConfig,
tensor_parallel_size, tensor_parallel_size,
splitwise_role="mixed", splitwise_role="mixed",
local_data_parallel_id=0, local_data_parallel_id=0,
@@ -207,7 +209,6 @@ class PrefixCacheManager:
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num) key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
key_cache_shape = ",".join([str(i) for i in key_cache_shape]) key_cache_shape = ",".join([str(i) for i in key_cache_shape])
val_cache_shape = ",".join([str(i) for i in val_cache_shape]) val_cache_shape = ",".join([str(i) for i in val_cache_shape])
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
if self.enable_splitwise: if self.enable_splitwise:
cache_messager_processes = self.launch_cache_messager( cache_messager_processes = self.launch_cache_messager(
cache_config, cache_config,
@@ -273,6 +274,7 @@ class PrefixCacheManager:
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}" + f" {sys.executable} {py_path}"
+ f" --model_id {os.path.basename(self.config.model_config.model)}"
+ f" --device_id {int(device_ids[i])}" + f" --device_id {int(device_ids[i])}"
+ f" --rank {i}" + f" --rank {i}"
+ f" --splitwise_role {self.splitwise_role}" + f" --splitwise_role {self.splitwise_role}"
@@ -390,7 +392,7 @@ class PrefixCacheManager:
+ f" --ipc_suffix {ipc_suffix}" + f" --ipc_suffix {ipc_suffix}"
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1" + f" >{log_dir}/launch_cache_messager_{i}.log 2>&1"
) )
logger.info(f"Launch cache messager, command:{launch_cmd}") logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -789,9 +791,15 @@ class PrefixCacheManager:
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}" f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
) )
start_time = time.time() start_time = time.time()
storage_matched_block_ids = self.issue_prefetch_storage_task( read_storage_task = ReadStorageTask(
req_id, no_match_block_keys, gpu_recv_storage_block_ids task_id=req_id,
keys=no_match_block_keys,
token_ids=input_token_ids,
gpu_block_ids=gpu_recv_storage_block_ids,
start_read_block_idx=match_token_num // block_size,
) )
logger.debug(f"issue read storage task: {read_storage_task}")
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
storage_matched_block_num = len(storage_matched_block_ids) storage_matched_block_num = len(storage_matched_block_ids)
storage_match_token_num = storage_matched_block_num * block_size storage_match_token_num = storage_matched_block_num * block_size
cost_time = time.time() - start_time cost_time = time.time() - start_time
@@ -1006,6 +1014,12 @@ class PrefixCacheManager:
if self.kvcache_storage_backend is None: if self.kvcache_storage_backend is None:
return return
token_ids = request.prompt_token_ids
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
if self.config.cache_config.enable_output_caching:
token_ids += request.output_token_ids
req_id = request.request_id req_id = request.request_id
keys = [] keys = []
node = self.req_leaf_map[req_id] node = self.req_leaf_map[req_id]
@@ -1018,24 +1032,33 @@ class PrefixCacheManager:
gpu_block_ids = request.block_tables[: len(keys)] gpu_block_ids = request.block_tables[: len(keys)]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
write_storage_task = WriteStorageTask(
task_id=req_id,
keys=keys,
token_ids=token_ids,
gpu_block_ids=gpu_block_ids,
)
logger.debug(f"issue write storage task: {write_storage_task}")
tic = time.time() tic = time.time()
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True) self.issue_write_back_storage_task(write_storage_task, is_sync=True)
cost_time = time.time() - tic cost_time = time.time() - tic
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
if self.kvcache_storage_backend is None: if self.kvcache_storage_backend is None:
return return
if len(hash_keys) != len(gpu_block_ids): if len(task.keys) != len(task.gpu_block_ids):
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})" err_msg = (
f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})"
)
logger.error(err_msg) logger.error(err_msg)
raise ValueError(err_msg) raise ValueError(err_msg)
self.task_write_back_event[req_id] = Event() self.task_write_back_event[task.task_id] = Event()
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout)) self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task))
if is_sync: if is_sync:
self.wait_write_storage_task(req_id) self.wait_write_storage_task(task.task_id)
def wait_write_storage_task(self, req_id): def wait_write_storage_task(self, req_id):
""" """
@@ -1045,16 +1068,19 @@ class PrefixCacheManager:
self.task_write_back_event[req_id].wait() self.task_write_back_event[req_id].wait()
del self.task_write_back_event[req_id] del self.task_write_back_event[req_id]
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5): def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
""" """
Prefetch cache from storage task Prefetch cache from storage task
""" """
if self.kvcache_storage_backend is None:
return []
storage_block_ids = [] storage_block_ids = []
self.task_prefetch_event[req_id] = Event() self.task_prefetch_event[task.task_id] = Event()
# issue task to cache_transfer_manager # issue task to cache_transfer_manager
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout)) self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task))
if is_sync: if is_sync:
storage_block_ids = self.wait_prefetch_storage_task(req_id) storage_block_ids = self.wait_prefetch_storage_task(task.task_id)
return storage_block_ids return storage_block_ids
def wait_prefetch_storage_task(self, req_id): def wait_prefetch_storage_task(self, req_id):
@@ -17,7 +17,7 @@
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from .kvcache_storage import KVCacheStorage from .kvcache_storage import KVCacheStorage
from .mooncake_store import MooncakeStore from .mooncake_store import AttentionStore, MooncakeStore
from .rdma_cache_transfer import RDMACommManager from .rdma_cache_transfer import RDMACommManager
if current_platform.is_cuda(): if current_platform.is_cuda():
@@ -31,4 +31,5 @@ __all__ = [
"RDMACommManager", "RDMACommManager",
"KVCacheStorage", "KVCacheStorage",
"MooncakeStore", "MooncakeStore",
"AttentionStore",
] ]
@@ -95,3 +95,10 @@ class KVCacheStorage(ABC):
Clear all keys in storage Clear all keys in storage
""" """
pass pass
@abstractmethod
def query(self) -> int:
"""
Query the number of blocks stored in the storage.
"""
pass
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
from .attention_store import AttentionStore
from .mooncake_store import MooncakeStore from .mooncake_store import MooncakeStore
__all__ = ["MooncakeStore"] __all__ = ["MooncakeStore", "AttentionStore"]
@@ -0,0 +1,209 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
import traceback
from dataclasses import dataclass
from typing import List
import paddle
from fastdeploy.cache_manager.transfer_factory.kvcache_storage import (
KVCacheStorage,
logger,
)
try:
from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens
from attentionstore_sdk.utils.err import AttentionStoreSDKError
_ATTENTIONSTORE_AVAILABLE = True
except Exception:
AttentionStoreSDK = None
Tokens = None
AttentionStoreSDKError = None
_ATTENTIONSTORE_AVAILABLE = False
@dataclass
class AttentionStoreConfig:
namespace: str = "default_ns"
pod_name: str = "default_pod"
model_version: str = "v0"
shard_id: int = 0
shard_num: int = 1
layer_num: int = 1
block_token_size: int = 64
bytes_per_shard_layer_per_block: int = 1024
device_id: int = 0
dp_id: int = 0
class AttentionStore(KVCacheStorage):
def __init__(self, **args):
if not _ATTENTIONSTORE_AVAILABLE:
raise ImportError("Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk.")
self.config = AttentionStoreConfig(**args)
try:
logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}")
self.sdk = AttentionStoreSDK(
self.config.namespace,
self.config.pod_name,
self.config.model_version,
self.config.shard_id,
self.config.shard_num,
self.config.layer_num,
self.config.block_token_size,
self.config.bytes_per_shard_layer_per_block,
self.config.device_id,
self.config.dp_id,
)
self.wait_for_sdk_ready(timeout=300, delta_t=5)
logger.info("[INIT] ✅ AttentionStore is initialized successfully!")
except Exception as e:
logger.error(
f"[INIT] ❌ AttentionStore initialization failed, error: {e}, traceback:\n{traceback.format_exc()}"
)
def wait_for_sdk_ready(self, timeout: float, delta_t: float):
t = 0
while t < timeout:
try:
tokens = Tokens(list(range(self.config.block_token_size + 1)), self.config.block_token_size)
self.sdk.match(tokens, 0, delta_t)
return
except AttentionStoreSDKError as e:
if "cuda memory not ready" in str(e):
logger.debug("[INIT] cuda memory not ready, try again..")
time.sleep(delta_t)
t += delta_t
continue
else:
raise RuntimeError(
f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}"
)
raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds")
def read(
self,
task_id: str,
key_cache: List[paddle.Tensor],
val_cache: List[paddle.Tensor],
token_ids: List[int],
gpu_block_ids: List[int],
start_read_block_idx: int,
timeout: float = 30.0,
):
logger.debug(
f"[READ BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
k_data_ptrs = [k.data_ptr() for k in key_cache]
v_data_ptrs = [v.data_ptr() for v in val_cache]
num = 0
try:
num = self.sdk.read(
list(range(self.config.layer_num)),
tokens,
start_read_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
)
logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
)
return num
def write(
self,
task_id: str,
key_cache: List[paddle.Tensor],
val_cache: List[paddle.Tensor],
token_ids: List[int],
gpu_block_ids: List[int],
start_write_block_idx: int,
timeout: float = 30.0,
) -> int:
logger.debug(
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
k_data_ptrs = [k.data_ptr() for k in key_cache]
v_data_ptrs = [v.data_ptr() for v in val_cache]
num = 0
try:
num = self.sdk.write(
list(range(self.config.layer_num)),
tokens,
start_write_block_idx,
k_data_ptrs,
v_data_ptrs,
gpu_block_ids,
timeout,
)
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
)
return num
def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
"""
Given the input ids and starting index to match, get the valid blocks number that
can be prefetched from storage backend.
"""
logger.debug(
f"[QUERY BEGIN] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}"
)
tokens = Tokens(token_ids, self.config.block_token_size)
num = 0
try:
num = self.sdk.match(tokens, start_match_block_idx, timeout)
logger.debug(f"[QUERY END] task_id: {task_id} matched_blocks: {num}")
except AttentionStoreSDKError:
logger.error(
f"[QUERY ERROR] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
)
return num
def get(self, **kwargs):
raise NotImplementedError("AttentionStore does not support this method")
def batch_get(self, **kwargs):
raise NotImplementedError("AttentionStore does not support this method")
def set(self, **kwargs) -> bool:
raise NotImplementedError("AttentionStore does not support this method")
def batch_set(self, **kwargs) -> bool:
raise NotImplementedError("AttentionStore does not support this method")
def exists(self, keys: List[str]) -> bool:
raise NotImplementedError("AttentionStore does not support this method")
def clear(self) -> bool:
raise NotImplementedError("AttentionStore does not support this method")
def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None:
raise NotImplementedError("AttentionStore does not support this method")
@@ -237,6 +237,21 @@ class MooncakeStore(KVCacheStorage):
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms") logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
return result return result
def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0):
"""
Given the k_keys and v_keys, get the valid blocks number that
can be prefetched from storage backend.
"""
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.exists(k_keys + v_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
num += 1
return num
def delete(self, key, timeout=5) -> bool: def delete(self, key, timeout=5) -> bool:
while timeout: while timeout:
result = self.store.remove(key) result = self.store.remove(key)
+1 -2
View File
@@ -629,7 +629,6 @@ class EngineArgs:
for port in cur_dp_ports: for port in cur_dp_ports:
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use." assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
console_logger.debug(f"post init {name}: {ports}")
return ports return ports
num_nodes = len(self.ips) if self.ips else 1 num_nodes = len(self.ips) if self.ips else 1
@@ -1077,7 +1076,7 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--kvcache-storage-backend", "--kvcache-storage-backend",
type=nullable_str, type=nullable_str,
choices=["mooncake"], choices=["mooncake", "attention_store"],
default=EngineArgs.kvcache_storage_backend, default=EngineArgs.kvcache_storage_backend,
help="The storage backend for kvcache storage. Leave empty to disable.", help="The storage backend for kvcache storage. Leave empty to disable.",
) )
-4
View File
@@ -225,10 +225,6 @@ class EngineService:
device_ids = self.cfg.parallel_config.device_ids.split(",") device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
# Set cache manager signal
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
# Worker launched # Worker launched
self.check_worker_initialize_status_func_thread.join() self.check_worker_initialize_status_func_thread.join()
if not result_container["worker_is_alive"]: if not result_container["worker_is_alive"]:
-4
View File
@@ -181,10 +181,6 @@ class LLMEngine:
device_ids = self.cfg.parallel_config.device_ids.split(",") device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
# Launch components: scheduler, cache_manager, expert_service et.al.
if self.cfg.scheduler_config.splitwise_role != "mixed":
self.launched_cache_manager_signal.value[0] = 1
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER: if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0] envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0] envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
@@ -1338,8 +1338,11 @@ class ResourceManagerV1(ResourceManager):
def update_metrics(self): def update_metrics(self):
# Update metrics # Update metrics
num_tasks = sum([1 if task else 0 for task in self.tasks_list]) num_tasks = sum([1 if task else 0 for task in self.tasks_list])
num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list]) blocks_used_by_tasks = set()
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks) for task in self.tasks_list:
if task is not None:
blocks_used_by_tasks.update(task.block_tables)
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks))
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
main_process_metrics.num_requests_running.set(len(self.running)) main_process_metrics.num_requests_running.set(len(self.running))
@@ -29,6 +29,7 @@ class Args:
mp_num = 1 mp_num = 1
device_id = 0 device_id = 0
speculative_config = {} speculative_config = {}
model_id = "test_model"
ipc_suffix = "test_ipc_suffix" ipc_suffix = "test_ipc_suffix"
cache_queue_port = 9999 cache_queue_port = 9999
pod_ip = "127.0.0.1" pod_ip = "127.0.0.1"
@@ -185,6 +185,7 @@ def _create_manager(
swap_space=4, swap_space=4,
) )
model_config = SimpleNamespace( model_config = SimpleNamespace(
model="test_model",
num_attention_heads=1, num_attention_heads=1,
num_key_value_heads=1, num_key_value_heads=1,
head_dim=1, head_dim=1,
-2
View File
@@ -332,8 +332,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
self.assertFalse(ok) self.assertFalse(ok)
# cache manager started before workers (lines 184-185) # cache manager started before workers (lines 184-185)
self.assertTrue(started_cache.get("called", False)) self.assertTrue(started_cache.get("called", False))
# launched_cache_manager_signal set (line 221)
self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1)
# avoid atexit finalizer # avoid atexit finalizer
if hasattr(eng, "_finalizer"): if hasattr(eng, "_finalizer"):
try: try: