[RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)

* [test] add a few unit tests

* [feat] update key prefix when model weights are updated

* [test] try to fix test_worker_process
This commit is contained in:
Yonghua Li
2026-04-02 19:58:41 +08:00
committed by GitHub
parent 9f3b3ce7f5
commit 98f3fc9267
8 changed files with 636 additions and 11 deletions
@@ -17,7 +17,7 @@ import sys
import tempfile
import time
import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import paddle
@@ -37,6 +37,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
from fastdeploy.engine.request import ControlRequest
# ==========================
@@ -121,6 +122,16 @@ class TestCacheTransferManager(unittest.TestCase):
patcher_thread.start()
self.addCleanup(patcher_thread.stop)
# --------------------------
# mock FMQ
# --------------------------
patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ")
mock_fmq_cls = patcher_fmq.start()
mock_fmq = MagicMock()
mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue")
mock_fmq_cls.return_value = mock_fmq
self.addCleanup(patcher_fmq.stop)
# --------------------------
# mock _init_cpu_cache 和 _init_gpu_cache
# --------------------------
@@ -1515,6 +1526,111 @@ class TestCacheTransferManager(unittest.TestCase):
self.assertFalse(self.manager.is_paused)
def test_init_control_builds_expected_queue_name(self):
self.manager.rank = 1
self.manager.n_ranks = 4
self.manager.local_data_parallel_id = 2
self.manager.cache_queue_port = 8899
queue = MagicMock(name="ctrl_q")
fmq = MagicMock()
fmq.queue.return_value = queue
with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq):
self.manager._init_control()
fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer")
self.assertIs(self.manager.ctrl_output_queue, queue)
def test_control_task_success_puts_control_response(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
self.manager._handle_pause = MagicMock(return_value=True)
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause"))
self.manager._handle_pause.assert_called_once()
self.manager.cache_task_queue.barrier.wait.assert_called_once()
self.manager.ctrl_output_queue.put.assert_called_once()
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.request_id, "ctrl-1")
self.assertEqual(response.error_code, 200)
def test_control_task_unknown_method_returns_400(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown"))
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 400)
self.assertIn("Unknown control method", response.error_message)
def test_control_task_exception_returns_500(self):
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
with (
patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")),
patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"),
):
self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep"))
response = self.manager.ctrl_output_queue.put.call_args.args[0]
self.assertEqual(response.error_code, 500)
self.assertIn("Failed to execute sleep", response.error_message)
def test_handle_resume_updates_key_prefix_for_storage_backend(self):
self.manager.is_paused = True
self.manager.storage_backend_type = "mooncake"
self.manager.resume = MagicMock()
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_resume()
self.assertTrue(result)
self.manager.resume.assert_called_once()
self.manager._update_key_prefix.assert_called_once()
def test_handle_update_weights_updates_key_prefix_for_storage_backend(self):
self.manager.storage_backend_type = "mooncake"
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_update_weights()
self.assertTrue(result)
self.manager._update_key_prefix.assert_called_once()
def test_handle_update_weights_skips_without_storage_backend(self):
self.manager.storage_backend_type = None
self.manager._update_key_prefix = MagicMock()
result = self.manager._handle_update_weights()
self.assertTrue(result)
self.manager._update_key_prefix.assert_not_called()
def test_handle_sleep_and_wakeup_are_idempotent(self):
self.manager.is_sleeping = True
self.manager._clear_cpu_cache = MagicMock()
self.manager._clear_gpu_cache = MagicMock()
self.manager._init_cpu_cache = MagicMock()
self.manager._init_gpu_cache = MagicMock()
self.assertTrue(self.manager._handle_sleep())
self.manager._clear_cpu_cache.assert_not_called()
self.manager._clear_gpu_cache.assert_not_called()
self.manager.is_sleeping = False
self.assertTrue(self.manager._handle_wakeup())
self.manager._init_cpu_cache.assert_not_called()
self.manager._init_gpu_cache.assert_not_called()
def test_submit_task_decrements_inflight_on_task_error(self):
class DummyPool:
def submit(self, fn, *args):