mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[KVCache] support unified cache backend (#4903)
* [Feature] support unified cache backend * fix * fix * fix * fix * Update metax_model_runner.py * fix * update * Update test_moba_attention_backend.py --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
@@ -58,36 +58,7 @@ def parse_args():
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
|
||||
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only support ipc now",
|
||||
)
|
||||
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
|
||||
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
|
||||
parser.add_argument(
|
||||
"--bytes_per_layer_per_block",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="per layer per block bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dtype",
|
||||
type=str,
|
||||
@@ -95,13 +66,33 @@ def parse_args():
|
||||
choices=["uint8", "bfloat16"],
|
||||
help="cache dtype",
|
||||
)
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
|
||||
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
|
||||
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
|
||||
parser.add_argument(
|
||||
"--engine_worker_queue_port",
|
||||
type=int,
|
||||
default=9923,
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
default="ipc",
|
||||
help="cache transfer protocol, only support ipc now",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument(
|
||||
"--speculative_config",
|
||||
type=json.loads,
|
||||
default="{}",
|
||||
help="speculative config",
|
||||
)
|
||||
parser.add_argument("--local_data_parallel_id", type=int, default=0)
|
||||
parser.add_argument("--create_cache_tensor", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -124,8 +115,13 @@ class CacheTransferManager:
|
||||
self.gpu_cache_k_tensors = []
|
||||
self.gpu_cache_v_tensors = []
|
||||
self.speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
|
||||
self.value_cache_shape = []
|
||||
if args.value_cache_shape:
|
||||
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
|
||||
self.num_gpu_blocks = self.key_cache_shape[0]
|
||||
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
|
||||
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
|
||||
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
@@ -164,8 +160,9 @@ class CacheTransferManager:
|
||||
|
||||
self.num_cpu_blocks = args.num_cpu_blocks
|
||||
|
||||
self._init_cpu_cache(args)
|
||||
self._init_gpu_cache(args)
|
||||
if self.num_cpu_blocks > 0:
|
||||
self._init_cpu_cache(args)
|
||||
|
||||
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
self.cache_task_broadcast_signal = IPCSignal(
|
||||
@@ -209,28 +206,47 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
|
||||
set_device(self.device)
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
|
||||
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
|
||||
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
|
||||
|
||||
key_cache_shape = [
|
||||
num_gpu_blocks,
|
||||
self.key_cache_shape[1],
|
||||
self.key_cache_shape[2],
|
||||
self.key_cache_shape[3],
|
||||
]
|
||||
value_cache_shape = []
|
||||
if self.value_cache_shape:
|
||||
value_cache_shape = [
|
||||
num_gpu_blocks,
|
||||
self.value_cache_shape[1],
|
||||
self.value_cache_shape[2],
|
||||
self.value_cache_shape[3],
|
||||
]
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
|
||||
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
set_data_ipc(key_cache, key_name)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
if self.value_cache_shape:
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype)
|
||||
set_data_ipc(val_cache, val_name)
|
||||
else:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
|
||||
)
|
||||
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
|
||||
key_cache = share_external_data_(key_cache, key_name, cache_shape, True)
|
||||
val_cache = share_external_data_(val_cache, val_name, cache_shape, True)
|
||||
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
|
||||
if self.value_cache_shape:
|
||||
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
|
||||
|
||||
self.gpu_cache_kvs[key_name] = key_cache
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
if args.value_cache_shape:
|
||||
self.gpu_cache_kvs[val_name] = val_cache
|
||||
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
|
||||
|
||||
if args.create_cache_tensor:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
|
||||
@@ -242,6 +258,22 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
|
||||
|
||||
def _init_cpu_cache(self, args):
|
||||
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
|
||||
if args.value_cache_shape:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
if args.cache_dtype == "bfloat16":
|
||||
cache_bytes = 2
|
||||
elif args.cache_dtype == "uint8":
|
||||
cache_bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
if args.num_cpu_blocks == 0:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
@@ -253,14 +285,14 @@ class CacheTransferManager:
|
||||
for i in range(args.num_layers + self.num_extra_layers):
|
||||
key_name = f"key_caches_{i}_rank{self.rank}"
|
||||
val_name = f"value_caches_{i}_rank{self.rank}"
|
||||
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
|
||||
logger.info(
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
|
||||
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
|
||||
)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
|
||||
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
if value_need_to_allocate_bytes > 0:
|
||||
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
|
||||
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user