[Cherry-Pick] [Feature] support v1 update/clear api for RL (#6761) (#6974)

* [Feature] support v1 update/clear api for RL

* [fix] fix stale control responses when control method timed out

* [chore] remove unused code

* [chore] optimize tags and key_prefix

* [test] fix ci

* [chore] fix code style

* [fix] fix ep control

* [fix] fix ep control for engine cache queue
This commit is contained in:
Yonghua Li
2026-03-25 19:18:35 +08:00
committed by GitHub
parent 49c2310854
commit 35034f91fa
25 changed files with 1665 additions and 328 deletions
+252 -48
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
import collections
import copy
import json
import multiprocessing
@@ -39,6 +40,8 @@ import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.cache_manager.cache_data import CacheStatus
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
@@ -83,7 +86,7 @@ class EngineService:
Base class containing common engine functionality
"""
def __init__(self, cfg, start_queue=True, use_async_llm=False):
def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False):
"""
Initializes the LLMEngine with the provided configuration.
@@ -103,14 +106,23 @@ class EngineService:
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_worker_output_queues = []
self._ctrl_output_queues = {}
self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict)
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for rank in range(tp_size):
for tp_rank in range(tp_size):
# create worker control response queue
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
name = f"ctrl_w2e_rank{tp_rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
# create cache control response queue
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
engine_cache_queue_port = self.cfg.cache_config.local_cache_queue_port
name = f"ctrl_c2e_rank{tp_rank+tp_size*dp_index}_{engine_cache_queue_port}"
self.llm_logger.info(f"Init Cache Control Output Queue: {name} (consumer)")
self._ctrl_output_queues[name] = FMQ().queue(name, "consumer")
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
@@ -1248,7 +1260,7 @@ class EngineService:
request_id = control_req.request_id
try:
self.llm_logger.info(f"START run control method {request_id}: {method}")
self.llm_logger.info(f"Start to run control method {method}: {request_id}")
handler_name = f"_control_{method}"
handler = getattr(self, handler_name, None)
@@ -1259,12 +1271,12 @@ class EngineService:
return
result = handler(control_req)
self.llm_logger.info(f"SUCCESS run control method {method}.")
self.llm_logger.info(f"Successfully run control method {method}: {request_id} {result}")
succ_result = ControlResponse(request_id, 200, "Success", result)
self.send_response_server.send_response(request_id, [succ_result])
except Exception as e:
error_msg = f"Failed run control method {method}: {str(e)}"
error_msg = f"Failed to run control method {method}: {request_id} {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
self.send_response_server.send_response(request_id, [error_result])
@@ -1287,12 +1299,15 @@ class EngineService:
if self.cfg.scheduler_config.name != "local":
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
self.llm_logger.info("Start to pause request generation.")
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Pause Request Generation: already paused.")
self.llm_logger.info("Engine is already paused, no need to pause again.")
return
self.is_paused = True
self.llm_logger.info("Start Abort Running Requests")
self.llm_logger.info("Abort running requests.")
self.resource_manager.log_status()
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
@@ -1303,7 +1318,7 @@ class EngineService:
if count >= timeout * 1000:
break
if count >= timeout * 1000:
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
running_reqs = self.resource_manager.preempted_all()
@@ -1318,12 +1333,22 @@ class EngineService:
# abort inflight requests to user
inflight_requests = self.scheduler.get_inflight_requests()
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).")
for req in inflight_requests:
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
self._send_error_response(req.request_id, "Request is aborted since engine is paused.")
self.scheduler.reset()
# pause cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to pause cache transfer.")
pause_transfer_request = ControlRequest(request_id="pause_transfer", method="pause")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, pause_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("pause_transfer", 60, executors=["cache_transfer"]))
self.llm_logger.info("Successfully paused cache transfer.")
self.resource_manager.cache_manager.reset()
self.llm_logger.info("Successfully paused request generation.")
return None
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
@@ -1335,14 +1360,24 @@ class EngineService:
Args:
control_request: Control request object containing resume operation information
"""
self.llm_logger.info("START Resume Request Generation")
self.llm_logger.info("Start to resume request generation.")
with self._pause_cond:
if not self.is_paused:
self.llm_logger.info("Resume Request Generation: not paused.")
self.llm_logger.info("Engine is not paused, no need to resume.")
return None
self.is_paused = False
self._pause_cond.notify_all()
self.llm_logger.info("END Resume Request Generation")
# resume cache transfer
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.llm_logger.info("Start to resume cache transfer.")
resume_transfer_request = ControlRequest(request_id="resume_transfer", method="resume")
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, resume_transfer_request))
# Wait for cache_transfer responses
asyncio.run(self._wait_for_control_responses("resume_transfer", 60, executors=["cache_transfer"]))
self.llm_logger.info("Successfully resumed cache transfer.")
self.llm_logger.info("Successfully resumed request generation.")
return None
def _control_is_paused(self, control_request: ControlRequest) -> bool:
@@ -1378,49 +1413,218 @@ class EngineService:
raise Exception(error_msg)
return self._call_worker(control_request, 60)
async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
This method concurrently waits for responses from all control workers
and enforces an overall timeout to avoid leaking pending tasks.
def _parse_tags(self, control_request: ControlRequest):
"""
timeout_ms = timeout * 1000
# Create one get() coroutine per worker output queue
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=timeout,
Parse tags from control request.
"""
allowed_tags = ["weight", "kv_cache"]
tags = control_request.args.get("tags", None)
if tags is None:
tags = ",".join(allowed_tags)
control_request.args["tags"] = tags
self.llm_logger.info(
f"Detected empty tags of request {control_request.request_id}, defaulting to tags: {tags}"
)
except asyncio.TimeoutError:
# Keep the error message consistent with previous behavior
raise Exception("Worker Update Weights Timeouted after 600s")
elif isinstance(tags, list):
tags = ",".join(tags)
for tag in tags.split(","):
if tag not in allowed_tags:
raise ValueError(f"Unsupported tag [{tag}] in [{tags}], expected one of {allowed_tags}")
return tags
def _control_sleep(self, control_request: ControlRequest):
"""
Offload gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for offloading memory
tags: list of tags to offload, supported values: ["weight", "cache"]
TODO: support different level of offloading, to provide options for release memory forever
or merely offloading to cpu memory for now.
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Make sure llm engine is paused.
self.llm_logger.warning(
"Implicitly pause LLM engine before sleeping. This behavior will be deprecated in future versions. "
"Please explicitly request to /pause the engine before /sleep."
)
self._control_pause(None)
# Determine which executors are needed for the sleep command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
if self.cfg.cache_config.enable_prefix_caching:
self.resource_manager.cache_manager.reset()
# Dispatch sleep request to executors
self.llm_logger.info(f"Dispatch sleep request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
return asyncio.run(self._wait_for_control_responses(control_request.request_id, 60, executors=executors))
def _control_wakeup(self, control_request: ControlRequest):
"""
Reload offloaded gpu memory occupation for certain parts, e.g. weight, cache.
Args:
control_request: Control request object containing parameters for reloading memory
tags: list of tags to reload, supported values: ["weight", "kv_cache"]
"""
# Args check
tags = self._parse_tags(control_request)
control_request.args["tags"] = tags
# Determine which executors are needed for the wakeup command
executors = set()
if "weight" in tags:
executors.add("worker")
if "kv_cache" in tags:
executors.add("worker")
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
executors.add("cache_transfer")
# Dispatch wakeup request to executors
self.llm_logger.info(f"Dispatch wakeup request to executors: {list(executors)}")
self._dispatch_control_request(control_request, executors)
result = asyncio.run(self._wait_for_control_responses(control_request.request_id, 300, executors=executors))
# Resume the engine after wakeup
self._control_resume(None)
return result
def _dispatch_control_request(self, control_request: ControlRequest, executors: List[str]):
"""
Dispatch control requests to workers, cache managers or engine itself.
Args:
control_request: ControlRequest
executors: List
"""
if "worker" in executors:
self.engine_worker_queue.put_tasks(([control_request], 1))
if "cache_transfer" in executors:
if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend:
self.cache_task_queue.put_transfer_task((CacheStatus.CTRL, control_request))
return
async def _wait_for_control_responses(self, request_id: str, timeout: int, executors: List[str] = None):
"""Wait for matching control responses from the selected executor queues.
This helper selects the control-response queues that belong to the requested
executors, then waits for all of them concurrently. Each queue gets a local
waiter that keeps reading until it sees the target request ID and stashes stale
responses into that queue's mailbox.
Args:
request_id: The control request ID that all returned responses must match.
timeout: Global timeout budget in seconds for the full multi-queue wait.
executors: Executor groups to wait for, for example `["worker"]` or
`["worker", "cache_transfer"]`. If `None`, waits for all control
response queues.
Returns:
A list of `response.result` values collected from all matched
`ControlResponse` objects. If no queue is selected, returns `None`.
Raises:
Exception: If the overall wait times out, or if any queue reports a non-200
control response or fails while waiting.
"""
def select_control_queues(executors: List[str] = None):
"""Select control response queues by executors."""
if executors is None:
return self._ctrl_output_queues
else:
queues = {}
for k, v in self._ctrl_output_queues.items():
if "w2e" in k and "worker" in executors:
queues[k] = v
elif "c2e" in k and "cache_transfer" in executors:
queues[k] = v
return queues
async def wait_one(queue_name: str, queue):
"""Wait until one queue returns a response for the current request_id."""
mailbox = self._ctrl_response_mailboxes[queue_name]
# Reuse a previously stashed response for this request before touching FMQ again.
cached_response = mailbox.pop(request_id, None)
if cached_response is not None:
self.llm_logger.info(f"Returning cached control response from {queue_name}.")
return cached_response
while True:
msg = await queue.get()
# Return if the response matches the control request
response: ControlResponse = msg.payload
if response.request_id == request_id:
self.llm_logger.info(f"Returning new control response from {queue_name}.")
return response
# Stash late responses from other control requests so they do not consume the
# current request's only read chance on this queue.
mailbox[response.request_id] = response
self.llm_logger.info(
f"Stashed old control response from {queue_name}. "
f"Expected request {request_id}, got request {response.request_id}"
)
# Select only the control response queues that belong to the requested executors.
queues = select_control_queues(executors)
if not queues:
self.llm_logger.info(f"No queues to wait for, executors: {executors}")
return
self.llm_logger.info(f"Waiting for control responses from {len(queues)} queues: {list(queues.keys())}")
# Each queue gets its own waiter, which will stash stale responses until it finds the
# target request ID for this control request.
tasks = {name: asyncio.create_task(wait_one(name, queue)) for name, queue in queues.items()}
done, pending = await asyncio.wait(tasks.values(), timeout=timeout)
if pending:
pending_names = [name for name, task in tasks.items() if task in pending]
done_names = [name for name, task in tasks.items() if task in done]
self.llm_logger.error(
f"Control request {request_id} execution timeout. "
f"Pending queues: {pending_names}, completed queues: {done_names}."
)
# Stop unfinished queue waiters so they do not outlive the control request.
for task in pending:
task.cancel()
await asyncio.gather(*pending, return_exceptions=True)
raise Exception(f"Control request {request_id} timed out after {timeout}s")
# Collect the results from all completed queues.
responses = []
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
if isinstance(msg, Exception):
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
raise Exception(f"Call Worker error: {repr(msg)}")
if msg is None:
# Preserve original semantics when no message is received
raise Exception("Worker Update Weights Timeouted after 600s")
response: ControlResponse = msg.payload
if response.request_id != request_id:
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
continue
for name, task in tasks.items():
try:
response = task.result()
except Exception as e:
self.llm_logger.error(f"Waiting for control response from {name} failed: {repr(e)}")
raise
if response.error_code != 200:
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
raise Exception(f"Call Worker error: {response.error_message}")
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
raise Exception(f"Error response from {name}: {response.error_message}")
responses.append(response.result)
return responses
def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
return asyncio.run(self._wait_for_control_responses(request_id, timeout, executors=["worker"]))
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
self.llm_logger.error(