[RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)

* [test] add a few unit tests

* [feat] update key prefix when model weights are updated

* [test] try to fix test_worker_process
This commit is contained in:
Yonghua Li
2026-04-02 19:58:41 +08:00
committed by GitHub
parent 9f3b3ce7f5
commit 98f3fc9267
8 changed files with 636 additions and 11 deletions
@@ -17,7 +17,7 @@ import sys
import tempfile
import time
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import paddle
@@ -37,6 +37,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
from fastdeploy.engine.request import ControlRequest
# ==========================
@@ -121,6 +122,16 @@ class TestCacheTransferManager(unittest.TestCase):
patcher_thread.start()
self.addCleanup(patcher_thread.stop)
# --------------------------
# mock FMQ
# --------------------------
patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ")
mock_fmq_cls = patcher_fmq.start()
mock_fmq = MagicMock()
mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue")
mock_fmq_cls.return_value = mock_fmq
self.addCleanup(patcher_fmq.stop)
# --------------------------
# mock _init_cpu_cache 和 _init_gpu_cache
# --------------------------
@@ -1515,6 +1526,111 @@ class TestCacheTransferManager(unittest.TestCase):
self.assertFalse(self.manager.is_paused)
def test_init_control_builds_expected_queue_name(self):
self.manager.rank = 1
self.manager.n_ranks = 4
self.manager.local_data_parallel_id = 2
self.manager.cache_queue_port = 8899
queue = MagicMock(name="ctrl_q")
fmq = MagicMock()
fmq.queue.return_value = queue
with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq):
self.manager._init_control()
fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer")
self.assertIs(self.manager.ctrl_output_queue, queue)
def test_control_task_success_puts_control_response(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
self.manager._handle_pause = MagicMock(return_value=True)
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause"))
self.manager._handle_pause.assert_called_once()
self.manager.cache_task_queue.barrier.wait.assert_called_once()
self.manager.ctrl_output_queue.put.assert_called_once()
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.request_id, "ctrl-1")
self.assertEqual(response.error_code, 200)
def test_control_task_unknown_method_returns_400(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown"))
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 400)
self.assertIn("Unknown control method", response.error_message)
def test_control_task_exception_returns_500(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
with (
patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")),
patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"),
):
self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep"))
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 500)
self.assertIn("Failed to execute sleep", response.error_message)
def test_handle_resume_updates_key_prefix_for_storage_backend(self):
self.manager.is_paused = True
self.manager.storage_backend_type = "mooncake"
self.manager.resume = MagicMock()
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_resume()
self.assertTrue(result)
self.manager.resume.assert_called_once()
self.manager._update_key_prefix.assert_called_once()
def test_handle_update_weights_updates_key_prefix_for_storage_backend(self):
self.manager.storage_backend_type = "mooncake"
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_update_weights()
self.assertTrue(result)
self.manager._update_key_prefix.assert_called_once()
def test_handle_update_weights_skips_without_storage_backend(self):
self.manager.storage_backend_type = None
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_update_weights()
self.assertTrue(result)
self.manager._update_key_prefix.assert_not_called()
def test_handle_sleep_and_wakeup_are_idempotent(self):
self.manager.is_sleeping = True
self.manager._clear_cpu_cache = MagicMock()
self.manager._clear_gpu_cache = MagicMock()
self.manager._init_cpu_cache = MagicMock()
self.manager._init_gpu_cache = MagicMock()
self.assertTrue(self.manager._handle_sleep())
self.manager._clear_cpu_cache.assert_not_called()
self.manager._clear_gpu_cache.assert_not_called()
self.manager.is_sleeping = False
self.assertTrue(self.manager._handle_wakeup())
self.manager._init_cpu_cache.assert_not_called()
self.manager._init_gpu_cache.assert_not_called()
def test_submit_task_decrements_inflight_on_task_error(self):
class DummyPool:
def submit(self, fn, *args):
+22
View File
@@ -39,6 +39,7 @@ if not hasattr(paddle, "compat"):
paddle.compat = _PaddleCompat()
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.common_engine import (
EngineService,
@@ -1117,6 +1118,27 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
self.assertEqual(eng.cfg.model_config.version, "new-version")
self._detach_finalizer(eng)
def test_control_update_weights_updates_cache_transfer_metadata(self):
eng = self._make_mixed_engine()
eng.is_paused = True
eng._pause_cond = threading.Condition()
eng.cfg.cache_config.num_cpu_blocks = 1
eng._call_worker = Mock(return_value=[{"version": "new-version"}])
eng.cache_task_queue = Mock(put_transfer_task=Mock())
eng._wait_for_control_responses = AsyncMock(return_value=[{"ok": True}])
result = eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
self.assertEqual(result, [{"version": "new-version"}])
payload = eng.cache_task_queue.put_transfer_task.call_args.args[0]
self.assertEqual(payload[0], CacheStatus.CTRL)
self.assertEqual(payload[1].method, "update_weights")
self.assertIn("update_weights", payload[1].request_id)
eng._wait_for_control_responses.assert_awaited_once_with(
payload[1].request_id, 60, executors=["cache_transfer"]
)
self._detach_finalizer(eng)
def test_control_pause_and_resume_paths(self):
eng = self._make_mixed_engine()
eng.is_paused = False
+60 -1
View File
@@ -25,7 +25,7 @@ import numpy as np
import paddle
import pytest
from fastdeploy.engine.request import ControlRequest
from fastdeploy.engine.request import ControlRequest, ControlResponse
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.inter_communicator import (
KVCacheStatus,
@@ -1882,6 +1882,65 @@ def test_valid_parameters_and_control_timeout(minimal_engine_client):
assert resp.error_code == 500
def test_run_control_method_uses_send_pyobj_for_mm_requests(minimal_engine_client):
queue = asyncio.Queue()
asyncio.run(queue.put(({"request_id": "mm-1", "status": 200, "msg": "ok"},)))
dealer = Mock(write=Mock())
minimal_engine_client.enable_mm = True
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="mm-1", method="ping")))
assert resp.error_code == 200
minimal_engine_client.zmq_client.send_pyobj.assert_called_once()
minimal_engine_client.zmq_client.send_json.assert_not_called()
def test_run_control_method_adds_worker_pid_in_batch_mode(minimal_engine_client):
queue = asyncio.Queue()
asyncio.run(queue.put(({"request_id": "batch-1", "status": 200, "msg": "ok"},)))
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(None, queue)))
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 1):
resp = asyncio.run(
minimal_engine_client.run_control_method(ControlRequest(request_id="batch-1", method="ping"))
)
assert resp.error_code == 200
payload = minimal_engine_client.zmq_client.send_json.call_args.args[0]
assert payload["zmq_worker_pid"] == minimal_engine_client.worker_pid
def test_run_control_method_generic_exception_returns_error(minimal_engine_client):
queue = MagicMock()
queue.get = AsyncMock(side_effect=RuntimeError("queue failed"))
dealer = Mock(write=Mock())
minimal_engine_client.connection_manager = MagicMock(get_connection=AsyncMock(return_value=(dealer, queue)))
with patch("fastdeploy.entrypoints.engine_client.envs.ZMQ_SEND_BATCH_DATA", 0):
resp = asyncio.run(minimal_engine_client.run_control_method(ControlRequest(request_id="r3", method="m")))
assert resp.error_code == 500
assert "queue failed" in resp.error_message
def test_run_control_method_sync_uses_threadsafe_bridge(minimal_engine_client):
req = ControlRequest(request_id="sync-1", method="ping")
future = Mock(result=Mock(return_value=ControlResponse("sync-1", 200, "Success")))
minimal_engine_client.run_control_method = AsyncMock(return_value=ControlResponse("sync-1", 200, "Success"))
with patch(
"fastdeploy.entrypoints.engine_client.asyncio.run_coroutine_threadsafe", return_value=future
) as mock_run:
resp = minimal_engine_client.run_control_method_sync(req, Mock())
assert resp.error_code == 200
mock_run.assert_called_once()
mock_run.call_args.args[0].close()
def test_rearrange_and_redundant_branch_matrix(minimal_engine_client):
cfg = create_mock_fd_config(enable_eplb=True)
cfg.parallel_config.tensor_parallel_rank = 0
+116 -1
View File
@@ -14,12 +14,13 @@
import unittest
from dataclasses import dataclass
from unittest.mock import Mock
from unittest.mock import Mock, patch
import numpy as np
import paddle
from fastdeploy.engine.request import ImagePosition
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
from fastdeploy.worker.input_batch import InputBatch
@@ -476,5 +477,119 @@ class TestProcessMMFeatures(unittest.TestCase):
)
class TestSleepWakeupBehavior(unittest.TestCase):
def _make_runner(self):
runner = GPUModelRunner.__new__(GPUModelRunner)
runner.is_weight_sleeping = False
runner.is_kvcache_sleeping = False
runner.use_cudagraph = False
runner.spec_method = None
runner.local_rank = 0
runner.device_id = 1
runner.num_gpu_blocks = 8
runner.model = Mock(clear_grpah_opt_backend=Mock())
runner.clear_cache = Mock()
runner.initialize_kv_cache = Mock()
runner.capture_model = Mock()
runner.share_inputs = Mock(reset_share_inputs=Mock())
runner.dynamic_weight_manager = Mock(
clear_deepep_buffer=Mock(),
clear_model_weight=Mock(),
clear_communication_group=Mock(),
restart_communication_group=Mock(),
recreate_deepep_buffer=Mock(),
reload_model_weights=Mock(),
)
runner.fd_config = Mock()
runner.fd_config.parallel_config = Mock(
enable_expert_parallel=False,
shutdown_comm_group_if_worker_idle=False,
)
runner.proposer = Mock(
clear_mtp_cache=Mock(),
initialize_kv_cache=Mock(),
model_inputs=Mock(reset_model_inputs=Mock()),
)
return runner
@patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use")
@patch("paddle.device.cuda.empty_cache")
def test_sleep_offloads_weight_and_cache(self, mock_empty_cache, mock_print_memory):
runner = self._make_runner()
runner.use_cudagraph = True
runner.spec_method = SpecMethod.MTP
runner.fd_config.parallel_config.enable_expert_parallel = True
runner.fd_config.parallel_config.shutdown_comm_group_if_worker_idle = True
runner.sleep("weight,kv_cache")
runner.model.clear_grpah_opt_backend.assert_called_once()
runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once()
runner.dynamic_weight_manager.clear_model_weight.assert_called_once()
runner.dynamic_weight_manager.clear_communication_group.assert_called_once()
runner.proposer.clear_mtp_cache.assert_called_once()
runner.clear_cache.assert_called_once()
self.assertTrue(runner.is_weight_sleeping)
self.assertTrue(runner.is_kvcache_sleeping)
mock_empty_cache.assert_called_once()
mock_print_memory.assert_called_once()
@patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use")
@patch("paddle.device.cuda.empty_cache")
def test_sleep_weight_is_idempotent(self, mock_empty_cache, mock_print_memory):
runner = self._make_runner()
runner.is_weight_sleeping = True
runner.sleep("weight")
runner.dynamic_weight_manager.clear_model_weight.assert_not_called()
runner.clear_cache.assert_not_called()
mock_empty_cache.assert_not_called()
mock_print_memory.assert_not_called()
def test_wakeup_rejects_weight_only_when_cudagraph_requires_kvcache(self):
runner = self._make_runner()
runner.use_cudagraph = True
runner.is_kvcache_sleeping = True
with self.assertRaises(RuntimeError):
runner.wakeup("weight")
@patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use")
def test_wakeup_restores_weight_and_cache(self, mock_print_memory):
runner = self._make_runner()
runner.use_cudagraph = True
runner.spec_method = SpecMethod.MTP
runner.is_weight_sleeping = True
runner.is_kvcache_sleeping = True
runner.fd_config.parallel_config.enable_expert_parallel = True
runner.fd_config.parallel_config.shutdown_comm_group_if_worker_idle = True
runner.wakeup("weight,kv_cache")
runner.proposer.model_inputs.reset_model_inputs.assert_called_once()
runner.share_inputs.reset_share_inputs.assert_called_once()
runner.proposer.initialize_kv_cache.assert_called_once_with(main_model_num_blocks=runner.num_gpu_blocks)
runner.initialize_kv_cache.assert_called_once()
runner.dynamic_weight_manager.restart_communication_group.assert_called_once()
runner.dynamic_weight_manager.recreate_deepep_buffer.assert_called_once()
runner.dynamic_weight_manager.reload_model_weights.assert_called_once()
runner.capture_model.assert_called_once()
self.assertFalse(runner.is_weight_sleeping)
self.assertFalse(runner.is_kvcache_sleeping)
mock_print_memory.assert_called_once()
@patch("fastdeploy.worker.gpu_model_runner.print_gpu_memory_use")
def test_wakeup_kvcache_is_idempotent(self, mock_print_memory):
runner = self._make_runner()
runner.is_kvcache_sleeping = False
runner.wakeup("kv_cache")
runner.initialize_kv_cache.assert_not_called()
runner.dynamic_weight_manager.reload_model_weights.assert_not_called()
mock_print_memory.assert_not_called()
if __name__ == "__main__":
unittest.main()
+99
View File
@@ -0,0 +1,99 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import Mock
from fastdeploy.config import FDConfig
from fastdeploy.worker.gpu_worker import GpuWorker
class TestGpuWorkerSleepWakeup(unittest.TestCase):
"""Test cases for GpuWorker sleep and wakeup methods - Coverage for lines 201, 205"""
def setUp(self):
"""Set up test fixtures"""
self.mock_fd_config = Mock(spec=FDConfig)
self.mock_fd_config.parallel_config = Mock()
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
def test_sleep_delegates_to_model_runner(self):
"""Test sleep method delegates to model_runner (line 201)"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call sleep
worker.sleep(tags="weight")
# Verify model_runner.sleep was called
worker.model_runner.sleep.assert_called_once_with(tags="weight")
def test_sleep_with_multiple_tags(self):
"""Test sleep with multiple tags"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call sleep with multiple tags
worker.sleep(tags="weight,kv_cache")
# Verify model_runner.sleep was called with correct tags
worker.model_runner.sleep.assert_called_once_with(tags="weight,kv_cache")
def test_sleep_with_kwargs(self):
"""Test sleep passes kwargs to model_runner"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call sleep with kwargs
worker.sleep(tags="weight", force=True, timeout=100)
# Verify model_runner.sleep was called with kwargs
worker.model_runner.sleep.assert_called_once_with(tags="weight", force=True, timeout=100)
def test_wakeup_delegates_to_model_runner(self):
"""Test wakeup method delegates to model_runner (line 205)"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call wakeup
worker.wakeup(tags="weight")
# Verify model_runner.wakeup was called
worker.model_runner.wakeup.assert_called_once_with(tags="weight")
def test_wakeup_with_multiple_tags(self):
"""Test wakeup with multiple tags"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call wakeup with multiple tags
worker.wakeup(tags="weight,kv_cache")
# Verify model_runner.wakeup was called with correct tags
worker.model_runner.wakeup.assert_called_once_with(tags="weight,kv_cache")
def test_wakeup_with_kwargs(self):
"""Test wakeup passes kwargs to model_runner"""
worker = GpuWorker.__new__(GpuWorker)
worker.model_runner = Mock()
# Call wakeup with kwargs
worker.wakeup(tags="kv_cache", async_load=True)
# Verify model_runner.wakeup was called with kwargs
worker.model_runner.wakeup.assert_called_once_with(tags="kv_cache", async_load=True)
if __name__ == "__main__":
unittest.main()
+186 -4
View File
@@ -13,14 +13,22 @@
# limitations under the License.
import logging
import types
import unittest
from unittest.mock import AsyncMock, Mock, patch
import pytest
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import ControlRequest
from fastdeploy.worker.worker_process import PaddleDisWorkerProc
class TestInterceptPaddleLoggers(unittest.TestCase):
"""Test cases for intercept_paddle_loggers context manager from tools.logger_patch"""
def test_intercept_paddle_loggers_with_paddle_prefix(self):
"""Test intercept_paddle_loggers configures paddle loggers correctly (line 28-30)"""
"""Test intercept_paddle_loggers configures paddle loggers correctly"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Create a logger with existing handlers before interception
@@ -34,12 +42,12 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
test_logger.addHandler(handler2)
self.assertEqual(len(test_logger.handlers), 2)
# Use the context manager to intercept paddle loggers
# Use context manager to intercept paddle loggers
with intercept_paddle_loggers():
# Get logger inside context - should be configured by interceptor
intercepted_logger = logging.getLogger(test_logger_name)
# Verify the logger was reconfigured by the interceptor
# Verify the logger was reconfigured by interceptor
self.assertEqual(len(intercepted_logger.handlers), 1)
self.assertIsInstance(intercepted_logger.handlers[0], logging.StreamHandler)
self.assertEqual(intercepted_logger.level, logging.INFO)
@@ -49,7 +57,7 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
test_logger.handlers = []
def test_intercept_paddle_loggers_restores_original(self):
"""Test intercept_paddle_loggers restores original getLogger after exit (line 46)"""
"""Test intercept_paddle_loggers restores original getLogger after exit"""
from fastdeploy.logger.logger import intercept_paddle_loggers
# Store original getLogger before context
@@ -104,5 +112,179 @@ class TestInterceptPaddleLoggers(unittest.TestCase):
self.assertEqual(logging.getLogger, original_getLogger)
class TestWorkerProcessControlMethod(unittest.TestCase):
"""Test cases for PaddleDisWorkerProc control method handling - Coverage for lines 761-786"""
def setUp(self):
"""Set up test fixtures"""
self.mock_fd_config = Mock(spec=FDConfig)
self.mock_fd_config.parallel_config = Mock()
self.mock_fd_config.parallel_config.use_ep = False
self.mock_fd_config.parallel_config.tensor_parallel_size = 1
self.mock_fd_config.load_config = Mock()
self.mock_fd_config.load_config.dynamic_load_weight = False
self.process = PaddleDisWorkerProc.__new__(PaddleDisWorkerProc)
self.process.fd_config = self.mock_fd_config
self.process.parallel_config = self.mock_fd_config.parallel_config
self.process.local_rank = 0
self.process.eplb_config = types.SimpleNamespace(enable_eplb=False)
# Mock worker - use spec to avoid auto-creating Mock methods
self.process.worker = Mock(spec=[]) # Empty spec = no methods defined
# Create async mock for queue
self.mock_queue = Mock()
self.mock_queue.put = AsyncMock()
self.process._ctrl_output = self.mock_queue
def test_run_control_method_unknown_handler(self):
"""Test run_control_method with unknown control method"""
# Create a request with unknown method
request = ControlRequest(request_id="test_id", method="unknown_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_non_callable_handler(self):
"""Test run_control_method with non-callable handler"""
# Add a non-callable attribute to worker
self.process.worker.some_method = "not_callable"
request = ControlRequest(request_id="test_id", method="some_method", args={})
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.error_code, 400)
def test_run_control_method_success(self):
"""Test run_control_method with successful execution"""
# Add a callable method to worker
mock_result = {"result": "success"}
self.process.worker.test_method = Mock(return_value=mock_result)
request = ControlRequest(request_id="test_id", method="test_method", args={"param": "value"})
self.process.run_control_method(request)
# Verify handler was called with args
self.process.worker.test_method.assert_called_once_with(param="value")
# Verify put was called with success response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 200)
def test_run_control_method_exception(self):
"""Test run_control_method with exception in handler"""
# Add a method that raises exception
def failing_method(**kwargs):
raise ValueError("Test error")
self.process.worker.test_method = failing_method
request = ControlRequest(request_id="test_id", method="test_method", args={})
with patch("fastdeploy.worker.worker_process.traceback") as mock_traceback:
mock_traceback.format_exc.return_value = "Traceback..."
self.process.run_control_method(request)
# Verify put was called with error response
self.mock_queue.put.assert_called_once()
call_args = self.mock_queue.put.call_args[0][0]
self.assertEqual(call_args.request_id, "test_id")
self.assertEqual(call_args.error_code, 500)
def test_run_control_directly_when_not_use_ep(self):
"""Test running control request directly when use_ep is disabled"""
self.process.parallel_config.use_ep = False
# Add a callable method to worker
self.process.worker.test_method = Mock(return_value={"result": "ok"})
control_req = ControlRequest(request_id="test_id", method="test_method", args={})
self.process.run_control_method(control_req)
# Verify handler was called
self.process.worker.test_method.assert_called_once()
# Verify put was called
self.mock_queue.put.assert_called_once()
@pytest.mark.skip("This case might hang in ci environment, to be fixed in the future")
def test_event_loop_caches_ep_control_requests_before_collective_run(self):
self.process.parallel_config.use_ep = True
self.process.parallel_config.ep_group = Mock(world_size=1)
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock()
self.process.run_control_method = Mock()
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=1))
control_req = ControlRequest(request_id="ep-ctrl", method="pause", args={})
self.process.task_queue.get_tasks.return_value = ([([control_req], 1)], False)
self.process.exist_task_signal = types.SimpleNamespace(value=[1])
self.process.worker = types.SimpleNamespace(
preprocess_new_task=Mock(),
model_runner=types.SimpleNamespace(),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with (
patch("fastdeploy.utils.all_gather_values", side_effect=SystemExit),
patch("fastdeploy.worker.worker_process.all_gather_values", side_effect=SystemExit),
):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.assertEqual(self.process.cached_control_reqs, [control_req])
self.process.run_control_method.assert_not_called()
def test_event_loop_skips_execute_model_when_runner_is_sleeping(self):
self.process.parallel_config.use_ep = False
self.process.parallel_config.tensor_parallel_size = 2
self.process.fd_config.load_config.dynamic_load_weight = True
self.process.cached_control_reqs = []
self.process._run_eplb = Mock()
self.process._tp_barrier_wait = Mock(side_effect=SystemExit)
self.process.worker_healthy_live_signal = Mock(value=[0])
self.process.max_chips_per_node = 8
self.process.nnode = 1
self.process.ranks = 1
self.process.local_rank = 0
self.process.task_queue = Mock()
self.process.task_queue.exist_tasks.return_value = False
self.process.task_queue.read_finish_flag = types.SimpleNamespace(get=Mock(return_value=0))
self.process.exist_task_signal = types.SimpleNamespace(value=[0])
self.process.worker = types.SimpleNamespace(
model_runner=types.SimpleNamespace(is_sleeping=True),
execute_model=Mock(),
exist_prefill=Mock(return_value=False),
)
with patch("fastdeploy.worker.worker_process.envs.FD_ENABLE_V1_UPDATE_WEIGHTS", "1"):
with self.assertRaises(SystemExit):
self.process.event_loop_normal()
self.process.worker.execute_model.assert_not_called()
if __name__ == "__main__":
unittest.main()