mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@@ -38,6 +39,7 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.async_llm import AsyncLLM
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.expert_service import ExpertService
|
||||
from fastdeploy.engine.request import ControlRequest
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
||||
@@ -370,6 +372,66 @@ def ping(raw_request: Request) -> Response:
|
||||
return health(raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/pause")
|
||||
async def pause(request: Request) -> Response:
|
||||
# todo: support wait_for_inflight_requests(default False), clear_cache(default True) arguments
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "pause")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/resume")
|
||||
async def resume(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "resume")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.get("/v1/is_paused")
|
||||
async def is_paused(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "is_paused")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/update_weights")
|
||||
async def update_weights(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
|
||||
request_data = await request.json() if await request.body() else {}
|
||||
|
||||
args = {}
|
||||
|
||||
# Validate and extract version parameter
|
||||
if "version" in request_data and request_data["version"] is not None:
|
||||
if not isinstance(request_data["version"], str):
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Invalid parameter type", "message": "version must be a string"}
|
||||
)
|
||||
args["version"] = request_data["version"]
|
||||
|
||||
# Validate and extract rsync_config parameter
|
||||
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
|
||||
if not isinstance(request_data["rsync_config"], dict):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
|
||||
)
|
||||
if "etcd_server" not in request_data["rsync_config"]:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
|
||||
)
|
||||
args["rsync_config"] = request_data["rsync_config"]
|
||||
|
||||
control_request = ControlRequest(request_id, "update_weights", args)
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
"""
|
||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||
|
||||
Reference in New Issue
Block a user