mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] [KVCache] let cache transfer managers update key prefix after weight update and add unit tests (#7083)
* [test] add a few unit tests * [feat] update key prefix when model weights are updated * [test] try to fix test_worker_process
This commit is contained in:
@@ -17,7 +17,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import paddle
|
||||
|
||||
@@ -37,6 +37,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..",
|
||||
import fastdeploy.cache_manager.cache_transfer_manager as cache_transfer_manager
|
||||
from fastdeploy.cache_manager.cache_tasks import ReadStorageTask, WriteStorageTask
|
||||
from fastdeploy.cache_manager.cache_transfer_manager import CacheTransferManager
|
||||
from fastdeploy.engine.request import ControlRequest
|
||||
|
||||
|
||||
# ==========================
|
||||
@@ -121,6 +122,16 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
patcher_thread.start()
|
||||
self.addCleanup(patcher_thread.stop)
|
||||
|
||||
# --------------------------
|
||||
# mock FMQ
|
||||
# --------------------------
|
||||
patcher_fmq = patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ")
|
||||
mock_fmq_cls = patcher_fmq.start()
|
||||
mock_fmq = MagicMock()
|
||||
mock_fmq.queue.return_value = MagicMock(name="ctrl_output_queue")
|
||||
mock_fmq_cls.return_value = mock_fmq
|
||||
self.addCleanup(patcher_fmq.stop)
|
||||
|
||||
# --------------------------
|
||||
# mock _init_cpu_cache 和 _init_gpu_cache
|
||||
# --------------------------
|
||||
@@ -1515,6 +1526,111 @@ class TestCacheTransferManager(unittest.TestCase):
|
||||
|
||||
self.assertFalse(self.manager.is_paused)
|
||||
|
||||
def test_init_control_builds_expected_queue_name(self):
|
||||
self.manager.rank = 1
|
||||
self.manager.n_ranks = 4
|
||||
self.manager.local_data_parallel_id = 2
|
||||
self.manager.cache_queue_port = 8899
|
||||
|
||||
queue = MagicMock(name="ctrl_q")
|
||||
fmq = MagicMock()
|
||||
fmq.queue.return_value = queue
|
||||
|
||||
with patch("fastdeploy.cache_manager.cache_transfer_manager.FMQ", return_value=fmq):
|
||||
self.manager._init_control()
|
||||
|
||||
fmq.queue.assert_called_once_with("ctrl_c2e_rank9_8899", "producer")
|
||||
self.assertIs(self.manager.ctrl_output_queue, queue)
|
||||
|
||||
def test_control_task_success_puts_control_response(self):
|
||||
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
|
||||
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
|
||||
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
|
||||
self.manager._handle_pause = MagicMock(return_value=True)
|
||||
|
||||
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
|
||||
self.manager.control_task(ControlRequest(request_id="ctrl-1", method="pause"))
|
||||
|
||||
self.manager._handle_pause.assert_called_once()
|
||||
self.manager.cache_task_queue.barrier.wait.assert_called_once()
|
||||
self.manager.ctrl_output_queue.put.assert_called_once()
|
||||
response = self.manager.ctrl_output_queue.put.call_args.args[0]
|
||||
self.assertEqual(response.request_id, "ctrl-1")
|
||||
self.assertEqual(response.error_code, 200)
|
||||
|
||||
def test_control_task_unknown_method_returns_400(self):
|
||||
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
|
||||
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
|
||||
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
|
||||
|
||||
with patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"):
|
||||
self.manager.control_task(ControlRequest(request_id="ctrl-2", method="unknown"))
|
||||
|
||||
response = self.manager.ctrl_output_queue.put.call_args.args[0]
|
||||
self.assertEqual(response.error_code, 400)
|
||||
self.assertIn("Unknown control method", response.error_message)
|
||||
|
||||
def test_control_task_exception_returns_500(self):
|
||||
self.manager.cache_task_queue.barrier = MagicMock(wait=Mock())
|
||||
self.manager.ctrl_output_queue = MagicMock(name="ctrl_q")
|
||||
self.manager.ctrl_output_queue.put = Mock(return_value="coro")
|
||||
|
||||
with (
|
||||
patch.object(self.manager, "_handle_sleep", side_effect=RuntimeError("boom")),
|
||||
patch("fastdeploy.cache_manager.cache_transfer_manager.asyncio.run"),
|
||||
):
|
||||
self.manager.control_task(ControlRequest(request_id="ctrl-3", method="sleep"))
|
||||
|
||||
response = self.manager.ctrl_output_queue.put.call_args.args[0]
|
||||
self.assertEqual(response.error_code, 500)
|
||||
self.assertIn("Failed to execute sleep", response.error_message)
|
||||
|
||||
def test_handle_resume_updates_key_prefix_for_storage_backend(self):
|
||||
self.manager.is_paused = True
|
||||
self.manager.storage_backend_type = "mooncake"
|
||||
self.manager.resume = MagicMock()
|
||||
self.manager._update_key_prefix = MagicMock()
|
||||
|
||||
result = self.manager._handle_resume()
|
||||
|
||||
self.assertTrue(result)
|
||||
self.manager.resume.assert_called_once()
|
||||
self.manager._update_key_prefix.assert_called_once()
|
||||
|
||||
def test_handle_update_weights_updates_key_prefix_for_storage_backend(self):
|
||||
self.manager.storage_backend_type = "mooncake"
|
||||
self.manager._update_key_prefix = MagicMock()
|
||||
|
||||
result = self.manager._handle_update_weights()
|
||||
|
||||
self.assertTrue(result)
|
||||
self.manager._update_key_prefix.assert_called_once()
|
||||
|
||||
def test_handle_update_weights_skips_without_storage_backend(self):
|
||||
self.manager.storage_backend_type = None
|
||||
self.manager._update_key_prefix = MagicMock()
|
||||
|
||||
result = self.manager._handle_update_weights()
|
||||
|
||||
self.assertTrue(result)
|
||||
self.manager._update_key_prefix.assert_not_called()
|
||||
|
||||
def test_handle_sleep_and_wakeup_are_idempotent(self):
|
||||
self.manager.is_sleeping = True
|
||||
self.manager._clear_cpu_cache = MagicMock()
|
||||
self.manager._clear_gpu_cache = MagicMock()
|
||||
self.manager._init_cpu_cache = MagicMock()
|
||||
self.manager._init_gpu_cache = MagicMock()
|
||||
|
||||
self.assertTrue(self.manager._handle_sleep())
|
||||
self.manager._clear_cpu_cache.assert_not_called()
|
||||
self.manager._clear_gpu_cache.assert_not_called()
|
||||
|
||||
self.manager.is_sleeping = False
|
||||
self.assertTrue(self.manager._handle_wakeup())
|
||||
self.manager._init_cpu_cache.assert_not_called()
|
||||
self.manager._init_gpu_cache.assert_not_called()
|
||||
|
||||
def test_submit_task_decrements_inflight_on_task_error(self):
|
||||
class DummyPool:
|
||||
def submit(self, fn, *args):
|
||||
|
||||
Reference in New Issue
Block a user