[PD Disaggregation] support DP via v1 router and decouple DP and EP (#5197)

* [fix] support DP via v1 router and decouple DP and EP

* [fix] fix scripts

* [fix] reset model path

* [fix] dp use get_output_ep, fix router port type, update scripts

* [merge] merge with latest code

* [chore] remove some debug log

* [fix] fix code style check

* [fix] fix test_multi_api_server for log_dir name

* [chore] reduce logs

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Yonghua Li
2025-12-04 15:38:43 +08:00
committed by GitHub
parent 5cd17fd662
commit f4119d51b4
15 changed files with 394 additions and 146 deletions
+48 -49
View File
@@ -26,71 +26,70 @@
#define MAX_BSZ 512 #define MAX_BSZ 512
// #define GET_OUTPUT_DEBUG // #define GET_OUTPUT_DEBUG
struct msgdata { struct msgdata {
long mtype; long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
}; };
void GetOutput(const paddle::Tensor& x, void GetOutput(const paddle::Tensor& x,
int64_t rank_id, int64_t rank_id,
bool wait_flag, bool wait_flag,
int msg_queue_id) { int msg_queue_id) {
if (rank_id > 0) { if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
return; return;
}
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
static key_t key = ftok("/dev/shm", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
#endif
return;
} }
void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
GetOutput(x, rank_id, wait_flag, 1); GetOutput(x, rank_id, wait_flag, 1);
} }
void GetOutputDynamic(const paddle::Tensor& x, void GetOutputDynamic(const paddle::Tensor& x,
int64_t rank_id, int64_t rank_id,
bool wait_flag, bool wait_flag,
int msg_queue_id) { int msg_queue_id) {
GetOutput(x, rank_id, wait_flag, msg_queue_id); GetOutput(x, rank_id, wait_flag, msg_queue_id);
} }
PD_BUILD_STATIC_OP(get_output) PD_BUILD_STATIC_OP(get_output)
+141
View File
@@ -0,0 +1,141 @@
#!/bin/bash
set -e
# Test splitwise deployment
# There are two methods for splitwise deployment:
# v0: using splitwise_scheduler or dp_scheduler
# v1: using local_scheduler + router
MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
DATA_PARALLEL_SIZE=2
TENSOR_PARALLEL_SIZE=1
NUM_GPUS=$(($DATA_PARALLEL_SIZE * $TENSOR_PARALLEL_SIZE))
LOG_DATE=$(date +%Y%m%d_%H%M%S)
export FD_DEBUG=1
export ENABLE_V1_KVCACHE_SCHEDULER=1
export KVCACHE_GDRCOPY_FLUSH_ENABLE=1
export FD_ENABLE_MULTI_API_SERVER=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu)
echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}"
if [ -z "${KVCACHE_RDMA_NICS}" ]; then
echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh"
exit 1
fi
unset http_proxy && unset https_proxy
source ${SCRIPT_DIR}/utils.sh
# start router
ROUTER_PORT=$(get_free_ports 1)
echo "---------------------------"
echo ROUTER_PORT: $ROUTER_PORT
export FD_LOG_DIR="log/$LOG_DATE/router"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.router.launch \
--port ${ROUTER_PORT} \
--splitwise \
2>&1 >${FD_LOG_DIR}/nohup &
sleep 1
# start prefill
P_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
P_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
P_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
echo "---------------------------"
echo P_SERVER_PORTS: $P_SERVER_PORTS
echo P_METRICS_PORTS: $P_METRICS_PORTS
echo P_ENGINE_WORKER_QUEUE_PORTS: $P_ENGINE_WORKER_QUEUE_PORTS
echo P_CACHE_QUEUE_PORTS: $P_CACHE_QUEUE_PORTS
echo P_RDMA_COMM_PORTS: $P_RDMA_COMM_PORTS
echo P_PD_COMM_PORTS: $P_PD_COMM_PORTS
export CUDA_VISIBLE_DEVICES="0,1"
export FD_LOG_DIR="log/$LOG_DATE/prefill"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
--num-servers ${DATA_PARALLEL_SIZE}\
--ports ${P_SERVER_PORTS} \
--metrics-port ${P_METRICS_PORTS} \
--args --model ${MODEL_NAME} \
--engine-worker-queue-port ${P_ENGINE_WORKER_QUEUE_PORTS} \
--cache-queue-port ${P_CACHE_QUEUE_PORTS} \
--max-model-len 32768 \
--data-parallel-size ${DATA_PARALLEL_SIZE} \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--splitwise-role "prefill" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports ${P_RDMA_COMM_PORTS} \
--pd-comm-port ${P_PD_COMM_PORTS} \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
echo "--- Health Check Status ---"
wait_for_health ${P_SERVER_PORTS}
# start decode
D_SERVER_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_ENGINE_WORKER_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_CACHE_QUEUE_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_METRICS_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
D_RDMA_COMM_PORTS=$(get_free_ports $NUM_GPUS)
D_PD_COMM_PORTS=$(get_free_ports $DATA_PARALLEL_SIZE)
echo "---------------------------"
echo D_SERVER_PORTS: $D_SERVER_PORTS
echo D_ENGINE_WORKER_QUEUE_PORTS: $D_ENGINE_WORKER_QUEUE_PORTS
echo D_CACHE_QUEUE_PORTS: $D_CACHE_QUEUE_PORTS
echo D_METRICS_PORTS: $D_METRICS_PORTS
echo D_RDMA_COMM_PORTS: $D_RDMA_COMM_PORTS
echo D_PD_COMM_PORTS: $D_PD_COMM_PORTS
export CUDA_VISIBLE_DEVICES="2,3"
export FD_LOG_DIR="log/$LOG_DATE/decode"
rm -rf $FD_LOG_DIR
mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.multi_api_server \
--num-servers ${DATA_PARALLEL_SIZE}\
--ports ${D_SERVER_PORTS} \
--metrics-port ${D_METRICS_PORTS} \
--args --model ${MODEL_NAME} \
--engine-worker-queue-port ${D_ENGINE_WORKER_QUEUE_PORTS} \
--cache-queue-port ${D_CACHE_QUEUE_PORTS} \
--max-model-len 32768 \
--data-parallel-size ${DATA_PARALLEL_SIZE} \
--tensor-parallel-size ${TENSOR_PARALLEL_SIZE} \
--splitwise-role "decode" \
--cache-transfer-protocol "rdma" \
--rdma-comm-ports ${D_RDMA_COMM_PORTS} \
--pd-comm-port ${D_PD_COMM_PORTS} \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
echo "--- Health Check Status ---"
wait_for_health ${D_SERVER_PORTS}
# send request
echo "------ Request Check ------"
sleep 10 # make sure server is registered to router
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"max_tokens": 100,
"stream": false
}'
+81 -8
View File
@@ -1,8 +1,16 @@
#!/bin/bash #!/bin/bash
is_port_free() {
local port=$1
if ss -ltn | awk '{print $4}' | grep -q ":${port}$"; then
return 1 # Port is occupied
fi
return 0 # Port is free
}
check_ports() { check_ports() {
for port in "$@"; do for port in "$@"; do
if ss -tuln | grep -q ":$port "; then if ! is_port_free $port; then
echo "❌ Port $port is already in use" echo "❌ Port $port is already in use"
return 1 return 1
fi fi
@@ -11,14 +19,79 @@ check_ports() {
} }
wait_for_health() { wait_for_health() {
local server_port=$1 IFS=',' read -r -a server_ports <<< "$1"
local num_ports=${#server_ports[@]}
local total_lines=$((num_ports + 1))
local first_run=true
local GREEN='\033[0;32m'
local RED='\033[0;31m'
local NC='\033[0m' # No Color
local start_time=$(date +%s)
while true; do while true; do
status_code=$(curl -s -o /dev/null -w "%{http_code}" "http://0.0.0.0:${server_port}/health" || echo "000") local all_ready=true
if [ "$status_code" -eq 200 ]; then for port in "${server_ports[@]}"; do
status_code=$(curl -s --max-time 1 -o /dev/null -w "%{http_code}" "http://0.0.0.0:${port}/health" || echo "000")
if [ "$status_code" -eq 200 ]; then
printf "Port %s: ${GREEN}[OK] 200${NC}\033[K\n" "$port"
else
all_ready=false
printf "Port %s: ${RED}[WAIT] %s${NC}\033[K\n" "$port" "$status_code"
fi
done
cur_time=$(date +%s)
if [ "$all_ready" = "true" ]; then
echo "All services are ready! [$((cur_time-start_time))s]"
break break
else else
echo "Service not ready. Retrying in 4s..." echo "Waiting for services... [$((cur_time-start_time))s]"
sleep 4 printf "\033[%dA" "$total_lines" # roll back cursor
fi sleep 1
fi
done done
} }
get_free_ports() {
free_ports_num=${1:-1}
start_port=${2:-8000}
end_port=${3:-9000}
free_ports=()
if [[ ! -n ${free_ports_num} || "${free_ports_num}" -le 0 ]]; then
log_warn "param can't be empty, and should > 0"
echo ${free_ports[@]}
return 1
fi
used_ports1=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($4,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
used_ports2=$(netstat -an | grep -E "(0.0.0.0|127.0.0.1|${POD_IP}|tcp6)" | awk '{n=split($5,a,":"); if(a[n]~/^[0-9]+$/) print a[n];}' | sort -u)
all_used_ports=$(printf "%s\n" "${used_ports1}" "${used_ports2}" | sort -u)
# Generate random number between 0 and 32767
random_num=$(( RANDOM ))
port=$(( random_num % (end_port - start_port + 1) + start_port ))
while true; do
(( port++ ))
if [[ ${port} -ge ${end_port} ]]; then
port=${start_port}
fi
if [[ "${all_used_ports[@]}" =~ "${port}" ]]; then
continue
fi
if is_port_free ${port}; then
free_ports+=("${port}")
(( free_ports_num-- ))
if [[ ${free_ports_num} = 0 ]]; then
break
fi
fi
done
# echo ${free_ports[@]}
IFS=',' && echo "${free_ports[*]}"
return 0
}
+12 -7
View File
@@ -41,7 +41,7 @@ from fastdeploy.inter_communicator import (
) )
from fastdeploy.utils import envs, get_logger from fastdeploy.utils import envs, get_logger
logger = get_logger("cache_messager", "cache_messager.log") # logger = get_logger("cache_messager", "cache_messager.log")
def parse_args(): def parse_args():
@@ -552,6 +552,7 @@ class CacheMessagerV1:
cache_info = self.engine_worker_queue.get_cache_info() cache_info = self.engine_worker_queue.get_cache_info()
finished_add_cache_task_req_ids = [] finished_add_cache_task_req_ids = []
if cache_info: if cache_info:
logger.debug(f"Get cache info from engine worker queue, {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait() self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info: for info in cache_info:
if info["request_id"] in self.cache_info: if info["request_id"] in self.cache_info:
@@ -570,14 +571,15 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1 current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init" current_info["status"] = "init"
logger.info(f"Get cache info from P: finish add cache task: {current_info}") logger.info(f"Get cache info from D: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info self.idx_cache_task_dict[current_info["current_id"]] = current_info
else: else:
logger.info(f"Get cache info from D: {info}") logger.info(f"Get cache info from P: {info}")
self.cache_info[info["request_id"]] = info self.cache_info[info["request_id"]] = info
if finished_add_cache_task_req_ids: if finished_add_cache_task_req_ids:
logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}")
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait() self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else: else:
@@ -671,7 +673,7 @@ class CacheMessagerV1:
target_ip, target_id, decode_tp_size target_ip, target_id, decode_tp_size
) )
if status: if status:
logger.info(f"connect to {target_ip}:{target_id} success") logger.debug(f"connect to {target_ip}:{target_id} success")
else: else:
logger.error(f"connect to {target_ip}:{target_id} failed") logger.error(f"connect to {target_ip}:{target_id} failed")
task["status"] = "connection error" task["status"] = "connection error"
@@ -722,7 +724,7 @@ class CacheMessagerV1:
if "error" not in task["status"]: if "error" not in task["status"]:
task["status"] = "finished" task["status"] = "finished"
logger.info( logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}" f"Finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
) )
else: else:
task["sended_layer_id"] = -1 task["sended_layer_id"] = -1
@@ -736,7 +738,9 @@ class CacheMessagerV1:
self.messager["ipc"].write_block_by_sync(target_id) self.messager["ipc"].write_block_by_sync(target_id)
self.engine_worker_queue.finish_send_cache_barrier.wait() self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]]) self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}") logger.info(
f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}"
)
self.engine_cache_tasks[task["current_id"]] = dict() self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]] del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]] del self.idx_cache_task_dict[task["current_id"]]
@@ -928,7 +932,8 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log") logger = get_logger("cache_messager", f"cache_messager_tprank{args.rank}.log")
logger.info("create cache messager...") logger.info("create cache messager...")
logger.info(f"{args}") logger.info(f"{args}")
main() main()
@@ -740,6 +740,6 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_rank{rank_id}.log") logger = get_logger("cache_transfer_manager", f"cache_transfer_manager_tprank{args.rank}.log")
set_device(args.device_id) set_device(args.device_id)
main() main()
@@ -280,7 +280,7 @@ class PrefixCacheManager:
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ (" --create_cache_tensor" if create_cache_tensor else "") + (" --create_cache_tensor" if create_cache_tensor else "")
+ f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" + f" >{log_dir}/launch_cache_transfer_manager_tprank{i}.log 2>&1"
) )
logger.info(f"Launch cache transfer manager, command:{launch_cmd}") logger.info(f"Launch cache transfer manager, command:{launch_cmd}")
cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_manager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
@@ -372,7 +372,7 @@ class PrefixCacheManager:
+ f" --engine_pid {pid_suffix}" + f" --engine_pid {pid_suffix}"
+ f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}" + f" --rdma_port {cache_config.rdma_comm_ports[i] if cache_config.rdma_comm_ports is not None else '0'}"
+ f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --speculative_config '{self.speculative_config.to_json_string()}'"
+ f" >{log_dir}/launch_cache_messager_{int(device_ids[i])}.log 2>&1" + f" >{log_dir}/launch_cache_messager_tprank{i}.log 2>&1"
) )
logger.info(f"Launch cache messager, command:{launch_cmd}") logger.info(f"Launch cache messager, command:{launch_cmd}")
cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid)) cache_messager_processes.append(subprocess.Popen(launch_cmd, shell=True, preexec_fn=os.setsid))
+6 -1
View File
@@ -545,6 +545,7 @@ class ParallelConfig:
self.tensor_parallel_size = 1 # TP degree self.tensor_parallel_size = 1 # TP degree
self.expert_parallel_rank = 0 # EP rank ID self.expert_parallel_rank = 0 # EP rank ID
self.expert_parallel_size = 1 # EP degree self.expert_parallel_size = 1 # EP degree
self.data_parallel_rank = 0 # DP rank ID
self.data_parallel_size = 1 # DP degree self.data_parallel_size = 1 # DP degree
self.enable_expert_parallel = False self.enable_expert_parallel = False
self.enable_chunked_moe = False self.enable_chunked_moe = False
@@ -1887,7 +1888,11 @@ class FDConfig:
engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[ engine_worker_queue_port = self.parallel_config.engine_worker_queue_port[
self.parallel_config.local_data_parallel_id self.parallel_config.local_data_parallel_id
] ]
connector_port = self.cache_config.pd_comm_port[0] if self.cache_config.pd_comm_port else None connector_port = (
self.cache_config.pd_comm_port[self.parallel_config.local_data_parallel_id]
if self.cache_config.pd_comm_port
else None
)
self.disaggregate_info = {} self.disaggregate_info = {}
if self.scheduler_config.splitwise_role != "mixed": if self.scheduler_config.splitwise_role != "mixed":
+38 -16
View File
@@ -82,9 +82,9 @@ class EngineService:
self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id] self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id]
) )
if self.cfg.parallel_config.enable_expert_parallel: if self.cfg.parallel_config.data_parallel_size > 1:
self.llm_logger = get_logger( self.llm_logger = get_logger(
"fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log" "fastdeploy", f"fastdeploy_dprank{self.cfg.parallel_config.local_data_parallel_id}.log"
) )
else: else:
self.llm_logger = llm_logger self.llm_logger = llm_logger
@@ -716,7 +716,11 @@ class EngineService:
is_fetching = False is_fetching = False
return return
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}") if tasks:
self.llm_logger.debug(
f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}"
)
if self.cfg.scheduler_config.splitwise_role != "mixed": if self.cfg.scheduler_config.splitwise_role != "mixed":
need_delete_tasks = [] need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD: if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
@@ -724,22 +728,24 @@ class EngineService:
# assure can allocate block ids in P # assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task): while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005) time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
while True: while True:
self.split_connector.send_splitwise_tasks([task], task.idx) self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task) status, msg = self.split_connector.check_decode_allocated(task)
if not status: if not status:
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.") self.llm_logger.error(
f"D failed to allocate resource for request {task.request_id}, try again."
)
time.sleep(0.05) time.sleep(0.05)
else: else:
break break
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
else: else:
for task in tasks: for task in tasks:
# assure can allocate block ids in P # assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task): while not self.resource_manager.preallocate_resource_in_p(task):
self.llm_logger.info("wait for preallocate_resource_in_p")
time.sleep(0.005) time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") self.llm_logger.debug(f"P has allocated resources for request: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx) self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks: for task in tasks:
@@ -747,7 +753,9 @@ class EngineService:
# assure fetch block ids from D # assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task) status, msg = self.split_connector.check_decode_allocated(task)
if not status: if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") self.llm_logger.error(
f"D failed to allocate resource for request {task.request_id}, message: {msg}."
)
self.scheduler.put_results( self.scheduler.put_results(
[ [
RequestOutput( RequestOutput(
@@ -760,25 +768,32 @@ class EngineService:
) )
need_delete_tasks.append(task) need_delete_tasks.append(task)
continue continue
else:
self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}")
for tmp_task in need_delete_tasks: for tmp_task in need_delete_tasks:
tasks.remove(tmp_task) tasks.remove(tmp_task)
# release resource in P # release resource in P
self.resource_manager.pre_recycle_resource(tmp_task.request_id) self.resource_manager.pre_recycle_resource(tmp_task.request_id)
if self.cfg.scheduler_config.splitwise_role == "prefill": if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager # to send cache info to cache messager
if tasks: if tasks:
need_check_req_ids = [task.request_id for task in tasks]
self.split_connector.send_cache_info_to_messager(tasks, 0) self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager # ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
while need_check_req_ids: while need_check_req_ids:
req_ids = self.engine_worker_queue.get_finished_add_cache_task_req() req_ids = self.engine_worker_queue.get_finished_add_cache_task_req()
self.llm_logger.info(f"get_finished_add_cache_task_req: {req_ids}")
if req_ids: if req_ids:
self.llm_logger.debug(
f"P has successfully sent cache infos to cache messager for requests: {req_ids}"
)
for req_id in req_ids: for req_id in req_ids:
assert req_id in need_check_req_ids assert req_id in need_check_req_ids
need_check_req_ids.remove(req_id) need_check_req_ids.remove(req_id)
else: else:
time.sleep(0.001) time.sleep(0.001)
# Fetch requests and add them to the scheduling queue # Fetch requests and add them to the scheduling queue
if tasks: if tasks:
for task in tasks: for task in tasks:
@@ -787,6 +802,9 @@ class EngineService:
) )
if self.cfg.scheduler_config.splitwise_role == "prefill": if self.cfg.scheduler_config.splitwise_role == "prefill":
self.resource_manager.add_request_in_p(tasks) self.resource_manager.add_request_in_p(tasks)
self.llm_logger.info(
f"P add requests into running queue: {[task.request_id for task in tasks]}"
)
else: else:
for task in tasks: for task in tasks:
self.resource_manager.add_request(task) self.resource_manager.add_request(task)
@@ -917,7 +935,6 @@ class EngineService:
request.llm_engine_recv_req_timestamp = time.time() request.llm_engine_recv_req_timestamp = time.time()
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER) start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc() main_process_metrics.requests_number.inc()
self.llm_logger.debug(f"Receive request: {request}")
trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.PREPROCESSING_END, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_SCHEDULE_START, data["request_id"], data.get("user", ""))
trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", ""))
@@ -1082,10 +1099,14 @@ class EngineService:
for item in items: for item in items:
tasks = item[1] tasks = item[1]
if isinstance(tasks[0], Request): if isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}") self.llm_logger.debug(
f"D has received tasks to preallocate resource for tasks: {[task.request_id for task in tasks]}"
)
allocate_resource_requests.extend(tasks) allocate_resource_requests.extend(tasks)
elif isinstance(tasks[0], RequestOutput): elif isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}") self.llm_logger.debug(
f"D has received tasks to process prefilled tasks: {[task.request_id for task in tasks]}"
)
if not isinstance(tasks, list): if not isinstance(tasks, list):
tasks = [tasks] tasks = [tasks]
for task in tasks: for task in tasks:
@@ -1099,13 +1120,13 @@ class EngineService:
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task): if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_info_to_prefill([task]) self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}")
processed_indices.append(idx) processed_indices.append(idx)
is_success = True is_success = True
else: else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.llm_logger.info(f"Resource available, processing task {task.request_id}") self.llm_logger.debug(f"D Resource available, processing task {task.request_id}")
self.insert_tasks([task]) self.insert_tasks([task])
processed_indices.append(idx) processed_indices.append(idx)
is_success = True is_success = True
@@ -1114,6 +1135,7 @@ class EngineService:
if not self.enable_decode_cache_task: if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources" task.error_msg = "Not enough resources"
self.split_connector.send_cache_info_to_prefill([task]) self.split_connector.send_cache_info_to_prefill([task])
self.llm_logger.warning(f"D has failed to send cache infos for task {task.request_id}")
processed_indices.append(idx) processed_indices.append(idx)
else: else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}") self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
@@ -1169,7 +1191,7 @@ class EngineService:
if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance
self.scheduler.put_results([req_output]) self.scheduler.put_results([req_output])
self.resource_manager.add_prefilled_request(req_output) self.resource_manager.add_prefilled_request(req_output)
self.llm_logger.debug(f"add prefilled request success, {request_id}") self.llm_logger.info(f"D has successfully added prefilled request, {request_id}")
def decode_loop(): def decode_loop():
while self.running: while self.running:
+12 -12
View File
@@ -61,7 +61,6 @@ class ExpertService:
] ]
self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos] self.cfg.local_device_ids = self.cfg.parallel_config.device_ids.split(",")[start_pos:end_pos]
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}") llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
self.cfg.disaggregate_info = None
if self.cfg.cache_config.num_gpu_blocks_override is None: if self.cfg.cache_config.num_gpu_blocks_override is None:
self.do_profile = True self.do_profile = True
@@ -127,17 +126,18 @@ class ExpertService:
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
if not envs.FD_ENABLE_MULTI_API_SERVER: if not envs.FD_ENABLE_MULTI_API_SERVER:
launched_expert_service_signal_data = np.zeros( if self.cfg.parallel_config.enable_expert_parallel:
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32 launched_expert_service_signal_data = np.zeros(
) shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
self.launched_expert_service_signal = IPCSignal( )
name="launched_expert_service_signal", self.launched_expert_service_signal = IPCSignal(
array=launched_expert_service_signal_data, name="launched_expert_service_signal",
dtype=np.int32, array=launched_expert_service_signal_data,
suffix=ipc_signal_suffix, dtype=np.int32,
create=False, suffix=ipc_signal_suffix,
) create=False,
self.launched_expert_service_signal.value[local_rank] = 1 )
self.launched_expert_service_signal.value[local_rank] = 1
if self.do_profile: if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32) get_profile_block_num = np.zeros([1], dtype=np.int32)
while True: while True:
@@ -496,6 +496,8 @@ class EngineWorkerQueue:
self.tasks.append(tasks) self.tasks.append(tasks)
self.lock.release() self.lock.release()
llm_logger.debug(f"put_tasks: tasks={tasks}")
def get_tasks(self) -> Tuple[List[Any], bool]: def get_tasks(self) -> Tuple[List[Any], bool]:
""" """
Retrieve tasks from the shared queue and update read status. Retrieve tasks from the shared queue and update read status.
@@ -512,6 +514,7 @@ class EngineWorkerQueue:
if all_client_read: if all_client_read:
self.tasks[:] = list() self.tasks[:] = list()
self.lock.release() self.lock.release()
llm_logger.debug(f"get_tasks: tasks={tasks}")
return tasks, all_client_read return tasks, all_client_read
def num_tasks(self) -> int: def num_tasks(self) -> int:
@@ -600,8 +603,7 @@ class EngineWorkerQueue:
self.cache_infos.extend(cache_info) self.cache_infos.extend(cache_info)
llm_logger.debug( llm_logger.debug(
f"put cache_infos to engine worker queue: {self.cache_infos}, " f"put_cache_info: cache_info={cache_info}, local_data_parallel_id={self.local_data_parallel_id}"
f"local_data_parallel_id:{self.local_data_parallel_id}"
) )
self.lock_info.release() self.lock_info.release()
@@ -214,6 +214,9 @@ class ZmqServerBase(ABC):
except zmq.Again: except zmq.Again:
time.sleep(0.001) time.sleep(0.001)
continue continue
except zmq.error.ZMQError as e:
llm_logger.error(f"recv_result_handle get zmq error: {e}")
break
except Exception as e: except Exception as e:
llm_logger.error(f"recv_result_handle get unknown exception: {e}") llm_logger.error(f"recv_result_handle get unknown exception: {e}")
continue continue
+1 -5
View File
@@ -402,12 +402,8 @@ class TokenProcessor:
rank_id, rank_id,
is_blocking, is_blocking,
) )
elif ( elif self.cfg.parallel_config.data_parallel_size > 1:
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
get_output_ep(self.output_tokens, rank_id, is_blocking) get_output_ep(self.output_tokens, rank_id, is_blocking)
else: else:
get_output(self.output_tokens, rank_id, is_blocking) get_output(self.output_tokens, rank_id, is_blocking)
+4 -1
View File
@@ -91,6 +91,7 @@ class Router:
self.prefill_servers = [] self.prefill_servers = []
self.decode_servers = [] self.decode_servers = []
self.lock = asyncio.Lock() # async-safe lock self.lock = asyncio.Lock() # async-safe lock
logger.info("Router started at http://{}:{}".format(self.host, self.port))
async def register_instance(self, instance_info_dict: dict): async def register_instance(self, instance_info_dict: dict):
"""Register an instance asynchronously""" """Register an instance asynchronously"""
@@ -172,6 +173,8 @@ class Router:
async def handle_splitwise_request(self, request_data: dict, endpoint_name: str): async def handle_splitwise_request(self, request_data: dict, endpoint_name: str):
logger.debug(f"Received request: {request_data}") logger.debug(f"Received request: {request_data}")
prefill_server, decode_server = await self.select_pd() prefill_server, decode_server = await self.select_pd()
logger.debug(f"Selected prefill server: {prefill_server}")
logger.debug(f"Selected decode server: {decode_server}")
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
raise HTTPException( raise HTTPException(
@@ -371,4 +374,4 @@ def launch_router(router_args: RouterArgs):
app.state.router = Router(app.state.router_args) app.state.router = Router(app.state.router_args)
asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5)) asyncio.create_task(app.state.router.monitor_instance_health(interval_secs=5))
uvicorn.run(app, host=router_args.host, port=router_args.port) uvicorn.run(app, host=router_args.host, port=int(router_args.port))
+30 -36
View File
@@ -44,9 +44,10 @@ class SplitwiseConnector:
resource_manager (object): Resource manager object. resource_manager (object): Resource manager object.
""" """
self.cfg = cfg self.cfg = cfg
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1: self.local_data_parallel_id = self.cfg.parallel_config.local_data_parallel_id
if self.cfg.parallel_config.data_parallel_size > 1:
self.logger = get_logger( self.logger = get_logger(
"splitwise_connector", f"splitwise_connector_{self.cfg.parallel_config.local_data_parallel_id}.log" "splitwise_connector", f"splitwise_connector_dprank{self.local_data_parallel_id}.log"
) )
else: else:
self.logger = get_logger("splitwise_connector", "splitwise_connector.log") self.logger = get_logger("splitwise_connector", "splitwise_connector.log")
@@ -54,7 +55,6 @@ class SplitwiseConnector:
self.resource_manager = resource_manager self.resource_manager = resource_manager
self.connect_innode_instances = {} self.connect_innode_instances = {}
self.current_request_ids = dict() self.current_request_ids = dict()
self.idx = self.cfg.parallel_config.local_data_parallel_id
self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1" self.enable_decode_cache_task = envs.FD_ENABLE_CACHE_TASK == "1"
if self.cfg.cache_config.pd_comm_port is not None: if self.cfg.cache_config.pd_comm_port is not None:
@@ -74,7 +74,7 @@ class SplitwiseConnector:
self.router_socket.setsockopt(zmq.SNDHWM, 1000) self.router_socket.setsockopt(zmq.SNDHWM, 1000)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1) self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}") self.router_socket.bind(f"tcp://*:{self.cfg.cache_config.pd_comm_port[0]}")
self.logger.info(f"bind {self.cfg.cache_config.pd_comm_port}") self.logger.info(f"_init_network: bind {self.cfg.cache_config.pd_comm_port}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
@@ -94,17 +94,17 @@ class SplitwiseConnector:
if not socks: if not socks:
continue continue
else: else:
self.logger.debug(f"receive {socks}") self.logger.debug(f"start_receiver: receive {socks}")
frames = self.router_socket.recv_multipart() frames = self.router_socket.recv_multipart()
self.logger.debug(f"frames: {frames}") self.logger.debug(f"start_receiver: frames: {frames}")
message = frames[-1] message = frames[-1]
self.io_executor.submit(self._process_message, message) self.io_executor.submit(self._process_message, message)
time.sleep(0.001) time.sleep(0.001)
else: else:
time.sleep(5) time.sleep(5)
except Exception as e: except Exception as e:
self.logger.error(f"Receiver error: {e}, {str(traceback.format_exc())}") self.logger.error(f"start_receiver: Receiver error: {e}, {str(traceback.format_exc())}")
time.sleep(1) time.sleep(1)
def _get_push_socket(self, addr): def _get_push_socket(self, addr):
@@ -116,7 +116,7 @@ class SplitwiseConnector:
return sock return sock
try: try:
self.logger.info(f"Establishing new connection to {addr}") self.logger.info(f"_get_push_socket: Establishing new connection to {addr}")
sock = self.zmq_ctx.socket(zmq.DEALER) sock = self.zmq_ctx.socket(zmq.DEALER)
# 设置连接参数 # 设置连接参数
@@ -135,36 +135,29 @@ class SplitwiseConnector:
return sock return sock
except zmq.ZMQError as e: except zmq.ZMQError as e:
self.logger.error(f"Connection to {addr} failed: {e}") self.logger.error(f"_get_push_socket: Connection to {addr} failed: {e}")
raise ConnectionError(f"Failed to connect to {addr}") from e raise ConnectionError(f"Failed to connect to {addr}") from e
def _send_message(self, addr, msg_type: str, payload): def _send_message(self, addr, msg_type: str, payload):
if not addr: if not addr:
return return
try: try:
self.logger.info(f"Sent {msg_type} to {addr}")
message = self._serialize_message(msg_type, payload) message = self._serialize_message(msg_type, payload)
try: try:
self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}")
sock = self._get_push_socket(addr) sock = self._get_push_socket(addr)
sock.send_multipart([b"", message]) sock.send_multipart([b"", message])
self.logger.info(f"Sent {msg_type} to {addr}")
except ConnectionError: except ConnectionError:
self.logger.warning(f"Connection to {addr} not established") self.logger.warning(f"_send_message: Connection to {addr} not established")
except zmq.Again: except zmq.Again:
self.logger.warning(f"Send queue full for {addr}") self.logger.warning(f"_send_message: Send queue full for {addr}")
except Exception as e: except Exception as e:
self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}") self.logger.error(f"_send_message: Send to {addr} failed: {e}, {str(traceback.format_exc())}")
main_process_metrics.send_cache_failed_num.inc() main_process_metrics.send_cache_failed_num.inc()
self._close_connection(addr) self._close_connection(addr)
except Exception as e: except Exception as e:
self.logger.error(f"Message preparation failed: {e}") self.logger.error(f"_send_message: Message preparation failed: {e}")
def _close_connection(self, addr): def _close_connection(self, addr):
""" """
@@ -191,21 +184,20 @@ class SplitwiseConnector:
if task.disaggregate_info["transfer_protocol"] == "ipc": if task.disaggregate_info["transfer_protocol"] == "ipc":
addr = task.disaggregate_info["cache_info"]["ipc"]["port"] addr = task.disaggregate_info["cache_info"]["ipc"]["port"]
task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id task.disaggregate_info["cache_info"]["ipc"]["current_id"] = current_id
self.logger.info(f"send_splitwise_tasks: protocol=ipc, addr={addr}, task={task.request_id}")
self.send_splitwise_tasks_innode([task], addr) self.send_splitwise_tasks_innode([task], addr)
else: else:
addr = ( addr = (
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}" + f"{task.disaggregate_info['cache_info']['rdma']['port']}"
) )
self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}")
self.current_request_ids[task.request_id] = "init" self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"] decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode" task.disaggregate_info["role"] = "decode"
self.logger.debug(f"send task to coupled instance, {addr}, {task}") self.logger.info(f"send_splitwise_tasks: protocol=rdma, addr={addr}, task={task.request_id}")
self._send_message(addr, "prefill", [task]) self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill" task.disaggregate_info["role"] = "prefill"
@@ -226,12 +218,12 @@ class SplitwiseConnector:
self.create_connection(port) self.create_connection(port)
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[ task.disaggregate_info["cache_info"]["ipc"]["port"] = self.cfg.parallel_config.engine_worker_queue_port[
self.idx self.local_data_parallel_id
] ]
self.logger.info(f"send_splitwise_tasks_innode: port={port}, tasks={[task.request_id for task in tasks]}")
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks))
for task in tasks: for task in tasks:
task.disaggregate_info["cache_info"]["ipc"]["port"] = port task.disaggregate_info["cache_info"]["ipc"]["port"] = port
self.logger.info(f"send splitwise tasks to port {port} decode")
current_port = port current_port = port
return current_port return current_port
@@ -241,7 +233,7 @@ class SplitwiseConnector:
""" """
if not isinstance(tasks_list, list): if not isinstance(tasks_list, list):
tasks_list = [tasks_list] tasks_list = [tasks_list]
self.logger.info(f"send first token to decode, {[x.request_id for x in tasks_list]}") self.logger.info(f"send_first_token: send first token to decode, {[x.request_id for x in tasks_list]}")
if prefill_msg["transfer_protocol"] == "ipc": if prefill_msg["transfer_protocol"] == "ipc":
port = prefill_msg["cache_info"]["ipc"]["port"] port = prefill_msg["cache_info"]["ipc"]["port"]
if port not in self.connect_innode_instances: if port not in self.connect_innode_instances:
@@ -249,7 +241,7 @@ class SplitwiseConnector:
self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list)) self.connect_innode_instances[port].put_disaggregated_tasks(("decode", tasks_list))
else: else:
node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}"
self.logger.info(f"send first token to port {node} decode") self.logger.info(f"send_first_token: send first token to port {node} decode")
self._send_message(node, "decode", tasks_list) self._send_message(node, "decode", tasks_list)
def create_connection(self, port): def create_connection(self, port):
@@ -288,7 +280,7 @@ class SplitwiseConnector:
del self.current_request_ids[task.request_id] del self.current_request_ids[task.request_id]
if msg == "finished": if msg == "finished":
return True, "" return True, ""
self.logger.error(f"Receive_decode_allocated error: {msg}") self.logger.error(f"check_decode_allocated: Receive_decode_allocated error: {msg}")
return False, msg return False, msg
def send_cache_info_to_messager(self, tasks: List[Request], current_id): def send_cache_info_to_messager(self, tasks: List[Request], current_id):
@@ -359,9 +351,11 @@ class SplitwiseConnector:
else: else:
info = { info = {
"request_id": tasks[i].request_id, "request_id": tasks[i].request_id,
"device_ids": self.cfg.parallel_config.device_ids.split(","), "device_ids": [self.cfg.parallel_config.device_ids.split(",")[self.local_data_parallel_id]],
"ip": self.cfg.host_ip, "ip": self.cfg.host_ip,
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"], "rdma_ports": [
self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"][self.local_data_parallel_id]
],
"transfer_protocol": "rdma", "transfer_protocol": "rdma",
"dest_block_ids": dsg_info["block_tables"], "dest_block_ids": dsg_info["block_tables"],
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size, "decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
@@ -404,7 +398,7 @@ class SplitwiseConnector:
""" """
try: try:
msg_type, payload = self._deserialize_message(message) msg_type, payload = self._deserialize_message(message)
self.logger.info(f"{msg_type}") self.logger.info(f"_process_message: {msg_type}")
if msg_type == "prefill": if msg_type == "prefill":
self._handle_prefill(payload) self._handle_prefill(payload)
@@ -412,7 +406,7 @@ class SplitwiseConnector:
self._handle_decode(payload) self._handle_decode(payload)
elif msg_type == "cache_sync": elif msg_type == "cache_sync":
for task in payload: for task in payload:
self.logger.info(f"cache_sync task: {task}") self.logger.info(f"_process_message: cache_sync task: {task}")
current_status = task.get("error_msg", "finished") current_status = task.get("error_msg", "finished")
self.current_request_ids[task["request_id"]] = current_status self.current_request_ids[task["request_id"]] = current_status
if self.enable_decode_cache_task: if self.enable_decode_cache_task:
@@ -421,13 +415,13 @@ class SplitwiseConnector:
self.engine_worker_queue.put_cache_info(payload) self.engine_worker_queue.put_cache_info(payload)
except Exception as e: except Exception as e:
self.logger.error(f"Message processing failed: {e}, {str(traceback.format_exc())}") self.logger.error(f"_process_message: Message processing failed: {e}, {str(traceback.format_exc())}")
def _handle_prefill(self, tasks): def _handle_prefill(self, tasks):
""" """
Handle prefill tasks from other nodes. Handle prefill tasks from other nodes.
""" """
self.logger.debug(f"_handle_prefill function receive {tasks}") self.logger.debug(f"_handle_prefill: receive payload {tasks}")
tasks_data = [Request.from_dict(task) for task in tasks] tasks_data = [Request.from_dict(task) for task in tasks]
self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data)) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks_data))
@@ -435,7 +429,7 @@ class SplitwiseConnector:
""" """
Handle decode tasks from other nodes. Handle decode tasks from other nodes.
""" """
self.logger.debug(f"_handle_decode function receive {payload}") self.logger.debug(f"_handle_decode: receive payload {payload}")
tasks = [] tasks = []
for task in payload: for task in payload:
tasks.append(RequestOutput.from_dict(task)) tasks.append(RequestOutput.from_dict(task))
+11 -6
View File
@@ -173,7 +173,11 @@ class PaddleDisWorkerProc:
model_weights_status: model_weights_status:
""" """
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if self.parallel_config.data_parallel_size > 1 and not envs.FD_ENABLE_MULTI_API_SERVER: if (
self.parallel_config.enable_expert_parallel
and self.parallel_config.data_parallel_size > 1
and not envs.FD_ENABLE_MULTI_API_SERVER
):
launched_expert_service_signal_data = np.zeros( launched_expert_service_signal_data = np.zeros(
shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32 shape=[self.parallel_config.data_parallel_size // self.fd_config.nnode], dtype=np.int32
) )
@@ -905,6 +909,12 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
# config for DP
if parallel_config.data_parallel_size > 1:
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
max_chips_per_node // parallel_config.tensor_parallel_size
)
# config for EP # config for EP
if parallel_config.expert_parallel_size > 1: if parallel_config.expert_parallel_size > 1:
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size) expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
@@ -914,11 +924,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
num_experts = model_config.moe_num_experts num_experts = model_config.moe_num_experts
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
parallel_config.local_data_parallel_id = parallel_config.data_parallel_rank % (
max_chips_per_node // parallel_config.tensor_parallel_size
)
parallel_config.expert_parallel_rank = expert_parallel_rank parallel_config.expert_parallel_rank = expert_parallel_rank
parallel_config.num_experts_per_rank = num_experts_per_rank parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset parallel_config.num_experts_start_offset = num_experts_start_offset