[fix] remove cache tensor creation for cache_transfer_manager (#4420)

* [fix] remove cache tensor creation for cache_transfer_manager

* [fix] fix code style

* [fix] fix code style

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
李泳桦
2025-10-20 16:19:56 +08:00
committed by GitHub
parent de2eaf4f81
commit b8d235445e
5 changed files with 28 additions and 23 deletions
+6 -12
View File
@@ -758,27 +758,20 @@ def main():
gpu_cache_k_tensors = []
gpu_cache_v_tensors = []
logger.info(f"[rank {rank}/{args.mp_num}] Initializing kv cache for all layers.")
for i in range(args.num_layers + num_extra_layers):
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
logger.info(f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {cache_shape}")
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
shape=cache_shape,
fill_value=0,
dtype=cache_type,
)
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[
num_gpu_blocks,
args.kv_num_head,
args.block_size,
args.head_dim,
],
shape=cache_shape,
fill_value=0,
dtype=cache_type,
)
@@ -835,6 +828,7 @@ def main():
create=False,
)
cache_ready_signal.value[rank] = 1
logger.info(f"[rank {rank}/{args.mp_num}] ✅ kv cache is ready!")
if args.splitwise_role == "mixed":
while True:
time.sleep(1)