mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix num_cpu_blocks computation (#6438)
* [BugFix] fix num_cpu_blocks computation * [fix] fix syntax and log * [fix] pre-commit * [fix] use getattr * [fix] ci test
This commit is contained in:
@@ -48,7 +48,7 @@ from fastdeploy.cache_manager.transfer_factory import (
|
||||
FileStore,
|
||||
MooncakeStore,
|
||||
)
|
||||
from fastdeploy.config import SpeculativeConfig
|
||||
from fastdeploy.config import CacheConfig, SpeculativeConfig
|
||||
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import console_logger, get_logger
|
||||
@@ -173,8 +173,8 @@ class CacheTransferManager:
|
||||
|
||||
# compute cache bytes
|
||||
self.cache_dtype = args.cache_dtype
|
||||
self.cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
|
||||
self.scale_item_bytes = self._get_cache_item_bytes(paddle.get_default_dtype())
|
||||
self.cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
|
||||
self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype())
|
||||
self.has_cache_scale = self.cache_dtype == "block_wise_fp8"
|
||||
if self.has_cache_scale:
|
||||
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size]
|
||||
@@ -521,7 +521,7 @@ class CacheTransferManager:
|
||||
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
|
||||
else:
|
||||
value_cache_size = 0
|
||||
cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype)
|
||||
cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
|
||||
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size
|
||||
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size
|
||||
if args.cache_dtype == "block_wise_fp8":
|
||||
@@ -564,17 +564,6 @@ class CacheTransferManager:
|
||||
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
|
||||
self.swap_space_ready_signal.value[self.rank] = 1
|
||||
|
||||
def _get_cache_item_bytes(self, cache_dtype):
|
||||
if cache_dtype == "float32":
|
||||
bytes = 4
|
||||
elif cache_dtype in ("bfloat16", "float16"):
|
||||
bytes = 2
|
||||
elif cache_dtype in ["uint8", "block_wise_fp8"]:
|
||||
bytes = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported cache dtype: {cache_dtype}")
|
||||
return bytes
|
||||
|
||||
def _run_read_storage(
|
||||
self,
|
||||
task_id: str,
|
||||
|
||||
Reference in New Issue
Block a user