Files
FastDeploy/tests/scheduler/test_scheduler_config.py
T
gongweibao a6351dea0b [BugFix][Optimization] Replace silent failures with catchable exceptions and informative error messages (#6533)
* init

* init

* fix format

* add

* add files

* add ut

* fix some

* add ut

* add more

* add

* fix pre-commit

* fix pre-commit

* fix cover

* skip long seq

* add

* add

* fix

* remove not need

* fix set attr

* fix comments

* fix comments

* fix failed tests

---------

Co-authored-by: gongweibao <gognweibao@baidu.com>
2026-03-16 21:32:43 +08:00

303 lines
11 KiB
Python

# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for scheduler configuration classes and FDConfig max_num_batched_tokens
assignment when ENABLE_V1_KVCACHE_SCHEDULER is enabled.
"""
import contextlib
import os
import unittest
from unittest.mock import MagicMock, Mock, patch
from fastdeploy.config import FDConfig
from fastdeploy.scheduler.config import (
DPLocalSchedulerConfig,
GlobalSchedulerConfig,
LocalSchedulerConfig,
SchedulerConfig,
)
class TestLocalSchedulerConfig(unittest.TestCase):
def test_defaults(self):
cfg = LocalSchedulerConfig()
self.assertEqual(cfg.max_size, -1)
self.assertEqual(cfg.ttl, 900)
self.assertEqual(cfg.max_model_len, 8192)
self.assertFalse(cfg.enable_chunked_prefill)
def test_auto_threshold(self):
"""long_prefill_token_threshold should be 4% of max_model_len when set to 0."""
cfg = LocalSchedulerConfig(max_model_len=10000, long_prefill_token_threshold=0)
self.assertEqual(cfg.long_prefill_token_threshold, 400)
def test_explicit_threshold(self):
cfg = LocalSchedulerConfig(long_prefill_token_threshold=512)
self.assertEqual(cfg.long_prefill_token_threshold, 512)
def test_custom_values(self):
cfg = LocalSchedulerConfig(max_size=100, ttl=300, max_model_len=4096)
self.assertEqual(cfg.max_size, 100)
self.assertEqual(cfg.ttl, 300)
self.assertEqual(cfg.max_model_len, 4096)
def test_kwargs_ignored(self):
"""Extra kwargs should not raise."""
cfg = LocalSchedulerConfig(unknown_key="value")
self.assertFalse(hasattr(cfg, "unknown_key"))
class TestDPLocalSchedulerConfig(unittest.TestCase):
def test_defaults(self):
cfg = DPLocalSchedulerConfig()
self.assertEqual(cfg.splitwise_role, "prefill")
def test_custom_role(self):
cfg = DPLocalSchedulerConfig(splitwise_role="decode")
self.assertEqual(cfg.splitwise_role, "decode")
class TestGlobalSchedulerConfig(unittest.TestCase):
def test_defaults(self):
cfg = GlobalSchedulerConfig()
self.assertEqual(cfg.host, "127.0.0.1")
self.assertEqual(cfg.port, 6379)
self.assertEqual(cfg.db, 0)
self.assertIsNone(cfg.password)
self.assertEqual(cfg.topic, "default")
def test_check_invalid_ttl(self):
cfg = GlobalSchedulerConfig(ttl=-1)
with self.assertRaises(ValueError):
cfg.check()
def test_check_invalid_min_load_score(self):
cfg = GlobalSchedulerConfig(min_load_score=0)
with self.assertRaises(ValueError):
cfg.check()
def test_check_invalid_load_shards_num(self):
cfg = GlobalSchedulerConfig(load_shards_num=0)
with self.assertRaises(ValueError):
cfg.check()
def test_auto_threshold(self):
cfg = GlobalSchedulerConfig(max_model_len=20000, long_prefill_token_threshold=0)
self.assertEqual(cfg.long_prefill_token_threshold, 800)
@patch("fastdeploy.scheduler.config.redis")
def test_check_redis_connection_failure_raises(self, mock_redis_mod):
"""Redis ping returning False should raise ConnectionError."""
mock_conn = MagicMock()
mock_conn.ping.return_value = False
mock_redis_mod.Redis.return_value = mock_conn
cfg = GlobalSchedulerConfig()
with self.assertRaises(ConnectionError):
cfg.check()
class TestSchedulerConfig(unittest.TestCase):
def test_local_scheduler(self):
cfg = SchedulerConfig({"name": "local", "max_size": 50, "ttl": 600})
self.assertEqual(cfg.name, "local")
self.assertIsInstance(cfg.config, LocalSchedulerConfig)
self.assertEqual(cfg.config.max_size, 50)
def test_dp_scheduler(self):
cfg = SchedulerConfig({"name": "dp", "splitwise_role": "decode"})
self.assertEqual(cfg.name, "dp")
self.assertIsInstance(cfg.config, DPLocalSchedulerConfig)
def test_global_scheduler(self):
cfg = SchedulerConfig({"name": "global", "host": "redis.local"})
self.assertEqual(cfg.name, "global")
self.assertIsInstance(cfg.config, GlobalSchedulerConfig)
self.assertEqual(cfg.config.host, "redis.local")
def test_check_unknown_name_raises(self):
cfg = SchedulerConfig({"name": "unknown"})
with self.assertRaises(Exception):
cfg.check()
def test_default_attrs(self):
cfg = SchedulerConfig({"name": "local"})
self.assertEqual(cfg.max_num_batched_tokens, 2048)
self.assertEqual(cfg.max_extra_num_batched_tokens, 16384)
self.assertEqual(cfg.max_num_seqs, 34)
self.assertEqual(cfg.splitwise_role, "mixed")
self.assertFalse(cfg.enable_overlap_schedule)
def test_attrs_override(self):
cfg = SchedulerConfig({"name": "local", "max_num_seqs": 64, "max_num_batched_tokens": 4096})
self.assertEqual(cfg.max_num_seqs, 64)
self.assertEqual(cfg.max_num_batched_tokens, 4096)
def _create_mock_configs():
"""Create all mock config objects needed for FDConfig initialization."""
# Mock scheduler_config
mock_scheduler = Mock(spec=SchedulerConfig)
mock_scheduler.max_num_batched_tokens = None
mock_scheduler.max_num_seqs = 34
mock_scheduler.splitwise_role = "mixed"
mock_scheduler.name = "local"
mock_scheduler.max_extra_num_batched_tokens = 16384
mock_scheduler.enable_overlap_schedule = False
# Mock model_config
mock_model = Mock()
mock_model.max_model_len = 8192
mock_model.architectures = ["TestModel"]
mock_model.enable_mm = False
mock_model.is_reasoning_model = False
mock_model.mm_max_tokens_per_item = None
mock_model.moe_phase = None
# Mock cache_config
mock_cache = Mock()
mock_cache.enable_prefix_caching = False
mock_cache.block_size = 64
mock_cache.enable_chunked_prefill = False
mock_cache.max_block_num_per_seq = 128
mock_cache.cache_queue_port = None
mock_cache.pd_comm_port = None
mock_cache.rdma_comm_ports = None
mock_cache.max_encoder_cache = 0
mock_cache.postprocess = Mock()
# Mock parallel_config
mock_parallel = Mock()
mock_parallel.tensor_parallel_size = 1
mock_parallel.data_parallel_size = 1
mock_parallel.expert_parallel_size = 1
mock_parallel.local_data_parallel_id = 0
mock_parallel.engine_worker_queue_port = [8080]
mock_parallel.local_engine_worker_queue_port = 8080
mock_parallel.device_ids = "0"
mock_parallel.use_sequence_parallel_moe = False
# Mock load_config
mock_load = Mock()
mock_load.load_strategy = "normal"
mock_load.dynamic_load_weight = False
# Mock graph_opt_config
mock_graph = Mock()
mock_graph.use_cudagraph = False
mock_graph.cudagraph_capture_sizes = None
mock_graph.max_capture_shape_prefill = 512
mock_graph.graph_opt_level = 0
mock_graph.cudagraph_only_prefill = False
mock_graph.filter_capture_size = Mock()
return mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
def _create_fd_config_instance(mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph):
"""Create an FDConfig instance with the given mock configs."""
fd_config = FDConfig.__new__(FDConfig)
fd_config.model_config = mock_model
fd_config.cache_config = mock_cache
fd_config.scheduler_config = mock_scheduler
fd_config.parallel_config = mock_parallel
fd_config.load_config = mock_load
fd_config.graph_opt_config = mock_graph
fd_config.speculative_config = None
fd_config.eplb_config = None
fd_config.structured_outputs_config = None
fd_config.router_config = None
fd_config.nnode = 1
fd_config.node_rank = 0
fd_config.worker_num_per_node = 1
fd_config.master_ip = "127.0.0.1"
fd_config.is_master = True
fd_config.max_num_partial_prefills = 1
fd_config.max_long_partial_prefills = 1
fd_config.long_prefill_token_threshold = 0
fd_config.paddle_commit_id = "test"
fd_config.routing_replay_config = None
return fd_config
@contextlib.contextmanager
def _patch_env_and_config(enable_v1_scheduler):
"""Context manager to patch all environment variables and config methods."""
env_vars = {
"ENABLE_V1_KVCACHE_SCHEDULER": str(enable_v1_scheduler),
"FD_ENABLE_MAX_PREFILL": "0",
"FD_FOR_TORCH_MODEL_FORMAT": "0",
"FD_MAX_STOP_SEQS_NUM": "10",
"FD_STOP_SEQS_MAX_LEN": "100",
}
with patch.dict(os.environ, env_vars):
yield
class TestSchedulerConfigMaxNumBatchedTokens(unittest.TestCase):
"""Test cases for scheduler_config.max_num_batched_tokens assignment logic."""
def test_max_num_batched_tokens_set_to_8192_when_v1_scheduler_enabled(self):
"""
Test that max_num_batched_tokens is set to 8192 when:
1. scheduler_config.max_num_batched_tokens is None
2. ENABLE_V1_KVCACHE_SCHEDULER is enabled (value is truthy)
This test covers the line:
self.scheduler_config.max_num_batched_tokens = 8192
"""
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph = _create_mock_configs()
with _patch_env_and_config(enable_v1_scheduler=1):
fd_config = _create_fd_config_instance(
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
)
fd_config.postprocess()
self.assertEqual(
fd_config.scheduler_config.max_num_batched_tokens,
8192,
"max_num_batched_tokens should be set to 8192 when "
"ENABLE_V1_KVCACHE_SCHEDULER is enabled and value is None",
)
def test_max_num_batched_tokens_not_overwritten_when_already_set(self):
"""
Test that max_num_batched_tokens is NOT overwritten when it already has a value.
This test ensures that if max_num_batched_tokens is explicitly set to a non-None value,
it should not be changed by the postprocess method.
"""
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph = _create_mock_configs()
original_value = 4096
mock_scheduler.max_num_batched_tokens = original_value
with _patch_env_and_config(enable_v1_scheduler=1):
fd_config = _create_fd_config_instance(
mock_scheduler, mock_model, mock_cache, mock_parallel, mock_load, mock_graph
)
fd_config.postprocess()
self.assertEqual(
fd_config.scheduler_config.max_num_batched_tokens,
original_value,
"max_num_batched_tokens should not be overwritten when already set",
)
if __name__ == "__main__":
unittest.main()