[RL] add pause, update_weights, resume interface for async RL (#6052)

* support dynamic run_control_request through zmq from apiserver to common_engine

* support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method

* change /is_puased from HTTP POST method to GET method

* add pause、resume、is_paused implementation

* support engine <==> worker communication(request&response)

* support sync weights through RDMA from checkpoint_transfer

* support specified version, rsync_config in update_weights rpc call

* add pause, update_weights, resume interface for async RL

* bug fix: update_weights support using default arguments

* fix typo

* typo fix

* typo fix

* typo fix

* add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all

* add "rsync" to LoadConfig.load_strategy Literal type hints

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* typo fix

* typo fix

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* check version/rsync params

* add error log when version.txt not exists

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* raise specified ValueError when paramters check failed

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* tp barrier after run_control_method

* encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue

* typo fix

* typo fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
wangyifei
2026-01-23 10:18:07 +08:00
committed by GitHub
parent 96b2cf2c20
commit b7c5daa316
18 changed files with 1170 additions and 16 deletions
+239 -1
View File
@@ -16,6 +16,7 @@
from __future__ import annotations
import asyncio
import copy
import json
import multiprocessing
@@ -38,7 +39,14 @@ import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.eplb.utils import init_eplb_signals
@@ -50,6 +58,7 @@ from fastdeploy.inter_communicator import (
ZmqIpcServer,
ZmqTcpServer,
)
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
@@ -89,6 +98,18 @@ class EngineService:
else:
self.llm_logger = llm_logger
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_worker_output_queues = []
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for rank in range(tp_size):
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
@@ -758,6 +779,8 @@ class EngineService:
def _fetch_request():
try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
nonlocal is_fetching
num_prefill_batch = min(
int(self.resource_manager.available_batch()),
@@ -922,6 +945,8 @@ class EngineService:
is_fetching = False
while self.running:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
try:
if self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
@@ -1065,6 +1090,17 @@ class EngineService:
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
continue
if ControlRequest.is_control_request(data):
try: # todo: run control request async, do not block request generation
control_req = ControlRequest.from_dict(data)
self.run_control_method(control_req)
except Exception as e:
self.llm_logger.error(
f"Failed to process control request {data.get('request_id')}: "
f"{e}, {traceback.format_exc()}"
)
continue
request, insert_task = data, []
results: List[Tuple[str, Optional[str]]] = list()
if data:
@@ -1096,6 +1132,13 @@ class EngineService:
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
self.llm_logger.debug(f"Receive request from api server: {request}")
if self.is_paused:
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
self._send_error_response(
request.request_id, "Request is aborted since LLM Engine is paused."
)
continue
except Exception as e:
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e)
@@ -1135,6 +1178,200 @@ class EngineService:
f"traceback={traceback.format_exc()}"
)
def run_control_method(self, control_req: ControlRequest):
"""
Execute control method, process control request and return response.
This method is responsible for handling control requests, calling the corresponding
handler function based on the method name in the request. If the method doesn't exist
or is not callable, it returns an error response; otherwise executes the method and
returns a success response.
Args:
control_req (ControlRequest): Control request object containing request ID,
method name and parameters.
Returns:
None: No return value, sends ControlResponse through send_response_server.
"""
method = control_req.get_method()
request_id = control_req.request_id
try:
self.llm_logger.info(f"START run control method {request_id}: {method}")
handler_name = f"_control_{method}"
handler = getattr(self, handler_name, None)
if handler is None or not callable(handler):
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
self.llm_logger.error(str(error_result))
self.send_response_server.send_response(request_id, [error_result])
return
result = handler(control_req)
self.llm_logger.info(f"SUCCESS run control method {method}.")
succ_result = ControlResponse(request_id, 200, "Success", result)
self.send_response_server.send_response(request_id, [succ_result])
except Exception as e:
error_msg = f"Failed run control method {method}: {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
self.send_response_server.send_response(request_id, [error_result])
def _control_pause(self, control_request: ControlRequest):
"""Pauses the LLM engine and aborts all running/inflight requests.
Args:
control_request: The control request containing pause command
Raises:
Exception: If pause is not supported in current configuration
Exception: If engine worker queue cleanup times out
Returns:
None
"""
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
if self.cfg.scheduler_config.name != "local":
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Pause Request Generation: already paused.")
self.is_paused = True
self.llm_logger.info("Start Abort Running Requests")
self.resource_manager.log_status()
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
timeout, count = 60, 0
while self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
count += 1
if count >= timeout * 1000:
break
if count >= timeout * 1000:
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
running_reqs = self.resource_manager.preempted_all()
if len(running_reqs) > 0:
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
self.resource_manager.get_real_bsz()
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
# self.engine_worker_queue.clear_data()
self.token_processor.clear_data()
self.resource_manager.log_status()
# abort inflight requests to user
inflight_requests = self.scheduler.get_inflight_requests()
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
for req in inflight_requests:
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
self.scheduler.reset()
self.resource_manager.cache_manager.reset()
return None
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
"""Control function for resuming request generation.
This method resumes the paused request generation process by setting the pause flag
and notifying all waiting threads. It logs the start and end of the resume operation.
Args:
control_request: Control request object containing resume operation information
"""
self.llm_logger.info("START Resume Request Generation")
with self._pause_cond:
if not self.is_paused:
self.llm_logger.info("Resume Request Generation: not paused.")
return None
self.is_paused = False
self._pause_cond.notify_all()
self.llm_logger.info("END Resume Request Generation")
return None
def _control_is_paused(self, control_request: ControlRequest) -> bool:
"""
Check if the LLM engine is in paused state.
Args:
control_request: Control request object.
Returns:
dict: Dictionary containing pause status information, {'is_paused': bool}
"""
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
with self._pause_cond:
return {"is_paused": self.is_paused}
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
"""Update model weights
Args:
control_request: Control request object containing parameters for weight updates
Returns:
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
Raises:
Exception: Raised when the engine is not in paused state
"""
self.llm_logger.info("Update Model Weights")
with self._pause_cond:
if self.is_paused is False:
error_msg = "Pause LLM Engine first before calling updating weights"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
return self._call_worker(control_request, 60)
async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
This method concurrently waits for responses from all control workers
and enforces an overall timeout to avoid leaking pending tasks.
"""
timeout_ms = timeout * 1000
# Create one get() coroutine per worker output queue
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=timeout,
)
except asyncio.TimeoutError:
# Keep the error message consistent with previous behavior
raise Exception("Worker Update Weights Timeouted after 600s")
responses = []
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
if isinstance(msg, Exception):
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
raise Exception(f"Call Worker error: {repr(msg)}")
if msg is None:
# Preserve original semantics when no message is received
raise Exception("Worker Update Weights Timeouted after 600s")
response: ControlResponse = msg.payload
if response.request_id != request_id:
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
continue
if response.error_code != 200:
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
raise Exception(f"Call Worker error: {response.error_message}")
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
responses.append(response.result)
return responses
def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
self.llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
@@ -1708,6 +1945,7 @@ class EngineService:
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"