mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
@@ -38,7 +39,14 @@ import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
Request,
|
||||
RequestOutput,
|
||||
RequestStatus,
|
||||
RequestType,
|
||||
)
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.eplb.utils import init_eplb_signals
|
||||
@@ -50,6 +58,7 @@ from fastdeploy.inter_communicator import (
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
@@ -89,6 +98,18 @@ class EngineService:
|
||||
else:
|
||||
self.llm_logger = llm_logger
|
||||
|
||||
self.is_paused = False # pause request generation
|
||||
self._pause_cond = threading.Condition()
|
||||
|
||||
self._ctrl_worker_output_queues = []
|
||||
tp_size = cfg.parallel_config.tensor_parallel_size
|
||||
dp_index = cfg.parallel_config.local_data_parallel_id
|
||||
for rank in range(tp_size):
|
||||
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
||||
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
|
||||
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
@@ -758,6 +779,8 @@ class EngineService:
|
||||
|
||||
def _fetch_request():
|
||||
try:
|
||||
with self._pause_cond:
|
||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||
nonlocal is_fetching
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
@@ -922,6 +945,8 @@ class EngineService:
|
||||
is_fetching = False
|
||||
|
||||
while self.running:
|
||||
with self._pause_cond:
|
||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||
try:
|
||||
if self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
@@ -1065,6 +1090,17 @@ class EngineService:
|
||||
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
||||
continue
|
||||
|
||||
if ControlRequest.is_control_request(data):
|
||||
try: # todo: run control request async, do not block request generation
|
||||
control_req = ControlRequest.from_dict(data)
|
||||
self.run_control_method(control_req)
|
||||
except Exception as e:
|
||||
self.llm_logger.error(
|
||||
f"Failed to process control request {data.get('request_id')}: "
|
||||
f"{e}, {traceback.format_exc()}"
|
||||
)
|
||||
continue
|
||||
|
||||
request, insert_task = data, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
@@ -1096,6 +1132,13 @@ class EngineService:
|
||||
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
||||
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
||||
self.llm_logger.debug(f"Receive request from api server: {request}")
|
||||
|
||||
if self.is_paused:
|
||||
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
|
||||
self._send_error_response(
|
||||
request.request_id, "Request is aborted since LLM Engine is paused."
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||
err_msg = str(e)
|
||||
@@ -1135,6 +1178,200 @@ class EngineService:
|
||||
f"traceback={traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def run_control_method(self, control_req: ControlRequest):
|
||||
"""
|
||||
Execute control method, process control request and return response.
|
||||
|
||||
This method is responsible for handling control requests, calling the corresponding
|
||||
handler function based on the method name in the request. If the method doesn't exist
|
||||
or is not callable, it returns an error response; otherwise executes the method and
|
||||
returns a success response.
|
||||
|
||||
Args:
|
||||
control_req (ControlRequest): Control request object containing request ID,
|
||||
method name and parameters.
|
||||
|
||||
Returns:
|
||||
None: No return value, sends ControlResponse through send_response_server.
|
||||
"""
|
||||
method = control_req.get_method()
|
||||
request_id = control_req.request_id
|
||||
|
||||
try:
|
||||
self.llm_logger.info(f"START run control method {request_id}: {method}")
|
||||
|
||||
handler_name = f"_control_{method}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is None or not callable(handler):
|
||||
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
|
||||
self.llm_logger.error(str(error_result))
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
return
|
||||
|
||||
result = handler(control_req)
|
||||
self.llm_logger.info(f"SUCCESS run control method {method}.")
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
self.send_response_server.send_response(request_id, [succ_result])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed run control method {method}: {str(e)}"
|
||||
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
|
||||
def _control_pause(self, control_request: ControlRequest):
|
||||
"""Pauses the LLM engine and aborts all running/inflight requests.
|
||||
Args:
|
||||
control_request: The control request containing pause command
|
||||
|
||||
Raises:
|
||||
Exception: If pause is not supported in current configuration
|
||||
Exception: If engine worker queue cleanup times out
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
||||
if self.cfg.scheduler_config.name != "local":
|
||||
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
|
||||
|
||||
with self._pause_cond:
|
||||
if self.is_paused:
|
||||
self.llm_logger.info("Pause Request Generation: already paused.")
|
||||
self.is_paused = True
|
||||
|
||||
self.llm_logger.info("Start Abort Running Requests")
|
||||
|
||||
self.resource_manager.log_status()
|
||||
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
|
||||
timeout, count = 60, 0
|
||||
while self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
count += 1
|
||||
if count >= timeout * 1000:
|
||||
break
|
||||
if count >= timeout * 1000:
|
||||
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
running_reqs = self.resource_manager.preempted_all()
|
||||
if len(running_reqs) > 0:
|
||||
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
|
||||
self.resource_manager.get_real_bsz()
|
||||
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
|
||||
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
|
||||
# self.engine_worker_queue.clear_data()
|
||||
self.token_processor.clear_data()
|
||||
self.resource_manager.log_status()
|
||||
|
||||
# abort inflight requests to user
|
||||
inflight_requests = self.scheduler.get_inflight_requests()
|
||||
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
|
||||
for req in inflight_requests:
|
||||
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
|
||||
self.scheduler.reset()
|
||||
|
||||
self.resource_manager.cache_manager.reset()
|
||||
return None
|
||||
|
||||
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Control function for resuming request generation.
|
||||
|
||||
This method resumes the paused request generation process by setting the pause flag
|
||||
and notifying all waiting threads. It logs the start and end of the resume operation.
|
||||
|
||||
Args:
|
||||
control_request: Control request object containing resume operation information
|
||||
"""
|
||||
self.llm_logger.info("START Resume Request Generation")
|
||||
with self._pause_cond:
|
||||
if not self.is_paused:
|
||||
self.llm_logger.info("Resume Request Generation: not paused.")
|
||||
return None
|
||||
self.is_paused = False
|
||||
self._pause_cond.notify_all()
|
||||
self.llm_logger.info("END Resume Request Generation")
|
||||
return None
|
||||
|
||||
def _control_is_paused(self, control_request: ControlRequest) -> bool:
|
||||
"""
|
||||
Check if the LLM engine is in paused state.
|
||||
|
||||
Args:
|
||||
control_request: Control request object.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing pause status information, {'is_paused': bool}
|
||||
"""
|
||||
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
|
||||
with self._pause_cond:
|
||||
return {"is_paused": self.is_paused}
|
||||
|
||||
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Update model weights
|
||||
Args:
|
||||
control_request: Control request object containing parameters for weight updates
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
|
||||
|
||||
Raises:
|
||||
Exception: Raised when the engine is not in paused state
|
||||
"""
|
||||
self.llm_logger.info("Update Model Weights")
|
||||
with self._pause_cond:
|
||||
if self.is_paused is False:
|
||||
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return self._call_worker(control_request, 60)
|
||||
|
||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||
"""Wait for control responses from all workers with a global timeout.
|
||||
|
||||
This method concurrently waits for responses from all control workers
|
||||
and enforces an overall timeout to avoid leaking pending tasks.
|
||||
"""
|
||||
timeout_ms = timeout * 1000
|
||||
# Create one get() coroutine per worker output queue
|
||||
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Keep the error message consistent with previous behavior
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
|
||||
responses = []
|
||||
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
|
||||
if isinstance(msg, Exception):
|
||||
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
|
||||
raise Exception(f"Call Worker error: {repr(msg)}")
|
||||
if msg is None:
|
||||
# Preserve original semantics when no message is received
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
response: ControlResponse = msg.payload
|
||||
if response.request_id != request_id:
|
||||
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
|
||||
continue
|
||||
if response.error_code != 200:
|
||||
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
|
||||
raise Exception(f"Call Worker error: {response.error_message}")
|
||||
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
|
||||
responses.append(response.result)
|
||||
return responses
|
||||
|
||||
def _call_worker(self, control_request: ControlRequest, timeout: int):
|
||||
request_id = control_request.request_id
|
||||
self.engine_worker_queue.put_tasks(([control_request], 1))
|
||||
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
||||
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
|
||||
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
||||
self.llm_logger.error(
|
||||
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
||||
@@ -1708,6 +1945,7 @@ class EngineService:
|
||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
|
||||
Reference in New Issue
Block a user