[RL] add version to the key of cache storage && refine raising error (#6160)

* Waiting for cache transfer manager inited

* up

* up

* up

* up

* up

* fix according comments

* fix unittest

* fix

* fix unittest

* fix error

* pass storage_backend to worker
This commit is contained in:
jc
2026-01-27 10:47:46 +08:00
committed by GitHub
parent 7ffa88bb01
commit b1698a79cb
8 changed files with 121 additions and 107 deletions
@@ -18,6 +18,7 @@ import argparse
import concurrent.futures
import gc
import json
import os
import queue
import threading
import time
@@ -26,6 +27,7 @@ from typing import List
import numpy as np
import paddle
import yaml
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import CacheStatus
@@ -45,7 +47,7 @@ from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeSt
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
from fastdeploy.platforms import current_platform
from fastdeploy.utils import get_logger
from fastdeploy.utils import console_logger, get_logger
def parse_args():
@@ -59,7 +61,6 @@ def parse_args():
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--model_id", type=str, default="default", help="model id")
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
@@ -111,7 +112,7 @@ def parse_args():
"--kvcache_storage_backend",
type=str,
default=None,
choices=["mooncake", "attention_store", "none"],
choices=["mooncake", "attention_store"],
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
)
parser.add_argument(
@@ -121,11 +122,22 @@ def parse_args():
default="write_through",
help="KVCache write policy",
)
parser.add_argument("--model_path", type=str, help="The path of model")
args = parser.parse_args()
return args
def get_key_prefix_from_version(version_file_path):
# the format of version string is RL-STEP{xx}-{timestamp}-{uuid4}
with open(version_file_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
version = data["version"]
parts = version.split("-", 2)
key_prefix = "-".join(parts[:2])
return key_prefix
class CacheTransferManager:
"""
管理CPU和GPU之间缓存的交换传输
@@ -160,7 +172,7 @@ class CacheTransferManager:
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
# extract other arg values
self.model_id = args.model_id
self.model_id = os.path.basename(args.model_path.rstrip("/"))
self.n_ranks = args.mp_num
self.rank = args.rank
self.device = args.device_id
@@ -210,6 +222,7 @@ class CacheTransferManager:
self._init_gpu_cache(args)
if self.num_cpu_blocks > 0:
self._init_cpu_cache(args)
self._init_storage(args)
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
@@ -231,34 +244,6 @@ class CacheTransferManager:
create=False,
)
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
self.storage_backend = None
elif args.kvcache_storage_backend == "mooncake":
logger.info("Start initialize mooncake store...")
self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully")
elif args.kvcache_storage_backend == "attention_store":
logger.info("Start initialize attention store...")
self.storage_backend = AttentionStore(
namespace=self.model_id,
shard_id=self.rank,
shard_num=self.n_ranks,
layer_num=self.num_layers + self.num_extra_layers,
block_token_size=self.block_size,
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
)
logger.info("Initialized attention store successfully!")
else:
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
self.storage_backend_type = args.kvcache_storage_backend
if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
# Initialize update/clear signals for RL
self.kv_cache_status_signal = IPCSignal(
name="kv_cache_status",
@@ -269,6 +254,61 @@ class CacheTransferManager:
)
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
self.cache_transfer_inited_signal = IPCSignal(
name="cache_transfer_inited_signal",
array=cache_transfer_inited_signal_data,
dtype=np.int32,
suffix=args.engine_worker_queue_port,
create=False,
)
self.cache_transfer_inited_signal.value[self.rank] = 1
def _init_storage(self, args):
self.storage_backend_type = args.kvcache_storage_backend
try:
if self.storage_backend_type is None:
self.storage_backend = None
elif self.storage_backend_type == "mooncake":
logger.info("Start initialize mooncake store...")
self.storage_backend = MooncakeStore(tp_rank=self.rank)
self._init_storage_buffer(args)
logger.info("Initialized mooncake store successfully")
elif self.storage_backend_type == "attention_store":
logger.info("Start initialize attention store...")
# TODO: support different model version in rl
self.storage_backend = AttentionStore(
namespace=self.model_id,
shard_id=self.rank,
shard_num=self.n_ranks,
layer_num=self.num_layers + self.num_extra_layers,
block_token_size=self.block_size,
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
device_id=self.device,
dp_id=self.local_data_parallel_id,
)
logger.info("Initialized attention store successfully!")
else:
raise NotImplementedError(f"Unsupported storage backend: {self.storage_backend_type}")
except Exception as e:
err_msg = f"Fail to initialize storage backend, {e}, traceback: {traceback.format_exc()}"
logger.error(err_msg)
console_logger.error(err_msg) # print error message to console
raise
if args.write_policy not in ["write_through"]:
raise ValueError(f"Invalid write policy: {args.write_policy}")
self.write_policy = args.write_policy
self.key_prefix = ""
version_file_path = os.path.join(args.model_path, "version.yaml")
if os.path.exists(version_file_path):
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
logger.info("Initialize cache storage successfully")
def _init_storage_buffer(self, args):
"""
Initialize pinned memory buffer that can hold the cache for a longest request
@@ -287,7 +327,7 @@ class CacheTransferManager:
)
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
logger.info(f"Creating cpu buffer cache for all layers: {total_bytes / 1024 ** 3:.2f}GB")
read_buffer = cuda_host_alloc(total_bytes)
self.storage_key_read_buffer = read_buffer
self.storage_value_read_buffer = read_buffer + total_bytes // 2
@@ -557,8 +597,8 @@ class CacheTransferManager:
try:
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
match_block_num = 0
if self.storage_backend_type == "mooncake":
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
@@ -694,8 +734,8 @@ class CacheTransferManager:
try:
gpu_block_ids = task.gpu_block_ids.copy()
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
match_block_num = 0
if self.storage_backend_type == "mooncake":
@@ -1112,6 +1152,13 @@ class CacheTransferManager:
self._init_gpu_cache(args)
logger.debug("[RL] successfully restored gpu caches")
if self.storage_backend_type is not None:
# use key_prefix to distinguish cache for different version of weight in rl
version_file_path = os.path.join(args.model_path, "version.yaml")
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
self.key_prefix = get_key_prefix_from_version(version_file_path)
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
# wait for all ranks caches to be ready
while np.sum(self.cache_ready_signal.value) != args.mp_num:
time.sleep(0.1)