mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[Feature] Support KV Cache Storage (#5571)
* Support Mooncake Store * up * up * add op * fix conflict * fix error * up for comments * avoid thread lock * up * fix unittest * fix unittest * remove debug info * consider tp_size > 1 * add default rdma_nics * add utils * up * fix error --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,10 @@
|
||||
/.venv/
|
||||
/venv/
|
||||
|
||||
tests/log_*
|
||||
benchmarks/openai-chat-infqps*
|
||||
splitwise/log*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@@ -102,7 +102,7 @@ def metrics_summary(metrics, token_timestamps):
|
||||
|
||||
# prefill 总耗时
|
||||
summary["prefill_cost_time"] = safe_cost(m0.get("send_request_output_to_decode_time"), arrival_time)
|
||||
# prefill准备耗时
|
||||
# prefill准备总耗时
|
||||
summary["prefill_prepare_cost_time"] = safe_cost(inference_start_time, arrival_time)
|
||||
# 预处理耗时
|
||||
summary["preprocess_cost_time"] = safe_cost(m0.get("scheduler_recv_req_time"), arrival_time)
|
||||
@@ -114,6 +114,10 @@ def metrics_summary(metrics, token_timestamps):
|
||||
summary["ask_decode_resource_cost_time"] = safe_cost(
|
||||
m0.get("ask_decode_resource_finish_time"), m0.get("ask_decode_resource_start_time")
|
||||
)
|
||||
# scheduler调度耗时
|
||||
summary["schedule_cost_time"] = safe_cost(
|
||||
m0.get("inference_start_time"), m0.get("ask_decode_resource_finish_time")
|
||||
)
|
||||
# prefill 的首 token 推理耗时
|
||||
summary["prefill_first_token_infer_cost_time"] = safe_cost(
|
||||
m0.get("engine_recv_first_token_time"), inference_start_time
|
||||
@@ -143,6 +147,19 @@ def metrics_summary(metrics, token_timestamps):
|
||||
token_timestamps[1], m_last.get("decode_recv_second_token_time")
|
||||
)
|
||||
|
||||
# MIX 模式下,scheduler调度耗时
|
||||
summary["mixed_schedule_cost_time"] = safe_cost(m0.get("inference_start_time"), m0.get("engine_get_req_time"))
|
||||
# MIX 模式下,返回首 token 链路耗时
|
||||
summary["mixed_first_token_transmission_cost_time"] = safe_cost(
|
||||
token_timestamps[0], m0.get("engine_recv_first_token_time")
|
||||
)
|
||||
|
||||
summary["gpu_cache_token_num"] = m0.get("gpu_cache_token_num")
|
||||
summary["cpu_cache_token_num"] = m0.get("cpu_cache_token_num")
|
||||
summary["storage_cache_token_num"] = m0.get("storage_cache_token_num")
|
||||
summary["gpu_cpu_cache_prepare_time"] = m0.get("gpu_cpu_cache_prepare_time")
|
||||
summary["storage_cache_prepare_time"] = m0.get("storage_cache_prepare_time")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
|
||||
@@ -695,7 +695,7 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
def process_pd_metrics(model_outputs, metric_key):
|
||||
def process_pd_metrics(model_outputs, metric_key, is_time=True):
|
||||
# 收集所有该 metric 的数值
|
||||
values = []
|
||||
percentiles = []
|
||||
@@ -712,24 +712,29 @@ async def benchmark(
|
||||
print(f"[WARN] metric_key '{metric_key}' not found in outputs.")
|
||||
return
|
||||
|
||||
arr = np.array(values) * 1000 # 秒 -> 毫秒
|
||||
if is_time:
|
||||
arr = np.array(values) * 1000 # 秒 -> 毫秒
|
||||
suffix = "(ms)"
|
||||
else:
|
||||
arr = np.array(values)
|
||||
suffix = ""
|
||||
|
||||
print("{s:{c}^{n}}".format(s=metric_key, n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_key} (ms):",
|
||||
f"Mean {metric_key} {suffix}:",
|
||||
np.mean(arr),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_key} (ms):",
|
||||
f"Median {metric_key} {suffix}:",
|
||||
np.median(arr),
|
||||
)
|
||||
)
|
||||
for p in percentiles:
|
||||
v = np.percentile(arr, p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms):", v))
|
||||
print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} {suffix}:", v))
|
||||
# print(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms): {v:10.2f}")
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
@@ -785,6 +790,7 @@ async def benchmark(
|
||||
process_pd_metrics(outputs, "prefill_prepare_cost_time")
|
||||
process_pd_metrics(outputs, "preprocess_cost_time")
|
||||
process_pd_metrics(outputs, "cache_in_scheduler_cost_time")
|
||||
process_pd_metrics(outputs, "schedule_cost_time")
|
||||
process_pd_metrics(outputs, "ask_decode_resource_cost_time")
|
||||
process_pd_metrics(outputs, "prefill_first_token_infer_cost_time")
|
||||
process_pd_metrics(outputs, "wait_sending_cache_cost_time")
|
||||
@@ -793,6 +799,12 @@ async def benchmark(
|
||||
process_pd_metrics(outputs, "decode_second_token_infer_cost_time")
|
||||
process_pd_metrics(outputs, "first_token_transmission_cost_time")
|
||||
process_pd_metrics(outputs, "second_token_transmission_cost_time")
|
||||
process_pd_metrics(outputs, "mixed_schedule_cost_time")
|
||||
process_pd_metrics(outputs, "gpu_cache_token_num", is_time=False)
|
||||
process_pd_metrics(outputs, "cpu_cache_token_num", is_time=False)
|
||||
process_pd_metrics(outputs, "storage_cache_token_num", is_time=False)
|
||||
process_pd_metrics(outputs, "gpu_cpu_cache_prepare_time")
|
||||
process_pd_metrics(outputs, "storage_cache_prepare_time")
|
||||
process_one_length("input_len", "Cached Tokens", "Cached Tokens")
|
||||
process_one_length("s_input_len", "Input Length", "Infer Input Length")
|
||||
process_one_length("reasoning_len", "Reasoning Lenth", "思考长度")
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
// #define SWAP_DEBUG
|
||||
|
||||
template <paddle::DataType D>
|
||||
void SwapCacheImpLayout(
|
||||
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const int64_t& cache_cpu_pointer, // cpu
|
||||
const std::vector<int64_t>& cache_shape,
|
||||
const std::vector<int64_t>& gpu_block_ids,
|
||||
const std::vector<int64_t>& cpu_block_ids,
|
||||
int mode) {
|
||||
// mode is 0: gpu to cpu; 1: cpu to gpu
|
||||
// cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
// buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int64_t layer_number = cache_gpu_tensors.size();
|
||||
const int64_t num_heads = cache_shape[1];
|
||||
const int64_t block_size = cache_shape[2];
|
||||
const int64_t head_dim = cache_shape[3];
|
||||
const int64_t cache_block_stride = num_heads * block_size * head_dim;
|
||||
|
||||
#ifdef SWAP_DEBUG
|
||||
std::cout << "layer_number:" << layer_number << std::endl;
|
||||
std::cout << "cache_shape:" << cache_shape[0] << ", " << cache_shape[1]
|
||||
<< ", " << cache_shape[2] << ", " << cache_shape[3] << std::endl;
|
||||
std::cout << "cache_block_stride:" << cache_block_stride << std::endl;
|
||||
#endif
|
||||
|
||||
auto stream = cache_gpu_tensors[0].stream();
|
||||
const cudaMemcpyKind copy_kind =
|
||||
(mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
|
||||
|
||||
for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) {
|
||||
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
|
||||
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
|
||||
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
|
||||
// auto stream = cache_gpu.stream();
|
||||
|
||||
for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) {
|
||||
auto cur_gpu_block_id = gpu_block_ids[block_idx];
|
||||
auto cur_cpu_block_id = cpu_block_ids[block_idx];
|
||||
auto* cache_gpu_ptr_now =
|
||||
cache_gpu_ptr + cur_gpu_block_id * cache_block_stride;
|
||||
auto* cache_cpu_ptr_now =
|
||||
cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number +
|
||||
layer_idx * cache_block_stride;
|
||||
|
||||
cudaError_t status = cudaMemcpyAsync(
|
||||
(copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now
|
||||
: cache_gpu_ptr_now,
|
||||
(copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now
|
||||
: cache_cpu_ptr_now,
|
||||
cache_block_stride * sizeof(DataType_),
|
||||
copy_kind,
|
||||
stream);
|
||||
|
||||
#ifdef SWAP_DEBUG
|
||||
cudaStreamSynchronize(stream);
|
||||
std::cout << "mode:" << mode << ", layer_idx:" << layer_idx
|
||||
<< ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:"
|
||||
<< static_cast<float>(*cache_cpu_ptr_now) << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
|
||||
void SwapCacheLayout(
|
||||
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
|
||||
const int64_t& cache_cpu_ptrs, // cpu memory pointer
|
||||
const std::vector<int64_t>& cache_shape,
|
||||
const std::vector<int64_t>& gpu_block_ids,
|
||||
const std::vector<int64_t>& cpu_block_ids,
|
||||
int rank,
|
||||
int mode) {
|
||||
cudaSetDevice(rank); // used for distributed launch
|
||||
assert(cache_gpu_tensors.size() > 0);
|
||||
switch (cache_gpu_tensors[0].dtype()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return SwapCacheImpLayout<paddle::DataType::BFLOAT16>(cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
cache_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
mode);
|
||||
case paddle::DataType::FLOAT16:
|
||||
return SwapCacheImpLayout<paddle::DataType::FLOAT16>(cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
cache_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
mode);
|
||||
case paddle::DataType::UINT8:
|
||||
return SwapCacheImpLayout<paddle::DataType::UINT8>(cache_gpu_tensors,
|
||||
cache_cpu_ptrs,
|
||||
cache_shape,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
mode);
|
||||
default:
|
||||
PD_THROW("Unsupported data type.");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(swap_cache_layout)
|
||||
.Inputs({paddle::Vec("cache_gpu_tensors")})
|
||||
.Attrs({
|
||||
"cache_cpu_ptrs: int64_t",
|
||||
"cache_shape: std::vector<int64_t>",
|
||||
"gpu_block_ids: std::vector<int64_t>",
|
||||
"cpu_block_ids: std::vector<int64_t>",
|
||||
"rank: int",
|
||||
"mode: int",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(SwapCacheLayout));
|
||||
@@ -288,6 +288,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/tune_cublaslt_gemm.cu",
|
||||
"gpu_ops/swap_cache_batch.cu",
|
||||
"gpu_ops/swap_cache.cu",
|
||||
"gpu_ops/swap_cache_layout.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/cpp_extensions.cc",
|
||||
"gpu_ops/share_external_data.cu",
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# MooncakeStore for FastDeploy
|
||||
|
||||
This document describes how to use MooncakeStore as the backend of FastDeploy.
|
||||
|
||||
## Preparation
|
||||
|
||||
### Install FastDeploy
|
||||
|
||||
Refer to [NVIDIA CUDA GPU Installation](https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/) for Fastdeploy installation.
|
||||
|
||||
### Install MooncakeStore
|
||||
|
||||
```bash
|
||||
pip install mooncake-transfer-engine
|
||||
```
|
||||
|
||||
## Run Examples
|
||||
|
||||
The example script is provided in `run.sh`. You can run it directly:
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
In the example script, we will start a Mooncake master server and two FastDeploy server.
|
||||
|
||||
Launch Mooncake master server:
|
||||
```bash
|
||||
mooncake_master \
|
||||
--port=15001 \
|
||||
--enable_http_metadata_server=true \
|
||||
--http_metadata_server_host=0.0.0.0 \
|
||||
--http_metadata_server_port=15002 \
|
||||
--metrics_port=15003 \
|
||||
```
|
||||
|
||||
More parameter can be found in the [official guide](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/python-api-reference/transfer-engine.md).
|
||||
|
||||
Launch the Fastdeploy with Mooncake enabled.
|
||||
|
||||
```bash
|
||||
export MOONCAKE_CONFIG_PATH="./mooncake_config.json"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${PORT} \
|
||||
--metrics-port $((PORT + 1)) \
|
||||
--engine-worker-queue-port $((PORT + 2)) \
|
||||
--cache-queue-port $((PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32 \
|
||||
--kvcache-storage-backend mooncake
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
For more details, please refer to:
|
||||
https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/troubleshooting.md
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"local_hostname":"localhost",
|
||||
"metadata_server":"http://0.0.0.0:15002/metadata",
|
||||
"global_segment_size":8589934592,
|
||||
"local_buffer_size":134217728,
|
||||
"protocol":"rdma",
|
||||
"rdma_devices": "mlx5_1,mlx5_2,mlx5_3,mlx5_4",
|
||||
"master_server_addr":"0.0.0.0:15001"
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
export MOONCAKE_CONFIG_PATH=./mooncake_config.json
|
||||
export FD_DEBUG=1
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
bash stop.sh
|
||||
|
||||
source ./utils.sh
|
||||
|
||||
S0_PORT=52700
|
||||
S1_PORT=52800
|
||||
|
||||
ports=(
|
||||
$S0_PORT $((S0_PORT + 1)) $((S0_PORT + 2)) $((S0_PORT + 3))
|
||||
$S1_PORT $((S1_PORT + 1)) $((S1_PORT + 2)) $((S1_PORT + 3))
|
||||
$ROUTER_PORT
|
||||
)
|
||||
check_ports "${ports[@]}" || {
|
||||
echo "❌ Some ports are in use. Please release them."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Launch MoonCake master
|
||||
nohup mooncake_master \
|
||||
--port=15001 \
|
||||
--enable_http_metadata_server=true \
|
||||
--http_metadata_server_host=0.0.0.0 \
|
||||
--http_metadata_server_port=15002 \
|
||||
--metrics_port=15003 \
|
||||
2>&1 > log_master &
|
||||
|
||||
# Launch FD server 0
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_0"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
echo "server 0 port: ${S0_PORT}"
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${S0_PORT} \
|
||||
--metrics-port $((S0_PORT + 1)) \
|
||||
--engine-worker-queue-port $((S0_PORT + 2)) \
|
||||
--cache-queue-port $((S0_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32 \
|
||||
--kvcache-storage-backend mooncake \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
# Launch FD server 1
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_1"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
echo "server 1 port: ${S1_PORT}"
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port ${S1_PORT} \
|
||||
--metrics-port $((S1_PORT + 1)) \
|
||||
--engine-worker-queue-port $((S1_PORT + 2)) \
|
||||
--cache-queue-port $((S1_PORT + 3)) \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32 \
|
||||
--kvcache-storage-backend mooncake \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${S0_PORT}
|
||||
wait_for_health ${S1_PORT}
|
||||
|
||||
# send request
|
||||
msg="深圳是中国经济实力最强的城市之一。近年来,深圳 GDP 持续稳步增长,**2023 年突破 3.4 万亿元人民币,2024 年接近 3.7 万亿元**,长期位居全国城市前列。深圳经济以第二产业和第三产业为主,高端制造业、电子信息产业和现代服务业发达,形成了以科技创新为核心的产业结构。依托华为、腾讯、大疆等龙头企业,深圳在数字经济、人工智能、新能源等领域具有显著优势。同时,深圳进出口总额常年位居全国城市第一,是中国对外开放和高质量发展的重要引擎。深圳2024年 GDP 是多少?"
|
||||
|
||||
echo "send request to server_0"
|
||||
curl -X POST "http://0.0.0.0:${S0_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"messages\": [
|
||||
{\"role\": \"user\", \"content\": \"${msg}\"}
|
||||
],
|
||||
\"max_tokens\": 50,
|
||||
\"stream\": false,
|
||||
\"top_p\": 0
|
||||
}"
|
||||
|
||||
sleep 5
|
||||
|
||||
echo "send request to server_1"
|
||||
curl -X POST "http://0.0.0.0:${S1_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"messages\": [
|
||||
{\"role\": \"user\", \"content\": \"${msg}\"}
|
||||
],
|
||||
\"max_tokens\": 50,
|
||||
\"stream\": false,
|
||||
\"top_p\": 0
|
||||
}"
|
||||
@@ -0,0 +1,97 @@
|
||||
#!/bin/bash
|
||||
|
||||
is_port_free() {
|
||||
local port=$1
|
||||
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
|
||||
return 1 # Port is occupied
|
||||
fi
|
||||
return 0 # Port is free
|
||||
}
|
||||
|
||||
check_ports() {
|
||||
for port in "$@"; do
|
||||
if ! is_port_free $port; then
|
||||
echo "❌ Port $port is already in use"
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
return 0
|
||||
}
|
||||
|
||||
wait_for_health() {
|
||||
IFS=',' read -r -a server_ports <<< "$1"
|
||||
local num_ports=${#server_ports[@]}
|
||||
local total_lines=$((num_ports + 1))
|
||||
local first_run=true
|
||||
local GREEN='\033[0;32m'
|
||||
local RED='\033[0;31m'
|
||||
local NC='\033[0m' # No Color
|
||||
local start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
local all_ready=true
|
||||
for port in "${server_ports[@]}"; do
|
||||
status_code=$(curl -s --max-time 1 -o /dev/null -w "%{http_code}" "http://0.0.0.0:${port}/health" || echo "000")
|
||||
if [ "$status_code" -eq 200 ]; then
|
||||
printf "Port %s: ${GREEN}[OK] 200${NC}\033[K\n" "$port"
|
||||
else
|
||||
all_ready=false
|
||||
printf "Port %s: ${RED}[WAIT] %s${NC}\033[K\n" "$port" "$status_code"
|
||||
fi
|
||||
done
|
||||
cur_time=$(date +%s)
|
||||
if [ "$all_ready" = "true" ]; then
|
||||
echo "All services are ready! [$((cur_time-start_time))s]"
|
||||
break
|
||||
else
|
||||
echo "Waiting for services... [$((cur_time-start_time))s]"
|
||||
printf "\033[%dA" "$total_lines" # roll back cursor
|
||||
sleep 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
get_free_ports() {
|
||||
free_ports_num=${1:-1}
|
||||
start_port=${2:-8000}
|
||||
end_port=${3:-9000}
|
||||
|
||||
free_ports=()
|
||||
if [[ ! -n ${free_ports_num} || "${free_ports_num}" -le 0 ]]; then
|
||||
log_warn "param can't be empty, and should > 0"
|
||||
echo ${free_ports[@]}
|
||||
return 1
|
||||
fi
|
||||
|
||||
used_ports1=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($4,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
|
||||
used_ports2=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($5,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
|
||||
all_used_ports=$(printf "%s\n" "${used_ports1}" "${used_ports2}" | sort -u)
|
||||
|
||||
# Generate random number between 0 and 32767
|
||||
random_num=$(( RANDOM ))
|
||||
port=$(( random_num % (end_port - start_port + 1) + start_port ))
|
||||
|
||||
while true; do
|
||||
(( port++ ))
|
||||
if [[ ${port} -ge ${end_port} ]]; then
|
||||
port=${start_port}
|
||||
fi
|
||||
|
||||
if [[ "${all_used_ports[@]}" =~ "${port}" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
if is_port_free ${port}; then
|
||||
free_ports+=("${port}")
|
||||
(( free_ports_num-- ))
|
||||
if [[ ${free_ports_num} = 0 ]]; then
|
||||
break
|
||||
fi
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
# echo ${free_ports[@]}
|
||||
IFS=',' && echo "${free_ports[*]}"
|
||||
return 0
|
||||
}
|
||||
@@ -30,6 +30,8 @@ class CacheStatus(Enum):
|
||||
SWAP2CPU = 1
|
||||
SWAP2GPU = 2
|
||||
CPU = 3
|
||||
GPU2STORAGE = 4
|
||||
STORAGE2GPU = 5
|
||||
|
||||
|
||||
class BlockNode:
|
||||
|
||||
@@ -22,6 +22,7 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -36,8 +37,10 @@ from fastdeploy.cache_manager.ops import (
|
||||
set_device,
|
||||
share_external_data_,
|
||||
swap_cache_all_layers,
|
||||
swap_cache_layout,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.cache_manager.transfer_factory import MooncakeStore
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -55,8 +58,9 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument(
|
||||
@@ -101,6 +105,20 @@ def parse_args():
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--create_cache_tensor", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kvcache_storage_backend",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mooncake", "none"],
|
||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--write_policy",
|
||||
type=str,
|
||||
choices=["write_through"],
|
||||
default="write_through",
|
||||
help="KVCache write policy",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -135,6 +153,9 @@ class CacheTransferManager:
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.read_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.write_back_storage_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.timeout_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
|
||||
self.transfer_task_queue = queue.Queue() # 用来接收传输任务
|
||||
self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕
|
||||
self.n_ranks = args.mp_num
|
||||
@@ -194,8 +215,54 @@ class CacheTransferManager:
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
|
||||
self.storage_backend = None
|
||||
elif args.kvcache_storage_backend == "mooncake":
|
||||
logger.info("Start initialize mooncake store...")
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
|
||||
buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
|
||||
"""
|
||||
layer_num = args.num_layers + self.num_extra_layers
|
||||
head_num = self.key_cache_shape[1]
|
||||
block_size = self.key_cache_shape[2]
|
||||
head_dim = self.key_cache_shape[3]
|
||||
block_num = (args.max_model_len + block_size - 1) // block_size
|
||||
logger.info(
|
||||
f"Creating cache buffer for storage with shape: "
|
||||
f"[{block_num}, {layer_num}, {head_num}, {block_size}, {head_dim}]"
|
||||
)
|
||||
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
self.storage_buffer_stride_bytes = layer_num * head_num * block_size * head_dim * self.cache_bytes
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
read_buffer = cuda_host_alloc(total_bytes)
|
||||
self.storage_key_read_buffer = read_buffer
|
||||
self.storage_value_read_buffer = read_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(read_buffer, total_bytes)
|
||||
|
||||
write_buffer = cuda_host_alloc(total_bytes)
|
||||
self.storage_key_write_buffer = write_buffer
|
||||
self.storage_value_write_buffer = write_buffer + total_bytes // 2
|
||||
self.storage_backend.register_buffer(write_buffer, total_bytes)
|
||||
|
||||
def _init_gpu_cache(self, args):
|
||||
|
||||
try:
|
||||
@@ -319,12 +386,7 @@ class CacheTransferManager:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
if args.cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8":
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
|
||||
cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
@@ -367,6 +429,222 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _get_cache_bytes(self, cache_dtype):
|
||||
if cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif cache_dtype in ["uint8", "block_wise_fp8"]:
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return cache_bytes
|
||||
|
||||
def _storage_exist_block_num(self, k_keys: List[str], v_keys: List[str]):
|
||||
"""
|
||||
Given the k_keys and v_keys, get the valid blocks number that
|
||||
can be prefetched from storage backend.
|
||||
"""
|
||||
assert len(k_keys) == len(v_keys), "k_keys and v_keys must have the same length."
|
||||
result = self.storage_backend.exists(k_keys + v_keys)
|
||||
|
||||
# only consider the case when both key and value exist
|
||||
num = 0
|
||||
for k, v in zip(k_keys, v_keys):
|
||||
if result[k] and result[v]:
|
||||
num += 1
|
||||
return num
|
||||
|
||||
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_read_storage, key_hash_keys: {k_cache_keys}, "
|
||||
f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [self.storage_key_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_read_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
|
||||
k_result, v_result = result[:block_num], result[block_num:]
|
||||
success_block_num = 0
|
||||
for k, v in zip(k_result, v_result):
|
||||
if k > 0 and v > 0:
|
||||
success_block_num += 1
|
||||
logger.debug(f"_run_read_storage, success_block_num: {success_block_num}")
|
||||
valid_gpu_block_ids = gpu_block_ids[:success_block_num]
|
||||
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
|
||||
|
||||
mode = 1 # cpu ==> gpu
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_read_buffer,
|
||||
self.key_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_read_buffer,
|
||||
self.value_cache_shape,
|
||||
valid_gpu_block_ids,
|
||||
valid_cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
|
||||
return valid_gpu_block_ids
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_read_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
def read_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
"""Read cache from the storage backend to the GPU memory."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"read_storage_task, task id: {task_id}, hash_keys: {keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}"
|
||||
)
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
logger.debug(f"read_storage_task, match {match_block_num} blocks from storage for task id: {task_id}")
|
||||
|
||||
k_cache_keys = k_cache_keys[:match_block_num]
|
||||
v_cache_keys = v_cache_keys[:match_block_num]
|
||||
gpu_block_ids = gpu_block_ids[:match_block_num]
|
||||
cpu_block_ids = [i for i in range(match_block_num)]
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
if match_block_num > 0:
|
||||
# TODO: support timeout with actual block count
|
||||
try:
|
||||
valid_gpu_block_ids = self._run_read_storage(
|
||||
k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids
|
||||
)
|
||||
logger.info(
|
||||
f"read_storage_task, finish loading {match_block_num} blocks from storage for task {task_id}."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[rank {self.rank}/{self.n_ranks}] An error occurred: {task_id} {e}")
|
||||
valid_gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.STORAGE2GPU, task_id, keys, valid_gpu_block_ids)
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.wait()
|
||||
self.cache_task_queue.swap_storage_to_gpu_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result)
|
||||
logger.debug(f"read_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(
|
||||
f"read_storage_task: put_transfer_done_signal for transfer_task_id {task_id}, "
|
||||
f"valid block num {len(valid_gpu_block_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in read_storage_task: "
|
||||
f"task_id: {task_id}, error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _run_write_back_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
|
||||
try:
|
||||
logger.debug(
|
||||
f"_run_write_back_storage, k_cache_keys: {k_cache_keys}, v_cache_keys: {v_cache_keys}, "
|
||||
f"gpu_block_ids: {gpu_block_ids}"
|
||||
)
|
||||
key_cache_size = [
|
||||
self.key_cache_shape[0],
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
mode = 0 # gpu ==> cpu
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_k_tensors,
|
||||
self.storage_key_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
swap_cache_layout(
|
||||
self.gpu_cache_v_tensors,
|
||||
self.storage_value_write_buffer,
|
||||
key_cache_size,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
self.device,
|
||||
mode,
|
||||
)
|
||||
|
||||
block_num = len(gpu_block_ids)
|
||||
keys = k_cache_keys + v_cache_keys
|
||||
k_cache_ptrs = [
|
||||
self.storage_key_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
v_cache_ptrs = [
|
||||
self.storage_value_write_buffer + i * self.storage_buffer_stride_bytes for i in cpu_block_ids
|
||||
]
|
||||
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
|
||||
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
|
||||
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def write_back_storage_task(self, task_id, keys, gpu_block_ids, timeout=0.1):
|
||||
"""
|
||||
Write cache to the storage backend from the GPU memory.
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"write cache to storage, keys: {keys}, gpu_block_ids: {gpu_block_ids}, "
|
||||
f"task_id: {task_id}, timeout: {timeout}"
|
||||
)
|
||||
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
|
||||
match_block_num = self._storage_exist_block_num(k_cache_keys, v_cache_keys)
|
||||
|
||||
k_cache_keys = k_cache_keys[match_block_num:]
|
||||
v_cache_keys = v_cache_keys[match_block_num:]
|
||||
gpu_block_ids = gpu_block_ids[match_block_num:]
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
|
||||
if len(k_cache_keys) == 0:
|
||||
logger.info(f"No uncached keys found for task {task_id}")
|
||||
gpu_block_ids = []
|
||||
else:
|
||||
try:
|
||||
# TODO: support timeout with actual block count
|
||||
self._run_write_back_storage(k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in write back storage task: {e}")
|
||||
gpu_block_ids = []
|
||||
|
||||
result = (CacheStatus.GPU2STORAGE, task_id, keys, gpu_block_ids)
|
||||
self.cache_task_queue.swap_to_storage_barrier.wait()
|
||||
if self.rank == 0: # 只有当rank为0时执行同步操作
|
||||
self.cache_task_queue.swap_to_storage_barrier.reset()
|
||||
self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号
|
||||
logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}")
|
||||
logger.info(f"write_back_storage_task: put_transfer_done_signal for transfer_task_id {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[rank {self.rank}/{self.n_ranks}] An error occurred in write_back_storage_task: "
|
||||
f"error:{e}, {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _do_swap_to_cpu_task(
|
||||
self,
|
||||
swap_node_ids,
|
||||
@@ -459,14 +737,9 @@ class CacheTransferManager:
|
||||
logger.debug(f"transfer data: get_transfer_task {data}")
|
||||
if read_finish:
|
||||
self.cache_task_broadcast_signal.value[0] = 0
|
||||
(
|
||||
swap_node_ids,
|
||||
gpu_block_id,
|
||||
cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
) = data
|
||||
event_type, transfer_task_id = data[0], data[1]
|
||||
if event_type.value == CacheStatus.SWAP2CPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
self.swap_to_cpu_thread_pool.submit(
|
||||
self._do_swap_to_cpu_task,
|
||||
swap_node_ids,
|
||||
@@ -475,7 +748,8 @@ class CacheTransferManager:
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
else:
|
||||
elif event_type.value == CacheStatus.SWAP2GPU.value:
|
||||
swap_node_ids, gpu_block_id, cpu_block_id = data[2:]
|
||||
self.swap_to_gpu_thread_pool.submit(
|
||||
self._do_swap_to_gpu_task,
|
||||
swap_node_ids,
|
||||
@@ -484,6 +758,24 @@ class CacheTransferManager:
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
elif event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
self.read_storage_thread_pool.submit(
|
||||
self.read_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
hash_keys, gpu_block_ids, timeout = data[2:]
|
||||
self.write_back_storage_thread_pool.submit(
|
||||
self.write_back_storage_task,
|
||||
transfer_task_id,
|
||||
hash_keys,
|
||||
gpu_block_ids,
|
||||
timeout,
|
||||
)
|
||||
else:
|
||||
if self.n_ranks > 1:
|
||||
self.cache_task_queue.barrier2.wait()
|
||||
@@ -635,11 +927,11 @@ class CacheTransferManager:
|
||||
+ f"transfer {len(gpu_block_ids)} blocks done elapsed_time {elasped_time:.4f}"
|
||||
)
|
||||
return (
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
|
||||
def clear_or_update_caches(self, args):
|
||||
@@ -738,9 +1030,7 @@ def main():
|
||||
"""
|
||||
启动cache manager
|
||||
"""
|
||||
|
||||
cache_manager = CacheTransferManager(args)
|
||||
|
||||
cache_manager.do_data_transfer()
|
||||
|
||||
|
||||
@@ -749,5 +1039,10 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
|
||||
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
|
||||
logger.info(f"args: {vars(args)}")
|
||||
set_device(args.device_id)
|
||||
main()
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
logger.error(f"cache_transfer_manager failed with error: {e}, traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
@@ -30,6 +30,7 @@ try:
|
||||
set_data_ipc,
|
||||
share_external_data,
|
||||
swap_cache_all_layers,
|
||||
swap_cache_layout,
|
||||
unset_data_ipc,
|
||||
)
|
||||
|
||||
@@ -50,6 +51,7 @@ try:
|
||||
)
|
||||
|
||||
unset_data_ipc = None
|
||||
swap_cache_layout = None
|
||||
memory_allocated = paddle.device.xpu.memory_allocated
|
||||
|
||||
def get_data_ptr_ipc(*args, **kwargs):
|
||||
@@ -102,6 +104,7 @@ except:
|
||||
ipc_sent_key_value_cache_by_remote_ptr_block_sync = None
|
||||
get_peer_mem_addr = None
|
||||
get_all_visible_devices = None
|
||||
swap_cache_layout = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -119,4 +122,5 @@ __all__ = [
|
||||
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
|
||||
"get_peer_mem_addr",
|
||||
"get_all_visible_devices",
|
||||
"swap_cache_layout",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import heapq
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
@@ -34,9 +32,10 @@ from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import get_hash_str, get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
@@ -65,6 +64,7 @@ class PrefixCacheManager:
|
||||
self.enable_splitwise = 0
|
||||
self.splitwise_role = splitwise_role
|
||||
self.config = config
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.cache_config = config.cache_config
|
||||
self.speculative_config = config.speculative_config
|
||||
self.local_data_parallel_id = local_data_parallel_id
|
||||
@@ -89,6 +89,13 @@ class PrefixCacheManager:
|
||||
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
|
||||
# prams for cache storage
|
||||
self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend
|
||||
self.write_policy = self.cache_config.write_policy
|
||||
self.task_write_back_event = {}
|
||||
self.task_prefetch_event = {}
|
||||
self.storage_prefetch_block_ids = {}
|
||||
|
||||
# gpu cache data structure
|
||||
self.gpu_lru_leaf_heap = []
|
||||
self.gpu_lru_leaf_set = set()
|
||||
@@ -105,7 +112,7 @@ class PrefixCacheManager:
|
||||
self.req_leaf_map = {} # {request_id: leaf node}
|
||||
self.leaf_req_map = defaultdict(set)
|
||||
self.unfilled_req_block_map = defaultdict(list)
|
||||
self.cache_info = {}
|
||||
self.cache_info = {} # {request_id: (last_match_node, num_cached_tokens)}
|
||||
|
||||
self.executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
@@ -253,6 +260,10 @@ class PrefixCacheManager:
|
||||
else:
|
||||
val_shape_str = str(val_cache_shape)
|
||||
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||
if cache_config.kvcache_storage_backend:
|
||||
kvcache_storage_backend_str = cache_config.kvcache_storage_backend
|
||||
else:
|
||||
kvcache_storage_backend_str = "none"
|
||||
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
@@ -281,7 +292,10 @@ class PrefixCacheManager:
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
|
||||
+ f" --write_policy {cache_config.write_policy}"
|
||||
+ f" --max_model_len {self.config.model_config.max_model_len}"
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
@@ -290,7 +304,7 @@ class PrefixCacheManager:
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
if self.num_cpu_blocks > 0:
|
||||
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
@@ -303,7 +317,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
threading.Thread(target=self.recv_data_transfer_result).start()
|
||||
if cache_config.enable_prefix_caching:
|
||||
@@ -505,13 +519,7 @@ class PrefixCacheManager:
|
||||
|
||||
self.task_swapping_event[transfer_task_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task(
|
||||
(
|
||||
swap_node_ids,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
(event_type, transfer_task_id, swap_node_ids, gpu_block_ids, cpu_block_ids)
|
||||
)
|
||||
if is_sync:
|
||||
self.sync_swap_task(transfer_task_id)
|
||||
@@ -629,6 +637,10 @@ class PrefixCacheManager:
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||
self.leaf_req_map[last_node].remove(req_id)
|
||||
logger.debug(
|
||||
f"update_cache_blocks: req_id {req_id}, num_cached_tokens {num_cached_tokens}, "
|
||||
f"can_cache_computed_tokens {can_cache_computed_tokens}"
|
||||
)
|
||||
|
||||
with self.request_release_lock:
|
||||
leaf_node = self.mm_build_path(
|
||||
@@ -640,7 +652,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
self.req_leaf_map[req_id] = leaf_node
|
||||
self.leaf_req_map[leaf_node].add(req_id)
|
||||
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
|
||||
self.cache_info[req_id] = [leaf_node, can_cache_computed_tokens]
|
||||
task.cached_block_num = can_cache_computed_tokens // block_size
|
||||
except Exception as e:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
@@ -692,7 +704,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
req_id = task.request_id
|
||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||
logger.info(f"request_match_blocks: start to process req {req_id}")
|
||||
input_token_num = len(prompt_token_ids + task.output_token_ids)
|
||||
common_block_ids = []
|
||||
# 1. match block
|
||||
@@ -708,12 +720,14 @@ class PrefixCacheManager:
|
||||
# update matched node info
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
|
||||
|
||||
# 2. prepare cache
|
||||
# allocate gpu cache for matched cpu blocks
|
||||
# 2. prepare cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
|
||||
gpu_recv_block_ids = []
|
||||
match_cpu_blocks_num = len(match_cpu_block_ids)
|
||||
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
|
||||
if match_cpu_blocks_num > 0:
|
||||
logger.debug(
|
||||
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
|
||||
)
|
||||
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
|
||||
if len(gpu_recv_block_ids) > 0:
|
||||
self._prepare_cpu_cache(
|
||||
@@ -746,20 +760,58 @@ class PrefixCacheManager:
|
||||
self.metrics._update_history_hit_metrics()
|
||||
if self.metrics.req_count % 10000 == 0:
|
||||
self.metrics.reset_metrics()
|
||||
logger.info(
|
||||
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
|
||||
)
|
||||
logger.info(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
|
||||
# set leaf node temporarily, then update it in update_cache_blocks
|
||||
self.req_leaf_map[req_id] = match_block_node
|
||||
self.leaf_req_map[match_block_node].add(req_id)
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
|
||||
self.cache_info[req_id] = [match_block_node, len(common_block_ids) * block_size]
|
||||
task.cached_block_num = len(common_block_ids)
|
||||
return common_block_ids, matched_token_num, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
def request_match_storage_blocks(self, request, extra_gpu_block_ids):
|
||||
"""
|
||||
Match and fetch the cached blocks from the storage backend for the given request.
|
||||
# TODO: merge this function into request_match_blocks
|
||||
|
||||
args:
|
||||
request: The request to be processed
|
||||
extra_gpu_block_ids: A list of GPU block IDs to be used for fetching the cache
|
||||
returns:
|
||||
matched_block_ids: A list of block IDs that prefetched cache from storage
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return []
|
||||
|
||||
req_id = request.request_id
|
||||
input_ids = request.prompt_token_ids
|
||||
block_size = self.cache_config.block_size
|
||||
|
||||
prefix_block_key = []
|
||||
num_cached_tokens = 0
|
||||
if req_id in self.cache_info:
|
||||
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
block_keys = []
|
||||
current_tokens = num_cached_tokens
|
||||
while current_tokens <= len(input_ids) - block_size:
|
||||
cur_block_key = get_hash_str(input_ids[current_tokens : current_tokens + block_size], prefix_block_key)
|
||||
block_keys.append(cur_block_key)
|
||||
current_tokens += block_size
|
||||
prefix_block_key = [cur_block_key]
|
||||
|
||||
logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}")
|
||||
matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids)
|
||||
logger.info(
|
||||
f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}"
|
||||
)
|
||||
|
||||
return matched_block_ids
|
||||
|
||||
def request_block_ids(self, task, block_size, dec_token_num, *args):
|
||||
"""
|
||||
Allocate blocks for a task.
|
||||
@@ -806,10 +858,7 @@ class PrefixCacheManager:
|
||||
current_time = time.time()
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time)
|
||||
# 2. prepare cache
|
||||
(
|
||||
gpu_recv_block_ids,
|
||||
gpu_extra_block_ids,
|
||||
) = self._prepare_cache(
|
||||
(gpu_recv_block_ids, gpu_extra_block_ids) = self._prepare_cache(
|
||||
req_id,
|
||||
input_ids,
|
||||
block_size,
|
||||
@@ -829,7 +878,6 @@ class PrefixCacheManager:
|
||||
gpu_build_path_block_ids = []
|
||||
|
||||
gpu_build_path_block_ids = gpu_extra_block_ids
|
||||
|
||||
leaf_node = self.build_path(
|
||||
req_id,
|
||||
current_time,
|
||||
@@ -883,6 +931,7 @@ class PrefixCacheManager:
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
req_id = task.request_id
|
||||
keys = []
|
||||
leaf_node = self.req_leaf_map.pop(req_id)
|
||||
if leaf_node in self.leaf_req_map:
|
||||
self.leaf_req_map[leaf_node].remove(req_id)
|
||||
@@ -893,6 +942,7 @@ class PrefixCacheManager:
|
||||
if req_id in node.req_id_set:
|
||||
node.req_id_set.remove(req_id)
|
||||
node.decrement_shared_count()
|
||||
keys.append(node.hash_value)
|
||||
node = node.parent
|
||||
|
||||
if req_id in self.cache_info:
|
||||
@@ -919,6 +969,78 @@ class PrefixCacheManager:
|
||||
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def write_cache_to_storage(self, request: Request):
|
||||
"""
|
||||
For finished request, write cache to storage.
|
||||
NOTE: this function does not modify the global params
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
req_id = request.request_id
|
||||
keys = []
|
||||
node = self.req_leaf_map[req_id]
|
||||
while node != self.radix_tree_root:
|
||||
keys.append(node.hash_value)
|
||||
node = node.parent
|
||||
keys = list(reversed(keys))
|
||||
if not keys:
|
||||
return
|
||||
|
||||
gpu_block_ids = request.block_tables[: len(keys)]
|
||||
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
||||
tic = time.time()
|
||||
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
|
||||
cost_time = time.time() - tic
|
||||
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
||||
|
||||
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
if len(hash_keys) != len(gpu_block_ids):
|
||||
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
self.task_write_back_event[req_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
if is_sync:
|
||||
self.wait_write_storage_task(req_id)
|
||||
|
||||
def wait_write_storage_task(self, req_id):
|
||||
"""
|
||||
Sync write back task
|
||||
"""
|
||||
if req_id in self.task_write_back_event:
|
||||
self.task_write_back_event[req_id].wait()
|
||||
del self.task_write_back_event[req_id]
|
||||
|
||||
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
"""
|
||||
Prefetch cache from storage task
|
||||
"""
|
||||
storage_block_ids = []
|
||||
self.task_prefetch_event[req_id] = Event()
|
||||
# issue task to cache_transfer_manager
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
if is_sync:
|
||||
storage_block_ids = self.wait_prefetch_storage_task(req_id)
|
||||
return storage_block_ids
|
||||
|
||||
def wait_prefetch_storage_task(self, req_id):
|
||||
"""
|
||||
Wait for prefetch cache from storage task to finish
|
||||
"""
|
||||
if req_id not in self.task_prefetch_event:
|
||||
return None
|
||||
|
||||
self.task_prefetch_event[req_id].wait()
|
||||
storage_block_ids = self.storage_prefetch_block_ids[req_id]
|
||||
del self.task_prefetch_event[req_id]
|
||||
del self.storage_prefetch_block_ids[req_id]
|
||||
return storage_block_ids
|
||||
|
||||
def free_nodes_directly(self, node):
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
@@ -1069,10 +1191,7 @@ class PrefixCacheManager:
|
||||
break
|
||||
node = heapq.heappop(self.gpu_lru_leaf_heap)
|
||||
self.gpu_lru_leaf_set.remove(node)
|
||||
if (
|
||||
not self.cache_config.enable_hierarchical_cache
|
||||
or self.cache_config.num_cpu_blocks < need_block_num
|
||||
):
|
||||
if self.cache_config.num_cpu_blocks < need_block_num:
|
||||
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
|
||||
self._handle_free_gpu_node_without_cpu(node)
|
||||
total_gpu_free_count += 1
|
||||
@@ -1195,12 +1314,6 @@ class PrefixCacheManager:
|
||||
)
|
||||
return total_cpu_free_count
|
||||
|
||||
def cal_block_hash(self, block):
|
||||
"""
|
||||
calculate hash value of a block
|
||||
"""
|
||||
return hash(tuple(block))
|
||||
|
||||
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
|
||||
"""
|
||||
Retrieves additional hash keys for block identification.
|
||||
@@ -1260,16 +1373,6 @@ class PrefixCacheManager:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return len(mm_inputs["mm_positions"]) - 1, hash_keys
|
||||
|
||||
def hash_block_features(self, input_ids, extra_keys: list = []):
|
||||
"""
|
||||
calculate hash value of a block with additional keys
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
extra_keys: Additional keys for block identification
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def _revert_match_blocks(
|
||||
self,
|
||||
request,
|
||||
@@ -1363,6 +1466,7 @@ class PrefixCacheManager:
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
prefix_block_key = []
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
@@ -1376,7 +1480,10 @@ class PrefixCacheManager:
|
||||
end_idx=match_token_num + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(token_block, extra_keys)
|
||||
prefix_block_key.extend(extra_keys)
|
||||
hash_value = get_hash_str(token_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
@@ -1476,6 +1583,7 @@ class PrefixCacheManager:
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
prefix_block_key = []
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
@@ -1483,7 +1591,8 @@ class PrefixCacheManager:
|
||||
token_num = len(token_block)
|
||||
if token_num != block_size:
|
||||
break
|
||||
hash_value = self.cal_block_hash(token_block)
|
||||
hash_value = get_hash_str(token_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
@@ -1515,6 +1624,8 @@ class PrefixCacheManager:
|
||||
swap_node_ids.append(child.node_id)
|
||||
match_token_num = match_token_num + block_size
|
||||
current_match_node = child
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = [child, match_token_num]
|
||||
else:
|
||||
break
|
||||
|
||||
@@ -1577,8 +1688,10 @@ class PrefixCacheManager:
|
||||
has_unfilled_block = False
|
||||
current_time = time.time()
|
||||
|
||||
input_hash_value = self.hash_block_features(input_ids)
|
||||
input_hash_value = get_hash_str(input_ids)
|
||||
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
|
||||
current_block = input_ids[i : i + block_size]
|
||||
current_block_size = len(current_block) # 最后一个block可能没填满
|
||||
@@ -1591,7 +1704,9 @@ class PrefixCacheManager:
|
||||
end_idx=i + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(current_block, extra_keys)
|
||||
prefix_block_key.extend(extra_keys)
|
||||
hash_value = get_hash_str(current_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
@@ -1651,7 +1766,7 @@ class PrefixCacheManager:
|
||||
gpu_block_ids = gpu_block_ids.copy()
|
||||
node = last_node
|
||||
reverved_dec_block_ids = []
|
||||
input_hash_value = self.cal_block_hash(input_ids)
|
||||
input_hash_value = get_hash_str(input_ids)
|
||||
|
||||
token_num = len(left_input_ids)
|
||||
if token_num == 0:
|
||||
@@ -1663,6 +1778,7 @@ class PrefixCacheManager:
|
||||
unique_node_ids = []
|
||||
new_last_node = last_node
|
||||
has_unfilled_block = False
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
for i in range(0, token_num, block_size):
|
||||
current_block = left_input_ids[i : i + block_size]
|
||||
@@ -1670,7 +1786,8 @@ class PrefixCacheManager:
|
||||
if current_block_size != block_size:
|
||||
has_unfilled_block = True
|
||||
else:
|
||||
hash_value = self.cal_block_hash(current_block)
|
||||
hash_value = get_hash_str(current_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
@@ -1764,28 +1881,47 @@ class PrefixCacheManager:
|
||||
if data is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
(
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
) = data
|
||||
length = len(task_gpu_block_id)
|
||||
for i in range(length):
|
||||
self._handle_swap_result(
|
||||
swap_node_ids[i],
|
||||
task_gpu_block_id[i],
|
||||
task_cpu_block_id[i],
|
||||
event_type = data[0]
|
||||
|
||||
if event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
logger.info(f"recv_data_transfer_result: {data}")
|
||||
task_id, hash_keys, block_ids = data[1:]
|
||||
if task_id not in self.storage_prefetch_block_ids:
|
||||
self.storage_prefetch_block_ids[task_id] = []
|
||||
saved_block_ids = self.storage_prefetch_block_ids[task_id]
|
||||
saved_block_ids.append(block_ids)
|
||||
if len(saved_block_ids) == self.tensor_parallel_size:
|
||||
self.storage_prefetch_block_ids[task_id] = min(saved_block_ids, key=len)
|
||||
if task_id in self.task_prefetch_event:
|
||||
self.task_prefetch_event[task_id].set()
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
logger.info(f"recv_data_transfer_result: {data}")
|
||||
task_id, hash_keys, block_ids = data[1:]
|
||||
if task_id in self.task_write_back_event:
|
||||
self.task_write_back_event[task_id].set()
|
||||
else:
|
||||
(
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
) = data
|
||||
length = len(task_gpu_block_id)
|
||||
for i in range(length):
|
||||
self._handle_swap_result(
|
||||
swap_node_ids[i],
|
||||
task_gpu_block_id[i],
|
||||
task_cpu_block_id[i],
|
||||
event_type,
|
||||
)
|
||||
if transfer_task_id in self.task_swapping_event:
|
||||
self.task_swapping_event[transfer_task_id].set()
|
||||
logger.info(
|
||||
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
|
||||
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
|
||||
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
|
||||
)
|
||||
if transfer_task_id in self.task_swapping_event:
|
||||
self.task_swapping_event[transfer_task_id].set()
|
||||
logger.info(
|
||||
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
|
||||
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
|
||||
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
@@ -14,7 +14,21 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .ipc_cache_transfer import IPCCommManager
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from .kvcache_storage import KVCacheStorage
|
||||
from .mooncake_store import MooncakeStore
|
||||
from .rdma_cache_transfer import RDMACommManager
|
||||
|
||||
__all__ = ["IPCCommManager", "RDMACommManager"]
|
||||
if current_platform.is_cuda():
|
||||
from .ipc_cache_transfer import IPCCommManager
|
||||
else:
|
||||
IPCCommManager = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"IPCCommManager",
|
||||
"RDMACommManager",
|
||||
"KVCacheStorage",
|
||||
"MooncakeStore",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_storage", "cache_storage.log")
|
||||
|
||||
|
||||
class KVCacheStorage(ABC):
|
||||
"""
|
||||
KVCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_size: Optional[Any] = None,
|
||||
) -> paddle.Tensor | None:
|
||||
"""
|
||||
Retrieve the value associated with the given key.
|
||||
Returns None if the key does not exist.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[paddle.Tensor | None]:
|
||||
"""
|
||||
Retrieve values for multiple keys.
|
||||
Returns a list of tensors or None for each key.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_size: Optional[Any] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store the value associated with the given key.
|
||||
Returns True if the operation was successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store multiple key-value pairs.
|
||||
Returns True if all operations were successful, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: List[str]) -> bool:
|
||||
"""
|
||||
Check if the key exists in the storage.
|
||||
Returns True if the key exists, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> bool:
|
||||
"""
|
||||
Clear all keys in storage
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .mooncake_store import MooncakeStore
|
||||
|
||||
__all__ = ["MooncakeStore"]
|
||||
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory.kvcache_storage import (
|
||||
KVCacheStorage,
|
||||
logger,
|
||||
)
|
||||
from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128MB
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
rdma_devices: str
|
||||
master_server_addr: str
|
||||
|
||||
@staticmethod
|
||||
def create() -> "MooncakeStoreConfig":
|
||||
"""Load the config from a JSON file or environment variables."""
|
||||
config = {}
|
||||
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
|
||||
if file_path is None:
|
||||
local_hostname = os.environ.get("MOONCAKE_LOCAL_HOSTNAME")
|
||||
metadata_server = os.environ.get("MOONCAKE_METADATA_SERVER")
|
||||
global_segment_size = os.environ.get("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||
local_buffer_size = os.environ.get("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
|
||||
protocol = os.environ.get("MOONCAKE_PROTOCOL", "rdma")
|
||||
rdma_devices = os.environ.get("MOONCAKE_RDMA_DEVICES", "")
|
||||
master_server_addr = os.environ.get("MOONCAKE_MASTER_SERVER_ADDR")
|
||||
else:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File path {file_path} for creating MooncakeStoreConfig does not exist.")
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
|
||||
local_hostname = config.get("local_hostname")
|
||||
metadata_server = config.get("metadata_server")
|
||||
global_segment_size = config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||
local_buffer_size = config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)
|
||||
protocol = config.get("protocol", "rdma")
|
||||
rdma_devices = config.get("rdma_devices", "")
|
||||
master_server_addr = config.get("master_server_addr")
|
||||
|
||||
if rdma_devices == "" and current_platform.is_cuda():
|
||||
# FIXME: use auto-select NICs in MooncakeStore will raise error and roll back to using TCP
|
||||
rdma_devices = get_rdma_nics()
|
||||
logger.info(f"No RDMA devices specified, defaulting to all available devices: {rdma_devices}")
|
||||
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=local_hostname,
|
||||
metadata_server=metadata_server,
|
||||
global_segment_size=global_segment_size,
|
||||
local_buffer_size=local_buffer_size,
|
||||
protocol=protocol,
|
||||
rdma_devices=rdma_devices,
|
||||
master_server_addr=master_server_addr,
|
||||
)
|
||||
|
||||
def select_rdma_device(self, tp_rank):
|
||||
"""Select RDMA device based on rank number."""
|
||||
device_list = self.rdma_devices.split(",")
|
||||
device_index = tp_rank % len(device_list)
|
||||
self.rdma_devices = device_list[device_index]
|
||||
|
||||
|
||||
class MooncakeStore(KVCacheStorage):
|
||||
def __init__(self, tp_rank=None):
|
||||
super().__init__()
|
||||
self.tp_rank = tp_rank
|
||||
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake store by following the instructions at "
|
||||
"https://kvcache-ai.github.io/Mooncake/python-api-reference/mooncake-store.html"
|
||||
"to run Fastdeploy with mooncake store."
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.config = MooncakeStoreConfig.create()
|
||||
if self.tp_rank is not None:
|
||||
self.config.select_rdma_device(self.tp_rank)
|
||||
logger.info(f"Mooncake Configuration loaded, {self.config}.")
|
||||
|
||||
ret_code = self.store.setup(
|
||||
local_hostname=self.config.local_hostname,
|
||||
metadata_server=self.config.metadata_server,
|
||||
global_segment_size=self.config.global_segment_size,
|
||||
local_buffer_size=self.config.local_buffer_size,
|
||||
protocol=self.config.protocol,
|
||||
rdma_devices=self.config.rdma_devices,
|
||||
master_server_addr=self.config.master_server_addr,
|
||||
)
|
||||
if ret_code != 0:
|
||||
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
|
||||
raise RuntimeError(f"failed to setup mooncake store, error code: {ret_code}")
|
||||
logger.info("Connect to Mooncake store successfully.")
|
||||
|
||||
self.warmup()
|
||||
logger.info("Mooncake store warmup successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Mooncake store initialization failed: {e}, traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
def warmup(self):
|
||||
warmup_key = "fastdeploy_mooncake_store_warmup_key" + str(uuid.uuid4())
|
||||
warmup_value = bytes(1 * 1024 * 1024) # 1 MB
|
||||
self.store.put(warmup_key, warmup_value)
|
||||
assert self.store.is_exist(warmup_key) == 1
|
||||
self.store.get(warmup_key)
|
||||
self.store.remove(warmup_key)
|
||||
|
||||
def register_buffer(self, buffer_ptr, buffer_size) -> None:
|
||||
try:
|
||||
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
|
||||
if ret_code:
|
||||
logger.error(f"failed to register buffer, error code: {ret_code}")
|
||||
except TypeError as err:
|
||||
logger.error("Failed to register buffer to Mooncake Store: %s", err)
|
||||
raise TypeError("Mooncake Store Register Buffer Error.") from err
|
||||
|
||||
def set(
|
||||
self,
|
||||
key,
|
||||
target_location: Optional[List[int]] = None,
|
||||
target_size: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
pass
|
||||
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[int]] = None,
|
||||
target_sizes: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Batch put multiple objects into the store.
|
||||
Args:
|
||||
keys (list): list of object names to be stored
|
||||
target_locations (list): list of memory locations where the data are stored
|
||||
target_sizes (list): list of byte sizes corresponding to each object
|
||||
Return:
|
||||
List[int]: List of status codes for each operation (0 = success, negative = error)
|
||||
"""
|
||||
if not (len(keys) == len(target_locations) == len(target_sizes)):
|
||||
err_msg = "The length of keys, target_location and target_sizes must match."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if len(keys) == 0:
|
||||
err_msg = "The length of keys, target_location and target_sizes must be greater than zero"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
return self._put_batch_zero_copy_impl(keys, target_locations, target_sizes)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key,
|
||||
target_location: Optional[Any] = None,
|
||||
target_size: Optional[Any] = None,
|
||||
) -> List[int]:
|
||||
pass
|
||||
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Batch get multiple objects from the store.
|
||||
Args:
|
||||
keys (list): list of object names to be fetched
|
||||
target_locations (list): list of memory locations where the data should be stored
|
||||
target_sizes (list): list of byte sizes corresponding to each object
|
||||
Returns:
|
||||
List[int]: List of bytes read for each operation (positive = success, negative = error)
|
||||
"""
|
||||
if not (len(keys) == len(target_locations) == len(target_sizes)):
|
||||
err_msg = "The length of keys, target_locations and target_sizes must match."
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
if len(keys) == 0:
|
||||
err_msg = "The length of keys, target_locations and target_sizes must be greater than zero"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
return self._get_batch_zero_copy_impl(keys, target_locations, target_sizes)
|
||||
|
||||
def exists(self, keys: List[str]):
|
||||
"""
|
||||
Check existence of multiple objects in a single batch operation.
|
||||
Args:
|
||||
keys (list): list of object names to be checked
|
||||
Returns:
|
||||
dict: dictionary mapping each key to its existence status {key: True|False}
|
||||
"""
|
||||
tic = time.time()
|
||||
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(keys))}
|
||||
cost_time = (time.time() - tic) * 1000
|
||||
logger.debug(f"The exists fun processes {len(keys)} objects, cost_time: {cost_time:.3f}ms")
|
||||
return result
|
||||
|
||||
def delete(self, key, timeout=5) -> bool:
|
||||
while timeout:
|
||||
result = self.store.remove(key)
|
||||
if result == 0:
|
||||
logger.info("Successfully removed")
|
||||
return True
|
||||
else:
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
# MooncakeDistributedStore will automatically call the destructor, so
|
||||
# it is unnecessary to close it manually.
|
||||
pass
|
||||
|
||||
def clear(self) -> bool:
|
||||
"""
|
||||
clear all the objects in the store
|
||||
"""
|
||||
count = self.store.remove_all()
|
||||
logger.info(f"Removed {count} objects")
|
||||
return True
|
||||
|
||||
def _put_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> int:
|
||||
try:
|
||||
tic = time.time()
|
||||
result = self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
||||
# List[int]: List of status codes for each operation (0 = success, negative = error)
|
||||
cost_time = time.time() - tic
|
||||
|
||||
total_num = len(key_strs)
|
||||
success_num = result.count(0)
|
||||
if success_num == total_num:
|
||||
logger.debug(
|
||||
f"Put all data into Mooncake Store successfully."
|
||||
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Some of the data was not put into Mooncake Store."
|
||||
f"total_num: {total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
if success_num > 0:
|
||||
total_bytes = sum(bi for ri, bi in zip(result, buffer_sizes) if ri == 0)
|
||||
total_gb = total_bytes / 1073741824
|
||||
speed = total_gb / cost_time
|
||||
logger.info(f"Put data into Mooncake Store, total_gb: {total_gb:.6f}GB, speed: {speed:.6f}GB/s")
|
||||
|
||||
return result
|
||||
except Exception as err:
|
||||
logger.error("Failed to put data into Mooncake Store: %s", err)
|
||||
raise
|
||||
|
||||
def _get_batch_zero_copy_impl(self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]) -> int:
|
||||
try:
|
||||
tic = time.time()
|
||||
result = self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
||||
# List[int]: List of bytes read for each operation (positive = success, negative = error)
|
||||
cost_time = time.time() - tic
|
||||
|
||||
total_num = len(key_strs)
|
||||
success_num = sum(x > 0 for x in result)
|
||||
if success_num == total_num:
|
||||
logger.debug(
|
||||
f"Get all data from Mooncake Store successfully. "
|
||||
f"success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Some of the data was not get from Mooncake Store."
|
||||
f"total_num:{total_num}, success_num: {success_num}, cost_time: {cost_time:.6f}s"
|
||||
)
|
||||
if success_num > 0:
|
||||
total_bytes = sum(bi for ri, bi in zip(result, buffer_sizes) if ri > 0)
|
||||
total_gb = total_bytes / 1073741824
|
||||
speed = total_gb / cost_time
|
||||
logger.info(f"Get data from Mooncake Store, total_gb: {total_gb:.6f}GB, speed: {speed:.6f}GB/s")
|
||||
|
||||
return result
|
||||
except Exception as err:
|
||||
logger.error("Failed to get data from Mooncake Store: %s", err)
|
||||
raise
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import traceback
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory.utils import get_rdma_nics
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
@@ -40,7 +41,6 @@ class RDMACommManager:
|
||||
prefill_tp_idx,
|
||||
):
|
||||
try:
|
||||
import importlib
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
@@ -66,28 +66,9 @@ class RDMACommManager:
|
||||
logger.info("Setting environment variable: export KVCACHE_GDRCOPY_FLUSH_ENABLE=1")
|
||||
|
||||
if os.getenv("KVCACHE_RDMA_NICS", "") == "" and current_platform.is_cuda():
|
||||
res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh"
|
||||
get_rdma_nics = None
|
||||
with importlib.resources.as_file(res) as path:
|
||||
get_rdma_nics = str(path)
|
||||
nic_type = current_platform.device_name
|
||||
command = ["bash", get_rdma_nics, nic_type]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
logger.info(f"get_rdma_nics command: {command}")
|
||||
logger.info(f"get_rdma_nics output: {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}")
|
||||
|
||||
env_name, env_value = result.stdout.strip().split("=")
|
||||
assert env_name == "KVCACHE_RDMA_NICS"
|
||||
os.environ[env_name] = env_value
|
||||
logger.info(f"Setting environment variable: export {env_name}={env_value}")
|
||||
rdma_nics = get_rdma_nics()
|
||||
os.environ["KVCACHE_RDMA_NICS"] = rdma_nics
|
||||
logger.info(f"Setting environment variable: export KVCACHE_RDMA_NICS={rdma_nics}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize RDMA environment! {e} {traceback.format_exc()}")
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import subprocess
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
|
||||
|
||||
def get_rdma_nics():
|
||||
res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh"
|
||||
with importlib.resources.as_file(res) as path:
|
||||
file_path = str(path)
|
||||
|
||||
nic_type = current_platform.device_name
|
||||
command = ["bash", file_path, nic_type]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
logger.info(f"get_rdma_nics command: {command}")
|
||||
logger.info(f"get_rdma_nics output: {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}")
|
||||
|
||||
env_name, env_value = result.stdout.strip().split("=")
|
||||
if env_name != "KVCACHE_RDMA_NICS":
|
||||
raise ValueError(f"Unexpected variable name: {env_name}, expected 'KVCACHE_RDMA_NICS'")
|
||||
|
||||
return env_value
|
||||
@@ -1296,6 +1296,9 @@ class CacheConfig:
|
||||
self.max_processor_cache = None
|
||||
self.enable_output_caching = False
|
||||
self.disable_chunked_mm_input = False
|
||||
self.kvcache_storage_backend = None
|
||||
self.write_policy = None
|
||||
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@@ -1304,11 +1307,6 @@ class CacheConfig:
|
||||
self.rdma_comm_ports = parse_ports(self.rdma_comm_ports)
|
||||
self.pd_comm_port = parse_ports(self.pd_comm_port)
|
||||
|
||||
if self.swap_space is None:
|
||||
self.enable_hierarchical_cache = False
|
||||
else:
|
||||
self.enable_hierarchical_cache = True
|
||||
|
||||
if self.model_cfg is not None:
|
||||
if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict):
|
||||
self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype)
|
||||
|
||||
@@ -230,6 +230,14 @@ class EngineArgs:
|
||||
"""
|
||||
Port for cache queue.
|
||||
"""
|
||||
kvcache_storage_backend: str = None
|
||||
"""
|
||||
The storage backend for kvcache storage. If set, it will use the kvcache storage backend.
|
||||
"""
|
||||
write_policy: str = "write_through"
|
||||
"""
|
||||
The policy of write cache to storage.
|
||||
"""
|
||||
|
||||
# System configuration parameters
|
||||
use_warmup: int = 0
|
||||
@@ -557,6 +565,14 @@ class EngineArgs:
|
||||
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
|
||||
envs.FD_ENABLE_MAX_PREFILL = 1
|
||||
|
||||
if self.kvcache_storage_backend is not None:
|
||||
if not self.enable_prefix_caching:
|
||||
raise NotImplementedError("kvcache_storage_backend is only supported when enable_prefix_caching=True")
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER == 0:
|
||||
raise NotImplementedError(
|
||||
"kvcache_storage_backend is only supported when ENABLE_V1_KVCACHE_SCHEDULER=1"
|
||||
)
|
||||
|
||||
self.post_init_all_ports()
|
||||
|
||||
def post_init_all_ports(self):
|
||||
@@ -1018,6 +1034,22 @@ class EngineArgs:
|
||||
help="Static decoding blocks num.",
|
||||
)
|
||||
|
||||
cache_group.add_argument(
|
||||
"--kvcache-storage-backend",
|
||||
type=nullable_str,
|
||||
choices=["mooncake"],
|
||||
default=EngineArgs.kvcache_storage_backend,
|
||||
help="The storage backend for kvcache storage. Leave empty to disable.",
|
||||
)
|
||||
|
||||
cache_group.add_argument(
|
||||
"--write-policy",
|
||||
type=str,
|
||||
choices=["write_through"],
|
||||
default=EngineArgs.write_policy,
|
||||
help="KVCache write policy",
|
||||
)
|
||||
|
||||
# Cluster system parameters group
|
||||
system_group = parser.add_argument_group("System Configuration")
|
||||
system_group.add_argument(
|
||||
|
||||
@@ -524,6 +524,13 @@ class RequestMetrics:
|
||||
|
||||
speculate_metrics: Optional[SpeculateMetrics] = None
|
||||
|
||||
# cache related
|
||||
gpu_cache_token_num: Optional[int] = 0
|
||||
cpu_cache_token_num: Optional[int] = 0
|
||||
storage_cache_token_num: Optional[int] = 0
|
||||
gpu_cpu_cache_prepare_time: Optional[float] = None
|
||||
storage_cache_prepare_time: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.arrival_time is None:
|
||||
self.arrival_time = time.time()
|
||||
|
||||
@@ -670,7 +670,8 @@ class ResourceManagerV1(ResourceManager):
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
req_index += 1
|
||||
# schedule the WAITING requests.
|
||||
|
||||
# Second, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
skip_requests: list[Request] = []
|
||||
while self.waiting and token_budget > 0:
|
||||
@@ -699,10 +700,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
self._update_mm_hashes(request)
|
||||
# Enable prefix caching
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
if (
|
||||
self.config.cache_config.enable_hierarchical_cache
|
||||
and self.cache_manager.num_cpu_blocks > 0
|
||||
):
|
||||
if self.cache_manager.num_cpu_blocks > 0:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||
// self.config.cache_config.block_size
|
||||
@@ -713,12 +711,21 @@ class ResourceManagerV1(ResourceManager):
|
||||
self._free_blocks(request)
|
||||
break
|
||||
|
||||
# Allocate blocks for the tokens that does not hit cache
|
||||
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||
# Allocate blocks to prefill
|
||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||
if not request.get("skip_allocate", False):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
|
||||
request.block_tables.extend(extra_gpu_block_ids)
|
||||
if (
|
||||
self.config.cache_config.enable_prefix_caching
|
||||
and self.config.cache_config.kvcache_storage_backend
|
||||
and num_new_tokens >= self.config.cache_config.block_size
|
||||
):
|
||||
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
|
||||
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
@@ -744,10 +751,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.num_total_tokens
|
||||
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
if (
|
||||
self.config.cache_config.enable_hierarchical_cache
|
||||
and self.cache_manager.num_cpu_blocks > 0
|
||||
):
|
||||
if self.cache_manager.num_cpu_blocks > 0:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||
// self.config.cache_config.block_size
|
||||
@@ -757,12 +761,22 @@ class ResourceManagerV1(ResourceManager):
|
||||
if not success:
|
||||
self._free_blocks(request)
|
||||
break
|
||||
|
||||
# Allocate blocks for the tokens that does not hit cache
|
||||
num_new_tokens = self._get_num_new_tokens(request, token_budget)
|
||||
num_new_block = self.get_new_block_nums(request, num_new_tokens)
|
||||
# Allocate blocks to prefill
|
||||
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
|
||||
if not request.get("skip_allocate", False):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(num_new_block))
|
||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(num_new_block)
|
||||
request.block_tables.extend(extra_gpu_block_ids)
|
||||
if (
|
||||
self.config.cache_config.enable_prefix_caching
|
||||
and self.config.cache_config.kvcache_storage_backend
|
||||
and num_new_tokens >= self.config.cache_config.block_size
|
||||
):
|
||||
matched_block_ids = self.get_storage_cached_blocks(request, extra_gpu_block_ids)
|
||||
num_new_tokens -= len(matched_block_ids) * self.config.cache_config.block_size
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
@@ -916,28 +930,61 @@ class ResourceManagerV1(ResourceManager):
|
||||
)
|
||||
|
||||
request.num_cached_tokens = matched_token_num
|
||||
request.gpu_cache_token_num = hit_info["gpu_match_token_num"]
|
||||
request.cpu_cache_token_num = hit_info["cpu_match_token_num"]
|
||||
request.cache_info = (matched_block_num, no_cache_block_num)
|
||||
request.metrics.gpu_cache_token_num = hit_info["gpu_match_token_num"]
|
||||
request.metrics.cpu_cache_token_num = hit_info["cpu_match_token_num"]
|
||||
request.cache_info = [matched_block_num, no_cache_block_num]
|
||||
request.block_tables = common_block_ids
|
||||
request.skip_allocate = False
|
||||
|
||||
# Report the number of cached tokens to Prometheus metrics
|
||||
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
|
||||
|
||||
if matched_token_num == request.need_prefill_tokens:
|
||||
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
|
||||
request.skip_allocate = True
|
||||
else:
|
||||
request.num_computed_tokens = matched_token_num
|
||||
request.cache_prepare_time = time.time() - cache_prepare_time
|
||||
request.metrics.gpu_cpu_cache_prepare_time = time.time() - cache_prepare_time
|
||||
return True
|
||||
except Exception as e:
|
||||
llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...")
|
||||
return False
|
||||
|
||||
def get_storage_cached_blocks(self, request: Request, extra_gpu_block_ids: list = []):
|
||||
"""
|
||||
Match and prefetch the cached blocks from the storage backend.
|
||||
TODO: merge this function into get_prefix_cached_blocks
|
||||
"""
|
||||
try:
|
||||
tic = time.time()
|
||||
req_id = request.request_id
|
||||
llm_logger.debug(f"get_storage_cached_blocks start process req {req_id}")
|
||||
matched_block_ids = self.cache_manager.request_match_storage_blocks(request, extra_gpu_block_ids)
|
||||
llm_logger.debug(
|
||||
f"matched {len(matched_block_ids)} blocks from storage for req_id:{req_id}, "
|
||||
f"cost_time: {time.time() - tic:.6f}s"
|
||||
)
|
||||
|
||||
matched_token_num = len(matched_block_ids) * self.config.cache_config.block_size
|
||||
request.metrics.storage_cache_token_num = matched_token_num
|
||||
request.num_computed_tokens += matched_token_num
|
||||
if request.num_computed_tokens == request.need_prefill_tokens:
|
||||
request.num_computed_tokens = request.num_computed_tokens - self.config.cache_config.block_size
|
||||
request.metrics.storage_cache_prepare_time = time.time() - tic
|
||||
request.cache_info[0] += len(matched_block_ids) # matched_block_num
|
||||
request.cache_info[1] -= len(matched_block_ids) # no_cache_block_num
|
||||
|
||||
main_process_metrics.prefix_cache_token_num.inc(matched_token_num)
|
||||
# TODO: main_process_metrics.prefix_storage_cache_token_num.inc(matched_token_num)
|
||||
return matched_block_ids
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"get_storage_cached_blocks process req {req_id}, error: {e}, {str(traceback.format_exc())} "
|
||||
)
|
||||
return []
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
with self.lock:
|
||||
self.apply_async_preprocess(request)
|
||||
@@ -980,7 +1027,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
) // self.config.cache_config.block_size + self.config.cache_config.enc_dec_block_num # consider for mtp, plus enc_dec_block_num
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
# Enable prefix caching
|
||||
if self.config.cache_config.enable_hierarchical_cache and self.cache_manager.num_cpu_blocks > 0:
|
||||
if self.cache_manager.num_cpu_blocks > 0:
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
need_prealloc_prefill_blocks
|
||||
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||
@@ -992,7 +1039,10 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
need_extra_prefill_blocks = need_prealloc_prefill_blocks - request.cache_info[0]
|
||||
if self.cache_manager.can_allocate_gpu_blocks(need_extra_prefill_blocks):
|
||||
request.block_tables.extend(self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks))
|
||||
extra_gpu_block_ids = self.cache_manager.allocate_gpu_blocks(need_extra_prefill_blocks)
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self.get_storage_cached_blocks(request, extra_gpu_block_ids)
|
||||
request.block_tables.extend(extra_gpu_block_ids)
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[request.idx] = request
|
||||
@@ -1113,39 +1163,45 @@ class ResourceManagerV1(ResourceManager):
|
||||
def finish_requests(self, request_ids: Union[str, Iterable[str]]):
|
||||
llm_logger.info(f"recycle resources for requests: {request_ids}")
|
||||
try:
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids,)
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
need_postprocess_reqs = []
|
||||
with self.lock:
|
||||
if isinstance(request_ids, str):
|
||||
request_ids = (request_ids,)
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
if request in self.running: # normally run and finished
|
||||
if request in self.waiting:
|
||||
llm_logger.error(f"request {request.request_id} scheduled into waiting list, after finished")
|
||||
continue
|
||||
if request in self.running:
|
||||
self.running.remove(request)
|
||||
request.status = RequestStatus.FINISHED
|
||||
try:
|
||||
self._free_blocks(request)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"release block failed {req_id}: {e}")
|
||||
if (
|
||||
request.request_id in self.to_be_rescheduled_request_id_set
|
||||
): # finished after preempted, blocks have been recycled.
|
||||
self.to_be_rescheduled_request_id_set.remove(
|
||||
request.request_id
|
||||
) # just remove from to_be_rescheduled_request_id_set
|
||||
if (
|
||||
request in self.waiting
|
||||
): # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
|
||||
raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished")
|
||||
need_postprocess_reqs.append(request)
|
||||
if request.request_id in self.to_be_rescheduled_request_id_set:
|
||||
# finished after preempted, blocks have been recycled.
|
||||
self.to_be_rescheduled_request_id_set.remove(request.request_id)
|
||||
|
||||
self.tasks_list[request.idx] = None
|
||||
self.stop_flags[request.idx] = True
|
||||
del self.requests[req_id]
|
||||
if req_id in self.req_dict:
|
||||
del self.req_dict[req_id]
|
||||
|
||||
# Do not block the main thread here
|
||||
for req in need_postprocess_reqs:
|
||||
self.cache_manager.write_cache_to_storage(req)
|
||||
|
||||
with self.lock:
|
||||
for req in need_postprocess_reqs:
|
||||
try:
|
||||
self._free_blocks(req)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"release block failed {req.request_id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
||||
finally:
|
||||
|
||||
@@ -102,6 +102,12 @@ class EngineCacheQueue:
|
||||
self.swap_to_gpu_barrier2_init = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.swap_storage_to_gpu_barrier_init = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
self.swap_to_storage_barrier_init = [
|
||||
threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size)
|
||||
]
|
||||
|
||||
# Register shared objects with proxy types
|
||||
QueueManager.register(
|
||||
@@ -148,7 +154,14 @@ class EngineCacheQueue:
|
||||
"get_swap_to_gpu_barrier2",
|
||||
callable=lambda idx: self.swap_to_gpu_barrier2_init[idx],
|
||||
)
|
||||
|
||||
QueueManager.register(
|
||||
"get_swap_storage_to_gpu_barrier",
|
||||
callable=lambda idx: self.swap_storage_to_gpu_barrier_init[idx],
|
||||
)
|
||||
QueueManager.register(
|
||||
"get_swap_to_storage_barrier",
|
||||
callable=lambda idx: self.swap_to_storage_barrier_init[idx],
|
||||
)
|
||||
self.manager: BaseManager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self.manager.start()
|
||||
|
||||
@@ -175,6 +188,8 @@ class EngineCacheQueue:
|
||||
QueueManager.register("get_swap_to_cpu_barrier2")
|
||||
QueueManager.register("get_swap_to_gpu_barrier1")
|
||||
QueueManager.register("get_swap_to_gpu_barrier2")
|
||||
QueueManager.register("get_swap_storage_to_gpu_barrier")
|
||||
QueueManager.register("get_swap_to_storage_barrier")
|
||||
|
||||
self.manager = QueueManager(address=self.address, authkey=self.authkey)
|
||||
self._connect_with_retry()
|
||||
@@ -194,6 +209,8 @@ class EngineCacheQueue:
|
||||
self.swap_to_cpu_barrier2 = self.manager.get_swap_to_cpu_barrier2(self.local_data_parallel_id)
|
||||
self.swap_to_gpu_barrier1 = self.manager.get_swap_to_gpu_barrier1(self.local_data_parallel_id)
|
||||
self.swap_to_gpu_barrier2 = self.manager.get_swap_to_gpu_barrier2(self.local_data_parallel_id)
|
||||
self.swap_storage_to_gpu_barrier = self.manager.get_swap_storage_to_gpu_barrier(self.local_data_parallel_id)
|
||||
self.swap_to_storage_barrier = self.manager.get_swap_to_storage_barrier(self.local_data_parallel_id)
|
||||
self.total_num: int = (1 << self.num_client) - 1
|
||||
|
||||
if not is_server:
|
||||
@@ -241,7 +258,7 @@ class EngineCacheQueue:
|
||||
self.task_lock.acquire()
|
||||
self.task_sync_value.set(0)
|
||||
self.transfer_task_queue.append(item)
|
||||
logger.info(f"put_transfer_task: put swap task {item[-1]} to queue successful")
|
||||
logger.info(f"put_transfer_task: put swap task {item} to queue successful")
|
||||
self.task_lock.release()
|
||||
|
||||
def get_transfer_task(self):
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import hashlib
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
@@ -824,6 +825,18 @@ def retrive_model_from_server(model_name_or_path, revision="master"):
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
def get_hash_str(token_ids: List[int], extra_keys: Optional[Any] = []) -> str:
|
||||
"""
|
||||
calculate hash value of a block with additional keys
|
||||
|
||||
Args:
|
||||
token_ids: Input token IDs
|
||||
extra_keys: Additional keys for block identification
|
||||
"""
|
||||
value = (token_ids, extra_keys)
|
||||
return hashlib.sha256(pickle.dumps(value)).hexdigest()
|
||||
|
||||
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: Union[type[T], tuple[type[T], ...]],
|
||||
|
||||
@@ -41,6 +41,8 @@ class Args:
|
||||
create_cache_tensor = False
|
||||
cache_dtype = "bfloat16"
|
||||
default_dtype = "bfloat16"
|
||||
kvcache_storage_backend = None
|
||||
write_policy = "write_through"
|
||||
|
||||
|
||||
# ==========================
|
||||
|
||||
@@ -27,6 +27,7 @@ import pytest
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager
|
||||
from fastdeploy.inter_communicator.ipc_signal_const import PrefixTreeStatus
|
||||
from fastdeploy.utils import get_hash_str
|
||||
|
||||
|
||||
# Metric test double used to track metric updates.
|
||||
@@ -179,6 +180,8 @@ def _create_manager(
|
||||
rdma_comm_ports=None,
|
||||
local_cache_queue_port=9000,
|
||||
local_rdma_comm_ports=None,
|
||||
kvcache_storage_backend=None,
|
||||
write_policy="write_through",
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
num_attention_heads=1,
|
||||
@@ -186,6 +189,7 @@ def _create_manager(
|
||||
head_dim=1,
|
||||
_architecture="",
|
||||
dtype="float16",
|
||||
max_model_len=128,
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
cache_config=cache_config,
|
||||
@@ -199,7 +203,7 @@ def _create_manager(
|
||||
|
||||
def _make_block_node(manager, node_id, input_ids, *, block_size=2, parent=None, cache_status=CacheStatus.GPU):
|
||||
parent = parent or manager.radix_tree_root
|
||||
block_hash = manager.cal_block_hash(input_ids)
|
||||
block_hash = get_hash_str(input_ids)
|
||||
node = BlockNode(
|
||||
node_id,
|
||||
input_ids,
|
||||
@@ -878,10 +882,10 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
manager = _create_manager(num_gpu_blocks=4)
|
||||
block_size = 2
|
||||
root = manager.radix_tree_root
|
||||
gpu_hash = manager.cal_block_hash([1, 2])
|
||||
gpu_hash = get_hash_str([1, 2])
|
||||
gpu_node = BlockNode(1, [], 0, 1, 0, block_size, gpu_hash, 0, parent=root)
|
||||
root.children[gpu_hash] = gpu_node
|
||||
cpu_hash = manager.cal_block_hash([3, 4])
|
||||
cpu_hash = get_hash_str([3, 4], extra_keys=[gpu_hash])
|
||||
cpu_node = BlockNode(2, [], 0, 2, 1, block_size, cpu_hash, 0, parent=gpu_node, cache_status=CacheStatus.CPU)
|
||||
gpu_node.children[cpu_hash] = cpu_node
|
||||
manager.gpu_lru_leaf_set.add(gpu_node)
|
||||
@@ -917,7 +921,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_free_block_ids_async_recycles_gpu_nodes(self):
|
||||
manager = _create_manager(num_gpu_blocks=4)
|
||||
node_hash = manager.cal_block_hash([1, 2])
|
||||
node_hash = get_hash_str([1, 2])
|
||||
node = BlockNode(10, [1, 2], node_hash, 1, 0, 2, node_hash, 0, parent=manager.radix_tree_root)
|
||||
node.shared_count = 0
|
||||
manager.radix_tree_root.children[node_hash] = node
|
||||
@@ -941,7 +945,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
manager.issue_swap_task = _fake_issue
|
||||
|
||||
node_hash = manager.cal_block_hash([3, 4])
|
||||
node_hash = get_hash_str([3, 4])
|
||||
node = BlockNode(11, [3, 4], node_hash, 1, 1, 2, node_hash, 0, parent=manager.radix_tree_root)
|
||||
node.shared_count = 0
|
||||
manager.radix_tree_root.children[node_hash] = node
|
||||
@@ -957,9 +961,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
block_size = 2
|
||||
manager.cache_config.disable_chunked_mm_input = False
|
||||
input_ids = [1, 2, 3, 4]
|
||||
hash_input = manager.hash_block_features(input_ids)
|
||||
hash_first = manager.hash_block_features([1, 2])
|
||||
hash_second = manager.hash_block_features([3, 4], ["img"])
|
||||
hash_input = get_hash_str(input_ids)
|
||||
hash_first = get_hash_str([1, 2])
|
||||
hash_second = get_hash_str([3, 4], [hash_first, "img"])
|
||||
|
||||
node1 = BlockNode(30, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
|
||||
manager.radix_tree_root.children[hash_first] = node1
|
||||
@@ -1004,9 +1008,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
manager.cache_config.disable_chunked_mm_input = False
|
||||
block_size = 2
|
||||
input_ids = [1, 2, 3, 4]
|
||||
hash_input = manager.hash_block_features(input_ids)
|
||||
hash_first = manager.hash_block_features([1, 2])
|
||||
hash_second = manager.hash_block_features([3, 4], ["img"])
|
||||
hash_input = get_hash_str(input_ids)
|
||||
hash_first = get_hash_str([1, 2])
|
||||
hash_second = get_hash_str([3, 4], [hash_first, "img"])
|
||||
node1 = BlockNode(40, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
|
||||
node2 = BlockNode(
|
||||
41,
|
||||
@@ -1045,7 +1049,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_release_block_ids_cleans_request_state(self):
|
||||
manager = _create_manager(num_gpu_blocks=4)
|
||||
node = BlockNode(50, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
node = BlockNode(50, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
node.cache_status = CacheStatus.GPU
|
||||
manager.radix_tree_root.children[node.hash_value] = node
|
||||
req_id = "release-req"
|
||||
@@ -1061,7 +1065,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_free_cpu_block_ids_eviction(self):
|
||||
manager = _create_manager(num_gpu_blocks=2, num_cpu_blocks=2)
|
||||
cpu_node = BlockNode(60, [3, 4], 0, 1, 0, 2, manager.cal_block_hash([3, 4]), 0, parent=manager.radix_tree_root)
|
||||
cpu_node = BlockNode(60, [3, 4], 0, 1, 0, 2, get_hash_str([3, 4]), 0, parent=manager.radix_tree_root)
|
||||
cpu_node.cache_status = CacheStatus.CPU
|
||||
manager.cpu_lru_leaf_heap.append(cpu_node)
|
||||
manager.cpu_lru_leaf_set.add(cpu_node)
|
||||
@@ -1070,8 +1074,8 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_free_nodes_directly_recovers_chain(self):
|
||||
manager = _create_manager(num_gpu_blocks=4)
|
||||
parent = BlockNode(70, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
child_hash = manager.cal_block_hash([3, 4])
|
||||
parent = BlockNode(70, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
child_hash = get_hash_str([3, 4])
|
||||
child = BlockNode(71, [1, 2, 3, 4], 0, 2, 1, 2, child_hash, 0, parent=parent)
|
||||
parent.children[child_hash] = child
|
||||
parent.shared_count = 0
|
||||
@@ -1102,9 +1106,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
manager.cache_config.disable_chunked_mm_input = True
|
||||
block_size = 2
|
||||
input_ids = [1, 2, 3, 4]
|
||||
hash_input = manager.hash_block_features(input_ids)
|
||||
hash_first = manager.hash_block_features([1, 2])
|
||||
hash_second = manager.hash_block_features([3, 4], ["img"])
|
||||
hash_input = get_hash_str(input_ids)
|
||||
hash_first = get_hash_str([1, 2])
|
||||
hash_second = get_hash_str([3, 4], ["img"])
|
||||
node1 = BlockNode(80, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
|
||||
node2 = BlockNode(81, input_ids, hash_input, 2, 1, block_size, hash_second, 0, parent=node1)
|
||||
manager.radix_tree_root.children[hash_first] = node1
|
||||
@@ -1144,7 +1148,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_handle_swap_result_updates_status(self):
|
||||
manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=2)
|
||||
node = BlockNode(90, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
|
||||
node = BlockNode(90, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
|
||||
node.cache_status = CacheStatus.SWAP2CPU
|
||||
manager.node_map[node.node_id] = node
|
||||
manager._handle_swap_result(node.node_id, 2, 3, CacheStatus.SWAP2CPU)
|
||||
@@ -1156,7 +1160,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_reset_clears_internal_state(self):
|
||||
manager = _create_manager(num_gpu_blocks=2, num_cpu_blocks=1)
|
||||
node = BlockNode(100, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
|
||||
node = BlockNode(100, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
|
||||
manager.node_map[node.node_id] = node
|
||||
manager.task_swapping_event["evt"] = threading.Event()
|
||||
manager.task_swapping_event["evt"].set()
|
||||
@@ -1166,9 +1170,9 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
|
||||
def test_recv_data_transfer_result_processes_queue(self):
|
||||
manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=1)
|
||||
node = BlockNode(110, [1], 0, 1, 0, 1, manager.cal_block_hash([1]), 0, parent=manager.radix_tree_root)
|
||||
node = BlockNode(110, [1], 0, 1, 0, 1, get_hash_str([1]), 0, parent=manager.radix_tree_root)
|
||||
manager.node_map[node.node_id] = node
|
||||
payload = [([node.node_id], [2], [3], CacheStatus.SWAP2GPU, "task")]
|
||||
payload = [(CacheStatus.SWAP2GPU, "task", [node.node_id], [2], [3])]
|
||||
manager.cache_task_queue = _FakeTransferQueue(payload, include_none=True)
|
||||
manager.task_swapping_event["task"] = threading.Event()
|
||||
with self.assertRaises(SystemExit):
|
||||
@@ -1196,7 +1200,7 @@ class PrefixCacheManagerTest(unittest.TestCase):
|
||||
request_id="revert",
|
||||
multimodal_inputs={"mm_positions": [SimpleNamespace(offset=2, length=2)]},
|
||||
)
|
||||
node = BlockNode(120, [1, 2], 0, 1, 0, 2, manager.cal_block_hash([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
node = BlockNode(120, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
|
||||
matche_nodes = [node]
|
||||
match_gpu = [0]
|
||||
match_node_ids = [node.node_id]
|
||||
|
||||
Reference in New Issue
Block a user