mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,11 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -42,7 +44,7 @@ from fastdeploy.config import (
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
)
|
||||
from fastdeploy.engine.request import RequestType
|
||||
from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
|
||||
from fastdeploy.eplb.async_expert_loader import (
|
||||
MODEL_MAIN_NAME,
|
||||
REARRANGE_EXPERT_MAGIC_NUM,
|
||||
@@ -57,6 +59,7 @@ from fastdeploy.inter_communicator import (
|
||||
ModelWeightsStatus,
|
||||
RearrangeExpertStatus,
|
||||
)
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||
from fastdeploy.model_executor.utils import v1_loader_support
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -164,6 +167,12 @@ class PaddleDisWorkerProc:
|
||||
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
|
||||
def init_control(self):
|
||||
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
|
||||
queue_name = f"ctrl_w2e_rank{self.local_rank}_{engine_worker_queue_port}"
|
||||
logger.info(f"Init Control Output Queue: {queue_name}(producer)")
|
||||
self._ctrl_output = FMQ().queue(queue_name, "producer")
|
||||
|
||||
def init_health_status(self) -> None:
|
||||
"""
|
||||
Initialize the health status of the worker.
|
||||
@@ -513,10 +522,20 @@ class PaddleDisWorkerProc:
|
||||
else:
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||
|
||||
req_dicts = []
|
||||
req_dicts, control_reqs = [], []
|
||||
for req_dict, bsz in tasks:
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||
control_reqs.append(req_dict[0])
|
||||
else:
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
|
||||
# todo: run control request async
|
||||
if len(control_reqs) > 0:
|
||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||
for control_req in control_reqs:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
|
||||
# Count prefill requests in current batch
|
||||
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
||||
@@ -655,6 +674,32 @@ class PaddleDisWorkerProc:
|
||||
paddle.distributed.barrier()
|
||||
self.loaded_model_signal.value[0] = 1
|
||||
|
||||
def run_control_method(self, control_request: ControlRequest) -> None:
|
||||
logger.info(f"Start run control request: {control_request}")
|
||||
request_id = control_request.request_id
|
||||
method = control_request.method
|
||||
kwargs = control_request.args
|
||||
|
||||
handler = getattr(self.worker, method, None)
|
||||
if handler is None or not callable(handler):
|
||||
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
|
||||
error_result = ControlResponse(request_id, 400, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
return
|
||||
|
||||
try:
|
||||
result = handler(**kwargs)
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
logger.info(
|
||||
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
|
||||
)
|
||||
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
|
||||
except Exception as e:
|
||||
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
@@ -813,12 +858,18 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--load_strategy",
|
||||
type=str,
|
||||
choices=["ipc", "ipc_snapshot", "meta", "normal"],
|
||||
choices=["ipc", "ipc_snapshot", "meta", "normal", "rsync"],
|
||||
default="ipc_snapshot",
|
||||
help="Weight loading method when dynamic loading is enabled: "
|
||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rsync_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Rsync weights config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_logprob",
|
||||
action="store_true",
|
||||
@@ -1045,6 +1096,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
logger.info(f"- Rsync config: {load_config.rsync_config}, {type(load_config.rsync_config)}")
|
||||
|
||||
if not (
|
||||
current_platform.is_cuda()
|
||||
@@ -1112,6 +1164,7 @@ def run_worker_proc() -> None:
|
||||
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||
else:
|
||||
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||
worker_proc.init_control()
|
||||
|
||||
# Initialize device and create model runner
|
||||
worker_proc.init_device()
|
||||
|
||||
Reference in New Issue
Block a user