mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
a6351dea0b
* init * init * fix format * add * add files * add ut * fix some * add ut * add more * add * fix pre-commit * fix pre-commit * fix cover * skip long seq * add * add * fix * remove not need * fix set attr * fix comments * fix comments * fix failed tests --------- Co-authored-by: gongweibao <gognweibao@baidu.com>
303 lines
11 KiB
Python
303 lines
11 KiB
Python
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Tests for scheduler configuration classes and FDConfig max_num_batched_tokens
|
|
assignment when ENABLE_V1_KVCACHE_SCHEDULER is enabled.
|
|
"""
|
|
|
|
import contextlib
|
|
import os
|
|
import unittest
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.scheduler.config import (
|
|
DPLocalSchedulerConfig,
|
|
GlobalSchedulerConfig,
|
|
LocalSchedulerConfig,
|
|
SchedulerConfig,
|
|
)
|
|
|
|
|
|
class TestLocalSchedulerConfig(unittest.TestCase):
|
|
def test_defaults(self):
|
|
cfg = LocalSchedulerConfig()
|
|
self.assertEqual(cfg.max_size, -1)
|
|
self.assertEqual(cfg.ttl, 900)
|
|
self.assertEqual(cfg.max_model_len, 8192)
|
|
self.assertFalse(cfg.enable_chunked_prefill)
|
|
|
|
def test_auto_threshold(self):
|
|
"""long_prefill_token_threshold should be 4% of max_model_len when set to 0."""
|
|
cfg = LocalSchedulerConfig(max_model_len=10000, long_prefill_token_threshold=0)
|
|
self.assertEqual(cfg.long_prefill_token_threshold, 400)
|
|
|
|
def test_explicit_threshold(self):
|
|
cfg = LocalSchedulerConfig(long_prefill_token_threshold=512)
|
|
self.assertEqual(cfg.long_prefill_token_threshold, 512)
|
|
|
|
def test_custom_values(self):
|
|
cfg = LocalSchedulerConfig(max_size=100, ttl=300, max_model_len=4096)
|
|
self.assertEqual(cfg.max_size, 100)
|
|
self.assertEqual(cfg.ttl, 300)
|
|
self.assertEqual(cfg.max_model_len, 4096)
|
|
|
|
def test_kwargs_ignored(self):
|
|
"""Extra kwargs should not raise."""
|
|
cfg = LocalSchedulerConfig(unknown_key="value")
|
|
self.assertFalse(hasattr(cfg, "unknown_key"))
|
|
|
|
|
|
class TestDPLocalSchedulerConfig(unittest.TestCase):
|
|
def test_defaults(self):
|
|
cfg = DPLocalSchedulerConfig()
|
|
self.assertEqual(cfg.splitwise_role, "prefill")
|
|
|
|
def test_custom_role(self):
|
|
cfg = DPLocalSchedulerConfig(splitwise_role="decode")
|
|
self.assertEqual(cfg.splitwise_role, "decode")
|
|
|
|
|
|
class TestGlobalSchedulerConfig(unittest.TestCase):
|
|
def test_defaults(self):
|
|
cfg = GlobalSchedulerConfig()
|
|
self.assertEqual(cfg.host, "127.0.0.1")
|
|
self.assertEqual(cfg.port, 6379)
|
|
self.assertEqual(cfg.db, 0)
|
|
self.assertIsNone(cfg.password)
|
|
self.assertEqual(cfg.topic, "default")
|
|
|
|
def test_check_invalid_ttl(self):
|
|
cfg = GlobalSchedulerConfig(ttl=-1)
|
|
with self.assertRaises(ValueError):
|
|
cfg.check()
|
|
|
|
def test_check_invalid_min_load_score(self):
|
|
cfg = GlobalSchedulerConfig(min_load_score=0)
|
|
with self.assertRaises(ValueError):
|
|
cfg.check()
|
|
|
|
def test_check_invalid_load_shards_num(self):
|
|
cfg = GlobalSchedulerConfig(load_shards_num=0)
|
|
with self.assertRaises(ValueError):
|
|
cfg.check()
|
|
|
|
def test_auto_threshold(self):
|
|
cfg = GlobalSchedulerConfig(max_model_len=20000, long_prefill_token_threshold=0)
|
|
self.assertEqual(cfg.long_prefill_token_threshold, 800)
|
|
|
|
@patch("fastdeploy.scheduler.config.redis")
|
|
def test_check_redis_connection_failure_raises(self, mock_redis_mod):
|
|
"""Redis ping returning False should raise ConnectionError."""
|
|
mock_conn = MagicMock()
|
|
mock_conn.ping.return_value = False
|
|
mock_redis_mod.Redis.return_value = mock_conn
|
|
|
|
cfg = GlobalSchedulerConfig()
|
|
with self.assertRaises(ConnectionError):
|
|
cfg.check()
|
|
|
|
|
|
class TestSchedulerConfig(unittest.TestCase):
|
|
def test_local_scheduler(self):
|
|
cfg = SchedulerConfig({"name": "local", "max_size": 50, "ttl": 600})
|
|
self.assertEqual(cfg.name, "local")
|
|
self.assertIsInstance(cfg.config, LocalSchedulerConfig)
|
|
self.assertEqual(cfg.config.max_size, 50)
|
|
|
|
def test_dp_scheduler(self):
|
|
cfg = SchedulerConfig({"name": "dp", "splitwise_role": "decode"})
|
|
self.assertEqual(cfg.name, "dp")
|
|
self.assertIsInstance(cfg.config, DPLocalSchedulerConfig)
|
|
|
|
def test_global_scheduler(self):
|
|
cfg = SchedulerConfig({"name": "global", "host": "redis.local"})
|
|
self.assertEqual(cfg.name, "global")
|
|
self.assertIsInstance(cfg.config, GlobalSchedulerConfig)
|
|
self.assertEqual(cfg.config.host, "redis.local")
|
|
|
|
def test_check_unknown_name_raises(self):
|
|
cfg = SchedulerConfig({"name": "unknown"})
|
|
with self.assertRaises(Exception):
|
|
cfg.check()
|
|
|
|
def test_default_attrs(self):
|
|
cfg = SchedulerConfig({"name": "local"})
|
|
self.assertEqual(cfg.max_num_batched_tokens, 2048)
|
|
self.assertEqual(cfg.max_extra_num_batched_tokens, 16384)
|
|
self.assertEqual(cfg.max_num_seqs, 34)
|
|
self.assertEqual(cfg.splitwise_role, "mixed")
|
|
self.assertFalse(cfg.enable_overlap_schedule)
|
|
|
|
def test_attrs_override(self):
|
|
cfg = SchedulerConfig({"name": "local", "max_num_seqs": 64, "max_num_batched_tokens": 4096})
|
|
self.assertEqual(cfg.max_num_seqs, 64)
|
|
self.assertEqual(cfg.max_num_batched_tokens, 4096)
|
|
|
|
|
|
def _create_mock_configs():
|
|
"""Create all mock config objects needed for FDConfig initialization."""
|
|
# Mock scheduler_config
|
|
mock_scheduler = Mock(spec=SchedulerConfig)
|
|
mock_scheduler.max_num_batched_tokens = None
|
|
mock_scheduler.max_num_seqs = 34
|
|
mock_scheduler.splitwise_role = "mixed"
|
|
mock_scheduler.name = "local"
|
|
mock_scheduler.max_extra_num_batched_tokens = 16384
|
|
mock_scheduler.enable_overlap_schedule = False
|
|
|
|
# Mock model_config
|
|
mock_model = Mock()
|
|
mock_model.max_model_len = 8192
|
|
mock_model.architectures = ["TestModel"]
|
|
mock_model.enable_mm = False
|
|
mock_model.is_reasoning_model = False
|
|
mock_model.mm_max_tokens_per_item = None
|
|
mock_model.moe_phase = None
|
|
|
|
# Mock cache_config
|
|
mock_cache = Mock()
|
|
mock_cache.enable_prefix_caching = False
|
|
mock_cache.block_size = 64
|
|
mock_cache.enable_chunked_prefill = False
|
|
mock_cache.max_block_num_per_seq = 128
|
|
mock_cache.cache_queue_port = None
|
|
mock_cache.pd_comm_port = None
|
|
mock_cache.rdma_comm_ports = None
|
|
mock_cache.max_encoder_cache = 0
|
|
mock_cache.postprocess = Mock()
|
|
|
|
# Mock parallel_config
|
|
mock_parallel = Mock()
|
|
mock_parallel.tensor_parallel_size = 1
|
|
mock_parallel.data_parallel_size = 1
|
|
mock_parallel.expert_parallel_size = 1
|
|
mock_parallel.local_data_parallel_id = 0
|
|
mock_parallel.engine_worker_queue_port = [8080]
|
|
mock_parallel.local_engine_worker_queue_port = 8080
|
|
mock_parallel.device_ids = "0"
|
|
mock_parallel.use_sequence_parallel_moe = False
|
|
|
|
# Mock load_config
|
|
mock_load = Mock()
|
|
mock_load.load_strategy = "normal"
|
|
mock_load.dynamic_load_weight = False
|
|
|
|
# Mock graph_opt_config
|
|
mock_graph = Mock()
|
|
mock_graph.use_cudagraph = False
|
|
mock_graph.cudagraph_capture_sizes = None
|
|
mock_graph.max_capture_shape_prefill = 512
|
|
mock_graph.graph_opt_level = 0
|
|
mock_graph.cudagraph_only_prefill = False
|
|
mock_graph.filter_capture_size = Mock()
|
|
|
|
return mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
|
|
|
|
|
|
def _create_fd_config_instance(mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph):
|
|
"""Create an FDConfig instance with the given mock configs."""
|
|
fd_config = FDConfig.__new__(FDConfig)
|
|
fd_config.model_config = mock_model
|
|
fd_config.cache_config = mock_cache
|
|
fd_config.scheduler_config = mock_scheduler
|
|
fd_config.parallel_config = mock_parallel
|
|
fd_config.load_config = mock_load
|
|
fd_config.graph_opt_config = mock_graph
|
|
fd_config.speculative_config = None
|
|
fd_config.eplb_config = None
|
|
fd_config.structured_outputs_config = None
|
|
fd_config.router_config = None
|
|
fd_config.nnode = 1
|
|
fd_config.node_rank = 0
|
|
fd_config.worker_num_per_node = 1
|
|
fd_config.master_ip = "127.0.0.1"
|
|
fd_config.is_master = True
|
|
fd_config.max_num_partial_prefills = 1
|
|
fd_config.max_long_partial_prefills = 1
|
|
fd_config.long_prefill_token_threshold = 0
|
|
fd_config.paddle_commit_id = "test"
|
|
fd_config.routing_replay_config = None
|
|
return fd_config
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _patch_env_and_config(enable_v1_scheduler):
|
|
"""Context manager to patch all environment variables and config methods."""
|
|
env_vars = {
|
|
"ENABLE_V1_KVCACHE_SCHEDULER": str(enable_v1_scheduler),
|
|
"FD_ENABLE_MAX_PREFILL": "0",
|
|
"FD_FOR_TORCH_MODEL_FORMAT": "0",
|
|
"FD_MAX_STOP_SEQS_NUM": "10",
|
|
"FD_STOP_SEQS_MAX_LEN": "100",
|
|
}
|
|
|
|
with patch.dict(os.environ, env_vars):
|
|
yield
|
|
|
|
|
|
class TestSchedulerConfigMaxNumBatchedTokens(unittest.TestCase):
|
|
"""Test cases for scheduler_config.max_num_batched_tokens assignment logic."""
|
|
|
|
def test_max_num_batched_tokens_set_to_8192_when_v1_scheduler_enabled(self):
|
|
"""
|
|
Test that max_num_batched_tokens is set to 8192 when:
|
|
1. scheduler_config.max_num_batched_tokens is None
|
|
2. ENABLE_V1_KVCACHE_SCHEDULER is enabled (value is truthy)
|
|
|
|
This test covers the line:
|
|
self.scheduler_config.max_num_batched_tokens = 8192
|
|
"""
|
|
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph = _create_mock_configs()
|
|
|
|
with _patch_env_and_config(enable_v1_scheduler=1):
|
|
fd_config = _create_fd_config_instance(
|
|
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
|
|
)
|
|
fd_config.postprocess()
|
|
|
|
self.assertEqual(
|
|
fd_config.scheduler_config.max_num_batched_tokens,
|
|
8192,
|
|
"max_num_batched_tokens should be set to 8192 when "
|
|
"ENABLE_V1_KVCACHE_SCHEDULER is enabled and value is None",
|
|
)
|
|
|
|
def test_max_num_batched_tokens_not_overwritten_when_already_set(self):
|
|
"""
|
|
Test that max_num_batched_tokens is NOT overwritten when it already has a value.
|
|
|
|
This test ensures that if max_num_batched_tokens is explicitly set to a non-None value,
|
|
it should not be changed by the postprocess method.
|
|
"""
|
|
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph = _create_mock_configs()
|
|
original_value = 4096
|
|
mock_scheduler.max_num_batched_tokens = original_value
|
|
|
|
with _patch_env_and_config(enable_v1_scheduler=1):
|
|
fd_config = _create_fd_config_instance(
|
|
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
|
|
)
|
|
fd_config.postprocess()
|
|
|
|
self.assertEqual(
|
|
fd_config.scheduler_config.max_num_batched_tokens,
|
|
original_value,
|
|
"max_num_batched_tokens should not be overwritten when already set",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|