mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)
* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports * [fix] fix some bugs * [fix] fix rdma port for cache manager/messager * [fix] temporarily cancel port availability check to see if it can pass ci test * [feat] simplify args for multi api server * [fix] fix dp * [fix] fix port for xpu * [fix] add tests for ports post processing & fix ci * [test] fix test_multi_api_server * [fix] fix rdma_comm_ports args for multi_api_server * [fix] fix test_common_engine * [fix] fix test_cache_transfer_manager * [chore] automatically setting FD_ENABLE_MULTI_API_SERVER * [fix] avoid api server from creating engine_args twice * [fix] fix test_run_batch * [fix] fix test_metrics * [fix] fix splitwise connector init * [test] add test_rdma_transfer and test_expert_service * [fix] fix code syntax * [fix] fix test_rdma_transfer and build wheel with rdma script
This commit is contained in:
@@ -20,37 +20,92 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from fastdeploy.utils import get_logger, is_port_available
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import find_free_ports, get_logger, is_port_available
|
||||
|
||||
logger = get_logger("multi_api_server", "multi_api_server.log")
|
||||
|
||||
|
||||
def start_servers(server_count, server_args, ports, metrics_ports, controller_ports):
|
||||
processes = []
|
||||
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
|
||||
for i in range(len(server_args)):
|
||||
if server_args[i] == "--engine-worker-queue-port":
|
||||
engine_worker_queue_port = server_args[i + 1].split(",")
|
||||
break
|
||||
def start_servers(
|
||||
server_count=None,
|
||||
device_count=None,
|
||||
server_args=None,
|
||||
ports=None,
|
||||
metrics_ports=None,
|
||||
controller_ports=None,
|
||||
):
|
||||
ports = ports.split(",")
|
||||
if not check_param(ports, server_count):
|
||||
return
|
||||
if not check_param(metrics_ports, server_count):
|
||||
return
|
||||
if not check_param(engine_worker_queue_port, server_count):
|
||||
return
|
||||
|
||||
if metrics_ports != "-1":
|
||||
metrics_ports = metrics_ports.split(",")
|
||||
if not check_param(metrics_ports, server_count):
|
||||
return
|
||||
|
||||
if controller_ports != "-1":
|
||||
controller_ports = controller_ports.split(",")
|
||||
if not check_param(controller_ports, server_count):
|
||||
return
|
||||
else:
|
||||
controller_ports = [-1] * server_count
|
||||
# check_param(server_args, server_count)
|
||||
|
||||
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
|
||||
port_idx = {}
|
||||
for i in range(len(server_args)):
|
||||
if server_args[i] == "--engine-worker-queue-port":
|
||||
port_idx["engine_worker_queue_port"] = i + 1
|
||||
if server_args[i] == "--cache-queue-port":
|
||||
port_idx["cache_queue_port"] = i + 1
|
||||
if server_args[i] == "--pd-comm-port":
|
||||
port_idx["pd_comm_port"] = i + 1
|
||||
if server_args[i] == "--rdma-comm-ports":
|
||||
port_idx["rdma_comm_ports"] = i + 1
|
||||
|
||||
if "engine_worker_queue_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--engine-worker-queue-port", ",".join(map(str, port))]
|
||||
port_idx["engine_worker_queue_port"] = len(server_args) - 1
|
||||
logger.info(f"No --engine-worker-queue-port specified, using random ports: {port}")
|
||||
engine_worker_queue_port = server_args[port_idx["engine_worker_queue_port"]].split(",")
|
||||
if not check_param(engine_worker_queue_port, server_count):
|
||||
return
|
||||
|
||||
if "cache_queue_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--cache-queue-port", ",".join(map(str, port))]
|
||||
port_idx["cache_queue_port"] = len(server_args) - 1
|
||||
logger.info(f"No --cache-queue-port specified, using random ports: {port}")
|
||||
cache_queue_port = server_args[port_idx["cache_queue_port"]].split(",")
|
||||
if not check_param(cache_queue_port, server_count):
|
||||
return
|
||||
|
||||
if "pd_comm_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--pd-comm-port", ",".join(map(str, port))]
|
||||
port_idx["pd_comm_port"] = len(server_args) - 1
|
||||
logger.info(f"No --pd-comm-port specified, using random ports: {port}")
|
||||
pd_comm_port = server_args[port_idx["pd_comm_port"]].split(",")
|
||||
if not check_param(pd_comm_port, server_count):
|
||||
return
|
||||
|
||||
if "rdma_comm_ports" not in port_idx:
|
||||
port = find_free_ports(num_ports=device_count)
|
||||
server_args += ["--rdma-comm-ports", ",".join(map(str, port))]
|
||||
port_idx["rdma_comm_ports"] = len(server_args) - 1
|
||||
logger.info(f"No --rdma-comm-ports specified, using random ports: {port}")
|
||||
rdma_comm_ports = server_args[port_idx["rdma_comm_ports"]].split(",")
|
||||
if not check_param(rdma_comm_ports, device_count):
|
||||
return
|
||||
|
||||
logger.info(f"Modified server_args: {server_args}")
|
||||
processes = []
|
||||
for i in range(server_count):
|
||||
port = int(ports[i])
|
||||
metrics_port = int(metrics_ports[i])
|
||||
controller_port = int(controller_ports[i])
|
||||
|
||||
env = os.environ.copy()
|
||||
env["FD_ENABLE_MULTI_API_SERVER"] = "1"
|
||||
env["FD_LOG_DIR"] = env.get("FD_LOG_DIR", "log") + f"/log_{i}"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
@@ -59,13 +114,13 @@ def start_servers(server_count, server_args, ports, metrics_ports, controller_po
|
||||
*server_args,
|
||||
"--port",
|
||||
str(port),
|
||||
"--metrics-port",
|
||||
str(metrics_port),
|
||||
"--controller-port",
|
||||
str(controller_port),
|
||||
"--local-data-parallel-id",
|
||||
str(i),
|
||||
]
|
||||
if metrics_ports != "-1":
|
||||
cmd += ["--metrics-port", metrics_ports[i]]
|
||||
|
||||
# 启动子进程
|
||||
proc = subprocess.Popen(cmd, env=env)
|
||||
@@ -89,21 +144,25 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server")
|
||||
parser.add_argument("--num-servers", default=2, type=int, help="number of workers")
|
||||
parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server")
|
||||
parser.add_argument("--metrics-ports", default="-1", type=str, help="ports for metrics server")
|
||||
parser.add_argument("--controller-ports", default="-1", type=str, help="ports for controller server port")
|
||||
parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting {args.num_servers} servers on ports: {args.ports} with args: {args.args}")
|
||||
# check_param(args.ports, args.num_servers)
|
||||
# check_param(args.metrics_ports, args.num_servers)
|
||||
# check_param(args.args.engine_worker_queue_port, args.num_servers)
|
||||
|
||||
device_count = 0
|
||||
if current_platform.is_cuda():
|
||||
device_count = len(os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(","))
|
||||
elif current_platform.is_xpu():
|
||||
device_count = len(os.getenv("XPU_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(","))
|
||||
|
||||
processes = start_servers(
|
||||
server_count=args.num_servers,
|
||||
device_count=device_count,
|
||||
server_args=args.args,
|
||||
ports=args.ports.split(","),
|
||||
metrics_ports=args.metrics_ports.split(","),
|
||||
ports=args.ports,
|
||||
metrics_ports=args.metrics_ports,
|
||||
controller_ports=args.controller_ports,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user