mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
b7c5daa316
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
330 lines
12 KiB
Python
330 lines
12 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from fastdeploy.engine.request import ControlRequest, ControlResponse
|
|
|
|
|
|
class TestControlRequest(unittest.TestCase):
|
|
"""Test cases for ControlRequest class."""
|
|
|
|
def test_initialization_basic(self):
|
|
"""Test basic initialization of ControlRequest."""
|
|
request_id = "test_request_123"
|
|
method = "get_metrics"
|
|
|
|
request = ControlRequest(request_id=request_id, method=method)
|
|
|
|
self.assertEqual(request.request_id, request_id)
|
|
self.assertEqual(request.method, method)
|
|
self.assertEqual(request.args, {})
|
|
|
|
def test_initialization_with_args(self):
|
|
"""Test initialization with arguments."""
|
|
request_id = "test_request_456"
|
|
method = "reset_scheduler"
|
|
args = {"force": True, "timeout": 30}
|
|
|
|
request = ControlRequest(request_id=request_id, method=method, args=args)
|
|
|
|
self.assertEqual(request.request_id, request_id)
|
|
self.assertEqual(request.method, method)
|
|
self.assertEqual(request.args, args)
|
|
|
|
def test_from_dict_basic(self):
|
|
"""Test creating ControlRequest from dictionary (basic case)."""
|
|
data = {"request_id": "test_from_dict", "method": "status_check"}
|
|
|
|
request = ControlRequest.from_dict(data)
|
|
|
|
self.assertEqual(request.request_id, data["request_id"])
|
|
self.assertEqual(request.method, data["method"])
|
|
self.assertEqual(request.args, {})
|
|
|
|
def test_from_dict_with_args(self):
|
|
"""Test creating ControlRequest from dictionary with arguments."""
|
|
data = {
|
|
"request_id": "test_from_dict_args",
|
|
"method": "configure",
|
|
"args": {"max_requests": 100, "queue_timeout": 60},
|
|
}
|
|
|
|
request = ControlRequest.from_dict(data)
|
|
|
|
self.assertEqual(request.request_id, data["request_id"])
|
|
self.assertEqual(request.method, data["method"])
|
|
self.assertEqual(request.args, data["args"])
|
|
|
|
def test_to_dict_basic(self):
|
|
"""Test converting ControlRequest to dictionary (basic case)."""
|
|
request = ControlRequest(request_id="test_to_dict", method="health_check")
|
|
|
|
result = request.to_dict()
|
|
|
|
expected = {"request_id": "test_to_dict", "method": "health_check", "args": {}}
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_to_dict_with_args(self):
|
|
"""Test converting ControlRequest to dictionary with arguments."""
|
|
args = {"setting1": "value1", "setting2": 42}
|
|
request = ControlRequest(request_id="test_to_dict_args", method="update_settings", args=args)
|
|
|
|
result = request.to_dict()
|
|
|
|
expected = {"request_id": "test_to_dict_args", "method": "update_settings", "args": args}
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_get_method(self):
|
|
"""Test get_method method."""
|
|
method = "custom_operation"
|
|
request = ControlRequest(request_id="test", method=method)
|
|
|
|
self.assertEqual(request.get_method(), method)
|
|
|
|
def test_get_args(self):
|
|
"""Test get_args method."""
|
|
args = {"param1": "value1", "param2": 123}
|
|
request = ControlRequest(request_id="test", method="test", args=args)
|
|
|
|
result_args = request.get_args()
|
|
|
|
self.assertEqual(result_args, args)
|
|
# Ensure it returns a copy, not the original dict
|
|
self.assertIsNot(result_args, args)
|
|
|
|
def test_is_control_request_valid(self):
|
|
"""Test is_control_request method with valid data."""
|
|
valid_data = [
|
|
{"request_id": "test1", "method": "method1"},
|
|
{"request_id": "test2", "method": "method2", "args": {}},
|
|
{"request_id": "test3", "method": "method3", "args": {"key": "value"}},
|
|
]
|
|
|
|
for data in valid_data:
|
|
with self.subTest(data=data):
|
|
self.assertTrue(ControlRequest.is_control_request(data))
|
|
|
|
def test_is_control_request_invalid(self):
|
|
"""Test is_control_request method with invalid data."""
|
|
invalid_data = [
|
|
# Missing required fields
|
|
{"method": "test"}, # missing request_id
|
|
{"request_id": "test"}, # missing method
|
|
# Wrong field types
|
|
{"request_id": 123, "method": "test"}, # request_id not string
|
|
{"request_id": "test", "method": 456}, # method not string
|
|
{"request_id": "test", "method": "test", "args": "not_a_dict"}, # args not dict
|
|
# Not a dict
|
|
"not_a_dict",
|
|
123,
|
|
None,
|
|
]
|
|
|
|
for data in invalid_data:
|
|
with self.subTest(data=data):
|
|
self.assertFalse(ControlRequest.is_control_request(data))
|
|
|
|
def test_repr_simple(self):
|
|
"""Test __repr__ method in simple mode."""
|
|
with patch("fastdeploy.envs.FD_DEBUG", False):
|
|
request = ControlRequest(request_id="test_repr", method="test_method")
|
|
repr_str = repr(request)
|
|
|
|
self.assertIn("ControlRequest", repr_str)
|
|
self.assertIn("test_repr", repr_str)
|
|
self.assertIn("test_method", repr_str)
|
|
self.assertNotIn("args", repr_str) # Args not shown in simple mode
|
|
|
|
def test_repr_debug_mode(self):
|
|
"""Test __repr__ method in debug mode."""
|
|
with patch("fastdeploy.envs.FD_DEBUG", True):
|
|
args = {"debug_param": "debug_value"}
|
|
request = ControlRequest(request_id="test_repr", method="test_method", args=args)
|
|
repr_str = repr(request)
|
|
|
|
self.assertIn("ControlRequest", repr_str)
|
|
self.assertIn("test_repr", repr_str)
|
|
self.assertIn("test_method", repr_str)
|
|
self.assertIn("debug_param", repr_str) # Args shown in debug mode
|
|
|
|
|
|
class TestControlResponse(unittest.TestCase):
|
|
"""Test cases for ControlResponse class."""
|
|
|
|
def test_initialization_basic(self):
|
|
"""Test basic initialization of ControlResponse."""
|
|
request_id = "test_response_123"
|
|
|
|
response = ControlResponse(request_id=request_id)
|
|
|
|
self.assertEqual(response.request_id, request_id)
|
|
self.assertEqual(response.error_code, 200)
|
|
self.assertIsNone(response.error_message)
|
|
self.assertIsNone(response.result)
|
|
self.assertTrue(response.finished)
|
|
|
|
def test_initialization_with_all_params(self):
|
|
"""Test initialization with all parameters."""
|
|
request_id = "test_response_456"
|
|
error_code = 404
|
|
error_message = "Not found"
|
|
result = {"data": "some_result"}
|
|
finished = False
|
|
|
|
response = ControlResponse(
|
|
request_id=request_id, error_code=error_code, error_message=error_message, result=result, finished=finished
|
|
)
|
|
|
|
self.assertEqual(response.request_id, request_id)
|
|
self.assertEqual(response.error_code, error_code)
|
|
self.assertEqual(response.error_message, error_message)
|
|
self.assertEqual(response.result, result)
|
|
self.assertEqual(response.finished, finished)
|
|
|
|
def test_initialization_error_cases(self):
|
|
"""Test initialization with various error codes."""
|
|
test_cases = [
|
|
(200, None, True), # Success case
|
|
(400, "Bad Request", False), # Client error
|
|
(500, "Internal Error", True), # Server error
|
|
]
|
|
|
|
for error_code, error_message, finished in test_cases:
|
|
with self.subTest(error_code=error_code):
|
|
response = ControlResponse(
|
|
request_id="test", error_code=error_code, error_message=error_message, finished=finished
|
|
)
|
|
|
|
self.assertEqual(response.error_code, error_code)
|
|
self.assertEqual(response.error_message, error_message)
|
|
self.assertEqual(response.finished, finished)
|
|
|
|
def test_from_dict_basic(self):
|
|
"""Test creating ControlResponse from dictionary (basic case)."""
|
|
data = {"request_id": "test_from_dict"}
|
|
|
|
response = ControlResponse.from_dict(data)
|
|
|
|
self.assertEqual(response.request_id, data["request_id"])
|
|
self.assertEqual(response.error_code, 200)
|
|
self.assertIsNone(response.error_message)
|
|
self.assertIsNone(response.result)
|
|
self.assertTrue(response.finished)
|
|
|
|
def test_from_dict_with_all_fields(self):
|
|
"""Test creating ControlResponse from dictionary with all fields."""
|
|
data = {
|
|
"request_id": "test_from_dict_full",
|
|
"error_code": 500,
|
|
"error_message": "Test error",
|
|
"result": {"key": "value"},
|
|
"finished": False,
|
|
}
|
|
|
|
response = ControlResponse.from_dict(data)
|
|
|
|
self.assertEqual(response.request_id, data["request_id"])
|
|
self.assertEqual(response.error_code, data["error_code"])
|
|
self.assertEqual(response.error_message, data["error_message"])
|
|
self.assertEqual(response.result, data["result"])
|
|
self.assertEqual(response.finished, data["finished"])
|
|
|
|
def test_to_dict_basic(self):
|
|
"""Test converting ControlResponse to dictionary (basic case)."""
|
|
response = ControlResponse(request_id="test_to_dict")
|
|
|
|
result = response.to_dict()
|
|
|
|
expected = {
|
|
"request_id": "test_to_dict",
|
|
"finished": True,
|
|
"error_code": 200,
|
|
"error_message": None,
|
|
"result": None,
|
|
}
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_to_dict_with_all_fields(self):
|
|
"""Test converting ControlResponse to dictionary with all fields."""
|
|
response = ControlResponse(
|
|
request_id="test_to_dict_full",
|
|
error_code=400,
|
|
error_message="Validation failed",
|
|
result={"valid": False, "reason": "missing_field"},
|
|
finished=False,
|
|
)
|
|
|
|
result = response.to_dict()
|
|
|
|
expected = {
|
|
"request_id": "test_to_dict_full",
|
|
"finished": False,
|
|
"error_code": 400,
|
|
"error_message": "Validation failed",
|
|
"result": {"valid": False, "reason": "missing_field"},
|
|
}
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_to_api_json_response_success(self):
|
|
"""Test converting to JSONResponse for successful response."""
|
|
result_data = {"metrics": {"cpu_usage": 0.5, "memory_used": 1024}}
|
|
response = ControlResponse(request_id="test_json_success", result=result_data)
|
|
|
|
json_response = response.to_api_json_response()
|
|
|
|
self.assertIsInstance(json_response, JSONResponse)
|
|
self.assertEqual(json_response.status_code, 200)
|
|
|
|
content = json_response.body.decode("utf-8")
|
|
self.assertIn("success", content)
|
|
self.assertIn("test_json_success", content)
|
|
self.assertIn("cpu_usage", content)
|
|
|
|
def test_to_api_json_response_error(self):
|
|
"""Test converting to JSONResponse for error response."""
|
|
response = ControlResponse(request_id="test_json_error", error_code=503, error_message="Service unavailable")
|
|
|
|
json_response = response.to_api_json_response()
|
|
|
|
self.assertIsInstance(json_response, JSONResponse)
|
|
self.assertEqual(json_response.status_code, 503)
|
|
|
|
content = json_response.body.decode("utf-8")
|
|
self.assertIn("error", content)
|
|
self.assertIn("test_json_error", content)
|
|
self.assertIn("Service unavailable", content)
|
|
|
|
def test_repr_method(self):
|
|
"""Test __repr__ method."""
|
|
response = ControlResponse(
|
|
request_id="test_repr", error_code=200, error_message=None, result={"data": "test"}, finished=True
|
|
)
|
|
|
|
repr_str = repr(response)
|
|
|
|
# Check that all important fields are represented
|
|
self.assertIn("ControlResponse", repr_str)
|
|
self.assertIn("test_repr", repr_str)
|
|
self.assertIn("200", repr_str)
|
|
self.assertIn("test", repr_str) # from result
|
|
self.assertIn("True", repr_str) # finished flag
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|