abort requests (#6992)

This commit is contained in:
qwes5s5
2026-03-31 11:02:26 +08:00
committed by GitHub
parent 6d9739f360
commit daa95244f7
13 changed files with 670 additions and 3 deletions
+1
View File
@@ -577,3 +577,4 @@ DeltaFunctionCall:
- `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset.
- `/v1/resume` - Resume generation.
- `/v1/is_paused` - Check if generation is paused.
- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts.
+1
View File
@@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling,
|----------|------|------|
| POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API |
| POST | `/v1/completions` | Provide scheduling services for general text completion inference requests |
| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts |
| POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling |
| GET | `/registered` | Query the list of currently registered inference instances |
| GET | `/registered_number` | Query the number of currently registered inference instances |
+1
View File
@@ -563,3 +563,4 @@ DeltaFunctionCall:
/v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。
/v1/resume - 恢复推理生成。
/v1/is_paused - 检查推理生成是否已暂停。
/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。
+1
View File
@@ -152,6 +152,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行
|----------|------|------|
| POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 |
| POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 |
| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids``abort_all=true`,返回已中断请求列表及其已生成的 token 数 |
| POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 |
| GET | `/registered` | 查询当前已注册的推理实例列表 |
| GET | `/registered_number` | 查询当前已注册的推理实例数量 |
+135
View File
@@ -43,9 +43,11 @@ from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
CompletionOutput,
ControlRequest,
ControlResponse,
Request,
RequestMetrics,
RequestOutput,
RequestStatus,
RequestType,
@@ -1500,6 +1502,139 @@ class EngineService:
return responses
def _control_abort_requests(self, control_req: ControlRequest):
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
args = control_req.get_args()
abort_all = args.get("abort_all", False)
req_ids = args.get("req_ids", [])
matched_input_ids = set()
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
# Step 1: Determine target request list
if abort_all:
# all requests in running + waiting
target_req_ids = now_reqs
else:
# filter out requests that actually exist
target_req_ids = []
for rid in req_ids:
if rid in now_reqs:
target_req_ids.append(rid)
matched_input_ids.add(rid)
elif f"{rid}_0" in now_reqs:
target_req_ids.append(f"{rid}_0")
matched_input_ids.add(rid)
if not target_req_ids:
return {"aborted": [], "not_found": req_ids if not abort_all else []}
# Step 2: Collect partial results
aborted_info = []
results = []
for req_id in target_req_ids:
request = self.resource_manager.requests.get(req_id)
if request is None:
scheduled_req = self.scheduler.requests.get(req_id)
if scheduled_req is None:
continue
request = scheduled_req.raw
partial_token_ids = list(request.output_token_ids)
# Construct finished response with partial results
now = time.time()
abort_metrics = RequestMetrics(
arrival_time=request.metrics.arrival_time if request.metrics else now,
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
engine_recv_latest_token_time=now,
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
request_start_time=request.metrics.arrival_time if request.metrics else now,
)
result = RequestOutput(
request_id=req_id,
finished=True,
outputs=CompletionOutput(
index=0,
send_idx=len(partial_token_ids),
token_ids=[self.data_processor.eos_token_ids[0]],
),
metrics=abort_metrics,
error_code=200,
error_msg="Aborted",
)
results.append(result)
aborted_info.append(
{
"request_id": req_id,
"output_token_count": len(partial_token_ids),
}
)
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for req_id in target_req_ids:
self.resource_manager.add_abort_req_ids(req_id)
time.sleep(0.0001)
if self.cfg.scheduler_config.splitwise_role != "prefill":
self._wait_abort_complete(target_req_ids)
# Add results to scheduler, engine will have a thread calling get_results,
# then cleanup and call send_response to send to client.
# When client disconnects, send_response will automatically ignore
if self.cfg.scheduler_config.splitwise_role != "prefill":
try:
# self.send_response_server.send_response(req_id, [result])
self.scheduler.put_results(results)
except Exception:
pass # client may have disconnected
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
return {"aborted": aborted_info, "not_found": not_found}
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
"""
Wait for all abort requests to complete.
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
reset progress state if any, then continue monitoring
"""
target_set = set(target_req_ids)
prev_remaining_count = len(target_set)
last_progress_time = time.time()
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
while remaining:
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
if not remaining:
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
return
current_count = len(remaining)
if current_count < prev_remaining_count:
# progress made: recycle_abort_task was called
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
last_progress_time = time.time()
prev_remaining_count = current_count
if time.time() - last_progress_time > stall_timeout:
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
if stuck:
self.llm_logger.warning(
f"no abort progress for {stall_timeout}s, "
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
)
for req_id in list(stuck):
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
self.resource_manager.recycle_abort_task(req_id)
# reset progress state
last_progress_time = time.time()
prev_remaining_count = current_count - len(stuck)
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
time.sleep(0.005)
def _parse_tags(self, control_request: ControlRequest):
"""
Parse tags from control request.
@@ -282,6 +282,7 @@ class ResourceManagerV1(ResourceManager):
del self.requests[request_id]
del self.req_dict[request_id]
self.to_be_aborted_req_id_set.remove(request_id)
self.update_metrics()
def _trigger_abort(self, request_id, scheduled_reqs):
if request_id in self.requests:
@@ -1207,6 +1208,9 @@ class ResourceManagerV1(ResourceManager):
return None
inputs["audio_features"] = result
def get_reqs_in_aborting(self):
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set
def get_available_position(self) -> int:
position = 0
while position < self.max_num_seqs:
@@ -475,6 +475,25 @@ async def update_weights(request: Request) -> Response:
return control_response.to_api_json_response()
@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
abort_all = body.get("abort_all", False)
req_ids = body.get("req_ids", None)
# 参数校验
if not abort_all and not req_ids:
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})
control_request = ControlRequest(
request_id=f"control-{uuid.uuid4()}",
method="abort_requests",
args={"abort_all": abort_all, "req_ids": req_ids or []},
)
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
def wrap_streaming_generator(original_generator: AsyncGenerator):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
@@ -469,6 +469,9 @@ class OpenAIServingChat:
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "recover_stop"
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choice.finish_reason = "abort"
inference_start_time[idx] = 0
if request.collect_metrics:
@@ -802,6 +805,8 @@ class OpenAIServingChat:
if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]:
finish_reason = "recover_stop"
if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]:
finish_reason = "abort"
return ChatCompletionResponseChoice(
index=idx,
message=message,
@@ -586,6 +586,8 @@ class OpenAIServingCompletion:
output,
tool_called[idx],
)
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
choices[-1].finish_reason = "abort"
inference_start_time[idx] = 0
send_idx = output.get("send_idx")
@@ -726,6 +728,8 @@ class OpenAIServingCompletion:
output,
False,
)
if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]:
finish_reason = "abort"
choice_data = CompletionResponseChoice(
token_ids=token_ids,
+44 -2
View File
@@ -17,8 +17,8 @@ from uuid import uuid4
import aiohttp
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse
from fastdeploy.router.utils import (
InstanceInfo,
@@ -503,6 +503,48 @@ async def health_generate():
return Response(status_code=200)
@app.post("/v1/abort_requests")
async def abort_requests(request: Request):
body = await request.json()
prefill_servers = app.state.router.prefill_servers
decode_servers = app.state.router.decode_servers
all_servers = prefill_servers + decode_servers
async with aiohttp.ClientSession() as session:
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Aggregate results from Node D only
all_aborted = []
all_not_found = []
errors = []
decode_start = len(prefill_servers)
for i, (server, resp) in enumerate(zip(all_servers, responses)):
if i < decode_start:
continue
if isinstance(resp, Exception):
errors.append({"server": server.url(), "error": str(resp)})
elif resp.status == 200:
data = await resp.json()
result = data.get("result") or {}
all_aborted.extend(result.get("aborted", []))
all_not_found.extend(result.get("not_found", []))
else:
errors.append({"server": server.url(), "status": resp.status})
return JSONResponse(
content={
"request_id": f"router-{uuid4()}",
"status": "success" if not errors else "error",
"error_message": None if not errors else str(errors),
"result": {
"aborted": all_aborted,
"not_found": list(set(all_not_found)),
},
}
)
def launch_router(router_args: RouterArgs):
app.state.router_args = router_args
print(f"Starting router with args: {router_args}")
+212
View File
@@ -3510,3 +3510,215 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
# At least one sleep call was made, confirming the inner function executed
self.assertGreaterEqual(call_count[0], 1)
self._detach_finalizer(eng)
# ── _control_abort_requests / _wait_abort_complete ───────────────
def _make_abort_engine(self, splitwise_role="mixed"):
"""Create an engine wired up for abort tests."""
extra = {}
if splitwise_role != "mixed":
extra["router"] = "0.0.0.0:9000"
cfg = self._make_cfg(splitwise_role=splitwise_role, num_gpu_blocks_override=4, **extra)
eng = self._make_engine(cfg)
eng.llm_logger = MagicMock()
# data_processor with eos token
eng.data_processor = MagicMock()
eng.data_processor.eos_token_ids = [2]
# resource_manager with requests dict and abort sets
eng.resource_manager = MagicMock()
eng.resource_manager.requests = {}
eng.resource_manager.waiting_abort_req_id_set = set()
eng.resource_manager.to_be_aborted_req_id_set = set()
eng.resource_manager.get_reqs_in_aborting = lambda: (
eng.resource_manager.waiting_abort_req_id_set | eng.resource_manager.to_be_aborted_req_id_set
)
# scheduler with requests dict and put_results
eng.scheduler = MagicMock()
eng.scheduler.requests = {}
eng.scheduler.put_results = MagicMock()
return eng
def _make_fake_request(self, output_token_ids=None):
"""Create a fake request object for abort tests."""
req = MagicMock()
req.output_token_ids = output_token_ids or [10, 20, 30]
req.metrics = MagicMock()
req.metrics.arrival_time = 1000.0
req.metrics.inference_start_time = 1000.1
req.metrics.engine_recv_first_token_time = 1000.2
return req
def test_control_abort_requests_not_v1_raises(self):
"""abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off."""
eng = self._make_abort_engine()
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0):
with self.assertRaises(Exception) as ctx:
eng._control_abort_requests(control_req)
self.assertIn("only supported", str(ctx.exception))
self._detach_finalizer(eng)
def test_control_abort_requests_abort_all(self):
"""abort_all=True aborts all requests in resource_manager + scheduler."""
eng = self._make_abort_engine()
eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])}
eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))}
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
def clear_abort_sets(req_id):
# Simulate immediate abort completion
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
result = eng._control_abort_requests(control_req)
self.assertEqual(len(result["aborted"]), 2)
self.assertEqual(result["not_found"], [])
ids = {a["request_id"] for a in result["aborted"]}
self.assertEqual(ids, {"req-1_0", "req-2_0"})
# put_results should have been called (not prefill)
eng.scheduler.put_results.assert_called_once()
self._detach_finalizer(eng)
def test_control_abort_requests_by_req_ids_with_suffix_match(self):
"""req_ids match both exact and _0 suffix."""
eng = self._make_abort_engine()
eng.resource_manager.requests = {
"req-A_0": self._make_fake_request([1, 2, 3]),
"req-B": self._make_fake_request([4, 5]),
}
control_req = ControlRequest(
"ctrl-1",
"abort_requests",
{
"abort_all": False,
"req_ids": ["req-A", "req-B", "req-C"],
},
)
def clear_abort_sets(req_id):
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
result = eng._control_abort_requests(control_req)
aborted_ids = {a["request_id"] for a in result["aborted"]}
self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix
self.assertIn("req-B", aborted_ids) # exact match
self.assertEqual(result["not_found"], ["req-C"])
self._detach_finalizer(eng)
def test_control_abort_requests_no_match(self):
"""No requests found returns empty aborted and all in not_found."""
eng = self._make_abort_engine()
control_req = ControlRequest(
"ctrl-1",
"abort_requests",
{
"abort_all": False,
"req_ids": ["nonexistent"],
},
)
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
result = eng._control_abort_requests(control_req)
self.assertEqual(result["aborted"], [])
self.assertEqual(result["not_found"], ["nonexistent"])
self._detach_finalizer(eng)
def test_control_abort_requests_prefill_skips_wait_and_put(self):
"""Prefill role skips _wait_abort_complete and put_results."""
eng = self._make_abort_engine(splitwise_role="prefill")
eng.resource_manager.requests = {"req-1_0": self._make_fake_request()}
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
eng.resource_manager.add_abort_req_ids = MagicMock()
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
result = eng._control_abort_requests(control_req)
self.assertEqual(len(result["aborted"]), 1)
eng.scheduler.put_results.assert_not_called()
self._detach_finalizer(eng)
def test_control_abort_requests_output_token_count(self):
"""output_token_count reflects partial_token_ids length."""
eng = self._make_abort_engine()
eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])}
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
def clear_abort_sets(req_id):
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
result = eng._control_abort_requests(control_req)
self.assertEqual(result["aborted"][0]["output_token_count"], 5)
self._detach_finalizer(eng)
def test_wait_abort_complete_immediate(self):
"""_wait_abort_complete returns immediately when all requests already cleaned."""
eng = self._make_abort_engine()
# Empty abort sets → remaining is empty → returns immediately
eng._wait_abort_complete(["req-1_0"])
self._detach_finalizer(eng)
def test_wait_abort_complete_progress(self):
"""_wait_abort_complete exits when background thread cleans up."""
eng = self._make_abort_engine()
eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"}
call_count = [0]
def fake_sleep(s):
call_count[0] += 1
# Simulate background thread cleaning up after first sleep
eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0")
with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep):
eng._wait_abort_complete(["req-1_0"])
self.assertGreaterEqual(call_count[0], 1)
self._detach_finalizer(eng)
def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self):
"""Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set."""
eng = self._make_abort_engine()
eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"}
def mock_recycle(req_id):
eng.resource_manager.to_be_aborted_req_id_set.discard(req_id)
eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle)
# Make time.time() advance past stall_timeout
time_values = [100.0, 100.0, 102.0, 102.0, 102.0]
time_idx = [0]
def fake_time():
idx = min(time_idx[0], len(time_values) - 1)
time_idx[0] += 1
return time_values[idx]
with (
patch("fastdeploy.engine.common_engine.time.time", fake_time),
patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None),
):
eng._wait_abort_complete(["req-1_0"], stall_timeout=1)
eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0")
self._detach_finalizer(eng)
@@ -809,3 +809,80 @@ def test_config_info():
api_server = _reload_api_server(args)
api_server.llm_engine = None
assert api_server.config_info().status_code == 500
# ── /v1/abort_requests ──────────────────────────────────────────────
def _mock_abort_control_response(api_server, result, status_code=200):
mock_resp = MagicMock()
mock_resp.to_api_json_response.return_value = api_server.JSONResponse(
content={"request_id": "control-test", "status": "success", "error_message": None, "result": result},
status_code=status_code,
)
api_server.app.state.engine_client = MagicMock()
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_resp)
@pytest.mark.asyncio
async def test_abort_requests_with_req_ids():
args = _build_args()
api_server = _reload_api_server(args)
_mock_abort_control_response(
api_server,
{
"aborted": [{"request_id": "req-1_0", "output_token_count": 10}],
"not_found": ["req-999"],
},
)
req = MagicMock()
req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]})
resp = await api_server.abort_requests(req)
assert resp.status_code == 200
control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0]
assert control_req.method == "abort_requests"
assert control_req.args["req_ids"] == ["req-1", "req-999"]
assert control_req.args["abort_all"] is False
@pytest.mark.asyncio
async def test_abort_requests_with_abort_all():
args = _build_args()
api_server = _reload_api_server(args)
_mock_abort_control_response(
api_server,
{
"aborted": [
{"request_id": "req-1_0", "output_token_count": 5},
{"request_id": "req-2_0", "output_token_count": 12},
],
"not_found": [],
},
)
req = MagicMock()
req.json = AsyncMock(return_value={"abort_all": True})
resp = await api_server.abort_requests(req)
assert resp.status_code == 200
control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0]
assert control_req.args["abort_all"] is True
assert control_req.args["req_ids"] == []
@pytest.mark.asyncio
async def test_abort_requests_missing_params():
args = _build_args()
api_server = _reload_api_server(args)
req = MagicMock()
req.json = AsyncMock(return_value={})
resp = await api_server.abort_requests(req)
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_abort_requests_empty_req_ids():
args = _build_args()
api_server = _reload_api_server(args)
req = MagicMock()
req.json = AsyncMock(return_value={"req_ids": []})
resp = await api_server.abort_requests(req)
assert resp.status_code == 400
+166 -1
View File
@@ -22,7 +22,7 @@ Why mock:
import unittest
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from fastdeploy.router.router import Router, RouterArgs
@@ -144,5 +144,170 @@ class TestRouterRegisteredNumber(unittest.IsolatedAsyncioTestCase):
self.assertEqual(result["decode"], 0)
class TestRouterAbortRequests(unittest.IsolatedAsyncioTestCase):
"""Tests for /v1/abort_requests route in router.py."""
def _make_mock_session(self, responses):
"""Create a mock aiohttp.ClientSession where post() returns coroutines."""
mock_session = MagicMock()
call_count = 0
def post_side_effect(*args, **kwargs):
nonlocal call_count
resp = responses[call_count]
call_count += 1
if isinstance(resp, Exception):
raise resp
async def _coro():
return resp
return _coro()
mock_session.post = MagicMock(side_effect=post_side_effect)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health):
"""P and D both receive the request, but only D results are aggregated."""
from fastdeploy.router.router import abort_requests as abort_fn
from fastdeploy.router.router import app
router = Router(_make_args(splitwise=True))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
app.state.router = router
prefill_resp = AsyncMock()
prefill_resp.status = 200
prefill_resp.json = AsyncMock(
return_value={
"request_id": "control-p",
"status": "success",
"error_message": None,
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []},
}
)
decode_resp = AsyncMock()
decode_resp.status = 200
decode_resp.json = AsyncMock(
return_value={
"request_id": "control-d",
"status": "success",
"error_message": None,
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []},
}
)
mock_session = self._make_mock_session([prefill_resp, decode_resp])
mock_request = AsyncMock()
mock_request.json = AsyncMock(return_value={"req_ids": ["req-1"]})
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
resp = await abort_fn(mock_request)
import json
body = json.loads(resp.body)
self.assertEqual(len(body["result"]["aborted"]), 1)
self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15)
self.assertEqual(body["status"], "success")
self.assertEqual(mock_session.post.call_count, 2)
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
async def test_abort_decode_error_returns_error_status(self, mock_health):
"""When D node returns a non-200 status, status should be 'error'."""
from fastdeploy.router.router import abort_requests as abort_fn
from fastdeploy.router.router import app
router = Router(_make_args(splitwise=True))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
app.state.router = router
prefill_resp = AsyncMock()
prefill_resp.status = 200
prefill_resp.json = AsyncMock(
return_value={
"request_id": "control-p",
"status": "success",
"error_message": None,
"result": {"aborted": [], "not_found": []},
}
)
decode_resp = AsyncMock()
decode_resp.status = 500
mock_session = self._make_mock_session([prefill_resp, decode_resp])
mock_request = AsyncMock()
mock_request.json = AsyncMock(return_value={"abort_all": True})
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
resp = await abort_fn(mock_request)
import json
body = json.loads(resp.body)
self.assertEqual(body["status"], "error")
self.assertIsNotNone(body["error_message"])
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
async def test_abort_decode_exception_returns_error(self, mock_health):
"""When D node connection fails (exception), error should be captured."""
from fastdeploy.router.router import abort_requests as abort_fn
from fastdeploy.router.router import app
router = Router(_make_args(splitwise=True))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
app.state.router = router
prefill_resp = AsyncMock()
prefill_resp.status = 200
prefill_resp.json = AsyncMock(
return_value={
"request_id": "control-p",
"status": "success",
"error_message": None,
"result": {"aborted": [], "not_found": []},
}
)
# D node raises exception — but asyncio.gather(return_exceptions=True) captures it
# So we pass the exception as a response directly
mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder
call_idx = [0]
def post_with_exception(*args, **kwargs):
call_idx[0] += 1
if call_idx[0] == 1:
# prefill: normal
async def _coro():
return prefill_resp
return _coro()
else:
# decode: raise (gather with return_exceptions=True will catch)
async def _coro_err():
raise ConnectionError("refused")
return _coro_err()
mock_session.post = MagicMock(side_effect=post_with_exception)
mock_request = AsyncMock()
mock_request.json = AsyncMock(return_value={"abort_all": True})
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
resp = await abort_fn(mock_request)
import json
body = json.loads(resp.body)
self.assertEqual(body["status"], "error")
self.assertIn("refused", body["error_message"])
if __name__ == "__main__":
unittest.main()