mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
abort requests (#6992)
This commit is contained in:
@@ -577,3 +577,4 @@ DeltaFunctionCall:
|
||||
- `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset.
|
||||
- `/v1/resume` - Resume generation.
|
||||
- `/v1/is_paused` - Check if generation is paused.
|
||||
- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts.
|
||||
|
||||
@@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling,
|
||||
|----------|------|------|
|
||||
| POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API |
|
||||
| POST | `/v1/completions` | Provide scheduling services for general text completion inference requests |
|
||||
| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts |
|
||||
| POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling |
|
||||
| GET | `/registered` | Query the list of currently registered inference instances |
|
||||
| GET | `/registered_number` | Query the number of currently registered inference instances |
|
||||
|
||||
@@ -563,3 +563,4 @@ DeltaFunctionCall:
|
||||
/v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。
|
||||
/v1/resume - 恢复推理生成。
|
||||
/v1/is_paused - 检查推理生成是否已暂停。
|
||||
/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。
|
||||
|
||||
@@ -152,6 +152,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行
|
||||
|----------|------|------|
|
||||
| POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 |
|
||||
| POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 |
|
||||
| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids` 或 `abort_all=true`,返回已中断请求列表及其已生成的 token 数 |
|
||||
| POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 |
|
||||
| GET | `/registered` | 查询当前已注册的推理实例列表 |
|
||||
| GET | `/registered_number` | 查询当前已注册的推理实例数量 |
|
||||
|
||||
@@ -43,9 +43,11 @@ from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.register_manager import RegisterManager
|
||||
from fastdeploy.engine.request import (
|
||||
CompletionOutput,
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
Request,
|
||||
RequestMetrics,
|
||||
RequestOutput,
|
||||
RequestStatus,
|
||||
RequestType,
|
||||
@@ -1500,6 +1502,139 @@ class EngineService:
|
||||
|
||||
return responses
|
||||
|
||||
def _control_abort_requests(self, control_req: ControlRequest):
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
||||
args = control_req.get_args()
|
||||
abort_all = args.get("abort_all", False)
|
||||
req_ids = args.get("req_ids", [])
|
||||
matched_input_ids = set()
|
||||
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
|
||||
|
||||
# Step 1: Determine target request list
|
||||
if abort_all:
|
||||
# all requests in running + waiting
|
||||
target_req_ids = now_reqs
|
||||
else:
|
||||
# filter out requests that actually exist
|
||||
target_req_ids = []
|
||||
for rid in req_ids:
|
||||
if rid in now_reqs:
|
||||
target_req_ids.append(rid)
|
||||
matched_input_ids.add(rid)
|
||||
elif f"{rid}_0" in now_reqs:
|
||||
target_req_ids.append(f"{rid}_0")
|
||||
matched_input_ids.add(rid)
|
||||
|
||||
if not target_req_ids:
|
||||
return {"aborted": [], "not_found": req_ids if not abort_all else []}
|
||||
|
||||
# Step 2: Collect partial results
|
||||
aborted_info = []
|
||||
results = []
|
||||
for req_id in target_req_ids:
|
||||
request = self.resource_manager.requests.get(req_id)
|
||||
if request is None:
|
||||
scheduled_req = self.scheduler.requests.get(req_id)
|
||||
if scheduled_req is None:
|
||||
continue
|
||||
request = scheduled_req.raw
|
||||
|
||||
partial_token_ids = list(request.output_token_ids)
|
||||
|
||||
# Construct finished response with partial results
|
||||
now = time.time()
|
||||
abort_metrics = RequestMetrics(
|
||||
arrival_time=request.metrics.arrival_time if request.metrics else now,
|
||||
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
|
||||
engine_recv_latest_token_time=now,
|
||||
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
|
||||
request_start_time=request.metrics.arrival_time if request.metrics else now,
|
||||
)
|
||||
result = RequestOutput(
|
||||
request_id=req_id,
|
||||
finished=True,
|
||||
outputs=CompletionOutput(
|
||||
index=0,
|
||||
send_idx=len(partial_token_ids),
|
||||
token_ids=[self.data_processor.eos_token_ids[0]],
|
||||
),
|
||||
metrics=abort_metrics,
|
||||
error_code=200,
|
||||
error_msg="Aborted",
|
||||
)
|
||||
results.append(result)
|
||||
aborted_info.append(
|
||||
{
|
||||
"request_id": req_id,
|
||||
"output_token_count": len(partial_token_ids),
|
||||
}
|
||||
)
|
||||
|
||||
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
for req_id in target_req_ids:
|
||||
self.resource_manager.add_abort_req_ids(req_id)
|
||||
time.sleep(0.0001)
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||
self._wait_abort_complete(target_req_ids)
|
||||
|
||||
# Add results to scheduler, engine will have a thread calling get_results,
|
||||
# then cleanup and call send_response to send to client.
|
||||
# When client disconnects, send_response will automatically ignore
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||
try:
|
||||
# self.send_response_server.send_response(req_id, [result])
|
||||
self.scheduler.put_results(results)
|
||||
except Exception:
|
||||
pass # client may have disconnected
|
||||
|
||||
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
|
||||
|
||||
return {"aborted": aborted_info, "not_found": not_found}
|
||||
|
||||
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
|
||||
"""
|
||||
Wait for all abort requests to complete.
|
||||
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
|
||||
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
|
||||
reset progress state if any, then continue monitoring
|
||||
"""
|
||||
target_set = set(target_req_ids)
|
||||
prev_remaining_count = len(target_set)
|
||||
last_progress_time = time.time()
|
||||
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
||||
while remaining:
|
||||
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
||||
if not remaining:
|
||||
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
|
||||
return
|
||||
|
||||
current_count = len(remaining)
|
||||
if current_count < prev_remaining_count:
|
||||
# progress made: recycle_abort_task was called
|
||||
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
|
||||
last_progress_time = time.time()
|
||||
prev_remaining_count = current_count
|
||||
|
||||
if time.time() - last_progress_time > stall_timeout:
|
||||
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
|
||||
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
|
||||
if stuck:
|
||||
self.llm_logger.warning(
|
||||
f"no abort progress for {stall_timeout}s, "
|
||||
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
|
||||
)
|
||||
for req_id in list(stuck):
|
||||
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
|
||||
self.resource_manager.recycle_abort_task(req_id)
|
||||
# reset progress state
|
||||
last_progress_time = time.time()
|
||||
prev_remaining_count = current_count - len(stuck)
|
||||
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def _parse_tags(self, control_request: ControlRequest):
|
||||
"""
|
||||
Parse tags from control request.
|
||||
|
||||
@@ -282,6 +282,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
del self.requests[request_id]
|
||||
del self.req_dict[request_id]
|
||||
self.to_be_aborted_req_id_set.remove(request_id)
|
||||
self.update_metrics()
|
||||
|
||||
def _trigger_abort(self, request_id, scheduled_reqs):
|
||||
if request_id in self.requests:
|
||||
@@ -1207,6 +1208,9 @@ class ResourceManagerV1(ResourceManager):
|
||||
return None
|
||||
inputs["audio_features"] = result
|
||||
|
||||
def get_reqs_in_aborting(self):
|
||||
return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set
|
||||
|
||||
def get_available_position(self) -> int:
|
||||
position = 0
|
||||
while position < self.max_num_seqs:
|
||||
|
||||
@@ -475,6 +475,25 @@ async def update_weights(request: Request) -> Response:
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/abort_requests")
|
||||
async def abort_requests(request: Request):
|
||||
body = await request.json()
|
||||
abort_all = body.get("abort_all", False)
|
||||
req_ids = body.get("req_ids", None)
|
||||
|
||||
# 参数校验
|
||||
if not abort_all and not req_ids:
|
||||
return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"})
|
||||
|
||||
control_request = ControlRequest(
|
||||
request_id=f"control-{uuid.uuid4()}",
|
||||
method="abort_requests",
|
||||
args={"abort_all": abort_all, "req_ids": req_ids or []},
|
||||
)
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
"""
|
||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||
|
||||
@@ -469,6 +469,9 @@ class OpenAIServingChat:
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
|
||||
choice.finish_reason = "abort"
|
||||
|
||||
inference_start_time[idx] = 0
|
||||
|
||||
if request.collect_metrics:
|
||||
@@ -802,6 +805,8 @@ class OpenAIServingChat:
|
||||
if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]:
|
||||
finish_reason = "recover_stop"
|
||||
|
||||
if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]:
|
||||
finish_reason = "abort"
|
||||
return ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=message,
|
||||
|
||||
@@ -586,6 +586,8 @@ class OpenAIServingCompletion:
|
||||
output,
|
||||
tool_called[idx],
|
||||
)
|
||||
if res.get("error_msg") is not None and "Aborted" in res["error_msg"]:
|
||||
choices[-1].finish_reason = "abort"
|
||||
inference_start_time[idx] = 0
|
||||
|
||||
send_idx = output.get("send_idx")
|
||||
@@ -726,6 +728,8 @@ class OpenAIServingCompletion:
|
||||
output,
|
||||
False,
|
||||
)
|
||||
if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]:
|
||||
finish_reason = "abort"
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
token_ids=token_ids,
|
||||
|
||||
@@ -17,8 +17,8 @@ from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from fastdeploy.router.utils import (
|
||||
InstanceInfo,
|
||||
@@ -503,6 +503,48 @@ async def health_generate():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/abort_requests")
|
||||
async def abort_requests(request: Request):
|
||||
body = await request.json()
|
||||
prefill_servers = app.state.router.prefill_servers
|
||||
decode_servers = app.state.router.decode_servers
|
||||
all_servers = prefill_servers + decode_servers
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results from Node D only
|
||||
all_aborted = []
|
||||
all_not_found = []
|
||||
errors = []
|
||||
decode_start = len(prefill_servers)
|
||||
for i, (server, resp) in enumerate(zip(all_servers, responses)):
|
||||
if i < decode_start:
|
||||
continue
|
||||
if isinstance(resp, Exception):
|
||||
errors.append({"server": server.url(), "error": str(resp)})
|
||||
elif resp.status == 200:
|
||||
data = await resp.json()
|
||||
result = data.get("result") or {}
|
||||
all_aborted.extend(result.get("aborted", []))
|
||||
all_not_found.extend(result.get("not_found", []))
|
||||
else:
|
||||
errors.append({"server": server.url(), "status": resp.status})
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"request_id": f"router-{uuid4()}",
|
||||
"status": "success" if not errors else "error",
|
||||
"error_message": None if not errors else str(errors),
|
||||
"result": {
|
||||
"aborted": all_aborted,
|
||||
"not_found": list(set(all_not_found)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def launch_router(router_args: RouterArgs):
|
||||
app.state.router_args = router_args
|
||||
print(f"Starting router with args: {router_args}")
|
||||
|
||||
@@ -3510,3 +3510,215 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
# At least one sleep call was made, confirming the inner function executed
|
||||
self.assertGreaterEqual(call_count[0], 1)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
# ── _control_abort_requests / _wait_abort_complete ───────────────
|
||||
|
||||
def _make_abort_engine(self, splitwise_role="mixed"):
|
||||
"""Create an engine wired up for abort tests."""
|
||||
extra = {}
|
||||
if splitwise_role != "mixed":
|
||||
extra["router"] = "0.0.0.0:9000"
|
||||
cfg = self._make_cfg(splitwise_role=splitwise_role, num_gpu_blocks_override=4, **extra)
|
||||
eng = self._make_engine(cfg)
|
||||
eng.llm_logger = MagicMock()
|
||||
|
||||
# data_processor with eos token
|
||||
eng.data_processor = MagicMock()
|
||||
eng.data_processor.eos_token_ids = [2]
|
||||
|
||||
# resource_manager with requests dict and abort sets
|
||||
eng.resource_manager = MagicMock()
|
||||
eng.resource_manager.requests = {}
|
||||
eng.resource_manager.waiting_abort_req_id_set = set()
|
||||
eng.resource_manager.to_be_aborted_req_id_set = set()
|
||||
eng.resource_manager.get_reqs_in_aborting = lambda: (
|
||||
eng.resource_manager.waiting_abort_req_id_set | eng.resource_manager.to_be_aborted_req_id_set
|
||||
)
|
||||
|
||||
# scheduler with requests dict and put_results
|
||||
eng.scheduler = MagicMock()
|
||||
eng.scheduler.requests = {}
|
||||
eng.scheduler.put_results = MagicMock()
|
||||
|
||||
return eng
|
||||
|
||||
def _make_fake_request(self, output_token_ids=None):
|
||||
"""Create a fake request object for abort tests."""
|
||||
req = MagicMock()
|
||||
req.output_token_ids = output_token_ids or [10, 20, 30]
|
||||
req.metrics = MagicMock()
|
||||
req.metrics.arrival_time = 1000.0
|
||||
req.metrics.inference_start_time = 1000.1
|
||||
req.metrics.engine_recv_first_token_time = 1000.2
|
||||
return req
|
||||
|
||||
def test_control_abort_requests_not_v1_raises(self):
|
||||
"""abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off."""
|
||||
eng = self._make_abort_engine()
|
||||
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0):
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
eng._control_abort_requests(control_req)
|
||||
self.assertIn("only supported", str(ctx.exception))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_abort_requests_abort_all(self):
|
||||
"""abort_all=True aborts all requests in resource_manager + scheduler."""
|
||||
eng = self._make_abort_engine()
|
||||
eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])}
|
||||
eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))}
|
||||
|
||||
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
|
||||
|
||||
def clear_abort_sets(req_id):
|
||||
# Simulate immediate abort completion
|
||||
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
|
||||
|
||||
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
|
||||
result = eng._control_abort_requests(control_req)
|
||||
|
||||
self.assertEqual(len(result["aborted"]), 2)
|
||||
self.assertEqual(result["not_found"], [])
|
||||
ids = {a["request_id"] for a in result["aborted"]}
|
||||
self.assertEqual(ids, {"req-1_0", "req-2_0"})
|
||||
# put_results should have been called (not prefill)
|
||||
eng.scheduler.put_results.assert_called_once()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_abort_requests_by_req_ids_with_suffix_match(self):
|
||||
"""req_ids match both exact and _0 suffix."""
|
||||
eng = self._make_abort_engine()
|
||||
eng.resource_manager.requests = {
|
||||
"req-A_0": self._make_fake_request([1, 2, 3]),
|
||||
"req-B": self._make_fake_request([4, 5]),
|
||||
}
|
||||
|
||||
control_req = ControlRequest(
|
||||
"ctrl-1",
|
||||
"abort_requests",
|
||||
{
|
||||
"abort_all": False,
|
||||
"req_ids": ["req-A", "req-B", "req-C"],
|
||||
},
|
||||
)
|
||||
|
||||
def clear_abort_sets(req_id):
|
||||
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
|
||||
|
||||
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
|
||||
result = eng._control_abort_requests(control_req)
|
||||
|
||||
aborted_ids = {a["request_id"] for a in result["aborted"]}
|
||||
self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix
|
||||
self.assertIn("req-B", aborted_ids) # exact match
|
||||
self.assertEqual(result["not_found"], ["req-C"])
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_abort_requests_no_match(self):
|
||||
"""No requests found returns empty aborted and all in not_found."""
|
||||
eng = self._make_abort_engine()
|
||||
control_req = ControlRequest(
|
||||
"ctrl-1",
|
||||
"abort_requests",
|
||||
{
|
||||
"abort_all": False,
|
||||
"req_ids": ["nonexistent"],
|
||||
},
|
||||
)
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
|
||||
result = eng._control_abort_requests(control_req)
|
||||
|
||||
self.assertEqual(result["aborted"], [])
|
||||
self.assertEqual(result["not_found"], ["nonexistent"])
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_abort_requests_prefill_skips_wait_and_put(self):
|
||||
"""Prefill role skips _wait_abort_complete and put_results."""
|
||||
eng = self._make_abort_engine(splitwise_role="prefill")
|
||||
eng.resource_manager.requests = {"req-1_0": self._make_fake_request()}
|
||||
|
||||
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
|
||||
eng.resource_manager.add_abort_req_ids = MagicMock()
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
|
||||
result = eng._control_abort_requests(control_req)
|
||||
|
||||
self.assertEqual(len(result["aborted"]), 1)
|
||||
eng.scheduler.put_results.assert_not_called()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_control_abort_requests_output_token_count(self):
|
||||
"""output_token_count reflects partial_token_ids length."""
|
||||
eng = self._make_abort_engine()
|
||||
eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])}
|
||||
|
||||
control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []})
|
||||
|
||||
def clear_abort_sets(req_id):
|
||||
eng.resource_manager.waiting_abort_req_id_set.discard(req_id)
|
||||
|
||||
eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets)
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1):
|
||||
result = eng._control_abort_requests(control_req)
|
||||
|
||||
self.assertEqual(result["aborted"][0]["output_token_count"], 5)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_abort_complete_immediate(self):
|
||||
"""_wait_abort_complete returns immediately when all requests already cleaned."""
|
||||
eng = self._make_abort_engine()
|
||||
# Empty abort sets → remaining is empty → returns immediately
|
||||
eng._wait_abort_complete(["req-1_0"])
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_abort_complete_progress(self):
|
||||
"""_wait_abort_complete exits when background thread cleans up."""
|
||||
eng = self._make_abort_engine()
|
||||
eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"}
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def fake_sleep(s):
|
||||
call_count[0] += 1
|
||||
# Simulate background thread cleaning up after first sleep
|
||||
eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0")
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep):
|
||||
eng._wait_abort_complete(["req-1_0"])
|
||||
|
||||
self.assertGreaterEqual(call_count[0], 1)
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self):
|
||||
"""Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set."""
|
||||
eng = self._make_abort_engine()
|
||||
eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"}
|
||||
|
||||
def mock_recycle(req_id):
|
||||
eng.resource_manager.to_be_aborted_req_id_set.discard(req_id)
|
||||
|
||||
eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle)
|
||||
|
||||
# Make time.time() advance past stall_timeout
|
||||
time_values = [100.0, 100.0, 102.0, 102.0, 102.0]
|
||||
time_idx = [0]
|
||||
|
||||
def fake_time():
|
||||
idx = min(time_idx[0], len(time_values) - 1)
|
||||
time_idx[0] += 1
|
||||
return time_values[idx]
|
||||
|
||||
with (
|
||||
patch("fastdeploy.engine.common_engine.time.time", fake_time),
|
||||
patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None),
|
||||
):
|
||||
eng._wait_abort_complete(["req-1_0"], stall_timeout=1)
|
||||
|
||||
eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0")
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
@@ -809,3 +809,80 @@ def test_config_info():
|
||||
api_server = _reload_api_server(args)
|
||||
api_server.llm_engine = None
|
||||
assert api_server.config_info().status_code == 500
|
||||
|
||||
|
||||
# ── /v1/abort_requests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_abort_control_response(api_server, result, status_code=200):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.to_api_json_response.return_value = api_server.JSONResponse(
|
||||
content={"request_id": "control-test", "status": "success", "error_message": None, "result": result},
|
||||
status_code=status_code,
|
||||
)
|
||||
api_server.app.state.engine_client = MagicMock()
|
||||
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_resp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_requests_with_req_ids():
|
||||
args = _build_args()
|
||||
api_server = _reload_api_server(args)
|
||||
_mock_abort_control_response(
|
||||
api_server,
|
||||
{
|
||||
"aborted": [{"request_id": "req-1_0", "output_token_count": 10}],
|
||||
"not_found": ["req-999"],
|
||||
},
|
||||
)
|
||||
req = MagicMock()
|
||||
req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]})
|
||||
resp = await api_server.abort_requests(req)
|
||||
assert resp.status_code == 200
|
||||
control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0]
|
||||
assert control_req.method == "abort_requests"
|
||||
assert control_req.args["req_ids"] == ["req-1", "req-999"]
|
||||
assert control_req.args["abort_all"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_requests_with_abort_all():
|
||||
args = _build_args()
|
||||
api_server = _reload_api_server(args)
|
||||
_mock_abort_control_response(
|
||||
api_server,
|
||||
{
|
||||
"aborted": [
|
||||
{"request_id": "req-1_0", "output_token_count": 5},
|
||||
{"request_id": "req-2_0", "output_token_count": 12},
|
||||
],
|
||||
"not_found": [],
|
||||
},
|
||||
)
|
||||
req = MagicMock()
|
||||
req.json = AsyncMock(return_value={"abort_all": True})
|
||||
resp = await api_server.abort_requests(req)
|
||||
assert resp.status_code == 200
|
||||
control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0]
|
||||
assert control_req.args["abort_all"] is True
|
||||
assert control_req.args["req_ids"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_requests_missing_params():
|
||||
args = _build_args()
|
||||
api_server = _reload_api_server(args)
|
||||
req = MagicMock()
|
||||
req.json = AsyncMock(return_value={})
|
||||
resp = await api_server.abort_requests(req)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_requests_empty_req_ids():
|
||||
args = _build_args()
|
||||
api_server = _reload_api_server(args)
|
||||
req = MagicMock()
|
||||
req.json = AsyncMock(return_value={"req_ids": []})
|
||||
resp = await api_server.abort_requests(req)
|
||||
assert resp.status_code == 400
|
||||
|
||||
+166
-1
@@ -22,7 +22,7 @@ Why mock:
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastdeploy.router.router import Router, RouterArgs
|
||||
|
||||
@@ -144,5 +144,170 @@ class TestRouterRegisteredNumber(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(result["decode"], 0)
|
||||
|
||||
|
||||
class TestRouterAbortRequests(unittest.IsolatedAsyncioTestCase):
|
||||
"""Tests for /v1/abort_requests route in router.py."""
|
||||
|
||||
def _make_mock_session(self, responses):
|
||||
"""Create a mock aiohttp.ClientSession where post() returns coroutines."""
|
||||
mock_session = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
def post_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
resp = responses[call_count]
|
||||
call_count += 1
|
||||
if isinstance(resp, Exception):
|
||||
raise resp
|
||||
|
||||
async def _coro():
|
||||
return resp
|
||||
|
||||
return _coro()
|
||||
|
||||
mock_session.post = MagicMock(side_effect=post_side_effect)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_session
|
||||
|
||||
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
||||
async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health):
|
||||
"""P and D both receive the request, but only D results are aggregated."""
|
||||
from fastdeploy.router.router import abort_requests as abort_fn
|
||||
from fastdeploy.router.router import app
|
||||
|
||||
router = Router(_make_args(splitwise=True))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
||||
app.state.router = router
|
||||
|
||||
prefill_resp = AsyncMock()
|
||||
prefill_resp.status = 200
|
||||
prefill_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"request_id": "control-p",
|
||||
"status": "success",
|
||||
"error_message": None,
|
||||
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []},
|
||||
}
|
||||
)
|
||||
decode_resp = AsyncMock()
|
||||
decode_resp.status = 200
|
||||
decode_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"request_id": "control-d",
|
||||
"status": "success",
|
||||
"error_message": None,
|
||||
"result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []},
|
||||
}
|
||||
)
|
||||
|
||||
mock_session = self._make_mock_session([prefill_resp, decode_resp])
|
||||
mock_request = AsyncMock()
|
||||
mock_request.json = AsyncMock(return_value={"req_ids": ["req-1"]})
|
||||
|
||||
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
||||
resp = await abort_fn(mock_request)
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(resp.body)
|
||||
self.assertEqual(len(body["result"]["aborted"]), 1)
|
||||
self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15)
|
||||
self.assertEqual(body["status"], "success")
|
||||
self.assertEqual(mock_session.post.call_count, 2)
|
||||
|
||||
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
||||
async def test_abort_decode_error_returns_error_status(self, mock_health):
|
||||
"""When D node returns a non-200 status, status should be 'error'."""
|
||||
from fastdeploy.router.router import abort_requests as abort_fn
|
||||
from fastdeploy.router.router import app
|
||||
|
||||
router = Router(_make_args(splitwise=True))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
||||
app.state.router = router
|
||||
|
||||
prefill_resp = AsyncMock()
|
||||
prefill_resp.status = 200
|
||||
prefill_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"request_id": "control-p",
|
||||
"status": "success",
|
||||
"error_message": None,
|
||||
"result": {"aborted": [], "not_found": []},
|
||||
}
|
||||
)
|
||||
decode_resp = AsyncMock()
|
||||
decode_resp.status = 500
|
||||
|
||||
mock_session = self._make_mock_session([prefill_resp, decode_resp])
|
||||
mock_request = AsyncMock()
|
||||
mock_request.json = AsyncMock(return_value={"abort_all": True})
|
||||
|
||||
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
||||
resp = await abort_fn(mock_request)
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(resp.body)
|
||||
self.assertEqual(body["status"], "error")
|
||||
self.assertIsNotNone(body["error_message"])
|
||||
|
||||
@patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True)
|
||||
async def test_abort_decode_exception_returns_error(self, mock_health):
|
||||
"""When D node connection fails (exception), error should be captured."""
|
||||
from fastdeploy.router.router import abort_requests as abort_fn
|
||||
from fastdeploy.router.router import app
|
||||
|
||||
router = Router(_make_args(splitwise=True))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.1", port=8001, role="prefill"))
|
||||
await router.register_instance(_make_instance_dict(host_ip="10.0.0.2", port=8002, role="decode"))
|
||||
app.state.router = router
|
||||
|
||||
prefill_resp = AsyncMock()
|
||||
prefill_resp.status = 200
|
||||
prefill_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"request_id": "control-p",
|
||||
"status": "success",
|
||||
"error_message": None,
|
||||
"result": {"aborted": [], "not_found": []},
|
||||
}
|
||||
)
|
||||
|
||||
# D node raises exception — but asyncio.gather(return_exceptions=True) captures it
|
||||
# So we pass the exception as a response directly
|
||||
mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder
|
||||
call_idx = [0]
|
||||
|
||||
def post_with_exception(*args, **kwargs):
|
||||
call_idx[0] += 1
|
||||
if call_idx[0] == 1:
|
||||
# prefill: normal
|
||||
async def _coro():
|
||||
return prefill_resp
|
||||
|
||||
return _coro()
|
||||
else:
|
||||
# decode: raise (gather with return_exceptions=True will catch)
|
||||
async def _coro_err():
|
||||
raise ConnectionError("refused")
|
||||
|
||||
return _coro_err()
|
||||
|
||||
mock_session.post = MagicMock(side_effect=post_with_exception)
|
||||
mock_request = AsyncMock()
|
||||
mock_request.json = AsyncMock(return_value={"abort_all": True})
|
||||
|
||||
with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session):
|
||||
resp = await abort_fn(mock_request)
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(resp.body)
|
||||
self.assertEqual(body["status"], "error")
|
||||
self.assertIn("refused", body["error_message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user