mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] add version to the key of cache storage && refine raising error (#6160)
* Waiting for cache transfer manager inited * up * up * up * up * up * fix according comments * fix unittest * fix * fix unittest * fix error * pass storage_backend to worker
This commit is contained in:
@@ -18,6 +18,7 @@ import argparse
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
@@ -26,6 +27,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
@@ -45,7 +47,7 @@ from fastdeploy.cache_manager.transfer_factory import AttentionStore, MooncakeSt
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import console_logger, get_logger
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -59,7 +61,6 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--model_id", type=str, default="default", help="model id")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--max_model_len", type=int, default=32768, help="max model length")
|
||||
@@ -111,7 +112,7 @@ def parse_args():
|
||||
"--kvcache_storage_backend",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["mooncake", "attention_store", "none"],
|
||||
choices=["mooncake", "attention_store"],
|
||||
help="The storage backend for kvcache storage. If not set, storage backend is disabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -121,11 +122,22 @@ def parse_args():
|
||||
default="write_through",
|
||||
help="KVCache write policy",
|
||||
)
|
||||
parser.add_argument("--model_path", type=str, help="The path of model")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_key_prefix_from_version(version_file_path):
|
||||
# the format of version string is RL-STEP{xx}-{timestamp}-{uuid4}
|
||||
with open(version_file_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
version = data["version"]
|
||||
parts = version.split("-", 2)
|
||||
key_prefix = "-".join(parts[:2])
|
||||
return key_prefix
|
||||
|
||||
|
||||
class CacheTransferManager:
|
||||
"""
|
||||
管理CPU和GPU之间缓存的交换传输
|
||||
@@ -160,7 +172,7 @@ class CacheTransferManager:
|
||||
self.cache_bytes = self._get_cache_bytes(self.cache_dtype)
|
||||
|
||||
# extract other arg values
|
||||
self.model_id = args.model_id
|
||||
self.model_id = os.path.basename(args.model_path.rstrip("/"))
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = args.rank
|
||||
self.device = args.device_id
|
||||
@@ -210,6 +222,7 @@ class CacheTransferManager:
|
||||
self._init_gpu_cache(args)
|
||||
if self.num_cpu_blocks > 0:
|
||||
self._init_cpu_cache(args)
|
||||
self._init_storage(args)
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
@@ -231,34 +244,6 @@ class CacheTransferManager:
|
||||
create=False,
|
||||
)
|
||||
|
||||
if args.kvcache_storage_backend is None or args.kvcache_storage_backend == "none":
|
||||
self.storage_backend = None
|
||||
elif args.kvcache_storage_backend == "mooncake":
|
||||
logger.info("Start initialize mooncake store...")
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
elif args.kvcache_storage_backend == "attention_store":
|
||||
logger.info("Start initialize attention store...")
|
||||
self.storage_backend = AttentionStore(
|
||||
namespace=self.model_id,
|
||||
shard_id=self.rank,
|
||||
shard_num=self.n_ranks,
|
||||
layer_num=self.num_layers + self.num_extra_layers,
|
||||
block_token_size=self.block_size,
|
||||
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||
device_id=self.device,
|
||||
dp_id=self.local_data_parallel_id,
|
||||
)
|
||||
logger.info("Initialized attention store successfully!")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {args.kvcache_storage_backend}")
|
||||
self.storage_backend_type = args.kvcache_storage_backend
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
# Initialize update/clear signals for RL
|
||||
self.kv_cache_status_signal = IPCSignal(
|
||||
name="kv_cache_status",
|
||||
@@ -269,6 +254,61 @@ class CacheTransferManager:
|
||||
)
|
||||
threading.Thread(target=self.check_cache_status, args=[args], daemon=True).start()
|
||||
|
||||
cache_transfer_inited_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
self.cache_transfer_inited_signal = IPCSignal(
|
||||
name="cache_transfer_inited_signal",
|
||||
array=cache_transfer_inited_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
self.cache_transfer_inited_signal.value[self.rank] = 1
|
||||
|
||||
def _init_storage(self, args):
|
||||
self.storage_backend_type = args.kvcache_storage_backend
|
||||
|
||||
try:
|
||||
if self.storage_backend_type is None:
|
||||
self.storage_backend = None
|
||||
elif self.storage_backend_type == "mooncake":
|
||||
logger.info("Start initialize mooncake store...")
|
||||
self.storage_backend = MooncakeStore(tp_rank=self.rank)
|
||||
self._init_storage_buffer(args)
|
||||
logger.info("Initialized mooncake store successfully")
|
||||
elif self.storage_backend_type == "attention_store":
|
||||
logger.info("Start initialize attention store...")
|
||||
# TODO: support different model version in rl
|
||||
self.storage_backend = AttentionStore(
|
||||
namespace=self.model_id,
|
||||
shard_id=self.rank,
|
||||
shard_num=self.n_ranks,
|
||||
layer_num=self.num_layers + self.num_extra_layers,
|
||||
block_token_size=self.block_size,
|
||||
bytes_per_shard_layer_per_block=self.head_num * self.block_size * self.head_dim * self.cache_bytes,
|
||||
device_id=self.device,
|
||||
dp_id=self.local_data_parallel_id,
|
||||
)
|
||||
logger.info("Initialized attention store successfully!")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported storage backend: {self.storage_backend_type}")
|
||||
except Exception as e:
|
||||
err_msg = f"Fail to initialize storage backend, {e}, traceback: {traceback.format_exc()}"
|
||||
logger.error(err_msg)
|
||||
console_logger.error(err_msg) # print error message to console
|
||||
raise
|
||||
|
||||
if args.write_policy not in ["write_through"]:
|
||||
raise ValueError(f"Invalid write policy: {args.write_policy}")
|
||||
self.write_policy = args.write_policy
|
||||
|
||||
self.key_prefix = ""
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
if os.path.exists(version_file_path):
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"The key_prefix of cache storage is {self.key_prefix}")
|
||||
|
||||
logger.info("Initialize cache storage successfully")
|
||||
|
||||
def _init_storage_buffer(self, args):
|
||||
"""
|
||||
Initialize pinned memory buffer that can hold the cache for a longest request
|
||||
@@ -287,7 +327,7 @@ class CacheTransferManager:
|
||||
)
|
||||
total_bytes = block_num * self.storage_buffer_stride_bytes * 2 # key and value
|
||||
|
||||
logger.info(f"Creating cpu buffer cache for alllayers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
logger.info(f"Creating cpu buffer cache for all layers: {total_bytes / 1024 ** 3:.2f}GB")
|
||||
read_buffer = cuda_host_alloc(total_bytes)
|
||||
self.storage_key_read_buffer = read_buffer
|
||||
self.storage_value_read_buffer = read_buffer + total_bytes // 2
|
||||
@@ -557,8 +597,8 @@ class CacheTransferManager:
|
||||
try:
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
|
||||
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
match_block_num = self.storage_backend.query(k_cache_keys, v_cache_keys)
|
||||
@@ -694,8 +734,8 @@ class CacheTransferManager:
|
||||
try:
|
||||
gpu_block_ids = task.gpu_block_ids.copy()
|
||||
cpu_block_ids = [i for i in range(len(gpu_block_ids))]
|
||||
k_cache_keys = [f"{key}_key_{self.rank}" for key in task.keys]
|
||||
v_cache_keys = [f"{key}_value_{self.rank}" for key in task.keys]
|
||||
k_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_key" for key in task.keys]
|
||||
v_cache_keys = [f"prefix{self.key_prefix}_{key}_{self.rank}_value" for key in task.keys]
|
||||
|
||||
match_block_num = 0
|
||||
if self.storage_backend_type == "mooncake":
|
||||
@@ -1112,6 +1152,13 @@ class CacheTransferManager:
|
||||
self._init_gpu_cache(args)
|
||||
logger.debug("[RL] successfully restored gpu caches")
|
||||
|
||||
if self.storage_backend_type is not None:
|
||||
# use key_prefix to distinguish cache for different version of weight in rl
|
||||
version_file_path = os.path.join(args.model_path, "version.yaml")
|
||||
assert os.path.exists(version_file_path), f"version.yaml not found at {version_file_path}"
|
||||
self.key_prefix = get_key_prefix_from_version(version_file_path)
|
||||
logger.info(f"Update key_prefix of cache storage to {self.key_prefix}")
|
||||
|
||||
# wait for all ranks caches to be ready
|
||||
while np.sum(self.cache_ready_signal.value) != args.mp_num:
|
||||
time.sleep(0.1)
|
||||
|
||||
Reference in New Issue
Block a user