mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
+35
-17
@@ -12,6 +12,7 @@ import random
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
@@ -114,23 +115,26 @@ class Router:
|
||||
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
||||
|
||||
async with self.lock:
|
||||
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
|
||||
self.mixed_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
|
||||
self.prefill_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register prefill instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
|
||||
self.decode_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register decode instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
instance_key = inst_info.get_key()
|
||||
|
||||
if inst_info.role == InstanceRole.MIXED:
|
||||
self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
|
||||
elif inst_info.role == InstanceRole.PREFILL:
|
||||
self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
|
||||
elif inst_info.role == InstanceRole.DECODE:
|
||||
self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
|
||||
|
||||
def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
|
||||
"""Update existing instance or add new one based on key (host_ip:port)."""
|
||||
for i, existing in enumerate(server_list):
|
||||
if existing.get_key() == key:
|
||||
if existing != inst_info:
|
||||
server_list[i] = inst_info
|
||||
logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
|
||||
return
|
||||
|
||||
server_list.append(inst_info)
|
||||
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
|
||||
|
||||
async def registered_number(self):
|
||||
"""Get number of registered instances"""
|
||||
@@ -140,6 +144,14 @@ class Router:
|
||||
"decode": len(self.decode_servers),
|
||||
}
|
||||
|
||||
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
|
||||
"""Get all registered decode instances, optionally filtered by version"""
|
||||
async with self.lock:
|
||||
instances = self.decode_servers
|
||||
if version is not None:
|
||||
instances = [inst for inst in instances if inst.version == version]
|
||||
return [inst.to_dict() for inst in instances]
|
||||
|
||||
async def select_pd(self):
|
||||
"""Select one prefill and one decode server"""
|
||||
async with self.lock:
|
||||
@@ -456,6 +468,12 @@ async def registered_number():
|
||||
return await app.state.router.registered_number()
|
||||
|
||||
|
||||
@app.get("/decode_instances")
|
||||
async def decode_instances(version: Optional[str] = None):
|
||||
"""Get all registered decode instances, optionally filtered by version"""
|
||||
return await app.state.router.get_decode_instances(version)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request_data: dict):
|
||||
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import asyncio
|
||||
from dataclasses import MISSING, asdict, dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@@ -41,6 +41,9 @@ class InstanceInfo:
|
||||
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
||||
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
||||
tp_size: int = 1
|
||||
is_paused: bool = False
|
||||
version: Optional[str] = None
|
||||
connected_decodes: List[Dict] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
|
||||
@@ -92,6 +95,10 @@ class InstanceInfo:
|
||||
url = f"http://{url}"
|
||||
return url
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
||||
return f"{self.host_ip}:{self.port}"
|
||||
|
||||
|
||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user