[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)

* [Feature] Register to router with version info for PD disaggregation

Add RegisterManager for PD (Prefill-Decode) disaggregated deployment:
- All instances (Prefill/Decode) register to Router with heartbeat
- Prefill instances fetch Decode instance list from Router
- Prefill instances establish eager RDMA connections to Decode instances
- Register info includes: host_ip, port, role, version, is_paused, connected_decodes

Changes:
- Add RegisterManager class for managing PD registration and RDMA connections
- Add version field to ModelConfig for model version tracking
- Add connected_decodes to register_info for tracking connected Decode instances
- Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable

Test fixes:
- Add None checks for load_config in FDConfig.__init__
- Add version attribute to test mock model configs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refine

* remove test

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
jc
2026-03-17 14:43:35 +08:00
committed by GitHub
parent b152baeeee
commit 950366e58d
14 changed files with 507 additions and 97 deletions
+35 -17
View File
@@ -12,6 +12,7 @@ import random
import traceback
from dataclasses import dataclass
from itertools import chain
from typing import Dict, List, Optional
from uuid import uuid4
import aiohttp
@@ -114,23 +115,26 @@ class Router:
raise RuntimeError(f"Instance {inst_info} is not healthy")
async with self.lock:
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
self.mixed_servers.append(inst_info)
logger.info(
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
)
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
self.prefill_servers.append(inst_info)
logger.info(
f"Register prefill instance success: {inst_info}, "
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
)
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
self.decode_servers.append(inst_info)
logger.info(
f"Register decode instance success: {inst_info}, "
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
)
instance_key = inst_info.get_key()
if inst_info.role == InstanceRole.MIXED:
self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
elif inst_info.role == InstanceRole.PREFILL:
self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
elif inst_info.role == InstanceRole.DECODE:
self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
"""Update existing instance or add new one based on key (host_ip:port)."""
for i, existing in enumerate(server_list):
if existing.get_key() == key:
if existing != inst_info:
server_list[i] = inst_info
logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
return
server_list.append(inst_info)
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
async def registered_number(self):
"""Get number of registered instances"""
@@ -140,6 +144,14 @@ class Router:
"decode": len(self.decode_servers),
}
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
"""Get all registered decode instances, optionally filtered by version"""
async with self.lock:
instances = self.decode_servers
if version is not None:
instances = [inst for inst in instances if inst.version == version]
return [inst.to_dict() for inst in instances]
async def select_pd(self):
"""Select one prefill and one decode server"""
async with self.lock:
@@ -456,6 +468,12 @@ async def registered_number():
return await app.state.router.registered_number()
@app.get("/decode_instances")
async def decode_instances(version: Optional[str] = None):
"""Get all registered decode instances, optionally filtered by version"""
return await app.state.router.get_decode_instances(version)
@app.post("/v1/chat/completions")
async def create_chat_completion(request_data: dict):
return await app.state.router.handle_request(request_data, "v1/chat/completions")
+8 -1
View File
@@ -17,7 +17,7 @@
import asyncio
from dataclasses import MISSING, asdict, dataclass, field, fields
from enum import Enum
from typing import Any, List, Union
from typing import Any, Dict, List, Optional, Union
import aiohttp
import requests
@@ -41,6 +41,9 @@ class InstanceInfo:
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
device_ids: Union[List[str], List[int]] = field(default_factory=list)
tp_size: int = 1
is_paused: bool = False
version: Optional[str] = None
connected_decodes: List[Dict] = field(default_factory=list)
@classmethod
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
@@ -92,6 +95,10 @@ class InstanceInfo:
url = f"http://{url}"
return url
def get_key(self) -> str:
"""Generate unique identifier for an instance: 'host_ip:port'."""
return f"{self.host_ip}:{self.port}"
def check_service_health(base_url: str, timeout: int = 3) -> bool:
"""