mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1225,7 +1225,8 @@ class LoadConfig:
|
||||
):
|
||||
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
|
||||
self.dynamic_load_weight: bool = False
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
|
||||
self.rsync_config: Optional[Dict[str, Any]] = None
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -195,6 +195,10 @@ class EngineArgs:
|
||||
"""
|
||||
dynamic load weight strategy
|
||||
"""
|
||||
rsync_config: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
rsync weights config info
|
||||
"""
|
||||
quantization: Optional[Dict[str, Any]] = None
|
||||
guided_decoding_backend: str = "off"
|
||||
"""
|
||||
@@ -812,6 +816,12 @@ class EngineArgs:
|
||||
default=EngineArgs.load_strategy,
|
||||
help="Flag to dynamic load strategy.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rsync-config",
|
||||
type=json.loads,
|
||||
default=EngineArgs.rsync_config,
|
||||
help="Rsync weights config",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--engine-worker-queue-port",
|
||||
type=lambda s: s.split(",") if s else None,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
@@ -38,7 +39,14 @@ import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
Request,
|
||||
RequestOutput,
|
||||
RequestStatus,
|
||||
RequestType,
|
||||
)
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.eplb.utils import init_eplb_signals
|
||||
@@ -50,6 +58,7 @@ from fastdeploy.inter_communicator import (
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
@@ -89,6 +98,18 @@ class EngineService:
|
||||
else:
|
||||
self.llm_logger = llm_logger
|
||||
|
||||
self.is_paused = False # pause request generation
|
||||
self._pause_cond = threading.Condition()
|
||||
|
||||
self._ctrl_worker_output_queues = []
|
||||
tp_size = cfg.parallel_config.tensor_parallel_size
|
||||
dp_index = cfg.parallel_config.local_data_parallel_id
|
||||
for rank in range(tp_size):
|
||||
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
||||
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
|
||||
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
@@ -758,6 +779,8 @@ class EngineService:
|
||||
|
||||
def _fetch_request():
|
||||
try:
|
||||
with self._pause_cond:
|
||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||
nonlocal is_fetching
|
||||
num_prefill_batch = min(
|
||||
int(self.resource_manager.available_batch()),
|
||||
@@ -922,6 +945,8 @@ class EngineService:
|
||||
is_fetching = False
|
||||
|
||||
while self.running:
|
||||
with self._pause_cond:
|
||||
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||
try:
|
||||
if self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
@@ -1065,6 +1090,17 @@ class EngineService:
|
||||
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
||||
continue
|
||||
|
||||
if ControlRequest.is_control_request(data):
|
||||
try: # todo: run control request async, do not block request generation
|
||||
control_req = ControlRequest.from_dict(data)
|
||||
self.run_control_method(control_req)
|
||||
except Exception as e:
|
||||
self.llm_logger.error(
|
||||
f"Failed to process control request {data.get('request_id')}: "
|
||||
f"{e}, {traceback.format_exc()}"
|
||||
)
|
||||
continue
|
||||
|
||||
request, insert_task = data, []
|
||||
results: List[Tuple[str, Optional[str]]] = list()
|
||||
if data:
|
||||
@@ -1096,6 +1132,13 @@ class EngineService:
|
||||
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
||||
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
||||
self.llm_logger.debug(f"Receive request from api server: {request}")
|
||||
|
||||
if self.is_paused:
|
||||
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
|
||||
self._send_error_response(
|
||||
request.request_id, "Request is aborted since LLM Engine is paused."
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||
err_msg = str(e)
|
||||
@@ -1135,6 +1178,200 @@ class EngineService:
|
||||
f"traceback={traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def run_control_method(self, control_req: ControlRequest):
|
||||
"""
|
||||
Execute control method, process control request and return response.
|
||||
|
||||
This method is responsible for handling control requests, calling the corresponding
|
||||
handler function based on the method name in the request. If the method doesn't exist
|
||||
or is not callable, it returns an error response; otherwise executes the method and
|
||||
returns a success response.
|
||||
|
||||
Args:
|
||||
control_req (ControlRequest): Control request object containing request ID,
|
||||
method name and parameters.
|
||||
|
||||
Returns:
|
||||
None: No return value, sends ControlResponse through send_response_server.
|
||||
"""
|
||||
method = control_req.get_method()
|
||||
request_id = control_req.request_id
|
||||
|
||||
try:
|
||||
self.llm_logger.info(f"START run control method {request_id}: {method}")
|
||||
|
||||
handler_name = f"_control_{method}"
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is None or not callable(handler):
|
||||
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
|
||||
self.llm_logger.error(str(error_result))
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
return
|
||||
|
||||
result = handler(control_req)
|
||||
self.llm_logger.info(f"SUCCESS run control method {method}.")
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
self.send_response_server.send_response(request_id, [succ_result])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed run control method {method}: {str(e)}"
|
||||
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
|
||||
def _control_pause(self, control_request: ControlRequest):
|
||||
"""Pauses the LLM engine and aborts all running/inflight requests.
|
||||
Args:
|
||||
control_request: The control request containing pause command
|
||||
|
||||
Raises:
|
||||
Exception: If pause is not supported in current configuration
|
||||
Exception: If engine worker queue cleanup times out
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
||||
if self.cfg.scheduler_config.name != "local":
|
||||
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
|
||||
|
||||
with self._pause_cond:
|
||||
if self.is_paused:
|
||||
self.llm_logger.info("Pause Request Generation: already paused.")
|
||||
self.is_paused = True
|
||||
|
||||
self.llm_logger.info("Start Abort Running Requests")
|
||||
|
||||
self.resource_manager.log_status()
|
||||
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
|
||||
timeout, count = 60, 0
|
||||
while self.engine_worker_queue.exist_tasks():
|
||||
time.sleep(0.001)
|
||||
count += 1
|
||||
if count >= timeout * 1000:
|
||||
break
|
||||
if count >= timeout * 1000:
|
||||
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
running_reqs = self.resource_manager.preempted_all()
|
||||
if len(running_reqs) > 0:
|
||||
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
|
||||
self.resource_manager.get_real_bsz()
|
||||
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
|
||||
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
|
||||
# self.engine_worker_queue.clear_data()
|
||||
self.token_processor.clear_data()
|
||||
self.resource_manager.log_status()
|
||||
|
||||
# abort inflight requests to user
|
||||
inflight_requests = self.scheduler.get_inflight_requests()
|
||||
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
|
||||
for req in inflight_requests:
|
||||
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
|
||||
self.scheduler.reset()
|
||||
|
||||
self.resource_manager.cache_manager.reset()
|
||||
return None
|
||||
|
||||
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Control function for resuming request generation.
|
||||
|
||||
This method resumes the paused request generation process by setting the pause flag
|
||||
and notifying all waiting threads. It logs the start and end of the resume operation.
|
||||
|
||||
Args:
|
||||
control_request: Control request object containing resume operation information
|
||||
"""
|
||||
self.llm_logger.info("START Resume Request Generation")
|
||||
with self._pause_cond:
|
||||
if not self.is_paused:
|
||||
self.llm_logger.info("Resume Request Generation: not paused.")
|
||||
return None
|
||||
self.is_paused = False
|
||||
self._pause_cond.notify_all()
|
||||
self.llm_logger.info("END Resume Request Generation")
|
||||
return None
|
||||
|
||||
def _control_is_paused(self, control_request: ControlRequest) -> bool:
|
||||
"""
|
||||
Check if the LLM engine is in paused state.
|
||||
|
||||
Args:
|
||||
control_request: Control request object.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing pause status information, {'is_paused': bool}
|
||||
"""
|
||||
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
|
||||
with self._pause_cond:
|
||||
return {"is_paused": self.is_paused}
|
||||
|
||||
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Update model weights
|
||||
Args:
|
||||
control_request: Control request object containing parameters for weight updates
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
|
||||
|
||||
Raises:
|
||||
Exception: Raised when the engine is not in paused state
|
||||
"""
|
||||
self.llm_logger.info("Update Model Weights")
|
||||
with self._pause_cond:
|
||||
if self.is_paused is False:
|
||||
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return self._call_worker(control_request, 60)
|
||||
|
||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||
"""Wait for control responses from all workers with a global timeout.
|
||||
|
||||
This method concurrently waits for responses from all control workers
|
||||
and enforces an overall timeout to avoid leaking pending tasks.
|
||||
"""
|
||||
timeout_ms = timeout * 1000
|
||||
# Create one get() coroutine per worker output queue
|
||||
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Keep the error message consistent with previous behavior
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
|
||||
responses = []
|
||||
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
|
||||
if isinstance(msg, Exception):
|
||||
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
|
||||
raise Exception(f"Call Worker error: {repr(msg)}")
|
||||
if msg is None:
|
||||
# Preserve original semantics when no message is received
|
||||
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||
response: ControlResponse = msg.payload
|
||||
if response.request_id != request_id:
|
||||
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
|
||||
continue
|
||||
if response.error_code != 200:
|
||||
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
|
||||
raise Exception(f"Call Worker error: {response.error_message}")
|
||||
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
|
||||
responses.append(response.result)
|
||||
return responses
|
||||
|
||||
def _call_worker(self, control_request: ControlRequest, timeout: int):
|
||||
request_id = control_request.request_id
|
||||
self.engine_worker_queue.put_tasks(([control_request], 1))
|
||||
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
||||
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
|
||||
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
||||
self.llm_logger.error(
|
||||
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
||||
@@ -1708,6 +1945,7 @@ class EngineService:
|
||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
|
||||
@@ -556,6 +556,7 @@ class LLMEngine:
|
||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import TypeVar as TypingTypeVar
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
@@ -99,7 +100,7 @@ class Request:
|
||||
guided_json_object: Optional[bool] = None,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
reasoning_max_tokens: Optional[int] = None,
|
||||
trace_carrier: dict = dict(),
|
||||
trace_carrier: Optional[Dict[str, Any]] = None,
|
||||
dp_rank: Optional[int] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
image_start: int = 0,
|
||||
@@ -544,6 +545,157 @@ class Request:
|
||||
return hasattr(self, key)
|
||||
|
||||
|
||||
class ControlRequest:
|
||||
"""A generic control request that supports method and args for control operations.
|
||||
|
||||
This request type is used for system-level control operations rather than
|
||||
typical inference requests. It enables dynamic control of engine behavior,
|
||||
resource management, and system configuration via a flexible method-args interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
method: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
request_id: Unique identifier for the control request.
|
||||
method: The control method to execute (e.g., "reset_scheduler", "get_metrics").
|
||||
args: Optional arguments for the control method.
|
||||
"""
|
||||
self.request_id = request_id
|
||||
self.method = method
|
||||
self.args = args or {}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
"""Create ControlRequest instance from dictionary."""
|
||||
return cls(request_id=d["request_id"], method=d["method"], args=d.get("args", {}))
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert ControlRequest into a serializable dict."""
|
||||
return {"request_id": self.request_id, "method": self.method, "args": self.args}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Provide a clean representation of the control request."""
|
||||
try:
|
||||
if not envs.FD_DEBUG:
|
||||
return f"ControlRequest(request_id={self.request_id}, method={self.method})"
|
||||
else:
|
||||
return (
|
||||
f"ControlRequest("
|
||||
f"request_id={self.request_id}, "
|
||||
f"method={self.method}, "
|
||||
f"args={self.args}"
|
||||
f")"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"<ControlRequest repr failed: {e}>"
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""Get the control method name."""
|
||||
return self.method
|
||||
|
||||
def get_args(self) -> Dict[str, Any]:
|
||||
"""Get the control method arguments."""
|
||||
return self.args.copy()
|
||||
|
||||
@staticmethod
|
||||
def is_control_request(d: dict) -> bool:
|
||||
"""
|
||||
Check if a dictionary represents a valid ControlRequest.
|
||||
|
||||
Args:
|
||||
d: Dictionary to check
|
||||
|
||||
Returns:
|
||||
bool: True if the dictionary contains the required fields for a ControlRequest
|
||||
"""
|
||||
|
||||
# Check if all required fields are present and have correct types
|
||||
if not isinstance(d, dict):
|
||||
return False
|
||||
|
||||
# Check field types
|
||||
if "request_id" not in d or not isinstance(d.get("request_id"), str):
|
||||
return False
|
||||
|
||||
if "method" not in d or not isinstance(d.get("method"), str):
|
||||
return False
|
||||
|
||||
# Args is optional, but if present should be a dict
|
||||
if "args" in d and not isinstance(d["args"], dict):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ControlResponse:
|
||||
"""
|
||||
Response for control operations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
error_code: int = 200,
|
||||
error_message: Optional[str] = None,
|
||||
result: Optional[dict] = None,
|
||||
finished: bool = True,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.finished = finished
|
||||
self.error_message = error_message
|
||||
self.result = result
|
||||
self.error_code = error_code
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert ControlResponse into a serializable dict."""
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"finished": self.finished,
|
||||
"error_code": self.error_code,
|
||||
"error_message": self.error_message,
|
||||
"result": self.result,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
"""Create ControlResponse instance from dictionary."""
|
||||
return cls(
|
||||
request_id=d["request_id"],
|
||||
finished=d.get("finished", True),
|
||||
error_code=d.get("error_code", 200),
|
||||
error_message=d.get("error_message"),
|
||||
result=d.get("result"),
|
||||
)
|
||||
|
||||
def to_api_json_response(self) -> JSONResponse:
|
||||
"""Convert ControlResponse into a JSONResponse."""
|
||||
status = "success" if self.error_code == 200 else "error"
|
||||
content = {
|
||||
"request_id": self.request_id,
|
||||
"status": status,
|
||||
"error_message": self.error_message,
|
||||
"result": self.result,
|
||||
}
|
||||
return JSONResponse(status_code=self.error_code, content=content)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Provide a clean representation of the control response."""
|
||||
return (
|
||||
f"ControlResponse("
|
||||
f"request_id={self.request_id}, "
|
||||
f"finished={self.finished}, "
|
||||
f"error_code={self.error_code}, "
|
||||
f"error_message={self.error_message}, "
|
||||
f"result={self.result}"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
@@ -240,6 +241,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
def reschedule_preempt_task(self, request_id, process_func=None):
|
||||
with self.lock:
|
||||
llm_logger.debug(f"reschedule {request_id} into waiting queue")
|
||||
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
|
||||
request = self.requests[request_id]
|
||||
if process_func is not None:
|
||||
@@ -266,6 +268,39 @@ class ResourceManagerV1(ResourceManager):
|
||||
return True
|
||||
return False
|
||||
|
||||
def preempted_all(self):
|
||||
with self.lock:
|
||||
preempted_reqs = []
|
||||
for i in range(len(self.running)):
|
||||
req = self.running.pop()
|
||||
# txt2image: req.use_extend_tables is True, req can not be preempted. txt2image is not used in RL.
|
||||
if req.use_extend_tables:
|
||||
self.running.insert(0, req)
|
||||
continue
|
||||
req.status = RequestStatus.PREEMPTED
|
||||
req.num_computed_tokens = 0
|
||||
self._free_blocks(req)
|
||||
req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(req.request_id)
|
||||
preempted_reqs.append(self._prepare_preempt_task(req))
|
||||
return preempted_reqs
|
||||
|
||||
def wait_worker_inflight_requests_finish(self, timeout=60):
|
||||
count = 0
|
||||
while count < timeout * 1000:
|
||||
# wait ongoing running and rescheduled requests finished in worker
|
||||
running_reqs_count = len(self.to_be_rescheduled_request_id_set) + len(self.running)
|
||||
if running_reqs_count == 0:
|
||||
break
|
||||
|
||||
count += 1
|
||||
time.sleep(0.001)
|
||||
if count >= timeout * 1000:
|
||||
llm_logger.info(
|
||||
f"wait_inflight_requests_finish timeout after {timeout} seconds, "
|
||||
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
|
||||
)
|
||||
|
||||
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
||||
"""
|
||||
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
|
||||
@@ -1347,3 +1382,16 @@ class ResourceManagerV1(ResourceManager):
|
||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||
main_process_metrics.num_requests_running.set(len(self.running))
|
||||
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))
|
||||
|
||||
def log_status(self):
|
||||
llm_logger.info(
|
||||
f"ResourceManagerV1( "
|
||||
f"waiting={len(self.waiting)}, "
|
||||
f"running={len(self.running)}, "
|
||||
f"preempted={len(self.to_be_rescheduled_request_id_set)}, "
|
||||
f"tasks_list={self.tasks_list}, "
|
||||
f"stop_flags={self.stop_flags}, "
|
||||
f"req_dict={self.req_dict}, "
|
||||
f"requests={self.requests}, "
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
@@ -29,7 +30,12 @@ from filelock import FileLock
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestStatus
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
Request,
|
||||
RequestStatus,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||
@@ -526,6 +532,23 @@ class EngineClient:
|
||||
|
||||
return True, ""
|
||||
|
||||
async def run_control_method(self, request: ControlRequest):
|
||||
api_server_logger.info(f"Start Run Control Method: {request}")
|
||||
self.zmq_client.send_json(request.to_dict())
|
||||
request_id = request.request_id
|
||||
dealer, response_queue = await self.connection_manager.get_connection(request_id)
|
||||
dealer.write([b"", request_id.encode("utf-8")])
|
||||
try:
|
||||
# todo: support user specified timeout. default 600s is enough for most control cases
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=600)
|
||||
response = ControlResponse.from_dict(response[0])
|
||||
api_server_logger.info(f"End Run Control Method: {response}")
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
|
||||
api_server_logger.error(f"Error Run Control Method: {error_response}")
|
||||
return error_response
|
||||
|
||||
def is_workers_alive(self):
|
||||
"""
|
||||
Check the health of the model server by checking whether all workers are alive.
|
||||
|
||||
@@ -20,6 +20,7 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@@ -38,6 +39,7 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.async_llm import AsyncLLM
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.expert_service import ExpertService
|
||||
from fastdeploy.engine.request import ControlRequest
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
||||
@@ -370,6 +372,66 @@ def ping(raw_request: Request) -> Response:
|
||||
return health(raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/pause")
|
||||
async def pause(request: Request) -> Response:
|
||||
# todo: support wait_for_inflight_requests(default False), clear_cache(default True) arguments
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "pause")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/resume")
|
||||
async def resume(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "resume")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.get("/v1/is_paused")
|
||||
async def is_paused(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
control_request = ControlRequest(request_id, "is_paused")
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
@app.post("/v1/update_weights")
|
||||
async def update_weights(request: Request) -> Response:
|
||||
request_id = f"control-{uuid.uuid4()}"
|
||||
|
||||
request_data = await request.json() if await request.body() else {}
|
||||
|
||||
args = {}
|
||||
|
||||
# Validate and extract version parameter
|
||||
if "version" in request_data and request_data["version"] is not None:
|
||||
if not isinstance(request_data["version"], str):
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Invalid parameter type", "message": "version must be a string"}
|
||||
)
|
||||
args["version"] = request_data["version"]
|
||||
|
||||
# Validate and extract rsync_config parameter
|
||||
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
|
||||
if not isinstance(request_data["rsync_config"], dict):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
|
||||
)
|
||||
if "etcd_server" not in request_data["rsync_config"]:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
|
||||
)
|
||||
args["rsync_config"] = request_data["rsync_config"]
|
||||
|
||||
control_request = ControlRequest(request_id, "update_weights", args)
|
||||
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||
return control_response.to_api_json_response()
|
||||
|
||||
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
"""
|
||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||
|
||||
@@ -214,7 +214,7 @@ class Queue(BaseComponent):
|
||||
else:
|
||||
self.socket.bind(full_ep)
|
||||
|
||||
fmq_logger.info(f"Queue {name} initialized on {full_ep}")
|
||||
fmq_logger.info(f"Queue {name}({role}) initialized on {full_ep}")
|
||||
|
||||
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
|
||||
"""
|
||||
|
||||
@@ -1034,7 +1034,7 @@ class TokenProcessor:
|
||||
finished=True,
|
||||
metrics=RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
request_start_time=task.metrics.arrival_time,
|
||||
),
|
||||
)
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
@@ -27,13 +28,38 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
|
||||
|
||||
|
||||
def sync_weights_by_rdma(config, step, rank):
|
||||
from checkpoint_transfer.core import RDMAWeightsDownloader
|
||||
|
||||
downloader = RDMAWeightsDownloader(config)
|
||||
downloader.initialize()
|
||||
logger.info(f"Fetching weights for step:{step}, rank:{rank}...")
|
||||
data = downloader.get_weights(step, rank)
|
||||
if data is None:
|
||||
logger.error("Failed to get weights!")
|
||||
raise Exception("Failed to rsync weights through checkpoint_transfer")
|
||||
logger.info(f"Successfully retrieved data. Type: {type(data)}")
|
||||
if isinstance(data, np.ndarray):
|
||||
data_bytes = data.tobytes()
|
||||
elif isinstance(data, (bytes, bytearray)):
|
||||
data_bytes = data
|
||||
else:
|
||||
data_bytes = bytes(data)
|
||||
logger.info(f"Data size: {len(data_bytes)} bytes")
|
||||
|
||||
buffer = io.BytesIO(data_bytes)
|
||||
new_state_dict = paddle.load(buffer)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
class DynamicWeightManager:
|
||||
"""Manages model weights loading, updating and shared state across processes."""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, models):
|
||||
def __init__(self, fd_config: FDConfig, models, local_rank: int):
|
||||
"""Initialize with config and model instances."""
|
||||
self.fd_config = fd_config
|
||||
self.load_config = fd_config.load_config
|
||||
self.local_rank = local_rank
|
||||
self.parallel_config = fd_config.parallel_config
|
||||
self.state_dict: Dict[str, paddle.Tensor] = {}
|
||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
@@ -46,7 +72,10 @@ class DynamicWeightManager:
|
||||
else:
|
||||
self.model_list = models
|
||||
self._capture_model_state()
|
||||
self.update_parameters()
|
||||
if self.load_config.load_strategy == "rsync":
|
||||
self.update_weights_by_rdma()
|
||||
else:
|
||||
self.update_parameters()
|
||||
self.finalize_update()
|
||||
|
||||
logger.info(
|
||||
@@ -62,6 +91,74 @@ class DynamicWeightManager:
|
||||
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
|
||||
def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||
def valid_parameters(old_state_dict, new_state_dict):
|
||||
is_valid = True
|
||||
for key in old_state_dict:
|
||||
if key not in new_state_dict:
|
||||
is_valid = False
|
||||
logger.error(f"Invalid parameter: {key} not in new_state_dict")
|
||||
elif old_state_dict[key].shape != new_state_dict[key].shape:
|
||||
is_valid = False
|
||||
logger.error(
|
||||
f"Invalid parameter: {key} shape mismatch, "
|
||||
f"new shape:{new_state_dict[key].shape}, "
|
||||
f"old shape:{old_state_dict[key].shape}"
|
||||
)
|
||||
elif old_state_dict[key].dtype != new_state_dict[key].dtype:
|
||||
is_valid = False
|
||||
logger.error(f"Invalid parameter: {key} dtype mismatch")
|
||||
return is_valid
|
||||
|
||||
if rsync_config is None:
|
||||
rsync_config = self.fd_config.load_config.rsync_config
|
||||
if rsync_config is None or len(rsync_config) == 0:
|
||||
raise Exception(
|
||||
"rsync config not set, please set it in 1) launch arguments '--rsync-config' "
|
||||
"or 2) interface arguments 'rsync_config'"
|
||||
)
|
||||
|
||||
if version is None or version == "":
|
||||
version = self.read_model_version_from_file()
|
||||
if version is None or version == "":
|
||||
raise Exception(
|
||||
"rsync model version not set, please set it in 1) {model_version}/version.txt "
|
||||
"or 2) interface arguments 'version'"
|
||||
)
|
||||
|
||||
logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}")
|
||||
rank = self.local_rank
|
||||
|
||||
sync_start = time.perf_counter()
|
||||
new_state_dict = sync_weights_by_rdma(rsync_config, version, rank)
|
||||
sync_cost = time.perf_counter() - sync_start
|
||||
logger.info(f"weights sync cost {sync_cost:.2f} seconds")
|
||||
|
||||
old_state_dict = self.state_dict
|
||||
if not valid_parameters(old_state_dict, new_state_dict):
|
||||
error_msg = "Invalid new_state_dict, update parameters failed"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
update_start = time.perf_counter()
|
||||
for name, param in old_state_dict.items():
|
||||
param.set_value(new_state_dict[name])
|
||||
update_cost = time.perf_counter() - update_start
|
||||
logger.info(f"params set value cost {update_cost:.2f} seconds")
|
||||
|
||||
total_cost = time.perf_counter() - sync_start
|
||||
logger.info(
|
||||
f"END update_weights_by_rdma, cost {total_cost:.2f} seconds"
|
||||
f" version:{version}, rsync_config: {rsync_config}",
|
||||
)
|
||||
return {
|
||||
"sync_cost": sync_cost,
|
||||
"update_cost": update_cost,
|
||||
"total_cost": total_cost,
|
||||
"version": version,
|
||||
"rank": rank,
|
||||
}
|
||||
|
||||
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
start_time = time.perf_counter()
|
||||
@@ -257,6 +354,17 @@ class DynamicWeightManager:
|
||||
if self.rank == 0:
|
||||
value[self.rank] = status
|
||||
|
||||
def read_model_version_from_file(self):
|
||||
model_dir = self.fd_config.model_config.model
|
||||
version_file = os.path.join(model_dir, "version.txt")
|
||||
try:
|
||||
with open(version_file, "r", encoding="utf-8") as f:
|
||||
version = f.read().strip()
|
||||
return version
|
||||
except (FileNotFoundError, OSError, IOError) as e:
|
||||
logger.error(f"Failed to read model version file '{version_file}': {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
|
||||
"""
|
||||
|
||||
@@ -158,6 +158,10 @@ class LocalScheduler:
|
||||
else:
|
||||
self.ids_read_cursor -= len(expired_ids)
|
||||
|
||||
def get_inflight_requests(self) -> List[Request]:
|
||||
with self.mutex:
|
||||
return [request.raw for request in self.requests.values()]
|
||||
|
||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
Add new requests to the scheduler queue.
|
||||
|
||||
@@ -20,7 +20,7 @@ import queue
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from threading import Thread
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -1518,7 +1518,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
|
||||
|
||||
# 2. Load lora model
|
||||
|
||||
@@ -2798,6 +2798,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
"""
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import gc
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
import pynvml
|
||||
@@ -188,6 +188,10 @@ class GpuWorker(WorkerBase):
|
||||
# accurate cache size
|
||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||
"""update weights in place"""
|
||||
return self.model_runner.update_weights(version, rsync_config)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -42,7 +44,7 @@ from fastdeploy.config import (
|
||||
SpeculativeConfig,
|
||||
StructuredOutputsConfig,
|
||||
)
|
||||
from fastdeploy.engine.request import RequestType
|
||||
from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
|
||||
from fastdeploy.eplb.async_expert_loader import (
|
||||
MODEL_MAIN_NAME,
|
||||
REARRANGE_EXPERT_MAGIC_NUM,
|
||||
@@ -57,6 +59,7 @@ from fastdeploy.inter_communicator import (
|
||||
ModelWeightsStatus,
|
||||
RearrangeExpertStatus,
|
||||
)
|
||||
from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||
from fastdeploy.model_executor.utils import v1_loader_support
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -164,6 +167,12 @@ class PaddleDisWorkerProc:
|
||||
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
|
||||
def init_control(self):
|
||||
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
|
||||
queue_name = f"ctrl_w2e_rank{self.local_rank}_{engine_worker_queue_port}"
|
||||
logger.info(f"Init Control Output Queue: {queue_name}(producer)")
|
||||
self._ctrl_output = FMQ().queue(queue_name, "producer")
|
||||
|
||||
def init_health_status(self) -> None:
|
||||
"""
|
||||
Initialize the health status of the worker.
|
||||
@@ -513,10 +522,20 @@ class PaddleDisWorkerProc:
|
||||
else:
|
||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||
|
||||
req_dicts = []
|
||||
req_dicts, control_reqs = [], []
|
||||
for req_dict, bsz in tasks:
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||
control_reqs.append(req_dict[0])
|
||||
else:
|
||||
max_occupied_batch_index = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
|
||||
# todo: run control request async
|
||||
if len(control_reqs) > 0:
|
||||
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||
for control_req in control_reqs:
|
||||
self.run_control_method(control_req)
|
||||
self._tp_barrier_wait() if tp_size > 1 else None
|
||||
|
||||
# Count prefill requests in current batch
|
||||
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
||||
@@ -655,6 +674,32 @@ class PaddleDisWorkerProc:
|
||||
paddle.distributed.barrier()
|
||||
self.loaded_model_signal.value[0] = 1
|
||||
|
||||
def run_control_method(self, control_request: ControlRequest) -> None:
|
||||
logger.info(f"Start run control request: {control_request}")
|
||||
request_id = control_request.request_id
|
||||
method = control_request.method
|
||||
kwargs = control_request.args
|
||||
|
||||
handler = getattr(self.worker, method, None)
|
||||
if handler is None or not callable(handler):
|
||||
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
|
||||
error_result = ControlResponse(request_id, 400, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
return
|
||||
|
||||
try:
|
||||
result = handler(**kwargs)
|
||||
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||
logger.info(
|
||||
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
|
||||
)
|
||||
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
|
||||
except Exception as e:
|
||||
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
|
||||
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
error_result = ControlResponse(request_id, 500, error_msg)
|
||||
asyncio.run(self._ctrl_output.put(error_result))
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
@@ -813,12 +858,18 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--load_strategy",
|
||||
type=str,
|
||||
choices=["ipc", "ipc_snapshot", "meta", "normal"],
|
||||
choices=["ipc", "ipc_snapshot", "meta", "normal", "rsync"],
|
||||
default="ipc_snapshot",
|
||||
help="Weight loading method when dynamic loading is enabled: "
|
||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rsync_config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Rsync weights config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_logprob",
|
||||
action="store_true",
|
||||
@@ -1045,6 +1096,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
logger.info(f"- Rsync config: {load_config.rsync_config}, {type(load_config.rsync_config)}")
|
||||
|
||||
if not (
|
||||
current_platform.is_cuda()
|
||||
@@ -1112,6 +1164,7 @@ def run_worker_proc() -> None:
|
||||
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||
else:
|
||||
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||
worker_proc.init_control()
|
||||
|
||||
# Initialize device and create model runner
|
||||
worker_proc.init_device()
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastdeploy.engine.request import ControlRequest, ControlResponse
|
||||
|
||||
|
||||
class TestControlRequest(unittest.TestCase):
|
||||
"""Test cases for ControlRequest class."""
|
||||
|
||||
def test_initialization_basic(self):
|
||||
"""Test basic initialization of ControlRequest."""
|
||||
request_id = "test_request_123"
|
||||
method = "get_metrics"
|
||||
|
||||
request = ControlRequest(request_id=request_id, method=method)
|
||||
|
||||
self.assertEqual(request.request_id, request_id)
|
||||
self.assertEqual(request.method, method)
|
||||
self.assertEqual(request.args, {})
|
||||
|
||||
def test_initialization_with_args(self):
|
||||
"""Test initialization with arguments."""
|
||||
request_id = "test_request_456"
|
||||
method = "reset_scheduler"
|
||||
args = {"force": True, "timeout": 30}
|
||||
|
||||
request = ControlRequest(request_id=request_id, method=method, args=args)
|
||||
|
||||
self.assertEqual(request.request_id, request_id)
|
||||
self.assertEqual(request.method, method)
|
||||
self.assertEqual(request.args, args)
|
||||
|
||||
def test_from_dict_basic(self):
|
||||
"""Test creating ControlRequest from dictionary (basic case)."""
|
||||
data = {"request_id": "test_from_dict", "method": "status_check"}
|
||||
|
||||
request = ControlRequest.from_dict(data)
|
||||
|
||||
self.assertEqual(request.request_id, data["request_id"])
|
||||
self.assertEqual(request.method, data["method"])
|
||||
self.assertEqual(request.args, {})
|
||||
|
||||
def test_from_dict_with_args(self):
|
||||
"""Test creating ControlRequest from dictionary with arguments."""
|
||||
data = {
|
||||
"request_id": "test_from_dict_args",
|
||||
"method": "configure",
|
||||
"args": {"max_requests": 100, "queue_timeout": 60},
|
||||
}
|
||||
|
||||
request = ControlRequest.from_dict(data)
|
||||
|
||||
self.assertEqual(request.request_id, data["request_id"])
|
||||
self.assertEqual(request.method, data["method"])
|
||||
self.assertEqual(request.args, data["args"])
|
||||
|
||||
def test_to_dict_basic(self):
|
||||
"""Test converting ControlRequest to dictionary (basic case)."""
|
||||
request = ControlRequest(request_id="test_to_dict", method="health_check")
|
||||
|
||||
result = request.to_dict()
|
||||
|
||||
expected = {"request_id": "test_to_dict", "method": "health_check", "args": {}}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_to_dict_with_args(self):
|
||||
"""Test converting ControlRequest to dictionary with arguments."""
|
||||
args = {"setting1": "value1", "setting2": 42}
|
||||
request = ControlRequest(request_id="test_to_dict_args", method="update_settings", args=args)
|
||||
|
||||
result = request.to_dict()
|
||||
|
||||
expected = {"request_id": "test_to_dict_args", "method": "update_settings", "args": args}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_get_method(self):
|
||||
"""Test get_method method."""
|
||||
method = "custom_operation"
|
||||
request = ControlRequest(request_id="test", method=method)
|
||||
|
||||
self.assertEqual(request.get_method(), method)
|
||||
|
||||
def test_get_args(self):
|
||||
"""Test get_args method."""
|
||||
args = {"param1": "value1", "param2": 123}
|
||||
request = ControlRequest(request_id="test", method="test", args=args)
|
||||
|
||||
result_args = request.get_args()
|
||||
|
||||
self.assertEqual(result_args, args)
|
||||
# Ensure it returns a copy, not the original dict
|
||||
self.assertIsNot(result_args, args)
|
||||
|
||||
def test_is_control_request_valid(self):
|
||||
"""Test is_control_request method with valid data."""
|
||||
valid_data = [
|
||||
{"request_id": "test1", "method": "method1"},
|
||||
{"request_id": "test2", "method": "method2", "args": {}},
|
||||
{"request_id": "test3", "method": "method3", "args": {"key": "value"}},
|
||||
]
|
||||
|
||||
for data in valid_data:
|
||||
with self.subTest(data=data):
|
||||
self.assertTrue(ControlRequest.is_control_request(data))
|
||||
|
||||
def test_is_control_request_invalid(self):
|
||||
"""Test is_control_request method with invalid data."""
|
||||
invalid_data = [
|
||||
# Missing required fields
|
||||
{"method": "test"}, # missing request_id
|
||||
{"request_id": "test"}, # missing method
|
||||
# Wrong field types
|
||||
{"request_id": 123, "method": "test"}, # request_id not string
|
||||
{"request_id": "test", "method": 456}, # method not string
|
||||
{"request_id": "test", "method": "test", "args": "not_a_dict"}, # args not dict
|
||||
# Not a dict
|
||||
"not_a_dict",
|
||||
123,
|
||||
None,
|
||||
]
|
||||
|
||||
for data in invalid_data:
|
||||
with self.subTest(data=data):
|
||||
self.assertFalse(ControlRequest.is_control_request(data))
|
||||
|
||||
def test_repr_simple(self):
|
||||
"""Test __repr__ method in simple mode."""
|
||||
with patch("fastdeploy.envs.FD_DEBUG", False):
|
||||
request = ControlRequest(request_id="test_repr", method="test_method")
|
||||
repr_str = repr(request)
|
||||
|
||||
self.assertIn("ControlRequest", repr_str)
|
||||
self.assertIn("test_repr", repr_str)
|
||||
self.assertIn("test_method", repr_str)
|
||||
self.assertNotIn("args", repr_str) # Args not shown in simple mode
|
||||
|
||||
def test_repr_debug_mode(self):
|
||||
"""Test __repr__ method in debug mode."""
|
||||
with patch("fastdeploy.envs.FD_DEBUG", True):
|
||||
args = {"debug_param": "debug_value"}
|
||||
request = ControlRequest(request_id="test_repr", method="test_method", args=args)
|
||||
repr_str = repr(request)
|
||||
|
||||
self.assertIn("ControlRequest", repr_str)
|
||||
self.assertIn("test_repr", repr_str)
|
||||
self.assertIn("test_method", repr_str)
|
||||
self.assertIn("debug_param", repr_str) # Args shown in debug mode
|
||||
|
||||
|
||||
class TestControlResponse(unittest.TestCase):
|
||||
"""Test cases for ControlResponse class."""
|
||||
|
||||
def test_initialization_basic(self):
|
||||
"""Test basic initialization of ControlResponse."""
|
||||
request_id = "test_response_123"
|
||||
|
||||
response = ControlResponse(request_id=request_id)
|
||||
|
||||
self.assertEqual(response.request_id, request_id)
|
||||
self.assertEqual(response.error_code, 200)
|
||||
self.assertIsNone(response.error_message)
|
||||
self.assertIsNone(response.result)
|
||||
self.assertTrue(response.finished)
|
||||
|
||||
def test_initialization_with_all_params(self):
|
||||
"""Test initialization with all parameters."""
|
||||
request_id = "test_response_456"
|
||||
error_code = 404
|
||||
error_message = "Not found"
|
||||
result = {"data": "some_result"}
|
||||
finished = False
|
||||
|
||||
response = ControlResponse(
|
||||
request_id=request_id, error_code=error_code, error_message=error_message, result=result, finished=finished
|
||||
)
|
||||
|
||||
self.assertEqual(response.request_id, request_id)
|
||||
self.assertEqual(response.error_code, error_code)
|
||||
self.assertEqual(response.error_message, error_message)
|
||||
self.assertEqual(response.result, result)
|
||||
self.assertEqual(response.finished, finished)
|
||||
|
||||
def test_initialization_error_cases(self):
|
||||
"""Test initialization with various error codes."""
|
||||
test_cases = [
|
||||
(200, None, True), # Success case
|
||||
(400, "Bad Request", False), # Client error
|
||||
(500, "Internal Error", True), # Server error
|
||||
]
|
||||
|
||||
for error_code, error_message, finished in test_cases:
|
||||
with self.subTest(error_code=error_code):
|
||||
response = ControlResponse(
|
||||
request_id="test", error_code=error_code, error_message=error_message, finished=finished
|
||||
)
|
||||
|
||||
self.assertEqual(response.error_code, error_code)
|
||||
self.assertEqual(response.error_message, error_message)
|
||||
self.assertEqual(response.finished, finished)
|
||||
|
||||
def test_from_dict_basic(self):
|
||||
"""Test creating ControlResponse from dictionary (basic case)."""
|
||||
data = {"request_id": "test_from_dict"}
|
||||
|
||||
response = ControlResponse.from_dict(data)
|
||||
|
||||
self.assertEqual(response.request_id, data["request_id"])
|
||||
self.assertEqual(response.error_code, 200)
|
||||
self.assertIsNone(response.error_message)
|
||||
self.assertIsNone(response.result)
|
||||
self.assertTrue(response.finished)
|
||||
|
||||
def test_from_dict_with_all_fields(self):
|
||||
"""Test creating ControlResponse from dictionary with all fields."""
|
||||
data = {
|
||||
"request_id": "test_from_dict_full",
|
||||
"error_code": 500,
|
||||
"error_message": "Test error",
|
||||
"result": {"key": "value"},
|
||||
"finished": False,
|
||||
}
|
||||
|
||||
response = ControlResponse.from_dict(data)
|
||||
|
||||
self.assertEqual(response.request_id, data["request_id"])
|
||||
self.assertEqual(response.error_code, data["error_code"])
|
||||
self.assertEqual(response.error_message, data["error_message"])
|
||||
self.assertEqual(response.result, data["result"])
|
||||
self.assertEqual(response.finished, data["finished"])
|
||||
|
||||
def test_to_dict_basic(self):
|
||||
"""Test converting ControlResponse to dictionary (basic case)."""
|
||||
response = ControlResponse(request_id="test_to_dict")
|
||||
|
||||
result = response.to_dict()
|
||||
|
||||
expected = {
|
||||
"request_id": "test_to_dict",
|
||||
"finished": True,
|
||||
"error_code": 200,
|
||||
"error_message": None,
|
||||
"result": None,
|
||||
}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_to_dict_with_all_fields(self):
|
||||
"""Test converting ControlResponse to dictionary with all fields."""
|
||||
response = ControlResponse(
|
||||
request_id="test_to_dict_full",
|
||||
error_code=400,
|
||||
error_message="Validation failed",
|
||||
result={"valid": False, "reason": "missing_field"},
|
||||
finished=False,
|
||||
)
|
||||
|
||||
result = response.to_dict()
|
||||
|
||||
expected = {
|
||||
"request_id": "test_to_dict_full",
|
||||
"finished": False,
|
||||
"error_code": 400,
|
||||
"error_message": "Validation failed",
|
||||
"result": {"valid": False, "reason": "missing_field"},
|
||||
}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_to_api_json_response_success(self):
|
||||
"""Test converting to JSONResponse for successful response."""
|
||||
result_data = {"metrics": {"cpu_usage": 0.5, "memory_used": 1024}}
|
||||
response = ControlResponse(request_id="test_json_success", result=result_data)
|
||||
|
||||
json_response = response.to_api_json_response()
|
||||
|
||||
self.assertIsInstance(json_response, JSONResponse)
|
||||
self.assertEqual(json_response.status_code, 200)
|
||||
|
||||
content = json_response.body.decode("utf-8")
|
||||
self.assertIn("success", content)
|
||||
self.assertIn("test_json_success", content)
|
||||
self.assertIn("cpu_usage", content)
|
||||
|
||||
def test_to_api_json_response_error(self):
|
||||
"""Test converting to JSONResponse for error response."""
|
||||
response = ControlResponse(request_id="test_json_error", error_code=503, error_message="Service unavailable")
|
||||
|
||||
json_response = response.to_api_json_response()
|
||||
|
||||
self.assertIsInstance(json_response, JSONResponse)
|
||||
self.assertEqual(json_response.status_code, 503)
|
||||
|
||||
content = json_response.body.decode("utf-8")
|
||||
self.assertIn("error", content)
|
||||
self.assertIn("test_json_error", content)
|
||||
self.assertIn("Service unavailable", content)
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test __repr__ method."""
|
||||
response = ControlResponse(
|
||||
request_id="test_repr", error_code=200, error_message=None, result={"data": "test"}, finished=True
|
||||
)
|
||||
|
||||
repr_str = repr(response)
|
||||
|
||||
# Check that all important fields are represented
|
||||
self.assertIn("ControlResponse", repr_str)
|
||||
self.assertIn("test_repr", repr_str)
|
||||
self.assertIn("200", repr_str)
|
||||
self.assertIn("test", repr_str) # from result
|
||||
self.assertIn("True", repr_str) # finished flag
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.request import Request, RequestStatus
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
|
||||
MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle"
|
||||
|
||||
|
||||
class TestResourceManagerV1(unittest.TestCase):
|
||||
"""Test cases for ResourceManagerV1."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
max_model_len=8192,
|
||||
tensor_parallel_size=1,
|
||||
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")),
|
||||
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")),
|
||||
)
|
||||
# Create and start the engine service
|
||||
mock_config = engine_args.create_engine_config()
|
||||
|
||||
self.manager = ResourceManagerV1(
|
||||
max_num_seqs=4,
|
||||
config=mock_config,
|
||||
tensor_parallel_size=1,
|
||||
splitwise_role="mixed",
|
||||
local_data_parallel_id=0,
|
||||
)
|
||||
|
||||
# Mock cache manager
|
||||
self.manager.cache_manager = Mock()
|
||||
self.manager.cache_manager.free_blocks = Mock()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.manager.need_block_num_signal.clear()
|
||||
|
||||
def test_preempted_all_with_no_running_requests(self):
|
||||
"""Test preempted_all with no running requests."""
|
||||
self.assertEqual(len(self.manager.running), 0)
|
||||
preempted_reqs = self.manager.preempted_all()
|
||||
self.assertEqual(len(preempted_reqs), 0)
|
||||
|
||||
def test_preempted_all_with_normal_requests(self):
|
||||
"""Test preempted_all with normal running requests."""
|
||||
# Add mock running requests
|
||||
req1 = Mock(spec=Request)
|
||||
req1.request_id = "req1"
|
||||
req1.use_extend_tables = False
|
||||
req1.status = RequestStatus.RUNNING
|
||||
req1.block_tables = [1, 2, 3]
|
||||
req1.num_cached_blocks = 0
|
||||
req1.idx = 0
|
||||
|
||||
req2 = Mock(spec=Request)
|
||||
req2.request_id = "req2"
|
||||
req2.use_extend_tables = False
|
||||
req2.status = RequestStatus.RUNNING
|
||||
req2.block_tables = [4, 5]
|
||||
req2.num_cached_blocks = 0
|
||||
req2.idx = 1
|
||||
|
||||
self.manager.running = [req1, req2]
|
||||
|
||||
preempted_reqs = self.manager.preempted_all()
|
||||
|
||||
# Verify
|
||||
self.assertEqual(len(preempted_reqs), 2)
|
||||
self.assertEqual(preempted_reqs[0].request_id, "req2")
|
||||
self.assertEqual(preempted_reqs[1].request_id, "req1")
|
||||
|
||||
# Verify request status changed
|
||||
self.assertEqual(req1.status, RequestStatus.PREEMPTED)
|
||||
self.assertEqual(req2.status, RequestStatus.PREEMPTED)
|
||||
|
||||
# Verify added to to_be_rescheduled_request_id_set
|
||||
self.assertIn("req1", self.manager.to_be_rescheduled_request_id_set)
|
||||
self.assertIn("req2", self.manager.to_be_rescheduled_request_id_set)
|
||||
|
||||
self.assertEqual(len(self.manager.running), 0)
|
||||
self.assertEqual(len(self.manager.waiting), 0)
|
||||
self.assertEqual(len(self.manager.to_be_rescheduled_request_id_set), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -246,6 +246,20 @@ class TestLocalScheduler(unittest.TestCase):
|
||||
# Verify only one request exists in scheduler
|
||||
self.assertEqual(len(self.scheduler.requests), 1)
|
||||
|
||||
def test_get_inflight_requests(self):
|
||||
"""Test getting inflight requests."""
|
||||
# Add some requests
|
||||
requests = [self.mock_request_1, self.mock_request_2]
|
||||
self.scheduler.put_requests(requests)
|
||||
|
||||
# Get inflight requests
|
||||
inflight_requests = self.scheduler.get_inflight_requests()
|
||||
|
||||
# Verify correct requests are returned
|
||||
self.assertEqual(len(inflight_requests), len(requests))
|
||||
for req in inflight_requests:
|
||||
self.assertIn(req, requests)
|
||||
|
||||
def test_put_requests_max_size_limit(self):
|
||||
"""Test that max size limit is enforced."""
|
||||
# Create scheduler with small max size
|
||||
|
||||
Reference in New Issue
Block a user