[BugFix] move cache creation back to cache transfer process and adapt clear/update (#6144)

* [fix] move cache creation back to cache transfer process

* [fix] fix clear cache

* [chore] change some log level

* [fix] fix clear cache

* [fix] fix clear cache for blockwisefp8 and mtp

* [fix] fix c8

* [fix] fix clear_mtp_cache args

* [chore] update cache_transfer_manager

* [fix] fix update mtp cache
This commit is contained in:
Yonghua Li
2026-01-24 21:59:13 +08:00
committed by GitHub
parent 976203cf60
commit 833d00e2d7
5 changed files with 158 additions and 71 deletions
+49 -26
View File
@@ -69,6 +69,7 @@ else:
speculate_save_output_topk,
update_attn_mask_offsets,
set_data_ipc,
unset_data_ipc,
)
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
@@ -99,6 +100,7 @@ class MTPProposer(Proposer):
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
self.enable_logprob = self.model_config.enable_logprob
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
self.cache_kvs_map = {}
# [mixed, prefill, decoder]
self.role = self.scheduler_config.splitwise_role
@@ -221,8 +223,12 @@ class MTPProposer(Proposer):
# Check if gpu runner needs to create kv cache
# 1. During profiling, it creates its own kv cache.
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
# 2. If no need to profile, create kv cache if cache managers do not exist.
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
if not create_cache_tensor:
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
@@ -245,72 +251,80 @@ class MTPProposer(Proposer):
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
key_cache = self._share_external_data(key_cache, key_cache_name, key_cache_shape)
self.cache_kvs_map[key_cache_name] = key_cache
cache_kvs_list.append(key_cache)
value_cache = paddle.empty(shape=[], dtype=cache_type)
value_cache = self._share_external_data(value_cache, val_cache_name, value_cache_shape)
self.cache_kvs_map[val_cache_name] = value_cache
cache_kvs_list.append(value_cache)
if kv_cache_quant_type == "block_wise_fp8":
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
key_scale_cache = self._share_external_data(
key_scale_cache, scale_key_cache_name, kv_cache_scale_shape
)
self.cache_kvs_map[scale_key_cache_name] = key_scale_cache
cache_kvs_list.append(key_scale_cache)
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_scale_cache = share_external_data(
value_scale_cache = self._share_external_data(
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
)
self.cache_kvs_map[scale_val_cache_name] = value_scale_cache
cache_kvs_list.append(value_scale_cache)
self.model_inputs["caches"] = cache_kvs_list
else:
cache_kvs_list = []
for i in range(
self.num_main_model_layers,
self.num_main_model_layers + self.model_config.num_hidden_layers,
):
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
key_cache = paddle.full(
shape=key_cache_shape,
fill_value=0,
dtype=cache_type,
)
set_data_ipc(
self.cache_kvs[f"key_caches_{i}"], f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
)
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(key_cache, key_cache_name)
self.cache_kvs_map[key_cache_name] = key_cache
cache_kvs_list.append(key_cache)
self.cache_kvs[f"value_caches_{i}"] = paddle.full(
val_cache = paddle.full(
shape=value_cache_shape,
fill_value=0,
dtype=cache_type,
)
set_data_ipc(
self.cache_kvs[f"value_caches_{i}"], f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
)
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(val_cache, val_cache_name)
self.cache_kvs_map[val_cache_name] = val_cache
cache_kvs_list.append(val_cache)
if kv_cache_quant_type == "block_wise_fp8":
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
key_cache_scales = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
self.cache_kvs[f"key_cache_scales_{i}"],
f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
)
key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(key_cache_scales, key_cache_scales_name)
self.cache_kvs_map[key_cache_scales_name] = key_cache_scales
cache_kvs_list.append(key_cache_scales)
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
val_cache_scales = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
set_data_ipc(
self.cache_kvs[f"value_cache_scales_{i}"],
f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
)
val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
set_data_ipc(val_cache_scales, val_cache_scales_name)
self.cache_kvs_map[val_cache_scales_name] = val_cache_scales
cache_kvs_list.append(val_cache_scales)
self.model_inputs["caches"] = cache_kvs_list
self.model_inputs["caches"] = list(self.cache_kvs.values())
for value in self.cache_kvs.values():
del value
self._empty_cache()
def _initialize_attn_backend(
@@ -385,10 +399,19 @@ class MTPProposer(Proposer):
)
self.attn_backends.append(attn_backend)
def clear_mtp_cache(self):
def clear_mtp_cache(self, profile=False):
"""
Clear allocated cacheKV
"""
create_cache_tensor = profile or not (
self.fd_config.cache_config.num_cpu_blocks > 0
or self.fd_config.cache_config.kvcache_storage_backend
or self.fd_config.scheduler_config.splitwise_role != "mixed"
)
if not create_cache_tensor:
for name, tensor in self.cache_kvs_map.items():
unset_data_ipc(tensor, name, True, False)
self.cache_kvs_map.clear()
del self.model_inputs["caches"]
if self.forward_meta is not None:
del self.forward_meta.caches