[RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)

* [test] add a few unit tests

* [feat] update key prefix when model weights are updated

* [test] try to fix test_worker_process
This commit is contained in:
Yonghua Li
2026-04-02 19:58:41 +08:00
committed by GitHub
parent 9f3b3ce7f5
commit 98f3fc9267
8 changed files with 636 additions and 11 deletions
+186 -4
View File
@@ -13,14 +13,22 @@
# limitations under the License.
import logging
import types
import unittest
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import ControlRequest
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
class TestInterceptPaddleLoggers(unittest.TestCase):
"""Test cases for intercept_paddle_loggers context manager from tools.logger_patch"""
def test_intercept_paddle_loggers_with_paddle_prefix(self):
"""Test intercept_paddle_loggers configures paddle loggers correctly (line 28-30)"""
"""Test intercept_paddle_loggers configures paddle loggers correctly"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Create a logger with existing handlers before interception
@@ -34,12 +42,12 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
test_logger.addHandler(handler2)
self.assertEqual(len(test_logger.handlers), 2)
# Use the context manager to intercept paddle loggers
# Use context manager to intercept paddle loggers
with intercept_paddle_loggers():
# Get logger inside context - should be configured by interceptor
intercepted_logger = logging.getLogger(test_logger_name)
# Verify the logger was reconfigured by the interceptor
# Verify the logger was reconfigured by interceptor
self.assertEqual(len(intercepted_logger.handlers), 1)
self.assertIsInstance(intercepted_logger.handlers[0], logging.StreamHandler)
self.assertEqual(intercepted_logger.level, logging.INFO)
@@ -49,7 +57,7 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
test_logger.handlers = []
def test_intercept_paddle_loggers_restores_original(self):
"""Test intercept_paddle_loggers restores original getLogger after exit (line 46)"""
"""Test intercept_paddle_loggers restores original getLogger after exit"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Store original getLogger before context
@@ -104,5 +112,179 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
self.assertEqual(logging.getLogger, original_getLogger)
class TestWorkerProcessControlMethod(unittest.TestCase):
"""Test cases for PaddleDisWorkerProc control method handling - Coverage for lines 761-786"""
def setUp(self):
"""Set up test fixtures"""
self.mock_fd_config = Mock(spec=FDConfig)
self.mock_fd_config.parallel_config = Mock()
self.mock_fd_config.parallel_config.use_ep = False
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
self.mock_fd_config.load_config = Mock()
self.mock_fd_config.load_config.dynamic_load_weight = False
self.process = PaddleDisWorkerProc.__new__(PaddleDisWorkerProc)
self.process.fd_config = self.mock_fd_config
self.process.parallel_config = self.mock_fd_config.parallel_config
self.process.local_rank = 0
self.process.eplb_config = types.SimpleNamespace(enable_eplb=False)
# Mock worker - use spec to avoid auto-creating Mock methods
self.process.worker = Mock(spec=[]) # Empty spec = no methods defined
# Create async mock for queue
self.mock_queue = Mock()
self.mock_queue.put = AsyncMock()
self.process._ctrl_output = self.mock_queue
def test_run_control_method_unknown_handler(self):
"""Test run_control_method with unknown control method"""
# Create a request with unknown method
request = ControlRequest(request_id="test_id", method="unknown_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_non_callable_handler(self):
"""Test run_control_method with non-callable handler"""
# Add a non-callable attribute to worker
self.process.worker.some_method = "not_callable"
request = ControlRequest(request_id="test_id", method="some_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_success(self):
"""Test run_control_method with successful execution"""
# Add a callable method to worker
mock_result = {"result": "success"}
self.process.worker.test_method = Mock(return_value=mock_result)
request = ControlRequest(request_id="test_id", method="test_method", args={"param": "value"})
self.process.run_control_method(request)
# Verify handler was called with args
self.process.worker.test_method.assert_called_once_with(param="value")
# Verify put was called with success response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 200)
def test_run_control_method_exception(self):
"""Test run_control_method with exception in handler"""
# Add a method that raises exception
def failing_method(**kwargs):
raise ValueError("Test error")
self.process.worker.test_method = failing_method
request = ControlRequest(request_id="test_id", method="test_method", args={})
with patch("fastdeploy.worker.worker_process.traceback") as mock_traceback:
mock_traceback.format_exc.return_value = "Traceback..."
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 500)
def test_run_control_directly_when_not_use_ep(self):
"""Test running control request directly when use_ep is disabled"""
self.process.parallel_config.use_ep = False
# Add a callable method to worker
self.process.worker.test_method = Mock(return_value={"result": "ok"})
control_req = ControlRequest(request_id="test_id", method="test_method", args={})
self.process.run_control_method(control_req)
# Verify handler was called
self.process.worker.test_method.assert_called_once()
# Verify put was called
self.mock_queue.put.assert_called_once()
@pytest.mark.skip("This case might hang in ci environment, to be fixed in the future")
def test_event_loop_caches_ep_control_requests_before_collective_run(self):
self.process.parallel_config.use_ep = True
self.process.parallel_config.ep_group = Mock(world_size=1)
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock()
self.process.run_control_method = Mock()
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=1))
control_req = ControlRequest(request_id="ep-ctrl", method="pause", args={})
self.process.task_queue.get_tasks.return_value = ([([control_req], 1)], False)
self.process.exist_task_signal = types.SimpleNamespace(value=[1])
self.process.worker = types.SimpleNamespace(
preprocess_new_task=Mock(),
model_runner=types.SimpleNamespace(),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with (
patch("fastdeploy.utils.all_gather_values", side_effect=SystemExit),
patch("fastdeploy.worker.worker_process.all_gather_values", side_effect=SystemExit),
):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.assertEqual(self.process.cached_control_reqs, [control_req])
self.process.run_control_method.assert_not_called()
def test_event_loop_skips_execute_model_when_runner_is_sleeping(self):
self.process.parallel_config.use_ep = False
self.process.parallel_config.tensor_parallel_size = 2
self.process.fd_config.load_config.dynamic_load_weight = True
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock(side_effect=SystemExit)
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.local_rank = 0
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=0))
self.process.exist_task_signal = types.SimpleNamespace(value=[0])
self.process.worker = types.SimpleNamespace(
model_runner=types.SimpleNamespace(is_sleeping=True),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with patch("fastdeploy.worker.worker_process.envs.FD_ENABLE_V1_UPDATE_WEIGHTS", "1"):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.process.worker.execute_model.assert_not_called()
if __name__ == "__main__":
unittest.main()