mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
950366e58d
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
397 lines
16 KiB
Python
397 lines
16 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from typing import Callable, Dict, List
|
|
from uuid import uuid4
|
|
|
|
import requests
|
|
|
|
from fastdeploy import envs
|
|
from fastdeploy.router.utils import check_service_health
|
|
from fastdeploy.utils import register_manager_logger as logger
|
|
|
|
|
|
class RegisterManager:
|
|
"""
|
|
Manages Prefill/Decode instance registration and RDMA connection for PD disaggregation.
|
|
|
|
In PD (Prefill-Decode) disaggregated deployment:
|
|
- All instances (Prefill/Decode) register to Router with heartbeat
|
|
- Prefill instances fetch Decode instance list from Router
|
|
- Prefill instances establish eager RDMA connections to Decode instances
|
|
|
|
Thread Model:
|
|
- _register_to_router: Periodic heartbeat registration thread
|
|
- _eager_connect_loop: Periodic RDMA connection management thread
|
|
- _get_connect_rdma_task_response_loop: Async RDMA connection result receiver thread
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
engine_worker_queue,
|
|
get_is_paused: Callable[[], bool],
|
|
):
|
|
"""
|
|
Initialize RegisterManager.
|
|
|
|
Args:
|
|
cfg: FDConfig object containing router, scheduler, cache configurations
|
|
engine_worker_queue: Queue for communicating RDMA connect tasks with worker
|
|
get_is_paused: Callable that returns current engine pause state
|
|
"""
|
|
self.cfg = cfg
|
|
self.engine_worker_queue = engine_worker_queue
|
|
self.get_is_paused = get_is_paused
|
|
|
|
# Registration state
|
|
self._is_registered = False
|
|
|
|
# RDMA connection state (protected by _lock)
|
|
self.connected_decodes: List[Dict] = [] # Successfully connected Decode instances
|
|
self.connect_status: Dict[str, bool] = {} # task_id -> connection result
|
|
|
|
# Timing configuration (seconds)
|
|
self._timeout = 5 # HTTP request and RDMA connect timeout
|
|
self._sleep_seconds = 5 # Interval between iterations
|
|
|
|
self._lock = threading.Lock() # Protects connected_decodes and connect_status
|
|
|
|
def start(self) -> None:
|
|
"""Start background threads for registration and RDMA connection management."""
|
|
self._register_to_router()
|
|
self._start_eager_connect_loop()
|
|
|
|
def get_connected_decodes(self) -> List[Dict]:
|
|
"""
|
|
Return a snapshot of successfully connected Decode instances.
|
|
Thread-safe: returns a copy to avoid concurrent modification issues.
|
|
"""
|
|
with self._lock:
|
|
return list(self.connected_decodes)
|
|
|
|
def is_registered(self) -> bool:
|
|
"""Return whether this instance has successfully registered to Router."""
|
|
return self._is_registered
|
|
|
|
def _register_to_router(self) -> None:
|
|
"""
|
|
Start background thread for periodic Router registration (heartbeat).
|
|
|
|
Registration info includes: host_ip, port, role, version, is_paused, etc.
|
|
This serves as both initial registration and keep-alive heartbeat.
|
|
"""
|
|
router_url = self.cfg.router_config.router
|
|
if router_url is None:
|
|
logger.info("Router is not enabled, skip registering to router")
|
|
return
|
|
|
|
def _register():
|
|
while True:
|
|
try:
|
|
api_server_host = self.cfg.router_config.api_server_host
|
|
api_server_port = self.cfg.router_config.api_server_port
|
|
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
|
if not check_service_health(api_server_url):
|
|
logger.info("Wait for API service health and then register to router")
|
|
time.sleep(self._sleep_seconds)
|
|
continue
|
|
|
|
# Update registration info
|
|
self.cfg.register_info["is_paused"] = self.get_is_paused()
|
|
self.cfg.register_info["version"] = self.cfg.model_config.version
|
|
self.cfg.register_info["connected_decodes"] = self.get_connected_decodes()
|
|
|
|
resp = requests.post(
|
|
f"{router_url}/register",
|
|
json=self.cfg.register_info,
|
|
timeout=self._timeout,
|
|
)
|
|
|
|
if resp.ok:
|
|
if not self._is_registered:
|
|
self._is_registered = True
|
|
logger.info("Register to router successfully")
|
|
else:
|
|
logger.error(
|
|
f"Send server info to router failed: {resp.status_code}, "
|
|
f"{resp.text}, {self.cfg.register_info}"
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Unexpected error during router registration: {e}")
|
|
|
|
time.sleep(self._sleep_seconds)
|
|
|
|
register_thread = threading.Thread(target=_register, daemon=True)
|
|
register_thread.start()
|
|
|
|
def _start_eager_connect_loop(self) -> None:
|
|
"""
|
|
Start background threads for eager RDMA connection management.
|
|
|
|
Only enabled when all conditions are met:
|
|
- Router is configured
|
|
- This instance is Prefill role
|
|
- FD_ENABLE_PD_RDMA_EAGER_CONNECT=1
|
|
- RDMA transfer protocol is enabled
|
|
|
|
Starts two threads:
|
|
1. _eager_connect_loop: Periodically discovers and connects to Decode instances
|
|
2. _get_connect_rdma_task_response_loop: Receives async RDMA connection results
|
|
"""
|
|
if not self._should_enable_eager_connect():
|
|
logger.info("Eager RDMA connect is not enabled, skip")
|
|
return
|
|
|
|
def _eager_connect_loop():
|
|
while True:
|
|
try:
|
|
self._eager_connect_iteration()
|
|
except Exception as e:
|
|
logger.exception(f"Error in eager connect loop: {e}")
|
|
time.sleep(self._sleep_seconds)
|
|
|
|
connect_thread = threading.Thread(target=_eager_connect_loop, daemon=True)
|
|
connect_thread.start()
|
|
logger.info("Eager RDMA connect loop started")
|
|
|
|
def _get_connect_rdma_task_response_loop():
|
|
while True:
|
|
try:
|
|
resp = self.engine_worker_queue.get_connect_rdma_task_response()
|
|
if resp:
|
|
task_id = resp["task_id"]
|
|
is_success = resp["success"]
|
|
with self._lock:
|
|
self.connect_status[task_id] = is_success
|
|
logger.debug(f"get_connect_rdma_task_response response: {resp}")
|
|
except Exception as e:
|
|
logger.error(f"_keep_get_connect_rdma_task_response got error: {e}, " f"{traceback.format_exc()}")
|
|
time.sleep(0.01)
|
|
|
|
get_resp_thread = threading.Thread(target=_get_connect_rdma_task_response_loop, daemon=True)
|
|
get_resp_thread.start()
|
|
logger.info("Get connect rdma task response loop started")
|
|
|
|
def _should_enable_eager_connect(self) -> bool:
|
|
"""
|
|
Check if eager RDMA connect should be enabled.
|
|
|
|
Returns True only when:
|
|
- Router URL is configured
|
|
- Instance role is 'prefill'
|
|
- FD_ENABLE_PD_RDMA_EAGER_CONNECT env is set
|
|
- RDMA protocol is in transfer_protocol list
|
|
- Local RDMA ports are configured
|
|
"""
|
|
if self.cfg.router_config.router is None:
|
|
return False
|
|
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
|
return False
|
|
if not envs.FD_ENABLE_PD_RDMA_EAGER_CONNECT:
|
|
return False
|
|
|
|
transfer_protocol = self.cfg.register_info.get("transfer_protocol", [])
|
|
rdma_ports = self.cfg.cache_config.local_rdma_comm_ports
|
|
if not ("rdma" in transfer_protocol and rdma_ports):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _eager_connect_iteration(self) -> None:
|
|
"""
|
|
Single iteration of the eager RDMA connect loop.
|
|
|
|
Workflow:
|
|
1. Fetch Decode instances from Router (filtered by model version)
|
|
2. For new instances: check health -> check RDMA support -> establish RDMA connection
|
|
3. For existing instances: verify health and RDMA connection status
|
|
4. Remove unhealthy or disconnected instances from connected_decodes
|
|
"""
|
|
if not self._is_registered:
|
|
logger.info("This instance has not registered to router, skip eager connect in this step")
|
|
return
|
|
|
|
# Step 1: Fetch Decode instances from Router
|
|
instances = self._fetch_decode_instances_internal()
|
|
if not instances:
|
|
return
|
|
|
|
# Step 2: Process new instances - try to establish RDMA connection
|
|
with self._lock:
|
|
connected_decodes_snapshot = list(self.connected_decodes)
|
|
existing_keys = {self._get_instance_key(inst) for inst in connected_decodes_snapshot}
|
|
|
|
for instance in instances:
|
|
try:
|
|
instance_key = self._get_instance_key(instance)
|
|
|
|
# Skip already connected instances
|
|
if instance_key in existing_keys:
|
|
continue
|
|
|
|
# Skip unhealthy instances
|
|
if not self._check_instance_health(instance):
|
|
logger.debug(f"Instance {instance_key} is unhealthy, skip")
|
|
continue
|
|
|
|
# Skip instances without RDMA support
|
|
if not self._supports_rdma(instance):
|
|
continue
|
|
|
|
# Try RDMA connection
|
|
if self._try_rdma_connect(instance):
|
|
with self._lock:
|
|
if instance not in self.connected_decodes:
|
|
self.connected_decodes.append(instance)
|
|
logger.info(f"RDMA connect succeeded: {instance_key}")
|
|
else:
|
|
logger.warning(f"RDMA connect failed: {instance_key}")
|
|
except Exception as e:
|
|
logger.exception(f"Error processing instance {instance}: {e}")
|
|
|
|
# Step 3: Verify existing connections - check health and RDMA status
|
|
to_remove = []
|
|
for instance in connected_decodes_snapshot:
|
|
instance_key = self._get_instance_key(instance)
|
|
|
|
if not self._check_instance_health(instance):
|
|
to_remove.append(instance)
|
|
logger.warning(f"Instance {instance_key} is unhealthy, will remove")
|
|
continue
|
|
|
|
if not self._check_rdma_connection(instance):
|
|
to_remove.append(instance)
|
|
logger.warning(f"Instance {instance_key} RDMA connection lost, will remove")
|
|
|
|
# Step 4: Remove failed instances from connected list
|
|
for instance in to_remove:
|
|
with self._lock:
|
|
if instance in self.connected_decodes:
|
|
self.connected_decodes.remove(instance)
|
|
|
|
logger.debug(
|
|
f"Connected decodes num is {len(self.connected_decodes)}, "
|
|
f"connected decodes are {[self._get_instance_key(inst) for inst in self.connected_decodes]}"
|
|
)
|
|
|
|
def _fetch_decode_instances_internal(self) -> List[Dict]:
|
|
"""
|
|
Fetch Decode instance list from Router.
|
|
|
|
Queries Router's /decode_instances endpoint with model version filter.
|
|
Returns empty list on error to allow retry in next iteration.
|
|
"""
|
|
router_url = self.cfg.router_config.router
|
|
if router_url is None:
|
|
return []
|
|
|
|
try:
|
|
version = self.cfg.model_config.version
|
|
resp = requests.get(
|
|
f"{router_url}/decode_instances",
|
|
params={"version": version},
|
|
timeout=self._timeout,
|
|
)
|
|
|
|
if resp.ok:
|
|
instances = resp.json()
|
|
logger.debug(
|
|
f"Fetched {len(instances)} decode instances from router, "
|
|
f"{[self._get_instance_key(instance) for instance in instances]}"
|
|
)
|
|
return instances
|
|
else:
|
|
logger.error(f"Fetch decode instances failed: {resp.status_code}")
|
|
return []
|
|
except Exception as e:
|
|
logger.exception(f"Error fetching decode instances: {e}")
|
|
return []
|
|
|
|
def _get_instance_key(self, instance: Dict) -> str:
|
|
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
|
return f"{instance.get('host_ip')}:{instance.get('port')}"
|
|
|
|
def _supports_rdma(self, instance: Dict) -> bool:
|
|
"""Check if instance supports RDMA transfer protocol and has RDMA ports configured."""
|
|
transfer_protocol = instance.get("transfer_protocol", [])
|
|
return "rdma" in transfer_protocol and instance.get("rdma_ports")
|
|
|
|
def _check_instance_health(self, instance: Dict) -> bool:
|
|
"""Check if Decode instance is healthy via HTTP /health endpoint."""
|
|
try:
|
|
host_ip = instance.get("host_ip")
|
|
port = instance.get("port")
|
|
url = f"http://{host_ip}:{port}/health"
|
|
response = requests.get(url, timeout=self._timeout)
|
|
return response.status_code == 200
|
|
except Exception as e:
|
|
logger.error(f"_is_decode_health error: {e}, host: {host_ip}, port: {port}")
|
|
return False
|
|
|
|
def _try_rdma_connect(self, instance: Dict) -> bool:
|
|
"""
|
|
Attempt to establish RDMA connection to a Decode instance.
|
|
|
|
Workflow:
|
|
1. Generate unique task_id and submit connect task to engine_worker_queue
|
|
2. Wait for connection result in connect_status dict (set by response loop)
|
|
3. Return True if connected successfully within timeout, False otherwise
|
|
|
|
Note: If already connected, the underlying RDMA layer will reuse existing connection.
|
|
|
|
Args:
|
|
instance: Decode instance info dict with 'host_ip' and 'rdma_ports'
|
|
|
|
Returns:
|
|
True if connection succeeded, False if failed or timeout
|
|
"""
|
|
try:
|
|
key = self._get_instance_key(instance)
|
|
task_id = f"{key}-{uuid4().hex}"
|
|
task = {"task_id": task_id, "ip": instance.get("host_ip"), "rdma_ports": instance.get("rdma_ports")}
|
|
self.engine_worker_queue.put_connect_rdma_task(task)
|
|
|
|
start_time = time.time()
|
|
while time.time() - start_time <= self._timeout:
|
|
with self._lock:
|
|
if task_id in self.connect_status:
|
|
result = self.connect_status[task_id]
|
|
del self.connect_status[task_id]
|
|
return result
|
|
time.sleep(0.01)
|
|
|
|
# Timeout: clean up any late-arriving result to prevent memory leak
|
|
with self._lock:
|
|
self.connect_status.pop(task_id, None)
|
|
|
|
except Exception as e:
|
|
logger.error(f"_try_rdma_connect error: {e}")
|
|
return False
|
|
|
|
def _check_rdma_connection(self, instance: Dict) -> bool:
|
|
"""
|
|
Verify RDMA connection to instance is still alive.
|
|
|
|
Reuses _try_rdma_connect() since the underlying RDMA layer:
|
|
- Returns success immediately if connection already exists
|
|
- Attempts reconnection if connection was lost
|
|
"""
|
|
return self._try_rdma_connect(instance)
|