mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-22 16:07:51 +08:00
[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -107,6 +107,8 @@ def get_decode_ip_idx(task):
|
||||
"""For compatibility, get decode ip and idx from task"""
|
||||
if "decode_ip" in task:
|
||||
decode_ip = task["decode_ip"]
|
||||
elif "host_ip" in task:
|
||||
decode_ip = task["host_ip"]
|
||||
else:
|
||||
decode_ip = task["ip"]
|
||||
if "decode_rdma_ports" in task:
|
||||
@@ -911,6 +913,7 @@ class CacheMessagerV1:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
logger.debug(f"_handle_connect_task send response: {response}")
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}")
|
||||
|
||||
|
||||
+27
-1
@@ -24,6 +24,7 @@ from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
import yaml
|
||||
from packaging.version import parse as parse_version
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
@@ -227,6 +228,7 @@ class ModelConfig:
|
||||
self.kv_cache_quant_scale_path = ""
|
||||
self.enable_entropy = False
|
||||
self.model_impl: ModelImpl = "auto"
|
||||
self.version: str = "init" # will override by the version.yaml in model dir
|
||||
|
||||
self.partial_rotary_factor: float = 1.0
|
||||
self.num_nextn_predict_layers = 0
|
||||
@@ -440,6 +442,17 @@ class ModelConfig:
|
||||
f"Config file path: {config_path}"
|
||||
)
|
||||
|
||||
def read_model_version(self):
|
||||
"""
|
||||
Read the version information from a YAML file located at 'version.yaml' within the model directory.
|
||||
If the file exists, it extracts the 'version' field using yaml.safe_load.
|
||||
Raises an assertion error if the file is not found at the specified path.
|
||||
"""
|
||||
version_path = os.path.join(self.model, "version.yaml")
|
||||
assert os.path.exists(version_path), f"version.yaml not exist at {version_path}"
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
self.version = yaml.safe_load(f)["version"]
|
||||
|
||||
def _get_default_runner_type(
|
||||
self,
|
||||
architectures: list[str],
|
||||
@@ -1939,7 +1952,7 @@ class FDConfig:
|
||||
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
|
||||
if num_ranks > self.max_chips_per_node and self.load_config and self.load_config.load_strategy != "meta":
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
@@ -1953,6 +1966,16 @@ class FDConfig:
|
||||
if current_platform.is_intel_hpu():
|
||||
self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
|
||||
|
||||
if (
|
||||
self.load_config
|
||||
and self.load_config.dynamic_load_weight
|
||||
and self.router_config
|
||||
and self.router_config.router
|
||||
):
|
||||
# For RL scenario: version.yaml will be required for models in future releases.
|
||||
# Temporarily enforce use router to be enabled.
|
||||
self.model_config.read_model_version()
|
||||
|
||||
self.read_from_config()
|
||||
self.postprocess()
|
||||
self.init_cache_info()
|
||||
@@ -2293,6 +2316,9 @@ class FDConfig:
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": transfer_protocol,
|
||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||
"is_paused": False,
|
||||
"version": self.model_config.version,
|
||||
"connected_decodes": [],
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
|
||||
|
||||
@@ -34,11 +34,11 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import requests
|
||||
import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
import fastdeploy.metrics.trace as tracing
|
||||
from fastdeploy.engine.register_manager import RegisterManager
|
||||
from fastdeploy.engine.request import (
|
||||
ControlRequest,
|
||||
ControlResponse,
|
||||
@@ -64,7 +64,6 @@ from fastdeploy.inter_communicator.fmq import FMQ
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||
from fastdeploy.router.utils import check_service_health
|
||||
from fastdeploy.spec_decode import SpecMethod
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
@@ -178,6 +177,13 @@ class EngineService:
|
||||
# between the CPU transfer process and the worker process.
|
||||
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
|
||||
|
||||
# Initialize RegisterManager
|
||||
self._register_manager = RegisterManager(
|
||||
cfg=self.cfg,
|
||||
engine_worker_queue=self.engine_worker_queue,
|
||||
get_is_paused=self._get_is_paused_safe,
|
||||
)
|
||||
|
||||
if self.cfg.eplb_config.enable_eplb:
|
||||
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
init_eplb_signals(cfg, current_suffix)
|
||||
@@ -210,7 +216,7 @@ class EngineService:
|
||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||
self._decode_process_splitwise_requests()
|
||||
|
||||
self._register_to_router()
|
||||
self._register_manager.start()
|
||||
|
||||
def start_worker_service(self, async_llm_pid=None):
|
||||
# Initialize IPC signals for worker management
|
||||
@@ -1370,6 +1376,11 @@ class EngineService:
|
||||
with self._pause_cond:
|
||||
return {"is_paused": self.is_paused}
|
||||
|
||||
def _get_is_paused_safe(self) -> bool:
|
||||
"""Thread-safe getter for is_paused state, used by RegisterManager."""
|
||||
with self._pause_cond:
|
||||
return self.is_paused
|
||||
|
||||
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||
"""Update model weights
|
||||
Args:
|
||||
@@ -1387,7 +1398,20 @@ class EngineService:
|
||||
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||
self.llm_logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return self._call_worker(control_request, 60)
|
||||
responses = self._call_worker(control_request, 60)
|
||||
|
||||
if responses:
|
||||
new_version = None
|
||||
for resp in responses:
|
||||
# Expect each worker response to be a dict-like object
|
||||
if isinstance(resp, dict) and "version" in resp:
|
||||
new_version = resp.get("version")
|
||||
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
|
||||
break
|
||||
if new_version is not None:
|
||||
self.cfg.model_config.version = new_version
|
||||
|
||||
return responses
|
||||
|
||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||
"""Wait for control responses from all workers with a global timeout.
|
||||
@@ -1699,55 +1723,6 @@ class EngineService:
|
||||
self.llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _register_to_router(self):
|
||||
"""
|
||||
Periodically send server information to the router for registeration, and it is used
|
||||
as a heartbeat signal.
|
||||
"""
|
||||
|
||||
def _register():
|
||||
timeout = 5
|
||||
sleep_seconds = 5
|
||||
is_registered = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
api_server_host = self.cfg.router_config.api_server_host
|
||||
api_server_port = self.cfg.router_config.api_server_port
|
||||
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
||||
if not check_service_health(api_server_url):
|
||||
time.sleep(sleep_seconds)
|
||||
self.llm_logger.info("Wait for API service health and then register to router")
|
||||
time.sleep(sleep_seconds)
|
||||
continue
|
||||
|
||||
router_url = self.cfg.router_config.router
|
||||
resp = requests.post(
|
||||
f"{router_url}/register",
|
||||
json=self.cfg.register_info,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
if not is_registered:
|
||||
is_registered = True
|
||||
self.llm_logger.info("Register to router successfully")
|
||||
else:
|
||||
self.llm_logger.error(
|
||||
f"Send server info to router failed: {resp.status_code}, "
|
||||
f"{resp.text}, {self.cfg.register_info}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
|
||||
|
||||
time.sleep(sleep_seconds)
|
||||
|
||||
if self.cfg.router_config.router is None:
|
||||
self.llm_logger.info("Router is not enabled, skip registering to router")
|
||||
else:
|
||||
register_thread = threading.Thread(target=_register, daemon=True)
|
||||
register_thread.start()
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Dict, List
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.router.utils import check_service_health
|
||||
from fastdeploy.utils import register_manager_logger as logger
|
||||
|
||||
|
||||
class RegisterManager:
|
||||
"""
|
||||
Manages Prefill/Decode instance registration and RDMA connection for PD disaggregation.
|
||||
|
||||
In PD (Prefill-Decode) disaggregated deployment:
|
||||
- All instances (Prefill/Decode) register to Router with heartbeat
|
||||
- Prefill instances fetch Decode instance list from Router
|
||||
- Prefill instances establish eager RDMA connections to Decode instances
|
||||
|
||||
Thread Model:
|
||||
- _register_to_router: Periodic heartbeat registration thread
|
||||
- _eager_connect_loop: Periodic RDMA connection management thread
|
||||
- _get_connect_rdma_task_response_loop: Async RDMA connection result receiver thread
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
engine_worker_queue,
|
||||
get_is_paused: Callable[[], bool],
|
||||
):
|
||||
"""
|
||||
Initialize RegisterManager.
|
||||
|
||||
Args:
|
||||
cfg: FDConfig object containing router, scheduler, cache configurations
|
||||
engine_worker_queue: Queue for communicating RDMA connect tasks with worker
|
||||
get_is_paused: Callable that returns current engine pause state
|
||||
"""
|
||||
self.cfg = cfg
|
||||
self.engine_worker_queue = engine_worker_queue
|
||||
self.get_is_paused = get_is_paused
|
||||
|
||||
# Registration state
|
||||
self._is_registered = False
|
||||
|
||||
# RDMA connection state (protected by _lock)
|
||||
self.connected_decodes: List[Dict] = [] # Successfully connected Decode instances
|
||||
self.connect_status: Dict[str, bool] = {} # task_id -> connection result
|
||||
|
||||
# Timing configuration (seconds)
|
||||
self._timeout = 5 # HTTP request and RDMA connect timeout
|
||||
self._sleep_seconds = 5 # Interval between iterations
|
||||
|
||||
self._lock = threading.Lock() # Protects connected_decodes and connect_status
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start background threads for registration and RDMA connection management."""
|
||||
self._register_to_router()
|
||||
self._start_eager_connect_loop()
|
||||
|
||||
def get_connected_decodes(self) -> List[Dict]:
|
||||
"""
|
||||
Return a snapshot of successfully connected Decode instances.
|
||||
Thread-safe: returns a copy to avoid concurrent modification issues.
|
||||
"""
|
||||
with self._lock:
|
||||
return list(self.connected_decodes)
|
||||
|
||||
def is_registered(self) -> bool:
|
||||
"""Return whether this instance has successfully registered to Router."""
|
||||
return self._is_registered
|
||||
|
||||
def _register_to_router(self) -> None:
|
||||
"""
|
||||
Start background thread for periodic Router registration (heartbeat).
|
||||
|
||||
Registration info includes: host_ip, port, role, version, is_paused, etc.
|
||||
This serves as both initial registration and keep-alive heartbeat.
|
||||
"""
|
||||
router_url = self.cfg.router_config.router
|
||||
if router_url is None:
|
||||
logger.info("Router is not enabled, skip registering to router")
|
||||
return
|
||||
|
||||
def _register():
|
||||
while True:
|
||||
try:
|
||||
api_server_host = self.cfg.router_config.api_server_host
|
||||
api_server_port = self.cfg.router_config.api_server_port
|
||||
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
||||
if not check_service_health(api_server_url):
|
||||
logger.info("Wait for API service health and then register to router")
|
||||
time.sleep(self._sleep_seconds)
|
||||
continue
|
||||
|
||||
# Update registration info
|
||||
self.cfg.register_info["is_paused"] = self.get_is_paused()
|
||||
self.cfg.register_info["version"] = self.cfg.model_config.version
|
||||
self.cfg.register_info["connected_decodes"] = self.get_connected_decodes()
|
||||
|
||||
resp = requests.post(
|
||||
f"{router_url}/register",
|
||||
json=self.cfg.register_info,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
if not self._is_registered:
|
||||
self._is_registered = True
|
||||
logger.info("Register to router successfully")
|
||||
else:
|
||||
logger.error(
|
||||
f"Send server info to router failed: {resp.status_code}, "
|
||||
f"{resp.text}, {self.cfg.register_info}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during router registration: {e}")
|
||||
|
||||
time.sleep(self._sleep_seconds)
|
||||
|
||||
register_thread = threading.Thread(target=_register, daemon=True)
|
||||
register_thread.start()
|
||||
|
||||
def _start_eager_connect_loop(self) -> None:
|
||||
"""
|
||||
Start background threads for eager RDMA connection management.
|
||||
|
||||
Only enabled when all conditions are met:
|
||||
- Router is configured
|
||||
- This instance is Prefill role
|
||||
- FD_ENABLE_PD_RDMA_EAGER_CONNECT=1
|
||||
- RDMA transfer protocol is enabled
|
||||
|
||||
Starts two threads:
|
||||
1. _eager_connect_loop: Periodically discovers and connects to Decode instances
|
||||
2. _get_connect_rdma_task_response_loop: Receives async RDMA connection results
|
||||
"""
|
||||
if not self._should_enable_eager_connect():
|
||||
logger.info("Eager RDMA connect is not enabled, skip")
|
||||
return
|
||||
|
||||
def _eager_connect_loop():
|
||||
while True:
|
||||
try:
|
||||
self._eager_connect_iteration()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in eager connect loop: {e}")
|
||||
time.sleep(self._sleep_seconds)
|
||||
|
||||
connect_thread = threading.Thread(target=_eager_connect_loop, daemon=True)
|
||||
connect_thread.start()
|
||||
logger.info("Eager RDMA connect loop started")
|
||||
|
||||
def _get_connect_rdma_task_response_loop():
|
||||
while True:
|
||||
try:
|
||||
resp = self.engine_worker_queue.get_connect_rdma_task_response()
|
||||
if resp:
|
||||
task_id = resp["task_id"]
|
||||
is_success = resp["success"]
|
||||
with self._lock:
|
||||
self.connect_status[task_id] = is_success
|
||||
logger.debug(f"get_connect_rdma_task_response response: {resp}")
|
||||
except Exception as e:
|
||||
logger.error(f"_keep_get_connect_rdma_task_response got error: {e}, " f"{traceback.format_exc()}")
|
||||
time.sleep(0.01)
|
||||
|
||||
get_resp_thread = threading.Thread(target=_get_connect_rdma_task_response_loop, daemon=True)
|
||||
get_resp_thread.start()
|
||||
logger.info("Get connect rdma task response loop started")
|
||||
|
||||
def _should_enable_eager_connect(self) -> bool:
|
||||
"""
|
||||
Check if eager RDMA connect should be enabled.
|
||||
|
||||
Returns True only when:
|
||||
- Router URL is configured
|
||||
- Instance role is 'prefill'
|
||||
- FD_ENABLE_PD_RDMA_EAGER_CONNECT env is set
|
||||
- RDMA protocol is in transfer_protocol list
|
||||
- Local RDMA ports are configured
|
||||
"""
|
||||
if self.cfg.router_config.router is None:
|
||||
return False
|
||||
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||
return False
|
||||
if not envs.FD_ENABLE_PD_RDMA_EAGER_CONNECT:
|
||||
return False
|
||||
|
||||
transfer_protocol = self.cfg.register_info.get("transfer_protocol", [])
|
||||
rdma_ports = self.cfg.cache_config.local_rdma_comm_ports
|
||||
if not ("rdma" in transfer_protocol and rdma_ports):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _eager_connect_iteration(self) -> None:
|
||||
"""
|
||||
Single iteration of the eager RDMA connect loop.
|
||||
|
||||
Workflow:
|
||||
1. Fetch Decode instances from Router (filtered by model version)
|
||||
2. For new instances: check health -> check RDMA support -> establish RDMA connection
|
||||
3. For existing instances: verify health and RDMA connection status
|
||||
4. Remove unhealthy or disconnected instances from connected_decodes
|
||||
"""
|
||||
if not self._is_registered:
|
||||
logger.info("This instance has not registered to router, skip eager connect in this step")
|
||||
return
|
||||
|
||||
# Step 1: Fetch Decode instances from Router
|
||||
instances = self._fetch_decode_instances_internal()
|
||||
if not instances:
|
||||
return
|
||||
|
||||
# Step 2: Process new instances - try to establish RDMA connection
|
||||
with self._lock:
|
||||
connected_decodes_snapshot = list(self.connected_decodes)
|
||||
existing_keys = {self._get_instance_key(inst) for inst in connected_decodes_snapshot}
|
||||
|
||||
for instance in instances:
|
||||
try:
|
||||
instance_key = self._get_instance_key(instance)
|
||||
|
||||
# Skip already connected instances
|
||||
if instance_key in existing_keys:
|
||||
continue
|
||||
|
||||
# Skip unhealthy instances
|
||||
if not self._check_instance_health(instance):
|
||||
logger.debug(f"Instance {instance_key} is unhealthy, skip")
|
||||
continue
|
||||
|
||||
# Skip instances without RDMA support
|
||||
if not self._supports_rdma(instance):
|
||||
continue
|
||||
|
||||
# Try RDMA connection
|
||||
if self._try_rdma_connect(instance):
|
||||
with self._lock:
|
||||
if instance not in self.connected_decodes:
|
||||
self.connected_decodes.append(instance)
|
||||
logger.info(f"RDMA connect succeeded: {instance_key}")
|
||||
else:
|
||||
logger.warning(f"RDMA connect failed: {instance_key}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing instance {instance}: {e}")
|
||||
|
||||
# Step 3: Verify existing connections - check health and RDMA status
|
||||
to_remove = []
|
||||
for instance in connected_decodes_snapshot:
|
||||
instance_key = self._get_instance_key(instance)
|
||||
|
||||
if not self._check_instance_health(instance):
|
||||
to_remove.append(instance)
|
||||
logger.warning(f"Instance {instance_key} is unhealthy, will remove")
|
||||
continue
|
||||
|
||||
if not self._check_rdma_connection(instance):
|
||||
to_remove.append(instance)
|
||||
logger.warning(f"Instance {instance_key} RDMA connection lost, will remove")
|
||||
|
||||
# Step 4: Remove failed instances from connected list
|
||||
for instance in to_remove:
|
||||
with self._lock:
|
||||
if instance in self.connected_decodes:
|
||||
self.connected_decodes.remove(instance)
|
||||
|
||||
logger.debug(
|
||||
f"Connected decodes num is {len(self.connected_decodes)}, "
|
||||
f"connected decodes are {[self._get_instance_key(inst) for inst in self.connected_decodes]}"
|
||||
)
|
||||
|
||||
def _fetch_decode_instances_internal(self) -> List[Dict]:
|
||||
"""
|
||||
Fetch Decode instance list from Router.
|
||||
|
||||
Queries Router's /decode_instances endpoint with model version filter.
|
||||
Returns empty list on error to allow retry in next iteration.
|
||||
"""
|
||||
router_url = self.cfg.router_config.router
|
||||
if router_url is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
version = self.cfg.model_config.version
|
||||
resp = requests.get(
|
||||
f"{router_url}/decode_instances",
|
||||
params={"version": version},
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
instances = resp.json()
|
||||
logger.debug(
|
||||
f"Fetched {len(instances)} decode instances from router, "
|
||||
f"{[self._get_instance_key(instance) for instance in instances]}"
|
||||
)
|
||||
return instances
|
||||
else:
|
||||
logger.error(f"Fetch decode instances failed: {resp.status_code}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.exception(f"Error fetching decode instances: {e}")
|
||||
return []
|
||||
|
||||
def _get_instance_key(self, instance: Dict) -> str:
|
||||
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
||||
return f"{instance.get('host_ip')}:{instance.get('port')}"
|
||||
|
||||
def _supports_rdma(self, instance: Dict) -> bool:
|
||||
"""Check if instance supports RDMA transfer protocol and has RDMA ports configured."""
|
||||
transfer_protocol = instance.get("transfer_protocol", [])
|
||||
return "rdma" in transfer_protocol and instance.get("rdma_ports")
|
||||
|
||||
def _check_instance_health(self, instance: Dict) -> bool:
|
||||
"""Check if Decode instance is healthy via HTTP /health endpoint."""
|
||||
try:
|
||||
host_ip = instance.get("host_ip")
|
||||
port = instance.get("port")
|
||||
url = f"http://{host_ip}:{port}/health"
|
||||
response = requests.get(url, timeout=self._timeout)
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"_is_decode_health error: {e}, host: {host_ip}, port: {port}")
|
||||
return False
|
||||
|
||||
def _try_rdma_connect(self, instance: Dict) -> bool:
|
||||
"""
|
||||
Attempt to establish RDMA connection to a Decode instance.
|
||||
|
||||
Workflow:
|
||||
1. Generate unique task_id and submit connect task to engine_worker_queue
|
||||
2. Wait for connection result in connect_status dict (set by response loop)
|
||||
3. Return True if connected successfully within timeout, False otherwise
|
||||
|
||||
Note: If already connected, the underlying RDMA layer will reuse existing connection.
|
||||
|
||||
Args:
|
||||
instance: Decode instance info dict with 'host_ip' and 'rdma_ports'
|
||||
|
||||
Returns:
|
||||
True if connection succeeded, False if failed or timeout
|
||||
"""
|
||||
try:
|
||||
key = self._get_instance_key(instance)
|
||||
task_id = f"{key}-{uuid4().hex}"
|
||||
task = {"task_id": task_id, "ip": instance.get("host_ip"), "rdma_ports": instance.get("rdma_ports")}
|
||||
self.engine_worker_queue.put_connect_rdma_task(task)
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time <= self._timeout:
|
||||
with self._lock:
|
||||
if task_id in self.connect_status:
|
||||
result = self.connect_status[task_id]
|
||||
del self.connect_status[task_id]
|
||||
return result
|
||||
time.sleep(0.01)
|
||||
|
||||
# Timeout: clean up any late-arriving result to prevent memory leak
|
||||
with self._lock:
|
||||
self.connect_status.pop(task_id, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"_try_rdma_connect error: {e}")
|
||||
return False
|
||||
|
||||
def _check_rdma_connection(self, instance: Dict) -> bool:
|
||||
"""
|
||||
Verify RDMA connection to instance is still alive.
|
||||
|
||||
Reuses _try_rdma_connect() since the underlying RDMA layer:
|
||||
- Returns success immediately if connection already exists
|
||||
- Attempts reconnection if connection was lost
|
||||
"""
|
||||
return self._try_rdma_connect(instance)
|
||||
@@ -171,6 +171,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
|
||||
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
|
||||
"FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")),
|
||||
# Enable early RDMA connection for PD disaggregation
|
||||
"FD_ENABLE_PD_RDMA_EAGER_CONNECT": lambda: bool(int(os.getenv("FD_ENABLE_PD_RDMA_EAGER_CONNECT", "0"))),
|
||||
"FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))),
|
||||
"FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")),
|
||||
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
||||
|
||||
+35
-17
@@ -12,6 +12,7 @@ import random
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
@@ -114,23 +115,26 @@ class Router:
|
||||
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
||||
|
||||
async with self.lock:
|
||||
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
|
||||
self.mixed_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
|
||||
self.prefill_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register prefill instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
|
||||
self.decode_servers.append(inst_info)
|
||||
logger.info(
|
||||
f"Register decode instance success: {inst_info}, "
|
||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
||||
)
|
||||
instance_key = inst_info.get_key()
|
||||
|
||||
if inst_info.role == InstanceRole.MIXED:
|
||||
self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
|
||||
elif inst_info.role == InstanceRole.PREFILL:
|
||||
self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
|
||||
elif inst_info.role == InstanceRole.DECODE:
|
||||
self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
|
||||
|
||||
def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
|
||||
"""Update existing instance or add new one based on key (host_ip:port)."""
|
||||
for i, existing in enumerate(server_list):
|
||||
if existing.get_key() == key:
|
||||
if existing != inst_info:
|
||||
server_list[i] = inst_info
|
||||
logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
|
||||
return
|
||||
|
||||
server_list.append(inst_info)
|
||||
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
|
||||
|
||||
async def registered_number(self):
|
||||
"""Get number of registered instances"""
|
||||
@@ -140,6 +144,14 @@ class Router:
|
||||
"decode": len(self.decode_servers),
|
||||
}
|
||||
|
||||
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
|
||||
"""Get all registered decode instances, optionally filtered by version"""
|
||||
async with self.lock:
|
||||
instances = self.decode_servers
|
||||
if version is not None:
|
||||
instances = [inst for inst in instances if inst.version == version]
|
||||
return [inst.to_dict() for inst in instances]
|
||||
|
||||
async def select_pd(self):
|
||||
"""Select one prefill and one decode server"""
|
||||
async with self.lock:
|
||||
@@ -456,6 +468,12 @@ async def registered_number():
|
||||
return await app.state.router.registered_number()
|
||||
|
||||
|
||||
@app.get("/decode_instances")
|
||||
async def decode_instances(version: Optional[str] = None):
|
||||
"""Get all registered decode instances, optionally filtered by version"""
|
||||
return await app.state.router.get_decode_instances(version)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request_data: dict):
|
||||
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import asyncio
|
||||
from dataclasses import MISSING, asdict, dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@@ -41,6 +41,9 @@ class InstanceInfo:
|
||||
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
||||
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
||||
tp_size: int = 1
|
||||
is_paused: bool = False
|
||||
version: Optional[str] = None
|
||||
connected_decodes: List[Dict] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
|
||||
@@ -92,6 +95,10 @@ class InstanceInfo:
|
||||
url = f"http://{url}"
|
||||
return url
|
||||
|
||||
def get_key(self) -> str:
|
||||
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
||||
return f"{self.host_ip}:{self.port}"
|
||||
|
||||
|
||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1163,6 +1163,7 @@ trace_logger = FastDeployLogger().get_trace_logger("trace", "trace.log")
|
||||
router_logger = get_logger("router", "router.log")
|
||||
fmq_logger = get_logger("fmq", "fmq.log")
|
||||
obj_logger = get_logger("obj", "obj.log") # debug内存问题
|
||||
register_manager_logger = get_logger("register_manager", "register_manager.log")
|
||||
|
||||
|
||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
@@ -1152,31 +1152,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_register_to_router_disabled(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.cfg.router_config.router = None
|
||||
|
||||
with (
|
||||
patch.object(eng, "llm_logger") as mock_logger,
|
||||
patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock,
|
||||
):
|
||||
eng._register_to_router()
|
||||
|
||||
mock_logger.info.assert_called()
|
||||
thread_mock.assert_not_called()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_register_to_router_enabled_starts_thread(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.cfg.router_config.router = "http://router"
|
||||
|
||||
with patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock:
|
||||
eng._register_to_router()
|
||||
|
||||
thread_mock.assert_called_once()
|
||||
thread_mock.return_value.start.assert_called_once()
|
||||
self._detach_finalizer(eng)
|
||||
|
||||
def test_insert_zmq_task_to_scheduler_normal_request(self):
|
||||
eng = self._make_mixed_engine()
|
||||
eng.running = True
|
||||
|
||||
@@ -177,6 +177,7 @@ class TestInitEplbSignals(unittest.TestCase):
|
||||
model_cfg.model = "/test/model"
|
||||
model_cfg.architectures = ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_layer_per_block = 1
|
||||
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
|
||||
@@ -57,6 +57,7 @@ class TestRedundantExpertManager(unittest.TestCase):
|
||||
model_cfg.model = "/test/model"
|
||||
model_cfg.architectures = ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_layer_per_block = 1
|
||||
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
|
||||
@@ -63,6 +63,7 @@ class FakeModelConfig:
|
||||
self.logprobs_mode = "raw_logprobs"
|
||||
self.architectures = ["test_model"]
|
||||
self.mm_max_tokens_per_item = None
|
||||
self.version = None # Required for register_info
|
||||
|
||||
|
||||
class FakeLoadConfig:
|
||||
|
||||
@@ -35,6 +35,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
|
||||
model_cfg.print = print
|
||||
model_cfg.architectures = ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_token_per_layer = 1
|
||||
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
|
||||
@@ -81,6 +81,7 @@ def _build_manager(
|
||||
model_cfg.max_model_len = max_model_len
|
||||
model_cfg.architectures = architectures or ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_token_per_layer = 1
|
||||
cache_cfg.kv_cache_ratio = 1.0
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
@@ -142,6 +143,7 @@ class TestResourceManagerV1(unittest.TestCase):
|
||||
model_cfg.max_model_len = 3200
|
||||
model_cfg.architectures = ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_token_per_layer = 1
|
||||
cache_cfg.kv_cache_ratio = 1.0
|
||||
parallel_cfg = ParallelConfig(args)
|
||||
@@ -304,6 +306,7 @@ class TestRevertChunkedMMInput(unittest.TestCase):
|
||||
model_cfg.max_model_len = 3200
|
||||
model_cfg.architectures = ["test_model"]
|
||||
model_cfg.mm_max_tokens_per_item = None
|
||||
model_cfg.version = None # Required for register_info
|
||||
cache_cfg.bytes_per_token_per_layer = 1
|
||||
cache_cfg.kv_cache_ratio = 1.0
|
||||
cache_cfg.block_size = 64
|
||||
|
||||
Reference in New Issue
Block a user