Files
FastDeploy/tests/worker/test_worker_process.py
T
Yonghua Li 98f3fc9267 [RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)
* [test] add a few unit tests

* [feat] update key prefix when model weights are updated

* [test] try to fix test_worker_process
2026-04-02 19:58:41 +08:00

291 lines
12 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import types
import unittest
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import ControlRequest
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
class TestInterceptPaddleLoggers(unittest.TestCase):
"""Test cases for intercept_paddle_loggers context manager from tools.logger_patch"""
def test_intercept_paddle_loggers_with_paddle_prefix(self):
"""Test intercept_paddle_loggers configures paddle loggers correctly"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Create a logger with existing handlers before interception
test_logger_name = "paddle.test.logger"
test_logger = logging.getLogger(test_logger_name)
# Add some handlers to the logger
handler1 = logging.StreamHandler()
handler2 = logging.StreamHandler()
test_logger.addHandler(handler1)
test_logger.addHandler(handler2)
self.assertEqual(len(test_logger.handlers), 2)
# Use context manager to intercept paddle loggers
with intercept_paddle_loggers():
# Get logger inside context - should be configured by interceptor
intercepted_logger = logging.getLogger(test_logger_name)
# Verify the logger was reconfigured by interceptor
self.assertEqual(len(intercepted_logger.handlers), 1)
self.assertIsInstance(intercepted_logger.handlers[0], logging.StreamHandler)
self.assertEqual(intercepted_logger.level, logging.INFO)
self.assertFalse(intercepted_logger.propagate)
# Clean up
test_logger.handlers = []
def test_intercept_paddle_loggers_restores_original(self):
"""Test intercept_paddle_loggers restores original getLogger after exit"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Store original getLogger before context
original_getLogger = logging.getLogger
# Use the context manager
with intercept_paddle_loggers():
# Inside context, getLogger should be patched
self.assertNotEqual(logging.getLogger, original_getLogger)
# After exit, getLogger should be restored
self.assertEqual(logging.getLogger, original_getLogger)
def test_intercept_paddle_loggers_non_paddle_logger_unchanged(self):
"""Test non-paddle loggers are not affected by intercept_paddle_loggers"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Create a non-paddle logger
test_logger_name = "other.test.logger"
test_logger = logging.getLogger(test_logger_name)
# Add a handler
original_handler = logging.StreamHandler()
test_logger.addHandler(original_handler)
original_handler_count = len(test_logger.handlers)
# Use the context manager
with intercept_paddle_loggers():
# Get the same logger
result_logger = logging.getLogger(test_logger_name)
# Non-paddle loggers should not be modified
self.assertEqual(len(result_logger.handlers), original_handler_count)
self.assertEqual(result_logger.handlers[0], original_handler)
# Clean up
test_logger.handlers = []
def test_intercept_paddle_loggers_exception_safety(self):
"""Test intercept_paddle_loggers restores getLogger even if exception occurs"""
from fastdeploy.logger.logger import intercept_paddle_loggers
original_getLogger = logging.getLogger
try:
with intercept_paddle_loggers():
# Raise an exception inside context
raise ValueError("Test exception")
except ValueError:
pass # Expected
# After exception, getLogger should still be restored
self.assertEqual(logging.getLogger, original_getLogger)
class TestWorkerProcessControlMethod(unittest.TestCase):
"""Test cases for PaddleDisWorkerProc control method handling - Coverage for lines 761-786"""
def setUp(self):
"""Set up test fixtures"""
self.mock_fd_config = Mock(spec=FDConfig)
self.mock_fd_config.parallel_config = Mock()
self.mock_fd_config.parallel_config.use_ep = False
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
self.mock_fd_config.load_config = Mock()
self.mock_fd_config.load_config.dynamic_load_weight = False
self.process = PaddleDisWorkerProc.__new__(PaddleDisWorkerProc)
self.process.fd_config = self.mock_fd_config
self.process.parallel_config = self.mock_fd_config.parallel_config
self.process.local_rank = 0
self.process.eplb_config = types.SimpleNamespace(enable_eplb=False)
# Mock worker - use spec to avoid auto-creating Mock methods
self.process.worker = Mock(spec=[]) # Empty spec = no methods defined
# Create async mock for queue
self.mock_queue = Mock()
self.mock_queue.put = AsyncMock()
self.process._ctrl_output = self.mock_queue
def test_run_control_method_unknown_handler(self):
"""Test run_control_method with unknown control method"""
# Create a request with unknown method
request = ControlRequest(request_id="test_id", method="unknown_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_non_callable_handler(self):
"""Test run_control_method with non-callable handler"""
# Add a non-callable attribute to worker
self.process.worker.some_method = "not_callable"
request = ControlRequest(request_id="test_id", method="some_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_success(self):
"""Test run_control_method with successful execution"""
# Add a callable method to worker
mock_result = {"result": "success"}
self.process.worker.test_method = Mock(return_value=mock_result)
request = ControlRequest(request_id="test_id", method="test_method", args={"param": "value"})
self.process.run_control_method(request)
# Verify handler was called with args
self.process.worker.test_method.assert_called_once_with(param="value")
# Verify put was called with success response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 200)
def test_run_control_method_exception(self):
"""Test run_control_method with exception in handler"""
# Add a method that raises exception
def failing_method(**kwargs):
raise ValueError("Test error")
self.process.worker.test_method = failing_method
request = ControlRequest(request_id="test_id", method="test_method", args={})
with patch("fastdeploy.worker.worker_process.traceback") as mock_traceback:
mock_traceback.format_exc.return_value = "Traceback..."
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 500)
def test_run_control_directly_when_not_use_ep(self):
"""Test running control request directly when use_ep is disabled"""
self.process.parallel_config.use_ep = False
# Add a callable method to worker
self.process.worker.test_method = Mock(return_value={"result": "ok"})
control_req = ControlRequest(request_id="test_id", method="test_method", args={})
self.process.run_control_method(control_req)
# Verify handler was called
self.process.worker.test_method.assert_called_once()
# Verify put was called
self.mock_queue.put.assert_called_once()
@pytest.mark.skip("This case might hang in ci environment, to be fixed in the future")
def test_event_loop_caches_ep_control_requests_before_collective_run(self):
self.process.parallel_config.use_ep = True
self.process.parallel_config.ep_group = Mock(world_size=1)
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock()
self.process.run_control_method = Mock()
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=1))
control_req = ControlRequest(request_id="ep-ctrl", method="pause", args={})
self.process.task_queue.get_tasks.return_value = ([([control_req], 1)], False)
self.process.exist_task_signal = types.SimpleNamespace(value=[1])
self.process.worker = types.SimpleNamespace(
preprocess_new_task=Mock(),
model_runner=types.SimpleNamespace(),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with (
patch("fastdeploy.utils.all_gather_values", side_effect=SystemExit),
patch("fastdeploy.worker.worker_process.all_gather_values", side_effect=SystemExit),
):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.assertEqual(self.process.cached_control_reqs, [control_req])
self.process.run_control_method.assert_not_called()
def test_event_loop_skips_execute_model_when_runner_is_sleeping(self):
self.process.parallel_config.use_ep = False
self.process.parallel_config.tensor_parallel_size = 2
self.process.fd_config.load_config.dynamic_load_weight = True
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock(side_effect=SystemExit)
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.local_rank = 0
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=0))
self.process.exist_task_signal = types.SimpleNamespace(value=[0])
self.process.worker = types.SimpleNamespace(
model_runner=types.SimpleNamespace(is_sleeping=True),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with patch("fastdeploy.worker.worker_process.envs.FD_ENABLE_V1_UPDATE_WEIGHTS", "1"):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.process.worker.execute_model.assert_not_called()
if __name__ == "__main__":
unittest.main()