mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] [KVCache] support attention_store kv cache backend (#5823)
* [feat] support attention_store kv cache backend * [fix] fix codestyle * [chore] optimize log * [fix] fix write storage task * [fix] fix read storage * [fix] fix code conflict after merge develop * [fix] fix cache bytes and read task token ids * [chore] add model for cache transfer manager * [chore] add some log * [chore] remove launched_cache_manager_signal * [fix] fix write_back_storage_task match_block_num condition * [fix] fix swap_cost_time * [ci] fix ci * Update fastdeploy/engine/sched/resource_manager_v1.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/cache_transfer_manager.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -31,7 +31,9 @@ import numpy as np
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
@@ -47,7 +49,7 @@ class PrefixCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: FDConfig,
|
||||
tensor_parallel_size,
|
||||
splitwise_role="mixed",
|
||||
local_data_parallel_id=0,
|
||||
@@ -207,7 +209,6 @@ class PrefixCacheManager:
|
||||
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
|
||||
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
|
||||
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
|
||||
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
|
||||
if self.enable_splitwise:
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
@@ -273,6 +274,7 @@ class PrefixCacheManager:
|
||||
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
|
||||
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
|
||||
+ f" {sys.executable} {py_path}"
|
||||
+ f" --model_id {os.path.basename(self.config.model_config.model)}"
|
||||
+ f" --device_id {int(device_ids[i])}"
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
@@ -390,7 +392,7 @@ class PrefixCacheManager:
|
||||
+ f" --ipc_suffix {ipc_suffix}"
|
||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
|
||||
+ f" >{log_dir}/launch_cache_messager_{i}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache messager, command:{launch_cmd}")
|
||||
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
@@ -789,9 +791,15 @@ class PrefixCacheManager:
|
||||
f"start prefetch cache from storage, req_id: {req_id}, block num: {len(no_match_block_keys)}"
|
||||
)
|
||||
start_time = time.time()
|
||||
storage_matched_block_ids = self.issue_prefetch_storage_task(
|
||||
req_id, no_match_block_keys, gpu_recv_storage_block_ids
|
||||
read_storage_task = ReadStorageTask(
|
||||
task_id=req_id,
|
||||
keys=no_match_block_keys,
|
||||
token_ids=input_token_ids,
|
||||
gpu_block_ids=gpu_recv_storage_block_ids,
|
||||
start_read_block_idx=match_token_num // block_size,
|
||||
)
|
||||
logger.debug(f"issue read storage task: {read_storage_task}")
|
||||
storage_matched_block_ids = self.issue_prefetch_storage_task(read_storage_task)
|
||||
storage_matched_block_num = len(storage_matched_block_ids)
|
||||
storage_match_token_num = storage_matched_block_num * block_size
|
||||
cost_time = time.time() - start_time
|
||||
@@ -1006,6 +1014,12 @@ class PrefixCacheManager:
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
token_ids = request.prompt_token_ids
|
||||
if isinstance(token_ids, np.ndarray):
|
||||
token_ids = token_ids.tolist()
|
||||
if self.config.cache_config.enable_output_caching:
|
||||
token_ids += request.output_token_ids
|
||||
|
||||
req_id = request.request_id
|
||||
keys = []
|
||||
node = self.req_leaf_map[req_id]
|
||||
@@ -1018,24 +1032,33 @@ class PrefixCacheManager:
|
||||
|
||||
gpu_block_ids = request.block_tables[: len(keys)]
|
||||
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
||||
write_storage_task = WriteStorageTask(
|
||||
task_id=req_id,
|
||||
keys=keys,
|
||||
token_ids=token_ids,
|
||||
gpu_block_ids=gpu_block_ids,
|
||||
)
|
||||
logger.debug(f"issue write storage task: {write_storage_task}")
|
||||
tic = time.time()
|
||||
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
|
||||
self.issue_write_back_storage_task(write_storage_task, is_sync=True)
|
||||
cost_time = time.time() - tic
|
||||
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
||||
|
||||
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
if len(hash_keys) != len(gpu_block_ids):
|
||||
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
|
||||
if len(task.keys) != len(task.gpu_block_ids):
|
||||
err_msg = (
|
||||
f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})"
|
||||
)
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
self.task_write_back_event[req_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
self.task_write_back_event[task.task_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task))
|
||||
if is_sync:
|
||||
self.wait_write_storage_task(req_id)
|
||||
self.wait_write_storage_task(task.task_id)
|
||||
|
||||
def wait_write_storage_task(self, req_id):
|
||||
"""
|
||||
@@ -1045,16 +1068,19 @@ class PrefixCacheManager:
|
||||
self.task_write_back_event[req_id].wait()
|
||||
del self.task_write_back_event[req_id]
|
||||
|
||||
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
|
||||
"""
|
||||
Prefetch cache from storage task
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return []
|
||||
|
||||
storage_block_ids = []
|
||||
self.task_prefetch_event[req_id] = Event()
|
||||
self.task_prefetch_event[task.task_id] = Event()
|
||||
# issue task to cache_transfer_manager
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, task))
|
||||
if is_sync:
|
||||
storage_block_ids = self.wait_prefetch_storage_task(req_id)
|
||||
storage_block_ids = self.wait_prefetch_storage_task(task.task_id)
|
||||
return storage_block_ids
|
||||
|
||||
def wait_prefetch_storage_task(self, req_id):
|
||||
|
||||
Reference in New Issue
Block a user