mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] move cache creation back to cache transfer process and adapt clear/update (#6144)
* [fix] move cache creation back to cache transfer process * [fix] fix clear cache * [chore] change some log level * [fix] fix clear cache * [fix] fix clear cache for blockwisefp8 and mtp * [fix] fix c8 * [fix] fix clear_mtp_cache args * [chore] update cache_transfer_manager * [fix] fix update mtp cache
This commit is contained in:
@@ -69,6 +69,7 @@ else:
|
||||
speculate_save_output_topk,
|
||||
update_attn_mask_offsets,
|
||||
set_data_ipc,
|
||||
unset_data_ipc,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
@@ -99,6 +100,7 @@ class MTPProposer(Proposer):
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
self.enable_logprob = self.model_config.enable_logprob
|
||||
self.enable_draft_logprob = self.speculative_config.enable_draft_logprob
|
||||
self.cache_kvs_map = {}
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = self.scheduler_config.splitwise_role
|
||||
@@ -221,8 +223,12 @@ class MTPProposer(Proposer):
|
||||
|
||||
# Check if gpu runner needs to create kv cache
|
||||
# 1. During profiling, it creates its own kv cache.
|
||||
# 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled.
|
||||
create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed"
|
||||
# 2. If no need to profile, create kv cache if cache managers do not exist.
|
||||
create_cache_tensor = profile or not (
|
||||
self.fd_config.cache_config.num_cpu_blocks > 0
|
||||
or self.fd_config.cache_config.kvcache_storage_backend
|
||||
or self.fd_config.scheduler_config.splitwise_role != "mixed"
|
||||
)
|
||||
|
||||
if not create_cache_tensor:
|
||||
logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}")
|
||||
@@ -245,72 +251,80 @@ class MTPProposer(Proposer):
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_cache = self._share_external_data(key_cache, key_cache_name, key_cache_shape)
|
||||
self.cache_kvs_map[key_cache_name] = key_cache
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = self._share_external_data(value_cache, val_cache_name, value_cache_shape)
|
||||
self.cache_kvs_map[val_cache_name] = value_cache
|
||||
cache_kvs_list.append(value_cache)
|
||||
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
key_scale_cache = share_external_data(key_scale_cache, scale_key_cache_name, kv_cache_scale_shape)
|
||||
key_scale_cache = self._share_external_data(
|
||||
key_scale_cache, scale_key_cache_name, kv_cache_scale_shape
|
||||
)
|
||||
self.cache_kvs_map[scale_key_cache_name] = key_scale_cache
|
||||
cache_kvs_list.append(key_scale_cache)
|
||||
value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
|
||||
value_scale_cache = share_external_data(
|
||||
value_scale_cache = self._share_external_data(
|
||||
value_scale_cache, scale_val_cache_name, kv_cache_scale_shape
|
||||
)
|
||||
self.cache_kvs_map[scale_val_cache_name] = value_scale_cache
|
||||
cache_kvs_list.append(value_scale_cache)
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
logger.info(f"..creating kv cache for mtp layer {i}: key:{key_cache_shape}, value:{value_cache_shape}")
|
||||
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
key_cache = paddle.full(
|
||||
shape=key_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_caches_{i}"], f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
)
|
||||
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
set_data_ipc(key_cache, key_cache_name)
|
||||
self.cache_kvs_map[key_cache_name] = key_cache
|
||||
cache_kvs_list.append(key_cache)
|
||||
|
||||
self.cache_kvs[f"value_caches_{i}"] = paddle.full(
|
||||
val_cache = paddle.full(
|
||||
shape=value_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_caches_{i}"], f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
)
|
||||
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
|
||||
set_data_ipc(val_cache, val_cache_name)
|
||||
self.cache_kvs_map[val_cache_name] = val_cache
|
||||
cache_kvs_list.append(val_cache)
|
||||
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
key_cache_scales = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"key_cache_scales_{i}"],
|
||||
f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
set_data_ipc(key_cache_scales, key_cache_scales_name)
|
||||
self.cache_kvs_map[key_cache_scales_name] = key_cache_scales
|
||||
cache_kvs_list.append(key_cache_scales)
|
||||
|
||||
self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
val_cache_scales = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
set_data_ipc(
|
||||
self.cache_kvs[f"value_cache_scales_{i}"],
|
||||
f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}",
|
||||
)
|
||||
val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}"
|
||||
set_data_ipc(val_cache_scales, val_cache_scales_name)
|
||||
self.cache_kvs_map[val_cache_scales_name] = val_cache_scales
|
||||
cache_kvs_list.append(val_cache_scales)
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
|
||||
self.model_inputs["caches"] = list(self.cache_kvs.values())
|
||||
for value in self.cache_kvs.values():
|
||||
del value
|
||||
self._empty_cache()
|
||||
|
||||
def _initialize_attn_backend(
|
||||
@@ -385,10 +399,19 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def clear_mtp_cache(self):
|
||||
def clear_mtp_cache(self, profile=False):
|
||||
"""
|
||||
Clear allocated cacheKV
|
||||
"""
|
||||
create_cache_tensor = profile or not (
|
||||
self.fd_config.cache_config.num_cpu_blocks > 0
|
||||
or self.fd_config.cache_config.kvcache_storage_backend
|
||||
or self.fd_config.scheduler_config.splitwise_role != "mixed"
|
||||
)
|
||||
if not create_cache_tensor:
|
||||
for name, tensor in self.cache_kvs_map.items():
|
||||
unset_data_ipc(tensor, name, True, False)
|
||||
self.cache_kvs_map.clear()
|
||||
del self.model_inputs["caches"]
|
||||
if self.forward_meta is not None:
|
||||
del self.forward_meta.caches
|
||||
|
||||
Reference in New Issue
Block a user