[KVCache] support unified cache backend (#4903)

* [Feature] support unified cache backend

* fix

* fix

* fix

* fix

* Update metax_model_runner.py

* fix

* update

* Update test_moba_attention_backend.py

---------

Co-authored-by: ltd0924 <luotingdan@baidu.com>
This commit is contained in:
ltd0924
2025-11-12 14:54:52 +08:00
committed by GitHub
parent 76e60e98f8
commit 5bf48de999
19 changed files with 281 additions and 202 deletions
+11 -16
View File
@@ -1196,11 +1196,11 @@ class MetaxModelRunner(ModelRunnerBase):
kv_cache_quant_type = self.quant_config.kv_cache_quant_type
# Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
key_cache_shape, value_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32)
@@ -1226,21 +1226,16 @@ class MetaxModelRunner(ModelRunnerBase):
logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}")
cache_kvs_list = []
# NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention,
# To rationalize the allocation of kvcache.
from fastdeploy import envs
self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN"
for i in range(self.model_config.num_hidden_layers):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
if not self.mla_cache:
if value_cache_shape:
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
if create_cache_tensor:
logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}")
key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
logger.info(f"..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(key_cache, key_cache_name)
if not self.mla_cache:
val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type)
if value_cache_shape:
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type)
set_data_ipc(val_cache, val_cache_name)
cache_kvs_list.extend([key_cache, val_cache])
else:
@@ -1257,12 +1252,12 @@ class MetaxModelRunner(ModelRunnerBase):
else:
cache_kvs_list.extend([key_cache_scales])
else:
logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}")
logger.info(f"..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}")
key_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
if not self.mla_cache:
key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape)
if value_cache_shape:
val_cache = paddle.empty(shape=[], dtype=cache_type)
val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape)
val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape)
cache_kvs_list.extend([key_cache, val_cache])
else:
cache_kvs_list.extend([key_cache])