mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[RL] add pause, update_weights, resume interface for async RL (#6052)
* support dynamic run_control_request through zmq from apiserver to common_engine * support pause/resume/is_paused/update_weights in apiserver->common_engine by common run_control_method * change /is_puased from HTTP POST method to GET method * add pause、resume、is_paused implementation * support engine <==> worker communication(request&response) * support sync weights through RDMA from checkpoint_transfer * support specified version, rsync_config in update_weights rpc call * add pause, update_weights, resume interface for async RL * bug fix: update_weights support using default arguments * fix typo * typo fix * typo fix * typo fix * add unitest for control request/response, localscheduler.get_inflight_requests, resource_manager_v1.preempted_all * add "rsync" to LoadConfig.load_strategy Literal type hints Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * typo fix * typo fix * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * check version/rsync params * add error log when version.txt not exists Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * raise specified ValueError when paramters check failed Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * tp barrier after run_control_method * encode 'engine_worker_queue_port' to unique name of worker2engine fmq queue * typo fix * typo fix --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1225,7 +1225,8 @@ class LoadConfig:
|
|||||||
):
|
):
|
||||||
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
|
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
|
||||||
self.dynamic_load_weight: bool = False
|
self.dynamic_load_weight: bool = False
|
||||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
|
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal", "rsync"]] = "normal"
|
||||||
|
self.rsync_config: Optional[Dict[str, Any]] = None
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|||||||
@@ -195,6 +195,10 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
dynamic load weight strategy
|
dynamic load weight strategy
|
||||||
"""
|
"""
|
||||||
|
rsync_config: Optional[Dict[str, Any]] = None
|
||||||
|
"""
|
||||||
|
rsync weights config info
|
||||||
|
"""
|
||||||
quantization: Optional[Dict[str, Any]] = None
|
quantization: Optional[Dict[str, Any]] = None
|
||||||
guided_decoding_backend: str = "off"
|
guided_decoding_backend: str = "off"
|
||||||
"""
|
"""
|
||||||
@@ -812,6 +816,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.load_strategy,
|
default=EngineArgs.load_strategy,
|
||||||
help="Flag to dynamic load strategy.",
|
help="Flag to dynamic load strategy.",
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
"--rsync-config",
|
||||||
|
type=json.loads,
|
||||||
|
default=EngineArgs.rsync_config,
|
||||||
|
help="Rsync weights config",
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--engine-worker-queue-port",
|
"--engine-worker-queue-port",
|
||||||
type=lambda s: s.split(",") if s else None,
|
type=lambda s: s.split(",") if s else None,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
@@ -38,7 +39,14 @@ import zmq
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import fastdeploy.metrics.trace as tracing
|
import fastdeploy.metrics.trace as tracing
|
||||||
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
|
from fastdeploy.engine.request import (
|
||||||
|
ControlRequest,
|
||||||
|
ControlResponse,
|
||||||
|
Request,
|
||||||
|
RequestOutput,
|
||||||
|
RequestStatus,
|
||||||
|
RequestType,
|
||||||
|
)
|
||||||
from fastdeploy.engine.resource_manager import ResourceManager
|
from fastdeploy.engine.resource_manager import ResourceManager
|
||||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||||
from fastdeploy.eplb.utils import init_eplb_signals
|
from fastdeploy.eplb.utils import init_eplb_signals
|
||||||
@@ -50,6 +58,7 @@ from fastdeploy.inter_communicator import (
|
|||||||
ZmqIpcServer,
|
ZmqIpcServer,
|
||||||
ZmqTcpServer,
|
ZmqTcpServer,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.inter_communicator.fmq import FMQ
|
||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||||
@@ -89,6 +98,18 @@ class EngineService:
|
|||||||
else:
|
else:
|
||||||
self.llm_logger = llm_logger
|
self.llm_logger = llm_logger
|
||||||
|
|
||||||
|
self.is_paused = False # pause request generation
|
||||||
|
self._pause_cond = threading.Condition()
|
||||||
|
|
||||||
|
self._ctrl_worker_output_queues = []
|
||||||
|
tp_size = cfg.parallel_config.tensor_parallel_size
|
||||||
|
dp_index = cfg.parallel_config.local_data_parallel_id
|
||||||
|
for rank in range(tp_size):
|
||||||
|
engine_worker_queue_port = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||||
|
name = f"ctrl_w2e_rank{rank+tp_size*dp_index}_{engine_worker_queue_port}"
|
||||||
|
self.llm_logger.info(f"Init Worker Control Output Queue: {name}(consumer)")
|
||||||
|
self._ctrl_worker_output_queues.append(FMQ().queue(name, "consumer"))
|
||||||
|
|
||||||
self.scheduler = cfg.scheduler_config.scheduler()
|
self.scheduler = cfg.scheduler_config.scheduler()
|
||||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||||
|
|
||||||
@@ -758,6 +779,8 @@ class EngineService:
|
|||||||
|
|
||||||
def _fetch_request():
|
def _fetch_request():
|
||||||
try:
|
try:
|
||||||
|
with self._pause_cond:
|
||||||
|
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||||
nonlocal is_fetching
|
nonlocal is_fetching
|
||||||
num_prefill_batch = min(
|
num_prefill_batch = min(
|
||||||
int(self.resource_manager.available_batch()),
|
int(self.resource_manager.available_batch()),
|
||||||
@@ -922,6 +945,8 @@ class EngineService:
|
|||||||
is_fetching = False
|
is_fetching = False
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
with self._pause_cond:
|
||||||
|
self._pause_cond.wait_for(lambda: not self.is_paused)
|
||||||
try:
|
try:
|
||||||
if self.engine_worker_queue.exist_tasks():
|
if self.engine_worker_queue.exist_tasks():
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
@@ -1065,6 +1090,17 @@ class EngineService:
|
|||||||
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
self.recv_request_server = ZmqIpcServer(name=self.api_server_pid, mode=zmq.PULL)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if ControlRequest.is_control_request(data):
|
||||||
|
try: # todo: run control request async, do not block request generation
|
||||||
|
control_req = ControlRequest.from_dict(data)
|
||||||
|
self.run_control_method(control_req)
|
||||||
|
except Exception as e:
|
||||||
|
self.llm_logger.error(
|
||||||
|
f"Failed to process control request {data.get('request_id')}: "
|
||||||
|
f"{e}, {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
request, insert_task = data, []
|
request, insert_task = data, []
|
||||||
results: List[Tuple[str, Optional[str]]] = list()
|
results: List[Tuple[str, Optional[str]]] = list()
|
||||||
if data:
|
if data:
|
||||||
@@ -1096,6 +1132,13 @@ class EngineService:
|
|||||||
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
|
||||||
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
|
||||||
self.llm_logger.debug(f"Receive request from api server: {request}")
|
self.llm_logger.debug(f"Receive request from api server: {request}")
|
||||||
|
|
||||||
|
if self.is_paused:
|
||||||
|
self.llm_logger.warning(f"Engine is paused, drop request: {request}")
|
||||||
|
self._send_error_response(
|
||||||
|
request.request_id, "Request is aborted since LLM Engine is paused."
|
||||||
|
)
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
|
||||||
err_msg = str(e)
|
err_msg = str(e)
|
||||||
@@ -1135,6 +1178,200 @@ class EngineService:
|
|||||||
f"traceback={traceback.format_exc()}"
|
f"traceback={traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_control_method(self, control_req: ControlRequest):
|
||||||
|
"""
|
||||||
|
Execute control method, process control request and return response.
|
||||||
|
|
||||||
|
This method is responsible for handling control requests, calling the corresponding
|
||||||
|
handler function based on the method name in the request. If the method doesn't exist
|
||||||
|
or is not callable, it returns an error response; otherwise executes the method and
|
||||||
|
returns a success response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_req (ControlRequest): Control request object containing request ID,
|
||||||
|
method name and parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: No return value, sends ControlResponse through send_response_server.
|
||||||
|
"""
|
||||||
|
method = control_req.get_method()
|
||||||
|
request_id = control_req.request_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.llm_logger.info(f"START run control method {request_id}: {method}")
|
||||||
|
|
||||||
|
handler_name = f"_control_{method}"
|
||||||
|
handler = getattr(self, handler_name, None)
|
||||||
|
if handler is None or not callable(handler):
|
||||||
|
error_result = ControlResponse(request_id, 400, f"unknown control method:{method}")
|
||||||
|
self.llm_logger.error(str(error_result))
|
||||||
|
self.send_response_server.send_response(request_id, [error_result])
|
||||||
|
return
|
||||||
|
|
||||||
|
result = handler(control_req)
|
||||||
|
self.llm_logger.info(f"SUCCESS run control method {method}.")
|
||||||
|
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||||
|
self.send_response_server.send_response(request_id, [succ_result])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed run control method {method}: {str(e)}"
|
||||||
|
self.llm_logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||||
|
error_result = ControlResponse(request_id, 500, error_msg)
|
||||||
|
self.send_response_server.send_response(request_id, [error_result])
|
||||||
|
|
||||||
|
def _control_pause(self, control_request: ControlRequest):
|
||||||
|
"""Pauses the LLM engine and aborts all running/inflight requests.
|
||||||
|
Args:
|
||||||
|
control_request: The control request containing pause command
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If pause is not supported in current configuration
|
||||||
|
Exception: If engine worker queue cleanup times out
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||||
|
raise Exception("pause only supported in ENABLE_V1_KVCACHE_SCHEDULER")
|
||||||
|
if self.cfg.scheduler_config.name != "local":
|
||||||
|
raise Exception(f"pause only supported in local scheduler, current {self.cfg.scheduler_config.name}")
|
||||||
|
|
||||||
|
with self._pause_cond:
|
||||||
|
if self.is_paused:
|
||||||
|
self.llm_logger.info("Pause Request Generation: already paused.")
|
||||||
|
self.is_paused = True
|
||||||
|
|
||||||
|
self.llm_logger.info("Start Abort Running Requests")
|
||||||
|
|
||||||
|
self.resource_manager.log_status()
|
||||||
|
# preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue
|
||||||
|
timeout, count = 60, 0
|
||||||
|
while self.engine_worker_queue.exist_tasks():
|
||||||
|
time.sleep(0.001)
|
||||||
|
count += 1
|
||||||
|
if count >= timeout * 1000:
|
||||||
|
break
|
||||||
|
if count >= timeout * 1000:
|
||||||
|
error_msg = f"wait engine_worker_queue tasks empty timeout after {timeout} seconds, worker may Hanged"
|
||||||
|
self.llm_logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
running_reqs = self.resource_manager.preempted_all()
|
||||||
|
if len(running_reqs) > 0:
|
||||||
|
self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.")
|
||||||
|
self.resource_manager.get_real_bsz()
|
||||||
|
self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz))
|
||||||
|
self.resource_manager.wait_worker_inflight_requests_finish(timeout=60)
|
||||||
|
# self.engine_worker_queue.clear_data()
|
||||||
|
self.token_processor.clear_data()
|
||||||
|
self.resource_manager.log_status()
|
||||||
|
|
||||||
|
# abort inflight requests to user
|
||||||
|
inflight_requests = self.scheduler.get_inflight_requests()
|
||||||
|
self.llm_logger.info(f"Start Abort Inflight Requests, total {len(inflight_requests)} waiting requests")
|
||||||
|
for req in inflight_requests:
|
||||||
|
self._send_error_response(req.request_id, "Request is aborted since LLM Engine is paused.")
|
||||||
|
self.scheduler.reset()
|
||||||
|
|
||||||
|
self.resource_manager.cache_manager.reset()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _control_resume(self, control_request: ControlRequest) -> Optional[dict]:
|
||||||
|
"""Control function for resuming request generation.
|
||||||
|
|
||||||
|
This method resumes the paused request generation process by setting the pause flag
|
||||||
|
and notifying all waiting threads. It logs the start and end of the resume operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_request: Control request object containing resume operation information
|
||||||
|
"""
|
||||||
|
self.llm_logger.info("START Resume Request Generation")
|
||||||
|
with self._pause_cond:
|
||||||
|
if not self.is_paused:
|
||||||
|
self.llm_logger.info("Resume Request Generation: not paused.")
|
||||||
|
return None
|
||||||
|
self.is_paused = False
|
||||||
|
self._pause_cond.notify_all()
|
||||||
|
self.llm_logger.info("END Resume Request Generation")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _control_is_paused(self, control_request: ControlRequest) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the LLM engine is in paused state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
control_request: Control request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing pause status information, {'is_paused': bool}
|
||||||
|
"""
|
||||||
|
self.llm_logger.info(f"LLM Engine request generation is paused: {self.is_paused}")
|
||||||
|
with self._pause_cond:
|
||||||
|
return {"is_paused": self.is_paused}
|
||||||
|
|
||||||
|
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||||
|
"""Update model weights
|
||||||
|
Args:
|
||||||
|
control_request: Control request object containing parameters for weight updates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[dict]: Returns the result dictionary if update succeeds, None otherwise
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Raised when the engine is not in paused state
|
||||||
|
"""
|
||||||
|
self.llm_logger.info("Update Model Weights")
|
||||||
|
with self._pause_cond:
|
||||||
|
if self.is_paused is False:
|
||||||
|
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||||
|
self.llm_logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
return self._call_worker(control_request, 60)
|
||||||
|
|
||||||
|
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||||
|
"""Wait for control responses from all workers with a global timeout.
|
||||||
|
|
||||||
|
This method concurrently waits for responses from all control workers
|
||||||
|
and enforces an overall timeout to avoid leaking pending tasks.
|
||||||
|
"""
|
||||||
|
timeout_ms = timeout * 1000
|
||||||
|
# Create one get() coroutine per worker output queue
|
||||||
|
tasks = [output_queue.get(timeout=timeout_ms) for output_queue in self._ctrl_worker_output_queues]
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await asyncio.wait_for(
|
||||||
|
asyncio.gather(*tasks, return_exceptions=True),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Keep the error message consistent with previous behavior
|
||||||
|
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
for output_queue, msg in zip(self._ctrl_worker_output_queues, results):
|
||||||
|
if isinstance(msg, Exception):
|
||||||
|
self.llm_logger.error(f"Call Worker Failed: {output_queue.name} {repr(msg)}")
|
||||||
|
raise Exception(f"Call Worker error: {repr(msg)}")
|
||||||
|
if msg is None:
|
||||||
|
# Preserve original semantics when no message is received
|
||||||
|
raise Exception("Worker Update Weights Timeouted after 600s")
|
||||||
|
response: ControlResponse = msg.payload
|
||||||
|
if response.request_id != request_id:
|
||||||
|
self.llm_logger.info(f"ignore old control response from worker:{output_queue.name} {response}")
|
||||||
|
continue
|
||||||
|
if response.error_code != 200:
|
||||||
|
self.llm_logger.info(f"Call Worker Failed: {output_queue.name} {response.error_message}")
|
||||||
|
raise Exception(f"Call Worker error: {response.error_message}")
|
||||||
|
self.llm_logger.info(f"Call Worker Succeed: {output_queue.name} {response.result}")
|
||||||
|
responses.append(response.result)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def _call_worker(self, control_request: ControlRequest, timeout: int):
|
||||||
|
request_id = control_request.request_id
|
||||||
|
self.engine_worker_queue.put_tasks(([control_request], 1))
|
||||||
|
# Use a single asyncio.run() to concurrently wait for all worker responses.
|
||||||
|
return asyncio.run(self._wait_all_control_responses(request_id, timeout))
|
||||||
|
|
||||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
||||||
self.llm_logger.error(
|
self.llm_logger.error(
|
||||||
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
||||||
@@ -1708,6 +1945,7 @@ class EngineService:
|
|||||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||||
|
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
||||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
|
|||||||
@@ -556,6 +556,7 @@ class LLMEngine:
|
|||||||
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
|
||||||
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.structured_outputs_config.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||||
|
f" --rsync_config '{json.dumps(self.cfg.load_config.rsync_config)}'"
|
||||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
f" --reasoning_parser {self.cfg.structured_outputs_config.reasoning_parser}"
|
||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from typing import TypeVar as TypingTypeVar
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
@@ -99,7 +100,7 @@ class Request:
|
|||||||
guided_json_object: Optional[bool] = None,
|
guided_json_object: Optional[bool] = None,
|
||||||
enable_thinking: Optional[bool] = None,
|
enable_thinking: Optional[bool] = None,
|
||||||
reasoning_max_tokens: Optional[int] = None,
|
reasoning_max_tokens: Optional[int] = None,
|
||||||
trace_carrier: dict = dict(),
|
trace_carrier: Optional[Dict[str, Any]] = None,
|
||||||
dp_rank: Optional[int] = None,
|
dp_rank: Optional[int] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
image_start: int = 0,
|
image_start: int = 0,
|
||||||
@@ -544,6 +545,157 @@ class Request:
|
|||||||
return hasattr(self, key)
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlRequest:
|
||||||
|
"""A generic control request that supports method and args for control operations.
|
||||||
|
|
||||||
|
This request type is used for system-level control operations rather than
|
||||||
|
typical inference requests. It enables dynamic control of engine behavior,
|
||||||
|
resource management, and system configuration via a flexible method-args interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
method: str,
|
||||||
|
args: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
request_id: Unique identifier for the control request.
|
||||||
|
method: The control method to execute (e.g., "reset_scheduler", "get_metrics").
|
||||||
|
args: Optional arguments for the control method.
|
||||||
|
"""
|
||||||
|
self.request_id = request_id
|
||||||
|
self.method = method
|
||||||
|
self.args = args or {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict):
|
||||||
|
"""Create ControlRequest instance from dictionary."""
|
||||||
|
return cls(request_id=d["request_id"], method=d["method"], args=d.get("args", {}))
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert ControlRequest into a serializable dict."""
|
||||||
|
return {"request_id": self.request_id, "method": self.method, "args": self.args}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Provide a clean representation of the control request."""
|
||||||
|
try:
|
||||||
|
if not envs.FD_DEBUG:
|
||||||
|
return f"ControlRequest(request_id={self.request_id}, method={self.method})"
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"ControlRequest("
|
||||||
|
f"request_id={self.request_id}, "
|
||||||
|
f"method={self.method}, "
|
||||||
|
f"args={self.args}"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"<ControlRequest repr failed: {e}>"
|
||||||
|
|
||||||
|
def get_method(self) -> str:
|
||||||
|
"""Get the control method name."""
|
||||||
|
return self.method
|
||||||
|
|
||||||
|
def get_args(self) -> Dict[str, Any]:
|
||||||
|
"""Get the control method arguments."""
|
||||||
|
return self.args.copy()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_control_request(d: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a dictionary represents a valid ControlRequest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d: Dictionary to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the dictionary contains the required fields for a ControlRequest
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if all required fields are present and have correct types
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check field types
|
||||||
|
if "request_id" not in d or not isinstance(d.get("request_id"), str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "method" not in d or not isinstance(d.get("method"), str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Args is optional, but if present should be a dict
|
||||||
|
if "args" in d and not isinstance(d["args"], dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ControlResponse:
|
||||||
|
"""
|
||||||
|
Response for control operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
error_code: int = 200,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
result: Optional[dict] = None,
|
||||||
|
finished: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self.finished = finished
|
||||||
|
self.error_message = error_message
|
||||||
|
self.result = result
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert ControlResponse into a serializable dict."""
|
||||||
|
return {
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"finished": self.finished,
|
||||||
|
"error_code": self.error_code,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"result": self.result,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict):
|
||||||
|
"""Create ControlResponse instance from dictionary."""
|
||||||
|
return cls(
|
||||||
|
request_id=d["request_id"],
|
||||||
|
finished=d.get("finished", True),
|
||||||
|
error_code=d.get("error_code", 200),
|
||||||
|
error_message=d.get("error_message"),
|
||||||
|
result=d.get("result"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_api_json_response(self) -> JSONResponse:
|
||||||
|
"""Convert ControlResponse into a JSONResponse."""
|
||||||
|
status = "success" if self.error_code == 200 else "error"
|
||||||
|
content = {
|
||||||
|
"request_id": self.request_id,
|
||||||
|
"status": status,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"result": self.result,
|
||||||
|
}
|
||||||
|
return JSONResponse(status_code=self.error_code, content=content)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""Provide a clean representation of the control response."""
|
||||||
|
return (
|
||||||
|
f"ControlResponse("
|
||||||
|
f"request_id={self.request_id}, "
|
||||||
|
f"finished={self.finished}, "
|
||||||
|
f"error_code={self.error_code}, "
|
||||||
|
f"error_message={self.error_message}, "
|
||||||
|
f"result={self.result}"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class CompletionOutput:
|
class CompletionOutput:
|
||||||
"""The output data of one completion output of a request.
|
"""The output data of one completion output of a request.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@@ -240,6 +241,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
|
|
||||||
def reschedule_preempt_task(self, request_id, process_func=None):
|
def reschedule_preempt_task(self, request_id, process_func=None):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
llm_logger.debug(f"reschedule {request_id} into waiting queue")
|
||||||
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
|
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
|
||||||
request = self.requests[request_id]
|
request = self.requests[request_id]
|
||||||
if process_func is not None:
|
if process_func is not None:
|
||||||
@@ -266,6 +268,39 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def preempted_all(self):
|
||||||
|
with self.lock:
|
||||||
|
preempted_reqs = []
|
||||||
|
for i in range(len(self.running)):
|
||||||
|
req = self.running.pop()
|
||||||
|
# txt2image: req.use_extend_tables is True, req can not be preempted. txt2image is not used in RL.
|
||||||
|
if req.use_extend_tables:
|
||||||
|
self.running.insert(0, req)
|
||||||
|
continue
|
||||||
|
req.status = RequestStatus.PREEMPTED
|
||||||
|
req.num_computed_tokens = 0
|
||||||
|
self._free_blocks(req)
|
||||||
|
req.cached_block_num = 0
|
||||||
|
self.to_be_rescheduled_request_id_set.add(req.request_id)
|
||||||
|
preempted_reqs.append(self._prepare_preempt_task(req))
|
||||||
|
return preempted_reqs
|
||||||
|
|
||||||
|
def wait_worker_inflight_requests_finish(self, timeout=60):
|
||||||
|
count = 0
|
||||||
|
while count < timeout * 1000:
|
||||||
|
# wait ongoing running and rescheduled requests finished in worker
|
||||||
|
running_reqs_count = len(self.to_be_rescheduled_request_id_set) + len(self.running)
|
||||||
|
if running_reqs_count == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
time.sleep(0.001)
|
||||||
|
if count >= timeout * 1000:
|
||||||
|
llm_logger.info(
|
||||||
|
f"wait_inflight_requests_finish timeout after {timeout} seconds, "
|
||||||
|
f"still {len(self.to_be_rescheduled_request_id_set)} requests running"
|
||||||
|
)
|
||||||
|
|
||||||
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
|
||||||
"""
|
"""
|
||||||
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
|
If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out.
|
||||||
@@ -1347,3 +1382,16 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc())
|
||||||
main_process_metrics.num_requests_running.set(len(self.running))
|
main_process_metrics.num_requests_running.set(len(self.running))
|
||||||
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))
|
main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running))
|
||||||
|
|
||||||
|
def log_status(self):
|
||||||
|
llm_logger.info(
|
||||||
|
f"ResourceManagerV1( "
|
||||||
|
f"waiting={len(self.waiting)}, "
|
||||||
|
f"running={len(self.running)}, "
|
||||||
|
f"preempted={len(self.to_be_rescheduled_request_id_set)}, "
|
||||||
|
f"tasks_list={self.tasks_list}, "
|
||||||
|
f"stop_flags={self.stop_flags}, "
|
||||||
|
f"req_dict={self.req_dict}, "
|
||||||
|
f"requests={self.requests}, "
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -29,7 +30,12 @@ from filelock import FileLock
|
|||||||
import fastdeploy.metrics.trace as tracing
|
import fastdeploy.metrics.trace as tracing
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request, RequestStatus
|
from fastdeploy.engine.request import (
|
||||||
|
ControlRequest,
|
||||||
|
ControlResponse,
|
||||||
|
Request,
|
||||||
|
RequestStatus,
|
||||||
|
)
|
||||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||||
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
from fastdeploy.eplb.utils import RedundantExpertWorkload
|
||||||
@@ -526,6 +532,23 @@ class EngineClient:
|
|||||||
|
|
||||||
return True, ""
|
return True, ""
|
||||||
|
|
||||||
|
async def run_control_method(self, request: ControlRequest):
|
||||||
|
api_server_logger.info(f"Start Run Control Method: {request}")
|
||||||
|
self.zmq_client.send_json(request.to_dict())
|
||||||
|
request_id = request.request_id
|
||||||
|
dealer, response_queue = await self.connection_manager.get_connection(request_id)
|
||||||
|
dealer.write([b"", request_id.encode("utf-8")])
|
||||||
|
try:
|
||||||
|
# todo: support user specified timeout. default 600s is enough for most control cases
|
||||||
|
response = await asyncio.wait_for(response_queue.get(), timeout=600)
|
||||||
|
response = ControlResponse.from_dict(response[0])
|
||||||
|
api_server_logger.info(f"End Run Control Method: {response}")
|
||||||
|
return response
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
error_response = ControlResponse(request_id, 500, "Timeout waiting for control method response")
|
||||||
|
api_server_logger.error(f"Error Run Control Method: {error_response}")
|
||||||
|
return error_response
|
||||||
|
|
||||||
def is_workers_alive(self):
|
def is_workers_alive(self):
|
||||||
"""
|
"""
|
||||||
Check the health of the model server by checking whether all workers are alive.
|
Check the health of the model server by checking whether all workers are alive.
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import signal
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ from fastdeploy.engine.args_utils import EngineArgs
|
|||||||
from fastdeploy.engine.async_llm import AsyncLLM
|
from fastdeploy.engine.async_llm import AsyncLLM
|
||||||
from fastdeploy.engine.engine import LLMEngine
|
from fastdeploy.engine.engine import LLMEngine
|
||||||
from fastdeploy.engine.expert_service import ExpertService
|
from fastdeploy.engine.expert_service import ExpertService
|
||||||
|
from fastdeploy.engine.request import ControlRequest
|
||||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||||
from fastdeploy.entrypoints.engine_client import EngineClient
|
from fastdeploy.entrypoints.engine_client import EngineClient
|
||||||
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
from fastdeploy.entrypoints.openai.middleware import AuthenticationMiddleware
|
||||||
@@ -370,6 +372,66 @@ def ping(raw_request: Request) -> Response:
|
|||||||
return health(raw_request)
|
return health(raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/pause")
|
||||||
|
async def pause(request: Request) -> Response:
|
||||||
|
# todo: support wait_for_inflight_requests(default False), clear_cache(default True) arguments
|
||||||
|
request_id = f"control-{uuid.uuid4()}"
|
||||||
|
control_request = ControlRequest(request_id, "pause")
|
||||||
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||||
|
return control_response.to_api_json_response()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/resume")
|
||||||
|
async def resume(request: Request) -> Response:
|
||||||
|
request_id = f"control-{uuid.uuid4()}"
|
||||||
|
control_request = ControlRequest(request_id, "resume")
|
||||||
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||||
|
return control_response.to_api_json_response()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/is_paused")
|
||||||
|
async def is_paused(request: Request) -> Response:
|
||||||
|
request_id = f"control-{uuid.uuid4()}"
|
||||||
|
control_request = ControlRequest(request_id, "is_paused")
|
||||||
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||||
|
return control_response.to_api_json_response()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/update_weights")
|
||||||
|
async def update_weights(request: Request) -> Response:
|
||||||
|
request_id = f"control-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
request_data = await request.json() if await request.body() else {}
|
||||||
|
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
# Validate and extract version parameter
|
||||||
|
if "version" in request_data and request_data["version"] is not None:
|
||||||
|
if not isinstance(request_data["version"], str):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400, content={"error": "Invalid parameter type", "message": "version must be a string"}
|
||||||
|
)
|
||||||
|
args["version"] = request_data["version"]
|
||||||
|
|
||||||
|
# Validate and extract rsync_config parameter
|
||||||
|
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
|
||||||
|
if not isinstance(request_data["rsync_config"], dict):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
|
||||||
|
)
|
||||||
|
if "etcd_server" not in request_data["rsync_config"]:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
|
||||||
|
)
|
||||||
|
args["rsync_config"] = request_data["rsync_config"]
|
||||||
|
|
||||||
|
control_request = ControlRequest(request_id, "update_weights", args)
|
||||||
|
control_response = await app.state.engine_client.run_control_method(control_request)
|
||||||
|
return control_response.to_api_json_response()
|
||||||
|
|
||||||
|
|
||||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||||
"""
|
"""
|
||||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class Queue(BaseComponent):
|
|||||||
else:
|
else:
|
||||||
self.socket.bind(full_ep)
|
self.socket.bind(full_ep)
|
||||||
|
|
||||||
fmq_logger.info(f"Queue {name} initialized on {full_ep}")
|
fmq_logger.info(f"Queue {name}({role}) initialized on {full_ep}")
|
||||||
|
|
||||||
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
|
async def put(self, data: Any, shm_threshold: int = 1024 * 1024):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1034,7 +1034,7 @@ class TokenProcessor:
|
|||||||
finished=True,
|
finished=True,
|
||||||
metrics=RequestMetrics(
|
metrics=RequestMetrics(
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
request_start_time=task.arrival_time,
|
request_start_time=task.metrics.arrival_time,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
@@ -27,13 +28,38 @@ from fastdeploy.config import FDConfig
|
|||||||
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
|
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
|
||||||
|
|
||||||
|
|
||||||
|
def sync_weights_by_rdma(config, step, rank):
|
||||||
|
from checkpoint_transfer.core import RDMAWeightsDownloader
|
||||||
|
|
||||||
|
downloader = RDMAWeightsDownloader(config)
|
||||||
|
downloader.initialize()
|
||||||
|
logger.info(f"Fetching weights for step:{step}, rank:{rank}...")
|
||||||
|
data = downloader.get_weights(step, rank)
|
||||||
|
if data is None:
|
||||||
|
logger.error("Failed to get weights!")
|
||||||
|
raise Exception("Failed to rsync weights through checkpoint_transfer")
|
||||||
|
logger.info(f"Successfully retrieved data. Type: {type(data)}")
|
||||||
|
if isinstance(data, np.ndarray):
|
||||||
|
data_bytes = data.tobytes()
|
||||||
|
elif isinstance(data, (bytes, bytearray)):
|
||||||
|
data_bytes = data
|
||||||
|
else:
|
||||||
|
data_bytes = bytes(data)
|
||||||
|
logger.info(f"Data size: {len(data_bytes)} bytes")
|
||||||
|
|
||||||
|
buffer = io.BytesIO(data_bytes)
|
||||||
|
new_state_dict = paddle.load(buffer)
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
class DynamicWeightManager:
|
class DynamicWeightManager:
|
||||||
"""Manages model weights loading, updating and shared state across processes."""
|
"""Manages model weights loading, updating and shared state across processes."""
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig, models):
|
def __init__(self, fd_config: FDConfig, models, local_rank: int):
|
||||||
"""Initialize with config and model instances."""
|
"""Initialize with config and model instances."""
|
||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.load_config = fd_config.load_config
|
self.load_config = fd_config.load_config
|
||||||
|
self.local_rank = local_rank
|
||||||
self.parallel_config = fd_config.parallel_config
|
self.parallel_config = fd_config.parallel_config
|
||||||
self.state_dict: Dict[str, paddle.Tensor] = {}
|
self.state_dict: Dict[str, paddle.Tensor] = {}
|
||||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||||
@@ -46,7 +72,10 @@ class DynamicWeightManager:
|
|||||||
else:
|
else:
|
||||||
self.model_list = models
|
self.model_list = models
|
||||||
self._capture_model_state()
|
self._capture_model_state()
|
||||||
self.update_parameters()
|
if self.load_config.load_strategy == "rsync":
|
||||||
|
self.update_weights_by_rdma()
|
||||||
|
else:
|
||||||
|
self.update_parameters()
|
||||||
self.finalize_update()
|
self.finalize_update()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -62,6 +91,74 @@ class DynamicWeightManager:
|
|||||||
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||||
self.state_dict[name] = param
|
self.state_dict[name] = param
|
||||||
|
|
||||||
|
def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||||
|
def valid_parameters(old_state_dict, new_state_dict):
|
||||||
|
is_valid = True
|
||||||
|
for key in old_state_dict:
|
||||||
|
if key not in new_state_dict:
|
||||||
|
is_valid = False
|
||||||
|
logger.error(f"Invalid parameter: {key} not in new_state_dict")
|
||||||
|
elif old_state_dict[key].shape != new_state_dict[key].shape:
|
||||||
|
is_valid = False
|
||||||
|
logger.error(
|
||||||
|
f"Invalid parameter: {key} shape mismatch, "
|
||||||
|
f"new shape:{new_state_dict[key].shape}, "
|
||||||
|
f"old shape:{old_state_dict[key].shape}"
|
||||||
|
)
|
||||||
|
elif old_state_dict[key].dtype != new_state_dict[key].dtype:
|
||||||
|
is_valid = False
|
||||||
|
logger.error(f"Invalid parameter: {key} dtype mismatch")
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
if rsync_config is None:
|
||||||
|
rsync_config = self.fd_config.load_config.rsync_config
|
||||||
|
if rsync_config is None or len(rsync_config) == 0:
|
||||||
|
raise Exception(
|
||||||
|
"rsync config not set, please set it in 1) launch arguments '--rsync-config' "
|
||||||
|
"or 2) interface arguments 'rsync_config'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if version is None or version == "":
|
||||||
|
version = self.read_model_version_from_file()
|
||||||
|
if version is None or version == "":
|
||||||
|
raise Exception(
|
||||||
|
"rsync model version not set, please set it in 1) {model_version}/version.txt "
|
||||||
|
"or 2) interface arguments 'version'"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}")
|
||||||
|
rank = self.local_rank
|
||||||
|
|
||||||
|
sync_start = time.perf_counter()
|
||||||
|
new_state_dict = sync_weights_by_rdma(rsync_config, version, rank)
|
||||||
|
sync_cost = time.perf_counter() - sync_start
|
||||||
|
logger.info(f"weights sync cost {sync_cost:.2f} seconds")
|
||||||
|
|
||||||
|
old_state_dict = self.state_dict
|
||||||
|
if not valid_parameters(old_state_dict, new_state_dict):
|
||||||
|
error_msg = "Invalid new_state_dict, update parameters failed"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
update_start = time.perf_counter()
|
||||||
|
for name, param in old_state_dict.items():
|
||||||
|
param.set_value(new_state_dict[name])
|
||||||
|
update_cost = time.perf_counter() - update_start
|
||||||
|
logger.info(f"params set value cost {update_cost:.2f} seconds")
|
||||||
|
|
||||||
|
total_cost = time.perf_counter() - sync_start
|
||||||
|
logger.info(
|
||||||
|
f"END update_weights_by_rdma, cost {total_cost:.2f} seconds"
|
||||||
|
f" version:{version}, rsync_config: {rsync_config}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"sync_cost": sync_cost,
|
||||||
|
"update_cost": update_cost,
|
||||||
|
"total_cost": total_cost,
|
||||||
|
"version": version,
|
||||||
|
"rank": rank,
|
||||||
|
}
|
||||||
|
|
||||||
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
|
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
|
||||||
"""Core method to update model parameters based on strategy."""
|
"""Core method to update model parameters based on strategy."""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -257,6 +354,17 @@ class DynamicWeightManager:
|
|||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
value[self.rank] = status
|
value[self.rank] = status
|
||||||
|
|
||||||
|
def read_model_version_from_file(self):
|
||||||
|
model_dir = self.fd_config.model_config.model
|
||||||
|
version_file = os.path.join(model_dir, "version.txt")
|
||||||
|
try:
|
||||||
|
with open(version_file, "r", encoding="utf-8") as f:
|
||||||
|
version = f.read().strip()
|
||||||
|
return version
|
||||||
|
except (FileNotFoundError, OSError, IOError) as e:
|
||||||
|
logger.error(f"Failed to read model version file '{version_file}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
|
def check_model_weights_status(model_weights_status, kv_cache_status, model_runner, pid, block):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -158,6 +158,10 @@ class LocalScheduler:
|
|||||||
else:
|
else:
|
||||||
self.ids_read_cursor -= len(expired_ids)
|
self.ids_read_cursor -= len(expired_ids)
|
||||||
|
|
||||||
|
def get_inflight_requests(self) -> List[Request]:
|
||||||
|
with self.mutex:
|
||||||
|
return [request.raw for request in self.requests.values()]
|
||||||
|
|
||||||
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
def put_requests(self, requests: List[Request]) -> List[Tuple[str, Optional[str]]]:
|
||||||
"""
|
"""
|
||||||
Add new requests to the scheduler queue.
|
Add new requests to the scheduler queue.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import queue
|
|||||||
import time
|
import time
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
@@ -1518,7 +1518,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
if self.fd_config.load_config.dynamic_load_weight:
|
if self.fd_config.load_config.dynamic_load_weight:
|
||||||
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager
|
||||||
|
|
||||||
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model)
|
self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model, self.local_rank)
|
||||||
|
|
||||||
# 2. Load lora model
|
# 2. Load lora model
|
||||||
|
|
||||||
@@ -2798,6 +2798,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||||
|
|
||||||
|
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||||
|
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
|
||||||
|
|
||||||
def padding_cudagraph_inputs(self) -> None:
|
def padding_cudagraph_inputs(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import pynvml
|
import pynvml
|
||||||
@@ -188,6 +188,10 @@ class GpuWorker(WorkerBase):
|
|||||||
# accurate cache size
|
# accurate cache size
|
||||||
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||||
|
|
||||||
|
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
|
||||||
|
"""update weights in place"""
|
||||||
|
return self.model_runner.update_weights(version, rsync_config)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_forward_batch: Optional[List[Request]] = None,
|
model_forward_batch: Optional[List[Request]] = None,
|
||||||
|
|||||||
@@ -15,9 +15,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -42,7 +44,7 @@ from fastdeploy.config import (
|
|||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
StructuredOutputsConfig,
|
StructuredOutputsConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.engine.request import RequestType
|
from fastdeploy.engine.request import ControlRequest, ControlResponse, RequestType
|
||||||
from fastdeploy.eplb.async_expert_loader import (
|
from fastdeploy.eplb.async_expert_loader import (
|
||||||
MODEL_MAIN_NAME,
|
MODEL_MAIN_NAME,
|
||||||
REARRANGE_EXPERT_MAGIC_NUM,
|
REARRANGE_EXPERT_MAGIC_NUM,
|
||||||
@@ -57,6 +59,7 @@ from fastdeploy.inter_communicator import (
|
|||||||
ModelWeightsStatus,
|
ModelWeightsStatus,
|
||||||
RearrangeExpertStatus,
|
RearrangeExpertStatus,
|
||||||
)
|
)
|
||||||
|
from fastdeploy.inter_communicator.fmq import FMQ
|
||||||
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
from fastdeploy.model_executor.layers.quantization import parse_quant_config
|
||||||
from fastdeploy.model_executor.utils import v1_loader_support
|
from fastdeploy.model_executor.utils import v1_loader_support
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
@@ -164,6 +167,12 @@ class PaddleDisWorkerProc:
|
|||||||
|
|
||||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
|
|
||||||
|
def init_control(self):
|
||||||
|
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
|
||||||
|
queue_name = f"ctrl_w2e_rank{self.local_rank}_{engine_worker_queue_port}"
|
||||||
|
logger.info(f"Init Control Output Queue: {queue_name}(producer)")
|
||||||
|
self._ctrl_output = FMQ().queue(queue_name, "producer")
|
||||||
|
|
||||||
def init_health_status(self) -> None:
|
def init_health_status(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the health status of the worker.
|
Initialize the health status of the worker.
|
||||||
@@ -513,10 +522,20 @@ class PaddleDisWorkerProc:
|
|||||||
else:
|
else:
|
||||||
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY
|
||||||
|
|
||||||
req_dicts = []
|
req_dicts, control_reqs = [], []
|
||||||
for req_dict, bsz in tasks:
|
for req_dict, bsz in tasks:
|
||||||
max_occupied_batch_index = int(bsz)
|
if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest):
|
||||||
req_dicts.extend(req_dict)
|
control_reqs.append(req_dict[0])
|
||||||
|
else:
|
||||||
|
max_occupied_batch_index = int(bsz)
|
||||||
|
req_dicts.extend(req_dict)
|
||||||
|
|
||||||
|
# todo: run control request async
|
||||||
|
if len(control_reqs) > 0:
|
||||||
|
logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.")
|
||||||
|
for control_req in control_reqs:
|
||||||
|
self.run_control_method(control_req)
|
||||||
|
self._tp_barrier_wait() if tp_size > 1 else None
|
||||||
|
|
||||||
# Count prefill requests in current batch
|
# Count prefill requests in current batch
|
||||||
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
num_prefill_requests = sum(1 for req in req_dicts if req.task_type == RequestType.PREFILL)
|
||||||
@@ -655,6 +674,32 @@ class PaddleDisWorkerProc:
|
|||||||
paddle.distributed.barrier()
|
paddle.distributed.barrier()
|
||||||
self.loaded_model_signal.value[0] = 1
|
self.loaded_model_signal.value[0] = 1
|
||||||
|
|
||||||
|
def run_control_method(self, control_request: ControlRequest) -> None:
|
||||||
|
logger.info(f"Start run control request: {control_request}")
|
||||||
|
request_id = control_request.request_id
|
||||||
|
method = control_request.method
|
||||||
|
kwargs = control_request.args
|
||||||
|
|
||||||
|
handler = getattr(self.worker, method, None)
|
||||||
|
if handler is None or not callable(handler):
|
||||||
|
error_msg = f"Rank-{self.local_rank}: Unknown control method {method}"
|
||||||
|
error_result = ControlResponse(request_id, 400, error_msg)
|
||||||
|
asyncio.run(self._ctrl_output.put(error_result))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = handler(**kwargs)
|
||||||
|
succ_result = ControlResponse(request_id, 200, "Success", result)
|
||||||
|
logger.info(
|
||||||
|
f"Rank-{self.local_rank} Success run control request: {control_request}, response: {succ_result}"
|
||||||
|
)
|
||||||
|
asyncio.run(self._ctrl_output.put(succ_result, shm_threshold=100 * 1024 * 1024))
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Rank-{self.local_rank} Failed run control method {method}: {str(e)}"
|
||||||
|
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||||
|
error_result = ControlResponse(request_id, 500, error_msg)
|
||||||
|
asyncio.run(self._ctrl_output.put(error_result))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
@@ -813,12 +858,18 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load_strategy",
|
"--load_strategy",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["ipc", "ipc_snapshot", "meta", "normal"],
|
choices=["ipc", "ipc_snapshot", "meta", "normal", "rsync"],
|
||||||
default="ipc_snapshot",
|
default="ipc_snapshot",
|
||||||
help="Weight loading method when dynamic loading is enabled: "
|
help="Weight loading method when dynamic loading is enabled: "
|
||||||
"'ipc': real-time IPC streaming with automatic resharding, "
|
"'ipc': real-time IPC streaming with automatic resharding, "
|
||||||
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
"'ipc_snapshot': load from disk snapshot of IPC weights.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rsync_config",
|
||||||
|
type=json.loads,
|
||||||
|
default=None,
|
||||||
|
help="Rsync weights config",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_logprob",
|
"--enable_logprob",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -1045,6 +1096,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
|
|
||||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||||
|
logger.info(f"- Rsync config: {load_config.rsync_config}, {type(load_config.rsync_config)}")
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
@@ -1112,6 +1164,7 @@ def run_worker_proc() -> None:
|
|||||||
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
|
worker_proc = IluvatarPaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||||
else:
|
else:
|
||||||
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
worker_proc = PaddleDisWorkerProc(fd_config, ranks, local_rank)
|
||||||
|
worker_proc.init_control()
|
||||||
|
|
||||||
# Initialize device and create model runner
|
# Initialize device and create model runner
|
||||||
worker_proc.init_device()
|
worker_proc.init_device()
|
||||||
|
|||||||
@@ -0,0 +1,329 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import ControlRequest, ControlResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestControlRequest(unittest.TestCase):
|
||||||
|
"""Test cases for ControlRequest class."""
|
||||||
|
|
||||||
|
def test_initialization_basic(self):
|
||||||
|
"""Test basic initialization of ControlRequest."""
|
||||||
|
request_id = "test_request_123"
|
||||||
|
method = "get_metrics"
|
||||||
|
|
||||||
|
request = ControlRequest(request_id=request_id, method=method)
|
||||||
|
|
||||||
|
self.assertEqual(request.request_id, request_id)
|
||||||
|
self.assertEqual(request.method, method)
|
||||||
|
self.assertEqual(request.args, {})
|
||||||
|
|
||||||
|
def test_initialization_with_args(self):
|
||||||
|
"""Test initialization with arguments."""
|
||||||
|
request_id = "test_request_456"
|
||||||
|
method = "reset_scheduler"
|
||||||
|
args = {"force": True, "timeout": 30}
|
||||||
|
|
||||||
|
request = ControlRequest(request_id=request_id, method=method, args=args)
|
||||||
|
|
||||||
|
self.assertEqual(request.request_id, request_id)
|
||||||
|
self.assertEqual(request.method, method)
|
||||||
|
self.assertEqual(request.args, args)
|
||||||
|
|
||||||
|
def test_from_dict_basic(self):
|
||||||
|
"""Test creating ControlRequest from dictionary (basic case)."""
|
||||||
|
data = {"request_id": "test_from_dict", "method": "status_check"}
|
||||||
|
|
||||||
|
request = ControlRequest.from_dict(data)
|
||||||
|
|
||||||
|
self.assertEqual(request.request_id, data["request_id"])
|
||||||
|
self.assertEqual(request.method, data["method"])
|
||||||
|
self.assertEqual(request.args, {})
|
||||||
|
|
||||||
|
def test_from_dict_with_args(self):
|
||||||
|
"""Test creating ControlRequest from dictionary with arguments."""
|
||||||
|
data = {
|
||||||
|
"request_id": "test_from_dict_args",
|
||||||
|
"method": "configure",
|
||||||
|
"args": {"max_requests": 100, "queue_timeout": 60},
|
||||||
|
}
|
||||||
|
|
||||||
|
request = ControlRequest.from_dict(data)
|
||||||
|
|
||||||
|
self.assertEqual(request.request_id, data["request_id"])
|
||||||
|
self.assertEqual(request.method, data["method"])
|
||||||
|
self.assertEqual(request.args, data["args"])
|
||||||
|
|
||||||
|
def test_to_dict_basic(self):
|
||||||
|
"""Test converting ControlRequest to dictionary (basic case)."""
|
||||||
|
request = ControlRequest(request_id="test_to_dict", method="health_check")
|
||||||
|
|
||||||
|
result = request.to_dict()
|
||||||
|
|
||||||
|
expected = {"request_id": "test_to_dict", "method": "health_check", "args": {}}
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_to_dict_with_args(self):
|
||||||
|
"""Test converting ControlRequest to dictionary with arguments."""
|
||||||
|
args = {"setting1": "value1", "setting2": 42}
|
||||||
|
request = ControlRequest(request_id="test_to_dict_args", method="update_settings", args=args)
|
||||||
|
|
||||||
|
result = request.to_dict()
|
||||||
|
|
||||||
|
expected = {"request_id": "test_to_dict_args", "method": "update_settings", "args": args}
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_get_method(self):
|
||||||
|
"""Test get_method method."""
|
||||||
|
method = "custom_operation"
|
||||||
|
request = ControlRequest(request_id="test", method=method)
|
||||||
|
|
||||||
|
self.assertEqual(request.get_method(), method)
|
||||||
|
|
||||||
|
def test_get_args(self):
|
||||||
|
"""Test get_args method."""
|
||||||
|
args = {"param1": "value1", "param2": 123}
|
||||||
|
request = ControlRequest(request_id="test", method="test", args=args)
|
||||||
|
|
||||||
|
result_args = request.get_args()
|
||||||
|
|
||||||
|
self.assertEqual(result_args, args)
|
||||||
|
# Ensure it returns a copy, not the original dict
|
||||||
|
self.assertIsNot(result_args, args)
|
||||||
|
|
||||||
|
def test_is_control_request_valid(self):
|
||||||
|
"""Test is_control_request method with valid data."""
|
||||||
|
valid_data = [
|
||||||
|
{"request_id": "test1", "method": "method1"},
|
||||||
|
{"request_id": "test2", "method": "method2", "args": {}},
|
||||||
|
{"request_id": "test3", "method": "method3", "args": {"key": "value"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in valid_data:
|
||||||
|
with self.subTest(data=data):
|
||||||
|
self.assertTrue(ControlRequest.is_control_request(data))
|
||||||
|
|
||||||
|
def test_is_control_request_invalid(self):
|
||||||
|
"""Test is_control_request method with invalid data."""
|
||||||
|
invalid_data = [
|
||||||
|
# Missing required fields
|
||||||
|
{"method": "test"}, # missing request_id
|
||||||
|
{"request_id": "test"}, # missing method
|
||||||
|
# Wrong field types
|
||||||
|
{"request_id": 123, "method": "test"}, # request_id not string
|
||||||
|
{"request_id": "test", "method": 456}, # method not string
|
||||||
|
{"request_id": "test", "method": "test", "args": "not_a_dict"}, # args not dict
|
||||||
|
# Not a dict
|
||||||
|
"not_a_dict",
|
||||||
|
123,
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in invalid_data:
|
||||||
|
with self.subTest(data=data):
|
||||||
|
self.assertFalse(ControlRequest.is_control_request(data))
|
||||||
|
|
||||||
|
def test_repr_simple(self):
|
||||||
|
"""Test __repr__ method in simple mode."""
|
||||||
|
with patch("fastdeploy.envs.FD_DEBUG", False):
|
||||||
|
request = ControlRequest(request_id="test_repr", method="test_method")
|
||||||
|
repr_str = repr(request)
|
||||||
|
|
||||||
|
self.assertIn("ControlRequest", repr_str)
|
||||||
|
self.assertIn("test_repr", repr_str)
|
||||||
|
self.assertIn("test_method", repr_str)
|
||||||
|
self.assertNotIn("args", repr_str) # Args not shown in simple mode
|
||||||
|
|
||||||
|
def test_repr_debug_mode(self):
|
||||||
|
"""Test __repr__ method in debug mode."""
|
||||||
|
with patch("fastdeploy.envs.FD_DEBUG", True):
|
||||||
|
args = {"debug_param": "debug_value"}
|
||||||
|
request = ControlRequest(request_id="test_repr", method="test_method", args=args)
|
||||||
|
repr_str = repr(request)
|
||||||
|
|
||||||
|
self.assertIn("ControlRequest", repr_str)
|
||||||
|
self.assertIn("test_repr", repr_str)
|
||||||
|
self.assertIn("test_method", repr_str)
|
||||||
|
self.assertIn("debug_param", repr_str) # Args shown in debug mode
|
||||||
|
|
||||||
|
|
||||||
|
class TestControlResponse(unittest.TestCase):
|
||||||
|
"""Test cases for ControlResponse class."""
|
||||||
|
|
||||||
|
def test_initialization_basic(self):
|
||||||
|
"""Test basic initialization of ControlResponse."""
|
||||||
|
request_id = "test_response_123"
|
||||||
|
|
||||||
|
response = ControlResponse(request_id=request_id)
|
||||||
|
|
||||||
|
self.assertEqual(response.request_id, request_id)
|
||||||
|
self.assertEqual(response.error_code, 200)
|
||||||
|
self.assertIsNone(response.error_message)
|
||||||
|
self.assertIsNone(response.result)
|
||||||
|
self.assertTrue(response.finished)
|
||||||
|
|
||||||
|
def test_initialization_with_all_params(self):
|
||||||
|
"""Test initialization with all parameters."""
|
||||||
|
request_id = "test_response_456"
|
||||||
|
error_code = 404
|
||||||
|
error_message = "Not found"
|
||||||
|
result = {"data": "some_result"}
|
||||||
|
finished = False
|
||||||
|
|
||||||
|
response = ControlResponse(
|
||||||
|
request_id=request_id, error_code=error_code, error_message=error_message, result=result, finished=finished
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.request_id, request_id)
|
||||||
|
self.assertEqual(response.error_code, error_code)
|
||||||
|
self.assertEqual(response.error_message, error_message)
|
||||||
|
self.assertEqual(response.result, result)
|
||||||
|
self.assertEqual(response.finished, finished)
|
||||||
|
|
||||||
|
def test_initialization_error_cases(self):
|
||||||
|
"""Test initialization with various error codes."""
|
||||||
|
test_cases = [
|
||||||
|
(200, None, True), # Success case
|
||||||
|
(400, "Bad Request", False), # Client error
|
||||||
|
(500, "Internal Error", True), # Server error
|
||||||
|
]
|
||||||
|
|
||||||
|
for error_code, error_message, finished in test_cases:
|
||||||
|
with self.subTest(error_code=error_code):
|
||||||
|
response = ControlResponse(
|
||||||
|
request_id="test", error_code=error_code, error_message=error_message, finished=finished
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.error_code, error_code)
|
||||||
|
self.assertEqual(response.error_message, error_message)
|
||||||
|
self.assertEqual(response.finished, finished)
|
||||||
|
|
||||||
|
def test_from_dict_basic(self):
|
||||||
|
"""Test creating ControlResponse from dictionary (basic case)."""
|
||||||
|
data = {"request_id": "test_from_dict"}
|
||||||
|
|
||||||
|
response = ControlResponse.from_dict(data)
|
||||||
|
|
||||||
|
self.assertEqual(response.request_id, data["request_id"])
|
||||||
|
self.assertEqual(response.error_code, 200)
|
||||||
|
self.assertIsNone(response.error_message)
|
||||||
|
self.assertIsNone(response.result)
|
||||||
|
self.assertTrue(response.finished)
|
||||||
|
|
||||||
|
def test_from_dict_with_all_fields(self):
|
||||||
|
"""Test creating ControlResponse from dictionary with all fields."""
|
||||||
|
data = {
|
||||||
|
"request_id": "test_from_dict_full",
|
||||||
|
"error_code": 500,
|
||||||
|
"error_message": "Test error",
|
||||||
|
"result": {"key": "value"},
|
||||||
|
"finished": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = ControlResponse.from_dict(data)
|
||||||
|
|
||||||
|
self.assertEqual(response.request_id, data["request_id"])
|
||||||
|
self.assertEqual(response.error_code, data["error_code"])
|
||||||
|
self.assertEqual(response.error_message, data["error_message"])
|
||||||
|
self.assertEqual(response.result, data["result"])
|
||||||
|
self.assertEqual(response.finished, data["finished"])
|
||||||
|
|
||||||
|
def test_to_dict_basic(self):
|
||||||
|
"""Test converting ControlResponse to dictionary (basic case)."""
|
||||||
|
response = ControlResponse(request_id="test_to_dict")
|
||||||
|
|
||||||
|
result = response.to_dict()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"request_id": "test_to_dict",
|
||||||
|
"finished": True,
|
||||||
|
"error_code": 200,
|
||||||
|
"error_message": None,
|
||||||
|
"result": None,
|
||||||
|
}
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_to_dict_with_all_fields(self):
|
||||||
|
"""Test converting ControlResponse to dictionary with all fields."""
|
||||||
|
response = ControlResponse(
|
||||||
|
request_id="test_to_dict_full",
|
||||||
|
error_code=400,
|
||||||
|
error_message="Validation failed",
|
||||||
|
result={"valid": False, "reason": "missing_field"},
|
||||||
|
finished=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.to_dict()
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"request_id": "test_to_dict_full",
|
||||||
|
"finished": False,
|
||||||
|
"error_code": 400,
|
||||||
|
"error_message": "Validation failed",
|
||||||
|
"result": {"valid": False, "reason": "missing_field"},
|
||||||
|
}
|
||||||
|
self.assertEqual(result, expected)
|
||||||
|
|
||||||
|
def test_to_api_json_response_success(self):
|
||||||
|
"""Test converting to JSONResponse for successful response."""
|
||||||
|
result_data = {"metrics": {"cpu_usage": 0.5, "memory_used": 1024}}
|
||||||
|
response = ControlResponse(request_id="test_json_success", result=result_data)
|
||||||
|
|
||||||
|
json_response = response.to_api_json_response()
|
||||||
|
|
||||||
|
self.assertIsInstance(json_response, JSONResponse)
|
||||||
|
self.assertEqual(json_response.status_code, 200)
|
||||||
|
|
||||||
|
content = json_response.body.decode("utf-8")
|
||||||
|
self.assertIn("success", content)
|
||||||
|
self.assertIn("test_json_success", content)
|
||||||
|
self.assertIn("cpu_usage", content)
|
||||||
|
|
||||||
|
def test_to_api_json_response_error(self):
|
||||||
|
"""Test converting to JSONResponse for error response."""
|
||||||
|
response = ControlResponse(request_id="test_json_error", error_code=503, error_message="Service unavailable")
|
||||||
|
|
||||||
|
json_response = response.to_api_json_response()
|
||||||
|
|
||||||
|
self.assertIsInstance(json_response, JSONResponse)
|
||||||
|
self.assertEqual(json_response.status_code, 503)
|
||||||
|
|
||||||
|
content = json_response.body.decode("utf-8")
|
||||||
|
self.assertIn("error", content)
|
||||||
|
self.assertIn("test_json_error", content)
|
||||||
|
self.assertIn("Service unavailable", content)
|
||||||
|
|
||||||
|
def test_repr_method(self):
|
||||||
|
"""Test __repr__ method."""
|
||||||
|
response = ControlResponse(
|
||||||
|
request_id="test_repr", error_code=200, error_message=None, result={"data": "test"}, finished=True
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(response)
|
||||||
|
|
||||||
|
# Check that all important fields are represented
|
||||||
|
self.assertIn("ControlResponse", repr_str)
|
||||||
|
self.assertIn("test_repr", repr_str)
|
||||||
|
self.assertIn("200", repr_str)
|
||||||
|
self.assertIn("test", repr_str) # from result
|
||||||
|
self.assertIn("True", repr_str) # finished flag
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from fastdeploy.engine.args_utils import EngineArgs
|
||||||
|
from fastdeploy.engine.request import Request, RequestStatus
|
||||||
|
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||||
|
|
||||||
|
MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResourceManagerV1(unittest.TestCase):
|
||||||
|
"""Test cases for ResourceManagerV1."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
max_model_len=8192,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")),
|
||||||
|
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")),
|
||||||
|
)
|
||||||
|
# Create and start the engine service
|
||||||
|
mock_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
self.manager = ResourceManagerV1(
|
||||||
|
max_num_seqs=4,
|
||||||
|
config=mock_config,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
splitwise_role="mixed",
|
||||||
|
local_data_parallel_id=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock cache manager
|
||||||
|
self.manager.cache_manager = Mock()
|
||||||
|
self.manager.cache_manager.free_blocks = Mock()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.manager.need_block_num_signal.clear()
|
||||||
|
|
||||||
|
def test_preempted_all_with_no_running_requests(self):
|
||||||
|
"""Test preempted_all with no running requests."""
|
||||||
|
self.assertEqual(len(self.manager.running), 0)
|
||||||
|
preempted_reqs = self.manager.preempted_all()
|
||||||
|
self.assertEqual(len(preempted_reqs), 0)
|
||||||
|
|
||||||
|
def test_preempted_all_with_normal_requests(self):
|
||||||
|
"""Test preempted_all with normal running requests."""
|
||||||
|
# Add mock running requests
|
||||||
|
req1 = Mock(spec=Request)
|
||||||
|
req1.request_id = "req1"
|
||||||
|
req1.use_extend_tables = False
|
||||||
|
req1.status = RequestStatus.RUNNING
|
||||||
|
req1.block_tables = [1, 2, 3]
|
||||||
|
req1.num_cached_blocks = 0
|
||||||
|
req1.idx = 0
|
||||||
|
|
||||||
|
req2 = Mock(spec=Request)
|
||||||
|
req2.request_id = "req2"
|
||||||
|
req2.use_extend_tables = False
|
||||||
|
req2.status = RequestStatus.RUNNING
|
||||||
|
req2.block_tables = [4, 5]
|
||||||
|
req2.num_cached_blocks = 0
|
||||||
|
req2.idx = 1
|
||||||
|
|
||||||
|
self.manager.running = [req1, req2]
|
||||||
|
|
||||||
|
preempted_reqs = self.manager.preempted_all()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
self.assertEqual(len(preempted_reqs), 2)
|
||||||
|
self.assertEqual(preempted_reqs[0].request_id, "req2")
|
||||||
|
self.assertEqual(preempted_reqs[1].request_id, "req1")
|
||||||
|
|
||||||
|
# Verify request status changed
|
||||||
|
self.assertEqual(req1.status, RequestStatus.PREEMPTED)
|
||||||
|
self.assertEqual(req2.status, RequestStatus.PREEMPTED)
|
||||||
|
|
||||||
|
# Verify added to to_be_rescheduled_request_id_set
|
||||||
|
self.assertIn("req1", self.manager.to_be_rescheduled_request_id_set)
|
||||||
|
self.assertIn("req2", self.manager.to_be_rescheduled_request_id_set)
|
||||||
|
|
||||||
|
self.assertEqual(len(self.manager.running), 0)
|
||||||
|
self.assertEqual(len(self.manager.waiting), 0)
|
||||||
|
self.assertEqual(len(self.manager.to_be_rescheduled_request_id_set), 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -246,6 +246,20 @@ class TestLocalScheduler(unittest.TestCase):
|
|||||||
# Verify only one request exists in scheduler
|
# Verify only one request exists in scheduler
|
||||||
self.assertEqual(len(self.scheduler.requests), 1)
|
self.assertEqual(len(self.scheduler.requests), 1)
|
||||||
|
|
||||||
|
def test_get_inflight_requests(self):
|
||||||
|
"""Test getting inflight requests."""
|
||||||
|
# Add some requests
|
||||||
|
requests = [self.mock_request_1, self.mock_request_2]
|
||||||
|
self.scheduler.put_requests(requests)
|
||||||
|
|
||||||
|
# Get inflight requests
|
||||||
|
inflight_requests = self.scheduler.get_inflight_requests()
|
||||||
|
|
||||||
|
# Verify correct requests are returned
|
||||||
|
self.assertEqual(len(inflight_requests), len(requests))
|
||||||
|
for req in inflight_requests:
|
||||||
|
self.assertIn(req, requests)
|
||||||
|
|
||||||
def test_put_requests_max_size_limit(self):
|
def test_put_requests_max_size_limit(self):
|
||||||
"""Test that max size limit is enforced."""
|
"""Test that max size limit is enforced."""
|
||||||
# Create scheduler with small max size
|
# Create scheduler with small max size
|
||||||
|
|||||||
Reference in New Issue
Block a user