[BugFix] Refine the preparation of cpu and storage cache (#5777)

* Refine the preparation of cpu and storage cache

* fix error

* fix error

* up

* fix

* up docs

* fix unittest

* remove debug info
This commit is contained in:
jc
2026-01-05 10:13:30 +08:00
committed by GitHub
parent 95257c1dbd
commit e911ac2ce7
10 changed files with 156 additions and 149 deletions
@@ -456,8 +456,9 @@ class CacheTransferManager:
def _run_read_storage(self, k_cache_keys, v_cache_keys, gpu_block_ids, cpu_block_ids):
try:
logger.debug(
f"_run_read_storage, key_hash_keys: {k_cache_keys}, "
f"value_hash_keys: {v_cache_keys}, gpu_block_ids: {gpu_block_ids}"
f"_run_read_storage, key_hash_keys_num: {len(k_cache_keys)}, "
f"value_hash_keys_num: {len(v_cache_keys)}, gpu_block_ids_num: {len(gpu_block_ids)}, "
f"cpu_block_ids_num: {len(cpu_block_ids)}"
)
block_num = len(gpu_block_ids)
@@ -468,7 +469,9 @@ class CacheTransferManager:
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
result = self.storage_backend.batch_get(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
read_cost_time = time.time() - start_time
k_result, v_result = result[:block_num], result[block_num:]
success_block_num = 0
@@ -480,6 +483,7 @@ class CacheTransferManager:
valid_cpu_block_ids = cpu_block_ids[:success_block_num]
mode = 1 # cpu ==> gpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_read_buffer,
@@ -498,6 +502,10 @@ class CacheTransferManager:
self.device,
mode,
)
swap_cost_time = time.time() - start_time
logger.debug(
f"_run_read_storage, swap_cost_time: {swap_cost_time:.6f}s, read_cost_time: {read_cost_time:.6f}s"
)
return valid_gpu_block_ids
except Exception as e:
@@ -511,8 +519,8 @@ class CacheTransferManager:
"""Read cache from the storage backend to the GPU memory."""
try:
logger.debug(
f"read_storage_task, task id: {task_id}, hash_keys: {keys}, "
f"gpu_block_ids: {gpu_block_ids}, timeout: {timeout}"
f"read_storage_task, task id: {task_id}, hash_keys_num: {len(keys)}, "
f"gpu_block_ids_num: {len(gpu_block_ids)}, timeout: {timeout}"
)
k_cache_keys = [f"{key}_key_{self.rank}" for key in keys]
v_cache_keys = [f"{key}_value_{self.rank}" for key in keys]
@@ -565,7 +573,9 @@ class CacheTransferManager:
self.key_cache_shape[2],
self.key_cache_shape[3],
]
mode = 0 # gpu ==> cpu
start_time = time.time()
swap_cache_layout(
self.gpu_cache_k_tensors,
self.storage_key_write_buffer,
@@ -584,6 +594,7 @@ class CacheTransferManager:
self.device,
mode,
)
swap_cost_time = time.time() - start_time
block_num = len(gpu_block_ids)
keys = k_cache_keys + v_cache_keys
@@ -595,7 +606,13 @@ class CacheTransferManager:
]
kv_cache_ptrs = k_cache_ptrs + v_cache_ptrs
kv_block_sizes = [self.storage_buffer_stride_bytes] * block_num * 2 # key and value
start_time = time.time()
self.storage_backend.batch_set(keys, target_locations=kv_cache_ptrs, target_sizes=kv_block_sizes)
write_cost_time = time.time() - start_time
logger.debug(
f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s"
)
except Exception as e:
logger.error(
f"[rank {self.rank}/{self.n_ranks}] An error occurred in _run_write_back_storage: "