mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[KVCache] support unified cache backend (#4903)
* [Feature] support unified cache backend * fix * fix * fix * fix * Update metax_model_runner.py * fix * update * Update test_moba_attention_backend.py --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
@@ -63,7 +63,7 @@ class PrefixCacheManager:
|
||||
else:
|
||||
self.enable_splitwise = 0
|
||||
self.splitwise_role = splitwise_role
|
||||
|
||||
self.config = config
|
||||
self.cache_config = config.cache_config
|
||||
self.speculative_config = config.speculative_config
|
||||
self.local_data_parallel_id = local_data_parallel_id
|
||||
@@ -82,6 +82,8 @@ class PrefixCacheManager:
|
||||
heapq.heapify(self.gpu_free_block_list)
|
||||
heapq.heapify(self.cpu_free_block_list)
|
||||
|
||||
self.key_cache_shape = []
|
||||
self.val_cache_shape = []
|
||||
self.node_id_pool = list(range(self.num_gpu_blocks + self.num_cpu_blocks))
|
||||
|
||||
self.radix_tree_root = BlockNode(-1, [], 0, 0, -1, 0, None, None, None)
|
||||
@@ -120,6 +122,39 @@ class PrefixCacheManager:
|
||||
main_process_metrics.free_gpu_block_num.set(self.num_gpu_blocks)
|
||||
main_process_metrics.available_gpu_resource.set(1.0)
|
||||
|
||||
def _get_kv_cache_shape(self, max_block_num):
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
|
||||
attn_cls = get_attention_backend()
|
||||
num_heads = self.config.model_config.num_attention_heads // self.config.parallel_config.tensor_parallel_size
|
||||
kv_num_heads = max(
|
||||
1,
|
||||
int(self.config.model_config.num_key_value_heads) // self.config.parallel_config.tensor_parallel_size,
|
||||
)
|
||||
head_dim = self.config.model_config.head_dim
|
||||
|
||||
kv_cache_quant_type = None
|
||||
if (
|
||||
self.config.quant_config
|
||||
and hasattr(self.config.quant_config, "kv_cache_quant_type")
|
||||
and self.config.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
kv_cache_quant_type = self.config.quant_config.kv_cache_quant_type
|
||||
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
key_cache_shape, value_cache_shape = attn_cls(
|
||||
self.config,
|
||||
kv_num_heads=kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
decoder_block_shape_q=decoder_block_shape_q,
|
||||
).get_kv_cache_shape(max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type)
|
||||
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {value_cache_shape}")
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
@property
|
||||
def available_gpu_resource(self):
|
||||
return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0
|
||||
@@ -161,11 +196,17 @@ class PrefixCacheManager:
|
||||
py_path = os.path.join(current_dir_path, filename)
|
||||
|
||||
cache_messager_processes = []
|
||||
key_cache_shape, val_cache_shape = self._get_kv_cache_shape(cache_config.total_block_num)
|
||||
key_cache_shape = ",".join([str(i) for i in key_cache_shape])
|
||||
val_cache_shape = ",".join([str(i) for i in val_cache_shape])
|
||||
logger.info(f"key_cache_shape {key_cache_shape} value_cache_shape {val_cache_shape}")
|
||||
if self.enable_splitwise:
|
||||
cache_messager_processes = self.launch_cache_messager(
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
key_cache_shape,
|
||||
val_cache_shape,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
@@ -174,17 +215,6 @@ class PrefixCacheManager:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
return []
|
||||
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and cache_config.model_cfg.num_key_value_heads is not None
|
||||
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||
):
|
||||
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
kv_num_head = max(1, kv_num_head)
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
name="cache_ready_signal",
|
||||
@@ -223,18 +253,15 @@ class PrefixCacheManager:
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ f" --value_cache_shape {val_cache_shape}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --bytes_per_layer_per_block {cache_config.bytes_per_layer_per_block}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
@@ -273,22 +300,21 @@ class PrefixCacheManager:
|
||||
return all_cache_processes
|
||||
|
||||
def launch_cache_messager(
|
||||
self, cache_config, tensor_parallel_size, device_ids, pod_ip, engine_worker_queue_port, pid_suffix
|
||||
self,
|
||||
cache_config,
|
||||
tensor_parallel_size,
|
||||
device_ids,
|
||||
key_cache_shape,
|
||||
value_cache_shape,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
):
|
||||
"""
|
||||
launch_cache_messager function used to initialize the cache messager.
|
||||
"""
|
||||
current_dir_path = os.path.split(os.path.abspath(__file__))[0]
|
||||
filename = "cache_messager.py"
|
||||
if (
|
||||
hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and hasattr(cache_config.model_cfg, "num_key_value_heads")
|
||||
and cache_config.model_cfg.num_key_value_heads is not None
|
||||
and int(cache_config.model_cfg.num_key_value_heads) > 0
|
||||
):
|
||||
kv_num_head = int(cache_config.model_cfg.num_key_value_heads) // tensor_parallel_size
|
||||
else:
|
||||
kv_num_head = cache_config.model_cfg.num_attention_heads // tensor_parallel_size
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[tensor_parallel_size], dtype=np.int32)
|
||||
self.cache_ready_signal = IPCSignal(
|
||||
@@ -311,15 +337,13 @@ class PrefixCacheManager:
|
||||
+ f" --rank {i}"
|
||||
+ f" --splitwise_role {self.splitwise_role}"
|
||||
+ f" --num_layers {cache_config.model_cfg.num_hidden_layers}"
|
||||
+ f" --head_dim {cache_config.model_cfg.head_dim}"
|
||||
+ f" --kv_num_head {kv_num_head}"
|
||||
+ f" --mp_num {tensor_parallel_size}"
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ f" --value_cache_shape {value_cache_shape}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_gpu_blocks {cache_config.total_block_num}"
|
||||
+ f" --block_size {cache_config.block_size}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
|
||||
Reference in New Issue
Block a user