[KVCache] support unified cache backend (#4903)

* [Feature] support unified cache backend

* fix

* fix

* fix

* fix

* Update metax_model_runner.py

* fix

* update

* Update test_moba_attention_backend.py

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-11-12 14:54:52 +08:00
committed by GitHub
parent 76e60e98f8
commit 5bf48de999
19 changed files with 281 additions and 202 deletions
@@ -58,36 +58,7 @@ def parse_args():
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only support ipc now",
)
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--bytes_per_layer_per_block",
type=int,
default=1024,
help="per layer per block bytes",
)
parser.add_argument(
"--cache_dtype",
type=str,
@@ -95,13 +66,33 @@ def parse_args():
choices=["uint8", "bfloat16"],
help="cache dtype",
)
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
parser.add_argument(
"--engine_worker_queue_port",
type=int,
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
type=str,
default="ipc",
help="cache transfer protocol, only support ipc now",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument(
"--speculative_config",
type=json.loads,
default="{}",
help="speculative config",
)
parser.add_argument("--local_data_parallel_id", type=int, default=0)
parser.add_argument("--create_cache_tensor", action="store_true")
args = parser.parse_args()
@@ -124,8 +115,13 @@ class CacheTransferManager:
self.gpu_cache_k_tensors = []
self.gpu_cache_v_tensors = []
self.speculative_config = SpeculativeConfig(args.speculative_config)
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
self.value_cache_shape = []
if args.value_cache_shape:
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
self.num_gpu_blocks = self.key_cache_shape[0]
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -164,8 +160,9 @@ class CacheTransferManager:
self.num_cpu_blocks = args.num_cpu_blocks
self._init_cpu_cache(args)
self._init_gpu_cache(args)
if self.num_cpu_blocks > 0:
self._init_cpu_cache(args)
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
self.cache_task_broadcast_signal = IPCSignal(
@@ -209,28 +206,47 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
set_device(self.device)
for i in range(args.num_layers + self.num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
key_cache_shape = [
num_gpu_blocks,
self.key_cache_shape[1],
self.key_cache_shape[2],
self.key_cache_shape[3],
]
value_cache_shape = []
if self.value_cache_shape:
value_cache_shape = [
num_gpu_blocks,
self.value_cache_shape[1],
self.value_cache_shape[2],
self.value_cache_shape[3],
]
if args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype)
set_data_ipc(key_cache, key_name)
set_data_ipc(val_cache, val_name)
if self.value_cache_shape:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype)
set_data_ipc(val_cache, val_name)
else:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
key_cache = share_external_data_(key_cache, key_name, cache_shape, True)
val_cache = share_external_data_(val_cache, val_name, cache_shape, True)
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
if self.value_cache_shape:
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
self.gpu_cache_kvs[key_name] = key_cache
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.value_cache_shape:
self.gpu_cache_kvs[val_name] = val_cache
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
if args.create_cache_tensor:
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
@@ -242,6 +258,22 @@ class CacheTransferManager:
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
def _init_cpu_cache(self, args):
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
if args.value_cache_shape:
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
else:
value_cache_size = 0
if args.cache_dtype == "bfloat16":
cache_bytes = 2
elif args.cache_dtype == "uint8":
cache_bytes = 1
else:
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
if args.num_cpu_blocks == 0:
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
self.swap_space_ready_signal.value[self.rank] = 1
@@ -253,14 +285,14 @@ class CacheTransferManager:
for i in range(args.num_layers + self.num_extra_layers):
key_name = f"key_caches_{i}_rank{self.rank}"
val_name = f"value_caches_{i}_rank{self.rank}"
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
logger.info(
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
)
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
if value_need_to_allocate_bytes > 0:
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
self.swap_space_ready_signal.value[self.rank] = 1