Files
FastDeploy/tests/engine/test_control_request_response.py
T
wangyifei b7c5daa316 [RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine

* support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method

* change /is_puased from HTTP POST method to GET method

* add pause、resume、is_paused implementation

* support engine <==> worker communication(request&response)

* support sync weights through RDMA from checkpoint_transfer

* support specified version, rsync_config in update_weights rpc call

* add pause, update_weights, resume interface for async RL

* bug fix: update_weights support using default arguments

* fix typo

* typo fix

* typo fix

* typo fix

* add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all

* add "rsync" to LoadConfig.load_strategy Literal type hints

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* typo fix

* typo fix

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* check version/rsync params

* add error log when version.txt not exists

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* raise specified ValueError when paramters check failed

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* tp barrier after run_control_method

* encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue

* typo fix

* typo fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-23 10:18:07 +08:00

330 lines
12 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import patch
from fastapi.responses import JSONResponse
from fastdeploy.engine.request import ControlRequest, ControlResponse
class TestControlRequest(unittest.TestCase):
"""Test cases for ControlRequest class."""
def test_initialization_basic(self):
"""Test basic initialization of ControlRequest."""
request_id = "test_request_123"
method = "get_metrics"
request = ControlRequest(request_id=request_id, method=method)
self.assertEqual(request.request_id, request_id)
self.assertEqual(request.method, method)
self.assertEqual(request.args, {})
def test_initialization_with_args(self):
"""Test initialization with arguments."""
request_id = "test_request_456"
method = "reset_scheduler"
args = {"force": True, "timeout": 30}
request = ControlRequest(request_id=request_id, method=method, args=args)
self.assertEqual(request.request_id, request_id)
self.assertEqual(request.method, method)
self.assertEqual(request.args, args)
def test_from_dict_basic(self):
"""Test creating ControlRequest from dictionary (basic case)."""
data = {"request_id": "test_from_dict", "method": "status_check"}
request = ControlRequest.from_dict(data)
self.assertEqual(request.request_id, data["request_id"])
self.assertEqual(request.method, data["method"])
self.assertEqual(request.args, {})
def test_from_dict_with_args(self):
"""Test creating ControlRequest from dictionary with arguments."""
data = {
"request_id": "test_from_dict_args",
"method": "configure",
"args": {"max_requests": 100, "queue_timeout": 60},
}
request = ControlRequest.from_dict(data)
self.assertEqual(request.request_id, data["request_id"])
self.assertEqual(request.method, data["method"])
self.assertEqual(request.args, data["args"])
def test_to_dict_basic(self):
"""Test converting ControlRequest to dictionary (basic case)."""
request = ControlRequest(request_id="test_to_dict", method="health_check")
result = request.to_dict()
expected = {"request_id": "test_to_dict", "method": "health_check", "args": {}}
self.assertEqual(result, expected)
def test_to_dict_with_args(self):
"""Test converting ControlRequest to dictionary with arguments."""
args = {"setting1": "value1", "setting2": 42}
request = ControlRequest(request_id="test_to_dict_args", method="update_settings", args=args)
result = request.to_dict()
expected = {"request_id": "test_to_dict_args", "method": "update_settings", "args": args}
self.assertEqual(result, expected)
def test_get_method(self):
"""Test get_method method."""
method = "custom_operation"
request = ControlRequest(request_id="test", method=method)
self.assertEqual(request.get_method(), method)
def test_get_args(self):
"""Test get_args method."""
args = {"param1": "value1", "param2": 123}
request = ControlRequest(request_id="test", method="test", args=args)
result_args = request.get_args()
self.assertEqual(result_args, args)
# Ensure it returns a copy, not the original dict
self.assertIsNot(result_args, args)
def test_is_control_request_valid(self):
"""Test is_control_request method with valid data."""
valid_data = [
{"request_id": "test1", "method": "method1"},
{"request_id": "test2", "method": "method2", "args": {}},
{"request_id": "test3", "method": "method3", "args": {"key": "value"}},
]
for data in valid_data:
with self.subTest(data=data):
self.assertTrue(ControlRequest.is_control_request(data))
def test_is_control_request_invalid(self):
"""Test is_control_request method with invalid data."""
invalid_data = [
# Missing required fields
{"method": "test"}, # missing request_id
{"request_id": "test"}, # missing method
# Wrong field types
{"request_id": 123, "method": "test"}, # request_id not string
{"request_id": "test", "method": 456}, # method not string
{"request_id": "test", "method": "test", "args": "not_a_dict"}, # args not dict
# Not a dict
"not_a_dict",
123,
None,
]
for data in invalid_data:
with self.subTest(data=data):
self.assertFalse(ControlRequest.is_control_request(data))
def test_repr_simple(self):
"""Test __repr__ method in simple mode."""
with patch("fastdeploy.envs.FD_DEBUG", False):
request = ControlRequest(request_id="test_repr", method="test_method")
repr_str = repr(request)
self.assertIn("ControlRequest", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("test_method", repr_str)
self.assertNotIn("args", repr_str) # Args not shown in simple mode
def test_repr_debug_mode(self):
"""Test __repr__ method in debug mode."""
with patch("fastdeploy.envs.FD_DEBUG", True):
args = {"debug_param": "debug_value"}
request = ControlRequest(request_id="test_repr", method="test_method", args=args)
repr_str = repr(request)
self.assertIn("ControlRequest", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("test_method", repr_str)
self.assertIn("debug_param", repr_str) # Args shown in debug mode
class TestControlResponse(unittest.TestCase):
"""Test cases for ControlResponse class."""
def test_initialization_basic(self):
"""Test basic initialization of ControlResponse."""
request_id = "test_response_123"
response = ControlResponse(request_id=request_id)
self.assertEqual(response.request_id, request_id)
self.assertEqual(response.error_code, 200)
self.assertIsNone(response.error_message)
self.assertIsNone(response.result)
self.assertTrue(response.finished)
def test_initialization_with_all_params(self):
"""Test initialization with all parameters."""
request_id = "test_response_456"
error_code = 404
error_message = "Not found"
result = {"data": "some_result"}
finished = False
response = ControlResponse(
request_id=request_id, error_code=error_code, error_message=error_message, result=result, finished=finished
)
self.assertEqual(response.request_id, request_id)
self.assertEqual(response.error_code, error_code)
self.assertEqual(response.error_message, error_message)
self.assertEqual(response.result, result)
self.assertEqual(response.finished, finished)
def test_initialization_error_cases(self):
"""Test initialization with various error codes."""
test_cases = [
(200, None, True), # Success case
(400, "Bad Request", False), # Client error
(500, "Internal Error", True), # Server error
]
for error_code, error_message, finished in test_cases:
with self.subTest(error_code=error_code):
response = ControlResponse(
request_id="test", error_code=error_code, error_message=error_message, finished=finished
)
self.assertEqual(response.error_code, error_code)
self.assertEqual(response.error_message, error_message)
self.assertEqual(response.finished, finished)
def test_from_dict_basic(self):
"""Test creating ControlResponse from dictionary (basic case)."""
data = {"request_id": "test_from_dict"}
response = ControlResponse.from_dict(data)
self.assertEqual(response.request_id, data["request_id"])
self.assertEqual(response.error_code, 200)
self.assertIsNone(response.error_message)
self.assertIsNone(response.result)
self.assertTrue(response.finished)
def test_from_dict_with_all_fields(self):
"""Test creating ControlResponse from dictionary with all fields."""
data = {
"request_id": "test_from_dict_full",
"error_code": 500,
"error_message": "Test error",
"result": {"key": "value"},
"finished": False,
}
response = ControlResponse.from_dict(data)
self.assertEqual(response.request_id, data["request_id"])
self.assertEqual(response.error_code, data["error_code"])
self.assertEqual(response.error_message, data["error_message"])
self.assertEqual(response.result, data["result"])
self.assertEqual(response.finished, data["finished"])
def test_to_dict_basic(self):
"""Test converting ControlResponse to dictionary (basic case)."""
response = ControlResponse(request_id="test_to_dict")
result = response.to_dict()
expected = {
"request_id": "test_to_dict",
"finished": True,
"error_code": 200,
"error_message": None,
"result": None,
}
self.assertEqual(result, expected)
def test_to_dict_with_all_fields(self):
"""Test converting ControlResponse to dictionary with all fields."""
response = ControlResponse(
request_id="test_to_dict_full",
error_code=400,
error_message="Validation failed",
result={"valid": False, "reason": "missing_field"},
finished=False,
)
result = response.to_dict()
expected = {
"request_id": "test_to_dict_full",
"finished": False,
"error_code": 400,
"error_message": "Validation failed",
"result": {"valid": False, "reason": "missing_field"},
}
self.assertEqual(result, expected)
def test_to_api_json_response_success(self):
"""Test converting to JSONResponse for successful response."""
result_data = {"metrics": {"cpu_usage": 0.5, "memory_used": 1024}}
response = ControlResponse(request_id="test_json_success", result=result_data)
json_response = response.to_api_json_response()
self.assertIsInstance(json_response, JSONResponse)
self.assertEqual(json_response.status_code, 200)
content = json_response.body.decode("utf-8")
self.assertIn("success", content)
self.assertIn("test_json_success", content)
self.assertIn("cpu_usage", content)
def test_to_api_json_response_error(self):
"""Test converting to JSONResponse for error response."""
response = ControlResponse(request_id="test_json_error", error_code=503, error_message="Service unavailable")
json_response = response.to_api_json_response()
self.assertIsInstance(json_response, JSONResponse)
self.assertEqual(json_response.status_code, 503)
content = json_response.body.decode("utf-8")
self.assertIn("error", content)
self.assertIn("test_json_error", content)
self.assertIn("Service unavailable", content)
def test_repr_method(self):
"""Test __repr__ method."""
response = ControlResponse(
request_id="test_repr", error_code=200, error_message=None, result={"data": "test"}, finished=True
)
repr_str = repr(response)
# Check that all important fields are represented
self.assertIn("ControlResponse", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("200", repr_str)
self.assertIn("test", repr_str) # from result
self.assertIn("True", repr_str) # finished flag
if __name__ == "__main__":
unittest.main()