mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] [PD Disaggregation] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports (#5415)
* [feat] simplify configuration for pd-disaggregated deployment, and refactor post-init and usage for all ports * [fix] fix some bugs * [fix] fix rdma port for cache manager/messager * [fix] temporarily cancel port availability check to see if it can pass ci test * [feat] simplify args for multi api server * [fix] fix dp * [fix] fix port for xpu * [fix] add tests for ports post processing & fix ci * [test] fix test_multi_api_server * [fix] fix rdma_comm_ports args for multi_api_server * [fix] fix test_common_engine * [fix] fix test_cache_transfer_manager * [chore] automatically setting FD_ENABLE_MULTI_API_SERVER * [fix] avoid api server from creating engine_args twice * [fix] fix test_run_batch * [fix] fix test_metrics * [fix] fix splitwise connector init * [test] add test_rdma_transfer and test_expert_service * [fix] fix code syntax * [fix] fix test_rdma_transfer and build wheel with rdma script
This commit is contained in:
@@ -22,12 +22,12 @@ if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
fi
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
source ./utils.sh
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
P_PORT=52400
|
||||
D_PORT=52500
|
||||
REDIS_PORT="${REDIS_PORT:-56388}"
|
||||
REDIS_PORT="${REDIS_PORT:-6379}"
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
ports=(
|
||||
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
|
||||
@@ -51,8 +51,8 @@ sleep 1
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
@@ -76,8 +76,8 @@ wait_for_health ${P_PORT}
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
|
||||
@@ -9,29 +9,19 @@ set -e
|
||||
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
DATA_PARALLEL_SIZE=2
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
NUM_GPUS=$(($DATA_PARALLEL_SIZE * $TENSOR_PARALLEL_SIZE))
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
export FD_ENABLE_MULTI_API_SERVER=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
# start router
|
||||
ROUTER_PORT=$(get_free_ports 1)
|
||||
echo "---------------------------"
|
||||
echo ROUTER_PORT: $ROUTER_PORT
|
||||
|
||||
export FD_LOG_DIR="log/$LOG_DATE/router"
|
||||
@@ -47,18 +37,7 @@ sleep 1
|
||||
|
||||
# start prefill
|
||||
P_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
P_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
|
||||
P_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
echo "---------------------------"
|
||||
echo P_SERVER_PORTS: $P_SERVER_PORTS
|
||||
echo P_METRICS_PORTS: $P_METRICS_PORTS
|
||||
echo P_ENGINE_WORKER_QUEUE_PORTS: $P_ENGINE_WORKER_QUEUE_PORTS
|
||||
echo P_CACHE_QUEUE_PORTS: $P_CACHE_QUEUE_PORTS
|
||||
echo P_RDMA_COMM_PORTS: $P_RDMA_COMM_PORTS
|
||||
echo P_PD_COMM_PORTS: $P_PD_COMM_PORTS
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
@@ -68,40 +47,22 @@ mkdir -p ${FD_LOG_DIR}
|
||||
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--num-servers ${DATA_PARALLEL_SIZE}\
|
||||
--ports ${P_SERVER_PORTS} \
|
||||
--metrics-port ${P_METRICS_PORTS} \
|
||||
--args --model ${MODEL_NAME} \
|
||||
--engine-worker-queue-port ${P_ENGINE_WORKER_QUEUE_PORTS} \
|
||||
--cache-queue-port ${P_CACHE_QUEUE_PORTS} \
|
||||
--max-model-len 32768 \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports ${P_RDMA_COMM_PORTS} \
|
||||
--pd-comm-port ${P_PD_COMM_PORTS} \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
echo "--- Health Check Status ---"
|
||||
wait_for_health ${P_SERVER_PORTS}
|
||||
|
||||
|
||||
# start decode
|
||||
D_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
D_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
|
||||
D_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
|
||||
echo "---------------------------"
|
||||
echo D_SERVER_PORTS: $D_SERVER_PORTS
|
||||
echo D_ENGINE_WORKER_QUEUE_PORTS: $D_ENGINE_WORKER_QUEUE_PORTS
|
||||
echo D_CACHE_QUEUE_PORTS: $D_CACHE_QUEUE_PORTS
|
||||
echo D_METRICS_PORTS: $D_METRICS_PORTS
|
||||
echo D_RDMA_COMM_PORTS: $D_RDMA_COMM_PORTS
|
||||
echo D_PD_COMM_PORTS: $D_PD_COMM_PORTS
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="2,3"
|
||||
export CUDA_VISIBLE_DEVICES="4,5"
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf $FD_LOG_DIR
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
@@ -109,26 +70,18 @@ mkdir -p ${FD_LOG_DIR}
|
||||
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--num-servers ${DATA_PARALLEL_SIZE}\
|
||||
--ports ${D_SERVER_PORTS} \
|
||||
--metrics-port ${D_METRICS_PORTS} \
|
||||
--args --model ${MODEL_NAME} \
|
||||
--engine-worker-queue-port ${D_ENGINE_WORKER_QUEUE_PORTS} \
|
||||
--cache-queue-port ${D_CACHE_QUEUE_PORTS} \
|
||||
--max-model-len 32768 \
|
||||
--data-parallel-size ${DATA_PARALLEL_SIZE} \
|
||||
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports ${D_RDMA_COMM_PORTS} \
|
||||
--pd-comm-port ${D_PD_COMM_PORTS} \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
echo "--- Health Check Status ---"
|
||||
wait_for_health ${D_SERVER_PORTS}
|
||||
|
||||
|
||||
# send request
|
||||
echo "------ Request Check ------"
|
||||
sleep 10 # make sure server is registered to router
|
||||
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
|
||||
@@ -9,39 +9,27 @@ set -e
|
||||
# prepare environment
|
||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
|
||||
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
|
||||
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
|
||||
exit 1
|
||||
fi
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
rm -rf log_*
|
||||
source ./utils.sh
|
||||
|
||||
P_PORT=52400
|
||||
D_PORT=52500
|
||||
ROUTER_PORT=52700
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
ports=(
|
||||
$P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5))
|
||||
$D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5))
|
||||
$ROUTER_PORT
|
||||
)
|
||||
ports=($P_PORT $D_PORT $ROUTER_PORT)
|
||||
check_ports "${ports[@]}" || {
|
||||
echo "❌ Some ports are in use. Please release them."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# start router
|
||||
export FD_LOG_DIR="log_router"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
export FD_LOG_DIR="log/$LOG_DATE/router"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${ROUTER_PORT} \
|
||||
@@ -50,43 +38,29 @@ nohup python -m fastdeploy.router.launch \
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${P_PORT}" \
|
||||
--metrics-port "$((P_PORT + 1))" \
|
||||
--engine-worker-queue-port "$((P_PORT + 2))" \
|
||||
--cache-queue-port "$((P_PORT + 3))" \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "prefill" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports "$((P_PORT + 4))" \
|
||||
--pd-comm-port "$((P_PORT + 5))" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${P_PORT}" \
|
||||
--splitwise-role "prefill" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${P_PORT}
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export FD_LOG_DIR="log_decode"
|
||||
mkdir -p ${FD_LOG_DIR}
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${D_PORT}" \
|
||||
--metrics-port "$((D_PORT + 2))" \
|
||||
--engine-worker-queue-port "$((D_PORT + 3))" \
|
||||
--cache-queue-port "$((D_PORT + 1))" \
|
||||
--max-model-len 32768 \
|
||||
--splitwise-role "decode" \
|
||||
--cache-transfer-protocol "rdma" \
|
||||
--rdma-comm-ports "$((D_PORT + 4))" \
|
||||
--pd-comm-port "$((D_PORT + 5))" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${D_PORT}" \
|
||||
--splitwise-role "decode" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${D_PORT}
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Test splitwise deployment
|
||||
# There are two methods for splitwise deployment:
|
||||
# v0: using splitwise_scheduler or dp_scheduler
|
||||
# v1: using local_scheduler + router
|
||||
|
||||
# prepare environment
|
||||
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||
export FD_DEBUG=1
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
|
||||
|
||||
SCRIPT_PATH=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
source ${SCRIPT_DIR}/utils.sh
|
||||
|
||||
unset http_proxy && unset https_proxy
|
||||
|
||||
P_PORT=52400
|
||||
D_PORT=52500
|
||||
ROUTER_PORT=52700
|
||||
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
ports=($P_PORT $D_PORT $ROUTER_PORT)
|
||||
check_ports "${ports[@]}" || {
|
||||
echo "❌ Some ports are in use. Please release them."
|
||||
exit 1
|
||||
}
|
||||
|
||||
# start router
|
||||
export FD_LOG_DIR="log/$LOG_DATE/router"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.router.launch \
|
||||
--port ${ROUTER_PORT} \
|
||||
--splitwise \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
# start prefill
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${P_PORT}" \
|
||||
--tensor-parallel-size 2 \
|
||||
--splitwise-role "prefill" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${P_PORT}
|
||||
|
||||
# start decode
|
||||
export CUDA_VISIBLE_DEVICES=2,3
|
||||
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||
|
||||
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ${MODEL_NAME} \
|
||||
--port "${D_PORT}" \
|
||||
--tensor-parallel-size 2 \
|
||||
--splitwise-role "decode" \
|
||||
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||
2>&1 >${FD_LOG_DIR}/nohup &
|
||||
|
||||
wait_for_health ${D_PORT}
|
||||
|
||||
# send request
|
||||
sleep 10 # make sure server is registered to router
|
||||
echo "send request..."
|
||||
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"max_tokens": 100,
|
||||
"stream": false
|
||||
}'
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
is_port_free() {
|
||||
local port=$1
|
||||
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
|
||||
if ss -ltun | awk '{print $4}' | grep -q ":${port}$"; then
|
||||
return 1 # Port is occupied
|
||||
fi
|
||||
return 0 # Port is free
|
||||
@@ -28,6 +28,7 @@ wait_for_health() {
|
||||
local NC='\033[0m' # No Color
|
||||
local start_time=$(date +%s)
|
||||
|
||||
echo "-------- WAIT FOR HEALTH --------"
|
||||
while true; do
|
||||
local all_ready=true
|
||||
for port in "${server_ports[@]}"; do
|
||||
@@ -44,11 +45,12 @@ wait_for_health() {
|
||||
echo "All services are ready! [$((cur_time-start_time))s]"
|
||||
break
|
||||
else
|
||||
echo "Waiting for services... [$((cur_time-start_time))s]"
|
||||
echo "Services not ready.. [$((cur_time-start_time))s]"
|
||||
printf "\033[%dA" "$total_lines" # roll back cursor
|
||||
sleep 1
|
||||
fi
|
||||
done
|
||||
echo "---------------------------------"
|
||||
}
|
||||
|
||||
get_free_ports() {
|
||||
|
||||
@@ -62,7 +62,7 @@ def parse_args():
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument("--ipc_suffix", type=str, default=None, help="ipc suffix")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
@@ -945,7 +945,7 @@ def main():
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
suffix=args.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
cache_ready_signal.value[rank] = 1
|
||||
|
||||
@@ -85,7 +85,7 @@ def parse_args():
|
||||
help="engine worker queue port",
|
||||
)
|
||||
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument("--ipc_suffix", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
type=str,
|
||||
@@ -140,7 +140,7 @@ class CacheTransferManager:
|
||||
self.n_ranks = args.mp_num
|
||||
self.rank = rank
|
||||
self.device = device
|
||||
self.engine_pid = args.engine_pid
|
||||
self.ipc_suffix = args.ipc_suffix
|
||||
self.cache_dtype = args.cache_dtype
|
||||
|
||||
address = (args.pod_ip, args.cache_queue_port)
|
||||
@@ -157,7 +157,7 @@ class CacheTransferManager:
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.engine_pid,
|
||||
suffix=self.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
|
||||
@@ -165,7 +165,7 @@ class CacheTransferManager:
|
||||
name="swap_space_ready_signal",
|
||||
array=swap_space_ready_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.engine_pid,
|
||||
suffix=self.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -180,7 +180,7 @@ class CacheTransferManager:
|
||||
name="cache_task_broadcast_signal",
|
||||
array=cache_task_broadcast_data,
|
||||
dtype=np.int32,
|
||||
suffix=args.engine_pid,
|
||||
suffix=args.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -653,7 +653,7 @@ class CacheTransferManager:
|
||||
name="kv_cache_status",
|
||||
array=kv_cache_status,
|
||||
dtype=np.int32,
|
||||
suffix=self.engine_pid,
|
||||
suffix=self.ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
while True:
|
||||
|
||||
@@ -167,7 +167,7 @@ class PrefixCacheManager:
|
||||
device_ids,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
ipc_suffix,
|
||||
create_cache_tensor,
|
||||
):
|
||||
"""
|
||||
@@ -184,7 +184,7 @@ class PrefixCacheManager:
|
||||
)
|
||||
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=(pod_ip, cache_config.cache_queue_port),
|
||||
address=(pod_ip, cache_config.local_cache_queue_port),
|
||||
authkey=b"cache_queue_service",
|
||||
is_server=False,
|
||||
num_client=tensor_parallel_size,
|
||||
@@ -210,7 +210,7 @@ class PrefixCacheManager:
|
||||
val_cache_shape,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
ipc_suffix,
|
||||
)
|
||||
if cache_messager_processes is None:
|
||||
raise RuntimeError("Launch cache messager failed")
|
||||
@@ -269,15 +269,15 @@ class PrefixCacheManager:
|
||||
+ f" --cache_dtype {cache_config.cache_dtype}"
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
|
||||
+ f" --enable_splitwise {int(self.enable_splitwise)}"
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --num_cpu_blocks {cache_config.num_cpu_blocks}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --ipc_suffix {ipc_suffix}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" --default_dtype '{self.config.model_config.dtype}'"
|
||||
+ (" --create_cache_tensor" if create_cache_tensor else "")
|
||||
@@ -321,7 +321,7 @@ class PrefixCacheManager:
|
||||
value_cache_shape,
|
||||
pod_ip,
|
||||
engine_worker_queue_port,
|
||||
pid_suffix,
|
||||
ipc_suffix,
|
||||
):
|
||||
"""
|
||||
launch_cache_messager function used to initialize the cache messager.
|
||||
@@ -334,7 +334,7 @@ class PrefixCacheManager:
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=pid_suffix,
|
||||
suffix=ipc_suffix,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -366,12 +366,12 @@ class PrefixCacheManager:
|
||||
+ f" --key_cache_shape {key_cache_shape}"
|
||||
+ val_cache_arg_str
|
||||
+ f" --pod_ip {pod_ip}"
|
||||
+ f" --cache_queue_port {cache_config.cache_queue_port}"
|
||||
+ f" --cache_queue_port {cache_config.local_cache_queue_port}"
|
||||
+ f" --engine_worker_queue_port {engine_worker_queue_port}"
|
||||
+ f" --protocol {cache_config.cache_transfer_protocol}"
|
||||
+ f" --local_data_parallel_id {self.local_data_parallel_id}"
|
||||
+ f" --engine_pid {pid_suffix}"
|
||||
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
|
||||
+ f" --ipc_suffix {ipc_suffix}"
|
||||
+ f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}"
|
||||
+ f" --speculative_config '{self.speculative_config.to_json_string()}'"
|
||||
+ f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
#!/bin/bash
|
||||
Cur_Dir=$(cd `dirname $0`; pwd)
|
||||
NICNAME_TYPE=xgbe # 默认检测类型
|
||||
type=$1
|
||||
|
||||
if [ "$ENABLE_EP_DP" == "1" ]; then
|
||||
gpu_root_port_filename="${Cur_Dir}/gpu_rootport_${DP_RANK}.txt"
|
||||
else
|
||||
gpu_root_port_filename="${Cur_Dir}/gpu_rootport.txt"
|
||||
fi
|
||||
|
||||
function __NEW_GPU_ROOTPORT_FILE__() {
|
||||
touch ${gpu_root_port_filename} 2>/dev/null
|
||||
echo "" > ${gpu_root_port_filename} 2>/dev/null
|
||||
for gpu_bus in $(lspci 2>/dev/null | grep -iE "Communication controller: | controller: NVIDIA" | awk '{print $1}')
|
||||
do
|
||||
readlink "/sys/bus/pci/devices/0000:${gpu_bus}" 2>/dev/null | awk -F [/] '{print $6}' >> ${gpu_root_port_filename}
|
||||
done
|
||||
}
|
||||
|
||||
function __RM_GPU_ROOTPORT_FILE__() {
|
||||
rm -rf ${gpu_root_port_filename} 2>/dev/null
|
||||
}
|
||||
|
||||
function __JUDGE_NIC_TYPE__() {
|
||||
XGBE_NUM=$(ip a 2>/dev/null | grep -c ": ${NICNAME_TYPE}")
|
||||
gpu_first=true
|
||||
xpu_first=true
|
||||
cpu_first=true
|
||||
|
||||
for (( xgbe_no=0; xgbe_no < XGBE_NUM; xgbe_no++ ))
|
||||
do
|
||||
[ ! -d "/sys/class/net/${NICNAME_TYPE}${xgbe_no}" ] && continue
|
||||
|
||||
PCI_ADDRESS=$(ethtool -i "${NICNAME_TYPE}${xgbe_no}" 2>/dev/null | awk -F '0000:' '/bus-info/{print $2}')
|
||||
[ -z "$PCI_ADDRESS" ] && continue
|
||||
NIC_ROOTPORT=$(readlink "/sys/bus/pci/devices/0000:${PCI_ADDRESS}" 2>/dev/null | awk -F '/' '{print $6}')
|
||||
|
||||
NIC_TYPE="CPU_NIC"
|
||||
grep -qxF "$NIC_ROOTPORT" ${gpu_root_port_filename} 2>/dev/null && NIC_TYPE="GPU_NIC"
|
||||
|
||||
if [[ "$type" == "gpu" && "$NIC_TYPE" == "GPU_NIC" ]]; then
|
||||
ibdev=$(ibdev2netdev 2>/dev/null | awk -v nic="${NICNAME_TYPE}${xgbe_no}" '$5 == nic {print $1}')
|
||||
if [ -n "$ibdev" ] && ip link show "${NICNAME_TYPE}${xgbe_no}" | grep -q "state UP"; then
|
||||
if $gpu_first; then
|
||||
printf "KVCACHE_RDMA_NICS=%s" "$ibdev"
|
||||
gpu_first=false
|
||||
else
|
||||
printf ",%s" "$ibdev"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$type" == "xpu" && "$NIC_TYPE" == "GPU_NIC" ]]; then
|
||||
ibdev=$(ibdev2netdev 2>/dev/null | awk -v nic="${NICNAME_TYPE}${xgbe_no}" '$5 == nic {print $1}')
|
||||
if [ -n "$ibdev" ] && ip link show "${NICNAME_TYPE}${xgbe_no}" | grep -q "state UP"; then
|
||||
if $xpu_first; then
|
||||
printf "KVCACHE_RDMA_NICS=%s,%s" "$ibdev" "$ibdev"
|
||||
xpu_first=false
|
||||
else
|
||||
printf ",%s,%s" "$ibdev" "$ibdev"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ "$type" == "cpu" ]]; then
|
||||
for (( xgbe_no=0; xgbe_no < XGBE_NUM; xgbe_no++ ))
|
||||
do
|
||||
[ ! -d "/sys/class/net/${NICNAME_TYPE}${xgbe_no}" ] && continue
|
||||
|
||||
PCI_ADDRESS=$(ethtool -i "${NICNAME_TYPE}${xgbe_no}" 2>/dev/null | awk -F '0000:' '/bus-info/{print $2}')
|
||||
[ -z "$PCI_ADDRESS" ] && continue
|
||||
|
||||
NIC_ROOTPORT=$(readlink "/sys/bus/pci/devices/0000:${PCI_ADDRESS}" 2>/dev/null | awk -F '/' '{print $6}')
|
||||
grep -qxF "$NIC_ROOTPORT" ${gpu_root_port_filename} 2>/dev/null && continue
|
||||
|
||||
if ip link show "${NICNAME_TYPE}${xgbe_no}" | grep -q "state UP" && \
|
||||
ip a show "${NICNAME_TYPE}${xgbe_no}" | grep -q "inet"; then
|
||||
printf "KV_CACHE_SOCKET_IFNAME=%s\n" "${NICNAME_TYPE}${xgbe_no}"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
echo "ERROR: No active CPU NIC with IP found!" >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$type" == "cpu_ib" && "$NIC_TYPE" == "CPU_NIC" ]]; then
|
||||
ibdev=$(ibdev2netdev 2>/dev/null | awk -v nic="${NICNAME_TYPE}${xgbe_no}" '$5 == nic {print $1}')
|
||||
if [ -n "$ibdev" ] && ip link show "${NICNAME_TYPE}${xgbe_no}" | grep -q "state UP" && \
|
||||
ip a show "${NICNAME_TYPE}${xgbe_no}" | grep -q "inet "; then
|
||||
if $cpu_ib_first; then
|
||||
printf "KVCACHE_RDMA_NICS=%s" "$ibdev"
|
||||
cpu_ib_first=false
|
||||
else
|
||||
printf ",%s" "$ibdev"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
done
|
||||
|
||||
case "$type" in
|
||||
gpu) ! $gpu_first && printf "\n" ;;
|
||||
xpu) ! $xpu_first && printf "\n" ;;
|
||||
cpu) ! $cpu_first && printf "\n" ;;
|
||||
cpu_ib) ! $cpu_ib_first && printf "\n" ;;
|
||||
esac
|
||||
}
|
||||
|
||||
function get_vxpu_nics() {
|
||||
local topo_output=$(xpu-smi topo -m)
|
||||
local xpu_info=$(echo "$topo_output" | grep -E '^XPU[0-9]+')
|
||||
|
||||
local nic_mapping=()
|
||||
while IFS= read -r line; do
|
||||
if [[ $line =~ NIC([0-9]+):\ +(mlx[0-9_]+) ]]; then
|
||||
local nic_idx=${BASH_REMATCH[1]}
|
||||
local nic_name=${BASH_REMATCH[2]}
|
||||
nic_mapping[$nic_idx]=$nic_name
|
||||
fi
|
||||
done < <(echo "$topo_output" | grep -E '^\s*NIC[0-9]+:')
|
||||
|
||||
local nic_count=${#nic_mapping[@]}
|
||||
|
||||
declare -A priority_map=([PIX]=2 [NODE]=1 [SYS]=0)
|
||||
local optimal_nics=()
|
||||
|
||||
while IFS= read -r line; do
|
||||
local fields=($line)
|
||||
local nic_start_index=5
|
||||
local max_nics=$(( ${#fields[@]} - nic_start_index ))
|
||||
local actual_nic_count=$(( max_nics < nic_count ? max_nics : nic_count ))
|
||||
|
||||
local best_priority=-1
|
||||
local best_nic=""
|
||||
|
||||
for ((nic_idx=0; nic_idx<actual_nic_count; nic_idx++)); do
|
||||
local conn_type=${fields[nic_idx+nic_start_index]}
|
||||
local current_priority=${priority_map[$conn_type]:--1}
|
||||
|
||||
if (( current_priority > best_priority )); then
|
||||
best_priority=$current_priority
|
||||
best_nic="${nic_mapping[$nic_idx]}"
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -n "$best_nic" ]]; then
|
||||
optimal_nics+=("$best_nic")
|
||||
fi
|
||||
done <<< "$xpu_info"
|
||||
|
||||
local IFS=,
|
||||
export KVCACHE_RDMA_NICS="${optimal_nics[*]}"
|
||||
echo "KVCACHE_RDMA_NICS=${optimal_nics[*]}"
|
||||
}
|
||||
|
||||
function get_vcpu_nics() {
|
||||
ip -o addr show | awk '$3 == "inet" && $4 ~ /^10\./ {print "KV_CACHE_SOCKET_IFNAME="$2; exit}'
|
||||
}
|
||||
|
||||
function __main__() {
|
||||
if [[ "$type" == "vxpu" ]]; then
|
||||
get_vxpu_nics
|
||||
return 0
|
||||
fi
|
||||
if [[ "$type" == "vcpu" ]]; then
|
||||
get_vcpu_nics
|
||||
return 0
|
||||
fi
|
||||
|
||||
# 处理 bond 情况
|
||||
if [[ "$type" == "cpu" ]]; then
|
||||
for bond in $(ls -d /sys/class/net/bond* 2>/dev/null); do
|
||||
bond_if=$(basename "$bond")
|
||||
if ip link show "$bond_if" | grep -q "state UP" && \
|
||||
ip a show "$bond_if" | grep -q "inet "; then
|
||||
printf "KV_CACHE_SOCKET_IFNAME=%s\n" "$bond_if"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [[ "$type" == "cpu_ib" ]]; then
|
||||
first=true
|
||||
for bond in $(ls -d /sys/class/net/bond* 2>/dev/null); do
|
||||
bond_if=$(basename "$bond")
|
||||
__NEW_GPU_ROOTPORT_FILE__
|
||||
|
||||
ibdev=$(ibdev2netdev 2>/dev/null | grep -w "$bond_if" | awk '{print $1}')
|
||||
if [ -n "$ibdev" ] && ip link show "$bond_if" | grep -q "state UP" && \
|
||||
ip a show "$bond_if" | grep -q "inet "; then
|
||||
if $first; then
|
||||
printf "KVCACHE_RDMA_NICS=%s" "$ibdev"
|
||||
first=false
|
||||
else
|
||||
printf ",%s" "$ibdev"
|
||||
fi
|
||||
fi
|
||||
|
||||
bondib=$(show_gids 2>/dev/null | grep -w "$bond_if" | awk '{print $1}' | grep "mlx.*bond" | head -1)
|
||||
if [ -n "$bondib" ] && ip link show "$bond_if" | grep -q "state UP" && \
|
||||
ip a show "$bond_if" | grep -q "inet " && $first; then
|
||||
printf "KVCACHE_RDMA_NICS=%s" "$bondib"
|
||||
first=false
|
||||
fi
|
||||
|
||||
__RM_GPU_ROOTPORT_FILE__
|
||||
done
|
||||
|
||||
! $first && printf "\n"
|
||||
[ ! $first ] && return 0
|
||||
fi
|
||||
|
||||
local nic_types=("eth" "ib" "xgbe")
|
||||
for nt in "${nic_types[@]}"; do
|
||||
if ip a | grep -iq "$nt"; then
|
||||
__NEW_GPU_ROOTPORT_FILE__
|
||||
NICNAME_TYPE=$nt
|
||||
__JUDGE_NIC_TYPE__
|
||||
__RM_GPU_ROOTPORT_FILE__
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
__main__
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cache_messager", "cache_messager.log")
|
||||
@@ -37,13 +39,66 @@ class RDMACommManager:
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
):
|
||||
try:
|
||||
import importlib
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if os.getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE", "") == "" and current_platform.is_cuda():
|
||||
command = ["nvidia-smi", "-i", "0", "--query-gpu=compute_cap", "--format=csv,noheader"]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
logger.info(f"nvidia-smi command: {command}")
|
||||
logger.info(f"nvidia-smi output: {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to get compute capability via nvidia-smi: {result.stderr.strip()}")
|
||||
|
||||
major, minor = result.stdout.strip().split(".")
|
||||
if major == "8": # for ampere arch
|
||||
os.environ["KVCACHE_GDRCOPY_FLUSH_ENABLE"] = "1"
|
||||
logger.info("Setting environment variable: export KVCACHE_GDRCOPY_FLUSH_ENABLE=1")
|
||||
|
||||
if os.getenv("KVCACHE_RDMA_NICS", "") == "" and current_platform.is_cuda():
|
||||
res = importlib.resources.files("fastdeploy.cache_manager.transfer_factory") / "get_rdma_nics.sh"
|
||||
get_rdma_nics = None
|
||||
with importlib.resources.as_file(res) as path:
|
||||
get_rdma_nics = str(path)
|
||||
nic_type = current_platform.device_name
|
||||
command = ["bash", get_rdma_nics, nic_type]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
logger.info(f"get_rdma_nics command: {command}")
|
||||
logger.info(f"get_rdma_nics output: {result.stdout}")
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to execute script `get_rdma_nics.sh`: {result.stderr.strip()}")
|
||||
|
||||
env_name, env_value = result.stdout.strip().split("=")
|
||||
assert env_name == "KVCACHE_RDMA_NICS"
|
||||
os.environ[env_name] = env_value
|
||||
logger.info(f"Setting environment variable: export {env_name}={env_value}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize RDMA environment! {e} {traceback.format_exc()}")
|
||||
|
||||
try:
|
||||
import rdma_comm
|
||||
except:
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"The installation of the RDMA library failed."
|
||||
"Confirm whether your network card supports RDMA transmission."
|
||||
"The installation of the RDMA library failed. Confirm whether your network card supports RDMA transmission."
|
||||
)
|
||||
|
||||
self.messager = rdma_comm.RDMACommunicator(
|
||||
splitwise_role,
|
||||
gpu_id,
|
||||
|
||||
+57
-29
@@ -33,7 +33,13 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.scheduler import SchedulerConfig
|
||||
from fastdeploy.transformer_utils.config import get_pooling_config
|
||||
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
|
||||
from fastdeploy.utils import (
|
||||
ceil_div,
|
||||
check_unified_ckpt,
|
||||
get_host_ip,
|
||||
get_logger,
|
||||
parse_ports,
|
||||
)
|
||||
|
||||
logger = get_logger("config", "config.log")
|
||||
|
||||
@@ -559,7 +565,8 @@ class ParallelConfig:
|
||||
|
||||
self.local_data_parallel_id = 0
|
||||
# Engine worker queue port
|
||||
self.engine_worker_queue_port: str = "9923"
|
||||
self.engine_worker_queue_port: Union[int, str, list] = None
|
||||
self.local_engine_worker_queue_port: Optional[int] = None
|
||||
# cuda visible devices
|
||||
self.device_ids: str = "0"
|
||||
# First token id
|
||||
@@ -579,11 +586,9 @@ class ParallelConfig:
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
if isinstance(self.engine_worker_queue_port, str):
|
||||
self.engine_worker_queue_port = [int(port) for port in self.engine_worker_queue_port.split(",")]
|
||||
logger.info(f"engine_worker_queue_port: {self.engine_worker_queue_port}")
|
||||
elif isinstance(self.engine_worker_queue_port, int):
|
||||
self.engine_worker_queue_port = [self.engine_worker_queue_port]
|
||||
|
||||
self.engine_worker_queue_port = parse_ports(self.engine_worker_queue_port)
|
||||
|
||||
# currently, the expert parallel size is equal data parallel size
|
||||
if self.enable_expert_parallel:
|
||||
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
|
||||
@@ -1267,11 +1272,14 @@ class CacheConfig:
|
||||
self.model_cfg = None
|
||||
self.enable_chunked_prefill = False
|
||||
self.rdma_comm_ports = None
|
||||
self.local_rdma_comm_ports = None
|
||||
self.cache_transfer_protocol = None
|
||||
self.pd_comm_port = None
|
||||
self.local_pd_comm_port = None
|
||||
self.enable_prefix_caching = False
|
||||
self.enable_ssd_cache = False
|
||||
self.cache_queue_port = None
|
||||
self.local_cache_queue_port = None
|
||||
self.swap_space = None
|
||||
self.max_encoder_cache = None
|
||||
self.max_processor_cache = None
|
||||
@@ -1281,11 +1289,9 @@ class CacheConfig:
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.rdma_comm_ports is not None and isinstance(self.rdma_comm_ports, str):
|
||||
self.rdma_comm_ports = self.rdma_comm_ports.split(",")
|
||||
|
||||
if self.pd_comm_port is not None and isinstance(self.pd_comm_port, str):
|
||||
self.pd_comm_port = [int(port) for port in self.pd_comm_port.split(",")]
|
||||
self.cache_queue_port = parse_ports(self.cache_queue_port)
|
||||
self.rdma_comm_ports = parse_ports(self.rdma_comm_ports)
|
||||
self.pd_comm_port = parse_ports(self.pd_comm_port)
|
||||
|
||||
if self.swap_space is None:
|
||||
self.enable_hierarchical_cache = False
|
||||
@@ -1657,7 +1663,7 @@ class FDConfig:
|
||||
if test_mode:
|
||||
return
|
||||
self.check()
|
||||
self.print()
|
||||
# self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized
|
||||
|
||||
def _disable_sequence_parallel_moe_if_needed(self, mode_name):
|
||||
if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph:
|
||||
@@ -1671,7 +1677,6 @@ class FDConfig:
|
||||
"""
|
||||
calculate some parameters
|
||||
"""
|
||||
self.local_device_ids = self.parallel_config.device_ids.split(",")[: self.parallel_config.tensor_parallel_size]
|
||||
|
||||
if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0:
|
||||
self.is_master = True
|
||||
@@ -1776,6 +1781,41 @@ class FDConfig:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.postprocess_devices_and_ports()
|
||||
|
||||
def postprocess_devices_and_ports(self):
|
||||
try:
|
||||
# get devices and ports for current dp
|
||||
self.local_device_ids = self.parallel_config.device_ids.split(",")[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
* self.parallel_config.tensor_parallel_size : (self.parallel_config.local_data_parallel_id + 1)
|
||||
* self.parallel_config.tensor_parallel_size
|
||||
]
|
||||
self.parallel_config.local_engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
]
|
||||
self.cache_config.local_cache_queue_port = (
|
||||
self.cache_config.cache_queue_port[self.parallel_config.local_data_parallel_id]
|
||||
if self.cache_config.cache_queue_port
|
||||
else None
|
||||
)
|
||||
self.cache_config.local_pd_comm_port = (
|
||||
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
|
||||
if self.cache_config.pd_comm_port
|
||||
else None
|
||||
)
|
||||
self.cache_config.local_rdma_comm_ports = (
|
||||
self.cache_config.rdma_comm_ports[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
* self.parallel_config.tensor_parallel_size : (self.parallel_config.local_data_parallel_id + 1)
|
||||
* self.parallel_config.tensor_parallel_size
|
||||
]
|
||||
if self.cache_config.rdma_comm_ports
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract local devices or ports. Servers may not be able to start properly. {e}")
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
check the legality of config
|
||||
@@ -1924,18 +1964,6 @@ class FDConfig:
|
||||
elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router:
|
||||
self.splitwise_version = "v1"
|
||||
|
||||
if isinstance(self.parallel_config.engine_worker_queue_port, (int, str)):
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port
|
||||
else:
|
||||
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
|
||||
self.parallel_config.local_data_parallel_id
|
||||
]
|
||||
connector_port = (
|
||||
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
|
||||
if self.cache_config.pd_comm_port
|
||||
else None
|
||||
)
|
||||
|
||||
# the information for registering this server to router or splitwise_scheduler
|
||||
port = self.router_config.api_server_port if self.router_config else None
|
||||
transfer_protocol = (
|
||||
@@ -1945,9 +1973,9 @@ class FDConfig:
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
"port": port,
|
||||
"connector_port": connector_port,
|
||||
"rdma_ports": self.cache_config.rdma_comm_ports,
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"connector_port": self.cache_config.local_pd_comm_port,
|
||||
"rdma_ports": self.cache_config.local_rdma_comm_ports,
|
||||
"engine_worker_queue_port": self.parallel_config.local_engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": transfer_protocol,
|
||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||
|
||||
@@ -47,7 +47,9 @@ from fastdeploy.utils import (
|
||||
DeprecatedOptionWarning,
|
||||
FlexibleArgumentParser,
|
||||
console_logger,
|
||||
find_free_ports,
|
||||
is_port_available,
|
||||
parse_ports,
|
||||
parse_quantization,
|
||||
)
|
||||
|
||||
@@ -224,7 +226,7 @@ class EngineArgs:
|
||||
The amount of CPU memory to offload to.
|
||||
"""
|
||||
|
||||
cache_queue_port: str = "0"
|
||||
cache_queue_port: Optional[Union[int, str, list]] = None
|
||||
"""
|
||||
Port for cache queue.
|
||||
"""
|
||||
@@ -266,7 +268,7 @@ class EngineArgs:
|
||||
# This optimization is enabled by default, and can be disabled by using this flag.
|
||||
"""
|
||||
|
||||
engine_worker_queue_port: str = "0"
|
||||
engine_worker_queue_port: Optional[Union[int, str, list]] = None
|
||||
"""
|
||||
Port for worker queue communication.
|
||||
"""
|
||||
@@ -301,17 +303,17 @@ class EngineArgs:
|
||||
Chunk size of moe input.
|
||||
"""
|
||||
|
||||
cache_transfer_protocol: str = "ipc"
|
||||
cache_transfer_protocol: str = "ipc,rdma"
|
||||
"""
|
||||
Protocol to use for cache transfer.
|
||||
"""
|
||||
|
||||
pd_comm_port: Optional[List[int]] = None
|
||||
pd_comm_port: Optional[Union[int, str, list]] = None
|
||||
"""
|
||||
Port for splitwise communication.
|
||||
"""
|
||||
|
||||
rdma_comm_ports: Optional[List[int]] = None
|
||||
rdma_comm_ports: Optional[Union[int, str, list]] = None
|
||||
"""
|
||||
Ports for rdma communication.
|
||||
"""
|
||||
@@ -497,6 +499,11 @@ class EngineArgs:
|
||||
Flag to rollout routing replay(r3)
|
||||
"""
|
||||
|
||||
skip_port_check: bool = False
|
||||
"""
|
||||
Whether to skip port availability check. Default is False (not skip).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Post-initialization processing to set default tokenizer if not provided.
|
||||
@@ -508,8 +515,6 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
|
||||
self.enable_prefix_caching = False
|
||||
# if self.dynamic_load_weight:
|
||||
# self.enable_prefix_caching = False
|
||||
if self.enable_logprob:
|
||||
if not current_platform.is_cuda() and not current_platform.is_xpu():
|
||||
raise NotImplementedError("Only CUDA and XPU platforms support logprob.")
|
||||
@@ -530,33 +535,69 @@ class EngineArgs:
|
||||
f"scheduler, please provide --router argument."
|
||||
)
|
||||
|
||||
if "rdma" in self.cache_transfer_protocol:
|
||||
if self.rdma_comm_ports is None:
|
||||
raise ValueError(
|
||||
"Please set --rdma_comm_ports argument when using " "rdma cache transfer protocol."
|
||||
)
|
||||
num_nodes = len(self.ips) if self.ips else 1
|
||||
if self.data_parallel_size % num_nodes != 0:
|
||||
raise ValueError(
|
||||
f"data_parallel_size ({self.data_parallel_size}) must be divisible by "
|
||||
f"num_nodes ({num_nodes})."
|
||||
)
|
||||
dp_per_node = self.data_parallel_size // num_nodes
|
||||
expected_ports = self.tensor_parallel_size * dp_per_node
|
||||
if len(self.rdma_comm_ports) != expected_ports:
|
||||
raise ValueError(
|
||||
f"The number of rdma_comm_ports must equal "
|
||||
f"tensor_parallel_size * (data_parallel_size / num_nodes) = "
|
||||
f"{self.tensor_parallel_size} * ({self.data_parallel_size} / {num_nodes}) "
|
||||
f"= {expected_ports}, but got {len(self.rdma_comm_ports)}."
|
||||
)
|
||||
|
||||
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
||||
if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name):
|
||||
envs.FD_ENABLE_MAX_PREFILL = 1
|
||||
|
||||
self.post_init_all_ports()
|
||||
|
||||
def post_init_all_ports(self):
|
||||
|
||||
def post_init_ports(name: str, ports: list, num_total_ports: int):
|
||||
ports = parse_ports(ports)
|
||||
num_cur_dp_ports = num_total_ports
|
||||
if envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
num_cur_dp_ports //= self.data_parallel_size
|
||||
if ports is None:
|
||||
ports = find_free_ports(num_ports=num_cur_dp_ports)
|
||||
console_logger.info(
|
||||
f"Parameter `{name}` is not specified, found available ports for possible use: {ports}"
|
||||
)
|
||||
else:
|
||||
num_input_ports = len(ports)
|
||||
if num_input_ports != num_total_ports:
|
||||
ports = find_free_ports(num_ports=num_cur_dp_ports)
|
||||
console_logger.warn(
|
||||
f"Parameter `{name}` expects {num_total_ports} ports, but got {num_input_ports}. Ignore them and assign new ones: {ports}"
|
||||
)
|
||||
else:
|
||||
console_logger.info(f"Using `{name}`: {ports}")
|
||||
|
||||
if not self.skip_port_check:
|
||||
for port in ports:
|
||||
assert is_port_available("0.0.0.0", port), f"Parameter `{name}`:{port} is already in use."
|
||||
|
||||
console_logger.debug(f"post init {name}: {ports}")
|
||||
return ports
|
||||
|
||||
num_nodes = len(self.ips) if self.ips else 1
|
||||
if self.data_parallel_size % num_nodes != 0:
|
||||
raise ValueError(
|
||||
f"data_parallel_size ({self.data_parallel_size}) must be divisible by num_nodes ({num_nodes})."
|
||||
)
|
||||
self.engine_worker_queue_port = post_init_ports(
|
||||
"engine_worker_queue_port",
|
||||
self.engine_worker_queue_port,
|
||||
self.data_parallel_size // num_nodes,
|
||||
)
|
||||
self.cache_queue_port = post_init_ports(
|
||||
"cache_queue_port",
|
||||
self.cache_queue_port,
|
||||
self.data_parallel_size // num_nodes,
|
||||
)
|
||||
self.rdma_comm_ports = post_init_ports(
|
||||
"rdma_comm_ports",
|
||||
self.rdma_comm_ports,
|
||||
self.tensor_parallel_size * self.data_parallel_size // num_nodes,
|
||||
)
|
||||
self.pd_comm_port = post_init_ports(
|
||||
"pd_comm_port",
|
||||
self.pd_comm_port,
|
||||
self.data_parallel_size // num_nodes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""
|
||||
@@ -1166,7 +1207,7 @@ class EngineArgs:
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: FlexibleArgumentParser) -> "EngineArgs":
|
||||
def from_cli_args(cls, args: FlexibleArgumentParser, skip_port_check=False) -> "EngineArgs":
|
||||
"""
|
||||
Create an instance of EngineArgs from command line arguments.
|
||||
"""
|
||||
@@ -1174,7 +1215,7 @@ class EngineArgs:
|
||||
for field in dataclass_fields(cls):
|
||||
if hasattr(args, field.name):
|
||||
args_dict[field.name] = getattr(args, field.name)
|
||||
return cls(**args_dict)
|
||||
return cls(**args_dict, skip_port_check=skip_port_check)
|
||||
|
||||
def create_speculative_config(self) -> SpeculativeConfig:
|
||||
""" """
|
||||
@@ -1253,7 +1294,7 @@ class EngineArgs:
|
||||
routing_replay_args[k] = v
|
||||
return RoutingReplayConfig(routing_replay_args)
|
||||
|
||||
def create_engine_config(self, port_availability_check=True) -> FDConfig:
|
||||
def create_engine_config(self) -> FDConfig:
|
||||
"""
|
||||
Create and return a Config object based on the current settings.
|
||||
"""
|
||||
@@ -1282,11 +1323,6 @@ class EngineArgs:
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
|
||||
if isinstance(self.engine_worker_queue_port, int):
|
||||
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
|
||||
if isinstance(self.engine_worker_queue_port, str):
|
||||
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
|
||||
|
||||
all_dict = asdict(self)
|
||||
all_dict["model_cfg"] = model_cfg
|
||||
cache_cfg = CacheConfig(all_dict)
|
||||
@@ -1302,10 +1338,6 @@ class EngineArgs:
|
||||
early_stop_cfg = self.create_early_stop_config()
|
||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||
structured_outputs_config: StructuredOutputsConfig = StructuredOutputsConfig(args=all_dict)
|
||||
if port_availability_check:
|
||||
assert is_port_available(
|
||||
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
|
||||
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
|
||||
|
||||
return FDConfig(
|
||||
model_config=model_cfg,
|
||||
|
||||
@@ -81,13 +81,6 @@ class EngineService:
|
||||
"""
|
||||
self.cfg = cfg
|
||||
self.use_async_llm = use_async_llm
|
||||
if cfg.scheduler_config.splitwise_role != "mixed" or cfg.cache_config.enable_prefix_caching:
|
||||
if isinstance(self.cfg.cache_config.cache_queue_port, str):
|
||||
self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",")
|
||||
if isinstance(self.cfg.cache_config.cache_queue_port, list):
|
||||
self.cfg.cache_config.cache_queue_port = int(
|
||||
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.llm_logger = get_logger(
|
||||
@@ -120,9 +113,8 @@ class EngineService:
|
||||
|
||||
self.start_worker_queue_service(start_queue)
|
||||
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.cfg.parallel_config.local_data_parallel_id
|
||||
]
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.parallel_config.local_engine_worker_queue_port)
|
||||
self.llm_logger.info(f"INFERENCE_MSG_QUEUE_ID: {str(self.cfg.parallel_config.local_engine_worker_queue_port)}")
|
||||
|
||||
self.split_connector = SplitwiseConnector(cfg, self.engine_worker_queue, self.resource_manager)
|
||||
self.token_processor = TokenProcessor(
|
||||
@@ -151,9 +143,7 @@ class EngineService:
|
||||
self._init_worker_monitor_signals()
|
||||
|
||||
if self.cfg.eplb_config.enable_eplb:
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
init_eplb_signals(cfg, current_suffix)
|
||||
|
||||
if self.use_async_llm:
|
||||
@@ -211,7 +201,7 @@ class EngineService:
|
||||
def check_worker_initialize_status_func(res: dict):
|
||||
res["worker_is_alive"] = True
|
||||
if not self.check_worker_initialize_status():
|
||||
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
|
||||
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
|
||||
res["worker_is_alive"] = False
|
||||
|
||||
self.check_worker_initialize_status_func_thread = threading.Thread(
|
||||
@@ -241,7 +231,7 @@ class EngineService:
|
||||
# Worker launched
|
||||
self.check_worker_initialize_status_func_thread.join()
|
||||
if not result_container["worker_is_alive"]:
|
||||
llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
|
||||
self.llm_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
|
||||
return False
|
||||
|
||||
# Start ZMQ service for communication with AsyncLLM
|
||||
@@ -259,9 +249,7 @@ class EngineService:
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
|
||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||
current_suffix = int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
)
|
||||
current_suffix = self.cfg.parallel_config.local_engine_worker_queue_port
|
||||
self.llm_logger.info(f"current_suffix: {current_suffix}")
|
||||
exist_task_signal_data = np.zeros([1], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
@@ -353,52 +341,43 @@ class EngineService:
|
||||
"""
|
||||
start queue service for engine worker communication
|
||||
"""
|
||||
|
||||
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
),
|
||||
)
|
||||
address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port)
|
||||
else:
|
||||
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock"
|
||||
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock"
|
||||
|
||||
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
|
||||
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
# Dynamically updates the port value if an anonymous port is used
|
||||
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = (
|
||||
str(self.engine_worker_queue_server.get_server_port())
|
||||
)
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[
|
||||
self.cfg.parallel_config.local_data_parallel_id
|
||||
]
|
||||
),
|
||||
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
|
||||
if start_queue:
|
||||
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
|
||||
self.engine_worker_queue_server = EngineWorkerQueue(
|
||||
address=address,
|
||||
is_server=True,
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
# Dynamically updates the port value if an anonymous port is used
|
||||
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
||||
self.cfg.parallel_config.local_engine_worker_queue_port = (
|
||||
self.engine_worker_queue_server.get_server_port()
|
||||
)
|
||||
address = (
|
||||
self.cfg.master_ip,
|
||||
self.cfg.parallel_config.local_engine_worker_queue_port,
|
||||
)
|
||||
|
||||
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.llm_logger.info(
|
||||
f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}"
|
||||
)
|
||||
self.cache_task_queue = EngineCacheQueue(
|
||||
address=(
|
||||
self.cfg.master_ip,
|
||||
self.cfg.cache_config.cache_queue_port,
|
||||
),
|
||||
address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port),
|
||||
authkey=b"cache_queue_service",
|
||||
is_server=True,
|
||||
num_client=self.cfg.parallel_config.tensor_parallel_size,
|
||||
client_id=-1,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
self.cfg.cache_config.cache_queue_port = self.cache_task_queue.get_server_port()
|
||||
self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port()
|
||||
|
||||
self.engine_worker_queue = EngineWorkerQueue(
|
||||
address=address,
|
||||
@@ -753,7 +732,7 @@ class EngineService:
|
||||
# so the same request sent by the decode api server will be ignored
|
||||
continue
|
||||
|
||||
llm_logger.debug(f"get tasks from scheduler: {tasks}")
|
||||
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
for task in tasks:
|
||||
task.metrics.ask_decode_resource_start_time = time.time()
|
||||
@@ -965,7 +944,7 @@ class EngineService:
|
||||
get_request_pool.submit(_fetch_request)
|
||||
except RuntimeError as e:
|
||||
if "shutdown" in str(e):
|
||||
llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
|
||||
self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop")
|
||||
break
|
||||
else:
|
||||
raise
|
||||
@@ -1023,7 +1002,7 @@ class EngineService:
|
||||
if error_tasks:
|
||||
for request_id, failed in error_tasks:
|
||||
if failed is None:
|
||||
llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
|
||||
self.llm_logger.warning(f"Request {request_id} has no error, skip sending error response.")
|
||||
continue
|
||||
self._send_error_response(request_id, failed)
|
||||
|
||||
@@ -1137,7 +1116,7 @@ class EngineService:
|
||||
)
|
||||
|
||||
def _send_error_response(self, request_id, error_msg, error_code: int = 500):
|
||||
llm_logger.error(
|
||||
self.llm_logger.error(
|
||||
f"Send error response to client, request_id: {request_id}, error_msg: {error_msg}, error_code: {error_code}"
|
||||
)
|
||||
error_result = RequestOutput(
|
||||
@@ -1200,7 +1179,7 @@ class EngineService:
|
||||
elif content.finished:
|
||||
new_step_contents.append(content)
|
||||
else:
|
||||
llm_logger.warning(
|
||||
self.llm_logger.warning(
|
||||
f"current tokens need to accumulate, req_id: {content.request_id} {content.outputs.token_ids}"
|
||||
)
|
||||
else:
|
||||
@@ -1230,16 +1209,16 @@ class EngineService:
|
||||
elif content.finished:
|
||||
new_contents.append(content)
|
||||
else:
|
||||
llm_logger.warning(
|
||||
self.llm_logger.warning(
|
||||
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
||||
)
|
||||
else:
|
||||
new_contents.append(content)
|
||||
if len(new_contents):
|
||||
llm_logger.debug(f"Send response for request id: {request_id}")
|
||||
self.llm_logger.debug(f"Send response for request id: {request_id}")
|
||||
self.send_response_server.send_response(request_id, new_contents)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
|
||||
def _decode_process_splitwise_requests(self):
|
||||
"""
|
||||
@@ -1378,10 +1357,8 @@ class EngineService:
|
||||
tensor_parallel_size=self.cfg.parallel_config.tensor_parallel_size,
|
||||
device_ids=device_ids,
|
||||
pod_ip=self.cfg.master_ip,
|
||||
engine_worker_queue_port=int(
|
||||
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
|
||||
),
|
||||
pid_suffix=ipc_signal_suffix,
|
||||
engine_worker_queue_port=self.cfg.parallel_config.local_engine_worker_queue_port,
|
||||
ipc_suffix=ipc_signal_suffix,
|
||||
create_cache_tensor=False,
|
||||
)
|
||||
|
||||
@@ -1390,15 +1367,15 @@ class EngineService:
|
||||
|
||||
def clear_data(self):
|
||||
try:
|
||||
llm_logger.info("Clear Data: Start")
|
||||
self.llm_logger.info("Clear Data: Start")
|
||||
self.token_processor.clear_data()
|
||||
self.engine_worker_queue.clear_data()
|
||||
self.send_response_server.req_dict.clear()
|
||||
self.recv_request_server.req_dict.clear()
|
||||
llm_logger.info("Clear Data: Successfully")
|
||||
self.llm_logger.info("Clear Data: Successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Clear data error: {e}")
|
||||
self.llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _register_to_router(self):
|
||||
@@ -1424,18 +1401,18 @@ class EngineService:
|
||||
)
|
||||
|
||||
if resp.ok:
|
||||
llm_logger.info("Successfully registered to the router!")
|
||||
self.llm_logger.info("Successfully registered to the router!")
|
||||
break
|
||||
else:
|
||||
llm_logger.error(
|
||||
self.llm_logger.error(
|
||||
f"Router registration failed: {resp.status_code}, "
|
||||
f"{resp.text}, {self.cfg.register_info}"
|
||||
)
|
||||
time.sleep(sleep_seconds)
|
||||
except requests.exceptions.RequestException as e:
|
||||
llm_logger.error(f"Register to router request error: {e}")
|
||||
self.llm_logger.error(f"Register to router request error: {e}")
|
||||
except Exception as e:
|
||||
llm_logger.exception(f"Unexpected error during router registration: {e}")
|
||||
self.llm_logger.exception(f"Unexpected error during router registration: {e}")
|
||||
|
||||
if self.cfg.router_config.router is not None:
|
||||
register_thread = threading.Thread(target=_register, daemon=True)
|
||||
@@ -1445,45 +1422,45 @@ class EngineService:
|
||||
"""
|
||||
exit sub services
|
||||
"""
|
||||
llm_logger.info("Exit sub services.....")
|
||||
self.llm_logger.info("Exit sub services.....")
|
||||
self.running = False
|
||||
|
||||
if self.use_async_llm:
|
||||
# Clean up worker processes first (before closing multiprocessing services)
|
||||
if hasattr(self, "worker_proc") and self.worker_proc is not None:
|
||||
llm_logger.info("Cleaning up worker processes...")
|
||||
self.llm_logger.info("Cleaning up worker processes...")
|
||||
try:
|
||||
pgid = os.getpgid(self.worker_proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
|
||||
self.llm_logger.error(f"Error extracting sub services: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
# Clean up cache manager processes
|
||||
if hasattr(self, "cache_manager_processes"):
|
||||
llm_logger.info("Cleaning up cache manager processes...")
|
||||
self.llm_logger.info("Cleaning up cache manager processes...")
|
||||
self.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
self.resource_manager.cache_manager.cache_ready_signal.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
self.llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
pgid = os.getpgid(p.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
self.llm_logger.error(
|
||||
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
|
||||
)
|
||||
|
||||
if hasattr(self, "cache_task_queue") and self.cache_task_queue is not None:
|
||||
llm_logger.info("Cleaning up cache_task_queue...")
|
||||
self.llm_logger.info("Cleaning up cache_task_queue...")
|
||||
# Check if cleanup method exists
|
||||
if hasattr(self.cache_task_queue, "cleanup"):
|
||||
self.cache_task_queue.cleanup()
|
||||
elif hasattr(self.cache_task_queue, "manager"):
|
||||
try:
|
||||
llm_logger.info("Shutting down cache_task_queue manager...")
|
||||
self.llm_logger.info("Shutting down cache_task_queue manager...")
|
||||
self.cache_task_queue.manager.shutdown()
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
|
||||
self.llm_logger.warning(f"Error shutting down cache_task_queue manager: {e}")
|
||||
|
||||
if hasattr(self, "get_profile_block_num_signal"):
|
||||
self.get_profile_block_num_signal.clear()
|
||||
@@ -1494,7 +1471,7 @@ class EngineService:
|
||||
# Clean up other services
|
||||
if hasattr(self, "dp_processed"):
|
||||
for p in self.dp_processed:
|
||||
llm_logger.info(f"Waiting for worker {p.pid} to exit")
|
||||
self.llm_logger.info(f"Waiting for worker {p.pid} to exit")
|
||||
p.join()
|
||||
for p in self.dp_engine_worker_queue_server:
|
||||
p.cleanup()
|
||||
@@ -1662,13 +1639,13 @@ class EngineService:
|
||||
|
||||
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
|
||||
if think_end_id > 0:
|
||||
llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
|
||||
self.llm_logger.info(f"Get think_end_id {think_end_id} from vocab.")
|
||||
else:
|
||||
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
||||
self.llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
|
||||
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
|
||||
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
|
||||
|
||||
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
||||
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
|
||||
ips = None
|
||||
if self.cfg.ips is not None:
|
||||
ips = ",".join(self.cfg.ips)
|
||||
@@ -1744,7 +1721,7 @@ class EngineService:
|
||||
if self.cfg.nnode > 1:
|
||||
pd_cmd = pd_cmd + f" --ips {ips} --nnodes {len(self.cfg.ips)}"
|
||||
pd_cmd = pd_cmd + arguments + f" 2>{log_dir}/launch_worker.log"
|
||||
llm_logger.info(f"Launch worker service command: {pd_cmd}")
|
||||
self.llm_logger.info(f"Launch worker service command: {pd_cmd}")
|
||||
p = subprocess.Popen(
|
||||
pd_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
@@ -1820,7 +1797,7 @@ class EngineService:
|
||||
else:
|
||||
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
|
||||
|
||||
llm_logger.info(f"dp start queue service {address}")
|
||||
self.llm_logger.info(f"dp start queue service {address}")
|
||||
self.dp_engine_worker_queue_server.append(
|
||||
EngineWorkerQueue(
|
||||
address=address,
|
||||
@@ -1842,7 +1819,7 @@ class EngineService:
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_logger.info(
|
||||
self.llm_logger.info(
|
||||
f"Engine is initialized successfully with {self.cfg.parallel_config.tensor_parallel_size}"
|
||||
+ f" data parallel id {i}"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
@@ -85,6 +86,7 @@ class LLMEngine:
|
||||
cfg (Config): Config object containing all the configuration parameters.
|
||||
"""
|
||||
self.cfg = cfg
|
||||
self.cfg.print()
|
||||
self.running = True
|
||||
self.is_started = False
|
||||
|
||||
@@ -368,7 +370,7 @@ class LLMEngine:
|
||||
)
|
||||
|
||||
# launched_expert_service_signal: Used to sense whether each expet_servic is started successfully
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
if self.cfg.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
|
||||
)
|
||||
@@ -524,7 +526,7 @@ class LLMEngine:
|
||||
image_patch_id = self.data_processor.tokenizer.get_vocab().get("<|IMAGE_PLACEHOLDER|>", -1)
|
||||
line_break_id = self.data_processor.tokenizer.get_vocab().get("\n", -1)
|
||||
|
||||
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
|
||||
ports = ",".join(map(str, self.cfg.parallel_config.engine_worker_queue_port))
|
||||
ips = None
|
||||
if self.cfg.ips is not None:
|
||||
ips = ",".join(self.cfg.ips)
|
||||
@@ -732,7 +734,7 @@ class LLMEngine:
|
||||
)
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.launched_expert_service_signal.value[0] = 1
|
||||
self.dp_processed = []
|
||||
self.dp_engine_worker_queue_server = []
|
||||
@@ -757,12 +759,13 @@ class LLMEngine:
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
)
|
||||
)
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
cfg = copy.deepcopy(self.cfg)
|
||||
self.dp_processed.append(
|
||||
ctx.Process(
|
||||
target=start_data_parallel_service,
|
||||
args=(
|
||||
self.cfg,
|
||||
cfg,
|
||||
i,
|
||||
None,
|
||||
request_queues_for_dp_ipc,
|
||||
|
||||
@@ -27,7 +27,7 @@ import numpy as np
|
||||
|
||||
from fastdeploy.engine.common_engine import EngineService
|
||||
from fastdeploy.inter_communicator import IPCSignal
|
||||
from fastdeploy.utils import console_logger, envs, llm_logger
|
||||
from fastdeploy.utils import console_logger, envs, get_logger, llm_logger
|
||||
|
||||
|
||||
class ExpertService:
|
||||
@@ -48,10 +48,13 @@ class ExpertService:
|
||||
"""
|
||||
|
||||
self.cfg = cfg
|
||||
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
|
||||
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
|
||||
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
self.llm_logger = get_logger("fastdeploy", f"fastdeploy_dprank{local_data_parallel_id}.log")
|
||||
else:
|
||||
self.llm_logger = llm_logger
|
||||
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT = envs.FD_ZMQ_RECV_REQUEST_SERVER_PORTS.split(",")[
|
||||
local_data_parallel_id
|
||||
@@ -59,22 +62,25 @@ class ExpertService:
|
||||
envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT = envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORTS.split(",")[
|
||||
local_data_parallel_id
|
||||
]
|
||||
self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos]
|
||||
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
|
||||
self.llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
|
||||
|
||||
if self.cfg.cache_config.num_gpu_blocks_override is None:
|
||||
self.do_profile = True
|
||||
else:
|
||||
self.do_profile = False
|
||||
|
||||
if cfg.scheduler_config.splitwise_role != "mixed":
|
||||
if len(self.cfg.cache_config.pd_comm_port) == 1:
|
||||
self.cfg.cache_config.pd_comm_port[0] = (
|
||||
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
|
||||
)
|
||||
else:
|
||||
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
# Update config for the current dp process
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
|
||||
self.cfg.postprocess_devices_and_ports()
|
||||
self.llm_logger.info(
|
||||
f"Update config for the current dp process: "
|
||||
f"local_engine_worker_queue_port: {self.cfg.parallel_config.local_engine_worker_queue_port} "
|
||||
f"local_cache_queue_port: {self.cfg.cache_config.local_cache_queue_port} "
|
||||
f"local_pd_comm_port: {self.cfg.cache_config.local_pd_comm_port} "
|
||||
f"local_rdma_comm_ports: {self.cfg.cache_config.local_rdma_comm_ports} "
|
||||
)
|
||||
|
||||
self.engine = EngineService(self.cfg, start_queue)
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.engine.scheduler.reset_nodeid(f"{self.engine.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
|
||||
@@ -107,7 +113,7 @@ class ExpertService:
|
||||
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
|
||||
self.engine.start_zmq_service(self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id])
|
||||
|
||||
llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
self.llm_logger.info(f"start expert service {local_data_parallel_id}")
|
||||
|
||||
if self.cfg.scheduler_config.name == "splitwise":
|
||||
self.cfg.init_cache_info()
|
||||
@@ -125,7 +131,7 @@ class ExpertService:
|
||||
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
|
||||
|
||||
if not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
if self.cfg.parallel_config.enable_expert_parallel:
|
||||
if self.cfg.parallel_config.data_parallel_size > 1:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
|
||||
)
|
||||
@@ -137,6 +143,7 @@ class ExpertService:
|
||||
create=False,
|
||||
)
|
||||
self.launched_expert_service_signal.value[local_rank] = 1
|
||||
|
||||
if self.do_profile:
|
||||
get_profile_block_num = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
@@ -154,10 +161,9 @@ class ExpertService:
|
||||
self.reset_kvcache_blocks()
|
||||
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching:
|
||||
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
|
||||
self.cache_manager_processes = self.engine.start_cache_service(
|
||||
self.cfg.local_device_ids,
|
||||
ipc_signal_suffix_cache,
|
||||
self.cfg.parallel_config.local_engine_worker_queue_port,
|
||||
)
|
||||
console_logger.info(
|
||||
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
|
||||
@@ -180,7 +186,7 @@ class ExpertService:
|
||||
if hasattr(self, "cache_manager_processes"):
|
||||
self.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear()
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
self.llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
except:
|
||||
|
||||
@@ -146,6 +146,7 @@ async def lifespan(app: FastAPI):
|
||||
"""
|
||||
async context manager for FastAPI lifespan
|
||||
"""
|
||||
global engine_args
|
||||
import logging
|
||||
|
||||
uvicorn_access = logging.getLogger("uvicorn.access")
|
||||
@@ -172,8 +173,8 @@ async def lifespan(app: FastAPI):
|
||||
verification = False
|
||||
model_paths = [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
fd_config = engine_args.create_engine_config(port_availability_check=False)
|
||||
engine_args = EngineArgs.from_cli_args(args, skip_port_check=True)
|
||||
fd_config = engine_args.create_engine_config()
|
||||
engine_client = EngineClient(
|
||||
pid=pid,
|
||||
port=int(os.environ.get("INFERENCE_MSG_QUEUE_ID", "0")),
|
||||
|
||||
@@ -20,37 +20,92 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from fastdeploy.utils import get_logger, is_port_available
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import find_free_ports, get_logger, is_port_available
|
||||
|
||||
logger = get_logger("multi_api_server", "multi_api_server.log")
|
||||
|
||||
|
||||
def start_servers(server_count, server_args, ports, metrics_ports, controller_ports):
|
||||
processes = []
|
||||
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
|
||||
for i in range(len(server_args)):
|
||||
if server_args[i] == "--engine-worker-queue-port":
|
||||
engine_worker_queue_port = server_args[i + 1].split(",")
|
||||
break
|
||||
def start_servers(
|
||||
server_count=None,
|
||||
device_count=None,
|
||||
server_args=None,
|
||||
ports=None,
|
||||
metrics_ports=None,
|
||||
controller_ports=None,
|
||||
):
|
||||
ports = ports.split(",")
|
||||
if not check_param(ports, server_count):
|
||||
return
|
||||
if not check_param(metrics_ports, server_count):
|
||||
return
|
||||
if not check_param(engine_worker_queue_port, server_count):
|
||||
return
|
||||
|
||||
if metrics_ports != "-1":
|
||||
metrics_ports = metrics_ports.split(",")
|
||||
if not check_param(metrics_ports, server_count):
|
||||
return
|
||||
|
||||
if controller_ports != "-1":
|
||||
controller_ports = controller_ports.split(",")
|
||||
if not check_param(controller_ports, server_count):
|
||||
return
|
||||
else:
|
||||
controller_ports = [-1] * server_count
|
||||
# check_param(server_args, server_count)
|
||||
|
||||
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
|
||||
port_idx = {}
|
||||
for i in range(len(server_args)):
|
||||
if server_args[i] == "--engine-worker-queue-port":
|
||||
port_idx["engine_worker_queue_port"] = i + 1
|
||||
if server_args[i] == "--cache-queue-port":
|
||||
port_idx["cache_queue_port"] = i + 1
|
||||
if server_args[i] == "--pd-comm-port":
|
||||
port_idx["pd_comm_port"] = i + 1
|
||||
if server_args[i] == "--rdma-comm-ports":
|
||||
port_idx["rdma_comm_ports"] = i + 1
|
||||
|
||||
if "engine_worker_queue_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--engine-worker-queue-port", ",".join(map(str, port))]
|
||||
port_idx["engine_worker_queue_port"] = len(server_args) - 1
|
||||
logger.info(f"No --engine-worker-queue-port specified, using random ports: {port}")
|
||||
engine_worker_queue_port = server_args[port_idx["engine_worker_queue_port"]].split(",")
|
||||
if not check_param(engine_worker_queue_port, server_count):
|
||||
return
|
||||
|
||||
if "cache_queue_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--cache-queue-port", ",".join(map(str, port))]
|
||||
port_idx["cache_queue_port"] = len(server_args) - 1
|
||||
logger.info(f"No --cache-queue-port specified, using random ports: {port}")
|
||||
cache_queue_port = server_args[port_idx["cache_queue_port"]].split(",")
|
||||
if not check_param(cache_queue_port, server_count):
|
||||
return
|
||||
|
||||
if "pd_comm_port" not in port_idx:
|
||||
port = find_free_ports(num_ports=server_count)
|
||||
server_args += ["--pd-comm-port", ",".join(map(str, port))]
|
||||
port_idx["pd_comm_port"] = len(server_args) - 1
|
||||
logger.info(f"No --pd-comm-port specified, using random ports: {port}")
|
||||
pd_comm_port = server_args[port_idx["pd_comm_port"]].split(",")
|
||||
if not check_param(pd_comm_port, server_count):
|
||||
return
|
||||
|
||||
if "rdma_comm_ports" not in port_idx:
|
||||
port = find_free_ports(num_ports=device_count)
|
||||
server_args += ["--rdma-comm-ports", ",".join(map(str, port))]
|
||||
port_idx["rdma_comm_ports"] = len(server_args) - 1
|
||||
logger.info(f"No --rdma-comm-ports specified, using random ports: {port}")
|
||||
rdma_comm_ports = server_args[port_idx["rdma_comm_ports"]].split(",")
|
||||
if not check_param(rdma_comm_ports, device_count):
|
||||
return
|
||||
|
||||
logger.info(f"Modified server_args: {server_args}")
|
||||
processes = []
|
||||
for i in range(server_count):
|
||||
port = int(ports[i])
|
||||
metrics_port = int(metrics_ports[i])
|
||||
controller_port = int(controller_ports[i])
|
||||
|
||||
env = os.environ.copy()
|
||||
env["FD_ENABLE_MULTI_API_SERVER"] = "1"
|
||||
env["FD_LOG_DIR"] = env.get("FD_LOG_DIR", "log") + f"/log_{i}"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
@@ -59,13 +114,13 @@ def start_servers(server_count, server_args, ports, metrics_ports, controller_po
|
||||
*server_args,
|
||||
"--port",
|
||||
str(port),
|
||||
"--metrics-port",
|
||||
str(metrics_port),
|
||||
"--controller-port",
|
||||
str(controller_port),
|
||||
"--local-data-parallel-id",
|
||||
str(i),
|
||||
]
|
||||
if metrics_ports != "-1":
|
||||
cmd += ["--metrics-port", metrics_ports[i]]
|
||||
|
||||
# 启动子进程
|
||||
proc = subprocess.Popen(cmd, env=env)
|
||||
@@ -89,21 +144,25 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server")
|
||||
parser.add_argument("--num-servers", default=2, type=int, help="number of workers")
|
||||
parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server")
|
||||
parser.add_argument("--metrics-ports", default="-1", type=str, help="ports for metrics server")
|
||||
parser.add_argument("--controller-ports", default="-1", type=str, help="ports for controller server port")
|
||||
parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting {args.num_servers} servers on ports: {args.ports} with args: {args.args}")
|
||||
# check_param(args.ports, args.num_servers)
|
||||
# check_param(args.metrics_ports, args.num_servers)
|
||||
# check_param(args.args.engine_worker_queue_port, args.num_servers)
|
||||
|
||||
device_count = 0
|
||||
if current_platform.is_cuda():
|
||||
device_count = len(os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(","))
|
||||
elif current_platform.is_xpu():
|
||||
device_count = len(os.getenv("XPU_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(","))
|
||||
|
||||
processes = start_servers(
|
||||
server_count=args.num_servers,
|
||||
device_count=device_count,
|
||||
server_args=args.args,
|
||||
ports=args.ports.split(","),
|
||||
metrics_ports=args.metrics_ports.split(","),
|
||||
ports=args.ports,
|
||||
metrics_ports=args.metrics_ports,
|
||||
controller_ports=args.controller_ports,
|
||||
)
|
||||
|
||||
|
||||
@@ -351,8 +351,8 @@ def create_model_paths(args: Namespace) -> List[ModelPath]:
|
||||
|
||||
async def initialize_engine_client(args: Namespace, pid: int) -> EngineClient:
|
||||
"""Initialize and configure the engine client."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
fd_config = engine_args.create_engine_config(port_availability_check=False)
|
||||
engine_args = EngineArgs.from_cli_args(args, skip_port_check=True)
|
||||
fd_config = engine_args.create_engine_config()
|
||||
engine_client = EngineClient(
|
||||
pid=pid,
|
||||
port=int(args.engine_worker_queue_port[args.local_data_parallel_id]),
|
||||
@@ -485,7 +485,7 @@ async def main(args: argparse.Namespace):
|
||||
try:
|
||||
if args.workers is None:
|
||||
args.workers = max(min(int(args.max_num_seqs // 32), 8), 1)
|
||||
|
||||
console_logger.info(f"Workers: {args.workers}")
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
|
||||
if args.tool_parser_plugin:
|
||||
|
||||
@@ -55,7 +55,7 @@ class SplitwiseConnector:
|
||||
self.current_request_ids = dict()
|
||||
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
|
||||
|
||||
if self.cfg.cache_config.pd_comm_port is not None:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.push_sockets: Dict[str, zmq.Socket] = {}
|
||||
self.pull_socket = None
|
||||
@@ -71,8 +71,8 @@ class SplitwiseConnector:
|
||||
self.router_socket.setsockopt(zmq.LINGER, 0)
|
||||
self.router_socket.setsockopt(zmq.SNDHWM, 1000)
|
||||
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
|
||||
self.logger.info(f"_init_network: bind {self.cfg.cache_config.pd_comm_port}")
|
||||
self.logger.info(f"_init_network: bind {self.cfg.cache_config.local_pd_comm_port}")
|
||||
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.local_pd_comm_port}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
@@ -603,6 +603,19 @@ def get_random_port():
|
||||
continue
|
||||
|
||||
|
||||
def parse_ports(ports):
|
||||
if ports is None:
|
||||
return None
|
||||
elif isinstance(ports, int):
|
||||
return [ports]
|
||||
elif isinstance(ports, str):
|
||||
return [int(p) for p in ports.split(",")]
|
||||
elif isinstance(ports, list):
|
||||
return [int(p) for p in ports]
|
||||
else:
|
||||
raise TypeError(f"Cannot parse ports into List[int]: {ports}")
|
||||
|
||||
|
||||
def is_port_available(host, port):
|
||||
"""
|
||||
Check the port is available
|
||||
@@ -621,6 +634,57 @@ def is_port_available(host, port):
|
||||
return True
|
||||
|
||||
|
||||
def find_free_ports(
|
||||
port_range: tuple[int, int] = (8000, 65535),
|
||||
num_ports: int = 1,
|
||||
host: str = "0.0.0.0",
|
||||
) -> list[int]:
|
||||
"""
|
||||
Find available TCP ports in a given range, scanning from a random start.
|
||||
|
||||
Args:
|
||||
port_range: (start, end), inclusive, e.g. (20000, 30000).
|
||||
num_ports: number of ports to find.
|
||||
host: host to bind, default "0.0.0.0".
|
||||
|
||||
Returns:
|
||||
List of available ports with length == num_ports.
|
||||
|
||||
Raises:
|
||||
ValueError: invalid port range or num_ports <= 0.
|
||||
RuntimeError: not enough free ports in the range.
|
||||
"""
|
||||
start, end = port_range
|
||||
if start < 0 or end > 65535 or start > end:
|
||||
raise ValueError(f"Invalid port range: {port_range}")
|
||||
|
||||
if num_ports <= 0:
|
||||
raise ValueError("num_ports must be a positive integer")
|
||||
|
||||
total_ports = end - start + 1
|
||||
if num_ports > total_ports:
|
||||
raise ValueError("num_ports is larger than range size")
|
||||
|
||||
# Generate all ports and rotate with a random start index
|
||||
ports = list(range(start, end + 1))
|
||||
offset = random.randint(0, total_ports - 1)
|
||||
ports = ports[offset:] + ports[:offset]
|
||||
|
||||
free_ports: list[int] = []
|
||||
|
||||
for port in ports:
|
||||
if is_port_available(host, port):
|
||||
free_ports.append(port)
|
||||
|
||||
if len(free_ports) >= num_ports:
|
||||
break
|
||||
|
||||
if len(free_ports) < num_ports:
|
||||
raise RuntimeError(f"Only found {len(free_ports)} free ports in {port_range}, requested {num_ports}.")
|
||||
|
||||
return free_ports
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
"""
|
||||
Singleton decorator for a class.
|
||||
|
||||
@@ -103,7 +103,7 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.forward_meta: ForwardMeta = None
|
||||
|
||||
# Postprocess Env params
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port)
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
|
||||
@@ -208,8 +208,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.forward_meta: ForwardMeta = None
|
||||
|
||||
# Postprocess Env params
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}")
|
||||
|
||||
# Rollout routing replay config
|
||||
self.routing_replay_manager = None
|
||||
@@ -1610,7 +1610,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -1692,6 +1692,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}")
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
logger.info("kv cache is initialized!")
|
||||
|
||||
def _initialize_attn_backend(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -386,7 +386,7 @@ class HPUModelRunner(ModelRunnerBase):
|
||||
self.is_hpu_perf_breakdown_sync_mode = int(os.environ.get("HPU_PERF_BREAKDOWN_SYNC_MODE", 1)) == 1
|
||||
# Postprocess Env params
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
|
||||
self.local_rank + int(self.parallel_config.engine_worker_queue_port)
|
||||
self.local_rank + int(self.parallel_config.local_engine_worker_queue_port)
|
||||
)
|
||||
|
||||
if int(os.environ.get("HABANA_PROFILE", 0)) == 1:
|
||||
|
||||
@@ -185,8 +185,8 @@ class MetaxModelRunner(ModelRunnerBase):
|
||||
self.forward_meta: ForwardMeta = None
|
||||
|
||||
# Postprocess Env params
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}")
|
||||
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.local_engine_worker_queue_port)
|
||||
logger.info(f"queue id is {str(self.parallel_config.local_engine_worker_queue_port)}")
|
||||
|
||||
# Rollout routing replay config
|
||||
self.routing_replay_manager = None
|
||||
|
||||
@@ -173,11 +173,7 @@ class PaddleDisWorkerProc:
|
||||
exist_swapped_task_signal:
|
||||
model_weights_status:
|
||||
"""
|
||||
if (
|
||||
self.parallel_config.enable_expert_parallel
|
||||
and self.parallel_config.data_parallel_size > 1
|
||||
and not envs.FD_ENABLE_MULTI_API_SERVER
|
||||
):
|
||||
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER:
|
||||
launched_expert_service_signal_data = np.zeros(
|
||||
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
|
||||
)
|
||||
@@ -217,7 +213,7 @@ class PaddleDisWorkerProc:
|
||||
name="worker_healthy_live_signal",
|
||||
array=workers_alive,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
@@ -229,7 +225,7 @@ class PaddleDisWorkerProc:
|
||||
name="model_weights_status",
|
||||
array=workers_model_weights,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -239,7 +235,7 @@ class PaddleDisWorkerProc:
|
||||
name="exist_task_signal",
|
||||
array=workers_exist_task,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -249,7 +245,7 @@ class PaddleDisWorkerProc:
|
||||
name="exist_swapped_task_signal",
|
||||
array=workers_swapped_task,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -259,7 +255,7 @@ class PaddleDisWorkerProc:
|
||||
name="exist_prefill_task_signal",
|
||||
array=exist_prefill_task_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
@@ -304,11 +300,11 @@ class PaddleDisWorkerProc:
|
||||
rank=self.local_rank,
|
||||
ep_size=self.ranks,
|
||||
fd_config=self.fd_config,
|
||||
ipc_signal_suffix=self.parallel_config.engine_worker_queue_port,
|
||||
ipc_signal_suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
)
|
||||
|
||||
dp_ipc_signal_suffix = (
|
||||
f"{self.parallel_config.engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}"
|
||||
f"{self.parallel_config.local_engine_worker_queue_port}_dp{self.parallel_config.local_data_parallel_id}"
|
||||
)
|
||||
if local_rank == 0: # master rank0
|
||||
signal_update_weight_from_tensor = np.zeros([1], dtype=np.int32)
|
||||
@@ -355,7 +351,7 @@ class PaddleDisWorkerProc:
|
||||
[MODEL_MAIN_NAME],
|
||||
self.local_rank,
|
||||
self.ranks,
|
||||
shm_uuid=self.parallel_config.engine_worker_queue_port,
|
||||
shm_uuid=self.parallel_config.local_engine_worker_queue_port,
|
||||
eplb_config=self.eplb_config,
|
||||
logger=logger,
|
||||
)
|
||||
@@ -468,7 +464,7 @@ class PaddleDisWorkerProc:
|
||||
self.model_weights_status,
|
||||
# model_weights_signal
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
self.parallel_config.local_engine_worker_queue_port,
|
||||
)
|
||||
logger.info(f"current task queue data: {self.task_queue.num_tasks()}")
|
||||
self.task_queue.clear_data()
|
||||
@@ -596,10 +592,10 @@ class PaddleDisWorkerProc:
|
||||
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
|
||||
task_address = (
|
||||
self.parallel_config.pod_ip,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
self.parallel_config.local_engine_worker_queue_port,
|
||||
)
|
||||
else:
|
||||
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock"
|
||||
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.local_engine_worker_queue_port}.sock"
|
||||
logger.info(f"connect task queue address {task_address}")
|
||||
self.task_queue = TaskQueue(
|
||||
address=task_address,
|
||||
@@ -937,10 +933,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
parallel_config.num_experts_per_rank = num_experts_per_rank
|
||||
parallel_config.num_experts_start_offset = num_experts_start_offset
|
||||
|
||||
if args.load_strategy != "meta":
|
||||
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
|
||||
parallel_config.local_data_parallel_id
|
||||
]
|
||||
parallel_config.set_communicate_group()
|
||||
|
||||
load_config = LoadConfig(vars(args))
|
||||
@@ -1015,6 +1007,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
eplb_config=eplb_config,
|
||||
routing_replay_config=routing_replay_config,
|
||||
)
|
||||
logger.info(f"parallel_config.local_engine_worker_queue_port {parallel_config.local_engine_worker_queue_port}")
|
||||
|
||||
update_fd_config_for_mm(fd_config)
|
||||
if fd_config.load_config.load_choices == "default_v1" and not v1_loader_support(fd_config):
|
||||
fd_config.load_config.load_choices = "default"
|
||||
|
||||
@@ -1006,7 +1006,7 @@ class XPUModelRunner(ModelRunnerBase):
|
||||
name="cache_ready_signal",
|
||||
array=cache_ready_signal_data,
|
||||
dtype=np.int32,
|
||||
suffix=self.parallel_config.engine_worker_queue_port,
|
||||
suffix=self.parallel_config.local_engine_worker_queue_port,
|
||||
create=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -300,6 +300,7 @@ setup(
|
||||
"input/ernie4_5_vl_processor/utils/*",
|
||||
"model_executor/ops/gcu/*",
|
||||
"model_executor/ops/gcu/fastdeploy_ops/*",
|
||||
"cache_manager/transfer_factory/get_rdma_nics.sh",
|
||||
"version.txt",
|
||||
]
|
||||
},
|
||||
|
||||
@@ -29,7 +29,7 @@ class Args:
|
||||
mp_num = 1
|
||||
device_id = 0
|
||||
speculative_config = {}
|
||||
engine_pid = "test_pid"
|
||||
ipc_suffix = "test_ipc_suffix"
|
||||
cache_queue_port = 9999
|
||||
pod_ip = "127.0.0.1"
|
||||
engine_worker_queue_port = 9998
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastdeploy.cache_manager.transfer_factory.rdma_cache_transfer import (
|
||||
RDMACommManager,
|
||||
)
|
||||
|
||||
|
||||
class TestRDMACommManager(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.args = {
|
||||
"splitwise_role": "prefill",
|
||||
"rank": 0,
|
||||
"gpu_id": 0,
|
||||
"cache_k_ptr_list": [1, 2, 3],
|
||||
"cache_v_ptr_list": [4, 5, 6],
|
||||
"max_block_num": 10,
|
||||
"block_bytes": 1024,
|
||||
"rdma_port": 12345,
|
||||
"prefill_tp_size": 1,
|
||||
"prefill_tp_idx": 0,
|
||||
}
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "", "KVCACHE_RDMA_NICS": ""})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_init_rdma_comm_manager_on_gpu_init_all(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
# Case: Automatically set all environment variables
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.device_name = "gpu"
|
||||
mock_run.side_effect = [
|
||||
Mock(returncode=0, stdout="8.0\n"),
|
||||
Mock(returncode=0, stdout="KVCACHE_RDMA_NICS=mlx5_2\n"),
|
||||
]
|
||||
|
||||
manager = RDMACommManager(**self.args)
|
||||
self.assertEqual(manager.splitwise_role, "prefill")
|
||||
self.assertEqual(mock_run.call_count, 2)
|
||||
mock_rdma_comm.assert_called_once()
|
||||
self.assertEqual(os.getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"), "1")
|
||||
self.assertEqual(os.getenv("KVCACHE_RDMA_NICS"), "mlx5_2")
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "", "KVCACHE_RDMA_NICS": "mlx5_1"})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_init_rdma_comm_manager_on_gpu_init_gdrcopy(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
# Case: Only set KVCACHE_GDRCOPY_FLUSH_ENABLE
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.device_name = "gpu"
|
||||
mock_run.side_effect = [Mock(returncode=0, stdout="8.0\n")]
|
||||
|
||||
manager = RDMACommManager(**self.args)
|
||||
self.assertEqual(manager.splitwise_role, "prefill")
|
||||
self.assertEqual(mock_run.call_count, 1)
|
||||
mock_rdma_comm.assert_called_once()
|
||||
self.assertEqual(os.getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"), "1")
|
||||
self.assertEqual(os.getenv("KVCACHE_RDMA_NICS"), "mlx5_1")
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "0", "KVCACHE_RDMA_NICS": ""})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_init_rdma_comm_manager_on_gpu_init_nics(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
# Case: Only set KVCACHE_RDMA_NICS
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.device_name = "gpu"
|
||||
mock_run.side_effect = [Mock(returncode=0, stdout="KVCACHE_RDMA_NICS=mlx5_2\n")]
|
||||
|
||||
manager = RDMACommManager(**self.args)
|
||||
self.assertEqual(manager.splitwise_role, "prefill")
|
||||
self.assertEqual(mock_run.call_count, 1)
|
||||
mock_rdma_comm.assert_called_once()
|
||||
self.assertEqual(os.getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"), "0")
|
||||
self.assertEqual(os.getenv("KVCACHE_RDMA_NICS"), "mlx5_2")
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "0", "KVCACHE_RDMA_NICS": "mlx5_1"})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_init_rdma_comm_manager_on_gpu_init_nothing(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
# Case: Do not set any environment variables
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.device_name = "gpu"
|
||||
|
||||
manager = RDMACommManager(**self.args)
|
||||
self.assertEqual(manager.splitwise_role, "prefill")
|
||||
self.assertEqual(mock_run.call_count, 0)
|
||||
mock_rdma_comm.assert_called_once()
|
||||
self.assertEqual(os.getenv("KVCACHE_GDRCOPY_FLUSH_ENABLE"), "0")
|
||||
self.assertEqual(os.getenv("KVCACHE_RDMA_NICS"), "mlx5_1")
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "0", "KVCACHE_RDMA_NICS": "mlx5_1"})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_connect_success(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
"""Test successful connection"""
|
||||
manager = RDMACommManager(**self.args)
|
||||
manager.messager.is_connected.return_value = False
|
||||
manager.messager.connect.return_value = 0
|
||||
|
||||
result = manager.connect("127.0.0.1", 12345)
|
||||
self.assertTrue(result)
|
||||
manager.messager.connect.assert_called_once_with("127.0.0.1", "12345", 0)
|
||||
|
||||
@patch.dict("os.environ", {"KVCACHE_GDRCOPY_FLUSH_ENABLE": "0", "KVCACHE_RDMA_NICS": "mlx5_1"})
|
||||
@patch("fastdeploy.platforms.current_platform")
|
||||
@patch("rdma_comm.RDMACommunicator")
|
||||
@patch("subprocess.run")
|
||||
def test_write_cache(self, mock_run, mock_rdma_comm, mock_platform):
|
||||
"""Test write_cache method"""
|
||||
manager = RDMACommManager(**self.args)
|
||||
manager.messager.write_cache.return_value = True
|
||||
|
||||
result = manager.write_cache("127.0.0.1", 12345, [1, 2], [3, 4], 0)
|
||||
self.assertTrue(result)
|
||||
manager.messager.write_cache.assert_called_once_with("127.0.0.1", "12345", [1, 2], [3, 4], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
@@ -47,7 +48,7 @@ def setup_and_run_server():
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8333))
|
||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
|
||||
clean_ports([FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT, FD_CONTROLLER_PORT])
|
||||
|
||||
env = os.environ.copy()
|
||||
@@ -174,10 +175,9 @@ def parse_prometheus_to_dict(metrics_text: str):
|
||||
value = float(line.split("}")[1].strip())
|
||||
|
||||
# 解析 labels
|
||||
labels = {}
|
||||
for kv in labels_str.split(","):
|
||||
k, v = kv.split("=")
|
||||
labels[k] = v.strip('"')
|
||||
# 用正则取出所有 key 和 value(去掉外层引号)
|
||||
pairs = re.findall(r'(\w+)="([^"]*)"', labels_str)
|
||||
labels = {k: v for k, v in pairs}
|
||||
|
||||
# 存储
|
||||
if metric_name not in result:
|
||||
@@ -214,7 +214,7 @@ def test_metrics_with_clear_and_reset():
|
||||
"""
|
||||
Test the metrics monitoring endpoint.
|
||||
"""
|
||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8333))
|
||||
FD_CONTROLLER_PORT = int(os.getenv("FD_CONTROLLER_PORT", 8633))
|
||||
metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
async_concurrency(n=10)
|
||||
|
||||
@@ -200,14 +200,26 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
and to drive specific code paths that were previously uncovered.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
patch("fastdeploy.engine.common_engine.EngineCacheQueue").start()
|
||||
|
||||
def _make_cfg(self, **kwargs):
|
||||
# If DP > 1, we must provide enough engine_worker_queue_port for each dp index
|
||||
dp = kwargs.get("data_parallel_size", 1)
|
||||
nnode = len(kwargs.get("ips", ["127.0.0.1"]))
|
||||
engine_worker_queue_port = int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778"))
|
||||
cache_queue_port = int(os.getenv("FD_CACHE_QUEUE_PORT", "6779"))
|
||||
if dp and dp > 1:
|
||||
engine_worker_queue_port = [engine_worker_queue_port + 20 + i for i in range(dp // nnode)]
|
||||
cache_queue_port = [cache_queue_port + 20 + i for i in range(dp // nnode)]
|
||||
|
||||
args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
max_model_len=128,
|
||||
tensor_parallel_size=1,
|
||||
# give unique ports to avoid collision with other tests
|
||||
engine_worker_queue_port=str(int(os.getenv("FD_ENGINE_QUEUE_PORT", "6778")) + 20),
|
||||
cache_queue_port=str(int(os.getenv("FD_CACHE_QUEUE_PORT", "6779")) + 20),
|
||||
engine_worker_queue_port=engine_worker_queue_port,
|
||||
cache_queue_port=cache_queue_port,
|
||||
enable_prefix_caching=True,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -218,14 +230,7 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
# Always enable chunked prefill in tests to avoid another strict check
|
||||
args.enable_chunked_prefill = True
|
||||
|
||||
# If DP > 1, we must provide enough engine_worker_queue_port for each dp index
|
||||
dp = kwargs.get("data_parallel_size", args.data_parallel_size)
|
||||
base = int(args.engine_worker_queue_port.split(",")[0])
|
||||
if dp and dp > 1:
|
||||
ports = ",".join(str(base + i) for i in range(dp))
|
||||
args.engine_worker_queue_port = ports
|
||||
|
||||
return args.create_engine_config(port_availability_check=False)
|
||||
return args.create_engine_config()
|
||||
|
||||
def _stub_processor(self):
|
||||
class _Tok:
|
||||
@@ -574,7 +579,9 @@ class TestCommonEngineAdditionalCoverage(unittest.TestCase):
|
||||
def test_start_worker_service_cmd_build(self):
|
||||
"""Cover 1517, 1526, 1568, 1592, 1595 by building the worker command with mocks."""
|
||||
with patch("fastdeploy.config.get_host_ip", return_value="127.0.0.1"):
|
||||
cfg = self._make_cfg(splitwise_role="mixed", num_gpu_blocks_override=4, ips=["127.0.0.1", "127.0.0.2"])
|
||||
cfg = self._make_cfg(
|
||||
splitwise_role="mixed", num_gpu_blocks_override=4, ips=["127.0.0.1", "127.0.0.2"], data_parallel_size=2
|
||||
)
|
||||
# Make model multi-modal so env var branch already covered above; here not required
|
||||
cfg.structured_outputs_config.logits_processors = ["A", "B"]
|
||||
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
# 添加路径以便导入模块
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
|
||||
from fastdeploy.engine.expert_service import ExpertService, start_data_parallel_service
|
||||
|
||||
|
||||
class TestExpertService(unittest.TestCase):
|
||||
"""测试 ExpertService 类"""
|
||||
|
||||
def setUp(self):
|
||||
"""设置测试环境"""
|
||||
# 创建模拟配置对象
|
||||
self.mock_cfg = Mock()
|
||||
self.mock_cfg.parallel_config = Mock()
|
||||
self.mock_cfg.parallel_config.data_parallel_size = 1
|
||||
self.mock_cfg.parallel_config.local_engine_worker_queue_port = 8080
|
||||
self.mock_cfg.parallel_config.engine_worker_queue_port = [8080, 8081]
|
||||
self.mock_cfg.cache_config = Mock()
|
||||
self.mock_cfg.cache_config.num_gpu_blocks_override = None
|
||||
self.mock_cfg.scheduler_config = Mock()
|
||||
self.mock_cfg.scheduler_config.name = "default"
|
||||
self.mock_cfg.scheduler_config.splitwise_role = "mixed"
|
||||
self.mock_cfg.host_ip = "127.0.0.1"
|
||||
self.mock_cfg.register_info = {}
|
||||
self.mock_cfg.worker_num_per_node = 1
|
||||
self.mock_cfg.nnode = 1
|
||||
self.mock_cfg.local_device_ids = [0]
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.EngineService")
|
||||
@patch("fastdeploy.engine.expert_service.get_logger")
|
||||
@patch("fastdeploy.engine.expert_service.llm_logger")
|
||||
def test_expert_service_init_single_dp(self, mock_llm_logger, mock_get_logger, mock_engine_service):
|
||||
"""测试单数据并行模式下的初始化"""
|
||||
local_data_parallel_id = 0
|
||||
|
||||
# 创建 ExpertService 实例
|
||||
expert_service = ExpertService(self.mock_cfg, local_data_parallel_id)
|
||||
|
||||
# 验证配置设置
|
||||
self.assertEqual(expert_service.cfg, self.mock_cfg)
|
||||
|
||||
# 验证日志设置
|
||||
self.assertEqual(expert_service.llm_logger, mock_llm_logger)
|
||||
|
||||
# 验证 EngineService 初始化
|
||||
mock_engine_service.assert_called_once_with(self.mock_cfg, True)
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.EngineService")
|
||||
@patch("fastdeploy.engine.expert_service.get_logger")
|
||||
@patch("fastdeploy.engine.expert_service.envs")
|
||||
def test_expert_service_init_multi_dp(self, mock_envs, mock_get_logger, mock_engine_service):
|
||||
"""测试多数据并行模式下的初始化"""
|
||||
# 设置多数据并行配置
|
||||
self.mock_cfg.parallel_config.data_parallel_size = 2
|
||||
mock_envs.FD_ENABLE_MULTI_API_SERVER = False
|
||||
mock_envs.FD_ENABLE_INTERNAL_ADAPTER = False
|
||||
|
||||
local_data_parallel_id = 1
|
||||
mock_logger = Mock()
|
||||
mock_get_logger.return_value = mock_logger
|
||||
|
||||
# 创建 ExpertService 实例
|
||||
expert_service = ExpertService(self.mock_cfg, local_data_parallel_id)
|
||||
|
||||
# 验证配置更新
|
||||
self.assertEqual(expert_service.cfg.parallel_config.local_data_parallel_id, local_data_parallel_id)
|
||||
|
||||
# 验证多DP模式下的日志设置
|
||||
mock_get_logger.assert_called_once_with("fastdeploy", f"fastdeploy_dprank{local_data_parallel_id}.log")
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.EngineService")
|
||||
@patch("fastdeploy.engine.expert_service.time")
|
||||
@patch("fastdeploy.engine.expert_service.threading")
|
||||
@patch("fastdeploy.engine.expert_service.envs")
|
||||
def test_start_method(self, mock_envs, mock_threading, mock_time, mock_engine_service):
|
||||
mock_envs.FD_ENABLE_RETURN_TEXT = False
|
||||
mock_envs.FD_ENABLE_MULTI_API_SERVER = False
|
||||
|
||||
local_data_parallel_id = 0
|
||||
|
||||
mock_process = Mock()
|
||||
mock_process.pid = 1234
|
||||
|
||||
# 关键:设在实例 mock 上
|
||||
mock_engine_instance = mock_engine_service.return_value
|
||||
mock_engine_instance.start_cache_service.return_value = [mock_process]
|
||||
|
||||
expert_service = ExpertService(self.mock_cfg, local_data_parallel_id)
|
||||
|
||||
with patch("fastdeploy.engine.expert_service.IPCSignal") as mock_ipc_signal:
|
||||
mock_ipc_instance = Mock()
|
||||
mock_ipc_instance.value = [100]
|
||||
mock_ipc_signal.return_value = mock_ipc_instance
|
||||
|
||||
result = expert_service.start(None, local_data_parallel_id)
|
||||
|
||||
# 验证用的是 EngineService 的实例 mock
|
||||
mock_engine_instance.start.assert_called_once()
|
||||
mock_engine_instance.start_zmq_service.assert_called_once_with(
|
||||
self.mock_cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
|
||||
)
|
||||
mock_engine_instance.start_cache_service.assert_called_once()
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.EngineService")
|
||||
@patch("fastdeploy.engine.expert_service.IPCSignal")
|
||||
@patch("fastdeploy.engine.expert_service.time")
|
||||
def test_reset_kvcache_blocks(self, mock_time, mock_ipc_signal, mock_engine_service):
|
||||
"""测试重置KV缓存块功能"""
|
||||
local_data_parallel_id = 0
|
||||
|
||||
# 创建 ExpertService 实例
|
||||
expert_service = ExpertService(self.mock_cfg, local_data_parallel_id)
|
||||
expert_service.llm_logger = Mock()
|
||||
expert_service.engine = Mock()
|
||||
expert_service.engine.resource_manager = Mock()
|
||||
|
||||
# 设置模拟信号
|
||||
mock_signal_instance = Mock()
|
||||
mock_signal_instance.value = [100] # 模拟已获取的块数
|
||||
expert_service.get_profile_block_num_signal = mock_signal_instance
|
||||
|
||||
# 调用 reset_kvcache_blocks
|
||||
expert_service.reset_kvcache_blocks()
|
||||
|
||||
# 验证缓存配置重置
|
||||
self.mock_cfg.cache_config.reset.assert_called_once_with(100)
|
||||
expert_service.engine.resource_manager.reset_cache_config.assert_called_once_with(self.mock_cfg.cache_config)
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.EngineService")
|
||||
@patch("fastdeploy.engine.expert_service.os")
|
||||
@patch("fastdeploy.engine.expert_service.signal")
|
||||
def test_exit_sub_services(self, mock_signal, mock_os, mock_engine_service):
|
||||
"""测试退出子服务功能"""
|
||||
local_data_parallel_id = 0
|
||||
|
||||
# 创建 ExpertService 实例
|
||||
expert_service = ExpertService(self.mock_cfg, local_data_parallel_id)
|
||||
expert_service.llm_logger = Mock()
|
||||
|
||||
# 设置模拟缓存管理进程
|
||||
mock_process = Mock()
|
||||
mock_process.pid = 1234
|
||||
expert_service.cache_manager_processes = [mock_process]
|
||||
|
||||
# 设置模拟引擎资源管理器
|
||||
expert_service.engine = Mock()
|
||||
expert_service.engine.resource_manager = Mock()
|
||||
expert_service.engine.resource_manager.cache_manager = Mock()
|
||||
expert_service.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast = Mock()
|
||||
|
||||
# 设置模拟ZMQ服务器
|
||||
expert_service.zmq_server = Mock()
|
||||
|
||||
# 调用退出方法
|
||||
expert_service._exit_sub_services()
|
||||
|
||||
# 验证缓存管理器清理
|
||||
expert_service.engine.resource_manager.cache_manager.shm_cache_task_flag_broadcast.clear.assert_called_once()
|
||||
mock_os.killpg.assert_called_once_with(1234, mock_signal.SIGTERM)
|
||||
|
||||
# 验证ZMQ服务器关闭
|
||||
expert_service.zmq_server.close.assert_called_once()
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.ExpertService")
|
||||
@patch("fastdeploy.engine.expert_service.threading")
|
||||
@patch("fastdeploy.engine.expert_service.time")
|
||||
@patch("fastdeploy.engine.expert_service.traceback")
|
||||
def test_start_data_parallel_service_success(self, mock_traceback, mock_time, mock_threading, mock_expert_service):
|
||||
"""测试启动数据并行服务的成功情况"""
|
||||
mock_cfg = Mock()
|
||||
local_data_parallel_id = 0
|
||||
|
||||
# 模拟 ExpertService 实例
|
||||
mock_expert_instance = Mock()
|
||||
mock_expert_service.return_value = mock_expert_instance
|
||||
|
||||
# 模拟线程
|
||||
mock_thread_instance = Mock()
|
||||
mock_threading.Thread.return_value = mock_thread_instance
|
||||
|
||||
# 调用函数
|
||||
start_data_parallel_service(mock_cfg, local_data_parallel_id)
|
||||
|
||||
# 验证 ExpertService 创建和启动
|
||||
mock_expert_service.assert_called_once_with(mock_cfg, local_data_parallel_id, start_queue=False)
|
||||
mock_expert_instance.start.assert_called_once_with(None, local_data_parallel_id, None, None)
|
||||
|
||||
@patch("fastdeploy.engine.expert_service.ExpertService")
|
||||
@patch("fastdeploy.engine.expert_service.llm_logger")
|
||||
@patch("fastdeploy.engine.expert_service.traceback")
|
||||
def test_start_data_parallel_service_exception(self, mock_traceback, mock_llm_logger, mock_expert_service):
|
||||
"""测试启动数据并行服务的异常情况"""
|
||||
mock_cfg = Mock()
|
||||
local_data_parallel_id = 0
|
||||
|
||||
# 模拟 ExpertService 启动失败
|
||||
mock_expert_instance = Mock()
|
||||
mock_expert_instance.start.side_effect = Exception("Test exception")
|
||||
mock_expert_service.return_value = mock_expert_instance
|
||||
|
||||
# 模拟 traceback
|
||||
mock_traceback.format_exc.return_value = "Traceback details"
|
||||
|
||||
# 调用函数并验证没有抛出异常
|
||||
try:
|
||||
start_data_parallel_service(mock_cfg, local_data_parallel_id)
|
||||
except Exception:
|
||||
self.fail("start_data_parallel_service should handle exceptions gracefully")
|
||||
|
||||
# 验证异常被记录
|
||||
mock_llm_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -30,10 +31,26 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.test_ports = ["8000", "8001"]
|
||||
self.test_metrics_ports = ["8800", "8801"]
|
||||
self.test_server_args = ["--model", "test_model", "--engine-worker-queue-port", "9000,9001"]
|
||||
self.test_model = "test_model"
|
||||
self.test_ports = "8000,8001"
|
||||
self.test_metrics_ports = "8800,8801"
|
||||
self.test_engine_worker_queue_port = "9000,9001"
|
||||
self.test_server_args = [
|
||||
"--model",
|
||||
self.test_model,
|
||||
"--engine-worker-queue-port",
|
||||
self.test_engine_worker_queue_port,
|
||||
]
|
||||
self.test_server_count = 2
|
||||
self.test_device_count = 2
|
||||
|
||||
patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}).start()
|
||||
patch(
|
||||
"fastdeploy.entrypoints.openai.multi_api_server.find_free_ports",
|
||||
side_effect=lambda *args, **kwargs: [
|
||||
random.randint(8000, 65535) for i in range(kwargs.get("num_ports", 1))
|
||||
],
|
||||
).start()
|
||||
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.subprocess.Popen")
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.is_port_available")
|
||||
@@ -49,6 +66,7 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
# Call start_servers
|
||||
processes = start_servers(
|
||||
server_count=self.test_server_count,
|
||||
device_count=self.test_device_count,
|
||||
server_args=self.test_server_args,
|
||||
ports=self.test_ports,
|
||||
metrics_ports=self.test_metrics_ports,
|
||||
@@ -63,24 +81,20 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
|
||||
# Verify the command arguments for the first server
|
||||
first_call_args = mock_popen.call_args_list[0][0][0]
|
||||
expected_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
"test_model",
|
||||
"--engine-worker-queue-port",
|
||||
"9000,9001",
|
||||
"--port",
|
||||
"8000",
|
||||
"--metrics-port",
|
||||
"8800",
|
||||
"--controller-port",
|
||||
"-1",
|
||||
"--local-data-parallel-id",
|
||||
"0",
|
||||
]
|
||||
self.assertEqual(first_call_args, expected_cmd)
|
||||
print(first_call_args)
|
||||
for i, item in enumerate(first_call_args):
|
||||
if item == "--port":
|
||||
self.assertEqual(first_call_args[i + 1], self.test_ports.split(",")[0])
|
||||
if item == "--metrics-port":
|
||||
self.assertEqual(first_call_args[i + 1], self.test_metrics_ports.split(",")[0])
|
||||
if item == "--controller-port":
|
||||
self.assertEqual(first_call_args[i + 1], "-1")
|
||||
if item == "--model":
|
||||
self.assertEqual(first_call_args[i + 1], self.test_model)
|
||||
if item == "--engine-worker-queue-port":
|
||||
self.assertEqual(first_call_args[i + 1], self.test_engine_worker_queue_port)
|
||||
if item == "--local-data-parallel-id":
|
||||
self.assertEqual(first_call_args[i + 1], "0")
|
||||
|
||||
# Verify environment variables are set correctly
|
||||
first_call_kwargs = mock_popen.call_args_list[0][1]
|
||||
@@ -94,7 +108,7 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
mock_is_port_available.return_value = True
|
||||
|
||||
# Should not raise any exception
|
||||
check_param(self.test_ports, self.test_server_count)
|
||||
check_param(self.test_ports.split(","), self.test_server_count)
|
||||
|
||||
def test_check_param_wrong_port_count(self):
|
||||
"""Test parameter validation with wrong port count"""
|
||||
@@ -108,12 +122,13 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
# Mock port availability check - first port available, second not
|
||||
mock_is_port_available.side_effect = [True, False]
|
||||
|
||||
self.assertFalse(check_param(self.test_ports, self.test_server_count))
|
||||
self.assertFalse(check_param(self.test_ports.split(","), self.test_server_count))
|
||||
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.is_port_available")
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.start_servers")
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.time.sleep")
|
||||
@patch("fastdeploy.entrypoints.openai.multi_api_server.check_param")
|
||||
def test_main_function(self, mock_check_param, mock_sleep, mock_start_servers):
|
||||
def test_main_function(self, mock_check_param, mock_sleep, mock_start_servers, mock_is_port_available):
|
||||
"""Test main function with mocked arguments"""
|
||||
# Mock command line arguments
|
||||
test_args = [
|
||||
@@ -133,6 +148,9 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
"9000,9001",
|
||||
]
|
||||
|
||||
# Mock utilization functions
|
||||
mock_is_port_available.return_value = True
|
||||
|
||||
# Mock processes
|
||||
mock_proc1 = MagicMock()
|
||||
mock_proc2 = MagicMock()
|
||||
@@ -144,12 +162,14 @@ class TestMultiApiServer(unittest.TestCase):
|
||||
with patch("sys.argv", test_args):
|
||||
main()
|
||||
|
||||
print(mock_start_servers)
|
||||
# Verify start_servers was called with correct parameters
|
||||
mock_start_servers.assert_called_once_with(
|
||||
server_count=2,
|
||||
server_args=["--model", "test_model", "--engine-worker-queue-port", "9000,9001"],
|
||||
ports=["8000", "8001"],
|
||||
metrics_ports=["8800", "8801"],
|
||||
server_count=self.test_server_count,
|
||||
device_count=self.test_device_count,
|
||||
server_args=self.test_server_args,
|
||||
ports=self.test_ports,
|
||||
metrics_ports=self.test_metrics_ports,
|
||||
controller_ports="8802,8803",
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -132,6 +133,59 @@ class TestConfig(unittest.TestCase):
|
||||
fd_config.init_cache_info()
|
||||
assert fd_config.register_info is not None
|
||||
|
||||
def test_fdconfig_postprocess_ports(self):
|
||||
data_parallel_size = 4
|
||||
tensor_parallel_size = 2
|
||||
local_data_parallel_id = random.randint(0, data_parallel_size - 1)
|
||||
engine_worker_queue_ports = [random.randint(8000, 65535) for _ in range(data_parallel_size)]
|
||||
cache_queue_ports = [random.randint(8000, 65535) for _ in range(data_parallel_size)]
|
||||
pd_comm_ports = [random.randint(8000, 65535) for _ in range(data_parallel_size)]
|
||||
rdma_comm_ports = [random.randint(8000, 65535) for _ in range(data_parallel_size * tensor_parallel_size)]
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
{
|
||||
"engine_worker_queue_port": ",".join(map(str, engine_worker_queue_ports)),
|
||||
"data_parallel_size": data_parallel_size,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"local_data_parallel_id": local_data_parallel_id,
|
||||
}
|
||||
)
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
cache_config = CacheConfig(
|
||||
{
|
||||
"cache_queue_port": ",".join(map(str, cache_queue_ports)),
|
||||
"pd_comm_port": ",".join(map(str, pd_comm_ports)),
|
||||
"rdma_comm_ports": ",".join(map(str, rdma_comm_ports)),
|
||||
}
|
||||
)
|
||||
load_config = LoadConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
model_config: Mock = Mock()
|
||||
model_config.max_model_len = 512
|
||||
|
||||
fd_config = FDConfig(
|
||||
parallel_config=parallel_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
ips="0.0.0.0",
|
||||
test_mode=True,
|
||||
)
|
||||
assert (
|
||||
fd_config.parallel_config.local_engine_worker_queue_port
|
||||
== engine_worker_queue_ports[local_data_parallel_id]
|
||||
)
|
||||
assert fd_config.cache_config.local_cache_queue_port == cache_queue_ports[local_data_parallel_id]
|
||||
assert fd_config.cache_config.local_pd_comm_port == pd_comm_ports[local_data_parallel_id]
|
||||
assert (
|
||||
fd_config.cache_config.local_rdma_comm_ports
|
||||
== rdma_comm_ports[
|
||||
local_data_parallel_id * tensor_parallel_size : (local_data_parallel_id + 1) * tensor_parallel_size
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user