[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)

* [Feature] Register to router with version info for PD disaggregation

Add RegisterManager for PD (Prefill-Decode) disaggregated deployment:
- All instances (Prefill/Decode) register to Router with heartbeat
- Prefill instances fetch Decode instance list from Router
- Prefill instances establish eager RDMA connections to Decode instances
- Register info includes: host_ip, port, role, version, is_paused, connected_decodes

Changes:
- Add RegisterManager class for managing PD registration and RDMA connections
- Add version field to ModelConfig for model version tracking
- Add connected_decodes to register_info for tracking connected Decode instances
- Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable

Test fixes:
- Add None checks for load_config in FDConfig.__init__
- Add version attribute to test mock model configs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refine

* remove test

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
jc
2026-03-17 14:43:35 +08:00
committed by GitHub
parent b152baeeee
commit 950366e58d
14 changed files with 507 additions and 97 deletions
+28 -53
View File
@@ -34,11 +34,11 @@ from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import requests
import zmq
from tqdm import tqdm
import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import (
ControlRequest,
ControlResponse,
@@ -64,7 +64,6 @@ from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.spec_decode import SpecMethod
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
@@ -178,6 +177,13 @@ class EngineService:
# between the CPU transfer process and the worker process.
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
# Initialize RegisterManager
self._register_manager = RegisterManager(
cfg=self.cfg,
engine_worker_queue=self.engine_worker_queue,
get_is_paused=self._get_is_paused_safe,
)
if self.cfg.eplb_config.enable_eplb:
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
init_eplb_signals(cfg, current_suffix)
@@ -210,7 +216,7 @@ class EngineService:
if self.cfg.scheduler_config.splitwise_role == "decode":
self._decode_process_splitwise_requests()
self._register_to_router()
self._register_manager.start()
def start_worker_service(self, async_llm_pid=None):
# Initialize IPC signals for worker management
@@ -1370,6 +1376,11 @@ class EngineService:
with self._pause_cond:
return {"is_paused": self.is_paused}
def _get_is_paused_safe(self) -> bool:
"""Thread-safe getter for is_paused state, used by RegisterManager."""
with self._pause_cond:
return self.is_paused
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
"""Update model weights
Args:
@@ -1387,7 +1398,20 @@ class EngineService:
error_msg = "Pause LLM Engine first before calling updating weights"
self.llm_logger.error(error_msg)
raise Exception(error_msg)
return self._call_worker(control_request, 60)
responses = self._call_worker(control_request, 60)
if responses:
new_version = None
for resp in responses:
# Expect each worker response to be a dict-like object
if isinstance(resp, dict) and "version" in resp:
new_version = resp.get("version")
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
break
if new_version is not None:
self.cfg.model_config.version = new_version
return responses
async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout.
@@ -1699,55 +1723,6 @@ class EngineService:
self.llm_logger.error(f"Clear data error: {e}")
return False
def _register_to_router(self):
"""
Periodically send server information to the router for registeration, and it is used
as a heartbeat signal.
"""
def _register():
timeout = 5
sleep_seconds = 5
is_registered = False
while True:
try:
api_server_host = self.cfg.router_config.api_server_host
api_server_port = self.cfg.router_config.api_server_port
api_server_url = f"http://{api_server_host}:{api_server_port}"
if not check_service_health(api_server_url):
time.sleep(sleep_seconds)
self.llm_logger.info("Wait for API service health and then register to router")
time.sleep(sleep_seconds)
continue
router_url = self.cfg.router_config.router
resp = requests.post(
f"{router_url}/register",
json=self.cfg.register_info,
timeout=timeout,
)
if resp.ok:
if not is_registered:
is_registered = True
self.llm_logger.info("Register to router successfully")
else:
self.llm_logger.error(
f"Send server info to router failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
except Exception as e:
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
time.sleep(sleep_seconds)
if self.cfg.router_config.router is None:
self.llm_logger.info("Router is not enabled, skip registering to router")
else:
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()
def _exit_sub_services(self):
"""
exit sub services