mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] [KVCache] support attention_store kv cache backend (#5823)
* [feat] support attention_store kv cache backend * [fix] fix codestyle * [chore] optimize log * [fix] fix write storage task * [fix] fix read storage * [fix] fix code conflict after merge develop * [fix] fix cache bytes and read task token ids * [chore] add model for cache transfer manager * [chore] add some log * [chore] remove launched_cache_manager_signal * [fix] fix write_back_storage_task match_block_num condition * [fix] fix swap_cost_time * [ci] fix ci * Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -32,7 +32,7 @@
|
|||||||
| KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 |
|
| KV缓存 | `fastdeploy:gpu_hit_token_rate` | Gauge | token 级别 GPU 前缀缓存命中率 | 百分比 |
|
||||||
| KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 |
|
| KV缓存 | `fastdeploy:prefix_cache_token_num` | Counter | 前缀缓存token总数 | 个 |
|
||||||
| KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
|
| KV缓存 | `fastdeploy:prefix_gpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
|
||||||
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 GPU 上的前缀缓存 token 总数 | 个 |
|
| KV缓存 | `fastdeploy:prefix_cpu_cache_token_num` | Counter | 位于 CPU 上的前缀缓存 token 总数 | 个 |
|
||||||
| KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 |
|
| KV缓存 | `fastdeploy:available_gpu_block_num` | Gauge | 缓存中可用的 GPU 块数量(包含尚未正式释放的前缀缓存块)| 个 |
|
||||||
| KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 |
|
| KV缓存 | `fastdeploy:free_gpu_block_num` | Gauge | 缓存中的可用块数 | 个 |
|
||||||
| KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 |
|
| KV缓存 | `fastdeploy:max_gpu_block_num` | Gauge | 服务启动时确定的总块数 | 个 |
|
||||||
|
|||||||
@@ -1051,7 +1051,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||||
logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
|
if args.mp_num > 1:
|
||||||
|
logger = get_logger("cache_messager", f"cache_messager_{rank_id}.log")
|
||||||
|
else:
|
||||||
|
logger = get_logger("cache_messager", "cache_messager.log")
|
||||||
|
|
||||||
logger.info("create cache messager...")
|
logger.info("create cache messager...")
|
||||||
logger.info(f"{args}")
|
logger.info(f"{args}")
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class CacheTask:
|
||||||
|
task_id: str
|
||||||
|
keys: List[str]
|
||||||
|
token_ids: List[int]
|
||||||
|
gpu_block_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class ReadStorageTask(CacheTask):
|
||||||
|
start_read_block_idx: int
|
||||||
|
timeout: float = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, kw_only=True)
|
||||||
|
class WriteStorageTask(CacheTask):
|
||||||
|
timeout: float = 30.0
|
||||||
@@ -29,6 +29,7 @@ import paddle
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||||
|
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||||
from fastdeploy.cache_manager.ops import (
|
from fastdeploy.cache_manager.ops import (
|
||||||
cuda_host_alloc,
|
cuda_host_alloc,
|
||||||
cuda_host_free,
|
cuda_host_free,
|
||||||
@@ -40,7 +41,7 @@ from fastdeploy.cache_manager.ops import (
|
|||||||
swap_cache_layout,
|
swap_cache_layout,
|
||||||
unset_data_ipc,
|
unset_data_ipc,
|
||||||
)
|
)
|
||||||
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
|
from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeStore
|
||||||
from fastdeploy.config import SpeculativeConfig
|
from fastdeploy.config import SpeculativeConfig
|
||||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
@@ -58,6 +59,7 @@ def parse_args():
|
|||||||
default="mixed",
|
default="mixed",
|
||||||
help="splitwise role, can be decode, prefill or mixed",
|
help="splitwise role, can be decode, prefill or mixed",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--model_id", type=str, default="default", help="model id")
|
||||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||||
@@ -109,7 +111,7 @@ def parse_args():
|
|||||||
"--kvcache_storage_backend",
|
"--kvcache_storage_backend",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
choices=["mooncake", "none"],
|
choices=["mooncake", "attention_store", "none"],
|
||||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -133,8 +135,6 @@ class CacheTransferManager:
|
|||||||
"""
|
"""
|
||||||
初始化CacheTransferManager
|
初始化CacheTransferManager
|
||||||
"""
|
"""
|
||||||
device = args.device_id
|
|
||||||
rank = args.rank
|
|
||||||
self.gpu_cache_kvs = {}
|
self.gpu_cache_kvs = {}
|
||||||
self.cpu_cache_kvs = {}
|
self.cpu_cache_kvs = {}
|
||||||
self.gpu_cache_k_tensors = []
|
self.gpu_cache_k_tensors = []
|
||||||
@@ -142,11 +142,31 @@ class CacheTransferManager:
|
|||||||
self.gpu_cache_scales_k_tensors = []
|
self.gpu_cache_scales_k_tensors = []
|
||||||
self.gpu_cache_scales_v_tensors = []
|
self.gpu_cache_scales_v_tensors = []
|
||||||
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
||||||
|
|
||||||
|
# parse kv cache shape
|
||||||
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
||||||
self.value_cache_shape = []
|
self.value_cache_shape = []
|
||||||
if args.value_cache_shape:
|
if args.value_cache_shape:
|
||||||
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
|
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
|
||||||
|
|
||||||
|
# extract kv cache shape into fields
|
||||||
self.num_gpu_blocks = self.key_cache_shape[0]
|
self.num_gpu_blocks = self.key_cache_shape[0]
|
||||||
|
self.head_num = self.key_cache_shape[1]
|
||||||
|
self.block_size = self.key_cache_shape[2]
|
||||||
|
self.head_dim = self.key_cache_shape[3]
|
||||||
|
|
||||||
|
# compute cache bytes
|
||||||
|
self.cache_dtype = args.cache_dtype
|
||||||
|
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||||
|
|
||||||
|
# extract other arg values
|
||||||
|
self.model_id = args.model_id
|
||||||
|
self.n_ranks = args.mp_num
|
||||||
|
self.rank = args.rank
|
||||||
|
self.device = args.device_id
|
||||||
|
self.num_layers = args.num_layers
|
||||||
|
self.ipc_suffix = args.ipc_suffix
|
||||||
|
self.local_data_parallel_id = args.local_data_parallel_id
|
||||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||||
paddle.set_default_dtype(args.default_dtype)
|
paddle.set_default_dtype(args.default_dtype)
|
||||||
@@ -158,18 +178,13 @@ class CacheTransferManager:
|
|||||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||||
self.n_ranks = args.mp_num
|
|
||||||
self.rank = rank
|
|
||||||
self.device = device
|
|
||||||
self.ipc_suffix = args.ipc_suffix
|
|
||||||
self.cache_dtype = args.cache_dtype
|
|
||||||
|
|
||||||
address = (args.pod_ip, args.cache_queue_port)
|
address = (args.pod_ip, args.cache_queue_port)
|
||||||
self.cache_task_queue = EngineCacheQueue(
|
self.cache_task_queue = EngineCacheQueue(
|
||||||
address=address,
|
address=address,
|
||||||
is_server=False,
|
is_server=False,
|
||||||
num_client=args.mp_num,
|
num_client=args.mp_num,
|
||||||
client_id=rank,
|
client_id=self.rank,
|
||||||
local_data_parallel_id=args.local_data_parallel_id,
|
local_data_parallel_id=args.local_data_parallel_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,8 +238,22 @@ class CacheTransferManager:
|
|||||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||||
self._init_storage_buffer(args)
|
self._init_storage_buffer(args)
|
||||||
logger.info("Initialized mooncake store successfully")
|
logger.info("Initialized mooncake store successfully")
|
||||||
|
elif args.kvcache_storage_backend == "attention_store":
|
||||||
|
logger.info("Start initialize attention store...")
|
||||||
|
self.storage_backend = AttentionStore(
|
||||||
|
namespace=self.model_id,
|
||||||
|
shard_id=self.rank,
|
||||||
|
shard_num=self.n_ranks,
|
||||||
|
layer_num=self.num_layers + self.num_extra_layers,
|
||||||
|
block_token_size=self.block_size,
|
||||||
|
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||||
|
device_id=self.device,
|
||||||
|
dp_id=self.local_data_parallel_id,
|
||||||
|
)
|
||||||
|
logger.info("Initialized attention store successfully!")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||||
|
self.storage_backend_type = args.kvcache_storage_backend
|
||||||
|
|
||||||
if args.write_policy not in ["write_through"]:
|
if args.write_policy not in ["write_through"]:
|
||||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||||
@@ -246,18 +275,16 @@ class CacheTransferManager:
|
|||||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||||
"""
|
"""
|
||||||
layer_num = args.num_layers + self.num_extra_layers
|
layer_num = self.num_layers + self.num_extra_layers
|
||||||
head_num = self.key_cache_shape[1]
|
block_num = (args.max_model_len + self.block_size - 1) // self.block_size
|
||||||
block_size = self.key_cache_shape[2]
|
|
||||||
head_dim = self.key_cache_shape[3]
|
|
||||||
block_num = (args.max_model_len + block_size - 1) // block_size
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Creating cache buffer for storage with shape: "
|
f"Creating cache buffer for storage with shape: "
|
||||||
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
|
f"[{block_num}, {layer_num}, {self.head_num}, {self.block_size}, {self.head_dim}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
self.storage_buffer_stride_bytes = (
|
||||||
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
|
layer_num * self.head_num * self.block_size * self.head_dim * self.cache_bytes
|
||||||
|
)
|
||||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||||
|
|
||||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||||
@@ -296,8 +323,8 @@ class CacheTransferManager:
|
|||||||
|
|
||||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||||
set_device(self.device)
|
set_device(self.device)
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
for i in range(self.num_layers + self.num_extra_layers):
|
||||||
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
num_gpu_blocks = self.num_gpu_blocks if i < self.num_layers else self.num_extra_layer_gpu_blocks
|
||||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}"
|
||||||
@@ -415,7 +442,7 @@ class CacheTransferManager:
|
|||||||
self.v_dst_ptrs = []
|
self.v_dst_ptrs = []
|
||||||
self.k_scales_ptrs = []
|
self.k_scales_ptrs = []
|
||||||
self.v_scales_ptrs = []
|
self.v_scales_ptrs = []
|
||||||
for i in range(args.num_layers + self.num_extra_layers):
|
for i in range(self.num_layers + self.num_extra_layers):
|
||||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}"
|
||||||
@@ -446,228 +473,283 @@ class CacheTransferManager:
|
|||||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||||
return cache_bytes
|
return cache_bytes
|
||||||
|
|
||||||
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
|
def _run_read_storage(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
token_ids: List[int],
|
||||||
|
start_read_block_idx: int,
|
||||||
|
k_cache_keys: List[str],
|
||||||
|
v_cache_keys: List[str],
|
||||||
|
gpu_block_ids: List[int],
|
||||||
|
cpu_block_ids: List[int],
|
||||||
|
timeout: float,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Given the k_keys and v_keys, get the valid blocks number that
|
Read storage data from the given blocks to the corresponding cache tensors on the current rank's GPU.
|
||||||
can be prefetched from storage backend.
|
|
||||||
"""
|
"""
|
||||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
|
||||||
result = self.storage_backend.exists(k_keys + v_keys)
|
|
||||||
|
|
||||||
# only consider the case when both key and value exist
|
|
||||||
num = 0
|
|
||||||
for k, v in zip(k_keys, v_keys):
|
|
||||||
if result[k] and result[v]:
|
|
||||||
num += 1
|
|
||||||
return num
|
|
||||||
|
|
||||||
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
if self.storage_backend_type == "mooncake":
|
||||||
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, "
|
block_num = len(gpu_block_ids)
|
||||||
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, "
|
keys = k_cache_keys + v_cache_keys
|
||||||
f"cpu_block_ids_num: {len(cpu_block_ids)}"
|
k_cache_ptrs = [
|
||||||
)
|
self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||||
|
]
|
||||||
|
v_cache_ptrs = [
|
||||||
|
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||||
|
]
|
||||||
|
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||||
|
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||||
|
start_time = time.time()
|
||||||
|
result = self.storage_backend.batch_get(
|
||||||
|
keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes
|
||||||
|
)
|
||||||
|
read_cost_time = time.time() - start_time
|
||||||
|
|
||||||
block_num = len(gpu_block_ids)
|
k_result, v_result = result[:block_num], result[block_num:]
|
||||||
keys = k_cache_keys + v_cache_keys
|
success_block_num = 0
|
||||||
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
|
for k, v in zip(k_result, v_result):
|
||||||
v_cache_ptrs = [
|
if k > 0 and v > 0:
|
||||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
success_block_num += 1
|
||||||
]
|
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||||
start_time = time.time()
|
|
||||||
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
|
||||||
read_cost_time = time.time() - start_time
|
|
||||||
|
|
||||||
k_result, v_result = result[:block_num], result[block_num:]
|
mode = 1 # cpu ==> gpu
|
||||||
success_block_num = 0
|
start_time = time.time()
|
||||||
for k, v in zip(k_result, v_result):
|
swap_cache_layout(
|
||||||
if k > 0 and v > 0:
|
self.gpu_cache_k_tensors,
|
||||||
success_block_num += 1
|
self.storage_key_read_buffer,
|
||||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
self.key_cache_shape,
|
||||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
valid_gpu_block_ids,
|
||||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
valid_cpu_block_ids,
|
||||||
|
self.device,
|
||||||
|
mode,
|
||||||
|
)
|
||||||
|
swap_cache_layout(
|
||||||
|
self.gpu_cache_v_tensors,
|
||||||
|
self.storage_value_read_buffer,
|
||||||
|
self.value_cache_shape,
|
||||||
|
valid_gpu_block_ids,
|
||||||
|
valid_cpu_block_ids,
|
||||||
|
self.device,
|
||||||
|
mode,
|
||||||
|
)
|
||||||
|
swap_cost_time = time.time() - start_time
|
||||||
|
logger.debug(
|
||||||
|
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
mode = 1 # cpu ==> gpu
|
elif self.storage_backend_type == "attention_store":
|
||||||
start_time = time.time()
|
key_cache = []
|
||||||
swap_cache_layout(
|
val_cache = []
|
||||||
self.gpu_cache_k_tensors,
|
for i in range(self.num_layers + self.num_extra_layers):
|
||||||
self.storage_key_read_buffer,
|
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||||
self.key_cache_shape,
|
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||||
valid_gpu_block_ids,
|
|
||||||
valid_cpu_block_ids,
|
start_time = time.time()
|
||||||
self.device,
|
read_block_num = self.storage_backend.read(
|
||||||
mode,
|
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_read_block_idx, timeout
|
||||||
)
|
)
|
||||||
swap_cache_layout(
|
read_cost_time = time.time() - start_time
|
||||||
self.gpu_cache_v_tensors,
|
valid_gpu_block_ids = gpu_block_ids[:read_block_num]
|
||||||
self.storage_value_read_buffer,
|
logger.debug(f"_run_read_storage, read_cost_time: {read_cost_time:.6f}s")
|
||||||
self.value_cache_shape,
|
|
||||||
valid_gpu_block_ids,
|
|
||||||
valid_cpu_block_ids,
|
|
||||||
self.device,
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
swap_cost_time = time.time() - start_time
|
|
||||||
logger.debug(
|
|
||||||
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return valid_gpu_block_ids
|
return valid_gpu_block_ids
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
|
f"An error occurred in _run_read_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||||
f"error:{e}, {traceback.format_exc()}"
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
def read_storage_task(self, task: ReadStorageTask):
|
||||||
"""Read cache from the storage backend to the GPU memory."""
|
"""Read cache from the storage backend to the GPU memory."""
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
gpu_block_ids = task.gpu_block_ids.copy()
|
||||||
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, "
|
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||||
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}"
|
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||||
)
|
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
match_block_num = 0
|
||||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
if self.storage_backend_type == "mooncake":
|
||||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
|
||||||
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
|
elif self.storage_backend_type == "attention_store":
|
||||||
|
match_block_num = self.storage_backend.query(
|
||||||
|
task.task_id, task.token_ids, task.start_read_block_idx, task.timeout
|
||||||
|
)
|
||||||
|
logger.info(f"Matched {match_block_num} blocks in cache storage for read task {task.task_id}")
|
||||||
|
|
||||||
k_cache_keys = k_cache_keys[:match_block_num]
|
k_cache_keys = k_cache_keys[:match_block_num]
|
||||||
v_cache_keys = v_cache_keys[:match_block_num]
|
v_cache_keys = v_cache_keys[:match_block_num]
|
||||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||||
cpu_block_ids = [i for i in range(match_block_num)]
|
cpu_block_ids = cpu_block_ids[:match_block_num]
|
||||||
valid_gpu_block_ids = []
|
valid_gpu_block_ids = []
|
||||||
|
|
||||||
if match_block_num > 0:
|
if match_block_num > 0:
|
||||||
# TODO: support timeout with actual block count
|
# TODO: support timeout with actual block count
|
||||||
try:
|
try:
|
||||||
valid_gpu_block_ids = self._run_read_storage(
|
valid_gpu_block_ids = self._run_read_storage(
|
||||||
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
|
task.task_id,
|
||||||
|
task.token_ids[: match_block_num * self.block_size],
|
||||||
|
task.start_read_block_idx,
|
||||||
|
k_cache_keys,
|
||||||
|
v_cache_keys,
|
||||||
|
gpu_block_ids,
|
||||||
|
cpu_block_ids,
|
||||||
|
task.timeout,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
|
f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
|
logger.error(f"Failed to read cache for task {task.task_id}, error: {e}")
|
||||||
valid_gpu_block_ids = []
|
valid_gpu_block_ids = []
|
||||||
|
|
||||||
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
|
result = (CacheStatus.STORAGE2GPU, task.task_id, task.keys, valid_gpu_block_ids)
|
||||||
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
||||||
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
||||||
self.cache_task_queue.put_transfer_done_signal(result)
|
self.cache_task_queue.put_transfer_done_signal(result)
|
||||||
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
|
logger.debug(f"read_storage_task: put transfer done signal for {task.task_id}")
|
||||||
logger.info(
|
|
||||||
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
|
|
||||||
f"valid block num {len(valid_gpu_block_ids)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
|
f"An error occurred in read_storage_task: "
|
||||||
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
|
f"task_id: {task.task_id}, error:{e}, {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
def _run_write_back_storage(
|
||||||
|
self,
|
||||||
|
task_id,
|
||||||
|
token_ids,
|
||||||
|
start_write_block_idx,
|
||||||
|
k_cache_keys,
|
||||||
|
v_cache_keys,
|
||||||
|
gpu_block_ids,
|
||||||
|
cpu_block_ids,
|
||||||
|
timeout,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
if self.storage_backend_type == "mooncake":
|
||||||
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
|
key_cache_size = [
|
||||||
f"gpu_block_ids: {gpu_block_ids}"
|
self.key_cache_shape[0],
|
||||||
)
|
self.key_cache_shape[1],
|
||||||
key_cache_size = [
|
self.key_cache_shape[2],
|
||||||
self.key_cache_shape[0],
|
self.key_cache_shape[3],
|
||||||
self.key_cache_shape[1],
|
]
|
||||||
self.key_cache_shape[2],
|
mode = 0 # gpu ==> cpu
|
||||||
self.key_cache_shape[3],
|
start_time = time.time()
|
||||||
]
|
swap_cache_layout(
|
||||||
|
self.gpu_cache_k_tensors,
|
||||||
|
self.storage_key_write_buffer,
|
||||||
|
key_cache_size,
|
||||||
|
gpu_block_ids,
|
||||||
|
cpu_block_ids,
|
||||||
|
self.device,
|
||||||
|
mode,
|
||||||
|
)
|
||||||
|
swap_cache_layout(
|
||||||
|
self.gpu_cache_v_tensors,
|
||||||
|
self.storage_value_write_buffer,
|
||||||
|
key_cache_size,
|
||||||
|
gpu_block_ids,
|
||||||
|
cpu_block_ids,
|
||||||
|
self.device,
|
||||||
|
mode,
|
||||||
|
)
|
||||||
|
swap_cost_time = time.time() - start_time
|
||||||
|
|
||||||
mode = 0 # gpu ==> cpu
|
block_num = len(gpu_block_ids)
|
||||||
start_time = time.time()
|
keys = k_cache_keys + v_cache_keys
|
||||||
swap_cache_layout(
|
k_cache_ptrs = [
|
||||||
self.gpu_cache_k_tensors,
|
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||||
self.storage_key_write_buffer,
|
]
|
||||||
key_cache_size,
|
v_cache_ptrs = [
|
||||||
gpu_block_ids,
|
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||||
cpu_block_ids,
|
]
|
||||||
self.device,
|
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||||
mode,
|
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||||
)
|
|
||||||
swap_cache_layout(
|
|
||||||
self.gpu_cache_v_tensors,
|
|
||||||
self.storage_value_write_buffer,
|
|
||||||
key_cache_size,
|
|
||||||
gpu_block_ids,
|
|
||||||
cpu_block_ids,
|
|
||||||
self.device,
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
swap_cost_time = time.time() - start_time
|
|
||||||
|
|
||||||
block_num = len(gpu_block_ids)
|
start_time = time.time()
|
||||||
keys = k_cache_keys + v_cache_keys
|
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||||
k_cache_ptrs = [
|
write_cost_time = time.time() - start_time
|
||||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
|
||||||
]
|
logger.debug(
|
||||||
v_cache_ptrs = [
|
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
||||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
)
|
||||||
]
|
return block_num
|
||||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
|
||||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
elif self.storage_backend_type == "attention_store":
|
||||||
start_time = time.time()
|
key_cache = []
|
||||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
val_cache = []
|
||||||
write_cost_time = time.time() - start_time
|
for i in range(self.num_layers + self.num_extra_layers):
|
||||||
|
key_cache.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||||
|
val_cache.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{self.rank}.device{self.device}"])
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
write_block_num = self.storage_backend.write(
|
||||||
|
task_id, key_cache, val_cache, token_ids, gpu_block_ids, start_write_block_idx, timeout
|
||||||
|
)
|
||||||
|
write_cost_time = time.time() - start_time
|
||||||
|
logger.debug(f"_run_write_back_storage, write_cost_time: {write_cost_time:.6f}s")
|
||||||
|
return write_block_num
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
|
f"An error occurred in _run_write_back_storage, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||||
f"error:{e}, {traceback.format_exc()}"
|
|
||||||
)
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
def write_back_storage_task(self, task: WriteStorageTask):
|
||||||
"""
|
"""
|
||||||
Write cache to the storage backend from the GPU memory.
|
Write cache to the storage backend from the GPU memory.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
gpu_block_ids = task.gpu_block_ids.copy()
|
||||||
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
|
|
||||||
f"task_id: {task_id}, timeout: {timeout}"
|
|
||||||
)
|
|
||||||
|
|
||||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
|
||||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
|
||||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
|
||||||
|
|
||||||
k_cache_keys = k_cache_keys[match_block_num:]
|
|
||||||
v_cache_keys = v_cache_keys[match_block_num:]
|
|
||||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
|
||||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||||
|
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||||
|
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||||
|
|
||||||
if len(k_cache_keys) == 0:
|
match_block_num = 0
|
||||||
logger.info(f"No uncached keys found for task {task_id}")
|
if self.storage_backend_type == "mooncake":
|
||||||
|
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys, task.timeout)
|
||||||
|
elif self.storage_backend_type == "attention_store":
|
||||||
|
match_block_num = self.storage_backend.query(task.task_id, task.token_ids, 0, task.timeout)
|
||||||
|
logger.info(f"Matched {match_block_num} blocks in cache storage for write task {task.task_id}")
|
||||||
|
|
||||||
|
if match_block_num >= len(k_cache_keys):
|
||||||
|
logger.info(f"No uncached keys found for task {task.task_id}")
|
||||||
gpu_block_ids = []
|
gpu_block_ids = []
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
k_cache_keys = k_cache_keys[match_block_num:]
|
||||||
|
v_cache_keys = v_cache_keys[match_block_num:]
|
||||||
|
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||||
|
cpu_block_ids = cpu_block_ids[match_block_num:]
|
||||||
# TODO: support timeout with actual block count
|
# TODO: support timeout with actual block count
|
||||||
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
|
write_block_num = self._run_write_back_storage(
|
||||||
|
task.task_id,
|
||||||
|
task.token_ids,
|
||||||
|
match_block_num,
|
||||||
|
k_cache_keys,
|
||||||
|
v_cache_keys,
|
||||||
|
gpu_block_ids,
|
||||||
|
cpu_block_ids,
|
||||||
|
task.timeout,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in write back storage task: {e}")
|
logger.error(f"Error in write back storage task: {e}")
|
||||||
gpu_block_ids = []
|
gpu_block_ids = []
|
||||||
|
|
||||||
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
|
result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids)
|
||||||
self.cache_task_queue.swap_to_storage_barrier.wait()
|
self.cache_task_queue.swap_to_storage_barrier.wait()
|
||||||
if self.rank == 0: # 只有当rank为0时执行同步操作
|
if self.rank == 0: # 只有当rank为0时执行同步操作
|
||||||
self.cache_task_queue.swap_to_storage_barrier.reset()
|
self.cache_task_queue.swap_to_storage_barrier.reset()
|
||||||
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
||||||
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
||||||
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
|
f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}"
|
||||||
f"error:{e}, {traceback.format_exc()}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _do_swap_to_cpu_task(
|
def _do_swap_to_cpu_task(
|
||||||
@@ -759,12 +841,12 @@ class CacheTransferManager:
|
|||||||
self.cache_task_queue.barrier1.reset()
|
self.cache_task_queue.barrier1.reset()
|
||||||
if self.cache_task_broadcast_signal.value[0] == 1:
|
if self.cache_task_broadcast_signal.value[0] == 1:
|
||||||
data, read_finish = self.cache_task_queue.get_transfer_task()
|
data, read_finish = self.cache_task_queue.get_transfer_task()
|
||||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
logger.debug(f"do_data_transfer: {data}")
|
||||||
if read_finish:
|
if read_finish:
|
||||||
self.cache_task_broadcast_signal.value[0] = 0
|
self.cache_task_broadcast_signal.value[0] = 0
|
||||||
event_type, transfer_task_id = data[0], data[1]
|
event_type, event_args = data[0], data[1:]
|
||||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||||
self.swap_to_cpu_thread_pool.submit(
|
self.swap_to_cpu_thread_pool.submit(
|
||||||
self._do_swap_to_cpu_task,
|
self._do_swap_to_cpu_task,
|
||||||
swap_node_ids,
|
swap_node_ids,
|
||||||
@@ -774,7 +856,7 @@ class CacheTransferManager:
|
|||||||
transfer_task_id,
|
transfer_task_id,
|
||||||
)
|
)
|
||||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
transfer_task_id, swap_node_ids, gpu_block_id, cpu_block_id = event_args
|
||||||
self.swap_to_gpu_thread_pool.submit(
|
self.swap_to_gpu_thread_pool.submit(
|
||||||
self._do_swap_to_gpu_task,
|
self._do_swap_to_gpu_task,
|
||||||
swap_node_ids,
|
swap_node_ids,
|
||||||
@@ -784,22 +866,16 @@ class CacheTransferManager:
|
|||||||
transfer_task_id,
|
transfer_task_id,
|
||||||
)
|
)
|
||||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
read_storage_task = event_args[0]
|
||||||
self.read_storage_thread_pool.submit(
|
self.read_storage_thread_pool.submit(
|
||||||
self.read_storage_task,
|
self.read_storage_task,
|
||||||
transfer_task_id,
|
read_storage_task,
|
||||||
hash_keys,
|
|
||||||
gpu_block_ids,
|
|
||||||
timeout,
|
|
||||||
)
|
)
|
||||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
write_storage_task = event_args[0]
|
||||||
self.write_back_storage_thread_pool.submit(
|
self.write_back_storage_thread_pool.submit(
|
||||||
self.write_back_storage_task,
|
self.write_back_storage_task,
|
||||||
transfer_task_id,
|
write_storage_task,
|
||||||
hash_keys,
|
|
||||||
gpu_block_ids,
|
|
||||||
timeout,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.n_ranks > 1:
|
if self.n_ranks > 1:
|
||||||
@@ -1047,7 +1123,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
if args.mp_num > 1:
|
||||||
|
logger = get_logger("cache_transfer", f"cache_transfer_{rank_id}.log")
|
||||||
|
else:
|
||||||
|
logger = get_logger("cache_transfer", "cache_transfer.log")
|
||||||
|
|
||||||
logger.info(f"args: {vars(args)}")
|
logger.info(f"args: {vars(args)}")
|
||||||
set_device(args.device_id)
|
set_device(args.device_id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ import numpy as np
|
|||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||||
|
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||||
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
@@ -47,7 +49,7 @@ class PrefixCacheManager:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config: FDConfig,
|
||||||
tensor_parallel_size,
|
tensor_parallel_size,
|
||||||
splitwise_role="mixed",
|
splitwise_role="mixed",
|
||||||
local_data_parallel_id=0,
|
local_data_parallel_id=0,
|
||||||
@@ -207,7 +209,6 @@ class PrefixCacheManager:
|
|||||||
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
|
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
|
||||||
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
|
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
|
||||||
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
|
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
|
||||||
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
|
|
||||||
if self.enable_splitwise:
|
if self.enable_splitwise:
|
||||||
cache_messager_processes = self.launch_cache_messager(
|
cache_messager_processes = self.launch_cache_messager(
|
||||||
cache_config,
|
cache_config,
|
||||||
@@ -273,6 +274,7 @@ class PrefixCacheManager:
|
|||||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||||
+ f" {sys.executable} {py_path}"
|
+ f" {sys.executable} {py_path}"
|
||||||
|
+ f" --model_id {os.path.basename(self.config.model_config.model)}"
|
||||||
+ f" --device_id {int(device_ids[i])}"
|
+ f" --device_id {int(device_ids[i])}"
|
||||||
+ f" --rank {i}"
|
+ f" --rank {i}"
|
||||||
+ f" --splitwise_role {self.splitwise_role}"
|
+ f" --splitwise_role {self.splitwise_role}"
|
||||||
@@ -390,7 +392,7 @@ class PrefixCacheManager:
|
|||||||
+ f" --ipc_suffix {ipc_suffix}"
|
+ f" --ipc_suffix {ipc_suffix}"
|
||||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||||
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
|
+ f" >{log_dir}/launch_cache_messager_{i}.log 2>&1"
|
||||||
)
|
)
|
||||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||||
@@ -789,9 +791,15 @@ class PrefixCacheManager:
|
|||||||
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
|
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
|
||||||
)
|
)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
storage_matched_block_ids = self.issue_prefetch_storage_task(
|
read_storage_task = ReadStorageTask(
|
||||||
req_id, no_match_block_keys, gpu_recv_storage_block_ids
|
task_id=req_id,
|
||||||
|
keys=no_match_block_keys,
|
||||||
|
token_ids=input_token_ids,
|
||||||
|
gpu_block_ids=gpu_recv_storage_block_ids,
|
||||||
|
start_read_block_idx=match_token_num // block_size,
|
||||||
)
|
)
|
||||||
|
logger.debug(f"issue read storage task: {read_storage_task}")
|
||||||
|
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
|
||||||
storage_matched_block_num = len(storage_matched_block_ids)
|
storage_matched_block_num = len(storage_matched_block_ids)
|
||||||
storage_match_token_num = storage_matched_block_num * block_size
|
storage_match_token_num = storage_matched_block_num * block_size
|
||||||
cost_time = time.time() - start_time
|
cost_time = time.time() - start_time
|
||||||
@@ -1006,6 +1014,12 @@ class PrefixCacheManager:
|
|||||||
if self.kvcache_storage_backend is None:
|
if self.kvcache_storage_backend is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
token_ids = request.prompt_token_ids
|
||||||
|
if isinstance(token_ids, np.ndarray):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
|
if self.config.cache_config.enable_output_caching:
|
||||||
|
token_ids += request.output_token_ids
|
||||||
|
|
||||||
req_id = request.request_id
|
req_id = request.request_id
|
||||||
keys = []
|
keys = []
|
||||||
node = self.req_leaf_map[req_id]
|
node = self.req_leaf_map[req_id]
|
||||||
@@ -1018,24 +1032,33 @@ class PrefixCacheManager:
|
|||||||
|
|
||||||
gpu_block_ids = request.block_tables[: len(keys)]
|
gpu_block_ids = request.block_tables[: len(keys)]
|
||||||
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
||||||
|
write_storage_task = WriteStorageTask(
|
||||||
|
task_id=req_id,
|
||||||
|
keys=keys,
|
||||||
|
token_ids=token_ids,
|
||||||
|
gpu_block_ids=gpu_block_ids,
|
||||||
|
)
|
||||||
|
logger.debug(f"issue write storage task: {write_storage_task}")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
|
self.issue_write_back_storage_task(write_storage_task, is_sync=True)
|
||||||
cost_time = time.time() - tic
|
cost_time = time.time() - tic
|
||||||
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
||||||
|
|
||||||
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
|
||||||
if self.kvcache_storage_backend is None:
|
if self.kvcache_storage_backend is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(hash_keys) != len(gpu_block_ids):
|
if len(task.keys) != len(task.gpu_block_ids):
|
||||||
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
|
err_msg = (
|
||||||
|
f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})"
|
||||||
|
)
|
||||||
logger.error(err_msg)
|
logger.error(err_msg)
|
||||||
raise ValueError(err_msg)
|
raise ValueError(err_msg)
|
||||||
|
|
||||||
self.task_write_back_event[req_id] = Event()
|
self.task_write_back_event[task.task_id] = Event()
|
||||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
|
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task))
|
||||||
if is_sync:
|
if is_sync:
|
||||||
self.wait_write_storage_task(req_id)
|
self.wait_write_storage_task(task.task_id)
|
||||||
|
|
||||||
def wait_write_storage_task(self, req_id):
|
def wait_write_storage_task(self, req_id):
|
||||||
"""
|
"""
|
||||||
@@ -1045,16 +1068,19 @@ class PrefixCacheManager:
|
|||||||
self.task_write_back_event[req_id].wait()
|
self.task_write_back_event[req_id].wait()
|
||||||
del self.task_write_back_event[req_id]
|
del self.task_write_back_event[req_id]
|
||||||
|
|
||||||
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
|
||||||
"""
|
"""
|
||||||
Prefetch cache from storage task
|
Prefetch cache from storage task
|
||||||
"""
|
"""
|
||||||
|
if self.kvcache_storage_backend is None:
|
||||||
|
return []
|
||||||
|
|
||||||
storage_block_ids = []
|
storage_block_ids = []
|
||||||
self.task_prefetch_event[req_id] = Event()
|
self.task_prefetch_event[task.task_id] = Event()
|
||||||
# issue task to cache_transfer_manager
|
# issue task to cache_transfer_manager
|
||||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
|
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task))
|
||||||
if is_sync:
|
if is_sync:
|
||||||
storage_block_ids = self.wait_prefetch_storage_task(req_id)
|
storage_block_ids = self.wait_prefetch_storage_task(task.task_id)
|
||||||
return storage_block_ids
|
return storage_block_ids
|
||||||
|
|
||||||
def wait_prefetch_storage_task(self, req_id):
|
def wait_prefetch_storage_task(self, req_id):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from .kvcache_storage import KVCacheStorage
|
from .kvcache_storage import KVCacheStorage
|
||||||
from .mooncake_store import MooncakeStore
|
from .mooncake_store import AttentionStore, MooncakeStore
|
||||||
from .rdma_cache_transfer import RDMACommManager
|
from .rdma_cache_transfer import RDMACommManager
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
@@ -31,4 +31,5 @@ __all__ = [
|
|||||||
"RDMACommManager",
|
"RDMACommManager",
|
||||||
"KVCacheStorage",
|
"KVCacheStorage",
|
||||||
"MooncakeStore",
|
"MooncakeStore",
|
||||||
|
"AttentionStore",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -95,3 +95,10 @@ class KVCacheStorage(ABC):
|
|||||||
Clear all keys in storage
|
Clear all keys in storage
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(self) -> int:
|
||||||
|
"""
|
||||||
|
Query the number of blocks stored in the storage.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .attention_store import AttentionStore
|
||||||
from .mooncake_store import MooncakeStore
|
from .mooncake_store import MooncakeStore
|
||||||
|
|
||||||
__all__ = ["MooncakeStore"]
|
__all__ = ["MooncakeStore", "AttentionStore"]
|
||||||
|
|||||||
@@ -0,0 +1,209 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.cache_manager.transfer_factory.kvcache_storage import (
|
||||||
|
KVCacheStorage,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens
|
||||||
|
from attentionstore_sdk.utils.err import AttentionStoreSDKError
|
||||||
|
|
||||||
|
_ATTENTIONSTORE_AVAILABLE = True
|
||||||
|
except Exception:
|
||||||
|
AttentionStoreSDK = None
|
||||||
|
Tokens = None
|
||||||
|
AttentionStoreSDKError = None
|
||||||
|
_ATTENTIONSTORE_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AttentionStoreConfig:
|
||||||
|
namespace: str = "default_ns"
|
||||||
|
pod_name: str = "default_pod"
|
||||||
|
model_version: str = "v0"
|
||||||
|
shard_id: int = 0
|
||||||
|
shard_num: int = 1
|
||||||
|
layer_num: int = 1
|
||||||
|
block_token_size: int = 64
|
||||||
|
bytes_per_shard_layer_per_block: int = 1024
|
||||||
|
device_id: int = 0
|
||||||
|
dp_id: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionStore(KVCacheStorage):
|
||||||
|
def __init__(self, **args):
|
||||||
|
|
||||||
|
if not _ATTENTIONSTORE_AVAILABLE:
|
||||||
|
raise ImportError("Please install attentionstore_sdk to run Fastdeploy with attentionstore_sdk.")
|
||||||
|
|
||||||
|
self.config = AttentionStoreConfig(**args)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}")
|
||||||
|
self.sdk = AttentionStoreSDK(
|
||||||
|
self.config.namespace,
|
||||||
|
self.config.pod_name,
|
||||||
|
self.config.model_version,
|
||||||
|
self.config.shard_id,
|
||||||
|
self.config.shard_num,
|
||||||
|
self.config.layer_num,
|
||||||
|
self.config.block_token_size,
|
||||||
|
self.config.bytes_per_shard_layer_per_block,
|
||||||
|
self.config.device_id,
|
||||||
|
self.config.dp_id,
|
||||||
|
)
|
||||||
|
self.wait_for_sdk_ready(timeout=300, delta_t=5)
|
||||||
|
logger.info("[INIT] ✅ AttentionStore is initialized successfully!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[INIT] ❌ AttentionStore initialization failed, error: {e}, traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for_sdk_ready(self, timeout: float, delta_t: float):
|
||||||
|
t = 0
|
||||||
|
while t < timeout:
|
||||||
|
try:
|
||||||
|
tokens = Tokens(list(range(self.config.block_token_size + 1)), self.config.block_token_size)
|
||||||
|
self.sdk.match(tokens, 0, delta_t)
|
||||||
|
return
|
||||||
|
except AttentionStoreSDKError as e:
|
||||||
|
if "cuda memory not ready" in str(e):
|
||||||
|
logger.debug("[INIT] cuda memory not ready, try again..")
|
||||||
|
time.sleep(delta_t)
|
||||||
|
t += delta_t
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected exception during AttentionStoreSDK initialization: {e}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise TimeoutError(f"AttentionStoreSDK initialization timed out after {timeout} seconds")
|
||||||
|
|
||||||
|
def read(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
key_cache: List[paddle.Tensor],
|
||||||
|
val_cache: List[paddle.Tensor],
|
||||||
|
token_ids: List[int],
|
||||||
|
gpu_block_ids: List[int],
|
||||||
|
start_read_block_idx: int,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"[READ BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_read_block_idx: {start_read_block_idx} timeout: {timeout}"
|
||||||
|
)
|
||||||
|
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||||
|
k_data_ptrs = [k.data_ptr() for k in key_cache]
|
||||||
|
v_data_ptrs = [v.data_ptr() for v in val_cache]
|
||||||
|
num = 0
|
||||||
|
try:
|
||||||
|
num = self.sdk.read(
|
||||||
|
list(range(self.config.layer_num)),
|
||||||
|
tokens,
|
||||||
|
start_read_block_idx,
|
||||||
|
k_data_ptrs,
|
||||||
|
v_data_ptrs,
|
||||||
|
gpu_block_ids,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}")
|
||||||
|
except AttentionStoreSDKError:
|
||||||
|
logger.error(
|
||||||
|
f"[READ ERROR] failed to execute sdk read, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return num
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
key_cache: List[paddle.Tensor],
|
||||||
|
val_cache: List[paddle.Tensor],
|
||||||
|
token_ids: List[int],
|
||||||
|
gpu_block_ids: List[int],
|
||||||
|
start_write_block_idx: int,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> int:
|
||||||
|
logger.debug(
|
||||||
|
f"[WRITE BEGIN] task_id: {task_id} token_ids: {token_ids} gpu_block_ids: {gpu_block_ids} start_write_block_idx: {start_write_block_idx} timeout: {timeout}"
|
||||||
|
)
|
||||||
|
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||||
|
k_data_ptrs = [k.data_ptr() for k in key_cache]
|
||||||
|
v_data_ptrs = [v.data_ptr() for v in val_cache]
|
||||||
|
num = 0
|
||||||
|
try:
|
||||||
|
num = self.sdk.write(
|
||||||
|
list(range(self.config.layer_num)),
|
||||||
|
tokens,
|
||||||
|
start_write_block_idx,
|
||||||
|
k_data_ptrs,
|
||||||
|
v_data_ptrs,
|
||||||
|
gpu_block_ids,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}")
|
||||||
|
except AttentionStoreSDKError:
|
||||||
|
logger.error(
|
||||||
|
f"[WRITE ERROR] failed to execute sdk write, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return num
|
||||||
|
|
||||||
|
def query(self, task_id: str, token_ids: List[int], start_match_block_idx: int, timeout: float = 10.0):
|
||||||
|
"""
|
||||||
|
Given the input ids and starting index to match, get the valid blocks number that
|
||||||
|
can be prefetched from storage backend.
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"[QUERY BEGIN] task_id: {task_id} token_ids: {token_ids} start_match_block_idx: {start_match_block_idx} timeout: {timeout}"
|
||||||
|
)
|
||||||
|
tokens = Tokens(token_ids, self.config.block_token_size)
|
||||||
|
num = 0
|
||||||
|
try:
|
||||||
|
num = self.sdk.match(tokens, start_match_block_idx, timeout)
|
||||||
|
logger.debug(f"[QUERY END] task_id: {task_id} matched_blocks: {num}")
|
||||||
|
except AttentionStoreSDKError:
|
||||||
|
logger.error(
|
||||||
|
f"[QUERY ERROR] Failed to execute sdk match, task_id: {task_id}, traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return num
|
||||||
|
|
||||||
|
def get(self, **kwargs):
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def batch_get(self, **kwargs):
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def set(self, **kwargs) -> bool:
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def batch_set(self, **kwargs) -> bool:
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def exists(self, keys: List[str]) -> bool:
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def clear(self) -> bool:
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
|
|
||||||
|
def register_buffer(self, buffer_ptr, buffer_size, buffer_type="none_type") -> None:
|
||||||
|
raise NotImplementedError("AttentionStore does not support this method")
|
||||||
@@ -237,6 +237,21 @@ class MooncakeStore(KVCacheStorage):
|
|||||||
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
|
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def query(self, k_keys: List[str], v_keys: List[str], timeout: float = 1.0):
|
||||||
|
"""
|
||||||
|
Given the k_keys and v_keys, get the valid blocks number that
|
||||||
|
can be prefetched from storage backend.
|
||||||
|
"""
|
||||||
|
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||||
|
result = self.exists(k_keys + v_keys)
|
||||||
|
|
||||||
|
# only consider the case when both key and value exist
|
||||||
|
num = 0
|
||||||
|
for k, v in zip(k_keys, v_keys):
|
||||||
|
if result[k] and result[v]:
|
||||||
|
num += 1
|
||||||
|
return num
|
||||||
|
|
||||||
def delete(self, key, timeout=5) -> bool:
|
def delete(self, key, timeout=5) -> bool:
|
||||||
while timeout:
|
while timeout:
|
||||||
result = self.store.remove(key)
|
result = self.store.remove(key)
|
||||||
|
|||||||
@@ -629,7 +629,6 @@ class EngineArgs:
|
|||||||
for port in cur_dp_ports:
|
for port in cur_dp_ports:
|
||||||
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
|
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
|
||||||
|
|
||||||
console_logger.debug(f"post init {name}: {ports}")
|
|
||||||
return ports
|
return ports
|
||||||
|
|
||||||
num_nodes = len(self.ips) if self.ips else 1
|
num_nodes = len(self.ips) if self.ips else 1
|
||||||
@@ -1077,7 +1076,7 @@ class EngineArgs:
|
|||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--kvcache-storage-backend",
|
"--kvcache-storage-backend",
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
choices=["mooncake"],
|
choices=["mooncake", "attention_store"],
|
||||||
default=EngineArgs.kvcache_storage_backend,
|
default=EngineArgs.kvcache_storage_backend,
|
||||||
help="The storage backend for kvcache storage. Leave empty to disable.",
|
help="The storage backend for kvcache storage. Leave empty to disable.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -225,10 +225,6 @@ class EngineService:
|
|||||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||||
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||||
|
|
||||||
# Set cache manager signal
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
||||||
self.launched_cache_manager_signal.value[0] = 1
|
|
||||||
|
|
||||||
# Worker launched
|
# Worker launched
|
||||||
self.check_worker_initialize_status_func_thread.join()
|
self.check_worker_initialize_status_func_thread.join()
|
||||||
if not result_container["worker_is_alive"]:
|
if not result_container["worker_is_alive"]:
|
||||||
|
|||||||
@@ -181,10 +181,6 @@ class LLMEngine:
|
|||||||
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
device_ids = self.cfg.parallel_config.device_ids.split(",")
|
||||||
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
|
||||||
|
|
||||||
# Launch components: scheduler, cache_manager, expert_service et.al.
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
|
||||||
self.launched_cache_manager_signal.value[0] = 1
|
|
||||||
|
|
||||||
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
|
if self.cfg.scheduler_config.splitwise_role != "mixed" and envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||||
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
|
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[0]
|
||||||
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
|
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[0]
|
||||||
|
|||||||
@@ -1338,8 +1338,11 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
def update_metrics(self):
|
def update_metrics(self):
|
||||||
# Update metrics
|
# Update metrics
|
||||||
num_tasks = sum([1 if task else 0 for task in self.tasks_list])
|
num_tasks = sum([1 if task else 0 for task in self.tasks_list])
|
||||||
num_blocks_used_by_tasks = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
|
blocks_used_by_tasks = set()
|
||||||
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - num_blocks_used_by_tasks)
|
for task in self.tasks_list:
|
||||||
|
if task is not None:
|
||||||
|
blocks_used_by_tasks.update(task.block_tables)
|
||||||
|
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks))
|
||||||
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch())
|
||||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||||
main_process_metrics.num_requests_running.set(len(self.running))
|
main_process_metrics.num_requests_running.set(len(self.running))
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class Args:
|
|||||||
mp_num = 1
|
mp_num = 1
|
||||||
device_id = 0
|
device_id = 0
|
||||||
speculative_config = {}
|
speculative_config = {}
|
||||||
|
model_id = "test_model"
|
||||||
ipc_suffix = "test_ipc_suffix"
|
ipc_suffix = "test_ipc_suffix"
|
||||||
cache_queue_port = 9999
|
cache_queue_port = 9999
|
||||||
pod_ip = "127.0.0.1"
|
pod_ip = "127.0.0.1"
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ def _create_manager(
|
|||||||
swap_space=4,
|
swap_space=4,
|
||||||
)
|
)
|
||||||
model_config = SimpleNamespace(
|
model_config = SimpleNamespace(
|
||||||
|
model="test_model",
|
||||||
num_attention_heads=1,
|
num_attention_heads=1,
|
||||||
num_key_value_heads=1,
|
num_key_value_heads=1,
|
||||||
head_dim=1,
|
head_dim=1,
|
||||||
|
|||||||
@@ -332,8 +332,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
|||||||
self.assertFalse(ok)
|
self.assertFalse(ok)
|
||||||
# cache manager started before workers (lines 184-185)
|
# cache manager started before workers (lines 184-185)
|
||||||
self.assertTrue(started_cache.get("called", False))
|
self.assertTrue(started_cache.get("called", False))
|
||||||
# launched_cache_manager_signal set (line 221)
|
|
||||||
self.assertEqual(int(eng.launched_cache_manager_signal.value[0]), 1)
|
|
||||||
# avoid atexit finalizer
|
# avoid atexit finalizer
|
||||||
if hasattr(eng, "_finalizer"):
|
if hasattr(eng, "_finalizer"):
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user