diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index ed3524ed24..78daf35d09 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -48,7 +48,7 @@ from fastdeploy.cache_manager.transfer_factory import ( FileStore, MooncakeStore, ) -from fastdeploy.config import SpeculativeConfig +from fastdeploy.config import CacheConfig, SpeculativeConfig from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.platforms import current_platform from fastdeploy.utils import console_logger, get_logger @@ -173,8 +173,8 @@ class CacheTransferManager: # compute cache bytes self.cache_dtype = args.cache_dtype - self.cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype) - self.scale_item_bytes = self._get_cache_item_bytes(paddle.get_default_dtype()) + self.cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype) + self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype()) self.has_cache_scale = self.cache_dtype == "block_wise_fp8" if self.has_cache_scale: self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size] @@ -521,7 +521,7 @@ class CacheTransferManager: value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3] else: value_cache_size = 0 - cache_item_bytes = self._get_cache_item_bytes(self.cache_dtype) + cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype) key_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * key_cache_size value_need_to_allocate_bytes = args.num_cpu_blocks * cache_item_bytes * value_cache_size if args.cache_dtype == "block_wise_fp8": @@ -564,17 +564,6 @@ class CacheTransferManager: logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 - def _get_cache_item_bytes(self, cache_dtype): - if cache_dtype == "float32": - bytes = 4 - elif cache_dtype in ("bfloat16", "float16"): - bytes = 2 - elif cache_dtype in ["uint8", "block_wise_fp8"]: - bytes = 1 - else: - raise ValueError(f"Unsupported cache dtype: {cache_dtype}") - return bytes - def _run_read_storage( self, task_id: str, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index f30274ceae..577901a7cd 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -124,8 +124,9 @@ class PrefixCacheManager: self.cache_status_lock = Lock() logger.info( - f"num_gpu_blocks_server_owned {self.num_gpu_blocks} num_cpu_blocks " - + f"{self.num_cpu_blocks}, bytes_per_layer_per_block {self.cache_config.bytes_per_layer_per_block}" + f"Prefix cache manager is initialized with {self.num_gpu_blocks} gpu blocks " + f"and {self.num_cpu_blocks} cpu blocks, bytes_per_token_per_layer for each rank: " + f"{self.cache_config.bytes_per_token_per_layer / self.config.parallel_config.tensor_parallel_size}" ) main_process_metrics.max_gpu_block_num.set(self.num_gpu_blocks) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 49068366ba..a49c3e4cba 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1393,6 +1393,7 @@ class CacheConfig: self.kvcache_storage_backend = None self.write_policy = None self.num_cpu_blocks = None + self.use_mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for key, value in args.items(): if hasattr(self, key): @@ -1407,35 +1408,16 @@ class CacheConfig: self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype) if self.model_cfg.quantization_config is not None: self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype) - if ( - hasattr(self.model_cfg, "num_key_value_heads") - and hasattr(self.model_cfg, "num_key_value_heads") - and self.model_cfg.num_key_value_heads is not None - and int(self.model_cfg.num_key_value_heads) > 0 - ): - kv_num_head = int(self.model_cfg.num_key_value_heads) - else: - kv_num_head = self.model_cfg.num_attention_heads - self.model_cfg.kv_num_head = kv_num_head - # TODO check name - if "int4" in self.cache_dtype.lower() or "float4" in self.cache_dtype.lower(): - byte_size = 0.5 - self.cache_dtype = "uint8" - elif "int8" in self.cache_dtype.lower() or "float8" in self.cache_dtype.lower(): - self.cache_dtype = "uint8" - byte_size = 1 - else: - byte_size = 2 - self.each_token_cache_space = int( - self.model_cfg.num_hidden_layers * kv_num_head * self.model_cfg.head_dim * byte_size + self.head_num = getattr(self.model_cfg, "num_key_value_heads", None) or getattr( + self.model_cfg, "num_attention_heads", None ) - self.bytes_per_block = int(self.each_token_cache_space * self.block_size) - self.bytes_per_layer_per_block = int( - self.block_size - * self.model_cfg.kv_num_head - * self.model_cfg.head_dim - // args["tensor_parallel_size"] - * byte_size + self.head_dim = getattr(self.model_cfg, "head_dim") + self.byte_size = self.get_cache_bytes(self.cache_dtype) + self.kv_factor = 1 if self.use_mla_cache else 2 + + self.bytes_per_token_per_layer = int(self.head_num * self.head_dim * self.byte_size * self.kv_factor) + self.bytes_per_block = int( + self.bytes_per_token_per_layer * self.block_size * self.model_cfg.num_hidden_layers ) if self.num_cpu_blocks is None: @@ -1446,6 +1428,19 @@ class CacheConfig: self._verify_args() + @staticmethod + def get_cache_bytes(cache_dtype: str): + if any(t in cache_dtype.lower() for t in ["float32", "fp32"]): + return 4 + elif any(t in cache_dtype.lower() for t in ["float16", "bf16", "fp16"]): + return 2 + elif any(t in cache_dtype.lower() for t in ["uint8", "int8", "float8", "fp8"]): + return 1 + elif any(t in cache_dtype.lower() for t in ["int4"]): + return 0.5 + else: + raise ValueError(f"Unsupported cache dtype: {cache_dtype}") + def metrics_info(self): """Convert cache_config to dict(key: str, value: str) for prometheus metrics info.""" return {key: str(value) for key, value in self.__dict__.items()} @@ -1690,7 +1685,6 @@ class FDConfig: self.quant_config: Optional[QuantConfigBase] = quant_config self.graph_opt_config: Optional[GraphOptimizationConfig] = graph_opt_config self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config - self.cache_config: CacheConfig = cache_config # type: ignore self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config self.router_config: RouterConfig = router_config diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e4ab707f62..7fa1c6f5db 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1533,12 +1533,6 @@ class GPUModelRunner(ModelRunnerBase): logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") cache_kvs_list = [] - - # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, - # To rationalize the allocation of kvcache. - from fastdeploy import envs - - self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" for i in range(self.model_config.num_hidden_layers): # init key cache key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" @@ -2748,7 +2742,7 @@ class GPUModelRunner(ModelRunnerBase): # NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention, # To rationalize the allocation of kvcache. - if self.mla_cache: + if self.fd_config.cache_config.use_mla_cache: required_memory = ( byte_of_dtype * (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 15585bf6fe..c4f4dea6a8 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -310,12 +310,6 @@ class TestCacheTransferManager(unittest.TestCase): # ========================== # 工具函数与存储相关测试 # ========================== - def test_get_cache_bytes_and_invalid(self): - self.assertEqual(self.manager._get_cache_item_bytes("bfloat16"), 2) - self.assertEqual(self.manager._get_cache_item_bytes("float32"), 4) - with self.assertRaises(ValueError): - self.manager._get_cache_item_bytes("int32") - def test_run_read_storage_swaps_valid_blocks(self): self.manager.storage_backend = MagicMock() self.manager.storage_backend_type = "mooncake" diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 6f095c4f39..58cb6a076a 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -185,6 +185,7 @@ def _create_manager( local_rdma_comm_ports=None, kvcache_storage_backend=None, write_policy="write_through", + bytes_per_token_per_layer=2048, swap_space=4, ) model_config = SimpleNamespace( diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 0d2581d94b..240cf702ed 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -196,6 +196,131 @@ class TestConfig(unittest.TestCase): ] ) + def test_fdconfig_get_cache_bytes(self): + """Test CacheConfig.get_cache_bytes static method for various dtypes.""" + # Test float32/fp32 variants + for dtype in ["float32", "fp32"]: + assert CacheConfig.get_cache_bytes(dtype) == 4 + + # Test float16/bf16/fp16 variants + for dtype in ["float16", "bf16", "fp16"]: + assert CacheConfig.get_cache_bytes(dtype) == 2 + + # Test 8-bit types + for dtype in ["uint8", "int8", "float8", "fp8"]: + assert CacheConfig.get_cache_bytes(dtype) == 1 + + # Test int4 + assert CacheConfig.get_cache_bytes("int4") == 0.5 + + # Test unsupported dtype raises ValueError + with self.assertRaises(ValueError) as ctx: + CacheConfig.get_cache_bytes("bf11") + assert "Unsupported cache dtype" in str(ctx.exception) + + def test_fdconfig_num_cpu_blocks(self): + """Test num_cpu_blocks calculation with swap_space.""" + # Create mock model config with required attributes + model_config = Mock() + model_config.num_key_value_heads = 32 + model_config.num_attention_heads = 32 + model_config.head_dim = 128 + model_config.num_hidden_layers = 24 + model_config.quantization = None + model_config.quantization_config = None + + # Test case 1: swap_space is None -> num_cpu_blocks = 0 + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "bfloat16", + "swap_space": None, + } + ) + assert cache_config.num_cpu_blocks == 0 + + # Test case 2: swap_space = 1GB + # bytes_per_block = head_num * head_dim * byte_size * kv_factor * block_size * num_hidden_layers + # = 32 * 128 * 2 * 2 * 64 * 24 = 25165824 bytes + # num_cpu_blocks = 1 * 1024^3 / 25165824 = 42 + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "bfloat16", + "swap_space": 1, + } + ) + expected_blocks = int(1 * 1024**3 / (32 * 128 * 2 * 2 * 64 * 24)) + assert cache_config.num_cpu_blocks == expected_blocks + assert cache_config.num_cpu_blocks == 42 + + # Test case 3: swap_space = 2GB + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "bfloat16", + "swap_space": 2, + } + ) + assert cache_config.num_cpu_blocks == 85 + + # Test case 4: with fp32 dtype (4 bytes) + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "float32", + "swap_space": 1, + } + ) + expected_blocks = int(1 * 1024**3 / (32 * 128 * 4 * 2 * 64 * 24)) + assert cache_config.num_cpu_blocks == expected_blocks + assert cache_config.num_cpu_blocks == 21 + + # Test case 5: with int8 dtype (1 byte) + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "int8", + "swap_space": 1, + } + ) + expected_blocks = int(1 * 1024**3 / (32 * 128 * 1 * 2 * 64 * 24)) + assert cache_config.num_cpu_blocks == expected_blocks + assert cache_config.num_cpu_blocks == 85 + + # Test case 6: num_cpu_blocks is explicitly set (not affected by swap_space) + cache_config = CacheConfig( + { + "model_cfg": model_config, + "cache_dtype": "bfloat16", + "swap_space": 10, + "num_cpu_blocks": 100, + } + ) + assert cache_config.num_cpu_blocks == 100 + + # Test case 7: with num_key_value_heads (GQA) + model_config_with_gqa = Mock() + model_config_with_gqa.num_key_value_heads = 8 # GQA + model_config_with_gqa.num_attention_heads = 32 + model_config_with_gqa.head_dim = 128 + model_config_with_gqa.num_hidden_layers = 24 + model_config_with_gqa.quantization = None + model_config_with_gqa.quantization_config = None + + cache_config = CacheConfig( + { + "model_cfg": model_config_with_gqa, + "cache_dtype": "bfloat16", + "swap_space": 1, + } + ) + # bytes_per_block = 8 * 128 * 2 * 2 * 64 * 24 = 6291456 bytes + # num_cpu_blocks = 1 * 1024^3 / 6291456 = 170 + expected_blocks = int(1 * 1024**3 / (8 * 128 * 2 * 2 * 64 * 24)) + assert cache_config.num_cpu_blocks == expected_blocks + assert cache_config.num_cpu_blocks == 170 + if __name__ == "__main__": unittest.main() diff --git a/tests/v1/cache_manager/test_prefix_cache.py b/tests/v1/cache_manager/test_prefix_cache.py index 5362276af7..8dccd0e500 100644 --- a/tests/v1/cache_manager/test_prefix_cache.py +++ b/tests/v1/cache_manager/test_prefix_cache.py @@ -35,7 +35,8 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over model_cfg.print = print model_cfg.architectures = ["test_model"] model_cfg.mm_max_tokens_per_item = None - cache_cfg.bytes_per_layer_per_block = 1 + cache_cfg.bytes_per_token_per_layer = 1 + parallel_cfg = ParallelConfig(args) scheduler_cfg = SchedulerConfig(args) graph_opt_cfg = engine_args.create_graph_optimization_config() diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index bfa28c9514..0723bd08b7 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -81,7 +81,7 @@ def _build_manager( model_cfg.max_model_len = max_model_len model_cfg.architectures = architectures or ["test_model"] model_cfg.mm_max_tokens_per_item = None - cache_cfg.bytes_per_layer_per_block = 1 + cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.kv_cache_ratio = 1.0 parallel_cfg = ParallelConfig(args) scheduler_cfg = SchedulerConfig(args) @@ -142,7 +142,7 @@ class TestResourceManagerV1(unittest.TestCase): model_cfg.max_model_len = 3200 model_cfg.architectures = ["test_model"] model_cfg.mm_max_tokens_per_item = None - cache_cfg.bytes_per_layer_per_block = 1 + cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.kv_cache_ratio = 1.0 parallel_cfg = ParallelConfig(args) scheduler_cfg = SchedulerConfig(args) @@ -304,7 +304,7 @@ class TestRevertChunkedMMInput(unittest.TestCase): model_cfg.max_model_len = 3200 model_cfg.architectures = ["test_model"] model_cfg.mm_max_tokens_per_item = None - cache_cfg.bytes_per_layer_per_block = 1 + cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.kv_cache_ratio = 1.0 cache_cfg.block_size = 64 parallel_cfg = ParallelConfig(args)