mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Support KV Cache Storage (#5571)
* Support Mooncake Store * up * up * add op * fix conflict * fix error * up for comments * avoid thread lock * up * fix unittest * fix unittest * remove debug info * consider tp_size > 1 * add default rdma_nics * add utils * up * fix error --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -14,10 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import heapq
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
@@ -34,9 +32,10 @@ from fastdeploy import envs
|
||||
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
|
||||
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
|
||||
from fastdeploy.cache_manager.ops import get_all_visible_devices
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.utils import get_hash_str, get_logger
|
||||
|
||||
logger = get_logger("prefix_cache_manager", "cache_manager.log")
|
||||
|
||||
@@ -65,6 +64,7 @@ class PrefixCacheManager:
|
||||
self.enable_splitwise = 0
|
||||
self.splitwise_role = splitwise_role
|
||||
self.config = config
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.cache_config = config.cache_config
|
||||
self.speculative_config = config.speculative_config
|
||||
self.local_data_parallel_id = local_data_parallel_id
|
||||
@@ -89,6 +89,13 @@ class PrefixCacheManager:
|
||||
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
|
||||
# prams for cache storage
|
||||
self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend
|
||||
self.write_policy = self.cache_config.write_policy
|
||||
self.task_write_back_event = {}
|
||||
self.task_prefetch_event = {}
|
||||
self.storage_prefetch_block_ids = {}
|
||||
|
||||
# gpu cache data structure
|
||||
self.gpu_lru_leaf_heap = []
|
||||
self.gpu_lru_leaf_set = set()
|
||||
@@ -105,7 +112,7 @@ class PrefixCacheManager:
|
||||
self.req_leaf_map = {} # {request_id: leaf node}
|
||||
self.leaf_req_map = defaultdict(set)
|
||||
self.unfilled_req_block_map = defaultdict(list)
|
||||
self.cache_info = {}
|
||||
self.cache_info = {} # {request_id: (last_match_node, num_cached_tokens)}
|
||||
|
||||
self.executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
self.free_gpu_executor_pool = ThreadPoolExecutor(max_workers=1)
|
||||
@@ -253,6 +260,10 @@ class PrefixCacheManager:
|
||||
else:
|
||||
val_shape_str = str(val_cache_shape)
|
||||
val_cache_arg_str = f" --value_cache_shape {val_shape_str}"
|
||||
if cache_config.kvcache_storage_backend:
|
||||
kvcache_storage_backend_str = cache_config.kvcache_storage_backend
|
||||
else:
|
||||
kvcache_storage_backend_str = "none"
|
||||
|
||||
for i in range(tensor_parallel_size):
|
||||
launch_cmd = (
|
||||
@@ -281,7 +292,10 @@ class PrefixCacheManager:
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
|
||||
+ f" --kvcache_storage_backend {kvcache_storage_backend_str}"
|
||||
+ f" --write_policy {cache_config.write_policy}"
|
||||
+ f" --max_model_len {self.config.model_config.max_model_len}"
|
||||
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1"
|
||||
)
|
||||
logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
|
||||
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
|
||||
@@ -290,7 +304,7 @@ class PrefixCacheManager:
|
||||
while np.sum(self.cache_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
if self.num_cpu_blocks > 0:
|
||||
while np.sum(self.swap_space_ready_signal.value) != tensor_parallel_size:
|
||||
time.sleep(1)
|
||||
|
||||
@@ -303,7 +317,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
|
||||
# Start additional threads
|
||||
if cache_config.enable_hierarchical_cache and self.num_cpu_blocks > 0:
|
||||
if cache_config.kvcache_storage_backend or self.num_cpu_blocks > 0:
|
||||
logger.info("Enable hierarchical cache.")
|
||||
threading.Thread(target=self.recv_data_transfer_result).start()
|
||||
if cache_config.enable_prefix_caching:
|
||||
@@ -505,13 +519,7 @@ class PrefixCacheManager:
|
||||
|
||||
self.task_swapping_event[transfer_task_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task(
|
||||
(
|
||||
swap_node_ids,
|
||||
gpu_block_ids,
|
||||
cpu_block_ids,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
)
|
||||
(event_type, transfer_task_id, swap_node_ids, gpu_block_ids, cpu_block_ids)
|
||||
)
|
||||
if is_sync:
|
||||
self.sync_swap_task(transfer_task_id)
|
||||
@@ -629,6 +637,10 @@ class PrefixCacheManager:
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||
self.leaf_req_map[last_node].remove(req_id)
|
||||
logger.debug(
|
||||
f"update_cache_blocks: req_id {req_id}, num_cached_tokens {num_cached_tokens}, "
|
||||
f"can_cache_computed_tokens {can_cache_computed_tokens}"
|
||||
)
|
||||
|
||||
with self.request_release_lock:
|
||||
leaf_node = self.mm_build_path(
|
||||
@@ -640,7 +652,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
self.req_leaf_map[req_id] = leaf_node
|
||||
self.leaf_req_map[leaf_node].add(req_id)
|
||||
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
|
||||
self.cache_info[req_id] = [leaf_node, can_cache_computed_tokens]
|
||||
task.cached_block_num = can_cache_computed_tokens // block_size
|
||||
except Exception as e:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
@@ -692,7 +704,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
req_id = task.request_id
|
||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||
logger.info(f"request_match_blocks: start to process req {req_id}")
|
||||
input_token_num = len(prompt_token_ids + task.output_token_ids)
|
||||
common_block_ids = []
|
||||
# 1. match block
|
||||
@@ -708,12 +720,14 @@ class PrefixCacheManager:
|
||||
# update matched node info
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time=time.time())
|
||||
|
||||
# 2. prepare cache
|
||||
# allocate gpu cache for matched cpu blocks
|
||||
# 2. prepare cache: allocate gpu cache for matched cpu blocks, wait for data transfer to complete
|
||||
gpu_recv_block_ids = []
|
||||
match_cpu_blocks_num = len(match_cpu_block_ids)
|
||||
if self.can_allocate_gpu_blocks(num_blocks=match_cpu_blocks_num):
|
||||
if match_cpu_blocks_num > 0:
|
||||
logger.debug(
|
||||
f"request_match_blocks: req_id {req_id}, allocate {match_cpu_blocks_num} block to receive cpu cache"
|
||||
)
|
||||
gpu_recv_block_ids = self.allocate_gpu_blocks(match_cpu_blocks_num)
|
||||
if len(gpu_recv_block_ids) > 0:
|
||||
self._prepare_cpu_cache(
|
||||
@@ -746,20 +760,58 @@ class PrefixCacheManager:
|
||||
self.metrics._update_history_hit_metrics()
|
||||
if self.metrics.req_count % 10000 == 0:
|
||||
self.metrics.reset_metrics()
|
||||
logger.info(
|
||||
f"request_match_blocks: request block for req_id {req_id}: common_block_ids {common_block_ids}"
|
||||
)
|
||||
logger.info(f"request_match_blocks: req_id {req_id}, matched_block_ids {common_block_ids}")
|
||||
# set leaf node temporarily, then update it in update_cache_blocks
|
||||
self.req_leaf_map[req_id] = match_block_node
|
||||
self.leaf_req_map[match_block_node].add(req_id)
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, len(common_block_ids) * block_size)
|
||||
self.cache_info[req_id] = [match_block_node, len(common_block_ids) * block_size]
|
||||
task.cached_block_num = len(common_block_ids)
|
||||
return common_block_ids, matched_token_num, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
def request_match_storage_blocks(self, request, extra_gpu_block_ids):
|
||||
"""
|
||||
Match and fetch the cached blocks from the storage backend for the given request.
|
||||
# TODO: merge this function into request_match_blocks
|
||||
|
||||
args:
|
||||
request: The request to be processed
|
||||
extra_gpu_block_ids: A list of GPU block IDs to be used for fetching the cache
|
||||
returns:
|
||||
matched_block_ids: A list of block IDs that prefetched cache from storage
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return []
|
||||
|
||||
req_id = request.request_id
|
||||
input_ids = request.prompt_token_ids
|
||||
block_size = self.cache_config.block_size
|
||||
|
||||
prefix_block_key = []
|
||||
num_cached_tokens = 0
|
||||
if req_id in self.cache_info:
|
||||
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
block_keys = []
|
||||
current_tokens = num_cached_tokens
|
||||
while current_tokens <= len(input_ids) - block_size:
|
||||
cur_block_key = get_hash_str(input_ids[current_tokens : current_tokens + block_size], prefix_block_key)
|
||||
block_keys.append(cur_block_key)
|
||||
current_tokens += block_size
|
||||
prefix_block_key = [cur_block_key]
|
||||
|
||||
logger.info(f"start prefetch cache from storage, req_id: {req_id}, block num: {len(block_keys)}")
|
||||
matched_block_ids = self.issue_prefetch_storage_task(req_id, block_keys, extra_gpu_block_ids)
|
||||
logger.info(
|
||||
f"finish prefetch cache from storage, req_id: {req_id}, matched block num: {len(matched_block_ids)}"
|
||||
)
|
||||
|
||||
return matched_block_ids
|
||||
|
||||
def request_block_ids(self, task, block_size, dec_token_num, *args):
|
||||
"""
|
||||
Allocate blocks for a task.
|
||||
@@ -806,10 +858,7 @@ class PrefixCacheManager:
|
||||
current_time = time.time()
|
||||
self._update_matched_node_info(req_id, match_block_node, current_time)
|
||||
# 2. prepare cache
|
||||
(
|
||||
gpu_recv_block_ids,
|
||||
gpu_extra_block_ids,
|
||||
) = self._prepare_cache(
|
||||
(gpu_recv_block_ids, gpu_extra_block_ids) = self._prepare_cache(
|
||||
req_id,
|
||||
input_ids,
|
||||
block_size,
|
||||
@@ -829,7 +878,6 @@ class PrefixCacheManager:
|
||||
gpu_build_path_block_ids = []
|
||||
|
||||
gpu_build_path_block_ids = gpu_extra_block_ids
|
||||
|
||||
leaf_node = self.build_path(
|
||||
req_id,
|
||||
current_time,
|
||||
@@ -883,6 +931,7 @@ class PrefixCacheManager:
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
req_id = task.request_id
|
||||
keys = []
|
||||
leaf_node = self.req_leaf_map.pop(req_id)
|
||||
if leaf_node in self.leaf_req_map:
|
||||
self.leaf_req_map[leaf_node].remove(req_id)
|
||||
@@ -893,6 +942,7 @@ class PrefixCacheManager:
|
||||
if req_id in node.req_id_set:
|
||||
node.req_id_set.remove(req_id)
|
||||
node.decrement_shared_count()
|
||||
keys.append(node.hash_value)
|
||||
node = node.parent
|
||||
|
||||
if req_id in self.cache_info:
|
||||
@@ -919,6 +969,78 @@ class PrefixCacheManager:
|
||||
logger.error(f"release_block_ids: error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
def write_cache_to_storage(self, request: Request):
|
||||
"""
|
||||
For finished request, write cache to storage.
|
||||
NOTE: this function does not modify the global params
|
||||
"""
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
req_id = request.request_id
|
||||
keys = []
|
||||
node = self.req_leaf_map[req_id]
|
||||
while node != self.radix_tree_root:
|
||||
keys.append(node.hash_value)
|
||||
node = node.parent
|
||||
keys = list(reversed(keys))
|
||||
if not keys:
|
||||
return
|
||||
|
||||
gpu_block_ids = request.block_tables[: len(keys)]
|
||||
logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}")
|
||||
tic = time.time()
|
||||
self.issue_write_back_storage_task(req_id=req_id, hash_keys=keys, gpu_block_ids=gpu_block_ids, is_sync=True)
|
||||
cost_time = time.time() - tic
|
||||
logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s")
|
||||
|
||||
def issue_write_back_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
if self.kvcache_storage_backend is None:
|
||||
return
|
||||
|
||||
if len(hash_keys) != len(gpu_block_ids):
|
||||
err_msg = f"write_back_storage error: hash_keys({len(hash_keys)}) != gpu_block_ids({len(gpu_block_ids)})"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
self.task_write_back_event[req_id] = Event()
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
if is_sync:
|
||||
self.wait_write_storage_task(req_id)
|
||||
|
||||
def wait_write_storage_task(self, req_id):
|
||||
"""
|
||||
Sync write back task
|
||||
"""
|
||||
if req_id in self.task_write_back_event:
|
||||
self.task_write_back_event[req_id].wait()
|
||||
del self.task_write_back_event[req_id]
|
||||
|
||||
def issue_prefetch_storage_task(self, req_id, hash_keys, gpu_block_ids, is_sync=True, timeout=0.5):
|
||||
"""
|
||||
Prefetch cache from storage task
|
||||
"""
|
||||
storage_block_ids = []
|
||||
self.task_prefetch_event[req_id] = Event()
|
||||
# issue task to cache_transfer_manager
|
||||
self.cache_task_queue.put_transfer_task((CacheStatus.STORAGE2GPU, req_id, hash_keys, gpu_block_ids, timeout))
|
||||
if is_sync:
|
||||
storage_block_ids = self.wait_prefetch_storage_task(req_id)
|
||||
return storage_block_ids
|
||||
|
||||
def wait_prefetch_storage_task(self, req_id):
|
||||
"""
|
||||
Wait for prefetch cache from storage task to finish
|
||||
"""
|
||||
if req_id not in self.task_prefetch_event:
|
||||
return None
|
||||
|
||||
self.task_prefetch_event[req_id].wait()
|
||||
storage_block_ids = self.storage_prefetch_block_ids[req_id]
|
||||
del self.task_prefetch_event[req_id]
|
||||
del self.storage_prefetch_block_ids[req_id]
|
||||
return storage_block_ids
|
||||
|
||||
def free_nodes_directly(self, node):
|
||||
with self.request_release_lock:
|
||||
try:
|
||||
@@ -1069,10 +1191,7 @@ class PrefixCacheManager:
|
||||
break
|
||||
node = heapq.heappop(self.gpu_lru_leaf_heap)
|
||||
self.gpu_lru_leaf_set.remove(node)
|
||||
if (
|
||||
not self.cache_config.enable_hierarchical_cache
|
||||
or self.cache_config.num_cpu_blocks < need_block_num
|
||||
):
|
||||
if self.cache_config.num_cpu_blocks < need_block_num:
|
||||
if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收
|
||||
self._handle_free_gpu_node_without_cpu(node)
|
||||
total_gpu_free_count += 1
|
||||
@@ -1195,12 +1314,6 @@ class PrefixCacheManager:
|
||||
)
|
||||
return total_cpu_free_count
|
||||
|
||||
def cal_block_hash(self, block):
|
||||
"""
|
||||
calculate hash value of a block
|
||||
"""
|
||||
return hash(tuple(block))
|
||||
|
||||
def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
|
||||
"""
|
||||
Retrieves additional hash keys for block identification.
|
||||
@@ -1260,16 +1373,6 @@ class PrefixCacheManager:
|
||||
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
|
||||
return len(mm_inputs["mm_positions"]) - 1, hash_keys
|
||||
|
||||
def hash_block_features(self, input_ids, extra_keys: list = []):
|
||||
"""
|
||||
calculate hash value of a block with additional keys
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
extra_keys: Additional keys for block identification
|
||||
"""
|
||||
return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest()
|
||||
|
||||
def _revert_match_blocks(
|
||||
self,
|
||||
request,
|
||||
@@ -1363,6 +1466,7 @@ class PrefixCacheManager:
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
prefix_block_key = []
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
@@ -1376,7 +1480,10 @@ class PrefixCacheManager:
|
||||
end_idx=match_token_num + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(token_block, extra_keys)
|
||||
prefix_block_key.extend(extra_keys)
|
||||
hash_value = get_hash_str(token_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
@@ -1476,6 +1583,7 @@ class PrefixCacheManager:
|
||||
matche_nodes = []
|
||||
has_modified_gpu_lru_leaf_heap = False
|
||||
has_modified_cpu_lru_leaf_heap = False
|
||||
prefix_block_key = []
|
||||
|
||||
with self.cache_status_lock:
|
||||
while match_token_num < total_token_num:
|
||||
@@ -1483,7 +1591,8 @@ class PrefixCacheManager:
|
||||
token_num = len(token_block)
|
||||
if token_num != block_size:
|
||||
break
|
||||
hash_value = self.cal_block_hash(token_block)
|
||||
hash_value = get_hash_str(token_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
if hash_value in current_match_node.children:
|
||||
child = current_match_node.children[hash_value]
|
||||
matche_nodes.append(child)
|
||||
@@ -1515,6 +1624,8 @@ class PrefixCacheManager:
|
||||
swap_node_ids.append(child.node_id)
|
||||
match_token_num = match_token_num + block_size
|
||||
current_match_node = child
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = [child, match_token_num]
|
||||
else:
|
||||
break
|
||||
|
||||
@@ -1577,8 +1688,10 @@ class PrefixCacheManager:
|
||||
has_unfilled_block = False
|
||||
current_time = time.time()
|
||||
|
||||
input_hash_value = self.hash_block_features(input_ids)
|
||||
input_hash_value = get_hash_str(input_ids)
|
||||
gpu_block_ids = request.block_tables[num_cached_tokens // block_size :].copy()
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
for i in range(num_cached_tokens, can_cache_computed_tokens, block_size):
|
||||
current_block = input_ids[i : i + block_size]
|
||||
current_block_size = len(current_block) # 最后一个block可能没填满
|
||||
@@ -1591,7 +1704,9 @@ class PrefixCacheManager:
|
||||
end_idx=i + block_size,
|
||||
mm_idx=mm_idx,
|
||||
)
|
||||
hash_value = self.hash_block_features(current_block, extra_keys)
|
||||
prefix_block_key.extend(extra_keys)
|
||||
hash_value = get_hash_str(current_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
@@ -1651,7 +1766,7 @@ class PrefixCacheManager:
|
||||
gpu_block_ids = gpu_block_ids.copy()
|
||||
node = last_node
|
||||
reverved_dec_block_ids = []
|
||||
input_hash_value = self.cal_block_hash(input_ids)
|
||||
input_hash_value = get_hash_str(input_ids)
|
||||
|
||||
token_num = len(left_input_ids)
|
||||
if token_num == 0:
|
||||
@@ -1663,6 +1778,7 @@ class PrefixCacheManager:
|
||||
unique_node_ids = []
|
||||
new_last_node = last_node
|
||||
has_unfilled_block = False
|
||||
prefix_block_key = [] if last_node.hash_value is None else [last_node.hash_value]
|
||||
|
||||
for i in range(0, token_num, block_size):
|
||||
current_block = left_input_ids[i : i + block_size]
|
||||
@@ -1670,7 +1786,8 @@ class PrefixCacheManager:
|
||||
if current_block_size != block_size:
|
||||
has_unfilled_block = True
|
||||
else:
|
||||
hash_value = self.cal_block_hash(current_block)
|
||||
hash_value = get_hash_str(current_block, prefix_block_key)
|
||||
prefix_block_key = [hash_value]
|
||||
allocated_block_id = gpu_block_ids.pop(0)
|
||||
node_id = self.node_id_pool.pop()
|
||||
unique_node_ids.append(node_id)
|
||||
@@ -1764,28 +1881,47 @@ class PrefixCacheManager:
|
||||
if data is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
(
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
) = data
|
||||
length = len(task_gpu_block_id)
|
||||
for i in range(length):
|
||||
self._handle_swap_result(
|
||||
swap_node_ids[i],
|
||||
task_gpu_block_id[i],
|
||||
task_cpu_block_id[i],
|
||||
event_type = data[0]
|
||||
|
||||
if event_type.value == CacheStatus.STORAGE2GPU.value:
|
||||
logger.info(f"recv_data_transfer_result: {data}")
|
||||
task_id, hash_keys, block_ids = data[1:]
|
||||
if task_id not in self.storage_prefetch_block_ids:
|
||||
self.storage_prefetch_block_ids[task_id] = []
|
||||
saved_block_ids = self.storage_prefetch_block_ids[task_id]
|
||||
saved_block_ids.append(block_ids)
|
||||
if len(saved_block_ids) == self.tensor_parallel_size:
|
||||
self.storage_prefetch_block_ids[task_id] = min(saved_block_ids, key=len)
|
||||
if task_id in self.task_prefetch_event:
|
||||
self.task_prefetch_event[task_id].set()
|
||||
elif event_type.value == CacheStatus.GPU2STORAGE.value:
|
||||
logger.info(f"recv_data_transfer_result: {data}")
|
||||
task_id, hash_keys, block_ids = data[1:]
|
||||
if task_id in self.task_write_back_event:
|
||||
self.task_write_back_event[task_id].set()
|
||||
else:
|
||||
(
|
||||
event_type,
|
||||
transfer_task_id,
|
||||
swap_node_ids,
|
||||
task_gpu_block_id,
|
||||
task_cpu_block_id,
|
||||
) = data
|
||||
length = len(task_gpu_block_id)
|
||||
for i in range(length):
|
||||
self._handle_swap_result(
|
||||
swap_node_ids[i],
|
||||
task_gpu_block_id[i],
|
||||
task_cpu_block_id[i],
|
||||
event_type,
|
||||
)
|
||||
if transfer_task_id in self.task_swapping_event:
|
||||
self.task_swapping_event[transfer_task_id].set()
|
||||
logger.info(
|
||||
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
|
||||
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
|
||||
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
|
||||
)
|
||||
if transfer_task_id in self.task_swapping_event:
|
||||
self.task_swapping_event[transfer_task_id].set()
|
||||
logger.info(
|
||||
f"recv_data_transfer_result: transfer_task_id {transfer_task_id}: "
|
||||
+ f"task_node_ids {swap_node_ids} task_gpu_block_id {task_gpu_block_id} "
|
||||
+ f"task_cpu_block_id {task_cpu_block_id} event_type {event_type} done"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"recv_data_transfer_result: error: {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user