mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
abort requests (#6992)
This commit is contained in:
@@ -43,9 +43,11 @@ from fastdeploy.cache_manager.cache_data import CacheStatus
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.register_manager import RegisterManager
|
||||
from fastdeploy.engine.request import (
|
||||
CompletionOutput,
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
Request,
|
||||
RequestMetrics,
|
||||
RequestOutput,
|
||||
RequestStatus,
|
||||
RequestType,
|
||||
@@ -1500,6 +1502,139 @@ class EngineService:
|
||||
|
||||
return responses
|
||||
|
||||
def _control_abort_requests(self, control_req: ControlRequest):
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
||||
args = control_req.get_args()
|
||||
abort_all = args.get("abort_all", False)
|
||||
req_ids = args.get("req_ids", [])
|
||||
matched_input_ids = set()
|
||||
now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()))
|
||||
|
||||
# Step 1: Determine target request list
|
||||
if abort_all:
|
||||
# all requests in running + waiting
|
||||
target_req_ids = now_reqs
|
||||
else:
|
||||
# filter out requests that actually exist
|
||||
target_req_ids = []
|
||||
for rid in req_ids:
|
||||
if rid in now_reqs:
|
||||
target_req_ids.append(rid)
|
||||
matched_input_ids.add(rid)
|
||||
elif f"{rid}_0" in now_reqs:
|
||||
target_req_ids.append(f"{rid}_0")
|
||||
matched_input_ids.add(rid)
|
||||
|
||||
if not target_req_ids:
|
||||
return {"aborted": [], "not_found": req_ids if not abort_all else []}
|
||||
|
||||
# Step 2: Collect partial results
|
||||
aborted_info = []
|
||||
results = []
|
||||
for req_id in target_req_ids:
|
||||
request = self.resource_manager.requests.get(req_id)
|
||||
if request is None:
|
||||
scheduled_req = self.scheduler.requests.get(req_id)
|
||||
if scheduled_req is None:
|
||||
continue
|
||||
request = scheduled_req.raw
|
||||
|
||||
partial_token_ids = list(request.output_token_ids)
|
||||
|
||||
# Construct finished response with partial results
|
||||
now = time.time()
|
||||
abort_metrics = RequestMetrics(
|
||||
arrival_time=request.metrics.arrival_time if request.metrics else now,
|
||||
inference_start_time=request.metrics.inference_start_time if request.metrics else now,
|
||||
engine_recv_latest_token_time=now,
|
||||
engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now,
|
||||
request_start_time=request.metrics.arrival_time if request.metrics else now,
|
||||
)
|
||||
result = RequestOutput(
|
||||
request_id=req_id,
|
||||
finished=True,
|
||||
outputs=CompletionOutput(
|
||||
index=0,
|
||||
send_idx=len(partial_token_ids),
|
||||
token_ids=[self.data_processor.eos_token_ids[0]],
|
||||
),
|
||||
metrics=abort_metrics,
|
||||
error_code=200,
|
||||
error_msg="Aborted",
|
||||
)
|
||||
results.append(result)
|
||||
aborted_info.append(
|
||||
{
|
||||
"request_id": req_id,
|
||||
"output_token_count": len(partial_token_ids),
|
||||
}
|
||||
)
|
||||
|
||||
# Step 3: Execute abort — add all requests to waiting_abort_req_id_set
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
for req_id in target_req_ids:
|
||||
self.resource_manager.add_abort_req_ids(req_id)
|
||||
time.sleep(0.0001)
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||
self._wait_abort_complete(target_req_ids)
|
||||
|
||||
# Add results to scheduler, engine will have a thread calling get_results,
|
||||
# then cleanup and call send_response to send to client.
|
||||
# When client disconnects, send_response will automatically ignore
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||
try:
|
||||
# self.send_response_server.send_response(req_id, [result])
|
||||
self.scheduler.put_results(results)
|
||||
except Exception:
|
||||
pass # client may have disconnected
|
||||
|
||||
not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else []
|
||||
|
||||
return {"aborted": aborted_info, "not_found": not_found}
|
||||
|
||||
def _wait_abort_complete(self, target_req_ids, stall_timeout=1):
|
||||
"""
|
||||
Wait for all abort requests to complete.
|
||||
- Keep monitoring as long as remaining is not empty, which means cleanup is not done yet
|
||||
- If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set,
|
||||
reset progress state if any, then continue monitoring
|
||||
"""
|
||||
target_set = set(target_req_ids)
|
||||
prev_remaining_count = len(target_set)
|
||||
last_progress_time = time.time()
|
||||
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
||||
while remaining:
|
||||
remaining = target_set & self.resource_manager.get_reqs_in_aborting()
|
||||
if not remaining:
|
||||
self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned")
|
||||
return
|
||||
|
||||
current_count = len(remaining)
|
||||
if current_count < prev_remaining_count:
|
||||
# progress made: recycle_abort_task was called
|
||||
self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}")
|
||||
last_progress_time = time.time()
|
||||
prev_remaining_count = current_count
|
||||
|
||||
if time.time() - last_progress_time > stall_timeout:
|
||||
# no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9)
|
||||
stuck = remaining & self.resource_manager.to_be_aborted_req_id_set
|
||||
if stuck:
|
||||
self.llm_logger.warning(
|
||||
f"no abort progress for {stall_timeout}s, "
|
||||
f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)"
|
||||
)
|
||||
for req_id in list(stuck):
|
||||
self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}")
|
||||
self.resource_manager.recycle_abort_task(req_id)
|
||||
# reset progress state
|
||||
last_progress_time = time.time()
|
||||
prev_remaining_count = current_count - len(stuck)
|
||||
# else: remaining are all in waiting_abort_req_id_set, waiting for natural flow
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def _parse_tags(self, control_request: ControlRequest):
|
||||
"""
|
||||
Parse tags from control request.
|
||||
|
||||
Reference in New Issue
Block a user