[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)

* [Feature] Register to router with version info for PD disaggregation

Add RegisterManager for PD (Prefill-Decode) disaggregated deployment:
- All instances (Prefill/Decode) register to Router with heartbeat
- Prefill instances fetch Decode instance list from Router
- Prefill instances establish eager RDMA connections to Decode instances
- Register info includes: host_ip, port, role, version, is_paused, connected_decodes

Changes:
- Add RegisterManager class for managing PD registration and RDMA connections
- Add version field to ModelConfig for model version tracking
- Add connected_decodes to register_info for tracking connected Decode instances
- Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable

Test fixes:
- Add None checks for load_config in FDConfig.__init__
- Add version attribute to test mock model configs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refine

* remove test

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
jc
2026-03-17 14:43:35 +08:00
committed by GitHub
parent b152baeeee
commit 950366e58d
14 changed files with 507 additions and 97 deletions
@@ -107,6 +107,8 @@ def get_decode_ip_idx(task):
"""For compatibility, get decode ip and idx from task""" """For compatibility, get decode ip and idx from task"""
if "decode_ip" in task: if "decode_ip" in task:
decode_ip = task["decode_ip"] decode_ip = task["decode_ip"]
elif "host_ip" in task:
decode_ip = task["host_ip"]
else: else:
decode_ip = task["ip"] decode_ip = task["ip"]
if "decode_rdma_ports" in task: if "decode_rdma_ports" in task:
@@ -911,6 +913,7 @@ class CacheMessagerV1:
response = {"task_id": task_id, "success": True} response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait() self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response) self.engine_worker_queue.put_connect_rdma_task_response(response)
logger.debug(f"_handle_connect_task send response: {response}")
except Exception as e: except Exception as e:
logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}") logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}")
+27 -1
View File
@@ -24,6 +24,7 @@ from typing import Any, Dict, Literal, Optional, Union
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import yaml
from packaging.version import parse as parse_version from packaging.version import parse as parse_version
from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.transformers.configuration_utils import PretrainedConfig
from typing_extensions import assert_never from typing_extensions import assert_never
@@ -227,6 +228,7 @@ class ModelConfig:
self.kv_cache_quant_scale_path = "" self.kv_cache_quant_scale_path = ""
self.enable_entropy = False self.enable_entropy = False
self.model_impl: ModelImpl = "auto" self.model_impl: ModelImpl = "auto"
self.version: str = "init" # will override by the version.yaml in model dir
self.partial_rotary_factor: float = 1.0 self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0 self.num_nextn_predict_layers = 0
@@ -440,6 +442,17 @@ class ModelConfig:
f"Config file path: {config_path}" f"Config file path: {config_path}"
) )
def read_model_version(self):
"""
Read the version information from a YAML file located at 'version.yaml' within the model directory.
If the file exists, it extracts the 'version' field using yaml.safe_load.
Raises an assertion error if the file is not found at the specified path.
"""
version_path = os.path.join(self.model, "version.yaml")
assert os.path.exists(version_path), f"version.yaml not exist at {version_path}"
with open(version_path, "r", encoding="utf-8") as f:
self.version = yaml.safe_load(f)["version"]
def _get_default_runner_type( def _get_default_runner_type(
self, self,
architectures: list[str], architectures: list[str],
@@ -1939,7 +1952,7 @@ class FDConfig:
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta": if num_ranks > self.max_chips_per_node and self.load_config and self.load_config.load_strategy != "meta":
self.worker_num_per_node = self.max_chips_per_node self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node) nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}" assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
@@ -1953,6 +1966,16 @@ class FDConfig:
if current_platform.is_intel_hpu(): if current_platform.is_intel_hpu():
self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids) self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
if (
self.load_config
and self.load_config.dynamic_load_weight
and self.router_config
and self.router_config.router
):
# For RL scenario: version.yaml will be required for models in future releases.
# Temporarily enforce use router to be enabled.
self.model_config.read_model_version()
self.read_from_config() self.read_from_config()
self.postprocess() self.postprocess()
self.init_cache_info() self.init_cache_info()
@@ -2293,6 +2316,9 @@ class FDConfig:
"device_ids": self.local_device_ids, "device_ids": self.local_device_ids,
"transfer_protocol": transfer_protocol, "transfer_protocol": transfer_protocol,
"tp_size": self.parallel_config.tensor_parallel_size, "tp_size": self.parallel_config.tensor_parallel_size,
"is_paused": False,
"version": self.model_config.version,
"connected_decodes": [],
} }
logger.info(f"register_info: {self.register_info}") logger.info(f"register_info: {self.register_info}")
+28 -53
View File
@@ -34,11 +34,11 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import paddle import paddle
import requests
import zmq import zmq
from tqdm import tqdm from tqdm import tqdm
import fastdeploy.metrics.trace as tracing import fastdeploy.metrics.trace as tracing
from fastdeploy.engine.register_manager import RegisterManager
from fastdeploy.engine.request import ( from fastdeploy.engine.request import (
ControlRequest, ControlRequest,
ControlResponse, ControlResponse,
@@ -64,7 +64,6 @@ from fastdeploy.inter_communicator.fmq import FMQ
from fastdeploy.metrics.metrics import main_process_metrics from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.model_executor.guided_decoding import schema_checker from fastdeploy.model_executor.guided_decoding import schema_checker
from fastdeploy.plugins.token_processor import load_token_processor_plugins from fastdeploy.plugins.token_processor import load_token_processor_plugins
from fastdeploy.router.utils import check_service_health
from fastdeploy.spec_decode import SpecMethod from fastdeploy.spec_decode import SpecMethod
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
@@ -178,6 +177,13 @@ class EngineService:
# between the CPU transfer process and the worker process. # between the CPU transfer process and the worker process.
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
# Initialize RegisterManager
self._register_manager = RegisterManager(
cfg=self.cfg,
engine_worker_queue=self.engine_worker_queue,
get_is_paused=self._get_is_paused_safe,
)
if self.cfg.eplb_config.enable_eplb: if self.cfg.eplb_config.enable_eplb:
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
init_eplb_signals(cfg, current_suffix) init_eplb_signals(cfg, current_suffix)
@@ -210,7 +216,7 @@ class EngineService:
if self.cfg.scheduler_config.splitwise_role == "decode": if self.cfg.scheduler_config.splitwise_role == "decode":
self._decode_process_splitwise_requests() self._decode_process_splitwise_requests()
self._register_to_router() self._register_manager.start()
def start_worker_service(self, async_llm_pid=None): def start_worker_service(self, async_llm_pid=None):
# Initialize IPC signals for worker management # Initialize IPC signals for worker management
@@ -1370,6 +1376,11 @@ class EngineService:
with self._pause_cond: with self._pause_cond:
return {"is_paused": self.is_paused} return {"is_paused": self.is_paused}
def _get_is_paused_safe(self) -> bool:
"""Thread-safe getter for is_paused state, used by RegisterManager."""
with self._pause_cond:
return self.is_paused
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]: def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
"""Update model weights """Update model weights
Args: Args:
@@ -1387,7 +1398,20 @@ class EngineService:
error_msg = "Pause LLM Engine first before calling updating weights" error_msg = "Pause LLM Engine first before calling updating weights"
self.llm_logger.error(error_msg) self.llm_logger.error(error_msg)
raise Exception(error_msg) raise Exception(error_msg)
return self._call_worker(control_request, 60) responses = self._call_worker(control_request, 60)
if responses:
new_version = None
for resp in responses:
# Expect each worker response to be a dict-like object
if isinstance(resp, dict) and "version" in resp:
new_version = resp.get("version")
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
break
if new_version is not None:
self.cfg.model_config.version = new_version
return responses
async def _wait_all_control_responses(self, request_id: str, timeout: int): async def _wait_all_control_responses(self, request_id: str, timeout: int):
"""Wait for control responses from all workers with a global timeout. """Wait for control responses from all workers with a global timeout.
@@ -1699,55 +1723,6 @@ class EngineService:
self.llm_logger.error(f"Clear data error: {e}") self.llm_logger.error(f"Clear data error: {e}")
return False return False
def _register_to_router(self):
"""
Periodically send server information to the router for registeration, and it is used
as a heartbeat signal.
"""
def _register():
timeout = 5
sleep_seconds = 5
is_registered = False
while True:
try:
api_server_host = self.cfg.router_config.api_server_host
api_server_port = self.cfg.router_config.api_server_port
api_server_url = f"http://{api_server_host}:{api_server_port}"
if not check_service_health(api_server_url):
time.sleep(sleep_seconds)
self.llm_logger.info("Wait for API service health and then register to router")
time.sleep(sleep_seconds)
continue
router_url = self.cfg.router_config.router
resp = requests.post(
f"{router_url}/register",
json=self.cfg.register_info,
timeout=timeout,
)
if resp.ok:
if not is_registered:
is_registered = True
self.llm_logger.info("Register to router successfully")
else:
self.llm_logger.error(
f"Send server info to router failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
except Exception as e:
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
time.sleep(sleep_seconds)
if self.cfg.router_config.router is None:
self.llm_logger.info("Router is not enabled, skip registering to router")
else:
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()
def _exit_sub_services(self): def _exit_sub_services(self):
""" """
exit sub services exit sub services
+396
View File
@@ -0,0 +1,396 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import threading
import time
import traceback
from typing import Callable, Dict, List
from uuid import uuid4
import requests
from fastdeploy import envs
from fastdeploy.router.utils import check_service_health
from fastdeploy.utils import register_manager_logger as logger
class RegisterManager:
"""
Manages Prefill/Decode instance registration and RDMA connection for PD disaggregation.
In PD (Prefill-Decode) disaggregated deployment:
- All instances (Prefill/Decode) register to Router with heartbeat
- Prefill instances fetch Decode instance list from Router
- Prefill instances establish eager RDMA connections to Decode instances
Thread Model:
- _register_to_router: Periodic heartbeat registration thread
- _eager_connect_loop: Periodic RDMA connection management thread
- _get_connect_rdma_task_response_loop: Async RDMA connection result receiver thread
"""
def __init__(
self,
cfg,
engine_worker_queue,
get_is_paused: Callable[[], bool],
):
"""
Initialize RegisterManager.
Args:
cfg: FDConfig object containing router, scheduler, cache configurations
engine_worker_queue: Queue for communicating RDMA connect tasks with worker
get_is_paused: Callable that returns current engine pause state
"""
self.cfg = cfg
self.engine_worker_queue = engine_worker_queue
self.get_is_paused = get_is_paused
# Registration state
self._is_registered = False
# RDMA connection state (protected by _lock)
self.connected_decodes: List[Dict] = [] # Successfully connected Decode instances
self.connect_status: Dict[str, bool] = {} # task_id -> connection result
# Timing configuration (seconds)
self._timeout = 5 # HTTP request and RDMA connect timeout
self._sleep_seconds = 5 # Interval between iterations
self._lock = threading.Lock() # Protects connected_decodes and connect_status
def start(self) -> None:
"""Start background threads for registration and RDMA connection management."""
self._register_to_router()
self._start_eager_connect_loop()
def get_connected_decodes(self) -> List[Dict]:
"""
Return a snapshot of successfully connected Decode instances.
Thread-safe: returns a copy to avoid concurrent modification issues.
"""
with self._lock:
return list(self.connected_decodes)
def is_registered(self) -> bool:
"""Return whether this instance has successfully registered to Router."""
return self._is_registered
def _register_to_router(self) -> None:
"""
Start background thread for periodic Router registration (heartbeat).
Registration info includes: host_ip, port, role, version, is_paused, etc.
This serves as both initial registration and keep-alive heartbeat.
"""
router_url = self.cfg.router_config.router
if router_url is None:
logger.info("Router is not enabled, skip registering to router")
return
def _register():
while True:
try:
api_server_host = self.cfg.router_config.api_server_host
api_server_port = self.cfg.router_config.api_server_port
api_server_url = f"http://{api_server_host}:{api_server_port}"
if not check_service_health(api_server_url):
logger.info("Wait for API service health and then register to router")
time.sleep(self._sleep_seconds)
continue
# Update registration info
self.cfg.register_info["is_paused"] = self.get_is_paused()
self.cfg.register_info["version"] = self.cfg.model_config.version
self.cfg.register_info["connected_decodes"] = self.get_connected_decodes()
resp = requests.post(
f"{router_url}/register",
json=self.cfg.register_info,
timeout=self._timeout,
)
if resp.ok:
if not self._is_registered:
self._is_registered = True
logger.info("Register to router successfully")
else:
logger.error(
f"Send server info to router failed: {resp.status_code}, "
f"{resp.text}, {self.cfg.register_info}"
)
except Exception as e:
logger.exception(f"Unexpected error during router registration: {e}")
time.sleep(self._sleep_seconds)
register_thread = threading.Thread(target=_register, daemon=True)
register_thread.start()
def _start_eager_connect_loop(self) -> None:
"""
Start background threads for eager RDMA connection management.
Only enabled when all conditions are met:
- Router is configured
- This instance is Prefill role
- FD_ENABLE_PD_RDMA_EAGER_CONNECT=1
- RDMA transfer protocol is enabled
Starts two threads:
1. _eager_connect_loop: Periodically discovers and connects to Decode instances
2. _get_connect_rdma_task_response_loop: Receives async RDMA connection results
"""
if not self._should_enable_eager_connect():
logger.info("Eager RDMA connect is not enabled, skip")
return
def _eager_connect_loop():
while True:
try:
self._eager_connect_iteration()
except Exception as e:
logger.exception(f"Error in eager connect loop: {e}")
time.sleep(self._sleep_seconds)
connect_thread = threading.Thread(target=_eager_connect_loop, daemon=True)
connect_thread.start()
logger.info("Eager RDMA connect loop started")
def _get_connect_rdma_task_response_loop():
while True:
try:
resp = self.engine_worker_queue.get_connect_rdma_task_response()
if resp:
task_id = resp["task_id"]
is_success = resp["success"]
with self._lock:
self.connect_status[task_id] = is_success
logger.debug(f"get_connect_rdma_task_response response: {resp}")
except Exception as e:
logger.error(f"_keep_get_connect_rdma_task_response got error: {e}, " f"{traceback.format_exc()}")
time.sleep(0.01)
get_resp_thread = threading.Thread(target=_get_connect_rdma_task_response_loop, daemon=True)
get_resp_thread.start()
logger.info("Get connect rdma task response loop started")
def _should_enable_eager_connect(self) -> bool:
"""
Check if eager RDMA connect should be enabled.
Returns True only when:
- Router URL is configured
- Instance role is 'prefill'
- FD_ENABLE_PD_RDMA_EAGER_CONNECT env is set
- RDMA protocol is in transfer_protocol list
- Local RDMA ports are configured
"""
if self.cfg.router_config.router is None:
return False
if self.cfg.scheduler_config.splitwise_role != "prefill":
return False
if not envs.FD_ENABLE_PD_RDMA_EAGER_CONNECT:
return False
transfer_protocol = self.cfg.register_info.get("transfer_protocol", [])
rdma_ports = self.cfg.cache_config.local_rdma_comm_ports
if not ("rdma" in transfer_protocol and rdma_ports):
return False
return True
def _eager_connect_iteration(self) -> None:
"""
Single iteration of the eager RDMA connect loop.
Workflow:
1. Fetch Decode instances from Router (filtered by model version)
2. For new instances: check health -> check RDMA support -> establish RDMA connection
3. For existing instances: verify health and RDMA connection status
4. Remove unhealthy or disconnected instances from connected_decodes
"""
if not self._is_registered:
logger.info("This instance has not registered to router, skip eager connect in this step")
return
# Step 1: Fetch Decode instances from Router
instances = self._fetch_decode_instances_internal()
if not instances:
return
# Step 2: Process new instances - try to establish RDMA connection
with self._lock:
connected_decodes_snapshot = list(self.connected_decodes)
existing_keys = {self._get_instance_key(inst) for inst in connected_decodes_snapshot}
for instance in instances:
try:
instance_key = self._get_instance_key(instance)
# Skip already connected instances
if instance_key in existing_keys:
continue
# Skip unhealthy instances
if not self._check_instance_health(instance):
logger.debug(f"Instance {instance_key} is unhealthy, skip")
continue
# Skip instances without RDMA support
if not self._supports_rdma(instance):
continue
# Try RDMA connection
if self._try_rdma_connect(instance):
with self._lock:
if instance not in self.connected_decodes:
self.connected_decodes.append(instance)
logger.info(f"RDMA connect succeeded: {instance_key}")
else:
logger.warning(f"RDMA connect failed: {instance_key}")
except Exception as e:
logger.exception(f"Error processing instance {instance}: {e}")
# Step 3: Verify existing connections - check health and RDMA status
to_remove = []
for instance in connected_decodes_snapshot:
instance_key = self._get_instance_key(instance)
if not self._check_instance_health(instance):
to_remove.append(instance)
logger.warning(f"Instance {instance_key} is unhealthy, will remove")
continue
if not self._check_rdma_connection(instance):
to_remove.append(instance)
logger.warning(f"Instance {instance_key} RDMA connection lost, will remove")
# Step 4: Remove failed instances from connected list
for instance in to_remove:
with self._lock:
if instance in self.connected_decodes:
self.connected_decodes.remove(instance)
logger.debug(
f"Connected decodes num is {len(self.connected_decodes)}, "
f"connected decodes are {[self._get_instance_key(inst) for inst in self.connected_decodes]}"
)
def _fetch_decode_instances_internal(self) -> List[Dict]:
"""
Fetch Decode instance list from Router.
Queries Router's /decode_instances endpoint with model version filter.
Returns empty list on error to allow retry in next iteration.
"""
router_url = self.cfg.router_config.router
if router_url is None:
return []
try:
version = self.cfg.model_config.version
resp = requests.get(
f"{router_url}/decode_instances",
params={"version": version},
timeout=self._timeout,
)
if resp.ok:
instances = resp.json()
logger.debug(
f"Fetched {len(instances)} decode instances from router, "
f"{[self._get_instance_key(instance) for instance in instances]}"
)
return instances
else:
logger.error(f"Fetch decode instances failed: {resp.status_code}")
return []
except Exception as e:
logger.exception(f"Error fetching decode instances: {e}")
return []
def _get_instance_key(self, instance: Dict) -> str:
"""Generate unique identifier for an instance: 'host_ip:port'."""
return f"{instance.get('host_ip')}:{instance.get('port')}"
def _supports_rdma(self, instance: Dict) -> bool:
"""Check if instance supports RDMA transfer protocol and has RDMA ports configured."""
transfer_protocol = instance.get("transfer_protocol", [])
return "rdma" in transfer_protocol and instance.get("rdma_ports")
def _check_instance_health(self, instance: Dict) -> bool:
"""Check if Decode instance is healthy via HTTP /health endpoint."""
try:
host_ip = instance.get("host_ip")
port = instance.get("port")
url = f"http://{host_ip}:{port}/health"
response = requests.get(url, timeout=self._timeout)
return response.status_code == 200
except Exception as e:
logger.error(f"_is_decode_health error: {e}, host: {host_ip}, port: {port}")
return False
def _try_rdma_connect(self, instance: Dict) -> bool:
"""
Attempt to establish RDMA connection to a Decode instance.
Workflow:
1. Generate unique task_id and submit connect task to engine_worker_queue
2. Wait for connection result in connect_status dict (set by response loop)
3. Return True if connected successfully within timeout, False otherwise
Note: If already connected, the underlying RDMA layer will reuse existing connection.
Args:
instance: Decode instance info dict with 'host_ip' and 'rdma_ports'
Returns:
True if connection succeeded, False if failed or timeout
"""
try:
key = self._get_instance_key(instance)
task_id = f"{key}-{uuid4().hex}"
task = {"task_id": task_id, "ip": instance.get("host_ip"), "rdma_ports": instance.get("rdma_ports")}
self.engine_worker_queue.put_connect_rdma_task(task)
start_time = time.time()
while time.time() - start_time <= self._timeout:
with self._lock:
if task_id in self.connect_status:
result = self.connect_status[task_id]
del self.connect_status[task_id]
return result
time.sleep(0.01)
# Timeout: clean up any late-arriving result to prevent memory leak
with self._lock:
self.connect_status.pop(task_id, None)
except Exception as e:
logger.error(f"_try_rdma_connect error: {e}")
return False
def _check_rdma_connection(self, instance: Dict) -> bool:
"""
Verify RDMA connection to instance is still alive.
Reuses _try_rdma_connect() since the underlying RDMA layer:
- Returns success immediately if connection already exists
- Attempts reconnection if connection was lost
"""
return self._try_rdma_connect(instance)
+2
View File
@@ -171,6 +171,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
"FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")), "FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")),
# Enable early RDMA connection for PD disaggregation
"FD_ENABLE_PD_RDMA_EAGER_CONNECT": lambda: bool(int(os.getenv("FD_ENABLE_PD_RDMA_EAGER_CONNECT", "0"))),
"FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))), "FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))),
"FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")), "FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")),
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU" # "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
+35 -17
View File
@@ -12,6 +12,7 @@ import random
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain from itertools import chain
from typing import Dict, List, Optional
from uuid import uuid4 from uuid import uuid4
import aiohttp import aiohttp
@@ -114,23 +115,26 @@ class Router:
raise RuntimeError(f"Instance {inst_info} is not healthy") raise RuntimeError(f"Instance {inst_info} is not healthy")
async with self.lock: async with self.lock:
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers: instance_key = inst_info.get_key()
self.mixed_servers.append(inst_info)
logger.info( if inst_info.role == InstanceRole.MIXED:
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}" self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
) elif inst_info.role == InstanceRole.PREFILL:
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers: self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
self.prefill_servers.append(inst_info) elif inst_info.role == InstanceRole.DECODE:
logger.info( self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
f"Register prefill instance success: {inst_info}, "
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
) """Update existing instance or add new one based on key (host_ip:port)."""
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers: for i, existing in enumerate(server_list):
self.decode_servers.append(inst_info) if existing.get_key() == key:
logger.info( if existing != inst_info:
f"Register decode instance success: {inst_info}, " server_list[i] = inst_info
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}" logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
) return
server_list.append(inst_info)
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
async def registered_number(self): async def registered_number(self):
"""Get number of registered instances""" """Get number of registered instances"""
@@ -140,6 +144,14 @@ class Router:
"decode": len(self.decode_servers), "decode": len(self.decode_servers),
} }
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
"""Get all registered decode instances, optionally filtered by version"""
async with self.lock:
instances = self.decode_servers
if version is not None:
instances = [inst for inst in instances if inst.version == version]
return [inst.to_dict() for inst in instances]
async def select_pd(self): async def select_pd(self):
"""Select one prefill and one decode server""" """Select one prefill and one decode server"""
async with self.lock: async with self.lock:
@@ -456,6 +468,12 @@ async def registered_number():
return await app.state.router.registered_number() return await app.state.router.registered_number()
@app.get("/decode_instances")
async def decode_instances(version: Optional[str] = None):
"""Get all registered decode instances, optionally filtered by version"""
return await app.state.router.get_decode_instances(version)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def create_chat_completion(request_data: dict): async def create_chat_completion(request_data: dict):
return await app.state.router.handle_request(request_data, "v1/chat/completions") return await app.state.router.handle_request(request_data, "v1/chat/completions")
+8 -1
View File
@@ -17,7 +17,7 @@
import asyncio import asyncio
from dataclasses import MISSING, asdict, dataclass, field, fields from dataclasses import MISSING, asdict, dataclass, field, fields
from enum import Enum from enum import Enum
from typing import Any, List, Union from typing import Any, Dict, List, Optional, Union
import aiohttp import aiohttp
import requests import requests
@@ -41,6 +41,9 @@ class InstanceInfo:
rdma_ports: Union[List[str], List[int]] = field(default_factory=list) rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
device_ids: Union[List[str], List[int]] = field(default_factory=list) device_ids: Union[List[str], List[int]] = field(default_factory=list)
tp_size: int = 1 tp_size: int = 1
is_paused: bool = False
version: Optional[str] = None
connected_decodes: List[Dict] = field(default_factory=list)
@classmethod @classmethod
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo": def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
@@ -92,6 +95,10 @@ class InstanceInfo:
url = f"http://{url}" url = f"http://{url}"
return url return url
def get_key(self) -> str:
"""Generate unique identifier for an instance: 'host_ip:port'."""
return f"{self.host_ip}:{self.port}"
def check_service_health(base_url: str, timeout: int = 3) -> bool: def check_service_health(base_url: str, timeout: int = 3) -> bool:
""" """
+1
View File
@@ -1163,6 +1163,7 @@ trace_logger = FastDeployLogger().get_trace_logger("trace", "trace.log")
router_logger = get_logger("router", "router.log") router_logger = get_logger("router", "router.log")
fmq_logger = get_logger("fmq", "fmq.log") fmq_logger = get_logger("fmq", "fmq.log")
obj_logger = get_logger("obj", "obj.log") # debug内存问题 obj_logger = get_logger("obj", "obj.log") # debug内存问题
register_manager_logger = get_logger("register_manager", "register_manager.log")
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
-25
View File
@@ -1152,31 +1152,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights")) eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
self._detach_finalizer(eng) self._detach_finalizer(eng)
def test_register_to_router_disabled(self):
eng = self._make_mixed_engine()
eng.cfg.router_config.router = None
with (
patch.object(eng, "llm_logger") as mock_logger,
patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock,
):
eng._register_to_router()
mock_logger.info.assert_called()
thread_mock.assert_not_called()
self._detach_finalizer(eng)
def test_register_to_router_enabled_starts_thread(self):
eng = self._make_mixed_engine()
eng.cfg.router_config.router = "http://router"
with patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock:
eng._register_to_router()
thread_mock.assert_called_once()
thread_mock.return_value.start.assert_called_once()
self._detach_finalizer(eng)
def test_insert_zmq_task_to_scheduler_normal_request(self): def test_insert_zmq_task_to_scheduler_normal_request(self):
eng = self._make_mixed_engine() eng = self._make_mixed_engine()
eng.running = True eng.running = True
+1
View File
@@ -177,6 +177,7 @@ class TestInitEplbSignals(unittest.TestCase):
model_cfg.model = "/test/model" model_cfg.model = "/test/model"
model_cfg.architectures = ["test_model"] model_cfg.architectures = ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_layer_per_block = 1 cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args) parallel_cfg = ParallelConfig(args)
+1
View File
@@ -57,6 +57,7 @@ class TestRedundantExpertManager(unittest.TestCase):
model_cfg.model = "/test/model" model_cfg.model = "/test/model"
model_cfg.architectures = ["test_model"] model_cfg.architectures = ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_layer_per_block = 1 cache_cfg.bytes_per_layer_per_block = 1
parallel_cfg = ParallelConfig(args) parallel_cfg = ParallelConfig(args)
+1
View File
@@ -63,6 +63,7 @@ class FakeModelConfig:
self.logprobs_mode = "raw_logprobs" self.logprobs_mode = "raw_logprobs"
self.architectures = ["test_model"] self.architectures = ["test_model"]
self.mm_max_tokens_per_item = None self.mm_max_tokens_per_item = None
self.version = None # Required for register_info
class FakeLoadConfig: class FakeLoadConfig:
@@ -35,6 +35,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
model_cfg.print = print model_cfg.print = print
model_cfg.architectures = ["test_model"] model_cfg.architectures = ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.bytes_per_token_per_layer = 1
parallel_cfg = ParallelConfig(args) parallel_cfg = ParallelConfig(args)
+3
View File
@@ -81,6 +81,7 @@ def _build_manager(
model_cfg.max_model_len = max_model_len model_cfg.max_model_len = max_model_len
model_cfg.architectures = architectures or ["test_model"] model_cfg.architectures = architectures or ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.bytes_per_token_per_layer = 1
cache_cfg.kv_cache_ratio = 1.0 cache_cfg.kv_cache_ratio = 1.0
parallel_cfg = ParallelConfig(args) parallel_cfg = ParallelConfig(args)
@@ -142,6 +143,7 @@ class TestResourceManagerV1(unittest.TestCase):
model_cfg.max_model_len = 3200 model_cfg.max_model_len = 3200
model_cfg.architectures = ["test_model"] model_cfg.architectures = ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.bytes_per_token_per_layer = 1
cache_cfg.kv_cache_ratio = 1.0 cache_cfg.kv_cache_ratio = 1.0
parallel_cfg = ParallelConfig(args) parallel_cfg = ParallelConfig(args)
@@ -304,6 +306,7 @@ class TestRevertChunkedMMInput(unittest.TestCase):
model_cfg.max_model_len = 3200 model_cfg.max_model_len = 3200
model_cfg.architectures = ["test_model"] model_cfg.architectures = ["test_model"]
model_cfg.mm_max_tokens_per_item = None model_cfg.mm_max_tokens_per_item = None
model_cfg.version = None # Required for register_info
cache_cfg.bytes_per_token_per_layer = 1 cache_cfg.bytes_per_token_per_layer = 1
cache_cfg.kv_cache_ratio = 1.0 cache_cfg.kv_cache_ratio = 1.0
cache_cfg.block_size = 64 cache_cfg.block_size = 64