mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix][Optimization] Replace silent failures with catchable exceptions and informative error messages (#6533)
* init * init * fix format * add * add files * add ut * fix some * add ut * add more * add * fix pre-commit * fix pre-commit * fix cover * skip long seq * add * add * fix * remove not need * fix set attr * fix comments * fix comments * fix failed tests --------- Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
@@ -12,16 +12,138 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for FDConfig and scheduler configuration, specifically for
|
||||
max_num_batched_tokens assignment when ENABLE_V1_KVCACHE_SCHEDULER is enabled.
|
||||
Tests for scheduler configuration classes and FDConfig max_num_batched_tokens
|
||||
assignment when ENABLE_V1_KVCACHE_SCHEDULER is enabled.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.scheduler.config import SchedulerConfig
|
||||
from fastdeploy.scheduler.config import (
|
||||
DPLocalSchedulerConfig,
|
||||
GlobalSchedulerConfig,
|
||||
LocalSchedulerConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestLocalSchedulerConfig(unittest.TestCase):
|
||||
def test_defaults(self):
|
||||
cfg = LocalSchedulerConfig()
|
||||
self.assertEqual(cfg.max_size, -1)
|
||||
self.assertEqual(cfg.ttl, 900)
|
||||
self.assertEqual(cfg.max_model_len, 8192)
|
||||
self.assertFalse(cfg.enable_chunked_prefill)
|
||||
|
||||
def test_auto_threshold(self):
|
||||
"""long_prefill_token_threshold should be 4% of max_model_len when set to 0."""
|
||||
cfg = LocalSchedulerConfig(max_model_len=10000, long_prefill_token_threshold=0)
|
||||
self.assertEqual(cfg.long_prefill_token_threshold, 400)
|
||||
|
||||
def test_explicit_threshold(self):
|
||||
cfg = LocalSchedulerConfig(long_prefill_token_threshold=512)
|
||||
self.assertEqual(cfg.long_prefill_token_threshold, 512)
|
||||
|
||||
def test_custom_values(self):
|
||||
cfg = LocalSchedulerConfig(max_size=100, ttl=300, max_model_len=4096)
|
||||
self.assertEqual(cfg.max_size, 100)
|
||||
self.assertEqual(cfg.ttl, 300)
|
||||
self.assertEqual(cfg.max_model_len, 4096)
|
||||
|
||||
def test_kwargs_ignored(self):
|
||||
"""Extra kwargs should not raise."""
|
||||
cfg = LocalSchedulerConfig(unknown_key="value")
|
||||
self.assertFalse(hasattr(cfg, "unknown_key"))
|
||||
|
||||
|
||||
class TestDPLocalSchedulerConfig(unittest.TestCase):
|
||||
def test_defaults(self):
|
||||
cfg = DPLocalSchedulerConfig()
|
||||
self.assertEqual(cfg.splitwise_role, "prefill")
|
||||
|
||||
def test_custom_role(self):
|
||||
cfg = DPLocalSchedulerConfig(splitwise_role="decode")
|
||||
self.assertEqual(cfg.splitwise_role, "decode")
|
||||
|
||||
|
||||
class TestGlobalSchedulerConfig(unittest.TestCase):
|
||||
def test_defaults(self):
|
||||
cfg = GlobalSchedulerConfig()
|
||||
self.assertEqual(cfg.host, "127.0.0.1")
|
||||
self.assertEqual(cfg.port, 6379)
|
||||
self.assertEqual(cfg.db, 0)
|
||||
self.assertIsNone(cfg.password)
|
||||
self.assertEqual(cfg.topic, "default")
|
||||
|
||||
def test_check_invalid_ttl(self):
|
||||
cfg = GlobalSchedulerConfig(ttl=-1)
|
||||
with self.assertRaises(ValueError):
|
||||
cfg.check()
|
||||
|
||||
def test_check_invalid_min_load_score(self):
|
||||
cfg = GlobalSchedulerConfig(min_load_score=0)
|
||||
with self.assertRaises(ValueError):
|
||||
cfg.check()
|
||||
|
||||
def test_check_invalid_load_shards_num(self):
|
||||
cfg = GlobalSchedulerConfig(load_shards_num=0)
|
||||
with self.assertRaises(ValueError):
|
||||
cfg.check()
|
||||
|
||||
def test_auto_threshold(self):
|
||||
cfg = GlobalSchedulerConfig(max_model_len=20000, long_prefill_token_threshold=0)
|
||||
self.assertEqual(cfg.long_prefill_token_threshold, 800)
|
||||
|
||||
@patch("fastdeploy.scheduler.config.redis")
|
||||
def test_check_redis_connection_failure_raises(self, mock_redis_mod):
|
||||
"""Redis ping returning False should raise ConnectionError."""
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.ping.return_value = False
|
||||
mock_redis_mod.Redis.return_value = mock_conn
|
||||
|
||||
cfg = GlobalSchedulerConfig()
|
||||
with self.assertRaises(ConnectionError):
|
||||
cfg.check()
|
||||
|
||||
|
||||
class TestSchedulerConfig(unittest.TestCase):
|
||||
def test_local_scheduler(self):
|
||||
cfg = SchedulerConfig({"name": "local", "max_size": 50, "ttl": 600})
|
||||
self.assertEqual(cfg.name, "local")
|
||||
self.assertIsInstance(cfg.config, LocalSchedulerConfig)
|
||||
self.assertEqual(cfg.config.max_size, 50)
|
||||
|
||||
def test_dp_scheduler(self):
|
||||
cfg = SchedulerConfig({"name": "dp", "splitwise_role": "decode"})
|
||||
self.assertEqual(cfg.name, "dp")
|
||||
self.assertIsInstance(cfg.config, DPLocalSchedulerConfig)
|
||||
|
||||
def test_global_scheduler(self):
|
||||
cfg = SchedulerConfig({"name": "global", "host": "redis.local"})
|
||||
self.assertEqual(cfg.name, "global")
|
||||
self.assertIsInstance(cfg.config, GlobalSchedulerConfig)
|
||||
self.assertEqual(cfg.config.host, "redis.local")
|
||||
|
||||
def test_check_unknown_name_raises(self):
|
||||
cfg = SchedulerConfig({"name": "unknown"})
|
||||
with self.assertRaises(Exception):
|
||||
cfg.check()
|
||||
|
||||
def test_default_attrs(self):
|
||||
cfg = SchedulerConfig({"name": "local"})
|
||||
self.assertEqual(cfg.max_num_batched_tokens, 2048)
|
||||
self.assertEqual(cfg.max_extra_num_batched_tokens, 16384)
|
||||
self.assertEqual(cfg.max_num_seqs, 34)
|
||||
self.assertEqual(cfg.splitwise_role, "mixed")
|
||||
self.assertFalse(cfg.enable_overlap_schedule)
|
||||
|
||||
def test_attrs_override(self):
|
||||
cfg = SchedulerConfig({"name": "local", "max_num_seqs": 64, "max_num_batched_tokens": 4096})
|
||||
self.assertEqual(cfg.max_num_seqs, 64)
|
||||
self.assertEqual(cfg.max_num_batched_tokens, 4096)
|
||||
|
||||
|
||||
def _create_mock_configs():
|
||||
@@ -113,21 +235,15 @@ def _create_fd_config_instance(mock_scheduler, mock_model, mock_cache, mock_para
|
||||
@contextlib.contextmanager
|
||||
def _patch_env_and_config(enable_v1_scheduler):
|
||||
"""Context manager to patch all environment variables and config methods."""
|
||||
from fastdeploy import envs as fastdeploy_envs
|
||||
env_vars = {
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": str(enable_v1_scheduler),
|
||||
"FD_ENABLE_MAX_PREFILL": "0",
|
||||
"FD_FOR_TORCH_MODEL_FORMAT": "0",
|
||||
"FD_MAX_STOP_SEQS_NUM": "10",
|
||||
"FD_STOP_SEQS_MAX_LEN": "100",
|
||||
}
|
||||
|
||||
env_patches = [
|
||||
patch.object(fastdeploy_envs, "ENABLE_V1_KVCACHE_SCHEDULER", enable_v1_scheduler),
|
||||
patch.object(fastdeploy_envs, "FD_ENABLE_MAX_PREFILL", False),
|
||||
patch.object(fastdeploy_envs, "FD_FOR_TORCH_MODEL_FORMAT", False),
|
||||
patch.object(fastdeploy_envs, "FD_MAX_STOP_SEQS_NUM", 10),
|
||||
patch.object(fastdeploy_envs, "FD_STOP_SEQS_MAX_LEN", 100),
|
||||
patch("fastdeploy.config.envs.ENABLE_V1_KVCACHE_SCHEDULER", enable_v1_scheduler),
|
||||
]
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
for p in env_patches:
|
||||
stack.enter_context(p)
|
||||
stack.enter_context(patch.object(FDConfig, "_disable_sequence_parallel_moe_if_needed"))
|
||||
with patch.dict(os.environ, env_vars):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user