[RL] add pause, update_weights, resume interface for async RL (#6052)

* support dynamic run_control_request through zmq from apiserver to common_engine

* support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method

* change /is_puased from HTTP POST method to GET method

* add pause、resume、is_paused implementation

* support engine <==> worker communication(request&response)

* support sync weights through RDMA from checkpoint_transfer

* support specified version, rsync_config in update_weights rpc call

* add pause, update_weights, resume interface for async RL

* bug fix: update_weights support using default arguments

* fix typo

* typo fix

* typo fix

* typo fix

* add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all

* add "rsync" to LoadConfig.load_strategy Literal type hints

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* typo fix

* typo fix

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* check version/rsync params

* add error log when version.txt not exists

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* raise specified ValueError when paramters check failed

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* tp barrier after run_control_method

* encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue

* typo fix

* typo fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
wangyifei
2026-01-23 10:18:07 +08:00
committed by GitHub
parent 96b2cf2c20
commit b7c5daa316
18 changed files with 1170 additions and 16 deletions
+58 -5
View File
@@ -15,9 +15,11 @@
"""
import argparse
import asyncio
import json
import os
import time
import traceback
from typing import Tuple
import numpy as np
@@ -42,7 +44,7 @@ from fastdeploy.config import (
SpeculativeConfig,
StructuredOutputsConfig,
)
from fastdeploy.engine.request import RequestType
from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
from fastdeploy.eplb.async_expert_loader import (
MODEL_MAIN_NAME,
REARRANGE_EXPERT_MAGIC_NUM,
@@ -57,6 +59,7 @@ from fastdeploy.inter_communicator import (
ModelWeightsStatus,
RearrangeExpertStatus,
)
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.model_executor.utils import v1_loader_support
from fastdeploy.platforms import current_platform
@@ -164,6 +167,12 @@ class PaddleDisWorkerProc:
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
def init_control(self):
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
queue_name = f"ctrl_w2e_rank{self.local_rank}_{engine_worker_queue_port}"
logger.info(f"Init Control Output Queue: {queue_name}(producer)")
self._ctrl_output = FMQ().queue(queue_name, "producer")
def init_health_status(self) -> None:
"""
Initialize the health status of the worker.
@@ -513,10 +522,20 @@ class PaddleDisWorkerProc:
else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
req_dicts = []
req_dicts, control_reqs = [], []
for req_dict, bsz in tasks:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
control_reqs.append(req_dict[0])
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
@@ -655,6 +674,32 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier()
self.loaded_model_signal.value[0] = 1
def run_control_method(self, control_request: ControlRequest) -> None:
logger.info(f"Start run control request: {control_request}")
request_id = control_request.request_id
method = control_request.method
kwargs = control_request.args
handler = getattr(self.worker, method, None)
if handler is None or not callable(handler):
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
error_result = ControlResponse(request_id, 400, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
return
try:
result = handler(**kwargs)
succ_result = ControlResponse(request_id, 200, "Success", result)
logger.info(
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
)
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
except Exception as e:
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
def parse_args():
"""
@@ -813,12 +858,18 @@ def parse_args():
parser.add_argument(
"--load_strategy",
type=str,
choices=["ipc", "ipc_snapshot", "meta", "normal"],
choices=["ipc", "ipc_snapshot", "meta", "normal", "rsync"],
default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_snapshot': load from disk snapshot of IPC weights.",
)
parser.add_argument(
"--rsync_config",
type=json.loads,
default=None,
help="Rsync weights config",
)
parser.add_argument(
"--enable_logprob",
action="store_true",
@@ -1045,6 +1096,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
logger.info(f"- Rsync config: {load_config.rsync_config}, {type(load_config.rsync_config)}")
if not (
current_platform.is_cuda()
@@ -1112,6 +1164,7 @@ def run_worker_proc() -> None:
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
else:
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
worker_proc.init_control()
# Initialize device and create model runner
worker_proc.init_device()