mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] [KVCache] support attention_store kv cache backend (#5823)
* [feat] support attention_store kv cache backend * [fix] fix codestyle * [chore] optimize log * [fix] fix write storage task * [fix] fix read storage * [fix] fix code conflict after merge develop * [fix] fix cache bytes and read task token ids * [chore] add model for cache transfer manager * [chore] add some log * [chore] remove launched_cache_manager_signal * [fix] fix write_back_storage_task match_block_num condition * [fix] fix swap_cost_time * [ci] fix ci * Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@
|
||||
| KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 |
|
||||
| KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 |
|
||||
| KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
|
||||
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
|
||||
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 | 个 |
|
||||
| KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 |
|
||||
| KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 |
|
||||
| KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 |
|
||||
|
||||
@@ -1051,7 +1051,10 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
|
||||
if args.mp_num > 1:
|
||||
logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log")
|
||||
else:
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
logger.info("create cache messager...")
|
||||
logger.info(f"{args}")
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class CacheTask:
|
||||
task_id: str
|
||||
keys: List[str]
|
||||
token_ids: List[int]
|
||||
gpu_block_ids: List[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class ReadStorageTask(CacheTask):
|
||||
start_read_block_idx: int
|
||||
timeout: float = 30.0
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class WriteStorageTask(CacheTask):
|
||||
timeout: float = 30.0
|
||||
@@ -29,6 +29,7 @@ import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||
from fastdeploy.cache_manager.ops import (
|
||||
cuda_host_alloc,
|
||||
cuda_host_free,
|
||||
@@ -40,7 +41,7 @@ from fastdeploy.cache_manager.ops import (
|
||||
swap_cache_layout,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
|
||||
from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -58,6 +59,7 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--model_id", type=str, default="default", help="model id")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||
@@ -109,7 +111,7 @@ def parse_args():
|
||||
"--kvcache_storage_backend",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mooncake", "none"],
|
||||
choices=["mooncake", "attention_store", "none"],
|
||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -133,8 +135,6 @@ class CacheTransferManager:
|
||||
"""
|
||||
初始化CacheTransferManager
|
||||
"""
|
||||
device = args.device_id
|
||||
rank = args.rank
|
||||
self.gpu_cache_kvs = {}
|
||||
self.cpu_cache_kvs = {}
|
||||
self.gpu_cache_k_tensors = []
|
||||
@@ -142,11 +142,31 @@ class CacheTransferManager:
|
||||
self.gpu_cache_scales_k_tensors = []
|
||||
self.gpu_cache_scales_v_tensors = []
|
||||
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
|
||||
# parse kv cache shape
|
||||
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
||||
self.value_cache_shape = []
|
||||
if args.value_cache_shape:
|
||||
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
|
||||
|
||||
# extract kv cache shape into fields
|
||||
self.num_gpu_blocks = self.key_cache_shape[0]
|
||||
self.head_num = self.key_cache_shape[1]
|
||||
self.block_size = self.key_cache_shape[2]
|
||||
self.head_dim = self.key_cache_shape[3]
|
||||
|
||||
# compute cache bytes
|
||||
self.cache_dtype = args.cache_dtype
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
|
||||
# extract other arg values
|
||||
self.model_id = args.model_id
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = args.rank
|
||||
self.device = args.device_id
|
||||
self.num_layers = args.num_layers
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.local_data_parallel_id = args.local_data_parallel_id
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
paddle.set_default_dtype(args.default_dtype)
|
||||
@@ -158,18 +178,13 @@ class CacheTransferManager:
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.cache_dtype = args.cache_dtype
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=address,
|
||||
is_server=False,
|
||||
num_client=args.mp_num,
|
||||
client_id=rank,
|
||||
client_id=self.rank,
|
||||
local_data_parallel_id=args.local_data_parallel_id,
|
||||
)
|
||||
|
||||
@@ -223,8 +238,22 @@ class CacheTransferManager:
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
elif args.kvcache_storage_backend == "attention_store":
|
||||
logger.info("Start initialize attention store...")
|
||||
self.storage_backend = AttentionStore(
|
||||
namespace=self.model_id,
|
||||
shard_id=self.rank,
|
||||
shard_num=self.n_ranks,
|
||||
layer_num=self.num_layers + self.num_extra_layers,
|
||||
block_token_size=self.block_size,
|
||||
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||
device_id=self.device,
|
||||
dp_id=self.local_data_parallel_id,
|
||||
)
|
||||
logger.info("Initialized attention store successfully!")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||
self.storage_backend_type = args.kvcache_storage_backend
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
@@ -246,18 +275,16 @@ class CacheTransferManager:
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
"""
|
||||
layer_num = args.num_layers + self.num_extra_layers
|
||||
head_num = self.key_cache_shape[1]
|
||||
block_size = self.key_cache_shape[2]
|
||||
head_dim = self.key_cache_shape[3]
|
||||
block_num = (args.max_model_len + block_size - 1) // block_size
|
||||
layer_num = self.num_layers + self.num_extra_layers
|
||||
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
|
||||
logger.info(
|
||||
f"Creating cache buffer for storage with shape: "
|
||||
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
|
||||
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
|
||||
)
|
||||
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
|
||||
self.storage_buffer_stride_bytes = (
|
||||
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
|
||||
)
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
@@ -296,8 +323,8 @@ class CacheTransferManager:
|
||||
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
||||
@@ -415,7 +442,7 @@ class CacheTransferManager:
|
||||
self.v_dst_ptrs = []
|
||||
self.k_scales_ptrs = []
|
||||
self.v_scales_ptrs = []
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||
@@ -446,228 +473,283 @@ class CacheTransferManager:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return cache_bytes
|
||||
|
||||
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
|
||||
def _run_read_storage(
|
||||
self,
|
||||
task_id: str,
|
||||
token_ids: List[int],
|
||||
start_read_block_idx: int,
|
||||
k_cache_keys: List[str],
|
||||
v_cache_keys: List[str],
|
||||
gpu_block_ids: List[int],
|
||||
cpu_block_ids: List[int],
|
||||
timeout: float,
|
||||
):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.storage_backend.exists(k_keys + v_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, "
|
||||
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, "
|
||||
f"cpu_block_ids_num: {len(cpu_block_ids)}"
|
||||
)
|
||||
if self.storage_backend_type == "mooncake":
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
result = self.storage_backend.batch_get(
|
||||
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
|
||||
)
|
||||
read_cost_time = time.time() - start_time
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
read_cost_time = time.time() - start_time
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
mode = 1 # cpu ==> gpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||
)
|
||||
|
||||
mode = 1 # cpu ==> gpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||
)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
key_cache = []
|
||||
val_cache = []
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
|
||||
start_time = time.time()
|
||||
read_block_num = self.storage_backend.read(
|
||||
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout
|
||||
)
|
||||
read_cost_time = time.time() - start_time
|
||||
valid_gpu_block_ids = gpu_block_ids[:read_block_num]
|
||||
logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s")
|
||||
|
||||
return valid_gpu_block_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
def read_storage_task(self, task: ReadStorageTask):
|
||||
"""Read cache from the storage backend to the GPU memory."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, "
|
||||
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}"
|
||||
)
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(
|
||||
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
|
||||
)
|
||||
logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}")
|
||||
|
||||
k_cache_keys = k_cache_keys[:match_block_num]
|
||||
v_cache_keys = v_cache_keys[:match_block_num]
|
||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||
cpu_block_ids = [i for i in range(match_block_num)]
|
||||
cpu_block_ids = cpu_block_ids[:match_block_num]
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
if match_block_num > 0:
|
||||
# TODO: support timeout with actual block count
|
||||
try:
|
||||
valid_gpu_block_ids = self._run_read_storage(
|
||||
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
|
||||
task.task_id,
|
||||
task.token_ids[: match_block_num * self.block_size],
|
||||
task.start_read_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
)
|
||||
logger.info(
|
||||
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
|
||||
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
|
||||
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
|
||||
result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids)
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
|
||||
f"valid block num {len(valid_gpu_block_ids)}"
|
||||
)
|
||||
logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
|
||||
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in read_storage_task: "
|
||||
f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
def _run_write_back_storage(
|
||||
self,
|
||||
task_id,
|
||||
token_ids,
|
||||
start_write_block_idx,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
timeout,
|
||||
):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
if self.storage_backend_type == "mooncake":
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
mode = 0 # gpu ==> cpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
|
||||
mode = 0 # gpu ==> cpu
|
||||
start_time = time.time()
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cost_time = time.time() - start_time
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
start_time = time.time()
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
write_cost_time = time.time() - start_time
|
||||
start_time = time.time()
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
write_cost_time = time.time() - start_time
|
||||
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
||||
)
|
||||
return block_num
|
||||
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
key_cache = []
|
||||
val_cache = []
|
||||
for i in range(self.num_layers + self.num_extra_layers):
|
||||
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||
|
||||
start_time = time.time()
|
||||
write_block_num = self.storage_backend.write(
|
||||
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout
|
||||
)
|
||||
write_cost_time = time.time() - start_time
|
||||
logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s")
|
||||
return write_block_num
|
||||
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
return 0
|
||||
|
||||
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
def write_back_storage_task(self, task: WriteStorageTask):
|
||||
"""
|
||||
Write cache to the storage backend from the GPU memory.
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
|
||||
f"task_id: {task_id}, timeout: {timeout}"
|
||||
)
|
||||
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
|
||||
if len(k_cache_keys) == 0:
|
||||
logger.info(f"No uncached keys found for task {task_id}")
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
|
||||
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
|
||||
|
||||
if match_block_num >= len(k_cache_keys):
|
||||
logger.info(f"No uncached keys found for task {task.task_id}")
|
||||
gpu_block_ids = []
|
||||
else:
|
||||
try:
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
cpu_block_ids = cpu_block_ids[match_block_num:]
|
||||
# TODO: support timeout with actual block count
|
||||
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
|
||||
write_block_num = self._run_write_back_storage(
|
||||
task.task_id,
|
||||
task.token_ids,
|
||||
match_block_num,
|
||||
k_cache_keys,
|
||||
v_cache_keys,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
task.timeout,
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in write back storage task: {e}")
|
||||
gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
|
||||
result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
|
||||
self.cache_task_queue.swap_to_storage_barrier.wait()
|
||||
if self.rank == 0: # 只有当rank为0时执行同步操作
|
||||
self.cache_task_queue.swap_to_storage_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
||||
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
@@ -759,12 +841,12 @@ class CacheTransferManager:
|
||||
self.cache_task_queue.barrier1.reset()
|
||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
||||
logger.debug(f"do_data_transfer: {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
event_type, transfer_task_id = data[0], data[1]
|
||||
event_type, event_args = data[0], data[1:]
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
@@ -774,7 +856,7 @@ class CacheTransferManager:
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
@@ -784,22 +866,16 @@ class CacheTransferManager:
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
read_storage_task = event_args[0]
|
||||
self.read_storage_thread_pool.submit(
|
||||
self.read_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
read_storage_task,
|
||||
)
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
write_storage_task = event_args[0]
|
||||
self.write_back_storage_thread_pool.submit(
|
||||
self.write_back_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
write_storage_task,
|
||||
)
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
@@ -1047,7 +1123,11 @@ if __name__ == "__main__":
|
||||
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
||||
if args.mp_num > 1:
|
||||
logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log")
|
||||
else:
|
||||
logger = get_logger("cache_transfer", "cache_transfer.log")
|
||||
|
||||
logger.info(f"args: {vars(args)}")
|
||||
set_device(args.device_id)
|
||||
try:
|
||||
|
||||
@@ -31,7 +31,9 @@ import numpy as np
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
@@ -47,7 +49,7 @@ class PrefixCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: FDConfig,
|
||||
tensor_parallel_size,
|
||||
splitwise_role="mixed",
|
||||
local_data_parallel_id=0,
|
||||
@@ -207,7 +209,6 @@ class PrefixCacheManager:
|
||||
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
|
||||
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
|
||||
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
|
||||
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
|
||||
if self.enable_splitwise:
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
@@ -273,6 +274,7 @@ class PrefixCacheManager:
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --model_id {os.path.basename(self.config.model_config.model)}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
@@ -390,7 +392,7 @@ class PrefixCacheManager:
|
||||
+ f" --ipc_suffix {ipc_suffix}"
|
||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_messager_{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
@@ -789,9 +791,15 @@ class PrefixCacheManager:
|
||||
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
|
||||
)
|
||||
start_time = time.time()
|
||||
storage_matched_block_ids = self.issue_prefetch_storage_task(
|
||||
req_id, no_match_block_keys, gpu_recv_storage_block_ids
|
||||
read_storage_task = ReadStorageTask(
|
||||
task_id=req_id,
|
||||
keys=no_match_block_keys,
|
||||
token_ids=input_token_ids,
|
||||
gpu_block_ids=gpu_recv_storage_block_ids,
|
||||
start_read_block_idx=match_token_num // block_size,
|
||||
)
|
||||
logger.debug(f"issue read storage task: {read_storage_task}")
|
||||
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
|
||||
storage_matched_block_num = len(storage_matched_block_ids)
|
||||
storage_match_token_num = storage_matched_block_num * block_size
|
||||
cost_time = time.time() - start_time
|
||||
@@ -1006,6 +1014,12 @@ class PrefixCacheManager:
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
token_ids = request.prompt_token_ids
|
||||
if isinstance(token_ids, np.ndarray):
|
||||
token_ids = token_ids.tolist()
|
||||
if self.config.cache_config.enable_output_caching:
|
||||
token_ids += request.output_token_ids
|
||||
|
||||
req_id = request.request_id
|
||||
keys = []
|
||||
node = self.req_leaf_map[req_id]
|
||||
@@ -1018,24 +1032,33 @@ class PrefixCacheManager:
|
||||
|
||||
gpu_block_ids = request.block_tables[: len(keys)]
|
||||
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
||||
write_storage_task = WriteStorageTask(
|
||||
task_id=req_id,
|
||||
keys=keys,
|
||||
token_ids=token_ids,
|
||||
gpu_block_ids=gpu_block_ids,
|
||||
)
|
||||
logger.debug(f"issue write storage task: {write_storage_task}")
|
||||
tic = time.time()
|
||||
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
|
||||
self.issue_write_back_storage_task(write_storage_task, is_sync=True)
|
||||
cost_time = time.time() - tic
|
||||
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
||||
|
||||
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
if len(hash_keys) != len(gpu_block_ids):
|
||||
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
|
||||
if len(task.keys) != len(task.gpu_block_ids):
|
||||
err_msg = (
|
||||
f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})"
|
||||
)
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
self.task_write_back_event[req_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
self.task_write_back_event[task.task_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task))
|
||||
if is_sync:
|
||||
self.wait_write_storage_task(req_id)
|
||||
self.wait_write_storage_task(task.task_id)
|
||||
|
||||
def wait_write_storage_task(self, req_id):
|
||||
"""
|
||||
@@ -1045,16 +1068,19 @@ class PrefixCacheManager:
|
||||
self.task_write_back_event[req_id].wait()
|
||||
del self.task_write_back_event[req_id]
|
||||
|
||||
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
|
||||
"""
|
||||
Prefetch cache from storage task
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return []
|
||||
|
||||
storage_block_ids = []
|
||||
self.task_prefetch_event[req_id] = Event()
|
||||
self.task_prefetch_event[task.task_id] = Event()
|
||||
# issue task to cache_transfer_manager
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task))
|
||||
if is_sync:
|
||||
storage_block_ids = self.wait_prefetch_storage_task(req_id)
|
||||
storage_block_ids = self.wait_prefetch_storage_task(task.task_id)
|
||||
return storage_block_ids
|
||||
|
||||
def wait_prefetch_storage_task(self, req_id):
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .kvcache_storage import KVCacheStorage
|
||||
from .mooncake_store import MooncakeStore
|
||||
from .mooncake_store import AttentionStore, MooncakeStore
|
||||
from .rdma_cache_transfer import RDMACommManager
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -31,4 +31,5 @@ __all__ = [
|
||||
"RDMACommManager",
|
||||
"KVCacheStorage",
|
||||
"MooncakeStore",
|
||||
"AttentionStore",
|
||||
]
|
||||
|
||||
@@ -95,3 +95,10 @@ class KVCacheStorage(ABC):
|
||||
Clear all keys in storage
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self) -> int:
|
||||
"""
|
||||
Query the number of blocks stored in the storage.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .attention_store import AttentionStore
|
||||
from .mooncake_store import MooncakeStore
|
||||
|
||||
__all__ = ["MooncakeStore"]
|
||||
__all__ = ["MooncakeStore", "AttentionStore"]
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory.kvcache_storage import (
|
||||
KVCacheStorage,
|
||||
logger,
|
||||
)
|
||||
|
||||
try:
|
||||
from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens
|
||||
from attentionstore_sdk.utils.err import AttentionStoreSDKError
|
||||
|
||||
_ATTENTIONSTORE_AVAILABLE = True
|
||||
except Exception:
|
||||
AttentionStoreSDK = None
|
||||
Tokens = None
|
||||
AttentionStoreSDKError = None
|
||||
_ATTENTIONSTORE_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionStoreConfig:
|
||||
namespace: str = "default_ns"
|
||||
pod_name: str = "default_pod"
|
||||
model_version: str = "v0"
|
||||
shard_id: int = 0
|
||||
shard_num: int = 1
|
||||
layer_num: int = 1
|
||||
block_token_size: int = 64
|
||||
bytes_per_shard_layer_per_block: int = 1024
|
||||
device_id: int = 0
|
||||
dp_id: int = 0
|
||||
|
||||
|
||||
class AttentionStore(KVCacheStorage):
|
||||
def __init__(self, **args):
|
||||
|
||||
if not _ATTENTIONSTORE_AVAILABLE:
|
||||
raise ImportError("Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk.")
|
||||
|
||||
self.config = AttentionStoreConfig(**args)
|
||||
|
||||
try:
|
||||
logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}")
|
||||
self.sdk = AttentionStoreSDK(
|
||||
self.config.namespace,
|
||||
self.config.pod_name,
|
||||
self.config.model_version,
|
||||
self.config.shard_id,
|
||||
self.config.shard_num,
|
||||
self.config.layer_num,
|
||||
self.config.block_token_size,
|
||||
self.config.bytes_per_shard_layer_per_block,
|
||||
self.config.device_id,
|
||||
self.config.dp_id,
|
||||
)
|
||||
self.wait_for_sdk_ready(timeout=300, delta_t=5)
|
||||
logger.info("[INIT] ✅ AttentionStore is initialized successfully!")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[INIT] ❌ AttentionStore initialization failed, error: {e}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def wait_for_sdk_ready(self, timeout: float, delta_t: float):
|
||||
t = 0
|
||||
while t < timeout:
|
||||
try:
|
||||
tokens = Tokens(list(range(self.config.block_token_size + 1)), self.config.block_token_size)
|
||||
self.sdk.match(tokens, 0, delta_t)
|
||||
return
|
||||
except AttentionStoreSDKError as e:
|
||||
if "cuda memory not ready" in str(e):
|
||||
logger.debug("[INIT] cuda memory not ready, try again..")
|
||||
time.sleep(delta_t)
|
||||
t += delta_t
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds")
|
||||
|
||||
def read(
|
||||
self,
|
||||
task_id: str,
|
||||
key_cache: List[paddle.Tensor],
|
||||
val_cache: List[paddle.Tensor],
|
||||
token_ids: List[int],
|
||||
gpu_block_ids: List[int],
|
||||
start_read_block_idx: int,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
logger.debug(
|
||||
f"[READ BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}"
|
||||
)
|
||||
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||
k_data_ptrs = [k.data_ptr() for k in key_cache]
|
||||
v_data_ptrs = [v.data_ptr() for v in val_cache]
|
||||
num = 0
|
||||
try:
|
||||
num = self.sdk.read(
|
||||
list(range(self.config.layer_num)),
|
||||
tokens,
|
||||
start_read_block_idx,
|
||||
k_data_ptrs,
|
||||
v_data_ptrs,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}")
|
||||
except AttentionStoreSDKError:
|
||||
logger.error(
|
||||
f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
return num
|
||||
|
||||
def write(
|
||||
self,
|
||||
task_id: str,
|
||||
key_cache: List[paddle.Tensor],
|
||||
val_cache: List[paddle.Tensor],
|
||||
token_ids: List[int],
|
||||
gpu_block_ids: List[int],
|
||||
start_write_block_idx: int,
|
||||
timeout: float = 30.0,
|
||||
) -> int:
|
||||
logger.debug(
|
||||
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
|
||||
)
|
||||
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||
k_data_ptrs = [k.data_ptr() for k in key_cache]
|
||||
v_data_ptrs = [v.data_ptr() for v in val_cache]
|
||||
num = 0
|
||||
try:
|
||||
num = self.sdk.write(
|
||||
list(range(self.config.layer_num)),
|
||||
tokens,
|
||||
start_write_block_idx,
|
||||
k_data_ptrs,
|
||||
v_data_ptrs,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
|
||||
except AttentionStoreSDKError:
|
||||
logger.error(
|
||||
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
return num
|
||||
|
||||
def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
|
||||
"""
|
||||
Given the input ids and starting index to match, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[QUERY BEGIN] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}"
|
||||
)
|
||||
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||
num = 0
|
||||
try:
|
||||
num = self.sdk.match(tokens, start_match_block_idx, timeout)
|
||||
logger.debug(f"[QUERY END] task_id: {task_id} matched_blocks: {num}")
|
||||
except AttentionStoreSDKError:
|
||||
logger.error(
|
||||
f"[QUERY ERROR] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
return num
|
||||
|
||||
def get(self, **kwargs):
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def batch_get(self, **kwargs):
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def set(self, **kwargs) -> bool:
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def batch_set(self, **kwargs) -> bool:
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def exists(self, keys: List[str]) -> bool:
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def clear(self) -> bool:
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
|
||||
def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None:
|
||||
raise NotImplementedError("AttentionStore does not support this method")
|
||||
@@ -237,6 +237,21 @@ class MooncakeStore(KVCacheStorage):
|
||||
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
|
||||
return result
|
||||
|
||||
def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.exists(k_keys + v_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def delete(self, key, timeout=5) -> bool:
|
||||
while timeout:
|
||||
result = self.store.remove(key)
|
||||
|
||||
@@ -629,7 +629,6 @@ class EngineArgs:
|
||||
for port in cur_dp_ports:
|
||||
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
|
||||
|
||||
console_logger.debug(f"post init {name}: {ports}")
|
||||
return ports
|
||||
|
||||
num_nodes = len(self.ips) if self.ips else 1
|
||||
@@ -1077,7 +1076,7 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--kvcache-storage-backend",
|
||||
type=nullable_str,
|
||||
choices=["mooncake"],
|
||||
choices=["mooncake", "attention_store"],
|
||||
default=EngineArgs.kvcache_storage_backend,
|
||||
help="The storage backend for kvcache storage. Leave empty to disable.",
|
||||
)
|
||||
|
||||
@@ -225,10 +225,6 @@ class EngineService:
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
# Set cache manager signal
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
# Worker launched
|
||||
self.check_worker_initialize_status_func_thread.join()
|
||||
if not result_container["worker_is_alive"]:
|
||||
|
||||
@@ -181,10 +181,6 @@ class LLMEngine:
|
||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||
|
||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.launched_cache_manager_signal.value[0] = 1
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
|
||||
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
|
||||
|
||||
@@ -1338,8 +1338,11 @@ class ResourceManagerV1(ResourceManager):
|
||||
def update_metrics(self):
|
||||
# Update metrics
|
||||
num_tasks = sum([1 if task else 0 for task in self.tasks_list])
|
||||
num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks)
|
||||
blocks_used_by_tasks = set()
|
||||
for task in self.tasks_list:
|
||||
if task is not None:
|
||||
blocks_used_by_tasks.update(task.block_tables)
|
||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks))
|
||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
main_process_metrics.num_requests_running.set(len(self.running))
|
||||
|
||||
@@ -29,6 +29,7 @@ class Args:
|
||||
mp_num = 1
|
||||
device_id = 0
|
||||
speculative_config = {}
|
||||
model_id = "test_model"
|
||||
ipc_suffix = "test_ipc_suffix"
|
||||
cache_queue_port = 9999
|
||||
pod_ip = "127.0.0.1"
|
||||
|
||||
@@ -185,6 +185,7 @@ def _create_manager(
|
||||
swap_space=4,
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
model="test_model",
|
||||
num_attention_heads=1,
|
||||
num_key_value_heads=1,
|
||||
head_dim=1,
|
||||
|
||||
@@ -332,8 +332,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
self.assertFalse(ok)
|
||||
# cache manager started before workers (lines 184-185)
|
||||
self.assertTrue(started_cache.get("called", False))
|
||||
# launched_cache_manager_signal set (line 221)
|
||||
self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1)
|
||||
# avoid atexit finalizer
|
||||
if hasattr(eng, "_finalizer"):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user