mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
abort requests (#6992)
This commit is contained in:
@@ -17,8 +17,8 @@ from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from fastdeploy.router.utils import (
|
||||
InstanceInfo,
|
||||
@@ -503,6 +503,48 @@ async def health_generate():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/abort_requests")
|
||||
async def abort_requests(request: Request):
|
||||
body = await request.json()
|
||||
prefill_servers = app.state.router.prefill_servers
|
||||
decode_servers = app.state.router.decode_servers
|
||||
all_servers = prefill_servers + decode_servers
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers]
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Aggregate results from Node D only
|
||||
all_aborted = []
|
||||
all_not_found = []
|
||||
errors = []
|
||||
decode_start = len(prefill_servers)
|
||||
for i, (server, resp) in enumerate(zip(all_servers, responses)):
|
||||
if i < decode_start:
|
||||
continue
|
||||
if isinstance(resp, Exception):
|
||||
errors.append({"server": server.url(), "error": str(resp)})
|
||||
elif resp.status == 200:
|
||||
data = await resp.json()
|
||||
result = data.get("result") or {}
|
||||
all_aborted.extend(result.get("aborted", []))
|
||||
all_not_found.extend(result.get("not_found", []))
|
||||
else:
|
||||
errors.append({"server": server.url(), "status": resp.status})
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"request_id": f"router-{uuid4()}",
|
||||
"status": "success" if not errors else "error",
|
||||
"error_message": None if not errors else str(errors),
|
||||
"result": {
|
||||
"aborted": all_aborted,
|
||||
"not_found": list(set(all_not_found)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def launch_router(router_args: RouterArgs):
|
||||
app.state.router_args = router_args
|
||||
print(f"Starting router with args: {router_args}")
|
||||
|
||||
Reference in New Issue
Block a user