mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
314 lines
13 KiB
Python
314 lines
13 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Tests for Router class.
|
|
|
|
Why mock:
|
|
- register_instance calls check_service_health_async which does real HTTP.
|
|
We mock it at the network boundary to test Router's registration and selection logic.
|
|
"""
|
|
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from fastdeploy.router.router import Router, RouterArgs
|
|
|
|
|
|
def _make_args(**kwargs):
|
|
defaults = {"host": "0.0.0.0", "port": 9000, "splitwise": False, "request_timeout_secs": 30}
|
|
defaults.update(kwargs)
|
|
return SimpleNamespace(**defaults)
|
|
|
|
|
|
def _make_instance_dict(host_ip="10.0.0.1", port=8080, role="mixed", **kwargs):
|
|
d = {
|
|
"host_ip": host_ip,
|
|
"port": port,
|
|
"role": role,
|
|
}
|
|
d.update(kwargs)
|
|
return d
|
|
|
|
|
|
class TestRouterArgs(unittest.TestCase):
|
|
def test_defaults(self):
|
|
args = RouterArgs()
|
|
self.assertEqual(args.host, "0.0.0.0")
|
|
self.assertEqual(args.port, 9000)
|
|
self.assertFalse(args.splitwise)
|
|
self.assertEqual(args.request_timeout_secs, 1800)
|
|
|
|
|
|
class TestRouterInit(unittest.TestCase):
|
|
def test_init(self):
|
|
args = _make_args()
|
|
router = Router(args)
|
|
self.assertEqual(router.host, "0.0.0.0")
|
|
self.assertEqual(router.port, 9000)
|
|
self.assertFalse(router.splitwise)
|
|
self.assertEqual(router.mixed_servers, [])
|
|
self.assertEqual(router.prefill_servers, [])
|
|
self.assertEqual(router.decode_servers, [])
|
|
|
|
|
|
class TestRouterRegistration(unittest.IsolatedAsyncioTestCase):
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_register_mixed_instance(self, mock_health):
|
|
router = Router(_make_args(splitwise=False))
|
|
inst_dict = _make_instance_dict(role="mixed")
|
|
await router.register_instance(inst_dict)
|
|
self.assertEqual(len(router.mixed_servers), 1)
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_register_splitwise_instances(self, mock_health):
|
|
router = Router(_make_args(splitwise=True))
|
|
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", role="prefill"))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", role="decode"))
|
|
|
|
self.assertEqual(len(router.prefill_servers), 1)
|
|
self.assertEqual(len(router.decode_servers), 1)
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_register_invalid_role_raises(self, mock_health):
|
|
"""Splitwise mode should reject mixed instances."""
|
|
router = Router(_make_args(splitwise=True))
|
|
with self.assertRaises(ValueError):
|
|
await router.register_instance(_make_instance_dict(role="mixed"))
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=False)
|
|
async def test_register_unhealthy_instance_raises(self, mock_health):
|
|
router = Router(_make_args(splitwise=False))
|
|
with self.assertRaises(RuntimeError):
|
|
await router.register_instance(_make_instance_dict(role="mixed"))
|
|
|
|
|
|
class TestRouterSelection(unittest.IsolatedAsyncioTestCase):
|
|
async def test_select_mixed_no_servers_raises(self):
|
|
router = Router(_make_args(splitwise=False))
|
|
with self.assertRaises(RuntimeError):
|
|
await router.select_mixed()
|
|
|
|
async def test_select_pd_no_prefill_raises(self):
|
|
router = Router(_make_args(splitwise=True))
|
|
with self.assertRaises(RuntimeError):
|
|
await router.select_pd()
|
|
|
|
async def test_select_pd_no_decode_raises(self):
|
|
"""Test select_pd raises when no decode servers available (line 152)."""
|
|
router = Router(_make_args(splitwise=True))
|
|
# Manually add a prefill server without going through health check
|
|
router.prefill_servers.append(_make_instance_dict(role="prefill"))
|
|
with self.assertRaises(RuntimeError) as ctx:
|
|
await router.select_pd()
|
|
self.assertIn("No decode servers available", str(ctx.exception))
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_select_mixed_returns_instance(self, mock_health):
|
|
router = Router(_make_args(splitwise=False))
|
|
await router.register_instance(_make_instance_dict(role="mixed"))
|
|
inst = await router.select_mixed()
|
|
self.assertIsNotNone(inst)
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_select_pd_returns_pair(self, mock_health):
|
|
router = Router(_make_args(splitwise=True))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", role="prefill"))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", role="decode"))
|
|
prefill, decode = await router.select_pd()
|
|
self.assertIsNotNone(prefill)
|
|
self.assertIsNotNone(decode)
|
|
|
|
|
|
class TestRouterRegisteredNumber(unittest.IsolatedAsyncioTestCase):
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_registered_number(self, mock_health):
|
|
router = Router(_make_args(splitwise=False))
|
|
await router.register_instance(_make_instance_dict(role="mixed"))
|
|
result = await router.registered_number()
|
|
self.assertEqual(result["mixed"], 1)
|
|
self.assertEqual(result["prefill"], 0)
|
|
self.assertEqual(result["decode"], 0)
|
|
|
|
|
|
class TestRouterAbortRequests(unittest.IsolatedAsyncioTestCase):
|
|
"""Tests for /v1/abort_requests route in router.py."""
|
|
|
|
def _make_mock_session(self, responses):
|
|
"""Create a mock aiohttp.ClientSession where post() returns coroutines."""
|
|
mock_session = MagicMock()
|
|
call_count = 0
|
|
|
|
def post_side_effect(*args, **kwargs):
|
|
nonlocal call_count
|
|
resp = responses[call_count]
|
|
call_count += 1
|
|
if isinstance(resp, Exception):
|
|
raise resp
|
|
|
|
async def _coro():
|
|
return resp
|
|
|
|
return _coro()
|
|
|
|
mock_session.post = MagicMock(side_effect=post_side_effect)
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=False)
|
|
return mock_session
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health):
|
|
"""P and D both receive the request, but only D results are aggregated."""
|
|
from fastdeploy.router.router import abort_requests as abort_fn
|
|
from fastdeploy.router.router import app
|
|
|
|
router = Router(_make_args(splitwise=True))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
|
app.state.router = router
|
|
|
|
prefill_resp = AsyncMock()
|
|
prefill_resp.status = 200
|
|
prefill_resp.json = AsyncMock(
|
|
return_value={
|
|
"request_id": "control-p",
|
|
"status": "success",
|
|
"error_message": None,
|
|
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []},
|
|
}
|
|
)
|
|
decode_resp = AsyncMock()
|
|
decode_resp.status = 200
|
|
decode_resp.json = AsyncMock(
|
|
return_value={
|
|
"request_id": "control-d",
|
|
"status": "success",
|
|
"error_message": None,
|
|
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []},
|
|
}
|
|
)
|
|
|
|
mock_session = self._make_mock_session([prefill_resp, decode_resp])
|
|
mock_request = AsyncMock()
|
|
mock_request.json = AsyncMock(return_value={"req_ids": ["req-1"]})
|
|
|
|
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
|
resp = await abort_fn(mock_request)
|
|
|
|
import json
|
|
|
|
body = json.loads(resp.body)
|
|
self.assertEqual(len(body["result"]["aborted"]), 1)
|
|
self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15)
|
|
self.assertEqual(body["status"], "success")
|
|
self.assertEqual(mock_session.post.call_count, 2)
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_abort_decode_error_returns_error_status(self, mock_health):
|
|
"""When D node returns a non-200 status, status should be 'error'."""
|
|
from fastdeploy.router.router import abort_requests as abort_fn
|
|
from fastdeploy.router.router import app
|
|
|
|
router = Router(_make_args(splitwise=True))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
|
app.state.router = router
|
|
|
|
prefill_resp = AsyncMock()
|
|
prefill_resp.status = 200
|
|
prefill_resp.json = AsyncMock(
|
|
return_value={
|
|
"request_id": "control-p",
|
|
"status": "success",
|
|
"error_message": None,
|
|
"result": {"aborted": [], "not_found": []},
|
|
}
|
|
)
|
|
decode_resp = AsyncMock()
|
|
decode_resp.status = 500
|
|
|
|
mock_session = self._make_mock_session([prefill_resp, decode_resp])
|
|
mock_request = AsyncMock()
|
|
mock_request.json = AsyncMock(return_value={"abort_all": True})
|
|
|
|
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
|
resp = await abort_fn(mock_request)
|
|
|
|
import json
|
|
|
|
body = json.loads(resp.body)
|
|
self.assertEqual(body["status"], "error")
|
|
self.assertIsNotNone(body["error_message"])
|
|
|
|
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
|
async def test_abort_decode_exception_returns_error(self, mock_health):
|
|
"""When D node connection fails (exception), error should be captured."""
|
|
from fastdeploy.router.router import abort_requests as abort_fn
|
|
from fastdeploy.router.router import app
|
|
|
|
router = Router(_make_args(splitwise=True))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
|
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
|
app.state.router = router
|
|
|
|
prefill_resp = AsyncMock()
|
|
prefill_resp.status = 200
|
|
prefill_resp.json = AsyncMock(
|
|
return_value={
|
|
"request_id": "control-p",
|
|
"status": "success",
|
|
"error_message": None,
|
|
"result": {"aborted": [], "not_found": []},
|
|
}
|
|
)
|
|
|
|
# D node raises exception — but asyncio.gather(return_exceptions=True) captures it
|
|
# So we pass the exception as a response directly
|
|
mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder
|
|
call_idx = [0]
|
|
|
|
def post_with_exception(*args, **kwargs):
|
|
call_idx[0] += 1
|
|
if call_idx[0] == 1:
|
|
# prefill: normal
|
|
async def _coro():
|
|
return prefill_resp
|
|
|
|
return _coro()
|
|
else:
|
|
# decode: raise (gather with return_exceptions=True will catch)
|
|
async def _coro_err():
|
|
raise ConnectionError("refused")
|
|
|
|
return _coro_err()
|
|
|
|
mock_session.post = MagicMock(side_effect=post_with_exception)
|
|
mock_request = AsyncMock()
|
|
mock_request.json = AsyncMock(return_value={"abort_all": True})
|
|
|
|
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
|
resp = await abort_fn(mock_request)
|
|
|
|
import json
|
|
|
|
body = json.loads(resp.body)
|
|
self.assertEqual(body["status"], "error")
|
|
self.assertIn("refused", body["error_message"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|