mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[KVCache] support unified cache backend (#4903)
* [Feature] support unified cache backend * fix * fix * fix * fix * Update metax_model_runner.py * fix * update * Update test_moba_attention_backend.py --------- Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
@@ -1196,11 +1196,11 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
|
||||
@@ -1226,21 +1226,16 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
|
||||
cache_kvs_list = []
|
||||
|
||||
# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
|
||||
# To rationalize the allocation of kvcache.
|
||||
from fastdeploy import envs
|
||||
|
||||
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
if not self.mla_cache:
|
||||
if value_cache_shape:
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
if create_cache_tensor:
|
||||
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
|
||||
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
|
||||
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(key_cache, key_cache_name)
|
||||
if not self.mla_cache:
|
||||
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
|
||||
if value_cache_shape:
|
||||
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
|
||||
set_data_ipc(val_cache, val_cache_name)
|
||||
cache_kvs_list.extend([key_cache, val_cache])
|
||||
else:
|
||||
@@ -1257,12 +1252,12 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
cache_kvs_list.extend([key_cache_scales])
|
||||
else:
|
||||
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
|
||||
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
if not self.mla_cache:
|
||||
key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape)
|
||||
if value_cache_shape:
|
||||
val_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
|
||||
val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape)
|
||||
cache_kvs_list.extend([key_cache, val_cache])
|
||||
else:
|
||||
cache_kvs_list.extend([key_cache])
|
||||
|
||||
Reference in New Issue
Block a user