Files
FastDeploy/tests/engine/test_engine.py
T

129 lines
5.0 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for fastdeploy/engine/engine.py (LLMEngine)
"""
import os
import tempfile
import types
import unittest
from unittest.mock import Mock, patch
import numpy as np
from fastdeploy.engine.engine import LLMEngine
class TestLLMEngineStopProfile(unittest.TestCase):
"""测试 LLMEngine._stop_profile 方法"""
def test_stop_profile_logs_worker_traceback_and_returns_false(self):
"""测试 worker 进程失败时,_stop_profile 打印 traceback 并返回 False"""
eng = object.__new__(LLMEngine)
eng.do_profile = 1
eng.get_profile_block_num_signal = type("Sig", (), {"value": np.array([0])})()
eng.worker_proc = Mock(poll=lambda: 1)
with tempfile.TemporaryDirectory() as temp_dir:
worker_log = os.path.join(temp_dir, "workerlog.0")
with open(worker_log, "w", encoding="utf-8") as fp:
fp.write(
"Traceback (most recent call last):\n"
"ValueError: The total number of blocks cannot be less than zero.\n"
)
with (
patch("fastdeploy.engine.engine.time.sleep", lambda *_: None),
patch("fastdeploy.engine.engine.envs.FD_LOG_DIR", temp_dir),
patch("fastdeploy.engine.engine.console_logger.error") as mock_error,
):
result = eng._stop_profile()
self.assertFalse(result)
error_messages = [call.args[0] for call in mock_error.call_args_list]
self.assertTrue(any("Traceback (most recent call last):" in msg for msg in error_messages))
self.assertTrue(any("The total number of blocks cannot be less than zero" in msg for msg in error_messages))
def test_stop_profile_returns_true_on_success(self):
"""测试 _stop_profile 正常完成时返回 True"""
eng = object.__new__(LLMEngine)
eng.do_profile = 1
eng.get_profile_block_num_signal = type("Sig", (), {"value": np.array([100])})()
eng.worker_proc = Mock(poll=lambda: None)
eng.ipc_signal_suffix = "_test"
eng.cfg = types.SimpleNamespace(
parallel_config=types.SimpleNamespace(device_ids="0"),
scheduler_config=types.SimpleNamespace(splitwise_role="decode"),
cache_config=Mock(enable_prefix_caching=False, reset=Mock()),
)
eng.engine = types.SimpleNamespace(
start_cache_service=lambda *_: None,
resource_manager=Mock(reset_cache_config=Mock()),
)
eng.cache_manager_processes = None
result = eng._stop_profile()
self.assertTrue(result)
class TestLLMEngineStart(unittest.TestCase):
"""测试 LLMEngine.start 方法中的错误处理"""
class _Sig:
def __init__(self, val):
self.value = np.array([val])
def test_start_returns_false_when_profile_worker_dies(self):
"""测试当 profile worker 失败时,start 返回 False"""
eng = object.__new__(LLMEngine)
eng.is_started = False
eng.api_server_pid = None
eng.do_profile = 1
port = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778"))
eng.cfg = types.SimpleNamespace(
parallel_config=types.SimpleNamespace(engine_worker_queue_port=[port], device_ids="0"),
scheduler_config=types.SimpleNamespace(splitwise_role="mixed", max_num_seqs=8),
cache_config=types.SimpleNamespace(
enable_prefix_caching=True,
block_size=64,
num_gpu_blocks_override=None,
total_block_num=0,
num_cpu_blocks=0,
),
model_config=types.SimpleNamespace(max_model_len=128),
)
eng._init_worker_signals = lambda: setattr(eng, "loaded_model_signal", self._Sig(1))
eng.launch_components = lambda: None
eng.worker_proc = None
eng.engine = types.SimpleNamespace(
start=lambda: None,
create_data_processor=lambda: setattr(eng.engine, "data_processor", object()),
data_processor=object(),
)
eng._start_worker_service = lambda: Mock(stdout=Mock(), poll=lambda: 1)
eng.check_worker_initialize_status = lambda: False
eng._stop_profile = lambda: False
with patch("fastdeploy.engine.engine.time.sleep", lambda *_: None):
result = eng.start(api_server_pid=None)
self.assertFalse(result)
if __name__ == "__main__":
unittest.main()