mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
98f3fc9267
* [test] add a few unit tests * [feat] update key prefix when model weights are updated * [test] try to fix test_worker_process
291 lines
12 KiB
Python
291 lines
12 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
import types
|
|
import unittest
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.engine.request import ControlRequest
|
|
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
|
|
|
|
|
|
class TestInterceptPaddleLoggers(unittest.TestCase):
|
|
"""Test cases for intercept_paddle_loggers context manager from tools.logger_patch"""
|
|
|
|
def test_intercept_paddle_loggers_with_paddle_prefix(self):
|
|
"""Test intercept_paddle_loggers configures paddle loggers correctly"""
|
|
from fastdeploy.logger.logger import intercept_paddle_loggers
|
|
|
|
# Create a logger with existing handlers before interception
|
|
test_logger_name = "paddle.test.logger"
|
|
test_logger = logging.getLogger(test_logger_name)
|
|
|
|
# Add some handlers to the logger
|
|
handler1 = logging.StreamHandler()
|
|
handler2 = logging.StreamHandler()
|
|
test_logger.addHandler(handler1)
|
|
test_logger.addHandler(handler2)
|
|
self.assertEqual(len(test_logger.handlers), 2)
|
|
|
|
# Use context manager to intercept paddle loggers
|
|
with intercept_paddle_loggers():
|
|
# Get logger inside context - should be configured by interceptor
|
|
intercepted_logger = logging.getLogger(test_logger_name)
|
|
|
|
# Verify the logger was reconfigured by interceptor
|
|
self.assertEqual(len(intercepted_logger.handlers), 1)
|
|
self.assertIsInstance(intercepted_logger.handlers[0], logging.StreamHandler)
|
|
self.assertEqual(intercepted_logger.level, logging.INFO)
|
|
self.assertFalse(intercepted_logger.propagate)
|
|
|
|
# Clean up
|
|
test_logger.handlers = []
|
|
|
|
def test_intercept_paddle_loggers_restores_original(self):
|
|
"""Test intercept_paddle_loggers restores original getLogger after exit"""
|
|
from fastdeploy.logger.logger import intercept_paddle_loggers
|
|
|
|
# Store original getLogger before context
|
|
original_getLogger = logging.getLogger
|
|
|
|
# Use the context manager
|
|
with intercept_paddle_loggers():
|
|
# Inside context, getLogger should be patched
|
|
self.assertNotEqual(logging.getLogger, original_getLogger)
|
|
|
|
# After exit, getLogger should be restored
|
|
self.assertEqual(logging.getLogger, original_getLogger)
|
|
|
|
def test_intercept_paddle_loggers_non_paddle_logger_unchanged(self):
|
|
"""Test non-paddle loggers are not affected by intercept_paddle_loggers"""
|
|
from fastdeploy.logger.logger import intercept_paddle_loggers
|
|
|
|
# Create a non-paddle logger
|
|
test_logger_name = "other.test.logger"
|
|
test_logger = logging.getLogger(test_logger_name)
|
|
|
|
# Add a handler
|
|
original_handler = logging.StreamHandler()
|
|
test_logger.addHandler(original_handler)
|
|
original_handler_count = len(test_logger.handlers)
|
|
|
|
# Use the context manager
|
|
with intercept_paddle_loggers():
|
|
# Get the same logger
|
|
result_logger = logging.getLogger(test_logger_name)
|
|
# Non-paddle loggers should not be modified
|
|
self.assertEqual(len(result_logger.handlers), original_handler_count)
|
|
self.assertEqual(result_logger.handlers[0], original_handler)
|
|
|
|
# Clean up
|
|
test_logger.handlers = []
|
|
|
|
def test_intercept_paddle_loggers_exception_safety(self):
|
|
"""Test intercept_paddle_loggers restores getLogger even if exception occurs"""
|
|
from fastdeploy.logger.logger import intercept_paddle_loggers
|
|
|
|
original_getLogger = logging.getLogger
|
|
|
|
try:
|
|
with intercept_paddle_loggers():
|
|
# Raise an exception inside context
|
|
raise ValueError("Test exception")
|
|
except ValueError:
|
|
pass # Expected
|
|
|
|
# After exception, getLogger should still be restored
|
|
self.assertEqual(logging.getLogger, original_getLogger)
|
|
|
|
|
|
class TestWorkerProcessControlMethod(unittest.TestCase):
|
|
"""Test cases for PaddleDisWorkerProc control method handling - Coverage for lines 761-786"""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures"""
|
|
self.mock_fd_config = Mock(spec=FDConfig)
|
|
self.mock_fd_config.parallel_config = Mock()
|
|
self.mock_fd_config.parallel_config.use_ep = False
|
|
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
|
|
self.mock_fd_config.load_config = Mock()
|
|
self.mock_fd_config.load_config.dynamic_load_weight = False
|
|
|
|
self.process = PaddleDisWorkerProc.__new__(PaddleDisWorkerProc)
|
|
self.process.fd_config = self.mock_fd_config
|
|
self.process.parallel_config = self.mock_fd_config.parallel_config
|
|
self.process.local_rank = 0
|
|
self.process.eplb_config = types.SimpleNamespace(enable_eplb=False)
|
|
|
|
# Mock worker - use spec to avoid auto-creating Mock methods
|
|
self.process.worker = Mock(spec=[]) # Empty spec = no methods defined
|
|
|
|
# Create async mock for queue
|
|
self.mock_queue = Mock()
|
|
self.mock_queue.put = AsyncMock()
|
|
self.process._ctrl_output = self.mock_queue
|
|
|
|
def test_run_control_method_unknown_handler(self):
|
|
"""Test run_control_method with unknown control method"""
|
|
# Create a request with unknown method
|
|
request = ControlRequest(request_id="test_id", method="unknown_method", args={})
|
|
|
|
self.process.run_control_method(request)
|
|
|
|
# Verify put was called with error response
|
|
self.mock_queue.put.assert_called_once()
|
|
call_args = self.mock_queue.put.call_args[0][0]
|
|
self.assertEqual(call_args.request_id, "test_id")
|
|
self.assertEqual(call_args.error_code, 400)
|
|
|
|
def test_run_control_method_non_callable_handler(self):
|
|
"""Test run_control_method with non-callable handler"""
|
|
# Add a non-callable attribute to worker
|
|
self.process.worker.some_method = "not_callable"
|
|
|
|
request = ControlRequest(request_id="test_id", method="some_method", args={})
|
|
|
|
self.process.run_control_method(request)
|
|
|
|
# Verify put was called with error response
|
|
self.mock_queue.put.assert_called_once()
|
|
call_args = self.mock_queue.put.call_args[0][0]
|
|
self.assertEqual(call_args.error_code, 400)
|
|
|
|
def test_run_control_method_success(self):
|
|
"""Test run_control_method with successful execution"""
|
|
# Add a callable method to worker
|
|
mock_result = {"result": "success"}
|
|
self.process.worker.test_method = Mock(return_value=mock_result)
|
|
|
|
request = ControlRequest(request_id="test_id", method="test_method", args={"param": "value"})
|
|
|
|
self.process.run_control_method(request)
|
|
|
|
# Verify handler was called with args
|
|
self.process.worker.test_method.assert_called_once_with(param="value")
|
|
|
|
# Verify put was called with success response
|
|
self.mock_queue.put.assert_called_once()
|
|
call_args = self.mock_queue.put.call_args[0][0]
|
|
self.assertEqual(call_args.request_id, "test_id")
|
|
self.assertEqual(call_args.error_code, 200)
|
|
|
|
def test_run_control_method_exception(self):
|
|
"""Test run_control_method with exception in handler"""
|
|
|
|
# Add a method that raises exception
|
|
def failing_method(**kwargs):
|
|
raise ValueError("Test error")
|
|
|
|
self.process.worker.test_method = failing_method
|
|
|
|
request = ControlRequest(request_id="test_id", method="test_method", args={})
|
|
|
|
with patch("fastdeploy.worker.worker_process.traceback") as mock_traceback:
|
|
mock_traceback.format_exc.return_value = "Traceback..."
|
|
|
|
self.process.run_control_method(request)
|
|
|
|
# Verify put was called with error response
|
|
self.mock_queue.put.assert_called_once()
|
|
call_args = self.mock_queue.put.call_args[0][0]
|
|
self.assertEqual(call_args.request_id, "test_id")
|
|
self.assertEqual(call_args.error_code, 500)
|
|
|
|
def test_run_control_directly_when_not_use_ep(self):
|
|
"""Test running control request directly when use_ep is disabled"""
|
|
self.process.parallel_config.use_ep = False
|
|
|
|
# Add a callable method to worker
|
|
self.process.worker.test_method = Mock(return_value={"result": "ok"})
|
|
|
|
control_req = ControlRequest(request_id="test_id", method="test_method", args={})
|
|
|
|
self.process.run_control_method(control_req)
|
|
|
|
# Verify handler was called
|
|
self.process.worker.test_method.assert_called_once()
|
|
|
|
# Verify put was called
|
|
self.mock_queue.put.assert_called_once()
|
|
|
|
@pytest.mark.skip("This case might hang in ci environment, to be fixed in the future")
|
|
def test_event_loop_caches_ep_control_requests_before_collective_run(self):
|
|
self.process.parallel_config.use_ep = True
|
|
self.process.parallel_config.ep_group = Mock(world_size=1)
|
|
self.process.cached_control_reqs = []
|
|
self.process._run_eplb = Mock()
|
|
self.process._tp_barrier_wait = Mock()
|
|
self.process.run_control_method = Mock()
|
|
self.process.worker_healthy_live_signal = Mock(value=[0])
|
|
self.process.max_chips_per_node = 8
|
|
self.process.nnode = 1
|
|
self.process.ranks = 1
|
|
self.process.task_queue = Mock()
|
|
self.process.task_queue.exist_tasks.return_value = False
|
|
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=1))
|
|
control_req = ControlRequest(request_id="ep-ctrl", method="pause", args={})
|
|
self.process.task_queue.get_tasks.return_value = ([([control_req], 1)], False)
|
|
self.process.exist_task_signal = types.SimpleNamespace(value=[1])
|
|
self.process.worker = types.SimpleNamespace(
|
|
preprocess_new_task=Mock(),
|
|
model_runner=types.SimpleNamespace(),
|
|
execute_model=Mock(),
|
|
exist_prefill=Mock(return_value=False),
|
|
)
|
|
with (
|
|
patch("fastdeploy.utils.all_gather_values", side_effect=SystemExit),
|
|
patch("fastdeploy.worker.worker_process.all_gather_values", side_effect=SystemExit),
|
|
):
|
|
with self.assertRaises(SystemExit):
|
|
self.process.event_loop_normal()
|
|
|
|
self.assertEqual(self.process.cached_control_reqs, [control_req])
|
|
self.process.run_control_method.assert_not_called()
|
|
|
|
def test_event_loop_skips_execute_model_when_runner_is_sleeping(self):
|
|
self.process.parallel_config.use_ep = False
|
|
self.process.parallel_config.tensor_parallel_size = 2
|
|
self.process.fd_config.load_config.dynamic_load_weight = True
|
|
self.process.cached_control_reqs = []
|
|
self.process._run_eplb = Mock()
|
|
self.process._tp_barrier_wait = Mock(side_effect=SystemExit)
|
|
self.process.worker_healthy_live_signal = Mock(value=[0])
|
|
self.process.max_chips_per_node = 8
|
|
self.process.nnode = 1
|
|
self.process.ranks = 1
|
|
self.process.local_rank = 0
|
|
self.process.task_queue = Mock()
|
|
self.process.task_queue.exist_tasks.return_value = False
|
|
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=0))
|
|
self.process.exist_task_signal = types.SimpleNamespace(value=[0])
|
|
self.process.worker = types.SimpleNamespace(
|
|
model_runner=types.SimpleNamespace(is_sleeping=True),
|
|
execute_model=Mock(),
|
|
exist_prefill=Mock(return_value=False),
|
|
)
|
|
|
|
with patch("fastdeploy.worker.worker_process.envs.FD_ENABLE_V1_UPDATE_WEIGHTS", "1"):
|
|
with self.assertRaises(SystemExit):
|
|
self.process.event_loop_normal()
|
|
|
|
self.process.worker.execute_model.assert_not_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|