mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[PD Disaggregation][RL] Register to router with version and support rdma eager connect for pd (#6718)
* [Feature] Register to router with version info for PD disaggregation Add RegisterManager for PD (Prefill-Decode) disaggregated deployment: - All instances (Prefill/Decode) register to Router with heartbeat - Prefill instances fetch Decode instance list from Router - Prefill instances establish eager RDMA connections to Decode instances - Register info includes: host_ip, port, role, version, is_paused, connected_decodes Changes: - Add RegisterManager class for managing PD registration and RDMA connections - Add version field to ModelConfig for model version tracking - Add connected_decodes to register_info for tracking connected Decode instances - Add FD_ENABLE_PD_RDMA_EAGER_CONNECT environment variable Test fixes: - Add None checks for load_config in FDConfig.__init__ - Add version attribute to test mock model configs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refine * remove test --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -107,6 +107,8 @@ def get_decode_ip_idx(task):
|
|||||||
"""For compatibility, get decode ip and idx from task"""
|
"""For compatibility, get decode ip and idx from task"""
|
||||||
if "decode_ip" in task:
|
if "decode_ip" in task:
|
||||||
decode_ip = task["decode_ip"]
|
decode_ip = task["decode_ip"]
|
||||||
|
elif "host_ip" in task:
|
||||||
|
decode_ip = task["host_ip"]
|
||||||
else:
|
else:
|
||||||
decode_ip = task["ip"]
|
decode_ip = task["ip"]
|
||||||
if "decode_rdma_ports" in task:
|
if "decode_rdma_ports" in task:
|
||||||
@@ -911,6 +913,7 @@ class CacheMessagerV1:
|
|||||||
response = {"task_id": task_id, "success": True}
|
response = {"task_id": task_id, "success": True}
|
||||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||||
|
logger.debug(f"_handle_connect_task send response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}")
|
logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|||||||
+27
-1
@@ -24,6 +24,7 @@ from typing import Any, Dict, Literal, Optional, Union
|
|||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
|
import yaml
|
||||||
from packaging.version import parse as parse_version
|
from packaging.version import parse as parse_version
|
||||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
@@ -227,6 +228,7 @@ class ModelConfig:
|
|||||||
self.kv_cache_quant_scale_path = ""
|
self.kv_cache_quant_scale_path = ""
|
||||||
self.enable_entropy = False
|
self.enable_entropy = False
|
||||||
self.model_impl: ModelImpl = "auto"
|
self.model_impl: ModelImpl = "auto"
|
||||||
|
self.version: str = "init" # will override by the version.yaml in model dir
|
||||||
|
|
||||||
self.partial_rotary_factor: float = 1.0
|
self.partial_rotary_factor: float = 1.0
|
||||||
self.num_nextn_predict_layers = 0
|
self.num_nextn_predict_layers = 0
|
||||||
@@ -440,6 +442,17 @@ class ModelConfig:
|
|||||||
f"Config file path: {config_path}"
|
f"Config file path: {config_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def read_model_version(self):
|
||||||
|
"""
|
||||||
|
Read the version information from a YAML file located at 'version.yaml' within the model directory.
|
||||||
|
If the file exists, it extracts the 'version' field using yaml.safe_load.
|
||||||
|
Raises an assertion error if the file is not found at the specified path.
|
||||||
|
"""
|
||||||
|
version_path = os.path.join(self.model, "version.yaml")
|
||||||
|
assert os.path.exists(version_path), f"version.yaml not exist at {version_path}"
|
||||||
|
with open(version_path, "r", encoding="utf-8") as f:
|
||||||
|
self.version = yaml.safe_load(f)["version"]
|
||||||
|
|
||||||
def _get_default_runner_type(
|
def _get_default_runner_type(
|
||||||
self,
|
self,
|
||||||
architectures: list[str],
|
architectures: list[str],
|
||||||
@@ -1939,7 +1952,7 @@ class FDConfig:
|
|||||||
|
|
||||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
||||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||||
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
|
if num_ranks > self.max_chips_per_node and self.load_config and self.load_config.load_strategy != "meta":
|
||||||
self.worker_num_per_node = self.max_chips_per_node
|
self.worker_num_per_node = self.max_chips_per_node
|
||||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||||
@@ -1953,6 +1966,16 @@ class FDConfig:
|
|||||||
if current_platform.is_intel_hpu():
|
if current_platform.is_intel_hpu():
|
||||||
self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
|
self.parallel_config.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.parallel_config.device_ids)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.load_config
|
||||||
|
and self.load_config.dynamic_load_weight
|
||||||
|
and self.router_config
|
||||||
|
and self.router_config.router
|
||||||
|
):
|
||||||
|
# For RL scenario: version.yaml will be required for models in future releases.
|
||||||
|
# Temporarily enforce use router to be enabled.
|
||||||
|
self.model_config.read_model_version()
|
||||||
|
|
||||||
self.read_from_config()
|
self.read_from_config()
|
||||||
self.postprocess()
|
self.postprocess()
|
||||||
self.init_cache_info()
|
self.init_cache_info()
|
||||||
@@ -2293,6 +2316,9 @@ class FDConfig:
|
|||||||
"device_ids": self.local_device_ids,
|
"device_ids": self.local_device_ids,
|
||||||
"transfer_protocol": transfer_protocol,
|
"transfer_protocol": transfer_protocol,
|
||||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||||
|
"is_paused": False,
|
||||||
|
"version": self.model_config.version,
|
||||||
|
"connected_decodes": [],
|
||||||
}
|
}
|
||||||
logger.info(f"register_info: {self.register_info}")
|
logger.info(f"register_info: {self.register_info}")
|
||||||
|
|
||||||
|
|||||||
@@ -34,11 +34,11 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import requests
|
|
||||||
import zmq
|
import zmq
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import fastdeploy.metrics.trace as tracing
|
import fastdeploy.metrics.trace as tracing
|
||||||
|
from fastdeploy.engine.register_manager import RegisterManager
|
||||||
from fastdeploy.engine.request import (
|
from fastdeploy.engine.request import (
|
||||||
ControlRequest,
|
ControlRequest,
|
||||||
ControlResponse,
|
ControlResponse,
|
||||||
@@ -64,7 +64,6 @@ from fastdeploy.inter_communicator.fmq import FMQ
|
|||||||
from fastdeploy.metrics.metrics import main_process_metrics
|
from fastdeploy.metrics.metrics import main_process_metrics
|
||||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||||
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
from fastdeploy.plugins.token_processor import load_token_processor_plugins
|
||||||
from fastdeploy.router.utils import check_service_health
|
|
||||||
from fastdeploy.spec_decode import SpecMethod
|
from fastdeploy.spec_decode import SpecMethod
|
||||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||||
@@ -178,6 +177,13 @@ class EngineService:
|
|||||||
# between the CPU transfer process and the worker process.
|
# between the CPU transfer process and the worker process.
|
||||||
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
|
self.resource_manager.cache_manager.gpu_cache_lock = self.gpu_cache_lock
|
||||||
|
|
||||||
|
# Initialize RegisterManager
|
||||||
|
self._register_manager = RegisterManager(
|
||||||
|
cfg=self.cfg,
|
||||||
|
engine_worker_queue=self.engine_worker_queue,
|
||||||
|
get_is_paused=self._get_is_paused_safe,
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.eplb_config.enable_eplb:
|
if self.cfg.eplb_config.enable_eplb:
|
||||||
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||||
init_eplb_signals(cfg, current_suffix)
|
init_eplb_signals(cfg, current_suffix)
|
||||||
@@ -210,7 +216,7 @@ class EngineService:
|
|||||||
if self.cfg.scheduler_config.splitwise_role == "decode":
|
if self.cfg.scheduler_config.splitwise_role == "decode":
|
||||||
self._decode_process_splitwise_requests()
|
self._decode_process_splitwise_requests()
|
||||||
|
|
||||||
self._register_to_router()
|
self._register_manager.start()
|
||||||
|
|
||||||
def start_worker_service(self, async_llm_pid=None):
|
def start_worker_service(self, async_llm_pid=None):
|
||||||
# Initialize IPC signals for worker management
|
# Initialize IPC signals for worker management
|
||||||
@@ -1370,6 +1376,11 @@ class EngineService:
|
|||||||
with self._pause_cond:
|
with self._pause_cond:
|
||||||
return {"is_paused": self.is_paused}
|
return {"is_paused": self.is_paused}
|
||||||
|
|
||||||
|
def _get_is_paused_safe(self) -> bool:
|
||||||
|
"""Thread-safe getter for is_paused state, used by RegisterManager."""
|
||||||
|
with self._pause_cond:
|
||||||
|
return self.is_paused
|
||||||
|
|
||||||
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
def _control_update_weights(self, control_request: ControlRequest) -> Optional[dict]:
|
||||||
"""Update model weights
|
"""Update model weights
|
||||||
Args:
|
Args:
|
||||||
@@ -1387,7 +1398,20 @@ class EngineService:
|
|||||||
error_msg = "Pause LLM Engine first before calling updating weights"
|
error_msg = "Pause LLM Engine first before calling updating weights"
|
||||||
self.llm_logger.error(error_msg)
|
self.llm_logger.error(error_msg)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
return self._call_worker(control_request, 60)
|
responses = self._call_worker(control_request, 60)
|
||||||
|
|
||||||
|
if responses:
|
||||||
|
new_version = None
|
||||||
|
for resp in responses:
|
||||||
|
# Expect each worker response to be a dict-like object
|
||||||
|
if isinstance(resp, dict) and "version" in resp:
|
||||||
|
new_version = resp.get("version")
|
||||||
|
self.llm_logger.info(f"Update Weights Version in Config: {new_version}")
|
||||||
|
break
|
||||||
|
if new_version is not None:
|
||||||
|
self.cfg.model_config.version = new_version
|
||||||
|
|
||||||
|
return responses
|
||||||
|
|
||||||
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
async def _wait_all_control_responses(self, request_id: str, timeout: int):
|
||||||
"""Wait for control responses from all workers with a global timeout.
|
"""Wait for control responses from all workers with a global timeout.
|
||||||
@@ -1699,55 +1723,6 @@ class EngineService:
|
|||||||
self.llm_logger.error(f"Clear data error: {e}")
|
self.llm_logger.error(f"Clear data error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _register_to_router(self):
|
|
||||||
"""
|
|
||||||
Periodically send server information to the router for registeration, and it is used
|
|
||||||
as a heartbeat signal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _register():
|
|
||||||
timeout = 5
|
|
||||||
sleep_seconds = 5
|
|
||||||
is_registered = False
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
api_server_host = self.cfg.router_config.api_server_host
|
|
||||||
api_server_port = self.cfg.router_config.api_server_port
|
|
||||||
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
|
||||||
if not check_service_health(api_server_url):
|
|
||||||
time.sleep(sleep_seconds)
|
|
||||||
self.llm_logger.info("Wait for API service health and then register to router")
|
|
||||||
time.sleep(sleep_seconds)
|
|
||||||
continue
|
|
||||||
|
|
||||||
router_url = self.cfg.router_config.router
|
|
||||||
resp = requests.post(
|
|
||||||
f"{router_url}/register",
|
|
||||||
json=self.cfg.register_info,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
if resp.ok:
|
|
||||||
if not is_registered:
|
|
||||||
is_registered = True
|
|
||||||
self.llm_logger.info("Register to router successfully")
|
|
||||||
else:
|
|
||||||
self.llm_logger.error(
|
|
||||||
f"Send server info to router failed: {resp.status_code}, "
|
|
||||||
f"{resp.text}, {self.cfg.register_info}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
|
|
||||||
|
|
||||||
time.sleep(sleep_seconds)
|
|
||||||
|
|
||||||
if self.cfg.router_config.router is None:
|
|
||||||
self.llm_logger.info("Router is not enabled, skip registering to router")
|
|
||||||
else:
|
|
||||||
register_thread = threading.Thread(target=_register, daemon=True)
|
|
||||||
register_thread.start()
|
|
||||||
|
|
||||||
def _exit_sub_services(self):
|
def _exit_sub_services(self):
|
||||||
"""
|
"""
|
||||||
exit sub services
|
exit sub services
|
||||||
|
|||||||
@@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Callable, Dict, List
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from fastdeploy import envs
|
||||||
|
from fastdeploy.router.utils import check_service_health
|
||||||
|
from fastdeploy.utils import register_manager_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterManager:
|
||||||
|
"""
|
||||||
|
Manages Prefill/Decode instance registration and RDMA connection for PD disaggregation.
|
||||||
|
|
||||||
|
In PD (Prefill-Decode) disaggregated deployment:
|
||||||
|
- All instances (Prefill/Decode) register to Router with heartbeat
|
||||||
|
- Prefill instances fetch Decode instance list from Router
|
||||||
|
- Prefill instances establish eager RDMA connections to Decode instances
|
||||||
|
|
||||||
|
Thread Model:
|
||||||
|
- _register_to_router: Periodic heartbeat registration thread
|
||||||
|
- _eager_connect_loop: Periodic RDMA connection management thread
|
||||||
|
- _get_connect_rdma_task_response_loop: Async RDMA connection result receiver thread
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
engine_worker_queue,
|
||||||
|
get_is_paused: Callable[[], bool],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize RegisterManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: FDConfig object containing router, scheduler, cache configurations
|
||||||
|
engine_worker_queue: Queue for communicating RDMA connect tasks with worker
|
||||||
|
get_is_paused: Callable that returns current engine pause state
|
||||||
|
"""
|
||||||
|
self.cfg = cfg
|
||||||
|
self.engine_worker_queue = engine_worker_queue
|
||||||
|
self.get_is_paused = get_is_paused
|
||||||
|
|
||||||
|
# Registration state
|
||||||
|
self._is_registered = False
|
||||||
|
|
||||||
|
# RDMA connection state (protected by _lock)
|
||||||
|
self.connected_decodes: List[Dict] = [] # Successfully connected Decode instances
|
||||||
|
self.connect_status: Dict[str, bool] = {} # task_id -> connection result
|
||||||
|
|
||||||
|
# Timing configuration (seconds)
|
||||||
|
self._timeout = 5 # HTTP request and RDMA connect timeout
|
||||||
|
self._sleep_seconds = 5 # Interval between iterations
|
||||||
|
|
||||||
|
self._lock = threading.Lock() # Protects connected_decodes and connect_status
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start background threads for registration and RDMA connection management."""
|
||||||
|
self._register_to_router()
|
||||||
|
self._start_eager_connect_loop()
|
||||||
|
|
||||||
|
def get_connected_decodes(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Return a snapshot of successfully connected Decode instances.
|
||||||
|
Thread-safe: returns a copy to avoid concurrent modification issues.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return list(self.connected_decodes)
|
||||||
|
|
||||||
|
def is_registered(self) -> bool:
|
||||||
|
"""Return whether this instance has successfully registered to Router."""
|
||||||
|
return self._is_registered
|
||||||
|
|
||||||
|
def _register_to_router(self) -> None:
|
||||||
|
"""
|
||||||
|
Start background thread for periodic Router registration (heartbeat).
|
||||||
|
|
||||||
|
Registration info includes: host_ip, port, role, version, is_paused, etc.
|
||||||
|
This serves as both initial registration and keep-alive heartbeat.
|
||||||
|
"""
|
||||||
|
router_url = self.cfg.router_config.router
|
||||||
|
if router_url is None:
|
||||||
|
logger.info("Router is not enabled, skip registering to router")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _register():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
api_server_host = self.cfg.router_config.api_server_host
|
||||||
|
api_server_port = self.cfg.router_config.api_server_port
|
||||||
|
api_server_url = f"http://{api_server_host}:{api_server_port}"
|
||||||
|
if not check_service_health(api_server_url):
|
||||||
|
logger.info("Wait for API service health and then register to router")
|
||||||
|
time.sleep(self._sleep_seconds)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update registration info
|
||||||
|
self.cfg.register_info["is_paused"] = self.get_is_paused()
|
||||||
|
self.cfg.register_info["version"] = self.cfg.model_config.version
|
||||||
|
self.cfg.register_info["connected_decodes"] = self.get_connected_decodes()
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{router_url}/register",
|
||||||
|
json=self.cfg.register_info,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.ok:
|
||||||
|
if not self._is_registered:
|
||||||
|
self._is_registered = True
|
||||||
|
logger.info("Register to router successfully")
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Send server info to router failed: {resp.status_code}, "
|
||||||
|
f"{resp.text}, {self.cfg.register_info}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Unexpected error during router registration: {e}")
|
||||||
|
|
||||||
|
time.sleep(self._sleep_seconds)
|
||||||
|
|
||||||
|
register_thread = threading.Thread(target=_register, daemon=True)
|
||||||
|
register_thread.start()
|
||||||
|
|
||||||
|
def _start_eager_connect_loop(self) -> None:
|
||||||
|
"""
|
||||||
|
Start background threads for eager RDMA connection management.
|
||||||
|
|
||||||
|
Only enabled when all conditions are met:
|
||||||
|
- Router is configured
|
||||||
|
- This instance is Prefill role
|
||||||
|
- FD_ENABLE_PD_RDMA_EAGER_CONNECT=1
|
||||||
|
- RDMA transfer protocol is enabled
|
||||||
|
|
||||||
|
Starts two threads:
|
||||||
|
1. _eager_connect_loop: Periodically discovers and connects to Decode instances
|
||||||
|
2. _get_connect_rdma_task_response_loop: Receives async RDMA connection results
|
||||||
|
"""
|
||||||
|
if not self._should_enable_eager_connect():
|
||||||
|
logger.info("Eager RDMA connect is not enabled, skip")
|
||||||
|
return
|
||||||
|
|
||||||
|
def _eager_connect_loop():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._eager_connect_iteration()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error in eager connect loop: {e}")
|
||||||
|
time.sleep(self._sleep_seconds)
|
||||||
|
|
||||||
|
connect_thread = threading.Thread(target=_eager_connect_loop, daemon=True)
|
||||||
|
connect_thread.start()
|
||||||
|
logger.info("Eager RDMA connect loop started")
|
||||||
|
|
||||||
|
def _get_connect_rdma_task_response_loop():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
resp = self.engine_worker_queue.get_connect_rdma_task_response()
|
||||||
|
if resp:
|
||||||
|
task_id = resp["task_id"]
|
||||||
|
is_success = resp["success"]
|
||||||
|
with self._lock:
|
||||||
|
self.connect_status[task_id] = is_success
|
||||||
|
logger.debug(f"get_connect_rdma_task_response response: {resp}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"_keep_get_connect_rdma_task_response got error: {e}, " f"{traceback.format_exc()}")
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
get_resp_thread = threading.Thread(target=_get_connect_rdma_task_response_loop, daemon=True)
|
||||||
|
get_resp_thread.start()
|
||||||
|
logger.info("Get connect rdma task response loop started")
|
||||||
|
|
||||||
|
def _should_enable_eager_connect(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if eager RDMA connect should be enabled.
|
||||||
|
|
||||||
|
Returns True only when:
|
||||||
|
- Router URL is configured
|
||||||
|
- Instance role is 'prefill'
|
||||||
|
- FD_ENABLE_PD_RDMA_EAGER_CONNECT env is set
|
||||||
|
- RDMA protocol is in transfer_protocol list
|
||||||
|
- Local RDMA ports are configured
|
||||||
|
"""
|
||||||
|
if self.cfg.router_config.router is None:
|
||||||
|
return False
|
||||||
|
if self.cfg.scheduler_config.splitwise_role != "prefill":
|
||||||
|
return False
|
||||||
|
if not envs.FD_ENABLE_PD_RDMA_EAGER_CONNECT:
|
||||||
|
return False
|
||||||
|
|
||||||
|
transfer_protocol = self.cfg.register_info.get("transfer_protocol", [])
|
||||||
|
rdma_ports = self.cfg.cache_config.local_rdma_comm_ports
|
||||||
|
if not ("rdma" in transfer_protocol and rdma_ports):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _eager_connect_iteration(self) -> None:
|
||||||
|
"""
|
||||||
|
Single iteration of the eager RDMA connect loop.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Fetch Decode instances from Router (filtered by model version)
|
||||||
|
2. For new instances: check health -> check RDMA support -> establish RDMA connection
|
||||||
|
3. For existing instances: verify health and RDMA connection status
|
||||||
|
4. Remove unhealthy or disconnected instances from connected_decodes
|
||||||
|
"""
|
||||||
|
if not self._is_registered:
|
||||||
|
logger.info("This instance has not registered to router, skip eager connect in this step")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Fetch Decode instances from Router
|
||||||
|
instances = self._fetch_decode_instances_internal()
|
||||||
|
if not instances:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: Process new instances - try to establish RDMA connection
|
||||||
|
with self._lock:
|
||||||
|
connected_decodes_snapshot = list(self.connected_decodes)
|
||||||
|
existing_keys = {self._get_instance_key(inst) for inst in connected_decodes_snapshot}
|
||||||
|
|
||||||
|
for instance in instances:
|
||||||
|
try:
|
||||||
|
instance_key = self._get_instance_key(instance)
|
||||||
|
|
||||||
|
# Skip already connected instances
|
||||||
|
if instance_key in existing_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip unhealthy instances
|
||||||
|
if not self._check_instance_health(instance):
|
||||||
|
logger.debug(f"Instance {instance_key} is unhealthy, skip")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip instances without RDMA support
|
||||||
|
if not self._supports_rdma(instance):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try RDMA connection
|
||||||
|
if self._try_rdma_connect(instance):
|
||||||
|
with self._lock:
|
||||||
|
if instance not in self.connected_decodes:
|
||||||
|
self.connected_decodes.append(instance)
|
||||||
|
logger.info(f"RDMA connect succeeded: {instance_key}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"RDMA connect failed: {instance_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error processing instance {instance}: {e}")
|
||||||
|
|
||||||
|
# Step 3: Verify existing connections - check health and RDMA status
|
||||||
|
to_remove = []
|
||||||
|
for instance in connected_decodes_snapshot:
|
||||||
|
instance_key = self._get_instance_key(instance)
|
||||||
|
|
||||||
|
if not self._check_instance_health(instance):
|
||||||
|
to_remove.append(instance)
|
||||||
|
logger.warning(f"Instance {instance_key} is unhealthy, will remove")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self._check_rdma_connection(instance):
|
||||||
|
to_remove.append(instance)
|
||||||
|
logger.warning(f"Instance {instance_key} RDMA connection lost, will remove")
|
||||||
|
|
||||||
|
# Step 4: Remove failed instances from connected list
|
||||||
|
for instance in to_remove:
|
||||||
|
with self._lock:
|
||||||
|
if instance in self.connected_decodes:
|
||||||
|
self.connected_decodes.remove(instance)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Connected decodes num is {len(self.connected_decodes)}, "
|
||||||
|
f"connected decodes are {[self._get_instance_key(inst) for inst in self.connected_decodes]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_decode_instances_internal(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Fetch Decode instance list from Router.
|
||||||
|
|
||||||
|
Queries Router's /decode_instances endpoint with model version filter.
|
||||||
|
Returns empty list on error to allow retry in next iteration.
|
||||||
|
"""
|
||||||
|
router_url = self.cfg.router_config.router
|
||||||
|
if router_url is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
version = self.cfg.model_config.version
|
||||||
|
resp = requests.get(
|
||||||
|
f"{router_url}/decode_instances",
|
||||||
|
params={"version": version},
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.ok:
|
||||||
|
instances = resp.json()
|
||||||
|
logger.debug(
|
||||||
|
f"Fetched {len(instances)} decode instances from router, "
|
||||||
|
f"{[self._get_instance_key(instance) for instance in instances]}"
|
||||||
|
)
|
||||||
|
return instances
|
||||||
|
else:
|
||||||
|
logger.error(f"Fetch decode instances failed: {resp.status_code}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error fetching decode instances: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_instance_key(self, instance: Dict) -> str:
|
||||||
|
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
||||||
|
return f"{instance.get('host_ip')}:{instance.get('port')}"
|
||||||
|
|
||||||
|
def _supports_rdma(self, instance: Dict) -> bool:
|
||||||
|
"""Check if instance supports RDMA transfer protocol and has RDMA ports configured."""
|
||||||
|
transfer_protocol = instance.get("transfer_protocol", [])
|
||||||
|
return "rdma" in transfer_protocol and instance.get("rdma_ports")
|
||||||
|
|
||||||
|
def _check_instance_health(self, instance: Dict) -> bool:
|
||||||
|
"""Check if Decode instance is healthy via HTTP /health endpoint."""
|
||||||
|
try:
|
||||||
|
host_ip = instance.get("host_ip")
|
||||||
|
port = instance.get("port")
|
||||||
|
url = f"http://{host_ip}:{port}/health"
|
||||||
|
response = requests.get(url, timeout=self._timeout)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"_is_decode_health error: {e}, host: {host_ip}, port: {port}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _try_rdma_connect(self, instance: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Attempt to establish RDMA connection to a Decode instance.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. Generate unique task_id and submit connect task to engine_worker_queue
|
||||||
|
2. Wait for connection result in connect_status dict (set by response loop)
|
||||||
|
3. Return True if connected successfully within timeout, False otherwise
|
||||||
|
|
||||||
|
Note: If already connected, the underlying RDMA layer will reuse existing connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: Decode instance info dict with 'host_ip' and 'rdma_ports'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if connection succeeded, False if failed or timeout
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = self._get_instance_key(instance)
|
||||||
|
task_id = f"{key}-{uuid4().hex}"
|
||||||
|
task = {"task_id": task_id, "ip": instance.get("host_ip"), "rdma_ports": instance.get("rdma_ports")}
|
||||||
|
self.engine_worker_queue.put_connect_rdma_task(task)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time <= self._timeout:
|
||||||
|
with self._lock:
|
||||||
|
if task_id in self.connect_status:
|
||||||
|
result = self.connect_status[task_id]
|
||||||
|
del self.connect_status[task_id]
|
||||||
|
return result
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Timeout: clean up any late-arriving result to prevent memory leak
|
||||||
|
with self._lock:
|
||||||
|
self.connect_status.pop(task_id, None)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"_try_rdma_connect error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_rdma_connection(self, instance: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Verify RDMA connection to instance is still alive.
|
||||||
|
|
||||||
|
Reuses _try_rdma_connect() since the underlying RDMA layer:
|
||||||
|
- Returns success immediately if connection already exists
|
||||||
|
- Attempts reconnection if connection was lost
|
||||||
|
"""
|
||||||
|
return self._try_rdma_connect(instance)
|
||||||
@@ -171,6 +171,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
|
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
|
||||||
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
|
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
|
||||||
"FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")),
|
"FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")),
|
||||||
|
# Enable early RDMA connection for PD disaggregation
|
||||||
|
"FD_ENABLE_PD_RDMA_EAGER_CONNECT": lambda: bool(int(os.getenv("FD_ENABLE_PD_RDMA_EAGER_CONNECT", "0"))),
|
||||||
"FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))),
|
"FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))),
|
||||||
"FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")),
|
"FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")),
|
||||||
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
||||||
|
|||||||
+35
-17
@@ -12,6 +12,7 @@ import random
|
|||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -114,23 +115,26 @@ class Router:
|
|||||||
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
raise RuntimeError(f"Instance {inst_info} is not healthy")
|
||||||
|
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
if inst_info.role == InstanceRole.MIXED and inst_info not in self.mixed_servers:
|
instance_key = inst_info.get_key()
|
||||||
self.mixed_servers.append(inst_info)
|
|
||||||
logger.info(
|
if inst_info.role == InstanceRole.MIXED:
|
||||||
f"Register mixed instance success: {inst_info}, " f"total mixed: {len(self.mixed_servers)}"
|
self._update_or_add_instance(self.mixed_servers, inst_info, instance_key, "mixed")
|
||||||
)
|
elif inst_info.role == InstanceRole.PREFILL:
|
||||||
elif inst_info.role == InstanceRole.PREFILL and inst_info not in self.prefill_servers:
|
self._update_or_add_instance(self.prefill_servers, inst_info, instance_key, "prefill")
|
||||||
self.prefill_servers.append(inst_info)
|
elif inst_info.role == InstanceRole.DECODE:
|
||||||
logger.info(
|
self._update_or_add_instance(self.decode_servers, inst_info, instance_key, "decode")
|
||||||
f"Register prefill instance success: {inst_info}, "
|
|
||||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
def _update_or_add_instance(self, server_list: List, inst_info: InstanceInfo, key: str, role_name: str):
|
||||||
)
|
"""Update existing instance or add new one based on key (host_ip:port)."""
|
||||||
elif inst_info.role == InstanceRole.DECODE and inst_info not in self.decode_servers:
|
for i, existing in enumerate(server_list):
|
||||||
self.decode_servers.append(inst_info)
|
if existing.get_key() == key:
|
||||||
logger.info(
|
if existing != inst_info:
|
||||||
f"Register decode instance success: {inst_info}, "
|
server_list[i] = inst_info
|
||||||
f"prefill: {len(self.prefill_servers)}, decode: {len(self.decode_servers)}"
|
logger.info(f"Updated {role_name} instance, key: {key}, inst_info: {inst_info}")
|
||||||
)
|
return
|
||||||
|
|
||||||
|
server_list.append(inst_info)
|
||||||
|
logger.info(f"Register {role_name} instance success: {inst_info}, total {role_name}: {len(server_list)}")
|
||||||
|
|
||||||
async def registered_number(self):
|
async def registered_number(self):
|
||||||
"""Get number of registered instances"""
|
"""Get number of registered instances"""
|
||||||
@@ -140,6 +144,14 @@ class Router:
|
|||||||
"decode": len(self.decode_servers),
|
"decode": len(self.decode_servers),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict]:
|
||||||
|
"""Get all registered decode instances, optionally filtered by version"""
|
||||||
|
async with self.lock:
|
||||||
|
instances = self.decode_servers
|
||||||
|
if version is not None:
|
||||||
|
instances = [inst for inst in instances if inst.version == version]
|
||||||
|
return [inst.to_dict() for inst in instances]
|
||||||
|
|
||||||
async def select_pd(self):
|
async def select_pd(self):
|
||||||
"""Select one prefill and one decode server"""
|
"""Select one prefill and one decode server"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
@@ -456,6 +468,12 @@ async def registered_number():
|
|||||||
return await app.state.router.registered_number()
|
return await app.state.router.registered_number()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/decode_instances")
|
||||||
|
async def decode_instances(version: Optional[str] = None):
|
||||||
|
"""Get all registered decode instances, optionally filtered by version"""
|
||||||
|
return await app.state.router.get_decode_instances(version)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def create_chat_completion(request_data: dict):
|
async def create_chat_completion(request_data: dict):
|
||||||
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
return await app.state.router.handle_request(request_data, "v1/chat/completions")
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import MISSING, asdict, dataclass, field, fields
|
from dataclasses import MISSING, asdict, dataclass, field, fields
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
@@ -41,6 +41,9 @@ class InstanceInfo:
|
|||||||
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
||||||
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
|
is_paused: bool = False
|
||||||
|
version: Optional[str] = None
|
||||||
|
connected_decodes: List[Dict] = field(default_factory=list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
|
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
|
||||||
@@ -92,6 +95,10 @@ class InstanceInfo:
|
|||||||
url = f"http://{url}"
|
url = f"http://{url}"
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
def get_key(self) -> str:
|
||||||
|
"""Generate unique identifier for an instance: 'host_ip:port'."""
|
||||||
|
return f"{self.host_ip}:{self.port}"
|
||||||
|
|
||||||
|
|
||||||
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
def check_service_health(base_url: str, timeout: int = 3) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1163,6 +1163,7 @@ trace_logger = FastDeployLogger().get_trace_logger("trace", "trace.log")
|
|||||||
router_logger = get_logger("router", "router.log")
|
router_logger = get_logger("router", "router.log")
|
||||||
fmq_logger = get_logger("fmq", "fmq.log")
|
fmq_logger = get_logger("fmq", "fmq.log")
|
||||||
obj_logger = get_logger("obj", "obj.log") # debug内存问题
|
obj_logger = get_logger("obj", "obj.log") # debug内存问题
|
||||||
|
register_manager_logger = get_logger("register_manager", "register_manager.log")
|
||||||
|
|
||||||
|
|
||||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||||
|
|||||||
@@ -1152,31 +1152,6 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
|||||||
eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
|
eng._control_update_weights(ControlRequest(request_id="ctrl", method="update_weights"))
|
||||||
self._detach_finalizer(eng)
|
self._detach_finalizer(eng)
|
||||||
|
|
||||||
def test_register_to_router_disabled(self):
|
|
||||||
eng = self._make_mixed_engine()
|
|
||||||
eng.cfg.router_config.router = None
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(eng, "llm_logger") as mock_logger,
|
|
||||||
patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock,
|
|
||||||
):
|
|
||||||
eng._register_to_router()
|
|
||||||
|
|
||||||
mock_logger.info.assert_called()
|
|
||||||
thread_mock.assert_not_called()
|
|
||||||
self._detach_finalizer(eng)
|
|
||||||
|
|
||||||
def test_register_to_router_enabled_starts_thread(self):
|
|
||||||
eng = self._make_mixed_engine()
|
|
||||||
eng.cfg.router_config.router = "http://router"
|
|
||||||
|
|
||||||
with patch("fastdeploy.engine.common_engine.threading.Thread") as thread_mock:
|
|
||||||
eng._register_to_router()
|
|
||||||
|
|
||||||
thread_mock.assert_called_once()
|
|
||||||
thread_mock.return_value.start.assert_called_once()
|
|
||||||
self._detach_finalizer(eng)
|
|
||||||
|
|
||||||
def test_insert_zmq_task_to_scheduler_normal_request(self):
|
def test_insert_zmq_task_to_scheduler_normal_request(self):
|
||||||
eng = self._make_mixed_engine()
|
eng = self._make_mixed_engine()
|
||||||
eng.running = True
|
eng.running = True
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class TestInitEplbSignals(unittest.TestCase):
|
|||||||
model_cfg.model = "/test/model"
|
model_cfg.model = "/test/model"
|
||||||
model_cfg.architectures = ["test_model"]
|
model_cfg.architectures = ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_layer_per_block = 1
|
cache_cfg.bytes_per_layer_per_block = 1
|
||||||
|
|
||||||
parallel_cfg = ParallelConfig(args)
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class TestRedundantExpertManager(unittest.TestCase):
|
|||||||
model_cfg.model = "/test/model"
|
model_cfg.model = "/test/model"
|
||||||
model_cfg.architectures = ["test_model"]
|
model_cfg.architectures = ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_layer_per_block = 1
|
cache_cfg.bytes_per_layer_per_block = 1
|
||||||
|
|
||||||
parallel_cfg = ParallelConfig(args)
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class FakeModelConfig:
|
|||||||
self.logprobs_mode = "raw_logprobs"
|
self.logprobs_mode = "raw_logprobs"
|
||||||
self.architectures = ["test_model"]
|
self.architectures = ["test_model"]
|
||||||
self.mm_max_tokens_per_item = None
|
self.mm_max_tokens_per_item = None
|
||||||
|
self.version = None # Required for register_info
|
||||||
|
|
||||||
|
|
||||||
class FakeLoadConfig:
|
class FakeLoadConfig:
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over
|
|||||||
model_cfg.print = print
|
model_cfg.print = print
|
||||||
model_cfg.architectures = ["test_model"]
|
model_cfg.architectures = ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_token_per_layer = 1
|
cache_cfg.bytes_per_token_per_layer = 1
|
||||||
|
|
||||||
parallel_cfg = ParallelConfig(args)
|
parallel_cfg = ParallelConfig(args)
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ def _build_manager(
|
|||||||
model_cfg.max_model_len = max_model_len
|
model_cfg.max_model_len = max_model_len
|
||||||
model_cfg.architectures = architectures or ["test_model"]
|
model_cfg.architectures = architectures or ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_token_per_layer = 1
|
cache_cfg.bytes_per_token_per_layer = 1
|
||||||
cache_cfg.kv_cache_ratio = 1.0
|
cache_cfg.kv_cache_ratio = 1.0
|
||||||
parallel_cfg = ParallelConfig(args)
|
parallel_cfg = ParallelConfig(args)
|
||||||
@@ -142,6 +143,7 @@ class TestResourceManagerV1(unittest.TestCase):
|
|||||||
model_cfg.max_model_len = 3200
|
model_cfg.max_model_len = 3200
|
||||||
model_cfg.architectures = ["test_model"]
|
model_cfg.architectures = ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_token_per_layer = 1
|
cache_cfg.bytes_per_token_per_layer = 1
|
||||||
cache_cfg.kv_cache_ratio = 1.0
|
cache_cfg.kv_cache_ratio = 1.0
|
||||||
parallel_cfg = ParallelConfig(args)
|
parallel_cfg = ParallelConfig(args)
|
||||||
@@ -304,6 +306,7 @@ class TestRevertChunkedMMInput(unittest.TestCase):
|
|||||||
model_cfg.max_model_len = 3200
|
model_cfg.max_model_len = 3200
|
||||||
model_cfg.architectures = ["test_model"]
|
model_cfg.architectures = ["test_model"]
|
||||||
model_cfg.mm_max_tokens_per_item = None
|
model_cfg.mm_max_tokens_per_item = None
|
||||||
|
model_cfg.version = None # Required for register_info
|
||||||
cache_cfg.bytes_per_token_per_layer = 1
|
cache_cfg.bytes_per_token_per_layer = 1
|
||||||
cache_cfg.kv_cache_ratio = 1.0
|
cache_cfg.kv_cache_ratio = 1.0
|
||||||
cache_cfg.block_size = 64
|
cache_cfg.block_size = 64
|
||||||
|
|||||||
Reference in New Issue
Block a user