mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -34,11 +34,11 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import requests
|
||||
import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy.engine.register_manager import RegisterManager
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
@@ -64,7 +64,6 @@ from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
from fastdeploy.router.utils import check_service_health
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
@@ -178,6 +177,13 @@ class EngineService:
|
||||
# between the CPU transfer process and the worker process.
|
||||
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
|
||||
|
||||
# Initialize RegisterManager
|
||||
self._register_manager = RegisterManager(
|
||||
cfg=self.cfg,
|
||||
engine_worker_queue=self.engine_worker_queue,
|
||||
get_is_paused=self._get_is_paused_safe,
|
||||
)
|
||||
|
||||
if self.cfg.eplb_config.enable_eplb:
|
||||
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
init_eplb_signals(cfg, current_suffix)
|
||||
@@ -210,7 +216,7 @@ class EngineService:
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
self._decode_process_splitwise_requests()
|
||||
|
||||
self._register_to_router()
|
||||
self._register_manager.start()
|
||||
|
||||
def start_worker_service(self, async_llm_pid=None):
|
||||
# Initialize IPC signals for worker management
|
||||
@@ -1370,6 +1376,11 @@ class EngineService:
|
||||
with self._pause_cond:
|
||||
return {"is_paused": self.is_paused}
|
||||
|
||||
def _get_is_paused_safe(self) -> bool:
|
||||
"""Thread-safe getter for is_paused state, used by RegisterManager."""
|
||||
with self._pause_cond:
|
||||
return self.is_paused
|
||||
|
||||
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Update model weights
|
||||
Args:
|
||||
@@ -1387,7 +1398,20 @@ class EngineService:
|
||||
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return self._call_worker(control_request, 60)
|
||||
responses = self._call_worker(control_request, 60)
|
||||
|
||||
if responses:
|
||||
new_version = None
|
||||
for resp in responses:
|
||||
# Expect each worker response to be a dict-like object
|
||||
if isinstance(resp, dict) and "version" in resp:
|
||||
new_version = resp.get("version")
|
||||
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
|
||||
break
|
||||
if new_version is not None:
|
||||
self.cfg.model_config.version = new_version
|
||||
|
||||
return responses
|
||||
|
||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||
"""Wait for control responses from all workers with a global timeout.
|
||||
@@ -1699,55 +1723,6 @@ class EngineService:
|
||||
self.llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _register_to_router(self):
|
||||
"""
|
||||
Periodically send server information to the router for registeration, and it is used
|
||||
as a heartbeat signal.
|
||||
"""
|
||||
|
||||
def _register():
|
||||
timeout = 5
|
||||
sleep_seconds = 5
|
||||
is_registered = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
api_server_host = self.cfg.router_config.api_server_host
|
||||
api_server_port = self.cfg.router_config.api_server_port
|
||||
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
||||
if not check_service_health(api_server_url):
|
||||
time.sleep(sleep_seconds)
|
||||
self.llm_logger.info("Wait for API service health and then register to router")
|
||||
time.sleep(sleep_seconds)
|
||||
continue
|
||||
|
||||
router_url = self.cfg.router_config.router
|
||||
resp = requests.post(
|
||||
f"{router_url}/register",
|
||||
json=self.cfg.register_info,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
if not is_registered:
|
||||
is_registered = True
|
||||
self.llm_logger.info("Register to router successfully")
|
||||
else:
|
||||
self.llm_logger.error(
|
||||
f"Send server info to router failed: {resp.status_code}, "
|
||||
f"{resp.text}, {self.cfg.register_info}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
|
||||
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
if self.cfg.router_config.router is None:
|
||||
self.llm_logger.info("Router is not enabled, skip registering to router")
|
||||
else:
|
||||
register_thread = threading.Thread(target=_register, daemon=True)
|
||||
register_thread.start()
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
|
||||
Reference in New Issue
Block a user