[KVCache] support unified cache backend (#4903)

* [Feature] support unified cache backend

* fix

* fix

* fix

* fix

* Update metax_model_runner.py

* fix

* update

* Update test_moba_attention_backend.py

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-11-12 14:54:52 +08:00
committed by GitHub
parent 76e60e98f8
commit 5bf48de999
19 changed files with 281 additions and 202 deletions
+39 -20
View File
@@ -52,8 +52,8 @@ def parse_args():
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
@@ -71,8 +71,6 @@ def parse_args():
default=9923,
help="engine worker queue port",
)
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
parser.add_argument(
"--cache_dtype",
type=str,
@@ -764,38 +762,59 @@ def main():
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")]
value_cache_shape_list = []
if args.value_cache_shape:
value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")]
total_gpu_blocks = key_cache_shape_list[0]
num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
logger.info(f"[rank {rank}/{args.mp_num}] Initializing kv cache for all layers.")
for i in range(args.num_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
logger.info(f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {cache_shape}")
num_gpu_blocks = total_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
key_cache_shape = [
num_gpu_blocks,
key_cache_shape_list[1],
key_cache_shape_list[2],
key_cache_shape_list[3],
]
value_cache_shape = []
if value_cache_shape_list:
value_cache_shape = [
num_gpu_blocks,
value_cache_shape_list[1],
value_cache_shape_list[2],
value_cache_shape_list[3],
]
logger.info(
f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
)
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=cache_shape,
shape=key_cache_shape,
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=cache_shape,
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
f"key_caches_{i}_rank{rank}.device{device}",
)
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
if value_cache_shape_list:
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=value_cache_shape,
fill_value=0,
dtype=cache_type,
)
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
set_data_ipc(
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
f"value_caches_{i}_rank{rank}.device{device}",
)
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")