[RL] add pause, update_weights, resume interface for async RL (#6052)

* support dynamic run_control_request through zmq from apiserver to common_engine

* support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method

* change /is_puased from HTTP POST method to GET method

* add pause、resume、is_paused implementation

* support engine <==> worker communication(request&response)

* support sync weights through RDMA from checkpoint_transfer

* support specified version, rsync_config in update_weights rpc call

* add pause, update_weights, resume interface for async RL

* bug fix: update_weights support using default arguments

* fix typo

* typo fix

* typo fix

* typo fix

* add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all

* add "rsync" to LoadConfig.load_strategy Literal type hints

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* typo fix

* typo fix

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* check version/rsync params

* add error log when version.txt not exists

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* raise specified ValueError when paramters check failed

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* tp barrier after run_control_method

* encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue

* typo fix

* typo fix

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
wangyifei
2026-01-23 10:18:07 +08:00
committed by GitHub
parent 96b2cf2c20
commit b7c5daa316
18 changed files with 1170 additions and 16 deletions
+2 -1
View File
@@ -1225,7 +1225,8 @@ class LoadConfig:
): ):
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.dynamic_load_weight: bool = False self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal" self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
self.rsync_config: Optional[Dict[str, Any]] = None
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
+10
View File
@@ -195,6 +195,10 @@ class EngineArgs:
""" """
dynamic load weight strategy dynamic load weight strategy
""" """
rsync_config: Optional[Dict[str, Any]] = None
"""
rsync weights config info
"""
quantization: Optional[Dict[str, Any]] = None quantization: Optional[Dict[str, Any]] = None
guided_decoding_backend: str = "off" guided_decoding_backend: str = "off"
""" """
@@ -812,6 +816,12 @@ class EngineArgs:
default=EngineArgs.load_strategy, default=EngineArgs.load_strategy,
help="Flag to dynamic load strategy.", help="Flag to dynamic load strategy.",
) )
model_group.add_argument(
"--rsync-config",
type=json.loads,
default=EngineArgs.rsync_config,
help="Rsync weights config",
)
model_group.add_argument( model_group.add_argument(
"--engine-worker-queue-port", "--engine-worker-queue-port",
type=lambda s: s.split(",") if s else None, type=lambda s: s.split(",") if s else None,
+239 -1
View File
@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import copy import copy
import json import json
import multiprocessing import multiprocessing
@@ -38,7 +39,14 @@ import zmq
from tqdm import tqdm from tqdm import tqdm
import fastdeploy.metrics.trace as tracing import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
Request,
RequestOutput,
RequestStatus,
RequestType,
)
from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.eplb.utils import init_eplb_signals from fastdeploy.eplb.utils import init_eplb_signals
@@ -50,6 +58,7 @@ from fastdeploy.inter_communicator import (
ZmqIpcServer, ZmqIpcServer,
ZmqTcpServer, ZmqTcpServer,
) )
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.plugins.token_processor import load_token_processor_plugins
@@ -89,6 +98,18 @@ class EngineService:
else: else:
self.llm_logger = llm_logger self.llm_logger = llm_logger
self.is_paused = False # pause request generation
self._pause_cond = threading.Condition()
self._ctrl_worker_output_queues = []
tp_size = cfg.parallel_config.tensor_parallel_size
dp_index = cfg.parallel_config.local_data_parallel_id
for rank in range(tp_size):
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
@@ -758,6 +779,8 @@ class EngineService:
def _fetch_request(): def _fetch_request():
try: try:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
nonlocal is_fetching nonlocal is_fetching
num_prefill_batch = min( num_prefill_batch = min(
int(self.resource_manager.available_batch()), int(self.resource_manager.available_batch()),
@@ -922,6 +945,8 @@ class EngineService:
is_fetching = False is_fetching = False
while self.running: while self.running:
with self._pause_cond:
self._pause_cond.wait_for(lambda: not self.is_paused)
try: try:
if self.engine_worker_queue.exist_tasks(): if self.engine_worker_queue.exist_tasks():
time.sleep(0.001) time.sleep(0.001)
@@ -1065,6 +1090,17 @@ class EngineService:
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL) self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
continue continue
if ControlRequest.is_control_request(data):
try: # todo: run control request async, do not block request generation
control_req = ControlRequest.from_dict(data)
self.run_control_method(control_req)
except Exception as e:
self.llm_logger.error(
f"Failed to process control request {data.get('request_id')}: "
f"{e}, {traceback.format_exc()}"
)
continue
request, insert_task = data, [] request, insert_task = data, []
results: List[Tuple[str, Optional[str]]] = list() results: List[Tuple[str, Optional[str]]] = list()
if data: if data:
@@ -1096,6 +1132,13 @@ class EngineService:
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
self.llm_logger.debug(f"Receive request from api server: {request}") self.llm_logger.debug(f"Receive request from api server: {request}")
if self.is_paused:
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
self._send_error_response(
request.request_id, "Request is aborted since LLM Engine is paused."
)
continue
except Exception as e: except Exception as e:
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e) err_msg = str(e)
@@ -1135,6 +1178,200 @@ class EngineService:
f"traceback={traceback.format_exc()}" f"traceback={traceback.format_exc()}"
) )
def run_control_method(self, control_req: ControlRequest):
"""
Execute control method, process control request and return response.
This method is responsible for handling control requests, calling the corresponding
handler function based on the method name in the request. If the method doesn't exist
or is not callable, it returns an error response; otherwise executes the method and
returns a success response.
Args:
control_req (ControlRequest): Control request object containing request ID,
method name and parameters.
Returns:
None: No return value, sends ControlResponse through send_response_server.
"""
method = control_req.get_method()
request_id = control_req.request_id
try:
self.llm_logger.info(f"START run control method {request_id}: {method}")
handler_name = f"_control_{method}"
handler = getattr(self, handler_name, None)
if handler is None or not callable(handler):
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
self.llm_logger.error(str(error_result))
self.send_response_server.send_response(request_id, [error_result])
return
result = handler(control_req)
self.llm_logger.info(f"SUCCESS run control method {method}.")
succ_result = ControlResponse(request_id, 200, "Success", result)
self.send_response_server.send_response(request_id, [succ_result])
except Exception as e:
error_msg = f"Failed run control method {method}: {str(e)}"
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
self.send_response_server.send_response(request_id, [error_result])
def _control_pause(self, control_request: ControlRequest):
"""Pauses the LLM engine and aborts all running/inflight requests.
Args:
control_request: The control request containing pause command
Raises:
Exception: If pause is not supported in current configuration
Exception: If engine worker queue cleanup times out
Returns:
None
"""
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
if self.cfg.scheduler_config.name != "local":
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
with self._pause_cond:
if self.is_paused:
self.llm_logger.info("Pause Request Generation: already paused.")
self.is_paused = True
self.llm_logger.info("Start Abort Running Requests")
self.resource_manager.log_status()
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
timeout, count = 60, 0
while self.engine_worker_queue.exist_tasks():
time.sleep(0.001)
count += 1
if count >= timeout * 1000:
break
if count >= timeout * 1000:
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
running_reqs = self.resource_manager.preempted_all()
if len(running_reqs) > 0:
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
self.resource_manager.get_real_bsz()
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
# self.engine_worker_queue.clear_data()
self.token_processor.clear_data()
self.resource_manager.log_status()
# abort inflight requests to user
inflight_requests = self.scheduler.get_inflight_requests()
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
for req in inflight_requests:
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
self.scheduler.reset()
self.resource_manager.cache_manager.reset()
return None
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
"""Control function for resuming request generation.
This method resumes the paused request generation process by setting the pause flag
and notifying all waiting threads. It logs the start and end of the resume operation.
Args:
control_request: Control request object containing resume operation information
"""
self.llm_logger.info("START Resume Request Generation")
with self._pause_cond:
if not self.is_paused:
self.llm_logger.info("Resume Request Generation: not paused.")
return None
self.is_paused = False
self._pause_cond.notify_all()
self.llm_logger.info("END Resume Request Generation")
return None
def _control_is_paused(self, control_request: ControlRequest) -> bool:
"""
Check if the LLM engine is in paused state.
Args:
control_request: Control request object.
Returns:
dict: Dictionary containing pause status information, {'is_paused': bool}
"""
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
with self._pause_cond:
return {"is_paused": self.is_paused}
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
"""Update model weights
Args:
control_request: Control request object containing parameters for weight updates
Returns:
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
Raises:
Exception: Raised when the engine is not in paused state
"""
self.llm_logger.info("Update Model Weights")
with self._pause_cond:
if self.is_paused is False:
error_msg = "Pause LLM Engine first before calling updating weights"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
return self._call_worker(control_request, 60)
async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
This method concurrently waits for responses from all control workers
and enforces an overall timeout to avoid leaking pending tasks.
"""
timeout_ms = timeout * 1000
# Create one get() coroutine per worker output queue
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=timeout,
)
except asyncio.TimeoutError:
# Keep the error message consistent with previous behavior
raise Exception("Worker Update Weights Timeouted after 600s")
responses = []
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
if isinstance(msg, Exception):
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
raise Exception(f"Call Worker error: {repr(msg)}")
if msg is None:
# Preserve original semantics when no message is received
raise Exception("Worker Update Weights Timeouted after 600s")
response: ControlResponse = msg.payload
if response.request_id != request_id:
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
continue
if response.error_code != 200:
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
raise Exception(f"Call Worker error: {response.error_message}")
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
responses.append(response.result)
return responses
def _call_worker(self, control_request: ControlRequest, timeout: int):
request_id = control_request.request_id
self.engine_worker_queue.put_tasks(([control_request], 1))
# Use a single asyncio.run() to concurrently wait for all worker responses.
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
def _send_error_response(self, request_id, error_msg, error_code: int = 500): def _send_error_response(self, request_id, error_msg, error_code: int = 500):
self.llm_logger.error( self.llm_logger.error(
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}" f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
@@ -1708,6 +1945,7 @@ class EngineService:
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}" f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}" f" --load_choices {self.cfg.load_config.load_choices}"
+1
View File
@@ -556,6 +556,7 @@ class LLMEngine:
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'" f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}" f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}" f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}" f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}" f" --load_choices {self.cfg.load_config.load_choices}"
+153 -1
View File
@@ -26,6 +26,7 @@ from typing import TypeVar as TypingTypeVar
from typing import Union from typing import Union
import numpy as np import numpy as np
from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypeVar from typing_extensions import TypeVar
@@ -99,7 +100,7 @@ class Request:
guided_json_object: Optional[bool] = None, guided_json_object: Optional[bool] = None,
enable_thinking: Optional[bool] = None, enable_thinking: Optional[bool] = None,
reasoning_max_tokens: Optional[int] = None, reasoning_max_tokens: Optional[int] = None,
trace_carrier: dict = dict(), trace_carrier: Optional[Dict[str, Any]] = None,
dp_rank: Optional[int] = None, dp_rank: Optional[int] = None,
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
image_start: int = 0, image_start: int = 0,
@@ -544,6 +545,157 @@ class Request:
return hasattr(self, key) return hasattr(self, key)
class ControlRequest:
"""A generic control request that supports method and args for control operations.
This request type is used for system-level control operations rather than
typical inference requests. It enables dynamic control of engine behavior,
resource management, and system configuration via a flexible method-args interface.
"""
def __init__(
self,
request_id: str,
method: str,
args: Optional[Dict[str, Any]] = None,
) -> None:
"""
Args:
request_id: Unique identifier for the control request.
method: The control method to execute (e.g., "reset_scheduler", "get_metrics").
args: Optional arguments for the control method.
"""
self.request_id = request_id
self.method = method
self.args = args or {}
@classmethod
def from_dict(cls, d: dict):
"""Create ControlRequest instance from dictionary."""
return cls(request_id=d["request_id"], method=d["method"], args=d.get("args", {}))
def to_dict(self) -> dict:
"""Convert ControlRequest into a serializable dict."""
return {"request_id": self.request_id, "method": self.method, "args": self.args}
def __repr__(self) -> str:
"""Provide a clean representation of the control request."""
try:
if not envs.FD_DEBUG:
return f"ControlRequest(request_id={self.request_id}, method={self.method})"
else:
return (
f"ControlRequest("
f"request_id={self.request_id}, "
f"method={self.method}, "
f"args={self.args}"
f")"
)
except Exception as e:
return f"<ControlRequest repr failed: {e}>"
def get_method(self) -> str:
"""Get the control method name."""
return self.method
def get_args(self) -> Dict[str, Any]:
"""Get the control method arguments."""
return self.args.copy()
@staticmethod
def is_control_request(d: dict) -> bool:
"""
Check if a dictionary represents a valid ControlRequest.
Args:
d: Dictionary to check
Returns:
bool: True if the dictionary contains the required fields for a ControlRequest
"""
# Check if all required fields are present and have correct types
if not isinstance(d, dict):
return False
# Check field types
if "request_id" not in d or not isinstance(d.get("request_id"), str):
return False
if "method" not in d or not isinstance(d.get("method"), str):
return False
# Args is optional, but if present should be a dict
if "args" in d and not isinstance(d["args"], dict):
return False
return True
class ControlResponse:
"""
Response for control operations
"""
def __init__(
self,
request_id: str,
error_code: int = 200,
error_message: Optional[str] = None,
result: Optional[dict] = None,
finished: bool = True,
) -> None:
self.request_id = request_id
self.finished = finished
self.error_message = error_message
self.result = result
self.error_code = error_code
def to_dict(self) -> dict:
"""Convert ControlResponse into a serializable dict."""
return {
"request_id": self.request_id,
"finished": self.finished,
"error_code": self.error_code,
"error_message": self.error_message,
"result": self.result,
}
@classmethod
def from_dict(cls, d: dict):
"""Create ControlResponse instance from dictionary."""
return cls(
request_id=d["request_id"],
finished=d.get("finished", True),
error_code=d.get("error_code", 200),
error_message=d.get("error_message"),
result=d.get("result"),
)
def to_api_json_response(self) -> JSONResponse:
"""Convert ControlResponse into a JSONResponse."""
status = "success" if self.error_code == 200 else "error"
content = {
"request_id": self.request_id,
"status": status,
"error_message": self.error_message,
"result": self.result,
}
return JSONResponse(status_code=self.error_code, content=content)
def __repr__(self) -> str:
"""Provide a clean representation of the control response."""
return (
f"ControlResponse("
f"request_id={self.request_id}, "
f"finished={self.finished}, "
f"error_code={self.error_code}, "
f"error_message={self.error_message}, "
f"result={self.result}"
f")"
)
@dataclass(slots=True) @dataclass(slots=True)
class CompletionOutput: class CompletionOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
@@ -16,6 +16,7 @@
import copy import copy
import threading import threading
import time
import traceback import traceback
from collections import deque from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
@@ -240,6 +241,7 @@ class ResourceManagerV1(ResourceManager):
def reschedule_preempt_task(self, request_id, process_func=None): def reschedule_preempt_task(self, request_id, process_func=None):
with self.lock: with self.lock:
llm_logger.debug(f"reschedule {request_id} into waiting queue")
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests: if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id] request = self.requests[request_id]
if process_func is not None: if process_func is not None:
@@ -266,6 +268,39 @@ class ResourceManagerV1(ResourceManager):
return True return True
return False return False
def preempted_all(self):
with self.lock:
preempted_reqs = []
for i in range(len(self.running)):
req = self.running.pop()
# txt2image: req.use_extend_tables is True, req can not be preempted. txt2image is not used in RL.
if req.use_extend_tables:
self.running.insert(0, req)
continue
req.status = RequestStatus.PREEMPTED
req.num_computed_tokens = 0
self._free_blocks(req)
req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(req.request_id)
preempted_reqs.append(self._prepare_preempt_task(req))
return preempted_reqs
def wait_worker_inflight_requests_finish(self, timeout=60):
count = 0
while count < timeout * 1000:
# wait ongoing running and rescheduled requests finished in worker
running_reqs_count = len(self.to_be_rescheduled_request_id_set) + len(self.running)
if running_reqs_count == 0:
break
count += 1
time.sleep(0.001)
if count >= timeout * 1000:
llm_logger.info(
f"wait_inflight_requests_finish timeout after {timeout} seconds, "
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
)
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
""" """
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
@@ -1347,3 +1382,16 @@ class ResourceManagerV1(ResourceManager):
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
main_process_metrics.num_requests_running.set(len(self.running)) main_process_metrics.num_requests_running.set(len(self.running))
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running)) main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))
def log_status(self):
llm_logger.info(
f"ResourceManagerV1( "
f"waiting={len(self.waiting)}, "
f"running={len(self.running)}, "
f"preempted={len(self.to_be_rescheduled_request_id_set)}, "
f"tasks_list={self.tasks_list}, "
f"stop_flags={self.stop_flags}, "
f"req_dict={self.req_dict}, "
f"requests={self.requests}, "
f")"
)
+24 -1
View File
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import asyncio
import inspect import inspect
import os import os
import re import re
@@ -29,7 +30,12 @@ from filelock import FileLock
import fastdeploy.metrics.trace as tracing import fastdeploy.metrics.trace as tracing
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestStatus from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
Request,
RequestStatus,
)
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.eplb.utils import RedundantExpertWorkload from fastdeploy.eplb.utils import RedundantExpertWorkload
@@ -526,6 +532,23 @@ class EngineClient:
return True, "" return True, ""
async def run_control_method(self, request: ControlRequest):
api_server_logger.info(f"Start Run Control Method: {request}")
self.zmq_client.send_json(request.to_dict())
request_id = request.request_id
dealer, response_queue = await self.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
try:
# todo: support user specified timeout. default 600s is enough for most control cases
response = await asyncio.wait_for(response_queue.get(), timeout=600)
response = ControlResponse.from_dict(response[0])
api_server_logger.info(f"End Run Control Method: {response}")
return response
except asyncio.TimeoutError:
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
api_server_logger.error(f"Error Run Control Method: {error_response}")
return error_response
def is_workers_alive(self): def is_workers_alive(self):
""" """
Check the health of the model server by checking whether all workers are alive. Check the health of the model server by checking whether all workers are alive.
@@ -20,6 +20,7 @@ import signal
import threading import threading
import time import time
import traceback import traceback
import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -38,6 +39,7 @@ from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.async_llm import AsyncLLM from fastdeploy.engine.async_llm import AsyncLLM
from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.expert_service import ExpertService from fastdeploy.engine.expert_service import ExpertService
from fastdeploy.engine.request import ControlRequest
from fastdeploy.entrypoints.chat_utils import load_chat_template from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
@@ -370,6 +372,66 @@ def ping(raw_request: Request) -> Response:
return health(raw_request) return health(raw_request)
@app.post("/v1/pause")
async def pause(request: Request) -> Response:
# todo: support wait_for_inflight_requests(default False), clear_cache(default True) arguments
request_id = f"control-{uuid.uuid4()}"
control_request = ControlRequest(request_id, "pause")
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
@app.post("/v1/resume")
async def resume(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
control_request = ControlRequest(request_id, "resume")
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
@app.get("/v1/is_paused")
async def is_paused(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
control_request = ControlRequest(request_id, "is_paused")
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
@app.post("/v1/update_weights")
async def update_weights(request: Request) -> Response:
request_id = f"control-{uuid.uuid4()}"
request_data = await request.json() if await request.body() else {}
args = {}
# Validate and extract version parameter
if "version" in request_data and request_data["version"] is not None:
if not isinstance(request_data["version"], str):
return JSONResponse(
status_code=400, content={"error": "Invalid parameter type", "message": "version must be a string"}
)
args["version"] = request_data["version"]
# Validate and extract rsync_config parameter
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
if not isinstance(request_data["rsync_config"], dict):
return JSONResponse(
status_code=400,
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
)
if "etcd_server" not in request_data["rsync_config"]:
return JSONResponse(
status_code=400,
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
)
args["rsync_config"] = request_data["rsync_config"]
control_request = ControlRequest(request_id, "update_weights", args)
control_response = await app.state.engine_client.run_control_method(control_request)
return control_response.to_api_json_response()
def wrap_streaming_generator(original_generator: AsyncGenerator): def wrap_streaming_generator(original_generator: AsyncGenerator):
""" """
Wrap an async generator to release the connection semaphore when the generator is finished. Wrap an async generator to release the connection semaphore when the generator is finished.
+1 -1
View File
@@ -214,7 +214,7 @@ class Queue(BaseComponent):
else: else:
self.socket.bind(full_ep) self.socket.bind(full_ep)
fmq_logger.info(f"Queue {name} initialized on {full_ep}") fmq_logger.info(f"Queue {name}({role}) initialized on {full_ep}")
async def put(self, data: Any, shm_threshold: int = 1024 * 1024): async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
""" """
+1 -1
View File
@@ -1034,7 +1034,7 @@ class TokenProcessor:
finished=True, finished=True,
metrics=RequestMetrics( metrics=RequestMetrics(
arrival_time=time.time(), arrival_time=time.time(),
request_start_time=task.arrival_time, request_start_time=task.metrics.arrival_time,
), ),
) )
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
+110 -2
View File
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import io
import os import os
import time import time
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
@@ -27,13 +28,38 @@ from fastdeploy.config import FDConfig
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
def sync_weights_by_rdma(config, step, rank):
from checkpoint_transfer.core import RDMAWeightsDownloader
downloader = RDMAWeightsDownloader(config)
downloader.initialize()
logger.info(f"Fetching weights for step:{step}, rank:{rank}...")
data = downloader.get_weights(step, rank)
if data is None:
logger.error("Failed to get weights!")
raise Exception("Failed to rsync weights through checkpoint_transfer")
logger.info(f"Successfully retrieved data. Type: {type(data)}")
if isinstance(data, np.ndarray):
data_bytes = data.tobytes()
elif isinstance(data, (bytes, bytearray)):
data_bytes = data
else:
data_bytes = bytes(data)
logger.info(f"Data size: {len(data_bytes)} bytes")
buffer = io.BytesIO(data_bytes)
new_state_dict = paddle.load(buffer)
return new_state_dict
class DynamicWeightManager: class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes.""" """Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, models): def __init__(self, fd_config: FDConfig, models, local_rank: int):
"""Initialize with config and model instances.""" """Initialize with config and model instances."""
self.fd_config = fd_config self.fd_config = fd_config
self.load_config = fd_config.load_config self.load_config = fd_config.load_config
self.local_rank = local_rank
self.parallel_config = fd_config.parallel_config self.parallel_config = fd_config.parallel_config
self.state_dict: Dict[str, paddle.Tensor] = {} self.state_dict: Dict[str, paddle.Tensor] = {}
self.rank = fd_config.parallel_config.tensor_parallel_rank self.rank = fd_config.parallel_config.tensor_parallel_rank
@@ -46,7 +72,10 @@ class DynamicWeightManager:
else: else:
self.model_list = models self.model_list = models
self._capture_model_state() self._capture_model_state()
self.update_parameters() if self.load_config.load_strategy == "rsync":
self.update_weights_by_rdma()
else:
self.update_parameters()
self.finalize_update() self.finalize_update()
logger.info( logger.info(
@@ -62,6 +91,74 @@ class DynamicWeightManager:
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param self.state_dict[name] = param
def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None):
def valid_parameters(old_state_dict, new_state_dict):
is_valid = True
for key in old_state_dict:
if key not in new_state_dict:
is_valid = False
logger.error(f"Invalid parameter: {key} not in new_state_dict")
elif old_state_dict[key].shape != new_state_dict[key].shape:
is_valid = False
logger.error(
f"Invalid parameter: {key} shape mismatch, "
f"new shape:{new_state_dict[key].shape}, "
f"old shape:{old_state_dict[key].shape}"
)
elif old_state_dict[key].dtype != new_state_dict[key].dtype:
is_valid = False
logger.error(f"Invalid parameter: {key} dtype mismatch")
return is_valid
if rsync_config is None:
rsync_config = self.fd_config.load_config.rsync_config
if rsync_config is None or len(rsync_config) == 0:
raise Exception(
"rsync config not set, please set it in 1) launch arguments '--rsync-config' "
"or 2) interface arguments 'rsync_config'"
)
if version is None or version == "":
version = self.read_model_version_from_file()
if version is None or version == "":
raise Exception(
"rsync model version not set, please set it in 1) {model_version}/version.txt "
"or 2) interface arguments 'version'"
)
logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}")
rank = self.local_rank
sync_start = time.perf_counter()
new_state_dict = sync_weights_by_rdma(rsync_config, version, rank)
sync_cost = time.perf_counter() - sync_start
logger.info(f"weights sync cost {sync_cost:.2f} seconds")
old_state_dict = self.state_dict
if not valid_parameters(old_state_dict, new_state_dict):
error_msg = "Invalid new_state_dict, update parameters failed"
logger.error(error_msg)
raise ValueError(error_msg)
update_start = time.perf_counter()
for name, param in old_state_dict.items():
param.set_value(new_state_dict[name])
update_cost = time.perf_counter() - update_start
logger.info(f"params set value cost {update_cost:.2f} seconds")
total_cost = time.perf_counter() - sync_start
logger.info(
f"END update_weights_by_rdma, cost {total_cost:.2f} seconds"
f" version:{version}, rsync_config: {rsync_config}",
)
return {
"sync_cost": sync_cost,
"update_cost": update_cost,
"total_cost": total_cost,
"version": version,
"rank": rank,
}
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None: def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
"""Core method to update model parameters based on strategy.""" """Core method to update model parameters based on strategy."""
start_time = time.perf_counter() start_time = time.perf_counter()
@@ -257,6 +354,17 @@ class DynamicWeightManager:
if self.rank == 0: if self.rank == 0:
value[self.rank] = status value[self.rank] = status
def read_model_version_from_file(self):
model_dir = self.fd_config.model_config.model
version_file = os.path.join(model_dir, "version.txt")
try:
with open(version_file, "r", encoding="utf-8") as f:
version = f.read().strip()
return version
except (FileNotFoundError, OSError, IOError) as e:
logger.error(f"Failed to read model version file '{version_file}': {e}")
return None
@staticmethod @staticmethod
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block): def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
""" """
+4
View File
@@ -158,6 +158,10 @@ class LocalScheduler:
else: else:
self.ids_read_cursor -= len(expired_ids) self.ids_read_cursor -= len(expired_ids)
def get_inflight_requests(self) -> List[Request]:
with self.mutex:
return [request.raw for request in self.requests.values()]
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]: def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
""" """
Add new requests to the scheduler queue. Add new requests to the scheduler queue.
+5 -2
View File
@@ -20,7 +20,7 @@ import queue
import time import time
from concurrent.futures import Future from concurrent.futures import Future
from threading import Thread from threading import Thread
from typing import List, Optional, cast from typing import Any, Dict, List, Optional, cast
import numpy as np import numpy as np
import paddle import paddle
@@ -1518,7 +1518,7 @@ class GPUModelRunner(ModelRunnerBase):
if self.fd_config.load_config.dynamic_load_weight: if self.fd_config.load_config.dynamic_load_weight:
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
# 2. Load lora model # 2. Load lora model
@@ -2798,6 +2798,9 @@ class GPUModelRunner(ModelRunnerBase):
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
def padding_cudagraph_inputs(self) -> None: def padding_cudagraph_inputs(self) -> None:
""" """
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
+5 -1
View File
@@ -16,7 +16,7 @@
import gc import gc
import time import time
from typing import List, Optional from typing import Any, Dict, List, Optional
import paddle import paddle
import pynvml import pynvml
@@ -188,6 +188,10 @@ class GpuWorker(WorkerBase):
# accurate cache size # accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
"""update weights in place"""
return self.model_runner.update_weights(version, rsync_config)
def execute_model( def execute_model(
self, self,
model_forward_batch: Optional[List[Request]] = None, model_forward_batch: Optional[List[Request]] = None,
+58 -5
View File
@@ -15,9 +15,11 @@
""" """
import argparse import argparse
import asyncio
import json import json
import os import os
import time import time
import traceback
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
@@ -42,7 +44,7 @@ from fastdeploy.config import (
SpeculativeConfig, SpeculativeConfig,
StructuredOutputsConfig, StructuredOutputsConfig,
) )
from fastdeploy.engine.request import RequestType from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
from fastdeploy.eplb.async_expert_loader import ( from fastdeploy.eplb.async_expert_loader import (
MODEL_MAIN_NAME, MODEL_MAIN_NAME,
REARRANGE_EXPERT_MAGIC_NUM, REARRANGE_EXPERT_MAGIC_NUM,
@@ -57,6 +59,7 @@ from fastdeploy.inter_communicator import (
ModelWeightsStatus, ModelWeightsStatus,
RearrangeExpertStatus, RearrangeExpertStatus,
) )
from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.model_executor.layers.quantization import parse_quant_config
from fastdeploy.model_executor.utils import v1_loader_support from fastdeploy.model_executor.utils import v1_loader_support
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
@@ -164,6 +167,12 @@ class PaddleDisWorkerProc:
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
def init_control(self):
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
queue_name = f"ctrl_w2e_rank{self.local_rank}_{engine_worker_queue_port}"
logger.info(f"Init Control Output Queue: {queue_name}(producer)")
self._ctrl_output = FMQ().queue(queue_name, "producer")
def init_health_status(self) -> None: def init_health_status(self) -> None:
""" """
Initialize the health status of the worker. Initialize the health status of the worker.
@@ -513,10 +522,20 @@ class PaddleDisWorkerProc:
else: else:
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
req_dicts = [] req_dicts, control_reqs = [], []
for req_dict, bsz in tasks: for req_dict, bsz in tasks:
max_occupied_batch_index = int(bsz) if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
req_dicts.extend(req_dict) control_reqs.append(req_dict[0])
else:
max_occupied_batch_index = int(bsz)
req_dicts.extend(req_dict)
# todo: run control request async
if len(control_reqs) > 0:
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
for control_req in control_reqs:
self.run_control_method(control_req)
self._tp_barrier_wait() if tp_size > 1 else None
# Count prefill requests in current batch # Count prefill requests in current batch
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL) num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
@@ -655,6 +674,32 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier() paddle.distributed.barrier()
self.loaded_model_signal.value[0] = 1 self.loaded_model_signal.value[0] = 1
def run_control_method(self, control_request: ControlRequest) -> None:
logger.info(f"Start run control request: {control_request}")
request_id = control_request.request_id
method = control_request.method
kwargs = control_request.args
handler = getattr(self.worker, method, None)
if handler is None or not callable(handler):
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
error_result = ControlResponse(request_id, 400, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
return
try:
result = handler(**kwargs)
succ_result = ControlResponse(request_id, 200, "Success", result)
logger.info(
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
)
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
except Exception as e:
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
logger.error(f"{error_msg}\n{traceback.format_exc()}")
error_result = ControlResponse(request_id, 500, error_msg)
asyncio.run(self._ctrl_output.put(error_result))
def parse_args(): def parse_args():
""" """
@@ -813,12 +858,18 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--load_strategy", "--load_strategy",
type=str, type=str,
choices=["ipc", "ipc_snapshot", "meta", "normal"], choices=["ipc", "ipc_snapshot", "meta", "normal", "rsync"],
default="ipc_snapshot", default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: " help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, " "'ipc': real-time IPC streaming with automatic resharding, "
"'ipc_snapshot': load from disk snapshot of IPC weights.", "'ipc_snapshot': load from disk snapshot of IPC weights.",
) )
parser.add_argument(
"--rsync_config",
type=json.loads,
default=None,
help="Rsync weights config",
)
parser.add_argument( parser.add_argument(
"--enable_logprob", "--enable_logprob",
action="store_true", action="store_true",
@@ -1045,6 +1096,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}") logger.info(f"- Load strategy: {load_config.load_strategy}")
logger.info(f"- Rsync config: {load_config.rsync_config}, {type(load_config.rsync_config)}")
if not ( if not (
current_platform.is_cuda() current_platform.is_cuda()
@@ -1112,6 +1164,7 @@ def run_worker_proc() -> None:
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank) worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
else: else:
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank) worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
worker_proc.init_control()
# Initialize device and create model runner # Initialize device and create model runner
worker_proc.init_device() worker_proc.init_device()
@@ -0,0 +1,329 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import patch
from fastapi.responses import JSONResponse
from fastdeploy.engine.request import ControlRequest, ControlResponse
class TestControlRequest(unittest.TestCase):
"""Test cases for ControlRequest class."""
def test_initialization_basic(self):
"""Test basic initialization of ControlRequest."""
request_id = "test_request_123"
method = "get_metrics"
request = ControlRequest(request_id=request_id, method=method)
self.assertEqual(request.request_id, request_id)
self.assertEqual(request.method, method)
self.assertEqual(request.args, {})
def test_initialization_with_args(self):
"""Test initialization with arguments."""
request_id = "test_request_456"
method = "reset_scheduler"
args = {"force": True, "timeout": 30}
request = ControlRequest(request_id=request_id, method=method, args=args)
self.assertEqual(request.request_id, request_id)
self.assertEqual(request.method, method)
self.assertEqual(request.args, args)
def test_from_dict_basic(self):
"""Test creating ControlRequest from dictionary (basic case)."""
data = {"request_id": "test_from_dict", "method": "status_check"}
request = ControlRequest.from_dict(data)
self.assertEqual(request.request_id, data["request_id"])
self.assertEqual(request.method, data["method"])
self.assertEqual(request.args, {})
def test_from_dict_with_args(self):
"""Test creating ControlRequest from dictionary with arguments."""
data = {
"request_id": "test_from_dict_args",
"method": "configure",
"args": {"max_requests": 100, "queue_timeout": 60},
}
request = ControlRequest.from_dict(data)
self.assertEqual(request.request_id, data["request_id"])
self.assertEqual(request.method, data["method"])
self.assertEqual(request.args, data["args"])
def test_to_dict_basic(self):
"""Test converting ControlRequest to dictionary (basic case)."""
request = ControlRequest(request_id="test_to_dict", method="health_check")
result = request.to_dict()
expected = {"request_id": "test_to_dict", "method": "health_check", "args": {}}
self.assertEqual(result, expected)
def test_to_dict_with_args(self):
"""Test converting ControlRequest to dictionary with arguments."""
args = {"setting1": "value1", "setting2": 42}
request = ControlRequest(request_id="test_to_dict_args", method="update_settings", args=args)
result = request.to_dict()
expected = {"request_id": "test_to_dict_args", "method": "update_settings", "args": args}
self.assertEqual(result, expected)
def test_get_method(self):
"""Test get_method method."""
method = "custom_operation"
request = ControlRequest(request_id="test", method=method)
self.assertEqual(request.get_method(), method)
def test_get_args(self):
"""Test get_args method."""
args = {"param1": "value1", "param2": 123}
request = ControlRequest(request_id="test", method="test", args=args)
result_args = request.get_args()
self.assertEqual(result_args, args)
# Ensure it returns a copy, not the original dict
self.assertIsNot(result_args, args)
def test_is_control_request_valid(self):
"""Test is_control_request method with valid data."""
valid_data = [
{"request_id": "test1", "method": "method1"},
{"request_id": "test2", "method": "method2", "args": {}},
{"request_id": "test3", "method": "method3", "args": {"key": "value"}},
]
for data in valid_data:
with self.subTest(data=data):
self.assertTrue(ControlRequest.is_control_request(data))
def test_is_control_request_invalid(self):
"""Test is_control_request method with invalid data."""
invalid_data = [
# Missing required fields
{"method": "test"}, # missing request_id
{"request_id": "test"}, # missing method
# Wrong field types
{"request_id": 123, "method": "test"}, # request_id not string
{"request_id": "test", "method": 456}, # method not string
{"request_id": "test", "method": "test", "args": "not_a_dict"}, # args not dict
# Not a dict
"not_a_dict",
123,
None,
]
for data in invalid_data:
with self.subTest(data=data):
self.assertFalse(ControlRequest.is_control_request(data))
def test_repr_simple(self):
"""Test __repr__ method in simple mode."""
with patch("fastdeploy.envs.FD_DEBUG", False):
request = ControlRequest(request_id="test_repr", method="test_method")
repr_str = repr(request)
self.assertIn("ControlRequest", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("test_method", repr_str)
self.assertNotIn("args", repr_str) # Args not shown in simple mode
def test_repr_debug_mode(self):
"""Test __repr__ method in debug mode."""
with patch("fastdeploy.envs.FD_DEBUG", True):
args = {"debug_param": "debug_value"}
request = ControlRequest(request_id="test_repr", method="test_method", args=args)
repr_str = repr(request)
self.assertIn("ControlRequest", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("test_method", repr_str)
self.assertIn("debug_param", repr_str) # Args shown in debug mode
class TestControlResponse(unittest.TestCase):
"""Test cases for ControlResponse class."""
def test_initialization_basic(self):
"""Test basic initialization of ControlResponse."""
request_id = "test_response_123"
response = ControlResponse(request_id=request_id)
self.assertEqual(response.request_id, request_id)
self.assertEqual(response.error_code, 200)
self.assertIsNone(response.error_message)
self.assertIsNone(response.result)
self.assertTrue(response.finished)
def test_initialization_with_all_params(self):
"""Test initialization with all parameters."""
request_id = "test_response_456"
error_code = 404
error_message = "Not found"
result = {"data": "some_result"}
finished = False
response = ControlResponse(
request_id=request_id, error_code=error_code, error_message=error_message, result=result, finished=finished
)
self.assertEqual(response.request_id, request_id)
self.assertEqual(response.error_code, error_code)
self.assertEqual(response.error_message, error_message)
self.assertEqual(response.result, result)
self.assertEqual(response.finished, finished)
def test_initialization_error_cases(self):
"""Test initialization with various error codes."""
test_cases = [
(200, None, True), # Success case
(400, "Bad Request", False), # Client error
(500, "Internal Error", True), # Server error
]
for error_code, error_message, finished in test_cases:
with self.subTest(error_code=error_code):
response = ControlResponse(
request_id="test", error_code=error_code, error_message=error_message, finished=finished
)
self.assertEqual(response.error_code, error_code)
self.assertEqual(response.error_message, error_message)
self.assertEqual(response.finished, finished)
def test_from_dict_basic(self):
"""Test creating ControlResponse from dictionary (basic case)."""
data = {"request_id": "test_from_dict"}
response = ControlResponse.from_dict(data)
self.assertEqual(response.request_id, data["request_id"])
self.assertEqual(response.error_code, 200)
self.assertIsNone(response.error_message)
self.assertIsNone(response.result)
self.assertTrue(response.finished)
def test_from_dict_with_all_fields(self):
"""Test creating ControlResponse from dictionary with all fields."""
data = {
"request_id": "test_from_dict_full",
"error_code": 500,
"error_message": "Test error",
"result": {"key": "value"},
"finished": False,
}
response = ControlResponse.from_dict(data)
self.assertEqual(response.request_id, data["request_id"])
self.assertEqual(response.error_code, data["error_code"])
self.assertEqual(response.error_message, data["error_message"])
self.assertEqual(response.result, data["result"])
self.assertEqual(response.finished, data["finished"])
def test_to_dict_basic(self):
"""Test converting ControlResponse to dictionary (basic case)."""
response = ControlResponse(request_id="test_to_dict")
result = response.to_dict()
expected = {
"request_id": "test_to_dict",
"finished": True,
"error_code": 200,
"error_message": None,
"result": None,
}
self.assertEqual(result, expected)
def test_to_dict_with_all_fields(self):
"""Test converting ControlResponse to dictionary with all fields."""
response = ControlResponse(
request_id="test_to_dict_full",
error_code=400,
error_message="Validation failed",
result={"valid": False, "reason": "missing_field"},
finished=False,
)
result = response.to_dict()
expected = {
"request_id": "test_to_dict_full",
"finished": False,
"error_code": 400,
"error_message": "Validation failed",
"result": {"valid": False, "reason": "missing_field"},
}
self.assertEqual(result, expected)
def test_to_api_json_response_success(self):
"""Test converting to JSONResponse for successful response."""
result_data = {"metrics": {"cpu_usage": 0.5, "memory_used": 1024}}
response = ControlResponse(request_id="test_json_success", result=result_data)
json_response = response.to_api_json_response()
self.assertIsInstance(json_response, JSONResponse)
self.assertEqual(json_response.status_code, 200)
content = json_response.body.decode("utf-8")
self.assertIn("success", content)
self.assertIn("test_json_success", content)
self.assertIn("cpu_usage", content)
def test_to_api_json_response_error(self):
"""Test converting to JSONResponse for error response."""
response = ControlResponse(request_id="test_json_error", error_code=503, error_message="Service unavailable")
json_response = response.to_api_json_response()
self.assertIsInstance(json_response, JSONResponse)
self.assertEqual(json_response.status_code, 503)
content = json_response.body.decode("utf-8")
self.assertIn("error", content)
self.assertIn("test_json_error", content)
self.assertIn("Service unavailable", content)
def test_repr_method(self):
"""Test __repr__ method."""
response = ControlResponse(
request_id="test_repr", error_code=200, error_message=None, result={"data": "test"}, finished=True
)
repr_str = repr(response)
# Check that all important fields are represented
self.assertIn("ControlResponse", repr_str)
self.assertIn("test_repr", repr_str)
self.assertIn("200", repr_str)
self.assertIn("test", repr_str) # from result
self.assertIn("True", repr_str) # finished flag
if __name__ == "__main__":
unittest.main()
+104
View File
@@ -0,0 +1,104 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from unittest.mock import Mock
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.request import Request, RequestStatus
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle"
class TestResourceManagerV1(unittest.TestCase):
"""Test cases for ResourceManagerV1."""
def setUp(self):
"""Set up test fixtures."""
engine_args = EngineArgs(
model=MODEL_NAME,
max_model_len=8192,
tensor_parallel_size=1,
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")),
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")),
)
# Create and start the engine service
mock_config = engine_args.create_engine_config()
self.manager = ResourceManagerV1(
max_num_seqs=4,
config=mock_config,
tensor_parallel_size=1,
splitwise_role="mixed",
local_data_parallel_id=0,
)
# Mock cache manager
self.manager.cache_manager = Mock()
self.manager.cache_manager.free_blocks = Mock()
def tearDown(self) -> None:
self.manager.need_block_num_signal.clear()
def test_preempted_all_with_no_running_requests(self):
"""Test preempted_all with no running requests."""
self.assertEqual(len(self.manager.running), 0)
preempted_reqs = self.manager.preempted_all()
self.assertEqual(len(preempted_reqs), 0)
def test_preempted_all_with_normal_requests(self):
"""Test preempted_all with normal running requests."""
# Add mock running requests
req1 = Mock(spec=Request)
req1.request_id = "req1"
req1.use_extend_tables = False
req1.status = RequestStatus.RUNNING
req1.block_tables = [1, 2, 3]
req1.num_cached_blocks = 0
req1.idx = 0
req2 = Mock(spec=Request)
req2.request_id = "req2"
req2.use_extend_tables = False
req2.status = RequestStatus.RUNNING
req2.block_tables = [4, 5]
req2.num_cached_blocks = 0
req2.idx = 1
self.manager.running = [req1, req2]
preempted_reqs = self.manager.preempted_all()
# Verify
self.assertEqual(len(preempted_reqs), 2)
self.assertEqual(preempted_reqs[0].request_id, "req2")
self.assertEqual(preempted_reqs[1].request_id, "req1")
# Verify request status changed
self.assertEqual(req1.status, RequestStatus.PREEMPTED)
self.assertEqual(req2.status, RequestStatus.PREEMPTED)
# Verify added to to_be_rescheduled_request_id_set
self.assertIn("req1", self.manager.to_be_rescheduled_request_id_set)
self.assertIn("req2", self.manager.to_be_rescheduled_request_id_set)
self.assertEqual(len(self.manager.running), 0)
self.assertEqual(len(self.manager.waiting), 0)
self.assertEqual(len(self.manager.to_be_rescheduled_request_id_set), 2)
if __name__ == "__main__":
unittest.main()
+14
View File
@@ -246,6 +246,20 @@ class TestLocalScheduler(unittest.TestCase):
# Verify only one request exists in scheduler # Verify only one request exists in scheduler
self.assertEqual(len(self.scheduler.requests), 1) self.assertEqual(len(self.scheduler.requests), 1)
def test_get_inflight_requests(self):
"""Test getting inflight requests."""
# Add some requests
requests = [self.mock_request_1, self.mock_request_2]
self.scheduler.put_requests(requests)
# Get inflight requests
inflight_requests = self.scheduler.get_inflight_requests()
# Verify correct requests are returned
self.assertEqual(len(inflight_requests), len(requests))
for req in inflight_requests:
self.assertIn(req, requests)
def test_put_requests_max_size_limit(self): def test_put_requests_max_size_limit(self):
"""Test that max size limit is enforced.""" """Test that max size limit is enforced."""
# Create scheduler with small max size # Create scheduler with small max size