[Feature] Support KV Cache Storage (#5571)

* Support Mooncake Store

* up

* up

* add op

* fix conflict

* fix error

* up for comments

* avoid thread lock

* up

* fix unittest

* fix unittest

* remove debug info

* consider tp_size > 1

* add default rdma_nics

* add utils

* up

* fix error

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
Juncai
2025-12-25 16:30:35 +08:00
committed by GitHub
parent be3be4913a
commit 412867fd99
27 changed files with 1672 additions and 195 deletions
+4
View File
@@ -2,6 +2,10 @@
/.venv/
/venv/
tests/log_*
benchmarks/openai-chat-infqps*
splitwise/log*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
+18 -1
View File
@@ -102,7 +102,7 @@ def metrics_summary(metrics, token_timestamps):
# prefill 总耗时
summary["prefill_cost_time"] = safe_cost(m0.get("send_request_output_to_decode_time"), arrival_time)
# prefill准备耗时
# prefill准备耗时
summary["prefill_prepare_cost_time"] = safe_cost(inference_start_time, arrival_time)
# 预处理耗时
summary["preprocess_cost_time"] = safe_cost(m0.get("scheduler_recv_req_time"), arrival_time)
@@ -114,6 +114,10 @@ def metrics_summary(metrics, token_timestamps):
summary["ask_decode_resource_cost_time"] = safe_cost(
m0.get("ask_decode_resource_finish_time"), m0.get("ask_decode_resource_start_time")
)
# scheduler调度耗时
summary["schedule_cost_time"] = safe_cost(
m0.get("inference_start_time"), m0.get("ask_decode_resource_finish_time")
)
# prefill 的首 token 推理耗时
summary["prefill_first_token_infer_cost_time"] = safe_cost(
m0.get("engine_recv_first_token_time"), inference_start_time
@@ -143,6 +147,19 @@ def metrics_summary(metrics, token_timestamps):
token_timestamps[1], m_last.get("decode_recv_second_token_time")
)
# MIX 模式下,scheduler调度耗时
summary["mixed_schedule_cost_time"] = safe_cost(m0.get("inference_start_time"), m0.get("engine_get_req_time"))
# MIX 模式下,返回首 token 链路耗时
summary["mixed_first_token_transmission_cost_time"] = safe_cost(
token_timestamps[0], m0.get("engine_recv_first_token_time")
)
summary["gpu_cache_token_num"] = m0.get("gpu_cache_token_num")
summary["cpu_cache_token_num"] = m0.get("cpu_cache_token_num")
summary["storage_cache_token_num"] = m0.get("storage_cache_token_num")
summary["gpu_cpu_cache_prepare_time"] = m0.get("gpu_cpu_cache_prepare_time")
summary["storage_cache_prepare_time"] = m0.get("storage_cache_prepare_time")
return summary
+17 -5
View File
@@ -695,7 +695,7 @@ async def benchmark(
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
def process_pd_metrics(model_outputs, metric_key):
def process_pd_metrics(model_outputs, metric_key, is_time=True):
# 收集所有该 metric 的数值
values = []
percentiles = []
@@ -712,24 +712,29 @@ async def benchmark(
print(f"[WARN] metric_key '{metric_key}' not found in outputs.")
return
arr = np.array(values) * 1000 # 秒 -> 毫秒
if is_time:
arr = np.array(values) * 1000 # 秒 -> 毫秒
suffix = "(ms)"
else:
arr = np.array(values)
suffix = ""
print("{s:{c}^{n}}".format(s=metric_key, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_key} (ms):",
f"Mean {metric_key} {suffix}:",
np.mean(arr),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_key} (ms):",
f"Median {metric_key} {suffix}:",
np.median(arr),
)
)
for p in percentiles:
v = np.percentile(arr, p)
print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms):", v))
print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} {suffix}:", v))
# print(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms): {v:10.2f}")
print(
"{:<40} {:<10.2f}".format(
@@ -785,6 +790,7 @@ async def benchmark(
process_pd_metrics(outputs, "prefill_prepare_cost_time")
process_pd_metrics(outputs, "preprocess_cost_time")
process_pd_metrics(outputs, "cache_in_scheduler_cost_time")
process_pd_metrics(outputs, "schedule_cost_time")
process_pd_metrics(outputs, "ask_decode_resource_cost_time")
process_pd_metrics(outputs, "prefill_first_token_infer_cost_time")
process_pd_metrics(outputs, "wait_sending_cache_cost_time")
@@ -793,6 +799,12 @@ async def benchmark(
process_pd_metrics(outputs, "decode_second_token_infer_cost_time")
process_pd_metrics(outputs, "first_token_transmission_cost_time")
process_pd_metrics(outputs, "second_token_transmission_cost_time")
process_pd_metrics(outputs, "mixed_schedule_cost_time")
process_pd_metrics(outputs, "gpu_cache_token_num", is_time=False)
process_pd_metrics(outputs, "cpu_cache_token_num", is_time=False)
process_pd_metrics(outputs, "storage_cache_token_num", is_time=False)
process_pd_metrics(outputs, "gpu_cpu_cache_prepare_time")
process_pd_metrics(outputs, "storage_cache_prepare_time")
process_one_length("input_len", "Cached Tokens", "Cached Tokens")
process_one_length("s_input_len", "Input Length", "Infer Input Length")
process_one_length("reasoning_len", "Reasoning Lenth", "思考长度")
+134
View File
@@ -0,0 +1,134 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
// #define SWAP_DEBUG
template <paddle::DataType D>
void SwapCacheImpLayout(
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const int64_t& cache_cpu_pointer, // cpu
const std::vector<int64_t>& cache_shape,
const std::vector<int64_t>& gpu_block_ids,
const std::vector<int64_t>& cpu_block_ids,
int mode) {
// mode is 0: gpu to cpu; 1: cpu to gpu
// cache layout: layer_num * [block_num, head_num, block_size, head_dim]
// buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int64_t layer_number = cache_gpu_tensors.size();
const int64_t num_heads = cache_shape[1];
const int64_t block_size = cache_shape[2];
const int64_t head_dim = cache_shape[3];
const int64_t cache_block_stride = num_heads * block_size * head_dim;
#ifdef SWAP_DEBUG
std::cout << "layer_number:" << layer_number << std::endl;
std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1]
<< ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl;
std::cout << "cache_block_stride:" << cache_block_stride << std::endl;
#endif
auto stream = cache_gpu_tensors[0].stream();
const cudaMemcpyKind copy_kind =
(mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) {
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
// auto stream = cache_gpu.stream();
for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) {
auto cur_gpu_block_id = gpu_block_ids[block_idx];
auto cur_cpu_block_id = cpu_block_ids[block_idx];
auto* cache_gpu_ptr_now =
cache_gpu_ptr + cur_gpu_block_id * cache_block_stride;
auto* cache_cpu_ptr_now =
cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number +
layer_idx * cache_block_stride;
cudaError_t status = cudaMemcpyAsync(
(copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now
: cache_gpu_ptr_now,
(copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now
: cache_cpu_ptr_now,
cache_block_stride * sizeof(DataType_),
copy_kind,
stream);
#ifdef SWAP_DEBUG
cudaStreamSynchronize(stream);
std::cout << "mode:" << mode << ", layer_idx:" << layer_idx
<< ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:"
<< static_cast<float>(*cache_cpu_ptr_now) << std::endl;
#endif
}
}
cudaStreamSynchronize(stream);
}
void SwapCacheLayout(
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
const int64_t& cache_cpu_ptrs, // cpu memory pointer
const std::vector<int64_t>& cache_shape,
const std::vector<int64_t>& gpu_block_ids,
const std::vector<int64_t>& cpu_block_ids,
int rank,
int mode) {
cudaSetDevice(rank); // used for distributed launch
assert(cache_gpu_tensors.size() > 0);
switch (cache_gpu_tensors[0].dtype()) {
case paddle::DataType::BFLOAT16:
return SwapCacheImpLayout<paddle::DataType::BFLOAT16>(cache_gpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
case paddle::DataType::FLOAT16:
return SwapCacheImpLayout<paddle::DataType::FLOAT16>(cache_gpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
case paddle::DataType::UINT8:
return SwapCacheImpLayout<paddle::DataType::UINT8>(cache_gpu_tensors,
cache_cpu_ptrs,
cache_shape,
gpu_block_ids,
cpu_block_ids,
mode);
default:
PD_THROW("Unsupported data type.");
}
}
PD_BUILD_STATIC_OP(swap_cache_layout)
.Inputs({paddle::Vec("cache_gpu_tensors")})
.Attrs({
"cache_cpu_ptrs: int64_t",
"cache_shape: std::vector<int64_t>",
"gpu_block_ids: std::vector<int64_t>",
"cpu_block_ids: std::vector<int64_t>",
"rank: int",
"mode: int",
})
.SetKernelFn(PD_KERNEL(SwapCacheLayout));
+1
View File
@@ -288,6 +288,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/tune_cublaslt_gemm.cu",
"gpu_ops/swap_cache_batch.cu",
"gpu_ops/swap_cache.cu",
"gpu_ops/swap_cache_layout.cu",
"gpu_ops/step_system_cache.cu",
"gpu_ops/cpp_extensions.cc",
"gpu_ops/share_external_data.cu",
+57
View File
@@ -0,0 +1,57 @@
# MooncakeStore for FastDeploy
This document describes how to use MooncakeStore as the backend of FastDeploy.
## Preparation
### Install FastDeploy
Refer to [NVIDIA CUDA GPU Installation](https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/) for Fastdeploy installation.
### Install MooncakeStore
```bash
pip install mooncake-transfer-engine
```
## Run Examples
The example script is provided in `run.sh`. You can run it directly:
```
bash run.sh
```
In the example script, we will start a Mooncake master server and two FastDeploy server.
Launch Mooncake master server:
```bash
mooncake_master \
--port=15001 \
--enable_http_metadata_server=true \
--http_metadata_server_host=0.0.0.0 \
--http_metadata_server_port=15002 \
--metrics_port=15003 \
```
More parameter can be found in the [official guide](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/python-api-reference/transfer-engine.md).
Launch the Fastdeploy with Mooncake enabled.
```bash
export MOONCAKE_CONFIG_PATH="./mooncake_config.json"
python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port ${PORT} \
--metrics-port $((PORT + 1)) \
--engine-worker-queue-port $((PORT + 2)) \
--cache-queue-port $((PORT + 3)) \
--max-model-len 32768 \
--max-num-seqs 32 \
--kvcache-storage-backend mooncake
```
## Troubleshooting
For more details, please refer to:
https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/troubleshooting.md
@@ -0,0 +1,9 @@
{
"local_hostname":"localhost",
"metadata_server":"http://0.0.0.0:15002/metadata",
"global_segment_size":8589934592,
"local_buffer_size":134217728,
"protocol":"rdma",
"rdma_devices": "mlx5_1,mlx5_2,mlx5_3,mlx5_4",
"master_server_addr":"0.0.0.0:15001"
}
+100
View File
@@ -0,0 +1,100 @@
#!/bin/bash
set -e
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
export MOONCAKE_CONFIG_PATH=./mooncake_config.json
export FD_DEBUG=1
unset http_proxy && unset https_proxy
rm -rf log_*
bash stop.sh
source ./utils.sh
S0_PORT=52700
S1_PORT=52800
ports=(
$S0_PORT $((S0_PORT + 1)) $((S0_PORT + 2)) $((S0_PORT + 3))
$S1_PORT $((S1_PORT + 1)) $((S1_PORT + 2)) $((S1_PORT + 3))
$ROUTER_PORT
)
check_ports "${ports[@]}" || {
echo "❌ Some ports are in use. Please release them."
exit 1
}
# Launch MoonCake master
nohup mooncake_master \
--port=15001 \
--enable_http_metadata_server=true \
--http_metadata_server_host=0.0.0.0 \
--http_metadata_server_port=15002 \
--metrics_port=15003 \
2>&1 > log_master &
# Launch FD server 0
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log_0"
mkdir -p ${FD_LOG_DIR}
echo "server 0 port: ${S0_PORT}"
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port ${S0_PORT} \
--metrics-port $((S0_PORT + 1)) \
--engine-worker-queue-port $((S0_PORT + 2)) \
--cache-queue-port $((S0_PORT + 3)) \
--max-model-len 32768 \
--max-num-seqs 32 \
--kvcache-storage-backend mooncake \
2>&1 >${FD_LOG_DIR}/nohup &
# Launch FD server 1
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log_1"
mkdir -p ${FD_LOG_DIR}
echo "server 1 port: ${S1_PORT}"
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port ${S1_PORT} \
--metrics-port $((S1_PORT + 1)) \
--engine-worker-queue-port $((S1_PORT + 2)) \
--cache-queue-port $((S1_PORT + 3)) \
--max-model-len 32768 \
--max-num-seqs 32 \
--kvcache-storage-backend mooncake \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health ${S0_PORT}
wait_for_health ${S1_PORT}
# send request
msg="深圳是中国经济实力最强的城市之一。近年来,深圳 GDP 持续稳步增长,**2023 年突破 3.4 万亿元人民币,2024 年接近 3.7 万亿元**,长期位居全国城市前列。深圳经济以第二产业和第三产业为主,高端制造业、电子信息产业和现代服务业发达,形成了以科技创新为核心的产业结构。依托华为、腾讯、大疆等龙头企业,深圳在数字经济、人工智能、新能源等领域具有显著优势。同时,深圳进出口总额常年位居全国城市第一,是中国对外开放和高质量发展的重要引擎。深圳2024年 GDP 是多少?"
echo "send request to server_0"
curl -X POST "http://0.0.0.0:${S0_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
\"messages\": [
{\"role\": \"user\", \"content\": \"${msg}\"}
],
\"max_tokens\": 50,
\"stream\": false,
\"top_p\": 0
}"
sleep 5
echo "send request to server_1"
curl -X POST "http://0.0.0.0:${S1_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{
\"messages\": [
{\"role\": \"user\", \"content\": \"${msg}\"}
],
\"max_tokens\": 50,
\"stream\": false,
\"top_p\": 0
}"
+97
View File
@@ -0,0 +1,97 @@
#!/bin/bash
is_port_free() {
local port=$1
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
return 1 # Port is occupied
fi
return 0 # Port is free
}
check_ports() {
for port in "$@"; do
if ! is_port_free $port; then
echo "❌ Port $port is already in use"
return 1
fi
done
return 0
}
wait_for_health() {
IFS=',' read -r -a server_ports <<< "$1"
local num_ports=${#server_ports[@]}
local total_lines=$((num_ports + 1))
local first_run=true
local GREEN='\033[0;32m'
local RED='\033[0;31m'
local NC='\033[0m' # No Color
local start_time=$(date +%s)
while true; do
local all_ready=true
for port in "${server_ports[@]}"; do
status_code=$(curl -s --max-time 1 -o /dev/null -w "%{http_code}" "http://0.0.0.0:${port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
printf "Port %s: ${GREEN}[OK] 200${NC}\033[K\n" "$port"
else
all_ready=false
printf "Port %s: ${RED}[WAIT] %s${NC}\033[K\n" "$port" "$status_code"
fi
done
cur_time=$(date +%s)
if [ "$all_ready" = "true" ]; then
echo "All services are ready! [$((cur_time-start_time))s]"
break
else
echo "Waiting for services... [$((cur_time-start_time))s]"
printf "\033[%dA" "$total_lines" # roll back cursor
sleep 1
fi
done
}
get_free_ports() {
free_ports_num=${1:-1}
start_port=${2:-8000}
end_port=${3:-9000}
free_ports=()
if [[ ! -n ${free_ports_num} || "${free_ports_num}" -le 0 ]]; then
log_warn "param can't be empty, and should > 0"
echo ${free_ports[@]}
return 1
fi
used_ports1=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($4,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
used_ports2=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($5,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
all_used_ports=$(printf "%s\n" "${used_ports1}" "${used_ports2}" | sort -u)
# Generate random number between 0 and 32767
random_num=$(( RANDOM ))
port=$(( random_num % (end_port - start_port + 1) + start_port ))
while true; do
(( port++ ))
if [[ ${port} -ge ${end_port} ]]; then
port=${start_port}
fi
if [[ "${all_used_ports[@]}" =~ "${port}" ]]; then
continue
fi
if is_port_free ${port}; then
free_ports+=("${port}")
(( free_ports_num-- ))
if [[ ${free_ports_num} = 0 ]]; then
break
fi
fi
done
# echo ${free_ports[@]}
IFS=',' && echo "${free_ports[*]}"
return 0
}
+2
View File
@@ -30,6 +30,8 @@ class CacheStatus(Enum):
SWAP2CPU = 1
SWAP2GPU = 2
CPU = 3
GPU2STORAGE = 4
STORAGE2GPU = 5
class BlockNode:
@@ -22,6 +22,7 @@ import queue
import threading
import time
import traceback
from typing import List
import numpy as np
import paddle
@@ -36,8 +37,10 @@ from fastdeploy.cache_manager.ops import (
set_device,
share_external_data_,
swap_cache_all_layers,
swap_cache_layout,
unset_data_ipc,
)
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform
@@ -55,8 +58,9 @@ def parse_args():
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
@@ -101,6 +105,20 @@ def parse_args():
help="speculative config",
)
parser.add_argument("--create_cache_tensor", action="store_true")
parser.add_argument(
"--kvcache_storage_backend",
type=str,
default=None,
choices=["mooncake", "none"],
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
)
parser.add_argument(
"--write_policy",
type=str,
choices=["write_through"],
default="write_through",
help="KVCache write policy",
)
args = parser.parse_args()
return args
@@ -135,6 +153,9 @@ class CacheTransferManager:
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
self.n_ranks = args.mp_num
@@ -194,8 +215,54 @@ class CacheTransferManager:
suffix=args.engine_worker_queue_port,
create=False,
)
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
self.storage_backend = None
elif args.kvcache_storage_backend == "mooncake":
logger.info("Start initialize mooncake store...")
self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully")
else:
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
def _init_storage_buffer(self, args):
"""
Initialize pinned memory buffer that can hold the cache for a longest request
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
"""
layer_num = args.num_layers + self.num_extra_layers
head_num = self.key_cache_shape[1]
block_size = self.key_cache_shape[2]
head_dim = self.key_cache_shape[3]
block_num = (args.max_model_len + block_size - 1) // block_size
logger.info(
f"Creating cache buffer for storage with shape: "
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
)
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
read_buffer = cuda_host_alloc(total_bytes)
self.storage_key_read_buffer = read_buffer
self.storage_value_read_buffer = read_buffer + total_bytes // 2
self.storage_backend.register_buffer(read_buffer, total_bytes)
write_buffer = cuda_host_alloc(total_bytes)
self.storage_key_write_buffer = write_buffer
self.storage_value_write_buffer = write_buffer + total_bytes // 2
self.storage_backend.register_buffer(write_buffer, total_bytes)
def _init_gpu_cache(self, args):
try:
@@ -319,12 +386,7 @@ class CacheTransferManager:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
if args.cache_dtype == "bfloat16":
cache_bytes = 2
elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8":
cache_bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
cache_bytes = self._get_cache_bytes(self.cache_dtype)
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
if args.cache_dtype == "block_wise_fp8":
@@ -367,6 +429,222 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1
def _get_cache_bytes(self, cache_dtype):
if cache_dtype == "bfloat16":
cache_bytes = 2
elif cache_dtype in ["uint8", "block_wise_fp8"]:
cache_bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
return cache_bytes
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
"""
Given the k_keys and v_keys, get the valid blocks number that
can be prefetched from storage backend.
"""
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
result = self.storage_backend.exists(k_keys + v_keys)
# only consider the case when both key and value exist
num = 0
for k, v in zip(k_keys, v_keys):
if result[k] and result[v]:
num += 1
return num
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_read_storage, key_hash_keys: {k_cache_keys}, "
f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}"
)
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
v_cache_ptrs = [
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
for k, v in zip(k_result, v_result):
if k > 0 and v > 0:
success_block_num += 1
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
mode = 1 # cpu ==> gpu
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_read_buffer,
self.key_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_read_buffer,
self.value_cache_shape,
valid_gpu_block_ids,
valid_cpu_block_ids,
self.device,
mode,
)
return valid_gpu_block_ids
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
f"error:{e}, {traceback.format_exc()}"
)
raise
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
"""Read cache from the storage backend to the GPU memory."""
try:
logger.debug(
f"read_storage_task, task id: {task_id}, hash_keys: {keys}, "
f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
k_cache_keys = k_cache_keys[:match_block_num]
v_cache_keys = v_cache_keys[:match_block_num]
gpu_block_ids = gpu_block_ids[:match_block_num]
cpu_block_ids = [i for i in range(match_block_num)]
valid_gpu_block_ids = []
if match_block_num > 0:
# TODO: support timeout with actual block count
try:
valid_gpu_block_ids = self._run_read_storage(
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
)
logger.info(
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
)
except Exception as e:
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
valid_gpu_block_ids = []
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result)
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
logger.info(
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
f"valid block num {len(valid_gpu_block_ids)}"
)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
)
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
f"gpu_block_ids: {gpu_block_ids}"
)
key_cache_size = [
self.key_cache_shape[0],
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
mode = 0 # gpu ==> cpu
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
swap_cache_layout(
self.gpu_cache_v_tensors,
self.storage_value_write_buffer,
key_cache_size,
gpu_block_ids,
cpu_block_ids,
self.device,
mode,
)
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
k_cache_ptrs = [
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
v_cache_ptrs = [
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
f"error:{e}, {traceback.format_exc()}"
)
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
"""
Write cache to the storage backend from the GPU memory.
"""
try:
logger.debug(
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
f"task_id: {task_id}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
k_cache_keys = k_cache_keys[match_block_num:]
v_cache_keys = v_cache_keys[match_block_num:]
gpu_block_ids = gpu_block_ids[match_block_num:]
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
if len(k_cache_keys) == 0:
logger.info(f"No uncached keys found for task {task_id}")
gpu_block_ids = []
else:
try:
# TODO: support timeout with actual block count
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
except Exception as e:
logger.error(f"Error in write back storage task: {e}")
gpu_block_ids = []
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
self.cache_task_queue.swap_to_storage_barrier.wait()
if self.rank == 0: # 只有当rank为0时执行同步操作
self.cache_task_queue.swap_to_storage_barrier.reset()
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
f"error:{e}, {traceback.format_exc()}"
)
def _do_swap_to_cpu_task(
self,
swap_node_ids,
@@ -459,14 +737,9 @@ class CacheTransferManager:
logger.debug(f"transfer data: get_transfer_task {data}")
if read_finish:
self.cache_task_broadcast_signal.value[0] = 0
(
swap_node_ids,
gpu_block_id,
cpu_block_id,
event_type,
transfer_task_id,
) = data
event_type, transfer_task_id = data[0], data[1]
if event_type.value == CacheStatus.SWAP2CPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
self.swap_to_cpu_thread_pool.submit(
self._do_swap_to_cpu_task,
swap_node_ids,
@@ -475,7 +748,8 @@ class CacheTransferManager:
event_type,
transfer_task_id,
)
else:
elif event_type.value == CacheStatus.SWAP2GPU.value:
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
self.swap_to_gpu_thread_pool.submit(
self._do_swap_to_gpu_task,
swap_node_ids,
@@ -484,6 +758,24 @@ class CacheTransferManager:
event_type,
transfer_task_id,
)
elif event_type.value == CacheStatus.STORAGE2GPU.value:
hash_keys, gpu_block_ids, timeout = data[2:]
self.read_storage_thread_pool.submit(
self.read_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
)
elif event_type.value == CacheStatus.GPU2STORAGE.value:
hash_keys, gpu_block_ids, timeout = data[2:]
self.write_back_storage_thread_pool.submit(
self.write_back_storage_task,
transfer_task_id,
hash_keys,
gpu_block_ids,
timeout,
)
else:
if self.n_ranks > 1:
self.cache_task_queue.barrier2.wait()
@@ -635,11 +927,11 @@ class CacheTransferManager:
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
)
return (
event_type,
transfer_task_id,
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
)
def clear_or_update_caches(self, args):
@@ -738,9 +1030,7 @@ def main():
"""
启动cache manager
"""
cache_manager = CacheTransferManager(args)
cache_manager.do_data_transfer()
@@ -749,5 +1039,10 @@ if __name__ == "__main__":
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
logger.info(f"args: {vars(args)}")
set_device(args.device_id)
main()
try:
main()
except Exception as e:
logger.error(f"cache_transfer_manager failed with error: {e}, traceback: {traceback.format_exc()}")
raise
+4
View File
@@ -30,6 +30,7 @@ try:
set_data_ipc,
share_external_data,
swap_cache_all_layers,
swap_cache_layout,
unset_data_ipc,
)
@@ -50,6 +51,7 @@ try:
)
unset_data_ipc = None
swap_cache_layout = None
memory_allocated = paddle.device.xpu.memory_allocated
def get_data_ptr_ipc(*args, **kwargs):
@@ -102,6 +104,7 @@ except:
ipc_sent_key_value_cache_by_remote_ptr_block_sync = None
get_peer_mem_addr = None
get_all_visible_devices = None
swap_cache_layout = None
__all__ = [
@@ -119,4 +122,5 @@ __all__ = [
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
"get_peer_mem_addr",
"get_all_visible_devices",
"swap_cache_layout",
]
+209 -73
View File
@@ -14,10 +14,8 @@
# limitations under the License.
"""
import hashlib
import heapq
import os
import pickle
import subprocess
import sys
import threading
@@ -34,9 +32,10 @@ from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.cache_manager.ops import get_all_visible_devices
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
from fastdeploy.utils import get_hash_str, get_logger
logger = get_logger("prefix_cache_manager", "cache_manager.log")
@@ -65,6 +64,7 @@ class PrefixCacheManager:
self.enable_splitwise = 0
self.splitwise_role = splitwise_role
self.config = config
self.tensor_parallel_size = tensor_parallel_size
self.cache_config = config.cache_config
self.speculative_config = config.speculative_config
self.local_data_parallel_id = local_data_parallel_id
@@ -89,6 +89,13 @@ class PrefixCacheManager:
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
# prams for cache storage
self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend
self.write_policy = self.cache_config.write_policy
self.task_write_back_event = {}
self.task_prefetch_event = {}
self.storage_prefetch_block_ids = {}
# gpu cache data structure
self.gpu_lru_leaf_heap = []
self.gpu_lru_leaf_set = set()
@@ -105,7 +112,7 @@ class PrefixCacheManager:
self.req_leaf_map = {} # {request_id: leaf node}
self.leaf_req_map = defaultdict(set)
self.unfilled_req_block_map = defaultdict(list)
self.cache_info = {}
self.cache_info = {} # {request_id: (last_match_node, num_cached_tokens)}
self.executor_pool = ThreadPoolExecutor(max_workers=1)
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
@@ -253,6 +260,10 @@ class PrefixCacheManager:
else:
val_shape_str = str(val_cache_shape)
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
if cache_config.kvcache_storage_backend:
kvcache_storage_backend_str = cache_config.kvcache_storage_backend
else:
kvcache_storage_backend_str = "none"
for i in range(tensor_parallel_size):
launch_cmd = (
@@ -281,7 +292,10 @@ class PrefixCacheManager:
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" --default_dtype '{self.config.model_config.dtype}'"
+ (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
+ f" --write_policy {cache_config.write_policy}"
+ f" --max_model_len {self.config.model_config.max_model_len}"
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
)
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -290,7 +304,7 @@ class PrefixCacheManager:
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
if self.num_cpu_blocks > 0:
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
time.sleep(1)
@@ -303,7 +317,7 @@ class PrefixCacheManager:
)
# Start additional threads
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
logger.info("Enable hierarchical cache.")
threading.Thread(target=self.recv_data_transfer_result).start()
if cache_config.enable_prefix_caching:
@@ -505,13 +519,7 @@ class PrefixCacheManager:
self.task_swapping_event[transfer_task_id] = Event()
self.cache_task_queue.put_transfer_task(
(
swap_node_ids,
gpu_block_ids,
cpu_block_ids,
event_type,
transfer_task_id,
)
(event_type, transfer_task_id, swap_node_ids, gpu_block_ids, cpu_block_ids)
)
if is_sync:
self.sync_swap_task(transfer_task_id)
@@ -629,6 +637,10 @@ class PrefixCacheManager:
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
self.leaf_req_map[last_node].remove(req_id)
logger.debug(
f"update_cache_blocks: req_id {req_id}, num_cached_tokens {num_cached_tokens}, "
f"can_cache_computed_tokens {can_cache_computed_tokens}"
)
with self.request_release_lock:
leaf_node = self.mm_build_path(
@@ -640,7 +652,7 @@ class PrefixCacheManager:
)
self.req_leaf_map[req_id] = leaf_node
self.leaf_req_map[leaf_node].add(req_id)
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
self.cache_info[req_id] = [leaf_node, can_cache_computed_tokens]
task.cached_block_num = can_cache_computed_tokens // block_size
except Exception as e:
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
@@ -692,7 +704,7 @@ class PrefixCacheManager:
else:
prompt_token_ids = task.prompt_token_ids
req_id = task.request_id
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
logger.info(f"request_match_blocks: start to process req {req_id}")
input_token_num = len(prompt_token_ids + task.output_token_ids)
common_block_ids = []
# 1. match block
@@ -708,12 +720,14 @@ class PrefixCacheManager:
# update matched node info
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
# 2. prepare cache
# allocate gpu cache for matched cpu blocks
# 2. prepare cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
gpu_recv_block_ids = []
match_cpu_blocks_num = len(match_cpu_block_ids)
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
if match_cpu_blocks_num > 0:
logger.debug(
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
)
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
if len(gpu_recv_block_ids) > 0:
self._prepare_cpu_cache(
@@ -746,20 +760,58 @@ class PrefixCacheManager:
self.metrics._update_history_hit_metrics()
if self.metrics.req_count % 10000 == 0:
self.metrics.reset_metrics()
logger.info(
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
)
logger.info(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
# set leaf node temporarily, then update it in update_cache_blocks
self.req_leaf_map[req_id] = match_block_node
self.leaf_req_map[match_block_node].add(req_id)
# record request cache info
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
self.cache_info[req_id] = [match_block_node, len(common_block_ids) * block_size]
task.cached_block_num = len(common_block_ids)
return common_block_ids, matched_token_num, hit_info
except Exception as e:
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
raise e
def request_match_storage_blocks(self, request, extra_gpu_block_ids):
"""
Match and fetch the cached blocks from the storage backend for the given request.
# TODO: merge this function into request_match_blocks
args:
request: The request to be processed
extra_gpu_block_ids: A list of GPU block IDs to be used for fetching the cache
returns:
matched_block_ids: A list of block IDs that prefetched cache from storage
"""
if self.kvcache_storage_backend is None:
return []
req_id = request.request_id
input_ids = request.prompt_token_ids
block_size = self.cache_config.block_size
prefix_block_key = []
num_cached_tokens = 0
if req_id in self.cache_info:
last_node, num_cached_tokens = self.cache_info[req_id]
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
block_keys = []
current_tokens = num_cached_tokens
while current_tokens <= len(input_ids) - block_size:
cur_block_key = get_hash_str(input_ids[current_tokens : current_tokens + block_size], prefix_block_key)
block_keys.append(cur_block_key)
current_tokens += block_size
prefix_block_key = [cur_block_key]
logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}")
matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids)
logger.info(
f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}"
)
return matched_block_ids
def request_block_ids(self, task, block_size, dec_token_num, *args):
"""
Allocate blocks for a task.
@@ -806,10 +858,7 @@ class PrefixCacheManager:
current_time = time.time()
self._update_matched_node_info(req_id, match_block_node, current_time)
# 2. prepare cache
(
gpu_recv_block_ids,
gpu_extra_block_ids,
) = self._prepare_cache(
(gpu_recv_block_ids, gpu_extra_block_ids) = self._prepare_cache(
req_id,
input_ids,
block_size,
@@ -829,7 +878,6 @@ class PrefixCacheManager:
gpu_build_path_block_ids = []
gpu_build_path_block_ids = gpu_extra_block_ids
leaf_node = self.build_path(
req_id,
current_time,
@@ -883,6 +931,7 @@ class PrefixCacheManager:
with self.request_release_lock:
try:
req_id = task.request_id
keys = []
leaf_node = self.req_leaf_map.pop(req_id)
if leaf_node in self.leaf_req_map:
self.leaf_req_map[leaf_node].remove(req_id)
@@ -893,6 +942,7 @@ class PrefixCacheManager:
if req_id in node.req_id_set:
node.req_id_set.remove(req_id)
node.decrement_shared_count()
keys.append(node.hash_value)
node = node.parent
if req_id in self.cache_info:
@@ -919,6 +969,78 @@ class PrefixCacheManager:
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
raise e
def write_cache_to_storage(self, request: Request):
"""
For finished request, write cache to storage.
NOTE: this function does not modify the global params
"""
if self.kvcache_storage_backend is None:
return
req_id = request.request_id
keys = []
node = self.req_leaf_map[req_id]
while node != self.radix_tree_root:
keys.append(node.hash_value)
node = node.parent
keys = list(reversed(keys))
if not keys:
return
gpu_block_ids = request.block_tables[: len(keys)]
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
tic = time.time()
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
cost_time = time.time() - tic
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
if self.kvcache_storage_backend is None:
return
if len(hash_keys) != len(gpu_block_ids):
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
logger.error(err_msg)
raise ValueError(err_msg)
self.task_write_back_event[req_id] = Event()
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
if is_sync:
self.wait_write_storage_task(req_id)
def wait_write_storage_task(self, req_id):
"""
Sync write back task
"""
if req_id in self.task_write_back_event:
self.task_write_back_event[req_id].wait()
del self.task_write_back_event[req_id]
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
"""
Prefetch cache from storage task
"""
storage_block_ids = []
self.task_prefetch_event[req_id] = Event()
# issue task to cache_transfer_manager
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
if is_sync:
storage_block_ids = self.wait_prefetch_storage_task(req_id)
return storage_block_ids
def wait_prefetch_storage_task(self, req_id):
"""
Wait for prefetch cache from storage task to finish
"""
if req_id not in self.task_prefetch_event:
return None
self.task_prefetch_event[req_id].wait()
storage_block_ids = self.storage_prefetch_block_ids[req_id]
del self.task_prefetch_event[req_id]
del self.storage_prefetch_block_ids[req_id]
return storage_block_ids
def free_nodes_directly(self, node):
with self.request_release_lock:
try:
@@ -1069,10 +1191,7 @@ class PrefixCacheManager:
break
node = heapq.heappop(self.gpu_lru_leaf_heap)
self.gpu_lru_leaf_set.remove(node)
if (
not self.cache_config.enable_hierarchical_cache
or self.cache_config.num_cpu_blocks < need_block_num
):
if self.cache_config.num_cpu_blocks < need_block_num:
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
self._handle_free_gpu_node_without_cpu(node)
total_gpu_free_count += 1
@@ -1195,12 +1314,6 @@ class PrefixCacheManager:
)
return total_cpu_free_count
def cal_block_hash(self, block):
"""
calculate hash value of a block
"""
return hash(tuple(block))
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
"""
Retrieves additional hash keys for block identification.
@@ -1260,16 +1373,6 @@ class PrefixCacheManager:
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return len(mm_inputs["mm_positions"]) - 1, hash_keys
def hash_block_features(self, input_ids, extra_keys: list = []):
"""
calculate hash value of a block with additional keys
Args:
input_ids: Input token IDs
extra_keys: Additional keys for block identification
"""
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
def _revert_match_blocks(
self,
request,
@@ -1363,6 +1466,7 @@ class PrefixCacheManager:
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
prefix_block_key = []
with self.cache_status_lock:
while match_token_num < total_token_num:
@@ -1376,7 +1480,10 @@ class PrefixCacheManager:
end_idx=match_token_num + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(token_block, extra_keys)
prefix_block_key.extend(extra_keys)
hash_value = get_hash_str(token_block, prefix_block_key)
prefix_block_key = [hash_value]
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
@@ -1476,6 +1583,7 @@ class PrefixCacheManager:
matche_nodes = []
has_modified_gpu_lru_leaf_heap = False
has_modified_cpu_lru_leaf_heap = False
prefix_block_key = []
with self.cache_status_lock:
while match_token_num < total_token_num:
@@ -1483,7 +1591,8 @@ class PrefixCacheManager:
token_num = len(token_block)
if token_num != block_size:
break
hash_value = self.cal_block_hash(token_block)
hash_value = get_hash_str(token_block, prefix_block_key)
prefix_block_key = [hash_value]
if hash_value in current_match_node.children:
child = current_match_node.children[hash_value]
matche_nodes.append(child)
@@ -1515,6 +1624,8 @@ class PrefixCacheManager:
swap_node_ids.append(child.node_id)
match_token_num = match_token_num + block_size
current_match_node = child
# record request cache info
self.cache_info[req_id] = [child, match_token_num]
else:
break
@@ -1577,8 +1688,10 @@ class PrefixCacheManager:
has_unfilled_block = False
current_time = time.time()
input_hash_value = self.hash_block_features(input_ids)
input_hash_value = get_hash_str(input_ids)
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
current_block = input_ids[i : i + block_size]
current_block_size = len(current_block) # 最后一个block可能没填满
@@ -1591,7 +1704,9 @@ class PrefixCacheManager:
end_idx=i + block_size,
mm_idx=mm_idx,
)
hash_value = self.hash_block_features(current_block, extra_keys)
prefix_block_key.extend(extra_keys)
hash_value = get_hash_str(current_block, prefix_block_key)
prefix_block_key = [hash_value]
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
@@ -1651,7 +1766,7 @@ class PrefixCacheManager:
gpu_block_ids = gpu_block_ids.copy()
node = last_node
reverved_dec_block_ids = []
input_hash_value = self.cal_block_hash(input_ids)
input_hash_value = get_hash_str(input_ids)
token_num = len(left_input_ids)
if token_num == 0:
@@ -1663,6 +1778,7 @@ class PrefixCacheManager:
unique_node_ids = []
new_last_node = last_node
has_unfilled_block = False
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
for i in range(0, token_num, block_size):
current_block = left_input_ids[i : i + block_size]
@@ -1670,7 +1786,8 @@ class PrefixCacheManager:
if current_block_size != block_size:
has_unfilled_block = True
else:
hash_value = self.cal_block_hash(current_block)
hash_value = get_hash_str(current_block, prefix_block_key)
prefix_block_key = [hash_value]
allocated_block_id = gpu_block_ids.pop(0)
node_id = self.node_id_pool.pop()
unique_node_ids.append(node_id)
@@ -1764,28 +1881,47 @@ class PrefixCacheManager:
if data is None:
time.sleep(0.001)
continue
(
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
event_type,
transfer_task_id,
) = data
length = len(task_gpu_block_id)
for i in range(length):
self._handle_swap_result(
swap_node_ids[i],
task_gpu_block_id[i],
task_cpu_block_id[i],
event_type = data[0]
if event_type.value == CacheStatus.STORAGE2GPU.value:
logger.info(f"recv_data_transfer_result: {data}")
task_id, hash_keys, block_ids = data[1:]
if task_id not in self.storage_prefetch_block_ids:
self.storage_prefetch_block_ids[task_id] = []
saved_block_ids = self.storage_prefetch_block_ids[task_id]
saved_block_ids.append(block_ids)
if len(saved_block_ids) == self.tensor_parallel_size:
self.storage_prefetch_block_ids[task_id] = min(saved_block_ids, key=len)
if task_id in self.task_prefetch_event:
self.task_prefetch_event[task_id].set()
elif event_type.value == CacheStatus.GPU2STORAGE.value:
logger.info(f"recv_data_transfer_result: {data}")
task_id, hash_keys, block_ids = data[1:]
if task_id in self.task_write_back_event:
self.task_write_back_event[task_id].set()
else:
(
event_type,
transfer_task_id,
swap_node_ids,
task_gpu_block_id,
task_cpu_block_id,
) = data
length = len(task_gpu_block_id)
for i in range(length):
self._handle_swap_result(
swap_node_ids[i],
task_gpu_block_id[i],
task_cpu_block_id[i],
event_type,
)
if transfer_task_id in self.task_swapping_event:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
if transfer_task_id in self.task_swapping_event:
self.task_swapping_event[transfer_task_id].set()
logger.info(
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
)
except Exception as e:
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
raise e
@@ -14,7 +14,21 @@
# limitations under the License.
"""
from .ipc_cache_transfer import IPCCommManager
from fastdeploy.platforms import current_platform
from .kvcache_storage import KVCacheStorage
from .mooncake_store import MooncakeStore
from .rdma_cache_transfer import RDMACommManager
__all__ = ["IPCCommManager", "RDMACommManager"]
if current_platform.is_cuda():
from .ipc_cache_transfer import IPCCommManager
else:
IPCCommManager = None
__all__ = [
"IPCCommManager",
"RDMACommManager",
"KVCacheStorage",
"MooncakeStore",
]
@@ -0,0 +1,97 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from abc import ABC, abstractmethod
from typing import Any, List, Optional
import paddle
from fastdeploy.utils import get_logger
logger = get_logger("cache_storage", "cache_storage.log")
class KVCacheStorage(ABC):
"""
KVCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
"""
@abstractmethod
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> paddle.Tensor | None:
"""
Retrieve the value associated with the given key.
Returns None if the key does not exist.
"""
pass
@abstractmethod
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[paddle.Tensor | None]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod
def set(
self,
key: str,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> bool:
"""
Store the value associated with the given key.
Returns True if the operation was successful, False otherwise.
"""
pass
@abstractmethod
def batch_set(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
"""
Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise.
"""
pass
@abstractmethod
def exists(self, keys: List[str]) -> bool:
"""
Check if the key exists in the storage.
Returns True if the key exists, False otherwise.
"""
pass
@abstractmethod
def clear(self) -> bool:
"""
Clear all keys in storage
"""
pass
@@ -0,0 +1,19 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .mooncake_store import MooncakeStore
__all__ = ["MooncakeStore"]
@@ -0,0 +1,320 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import time
import traceback
import uuid
from dataclasses import dataclass
from typing import Any, List, Optional
from fastdeploy.cache_manager.transfer_factory.kvcache_storage import (
KVCacheStorage,
logger,
)
from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics
from fastdeploy.platforms import current_platform
DEFAULT_GLOBAL_SEGMENT_SIZE = 1024 * 1024 * 1024 # 1 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128MB
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
rdma_devices: str
master_server_addr: str
@staticmethod
def create() -> "MooncakeStoreConfig":
"""Load the config from a JSON file or environment variables."""
config = {}
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if file_path is None:
local_hostname = os.environ.get("MOONCAKE_LOCAL_HOSTNAME")
metadata_server = os.environ.get("MOONCAKE_METADATA_SERVER")
global_segment_size = os.environ.get("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
local_buffer_size = os.environ.get("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
protocol = os.environ.get("MOONCAKE_PROTOCOL", "rdma")
rdma_devices = os.environ.get("MOONCAKE_RDMA_DEVICES", "")
master_server_addr = os.environ.get("MOONCAKE_MASTER_SERVER_ADDR")
else:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File path {file_path} for creating MooncakeStoreConfig does not exist.")
with open(file_path) as fin:
config = json.load(fin)
local_hostname = config.get("local_hostname")
metadata_server = config.get("metadata_server")
global_segment_size = config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
local_buffer_size = config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)
protocol = config.get("protocol", "rdma")
rdma_devices = config.get("rdma_devices", "")
master_server_addr = config.get("master_server_addr")
if rdma_devices == "" and current_platform.is_cuda():
# FIXME: use auto-select NICs in MooncakeStore will raise error and roll back to using TCP
rdma_devices = get_rdma_nics()
logger.info(f"No RDMA devices specified, defaulting to all available devices: {rdma_devices}")
return MooncakeStoreConfig(
local_hostname=local_hostname,
metadata_server=metadata_server,
global_segment_size=global_segment_size,
local_buffer_size=local_buffer_size,
protocol=protocol,
rdma_devices=rdma_devices,
master_server_addr=master_server_addr,
)
def select_rdma_device(self, tp_rank):
"""Select RDMA device based on rank number."""
device_list = self.rdma_devices.split(",")
device_index = tp_rank % len(device_list)
self.rdma_devices = device_list[device_index]
class MooncakeStore(KVCacheStorage):
def __init__(self, tp_rank=None):
super().__init__()
self.tp_rank = tp_rank
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake store by following the instructions at "
"https://kvcache-ai.github.io/Mooncake/python-api-reference/mooncake-store.html"
"to run Fastdeploy with mooncake store."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.create()
if self.tp_rank is not None:
self.config.select_rdma_device(self.tp_rank)
logger.info(f"Mooncake Configuration loaded, {self.config}.")
ret_code = self.store.setup(
local_hostname=self.config.local_hostname,
metadata_server=self.config.metadata_server,
global_segment_size=self.config.global_segment_size,
local_buffer_size=self.config.local_buffer_size,
protocol=self.config.protocol,
rdma_devices=self.config.rdma_devices,
master_server_addr=self.config.master_server_addr,
)
if ret_code != 0:
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
raise RuntimeError(f"failed to setup mooncake store, error code: {ret_code}")
logger.info("Connect to Mooncake store successfully.")
self.warmup()
logger.info("Mooncake store warmup successfully.")
except Exception as e:
logger.error(f"Mooncake store initialization failed: {e}, traceback: {traceback.format_exc()}")
raise
def warmup(self):
warmup_key = "fastdeploy_mooncake_store_warmup_key" + str(uuid.uuid4())
warmup_value = bytes(1 * 1024 * 1024) # 1 MB
self.store.put(warmup_key, warmup_value)
assert self.store.is_exist(warmup_key) == 1
self.store.get(warmup_key)
self.store.remove(warmup_key)
def register_buffer(self, buffer_ptr, buffer_size) -> None:
try:
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
if ret_code:
logger.error(f"failed to register buffer, error code: {ret_code}")
except TypeError as err:
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
def set(
self,
key,
target_location: Optional[List[int]] = None,
target_size: Optional[List[int]] = None,
) -> List[int]:
pass
def batch_set(
self,
keys: List[str],
target_locations: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> List[int]:
"""
Batch put multiple objects into the store.
Args:
keys (list): list of object names to be stored
target_locations (list): list of memory locations where the data are stored
target_sizes (list): list of byte sizes corresponding to each object
Return:
List[int]: List of status codes for each operation (0 = success, negative = error)
"""
if not (len(keys) == len(target_locations) == len(target_sizes)):
err_msg = "The length of keys, target_location and target_sizes must match."
logger.error(err_msg)
raise ValueError(err_msg)
if len(keys) == 0:
err_msg = "The length of keys, target_location and target_sizes must be greater than zero"
logger.error(err_msg)
raise ValueError(err_msg)
return self._put_batch_zero_copy_impl(keys, target_locations, target_sizes)
def get(
self,
key,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> List[int]:
pass
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[int]:
"""
Batch get multiple objects from the store.
Args:
keys (list): list of object names to be fetched
target_locations (list): list of memory locations where the data should be stored
target_sizes (list): list of byte sizes corresponding to each object
Returns:
List[int]: List of bytes read for each operation (positive = success, negative = error)
"""
if not (len(keys) == len(target_locations) == len(target_sizes)):
err_msg = "The length of keys, target_locations and target_sizes must match."
logger.error(err_msg)
raise ValueError(err_msg)
if len(keys) == 0:
err_msg = "The length of keys, target_locations and target_sizes must be greater than zero"
logger.error(err_msg)
raise ValueError(err_msg)
return self._get_batch_zero_copy_impl(keys, target_locations, target_sizes)
def exists(self, keys: List[str]):
"""
Check existence of multiple objects in a single batch operation.
Args:
keys (list): list of object names to be checked
Returns:
dict: dictionary mapping each key to its existence status {key: True|False}
"""
tic = time.time()
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(keys))}
cost_time = (time.time() - tic) * 1000
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
return result
def delete(self, key, timeout=5) -> bool:
while timeout:
result = self.store.remove(key)
if result == 0:
logger.info("Successfully removed")
return True
else:
time.sleep(1)
timeout -= 1
return False
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def clear(self) -> bool:
"""
clear all the objects in the store
"""
count = self.store.remove_all()
logger.info(f"Removed {count} objects")
return True
def _put_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> int:
try:
tic = time.time()
result = self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
# List[int]: List of status codes for each operation (0 = success, negative = error)
cost_time = time.time() - tic
total_num = len(key_strs)
success_num = result.count(0)
if success_num == total_num:
logger.debug(
f"Put all data into Mooncake Store successfully."
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
else:
logger.error(
f"Some of the data was not put into Mooncake Store."
f"total_num: {total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
if success_num > 0:
total_bytes = sum(bi for ri, bi in zip(result, buffer_sizes) if ri == 0)
total_gb = total_bytes / 1073741824
speed = total_gb / cost_time
logger.info(f"Put data into Mooncake Store, total_gb: {total_gb:.6f}GB, speed: {speed:.6f}GB/s")
return result
except Exception as err:
logger.error("Failed to put data into Mooncake Store: %s", err)
raise
def _get_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> int:
try:
tic = time.time()
result = self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
# List[int]: List of bytes read for each operation (positive = success, negative = error)
cost_time = time.time() - tic
total_num = len(key_strs)
success_num = sum(x > 0 for x in result)
if success_num == total_num:
logger.debug(
f"Get all data from Mooncake Store successfully. "
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
else:
logger.error(
f"Some of the data was not get from Mooncake Store."
f"total_num:{total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
)
if success_num > 0:
total_bytes = sum(bi for ri, bi in zip(result, buffer_sizes) if ri > 0)
total_gb = total_bytes / 1073741824
speed = total_gb / cost_time
logger.info(f"Get data from Mooncake Store, total_gb: {total_gb:.6f}GB, speed: {speed:.6f}GB/s")
return result
except Exception as err:
logger.error("Failed to get data from Mooncake Store: %s", err)
raise
@@ -16,6 +16,7 @@
import traceback
from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
@@ -40,7 +41,6 @@ class RDMACommManager:
prefill_tp_idx,
):
try:
import importlib
import os
import subprocess
@@ -66,28 +66,9 @@ class RDMACommManager:
logger.info("Setting environment variable: export KVCACHE_GDRCOPY_FLUSH_ENABLE=1")
if os.getenv("KVCACHE_RDMA_NICS", "") == "" and current_platform.is_cuda():
res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh"
get_rdma_nics = None
with importlib.resources.as_file(res) as path:
get_rdma_nics = str(path)
nic_type = current_platform.device_name
command = ["bash", get_rdma_nics, nic_type]
result = subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)
logger.info(f"get_rdma_nics command: {command}")
logger.info(f"get_rdma_nics output: {result.stdout}")
if result.returncode != 0:
raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}")
env_name, env_value = result.stdout.strip().split("=")
assert env_name == "KVCACHE_RDMA_NICS"
os.environ[env_name] = env_value
logger.info(f"Setting environment variable: export {env_name}={env_value}")
rdma_nics = get_rdma_nics()
os.environ["KVCACHE_RDMA_NICS"] = rdma_nics
logger.info(f"Setting environment variable: export KVCACHE_RDMA_NICS={rdma_nics}")
except Exception as e:
raise RuntimeError(f"Failed to initialize RDMA environment! {e} {traceback.format_exc()}")
@@ -0,0 +1,49 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import importlib
import subprocess
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger
logger = get_logger("cache_messager", "cache_messager.log")
def get_rdma_nics():
res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh"
with importlib.resources.as_file(res) as path:
file_path = str(path)
nic_type = current_platform.device_name
command = ["bash", file_path, nic_type]
result = subprocess.run(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)
logger.info(f"get_rdma_nics command: {command}")
logger.info(f"get_rdma_nics output: {result.stdout}")
if result.returncode != 0:
raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}")
env_name, env_value = result.stdout.strip().split("=")
if env_name != "KVCACHE_RDMA_NICS":
raise ValueError(f"Unexpected variable name: {env_name}, expected 'KVCACHE_RDMA_NICS'")
return env_value
+3 -5
View File
@@ -1296,6 +1296,9 @@ class CacheConfig:
self.max_processor_cache = None
self.enable_output_caching = False
self.disable_chunked_mm_input = False
self.kvcache_storage_backend = None
self.write_policy = None
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
@@ -1304,11 +1307,6 @@ class CacheConfig:
self.rdma_comm_ports = parse_ports(self.rdma_comm_ports)
self.pd_comm_port = parse_ports(self.pd_comm_port)
if self.swap_space is None:
self.enable_hierarchical_cache = False
else:
self.enable_hierarchical_cache = True
if self.model_cfg is not None:
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
+32
View File
@@ -230,6 +230,14 @@ class EngineArgs:
"""
Port for cache queue.
"""
kvcache_storage_backend: str = None
"""
The storage backend for kvcache storage. If set, it will use the kvcache storage backend.
"""
write_policy: str = "write_through"
"""
The policy of write cache to storage.
"""
# System configuration parameters
use_warmup: int = 0
@@ -557,6 +565,14 @@ class EngineArgs:
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
envs.FD_ENABLE_MAX_PREFILL = 1
if self.kvcache_storage_backend is not None:
if not self.enable_prefix_caching:
raise NotImplementedError("kvcache_storage_backend is only supported when enable_prefix_caching=True")
if envs.ENABLE_V1_KVCACHE_SCHEDULER == 0:
raise NotImplementedError(
"kvcache_storage_backend is only supported when ENABLE_V1_KVCACHE_SCHEDULER=1"
)
self.post_init_all_ports()
def post_init_all_ports(self):
@@ -1018,6 +1034,22 @@ class EngineArgs:
help="Static decoding blocks num.",
)
cache_group.add_argument(
"--kvcache-storage-backend",
type=nullable_str,
choices=["mooncake"],
default=EngineArgs.kvcache_storage_backend,
help="The storage backend for kvcache storage. Leave empty to disable.",
)
cache_group.add_argument(
"--write-policy",
type=str,
choices=["write_through"],
default=EngineArgs.write_policy,
help="KVCache write policy",
)
# Cluster system parameters group
system_group = parser.add_argument_group("System Configuration")
system_group.add_argument(
+7
View File
@@ -524,6 +524,13 @@ class RequestMetrics:
speculate_metrics: Optional[SpeculateMetrics] = None
# cache related
gpu_cache_token_num: Optional[int] = 0
cpu_cache_token_num: Optional[int] = 0
storage_cache_token_num: Optional[int] = 0
gpu_cpu_cache_prepare_time: Optional[float] = None
storage_cache_prepare_time: Optional[float] = None
def __post_init__(self):
if self.arrival_time is None:
self.arrival_time = time.time()
+97 -41
View File
@@ -670,7 +670,8 @@ class ResourceManagerV1(ResourceManager):
request, self.config.cache_config.block_size, request.num_computed_tokens
)
req_index += 1
# schedule the WAITING requests.
# Second, schedule the WAITING requests.
if not preempted_reqs:
skip_requests: list[Request] = []
while self.waiting and token_budget > 0:
@@ -699,10 +700,7 @@ class ResourceManagerV1(ResourceManager):
self._update_mm_hashes(request)
# Enable prefix caching
if self.config.cache_config.enable_prefix_caching:
if (
self.config.cache_config.enable_hierarchical_cache
and self.cache_manager.num_cpu_blocks > 0
):
if self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
@@ -713,12 +711,21 @@ class ResourceManagerV1(ResourceManager):
self._free_blocks(request)
break
# Allocate blocks for the tokens that does not hit cache
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
if not request.get("skip_allocate", False):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
if (
self.config.cache_config.enable_prefix_caching
and self.config.cache_config.kvcache_storage_backend
and num_new_tokens >= self.config.cache_config.block_size
):
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
@@ -744,10 +751,7 @@ class ResourceManagerV1(ResourceManager):
request.num_total_tokens
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
if self.config.cache_config.enable_prefix_caching:
if (
self.config.cache_config.enable_hierarchical_cache
and self.cache_manager.num_cpu_blocks > 0
):
if self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
// self.config.cache_config.block_size
@@ -757,12 +761,22 @@ class ResourceManagerV1(ResourceManager):
if not success:
self._free_blocks(request)
break
# Allocate blocks for the tokens that does not hit cache
num_new_tokens = self._get_num_new_tokens(request, token_budget)
num_new_block = self.get_new_block_nums(request, num_new_tokens)
# Allocate blocks to prefill
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
if not request.get("skip_allocate", False):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
request.block_tables.extend(extra_gpu_block_ids)
if (
self.config.cache_config.enable_prefix_caching
and self.config.cache_config.kvcache_storage_backend
and num_new_tokens >= self.config.cache_config.block_size
):
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
self.waiting.popleft()
self.running.append(request)
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
@@ -916,28 +930,61 @@ class ResourceManagerV1(ResourceManager):
)
request.num_cached_tokens = matched_token_num
request.gpu_cache_token_num = hit_info["gpu_match_token_num"]
request.cpu_cache_token_num = hit_info["cpu_match_token_num"]
request.cache_info = (matched_block_num, no_cache_block_num)
request.metrics.gpu_cache_token_num = hit_info["gpu_match_token_num"]
request.metrics.cpu_cache_token_num = hit_info["cpu_match_token_num"]
request.cache_info = [matched_block_num, no_cache_block_num]
request.block_tables = common_block_ids
request.skip_allocate = False
# Report the number of cached tokens to Prometheus metrics
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True
else:
request.num_computed_tokens = matched_token_num
request.cache_prepare_time = time.time() - cache_prepare_time
request.metrics.gpu_cpu_cache_prepare_time = time.time() - cache_prepare_time
return True
except Exception as e:
llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...")
return False
def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list = []):
"""
Match and prefetch the cached blocks from the storage backend.
TODO: merge this function into get_prefix_cached_blocks
"""
try:
tic = time.time()
req_id = request.request_id
llm_logger.debug(f"get_storage_cached_blocks start process req {req_id}")
matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids)
llm_logger.debug(
f"matched {len(matched_block_ids)} blocks from storage for req_id:{req_id}, "
f"cost_time: {time.time() - tic:.6f}s"
)
matched_token_num = len(matched_block_ids) * self.config.cache_config.block_size
request.metrics.storage_cache_token_num = matched_token_num
request.num_computed_tokens += matched_token_num
if request.num_computed_tokens == request.need_prefill_tokens:
request.num_computed_tokens = request.num_computed_tokens - self.config.cache_config.block_size
request.metrics.storage_cache_prepare_time = time.time() - tic
request.cache_info[0] += len(matched_block_ids) # matched_block_num
request.cache_info[1] -= len(matched_block_ids) # no_cache_block_num
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
# TODO: main_process_metrics.prefix_storage_cache_token_num.inc(matched_token_num)
return matched_block_ids
except Exception as e:
llm_logger.error(
f"get_storage_cached_blocks process req {req_id}, error: {e}, {str(traceback.format_exc())} "
)
return []
def add_request(self, request: Request) -> None:
with self.lock:
self.apply_async_preprocess(request)
@@ -980,7 +1027,7 @@ class ResourceManagerV1(ResourceManager):
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
if self.config.cache_config.enable_prefix_caching:
# Enable prefix caching
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
if self.cache_manager.num_cpu_blocks > 0:
if not self.cache_manager.can_allocate_gpu_blocks(
need_prealloc_prefill_blocks
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
@@ -992,7 +1039,10 @@ class ResourceManagerV1(ResourceManager):
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)
if self.config.cache_config.enable_prefix_caching:
self.get_storage_cached_blocks(request, extra_gpu_block_ids)
request.block_tables.extend(extra_gpu_block_ids)
allocated_position = self.get_available_position()
request.idx = allocated_position
self.tasks_list[request.idx] = request
@@ -1113,39 +1163,45 @@ class ResourceManagerV1(ResourceManager):
def finish_requests(self, request_ids: Union[str, Iterable[str]]):
llm_logger.info(f"recycle resources for requests: {request_ids}")
try:
if isinstance(request_ids, str):
request_ids = (request_ids,)
else:
request_ids = set(request_ids)
need_postprocess_reqs = []
with self.lock:
if isinstance(request_ids, str):
request_ids = (request_ids,)
else:
request_ids = set(request_ids)
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
if request in self.running: # normally run and finished
if request in self.waiting:
llm_logger.error(f"request {request.request_id} scheduled into waiting list, after finished")
continue
if request in self.running:
self.running.remove(request)
request.status = RequestStatus.FINISHED
try:
self._free_blocks(request)
except Exception as e:
llm_logger.warning(f"release block failed {req_id}: {e}")
if (
request.request_id in self.to_be_rescheduled_request_id_set
): # finished after preempted, blocks have been recycled.
self.to_be_rescheduled_request_id_set.remove(
request.request_id
) # just remove from to_be_rescheduled_request_id_set
if (
request in self.waiting
): # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished")
need_postprocess_reqs.append(request)
if request.request_id in self.to_be_rescheduled_request_id_set:
# finished after preempted, blocks have been recycled.
self.to_be_rescheduled_request_id_set.remove(request.request_id)
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
# Do not block the main thread here
for req in need_postprocess_reqs:
self.cache_manager.write_cache_to_storage(req)
with self.lock:
for req in need_postprocess_reqs:
try:
self._free_blocks(req)
except Exception as e:
llm_logger.warning(f"release block failed {req.request_id}: {e}")
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
finally:
@@ -102,6 +102,12 @@ class EngineCacheQueue:
self.swap_to_gpu_barrier2_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.swap_storage_to_gpu_barrier_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
self.swap_to_storage_barrier_init = [
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
]
# Register shared objects with proxy types
QueueManager.register(
@@ -148,7 +154,14 @@ class EngineCacheQueue:
"get_swap_to_gpu_barrier2",
callable=lambda idx: self.swap_to_gpu_barrier2_init[idx],
)
QueueManager.register(
"get_swap_storage_to_gpu_barrier",
callable=lambda idx: self.swap_storage_to_gpu_barrier_init[idx],
)
QueueManager.register(
"get_swap_to_storage_barrier",
callable=lambda idx: self.swap_to_storage_barrier_init[idx],
)
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
self.manager.start()
@@ -175,6 +188,8 @@ class EngineCacheQueue:
QueueManager.register("get_swap_to_cpu_barrier2")
QueueManager.register("get_swap_to_gpu_barrier1")
QueueManager.register("get_swap_to_gpu_barrier2")
QueueManager.register("get_swap_storage_to_gpu_barrier")
QueueManager.register("get_swap_to_storage_barrier")
self.manager = QueueManager(address=self.address, authkey=self.authkey)
self._connect_with_retry()
@@ -194,6 +209,8 @@ class EngineCacheQueue:
self.swap_to_cpu_barrier2 = self.manager.get_swap_to_cpu_barrier2(self.local_data_parallel_id)
self.swap_to_gpu_barrier1 = self.manager.get_swap_to_gpu_barrier1(self.local_data_parallel_id)
self.swap_to_gpu_barrier2 = self.manager.get_swap_to_gpu_barrier2(self.local_data_parallel_id)
self.swap_storage_to_gpu_barrier = self.manager.get_swap_storage_to_gpu_barrier(self.local_data_parallel_id)
self.swap_to_storage_barrier = self.manager.get_swap_to_storage_barrier(self.local_data_parallel_id)
self.total_num: int = (1 << self.num_client) - 1
if not is_server:
@@ -241,7 +258,7 @@ class EngineCacheQueue:
self.task_lock.acquire()
self.task_sync_value.set(0)
self.transfer_task_queue.append(item)
logger.info(f"put_transfer_task: put swap task {item[-1]} to queue successful")
logger.info(f"put_transfer_task: put swap task {item} to queue successful")
self.task_lock.release()
def get_transfer_task(self):
+13
View File
@@ -17,6 +17,7 @@
import argparse
import asyncio
import codecs
import hashlib
import importlib
import json
import logging
@@ -824,6 +825,18 @@ def retrive_model_from_server(model_name_or_path, revision="master"):
return model_name_or_path
def get_hash_str(token_ids: List[int], extra_keys: Optional[Any] = []) -> str:
"""
calculate hash value of a block with additional keys
Args:
token_ids: Input token IDs
extra_keys: Additional keys for block identification
"""
value = (token_ids, extra_keys)
return hashlib.sha256(pickle.dumps(value)).hexdigest()
def is_list_of(
value: object,
typ: Union[type[T], tuple[type[T], ...]],
@@ -41,6 +41,8 @@ class Args:
create_cache_tensor = False
cache_dtype = "bfloat16"
default_dtype = "bfloat16"
kvcache_storage_backend = None
write_policy = "write_through"
# ==========================
@@ -27,6 +27,7 @@ import pytest
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
from fastdeploy.inter_communicator.ipc_signal_const import PrefixTreeStatus
from fastdeploy.utils import get_hash_str
# Metric test double used to track metric updates.
@@ -179,6 +180,8 @@ def _create_manager(
rdma_comm_ports=None,
local_cache_queue_port=9000,
local_rdma_comm_ports=None,
kvcache_storage_backend=None,
write_policy="write_through",
)
model_config = SimpleNamespace(
num_attention_heads=1,
@@ -186,6 +189,7 @@ def _create_manager(
head_dim=1,
_architecture="",
dtype="float16",
max_model_len=128,
)
config = SimpleNamespace(
cache_config=cache_config,
@@ -199,7 +203,7 @@ def _create_manager(
def _make_block_node(manager, node_id, input_ids, *, block_size=2, parent=None, cache_status=CacheStatus.GPU):
parent = parent or manager.radix_tree_root
block_hash = manager.cal_block_hash(input_ids)
block_hash = get_hash_str(input_ids)
node = BlockNode(
node_id,
input_ids,
@@ -878,10 +882,10 @@ class PrefixCacheManagerTest(unittest.TestCase):
manager = _create_manager(num_gpu_blocks=4)
block_size = 2
root = manager.radix_tree_root
gpu_hash = manager.cal_block_hash([1, 2])
gpu_hash = get_hash_str([1, 2])
gpu_node = BlockNode(1, [], 0, 1, 0, block_size, gpu_hash, 0, parent=root)
root.children[gpu_hash] = gpu_node
cpu_hash = manager.cal_block_hash([3, 4])
cpu_hash = get_hash_str([3, 4], extra_keys=[gpu_hash])
cpu_node = BlockNode(2, [], 0, 2, 1, block_size, cpu_hash, 0, parent=gpu_node, cache_status=CacheStatus.CPU)
gpu_node.children[cpu_hash] = cpu_node
manager.gpu_lru_leaf_set.add(gpu_node)
@@ -917,7 +921,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_free_block_ids_async_recycles_gpu_nodes(self):
manager = _create_manager(num_gpu_blocks=4)
node_hash = manager.cal_block_hash([1, 2])
node_hash = get_hash_str([1, 2])
node = BlockNode(10, [1, 2], node_hash, 1, 0, 2, node_hash, 0, parent=manager.radix_tree_root)
node.shared_count = 0
manager.radix_tree_root.children[node_hash] = node
@@ -941,7 +945,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
manager.issue_swap_task = _fake_issue
node_hash = manager.cal_block_hash([3, 4])
node_hash = get_hash_str([3, 4])
node = BlockNode(11, [3, 4], node_hash, 1, 1, 2, node_hash, 0, parent=manager.radix_tree_root)
node.shared_count = 0
manager.radix_tree_root.children[node_hash] = node
@@ -957,9 +961,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
block_size = 2
manager.cache_config.disable_chunked_mm_input = False
input_ids = [1, 2, 3, 4]
hash_input = manager.hash_block_features(input_ids)
hash_first = manager.hash_block_features([1, 2])
hash_second = manager.hash_block_features([3, 4], ["img"])
hash_input = get_hash_str(input_ids)
hash_first = get_hash_str([1, 2])
hash_second = get_hash_str([3, 4], [hash_first, "img"])
node1 = BlockNode(30, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
manager.radix_tree_root.children[hash_first] = node1
@@ -1004,9 +1008,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
manager.cache_config.disable_chunked_mm_input = False
block_size = 2
input_ids = [1, 2, 3, 4]
hash_input = manager.hash_block_features(input_ids)
hash_first = manager.hash_block_features([1, 2])
hash_second = manager.hash_block_features([3, 4], ["img"])
hash_input = get_hash_str(input_ids)
hash_first = get_hash_str([1, 2])
hash_second = get_hash_str([3, 4], [hash_first, "img"])
node1 = BlockNode(40, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
node2 = BlockNode(
41,
@@ -1045,7 +1049,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_release_block_ids_cleans_request_state(self):
manager = _create_manager(num_gpu_blocks=4)
node = BlockNode(50, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
node = BlockNode(50, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
node.cache_status = CacheStatus.GPU
manager.radix_tree_root.children[node.hash_value] = node
req_id = "release-req"
@@ -1061,7 +1065,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_free_cpu_block_ids_eviction(self):
manager = _create_manager(num_gpu_blocks=2, num_cpu_blocks=2)
cpu_node = BlockNode(60, [3, 4], 0, 1, 0, 2, manager.cal_block_hash([3, 4]), 0, parent=manager.radix_tree_root)
cpu_node = BlockNode(60, [3, 4], 0, 1, 0, 2, get_hash_str([3, 4]), 0, parent=manager.radix_tree_root)
cpu_node.cache_status = CacheStatus.CPU
manager.cpu_lru_leaf_heap.append(cpu_node)
manager.cpu_lru_leaf_set.add(cpu_node)
@@ -1070,8 +1074,8 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_free_nodes_directly_recovers_chain(self):
manager = _create_manager(num_gpu_blocks=4)
parent = BlockNode(70, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
child_hash = manager.cal_block_hash([3, 4])
parent = BlockNode(70, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
child_hash = get_hash_str([3, 4])
child = BlockNode(71, [1, 2, 3, 4], 0, 2, 1, 2, child_hash, 0, parent=parent)
parent.children[child_hash] = child
parent.shared_count = 0
@@ -1102,9 +1106,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
manager.cache_config.disable_chunked_mm_input = True
block_size = 2
input_ids = [1, 2, 3, 4]
hash_input = manager.hash_block_features(input_ids)
hash_first = manager.hash_block_features([1, 2])
hash_second = manager.hash_block_features([3, 4], ["img"])
hash_input = get_hash_str(input_ids)
hash_first = get_hash_str([1, 2])
hash_second = get_hash_str([3, 4], ["img"])
node1 = BlockNode(80, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
node2 = BlockNode(81, input_ids, hash_input, 2, 1, block_size, hash_second, 0, parent=node1)
manager.radix_tree_root.children[hash_first] = node1
@@ -1144,7 +1148,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_handle_swap_result_updates_status(self):
manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=2)
node = BlockNode(90, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
node = BlockNode(90, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
node.cache_status = CacheStatus.SWAP2CPU
manager.node_map[node.node_id] = node
manager._handle_swap_result(node.node_id, 2, 3, CacheStatus.SWAP2CPU)
@@ -1156,7 +1160,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_reset_clears_internal_state(self):
manager = _create_manager(num_gpu_blocks=2, num_cpu_blocks=1)
node = BlockNode(100, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
node = BlockNode(100, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
manager.node_map[node.node_id] = node
manager.task_swapping_event["evt"] = threading.Event()
manager.task_swapping_event["evt"].set()
@@ -1166,9 +1170,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
def test_recv_data_transfer_result_processes_queue(self):
manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=1)
node = BlockNode(110, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
node = BlockNode(110, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
manager.node_map[node.node_id] = node
payload = [([node.node_id], [2], [3], CacheStatus.SWAP2GPU, "task")]
payload = [(CacheStatus.SWAP2GPU, "task", [node.node_id], [2], [3])]
manager.cache_task_queue = _FakeTransferQueue(payload, include_none=True)
manager.task_swapping_event["task"] = threading.Event()
with self.assertRaises(SystemExit):
@@ -1196,7 +1200,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
request_id="revert",
multimodal_inputs={"mm_positions": [SimpleNamespace(offset=2, length=2)]},
)
node = BlockNode(120, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
node = BlockNode(120, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
matche_nodes = [node]
match_gpu = [0]
match_node_ids = [node.node_id]