[Feature] Add Golang-based Router for Request Scheduling and Load Balancing (#5882)

* [Feature] add golang router

* [Feature] add golang router

* [Feature] add golang router

* [Feature] add golang router

* [Feature] add golang router

* [Feature] Add Golang-based Router for Request Scheduling and Load Balancing

* [Feature] Add Golang-based Router for Request Scheduling and Load Balancing

* [Feature] Add Golang-based Router for Request Scheduling and Load Balancing

* [Feature] Add Golang-based Router for Request Scheduling and Load Balancing

---------

Co-authored-by: mouxin <mouxin@baidu.com>
This commit is contained in:
mouxin
2026-01-07 21:28:08 +08:00
committed by GitHub
parent 925e7edd3c
commit 0a92e96f20
50 changed files with 6298 additions and 0 deletions
+107
View File
@@ -0,0 +1,107 @@
#!/bin/bash
set -e
# Test splitwise deployment
# There are two methods for splitwise deployment:
# v0: using splitwise_scheduler or dp_scheduler
# v1: using local_scheduler + router
# v2: using local_scheduler + golang_router
# prepare environment
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
export FD_DEBUG=1
SCRIPT_PATH=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
source ${SCRIPT_DIR}/utils.sh
unset http_proxy && unset https_proxy
P_PORT=52400
D_PORT=52500
ROUTER_PORT=52700
LOG_DATE=$(date +%Y%m%d_%H%M%S)
FD_BIN_DIR="/usr/local/bin"
FD_ROUTER_BIN="${FD_BIN_DIR}/fd-router"
FD_ROUTER_URL="https://paddle-qa.bj.bcebos.com/FastDeploy/fd-router"
FD_ROUTER_SHA256="67640aaeebdd886826d3534930b2154cd2c1441a26bc3f38c3af5f0aadba7c2d"
ports=($P_PORT $D_PORT $ROUTER_PORT)
check_ports "${ports[@]}" || {
echo "❌ Some ports are in use. Please release them."
exit 1
}
# check fd-router binary
if [ ! -x "${FD_ROUTER_BIN}" ]; then
echo "⚠️ fd-router not found, downloading..."
mkdir -p "${FD_BIN_DIR}"
TMP_BIN="${FD_ROUTER_BIN}.tmp"
wget -q --no-proxy "${FD_ROUTER_URL}" -O "${TMP_BIN}" || exit 1
echo "${FD_ROUTER_SHA256} ${TMP_BIN}" | sha256sum -c - || {
echo "❌ Integrity check failed"
rm -f "${TMP_BIN}"
exit 1
}
mv "${TMP_BIN}" "${FD_ROUTER_BIN}"
chmod +x "${FD_ROUTER_BIN}"
echo "fd-router installed and verified"
else
echo "fd-router already exists"
fi
# start router
export FD_LOG_DIR="log/$LOG_DATE/router"
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
nohup /usr/local/bin/fd-router \
--port ${ROUTER_PORT} \
--splitwise \
2>&1 >${FD_LOG_DIR}/nohup &
# start prefill
export CUDA_VISIBLE_DEVICES=0
export FD_LOG_DIR="log/$LOG_DATE/prefill"
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port "${P_PORT}" \
--splitwise-role "prefill" \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health ${P_PORT}
# start decode
export CUDA_VISIBLE_DEVICES=1
export FD_LOG_DIR="log/$LOG_DATE/decode"
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
nohup python -m fastdeploy.entrypoints.openai.api_server \
--model ${MODEL_NAME} \
--port "${D_PORT}" \
--splitwise-role "decode" \
--router "0.0.0.0:${ROUTER_PORT}" \
2>&1 >${FD_LOG_DIR}/nohup &
wait_for_health ${D_PORT}
# send request
sleep 10 # make sure server is registered to router
echo "send request..."
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "hello"}
],
"max_tokens": 100,
"stream": false
}'
+31
View File
@@ -0,0 +1,31 @@
HOMEDIR := $(shell pwd)
OUTDIR := /usr/local/bin
export GOENV = $(HOMEDIR)/go.env
GOPKGS := $$(go list ./...| grep -vE "vendor")
# make, make all
all: prepare compile
prepare:
git version # 低于 2.17.1 可能不能正常工作
go env
go mod download || go mod download -x # 下载 依赖
#make compile
compile: build
build:
go build -o $(OUTDIR)/fd-router ./cmd
# make test, test your code
test: prepare
go test -race -timeout=120s -v -cover $(GOPKGS) -coverprofile=coverage.out | tee unittest.txt
# make clean
clean:
go clean
rm -rf $(OUTDIR)
# avoid filename conflict and speed up build
.PHONY: all prepare compile test package clean build
+184
View File
@@ -0,0 +1,184 @@
# Golang-Router
## 关于
【正在开发迭代中】
Golang-Router 是一个面向大语言模型推理系统的高性能 Golang 路由框架,作为系统的**控制与调度平面**运行,负责请求接入、实例选择与流量转发,设计上适配 Prefill–DecodePD)分离推理架构。
Golang-Router 可独立部署运行,也可通过 HTTP 接口与 FastDeploy 推理实例协同工作。框架提供基础而稳定的路由、中间件扩展与健康检查能力,适用于单点推理部署场景,并在架构层面为后续的水平扩展与调度能力演进预留空间。
### 背景与动机
在大语言模型推理系统中,路由组件已从传统的流量转发层演进为影响系统性能与资源利用效率的关键基础设施。随着 Prefill–Decode 分离推理架构的广泛采用,不同推理阶段在计算特征、显存占用与缓存行为方面呈现出明显差异,仅依赖请求级静态信息进行调度已难以满足稳定性与效率需求。
在保持请求级调度模型不变的前提下,引入更细粒度的运行时信号辅助调度决策,成为提升调度能力与系统可预测性的工程共识。Golang-Router 正是在这一背景下构建,作为独立的路由与调度组件,为推理系统提供清晰、可扩展的控制平面。
### 设计目标
Golang-Router 聚焦解决以下核心问题:
- **调度决策信息不足**
传统 Router 通常仅基于请求级元信息或粗粒度实例状态进行调度,难以利用推理过程中产生的细粒度缓存相关信号,从而限制了 cache-aware 策略的实际效果。
- **调度逻辑与推理执行强耦合**
路由与调度逻辑内嵌于推理框架内部,增加了系统复杂度,限制了调度策略的独立演进与复用能力。
- **高并发场景下的可扩展性挑战**
在高并发推理负载下,实例状态维护与实例选择逻辑对路由组件的并发模型、性能与稳定性提出更高要求。
### 核心特性
- 基于 Golang 实现的高性能路由与调度组件,适用于高并发、低延迟推理场景
- 请求级调度模型,保持接口语义清晰与系统复杂度可控
- 利用token级缓存相关运行时信息作为调度策略的辅助输入,用于提升实例选择的准确性与稳定性
- 模块化架构设计(Gateway / Scheduler / Manager),职责边界清晰,便于扩展与维护
- 面向 Prefill–Decode 分离推理架构设计,为复杂调度策略与能力演进提供结构性支持
### 与现有方案的差异
与 sglang 等推理框架内置 Router 相比,Golang-Router 以**独立 Golang 服务**的形式运行,将路由、调度与实例状态管理能力从推理执行逻辑中解耦。
Golang-Router 已支持 cache-aware 调度,在请求级调度框架内引入 token 级缓存相关运行时信号,辅助调度决策制定,以更稳定地适配 Prefill–Decode 分离推理架构下的缓存利用需求。
## 功能特性
- 高性能 HTTP/HTTPS 服务器
- RESTful API 路由支持
- 可扩展的中间件系统
- 动态配置管理
- 内置健康检查和监控
- 负载均衡
- 日志记录和指标收集
## 快速开始
### 前置要求
- Go 1.21
- 构建不依赖特定系统环境
- 可直接在 FastDeploy 官方 Docker 环境中编译与运行
### 编译
```bash
./build.sh
```
### 配置
1. 配置文件准备(可选)
如需修改默认配置,可复制配置模板并进行调整(示例可参考 examples/run_with_config):
```bash
cp config/config.example.yaml config/config.yaml
```
2. 主要配置项说明:
```yaml
server:
port: "8080" # 监听端口
host: "0.0.0.0" # 监听地址
mode: "debug" # 启动模式: debug, release, test
splitwise: true # true代表开启pd分离模式,false代表开启非pd分离模式
scheduler:
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略
eviction-interval-secs: 60 # cache-aware策略清理过期cache的间隔时间
balance-abs-threshold: 1 # cache-aware策略绝对阈值
balance-rel-threshold: 0.2 # cache-aware策略相对阈值
hit-ratio-weight: 1.0 # cache-aware策略命中率权重
load-balance-weight: 0.05 # cache-aware策略负载均衡权重
cache-block-size: 4 # cache-aware策略cache block大小
tokenizer-url: "http://0.0.0.0:8098" # tokenizer服务地址(可选)
tokenizer-timeout-secs: 2 # tokenizer服务超时时间
waiting-weight: 10 # cache-aware策略等待权重
manager:
health-failure-threshold: 3 # 健康检查失败次数,超过次数后认为节点不健康
health-success-threshold: 2 # 健康检查成功次数,超过次数后认为节点健康
health-check-timeout-secs: 5 # 健康检查超时时间
health-check-interval-secs: 5 # 健康检查间隔时间
health-check-endpoint: /health # 健康检查接口
register-path: "config/register.yaml" # 推理实例注册配置文件路径(可选)
log:
level: "info" # 日志打印级别: debug / info / warn / error
output: "file" # 日志输出方式: stdout / file
```
3. 启动时注册实例(可选)
支持通过配置文件在启动阶段注册推理实例(示例可参考 examples/run_with_default_workers):
```bash
cp config/config.example.yaml config/config.yaml
cp config/register.example.yaml config/register.yaml
```
### 运行
本项目支持两种运行方式:直接运行源码 或 构建二进制文件后运行。
方式一:直接运行源码
在项目根目录下,使用 go run 启动服务:
```bash
go run cmd/main.go
```
该方式适用于本地开发与调试场景。
方式二:构建并运行二进制文件
1. 构建二进制文件
通过构建脚本生成可执行文件:
```bash
./build.sh
```
构建完成后,二进制文件将被安装到指定目录(默认为 /usr/local/bin,可通过修改 Makefile 中的 OUTDIR 进行调整)。
此外,也可以在项目根目录下手动构建二进制文件:
```bash
go build -o ./fd-router ./cmd
```
该方式便于本地测试或将二进制文件与配置文件一并分发。
2. 运行二进制文件
可以通过运行脚本启动服务:
```bash
./run.sh
```
运行脚本会自动处理常见启动参数及日志目录,适合标准化部署场景。
也可以直接运行二进制文件,在项目根目录或二进制所在目录下执行:
```bash
./fd-router \
--port 8080 \
--splitwise \
--config_path ./config/config.yaml
```
其中:
- --port 为必填参数
- 其他参数可根据实际需求配置
## 项目结构
```
.
├── cmd/ # 主程序入口
├── config/ # 配置文件
├── internal/ # 核心实现代码
│ ├── common/ # 公共接口定义
│ ├── config/ # 配置处理
│ ├── gateway/ # API网关实现
│ ├── manager/ # 路由管理
│ ├── middleware/ # 中间件实现
│ ├── router/ # 路由核心逻辑
│ └── scheduler/ # 调度器实现
├── logs/ # 日志目录
├── output/ # 构建输出
├── pkg/ # 可复用组件
│ ├── logger/ # 日志组件
│ └── metrics/ # 监控指标
├── build.sh # 构建脚本
├── go.mod # Go模块定义
├── go.sum # 依赖校验
├── Makefile # 构建管理
├── README.md # 项目说明
└── run.sh # 启动脚本
```
### 运行测试
```bash
make test
```
## 贡献
欢迎提交 Issue 和 Pull Request
+4
View File
@@ -0,0 +1,4 @@
#!/bin/bash
set -e
make all
+54
View File
@@ -0,0 +1,54 @@
package main
import (
"context"
"flag"
"log"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
"github.com/PaddlePaddle/FastDeploy/router/internal/router"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
)
func main() {
// Parse command line arguments
var configPath, port string
var splitwise bool
flag.StringVar(&configPath, "config_path", "", "path to config file")
flag.StringVar(&port, "port", "", "listen port of router")
flag.BoolVar(&splitwise, "splitwise", false, "enable splitwise mode")
flag.Parse()
// Load configuration
cfg, err := config.Load(configPath, port, splitwise)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// Initialize logger
logger.Init(cfg.Log.Level, cfg.Log.Output)
defer logger.CloseLogFile()
// Initialize manager
manager.Init(cfg)
scheduler_handler.Init(cfg, manager.DefaultManager)
registerYamlPath := cfg.Manager.RegisterPath
manager.RegisterInstancesFromConfig(registerYamlPath)
// Initialize router
r := router.New(cfg)
intervalSecs := cfg.Manager.HealthCheckIntervalSecs
go manager.MonitorInstanceHealth(context.Background(), intervalSecs)
intervalCleanupSecs := cfg.Scheduler.EvictionIntervalSecs
go scheduler_handler.StartBackupCleanupTask(context.Background(), intervalCleanupSecs)
// Start server
addr := ":" + cfg.Server.Port
logger.Info("Starting server on %s", addr)
if err := r.Run(addr); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
@@ -0,0 +1,31 @@
server:
port: "8080"
host: "0.0.0.0"
mode: "debug" # debug, release, test
splitwise: true # true means pd mode, false means mixed mode
scheduler:
policy: "power_of_two"
prefill-policy: "cache_aware"
decode-policy: "fd_metrics_score"
eviction-interval-secs: 60
balance-abs-threshold: 1
balance-rel-threshold: 0.2
hit-ratio-weight: 1.0
load-balance-weight: 0.05
cache-block-size: 4
tokenizer-url: "http://0.0.0.0:8098" # optional tokenizer service endpoint
tokenizer-timeout-secs: 2
waiting-weight: 10
manager:
health-failure-threshold: 3
health-success-threshold: 2
health-check-timeout-secs: 5
health-check-interval-secs: 5
health-check-endpoint: /health
register-path: "config/register.yaml"
log:
level: "info" # debug, info, warn, error
output: "file" # stdout, file
@@ -0,0 +1,23 @@
#!/bin/bash
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
echo "Killing existing fd-router process (PID: $PID)"
# Try graceful shutdown first
kill -15 $PID
TIMEOUT=10
while kill -0 $PID 2>/dev/null && [ $TIMEOUT -gt 0 ]; do
echo "Waiting for fd-router (PID: $PID) to exit gracefully... ($TIMEOUT seconds remaining)"
sleep 1
TIMEOUT=$((TIMEOUT - 1))
done
# Force kill if still running after timeout
if kill -0 $PID 2>/dev/null; then
echo "fd-router (PID: $PID) did not exit gracefully; sending SIGKILL..."
kill -9 $PID
fi
fi
echo "Starting new fd-router process..."
nohup ./fd-router --config_path ./config/config.yaml --splitwise > fd-router.log 2>&1 &
echo "fd-router started with PID: $!"
@@ -0,0 +1,31 @@
server:
port: "8080"
host: "0.0.0.0"
mode: "debug" # debug, release, test
splitwise: true # true means pd mode, false means mixed mode
scheduler:
policy: "power_of_two"
prefill-policy: "cache_aware"
decode-policy: "fd_metrics_score"
eviction-interval-secs: 60
balance-abs-threshold: 1
balance-rel-threshold: 0.2
hit-ratio-weight: 1.0
load-balance-weight: 0.05
cache-block-size: 4
tokenizer-url: "http://0.0.0.0:8098" # optional tokenizer service endpoint
tokenizer-timeout-secs: 2
waiting-weight: 10
manager:
health-failure-threshold: 3
health-success-threshold: 2
health-check-timeout-secs: 5
health-check-interval-secs: 5
health-check-endpoint: /health
register-path: "config/register.yaml"
log:
level: "info" # debug, info, warn, error
output: "file" # stdout, file
@@ -0,0 +1,21 @@
instances:
- role: "prefill"
host_ip: 127.0.0.1
port: 8097
connector_port: 8001
engine_worker_queue_port: 8002
transfer_protocol:
- ipc
- rdma
rdma_ports: [7100, "7101"]
device_ids: [0, "1"]
- role: "decode"
host_ip: 127.0.0.1
port: 8098
connector_port: 8001
engine_worker_queue_port: 8002
transfer_protocol:
- ipc
- rdma
rdma_ports: ["7100", "7101"]
device_ids: ["0", "1"]
@@ -0,0 +1,18 @@
#!/bin/bash
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
echo "Killing existing fd-router process (PID: $PID)"
# First try to terminate gracefully
kill -15 "$PID"
sleep 5
# If still running after timeout, force kill as a last resort
if ps -p "$PID" > /dev/null 2>&1; then
echo "Process $PID did not terminate gracefully; force killing..."
kill -9 "$PID"
fi
fi
echo "Starting new fd-router process..."
nohup ./fd-router --config_path ./config/config.yaml --splitwise > fd-router.log 2>&1 &
echo "fd-router started with PID: $!"
+47
View File
@@ -0,0 +1,47 @@
module github.com/PaddlePaddle/FastDeploy/router
go 1.21
require (
github.com/gin-gonic/gin v1.9.1
github.com/stretchr/testify v1.11.1
github.com/valyala/bytebufferpool v1.0.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.21.1 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.36.1 // indirect
)
+112
View File
@@ -0,0 +1,112 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
@@ -0,0 +1,8 @@
package common
import "context"
type ManagerAPI interface {
GetHealthyURLs(ctx context.Context) []string
GetMetrics(ctx context.Context, url string) (int, int, int)
}
@@ -0,0 +1,122 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
Manager ManagerConfig `yaml:"manager"`
Scheduler SchedulerConfig `yaml:"scheduler"`
}
type ServerConfig struct {
Name string `yaml:"name"`
Port string `yaml:"port"`
Host string `yaml:"host"`
Mode string `yaml:"mode"` // debug, release, test
Splitwise bool `yaml:"splitwise"`
}
type ManagerConfig struct {
RegisterPath string `yaml:"register-path"`
HealthFailureThreshold int `yaml:"health-failure-threshold"`
HealthSuccessThreshold int `yaml:"health-success-threshold"`
HealthCheckTimeoutSecs float64 `yaml:"health-check-timeout-secs"`
HealthCheckIntervalSecs float64 `yaml:"health-check-interval-secs"`
HealthCheckEndpoint string `yaml:"health-check-endpoint"`
}
type SchedulerConfig struct {
Policy string `yaml:"policy"`
PrefillPolicy string `yaml:"prefill-policy"`
DecodePolicy string `yaml:"decode-policy"`
EvictionIntervalSecs float64 `yaml:"eviction-interval-secs"`
CacheBlockSize int `yaml:"cache-block-size"`
TokenizerURL string `yaml:"tokenizer-url"`
TokenizerTimeoutSecs float64 `yaml:"tokenizer-timeout-secs"`
BalanceAbsThreshold float64 `yaml:"balance-abs-threshold"`
BalanceRelThreshold float64 `yaml:"balance-rel-threshold"`
HitRatioWeight float64 `yaml:"hit-ratio-weight"`
LoadBalanceWeight float64 `yaml:"load-balance-weight"`
WaitingWeight float64 `yaml:"waiting-weight"`
}
type LogConfig struct {
Level string `yaml:"level"` // debug, info, warn, error
Output string `yaml:"output"` // stdout, file
}
func Load(configPath, listenPort string, isSplitwise bool) (*Config, error) {
var cfg Config
if configPath != "" {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
}
// Set default values
if listenPort != "" {
cfg.Server.Port = listenPort
} else if cfg.Server.Port == "" {
return nil, fmt.Errorf("failed to set router listen port")
}
if isSplitwise {
cfg.Server.Splitwise = true
}
if cfg.Server.Mode == "" {
cfg.Server.Mode = "release"
}
if cfg.Log.Level == "" {
cfg.Log.Level = "info"
}
if cfg.Manager.HealthCheckEndpoint == "" {
cfg.Manager.HealthCheckEndpoint = "/health"
}
if cfg.Manager.HealthCheckTimeoutSecs == 0 {
cfg.Manager.HealthCheckTimeoutSecs = 5
}
if cfg.Manager.HealthCheckIntervalSecs == 0 {
cfg.Manager.HealthCheckIntervalSecs = 5
}
if cfg.Manager.HealthFailureThreshold == 0 {
cfg.Manager.HealthFailureThreshold = 1
}
if cfg.Manager.HealthSuccessThreshold == 0 {
cfg.Manager.HealthSuccessThreshold = 1
}
if cfg.Scheduler.EvictionIntervalSecs == 0 {
cfg.Scheduler.EvictionIntervalSecs = 60
}
if cfg.Scheduler.CacheBlockSize == 0 {
cfg.Scheduler.CacheBlockSize = 64
}
if cfg.Scheduler.TokenizerTimeoutSecs == 0 {
cfg.Scheduler.TokenizerTimeoutSecs = 2
}
if cfg.Scheduler.HitRatioWeight == 0 {
cfg.Scheduler.HitRatioWeight = 1
}
if cfg.Scheduler.LoadBalanceWeight == 0 {
cfg.Scheduler.LoadBalanceWeight = 1
}
if cfg.Scheduler.BalanceAbsThreshold == 0 {
cfg.Scheduler.BalanceAbsThreshold = 1
}
if cfg.Scheduler.BalanceRelThreshold == 0 {
cfg.Scheduler.BalanceRelThreshold = 0.2
}
if cfg.Scheduler.WaitingWeight == 0 {
cfg.Scheduler.WaitingWeight = 1
}
return &cfg, nil
}
@@ -0,0 +1,372 @@
package gateway
import (
"bufio"
"bytes"
crand "crypto/rand"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/PaddlePaddle/FastDeploy/router/pkg/metrics"
"github.com/gin-gonic/gin"
"github.com/valyala/bytebufferpool"
)
const maxCapacity = 10 * 1024 * 1024 // 10MB
// newRequestID generates UUIDv4 style request_id
func newRequestID() string {
b := make([]byte, 16)
if _, err := crand.Read(b); err == nil {
// Set version and variant bits, compliant with RFC 4122
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63())
}
// extractPromptFromChatRequest extracts text prompt from OpenAI ChatCompletions style request
func extractPromptFromChatRequest(rawReq map[string]any) string {
messagesVal, ok := rawReq["messages"]
if !ok {
return ""
}
messages, ok := messagesVal.([]any)
if !ok {
return ""
}
var builder strings.Builder
appendText := func(s string) {
s = strings.TrimSpace(s)
if s == "" {
return
}
if builder.Len() > 0 {
builder.WriteByte(' ')
}
builder.WriteString(s)
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"]
if !ok {
continue
}
switch v := content.(type) {
case string:
appendText(v)
case []any:
for _, item := range v {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType, _ := itemMap["type"].(string)
if itemType != "text" {
continue
}
if textVal, ok := itemMap["text"].(string); ok {
appendText(textVal)
}
}
default:
// Other structures are ignored for now
}
}
return builder.String()
}
// PostToPD sends requests to both Prefill and Decode instances, only returns Decode node response
func PostToPD(c *gin.Context, decodeURL, prefillURL string, reqBody []byte) (*http.Response, error) {
ctx := c.Request.Context()
decodeEndpoint := fmt.Sprintf("%s/v1/%s", decodeURL, "chat/completions")
prefillEndpoint := fmt.Sprintf("%s/v1/%s", prefillURL, "chat/completions")
// Construct two requests
decodeReq, err := http.NewRequestWithContext(ctx, "POST", decodeEndpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
prefillReq, err := http.NewRequestWithContext(ctx, "POST", prefillEndpoint, bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
// Copy request headers
for k, v := range c.Request.Header {
if k != "Content-Length" {
decodeReq.Header[k] = v
prefillReq.Header[k] = v
}
}
client := &http.Client{}
type respResult struct {
resp *http.Response
err error
}
prefillCh := make(chan respResult, 1)
decodeCh := make(chan respResult, 1)
// Concurrently send requests to P/D
go func() {
resp, err := client.Do(prefillReq)
prefillCh <- respResult{resp: resp, err: err}
}()
go func() {
resp, err := client.Do(decodeReq)
decodeCh <- respResult{resp: resp, err: err}
}()
prefillRes := <-prefillCh
decodeRes := <-decodeCh
// Prioritize returning Decode errors
if decodeRes.err != nil {
if prefillRes.resp != nil {
prefillRes.resp.Body.Close()
}
return nil, decodeRes.err
}
if prefillRes.err != nil {
// Prefill errors are also considered failures to avoid inconsistent behavior
if decodeRes.resp != nil {
decodeRes.resp.Body.Close()
}
return nil, prefillRes.err
}
if prefillRes.resp != nil {
prefillRes.resp.Body.Close()
}
return decodeRes.resp, nil
}
// ChatCompletions implements request forwarding to actual large model inference service
func ChatCompletions(c *gin.Context) {
ctx := c.Request.Context()
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.Writer.WriteHeader(http.StatusBadRequest)
c.Writer.Write([]byte(`{"error": "Invalid request body"}`))
return
}
var rawReq map[string]any
if err := json.Unmarshal(bodyBytes, &rawReq); err != nil {
c.Writer.WriteHeader(http.StatusBadRequest)
c.Writer.Write([]byte(`{"error": "Invalid JSON format"}`))
return
}
isSplitwise := manager.GetSplitwise(ctx)
var (
destURL string
releaseTargets []string
requestBodyData []byte
)
if isSplitwise {
// PD mode: select instances for Prefill/Decode separately
message := extractPromptFromChatRequest(rawReq)
prefillURL, decodeURL, err := manager.SelectWorkerPair(ctx, message)
if err != nil {
c.Writer.WriteHeader(http.StatusBadGateway)
c.Writer.Write([]byte(`{"error": "Failed to select worker pair"}`))
return
}
if prefillURL == "" || decodeURL == "" {
c.Writer.WriteHeader(http.StatusServiceUnavailable)
c.Writer.Write([]byte(`{"error": "No available prefill/decode workers"}`))
return
}
// Construct disaggregate_info to ensure selected P/D work in pairs within FastDeploy
disagg, err := manager.BuildDisaggregateInfo(ctx, prefillURL, decodeURL)
if err != nil {
c.Writer.WriteHeader(http.StatusInternalServerError)
c.Writer.Write([]byte(`{"error": "Failed to build disaggregate_info"}`))
return
}
rawReq["disaggregate_info"] = disagg
// If user didn't provide request_id, generate one
if _, ok := rawReq["request_id"]; !ok {
rawReq["request_id"] = newRequestID()
}
// Re-encode request body and send to P and D
requestBodyData, err = json.Marshal(rawReq)
if err != nil {
c.Writer.WriteHeader(http.StatusInternalServerError)
c.Writer.Write([]byte(`{"error": "Failed to encode modified request"}`))
return
}
destURL = decodeURL
releaseTargets = []string{prefillURL, decodeURL}
// Expose scheduling results to caller for debugging/validating scheduling strategy
c.Writer.Header().Set("X-Router-Prefill-URL", prefillURL)
c.Writer.Header().Set("X-Router-Decode-URL", decodeURL)
// Prefill node token count was added in SelectWorker, release when request ends
defer scheduler_handler.ReleasePrefillTokens(ctx, prefillURL, message)
} else {
// Non-PD mode: use Mixed instance
dest, err := manager.SelectWorker(ctx, "")
if err != nil {
c.Writer.WriteHeader(http.StatusBadGateway)
c.Writer.Write([]byte(`{"error": "Failed to select worker"}`))
return
}
destURL = dest
releaseTargets = []string{destURL}
requestBodyData = bodyBytes
}
// Maintain request_num count for related instances (Inc done in SelectWorker, Release here)
defer func() {
for _, url := range releaseTargets {
scheduler_handler.Release(ctx, url)
}
}()
// Send request
var backendResp *http.Response
if isSplitwise {
backendResp, err = PostToPD(c, destURL, releaseTargets[0], requestBodyData)
} else {
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL)
}
if err != nil {
c.Writer.WriteHeader(http.StatusBadGateway)
c.Writer.Write([]byte(`{"error": "Failed to connect to backend service"}`))
return
}
defer backendResp.Body.Close()
if isSplitwise {
metrics.InferenceRequests.WithLabelValues("", releaseTargets[0], destURL, strconv.Itoa(backendResp.StatusCode)).Inc()
} else {
metrics.InferenceRequests.WithLabelValues(destURL, "", "", strconv.Itoa(backendResp.StatusCode)).Inc()
}
// Copy response headers
for k, v := range backendResp.Header {
if k != "Content-Length" { // Remove Content-Length header
c.Writer.Header()[k] = v
}
}
//c.Writer.Header().Set("Transfer-Encoding", "chunked") // Set chunked transfer
if backendResp.StatusCode == http.StatusOK {
c.Writer.WriteHeader(backendResp.StatusCode)
}
isStream := false
if v, ok := rawReq["stream"]; ok {
stream, ok := v.(bool)
if ok && stream {
isStream = true
}
}
redirect(c, isStream, backendResp)
}
func redirect(c *gin.Context, isStream bool, backendResp *http.Response) {
// Forward response body
if isStream {
// Stream response, use buffer pool to avoid frequent buffer creation/destruction
buffer := bytebufferpool.Get()
buffer.Reset()
defer bytebufferpool.Put(buffer)
scanner := bufio.NewScanner(backendResp.Body)
scanner.Buffer(buffer.B, maxCapacity) // Key: reset buffer
for scanner.Scan() {
line := scanner.Text()
c.Writer.Write([]byte(line + "\n"))
c.Writer.Flush()
}
if err := scanner.Err(); err != nil {
logger.Error("scanner error: %v", err)
}
} else {
// Compatible with non-stream response
io.Copy(c.Writer, backendResp.Body)
}
}
// GetClientWithRetry adds retry
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string) (
backendResp *http.Response, err error) {
// Five retries
maxRetry := 3
for i := 0; i < maxRetry; i++ {
// If creating request fails, it's network connection error, check if selected node is elastic resource, if so, delete it
backendResp, err = GetClient(c, destUrl, "chat/completions", bodyBytes)
if err == nil { // Return latest bucketsize
return backendResp, nil
}
}
return nil, err
}
func GetClient(c *gin.Context, address, api string, reqBody []byte) (*http.Response, error) {
backendURL := fmt.Sprintf("%s/v1/%s", address, api)
backendReq, err := http.NewRequestWithContext(
c.Request.Context(),
"POST",
backendURL,
bytes.NewReader(reqBody),
)
if err != nil {
return nil, err
}
// Copy request headers
for k, v := range c.Request.Header {
if k != "Content-Length" { // Remove Content-Length header
backendReq.Header[k] = v
}
}
client := &http.Client{}
backendResp, err := client.Do(backendReq)
if err != nil {
return nil, err
}
return backendResp, nil
}
@@ -0,0 +1,146 @@
package gateway
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestChatCompletions(t *testing.T) {
// Since the actual implementation uses package-level functions that depend on DefaultManager,
// and we don't want to set up a full manager for unit tests,
// this test will be marked as integration test and skipped for now
t.Skip("Integration test requiring manager setup")
}
func TestExtractPromptFromChatRequest(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
"simple message",
`{"messages": [{"role": "user", "content": "hello"}]}`,
"hello",
},
{
"multiple messages",
`{"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
{"role": "user", "content": "how are you"}
]}`,
"hello hi how are you",
},
{
"empty messages",
`{"messages": []}`,
"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var rawReq map[string]any
err := json.Unmarshal([]byte(tt.input), &rawReq)
assert.NoError(t, err)
result := extractPromptFromChatRequest(rawReq)
assert.Equal(t, tt.expected, result)
})
}
}
func TestPostToPD(t *testing.T) {
// Setup test servers
prefillTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer prefillTS.Close()
decodeTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"response": "test"}`))
}))
defer decodeTS.Close()
// Setup test context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(`{"test": "data"}`))
resp, err := PostToPD(c, decodeTS.URL, prefillTS.URL, []byte(`{"test": "data"}`))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestRedirect(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
}))
defer ts.Close()
// Test stream response
t.Run("stream response", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/", nil)
resp, err := http.Get(ts.URL)
assert.NoError(t, err)
redirect(c, true, resp)
assert.Equal(t, "test response\n", w.Body.String())
})
// Test non-stream response
t.Run("non-stream response", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/", nil)
resp, err := http.Get(ts.URL)
assert.NoError(t, err)
redirect(c, false, resp)
assert.Equal(t, "test response", w.Body.String())
})
}
func TestGetClient(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
}))
defer ts.Close()
// Setup test context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(`{"test": "data"}`))
resp, err := GetClient(c, ts.URL, "chat/completions", []byte(`{"test": "data"}`))
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestNewRequestID(t *testing.T) {
id1 := newRequestID()
id2 := newRequestID()
// Check that IDs are not empty
assert.NotEmpty(t, id1)
assert.NotEmpty(t, id2)
// Check that IDs are different
assert.NotEqual(t, id1, id2)
}
@@ -0,0 +1,138 @@
package manager
import (
"context"
"sort"
"sync"
"time"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
)
type Manager struct {
mixedWorkerMap map[string]*WorkerInfo
prefillWorkerMap map[string]*WorkerInfo
decodeWorkerMap map[string]*WorkerInfo
splitwise bool
mu sync.RWMutex
}
type WorkerInfo struct {
Url string `json:"url"`
WorkerType string `json:"worker_type"`
ConnectorPort string `json:"connector_port"`
EngineWorkerQueuePort string `json:"engine_worker_queue_port"`
TransferProtocol []string `json:"transfer_protocol"`
RdmaPorts []string `json:"rdma_ports"`
DeviceIDs []string `json:"device_ids"`
MetricsPort string `json:"metrics_port"`
}
var DefaultManager *Manager
var defaultCheckTimeout time.Duration
var healthEndpoint string
var failureThreshold int
var successThreshold int
// Manager module initialization
func Init(cfg *config.Config) {
manager := &Manager{
mixedWorkerMap: make(map[string]*WorkerInfo),
prefillWorkerMap: make(map[string]*WorkerInfo),
decodeWorkerMap: make(map[string]*WorkerInfo),
splitwise: cfg.Server.Splitwise,
}
DefaultManager = manager
// Define a default timeout duration
defaultCheckTimeout = time.Duration(cfg.Manager.HealthCheckTimeoutSecs * float64(time.Second))
healthEndpoint = cfg.Manager.HealthCheckEndpoint
failureThreshold = cfg.Manager.HealthFailureThreshold
successThreshold = cfg.Manager.HealthSuccessThreshold
}
func WorkerMapToList(ctx context.Context, workerType string) []string {
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
var workerMap map[string]*WorkerInfo
switch workerType {
case "mixed":
workerMap = DefaultManager.mixedWorkerMap
case "prefill":
workerMap = DefaultManager.prefillWorkerMap
case "decode":
workerMap = DefaultManager.decodeWorkerMap
default:
return []string{}
}
if workerMap == nil {
return []string{}
}
// Get all keys and sort them
keys := make([]string, 0, len(workerMap))
for key := range workerMap {
keys = append(keys, key)
}
sort.Strings(keys)
// Build worker list
workerURLs := make([]string, 0, len(keys))
for _, key := range keys {
if workerInfo, exists := workerMap[key]; exists {
workerURLs = append(workerURLs, workerInfo.Url)
}
}
return workerURLs
}
func (m *Manager) GetHealthyURLs(ctx context.Context) []string {
if m == nil {
return []string{}
}
m.mu.RLock()
defer m.mu.RUnlock()
totalSeversLength := len(m.prefillWorkerMap) + len(m.decodeWorkerMap) + len(m.mixedWorkerMap)
allServerURLs := make([]string, 0, totalSeversLength)
for id := range m.prefillWorkerMap {
allServerURLs = append(allServerURLs, id)
}
for id := range m.decodeWorkerMap {
allServerURLs = append(allServerURLs, id)
}
for id := range m.mixedWorkerMap {
allServerURLs = append(allServerURLs, id)
}
return allServerURLs
}
func SelectWorker(ctx context.Context, message string) (string, error) {
workers := WorkerMapToList(ctx, "mixed")
selectedWorkerURL, err := scheduler_handler.SelectWorker(ctx, workers, message, "mixed")
if err != nil {
return "", err
}
return selectedWorkerURL, nil
}
func SelectWorkerPair(ctx context.Context, message string) (string, string, error) {
prefillWorkers := WorkerMapToList(ctx, "prefill")
decodeWorkers := WorkerMapToList(ctx, "decode")
if len(prefillWorkers) == 0 || len(decodeWorkers) == 0 {
return "", "", nil
}
selectedPrefillWorkerURL, err := scheduler_handler.SelectWorker(ctx, prefillWorkers, message, "prefill")
if err != nil {
return "", "", err
}
selectedDecodeWorkerURL, err := scheduler_handler.SelectWorker(ctx, decodeWorkers, message, "decode")
if err != nil {
return "", "", err
}
return selectedPrefillWorkerURL, selectedDecodeWorkerURL, nil
}
@@ -0,0 +1,117 @@
package manager
import (
"context"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/stretchr/testify/assert"
)
func TestInit(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Splitwise: true,
},
Manager: config.ManagerConfig{
HealthCheckTimeoutSecs: 5.0,
HealthCheckEndpoint: "/health",
HealthFailureThreshold: 3,
HealthSuccessThreshold: 2,
},
}
Init(cfg)
assert.NotNil(t, DefaultManager)
assert.True(t, DefaultManager.splitwise)
assert.Equal(t, "/health", healthEndpoint)
assert.Equal(t, 3, failureThreshold)
assert.Equal(t, 2, successThreshold)
}
func TestWorkerMapToList(t *testing.T) {
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"http://worker1": {Url: "http://worker1"},
"http://worker2": {Url: "http://worker2"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"http://worker3": {Url: "http://worker3"},
}
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
"http://worker4": {Url: "http://worker4"},
}
t.Run("prefill workers", func(t *testing.T) {
workers := WorkerMapToList(context.Background(), "prefill")
assert.Len(t, workers, 2)
assert.Contains(t, workers, "http://worker1")
assert.Contains(t, workers, "http://worker2")
})
t.Run("decode workers", func(t *testing.T) {
workers := WorkerMapToList(context.Background(), "decode")
assert.Len(t, workers, 1)
assert.Contains(t, workers, "http://worker3")
})
t.Run("mixed workers", func(t *testing.T) {
workers := WorkerMapToList(context.Background(), "mixed")
assert.Len(t, workers, 1)
assert.Contains(t, workers, "http://worker4")
})
t.Run("invalid worker type", func(t *testing.T) {
workers := WorkerMapToList(context.Background(), "invalid")
assert.Len(t, workers, 0)
})
}
func TestManager_GetHealthyURLs(t *testing.T) {
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: "http://worker1"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: "http://worker2"},
}
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
"worker3": {Url: "http://worker3"},
}
urls := DefaultManager.GetHealthyURLs(context.Background())
assert.Len(t, urls, 3)
assert.Contains(t, urls, "worker1")
assert.Contains(t, urls, "worker2")
assert.Contains(t, urls, "worker3")
}
func TestSelectWorker(t *testing.T) {
// Setup test data
Init(&config.Config{})
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
"http://worker1": {Url: "http://worker1"},
}
// This will fail because SelectWorker depends on scheduler
// which we don't want to mock in this unit test
t.Skip("Integration test requiring scheduler setup")
}
func TestSelectWorkerPair(t *testing.T) {
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"http://worker1": {Url: "http://worker1"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"http://worker2": {Url: "http://worker2"},
}
// This will fail because SelectWorkerPair depends on scheduler
// which we don't want to mock in this unit test
t.Skip("Integration test requiring scheduler setup")
}
@@ -0,0 +1,279 @@
package manager
import (
"context"
"io"
"net/http"
"sync"
"time"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/gin-gonic/gin"
)
type healthCheckResult struct {
url string
isHealthy bool
}
type healthMonitorResult struct {
id string
worker *WorkerInfo // node information
isHealthy bool // check error (nil means healthy)
}
func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Duration) bool {
// Handle empty baseURL
if baseURL == "" {
logger.Error("empty baseURL provided")
return false
}
healthPath := healthEndpoint
url := baseURL + healthPath
timeoutToUse := defaultCheckTimeout // Default timeout
// Override default value if caller provides valid timeout parameter
if len(timeout) > 0 && timeout[0] > 0 {
timeoutToUse = timeout[0]
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
logger.Error("failed to create request: %v", err)
return false
}
// Send request
client := &http.Client{Timeout: timeoutToUse}
resp, err := client.Do(req)
if err != nil {
logger.Error("failed to send request to %s with error: %v", url, err)
return false
}
defer resp.Body.Close()
// Read response body
_, err = io.ReadAll(resp.Body)
if err != nil {
logger.Error("failed to read response body: %v", err)
return false
}
// Check response status code
if resp.StatusCode == http.StatusOK {
return true
}
return false
}
func CheckWorkerHealth(ctx context.Context, baseURL string) bool {
allServers := GetAllMapServers(ctx)
_, exists := allServers[baseURL]
if exists {
for i := 0; i < failureThreshold; i++ {
checkOk := CheckServiceHealth(ctx, baseURL)
if checkOk {
return true
}
}
return false
}
for i := 0; i < successThreshold; i++ {
checkOk := CheckServiceHealth(ctx, baseURL)
if !checkOk {
return false
}
}
return true
}
// Get all URLs of Prefill and Decode servers
func GetAllServerURLs(ctx context.Context) []string {
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
totalSeversLength := len(DefaultManager.prefillWorkerMap) + len(DefaultManager.decodeWorkerMap)
allServerURLs := make([]string, 0, totalSeversLength)
for _, server := range DefaultManager.prefillWorkerMap {
allServerURLs = append(allServerURLs, server.Url)
}
for _, server := range DefaultManager.decodeWorkerMap {
allServerURLs = append(allServerURLs, server.Url)
}
return allServerURLs
}
func HealthGenerate(c *gin.Context) {
// The buffer size of this channel equals the total number of tasks, avoids goroutine blocking
results := make(chan healthCheckResult, len(DefaultManager.prefillWorkerMap)+len(DefaultManager.decodeWorkerMap))
// Use WaitGroup to wait for all goroutines to complete sending results
var wg sync.WaitGroup
allServerURLs := GetAllServerURLs(c.Request.Context())
for _, s := range allServerURLs {
wg.Add(1)
go func(serverURL string) {
defer wg.Done()
baseURL := serverURL
isHealthy := CheckWorkerHealth(c.Request.Context(), baseURL)
results <- healthCheckResult{
url: serverURL,
isHealthy: isHealthy,
}
}(s)
}
// Start a goroutine to close the result channel after all check tasks complete
// Used to notify the range loop can end
go func() {
wg.Wait()
close(results)
}()
for res := range results {
// Process each result
if !res.isHealthy {
logger.Warn("Server %s is not healthy", res.url)
} else {
logger.Info("Server %s is healthy", res.url)
}
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"msg": "Health check complete",
})
}
func RemoveServers(ctx context.Context, prefillToRemove []string, decodeToRemove []string, mixedToRemove []string) {
DefaultManager.mu.Lock()
defer DefaultManager.mu.Unlock()
for _, id := range prefillToRemove {
if worker, exists := DefaultManager.prefillWorkerMap[id]; exists {
delete(DefaultManager.prefillWorkerMap, id)
logger.Info("Removed unhealthy prefill instance: %s", worker.Url)
}
}
for _, id := range decodeToRemove {
if worker, exists := DefaultManager.decodeWorkerMap[id]; exists {
delete(DefaultManager.decodeWorkerMap, id)
logger.Info("Removed unhealthy decode instance: %s", worker.Url)
}
}
for _, id := range mixedToRemove {
if worker, exists := DefaultManager.mixedWorkerMap[id]; exists {
delete(DefaultManager.mixedWorkerMap, id)
logger.Info("Removed unhealthy mixed instance: %s", worker.Url)
}
}
}
func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedInstances []string) {
if DefaultManager == nil {
logger.Debug("Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
return []string{}, []string{}, []string{}
}
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
// Pre-allocate sufficient capacity to avoid multiple expansions
prefillInstances = make([]string, 0, len(DefaultManager.prefillWorkerMap))
decodeInstances = make([]string, 0, len(DefaultManager.decodeWorkerMap))
mixedInstances = make([]string, 0, len(DefaultManager.mixedWorkerMap))
// Copy data to avoid holding lock for long time
for _, w := range DefaultManager.prefillWorkerMap {
prefillInstances = append(prefillInstances, w.Url)
}
for _, w := range DefaultManager.decodeWorkerMap {
decodeInstances = append(decodeInstances, w.Url)
}
for _, w := range DefaultManager.mixedWorkerMap {
mixedInstances = append(mixedInstances, w.Url)
}
logger.Debug(
"Healthy instances: prefill=%v, decode=%v, mixed=%v",
prefillInstances,
decodeInstances,
mixedInstances,
)
return prefillInstances, decodeInstances, mixedInstances
}
func MonitorInstanceHealthCore(ctx context.Context) {
if DefaultManager == nil {
return
}
// Concurrently check health status of all nodes (fix concurrent security issues)
allServers := GetAllMapServers(ctx)
length := len(allServers)
resultCh := make(chan healthMonitorResult, length)
var wg sync.WaitGroup
for id, server := range allServers {
wg.Add(1)
go func(id string, server *WorkerInfo) {
defer wg.Done()
// Execute health check logic
baseURL := server.Url
isHealthy := CheckWorkerHealth(ctx, baseURL)
resultCh <- healthMonitorResult{
id: id,
worker: server,
isHealthy: isHealthy,
}
}(id, server)
}
// Wait for all checks to complete
go func() {
wg.Wait()
close(resultCh)
}()
var prefillToRemove, decodeToRemove, mixedToRemove []string
for res := range resultCh {
if !res.isHealthy {
// logger.Warn("Server %s meets error: %v", res.worker.url, res.err)
switch res.worker.WorkerType {
case "prefill":
prefillToRemove = append(prefillToRemove, res.id)
case "decode":
decodeToRemove = append(decodeToRemove, res.id)
case "mixed":
mixedToRemove = append(mixedToRemove, res.id)
}
go scheduler_handler.CleanupUnhealthyCounter(ctx, res.id)
}
}
// Remove unhealthy instances
RemoveServers(ctx, prefillToRemove, decodeToRemove, mixedToRemove)
ReadServers(ctx)
}
func MonitorInstanceHealth(ctx context.Context, intervalSecs float64) {
ticker := time.NewTicker(time.Duration(intervalSecs * float64(time.Second)))
defer ticker.Stop()
// Infinite loop: continuously execute health checks
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
go MonitorInstanceHealthCore(ctx)
}
}
}
@@ -0,0 +1,160 @@
package manager
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// Initialize logger for all tests
logger.Init("info", "stdout")
}
func TestCheckServiceHealth(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
t.Run("healthy service", func(t *testing.T) {
healthy := CheckServiceHealth(context.Background(), ts.URL)
assert.True(t, healthy)
})
t.Run("unhealthy service", func(t *testing.T) {
unhealthyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer unhealthyServer.Close()
healthy := CheckServiceHealth(context.Background(), unhealthyServer.URL)
assert.False(t, healthy)
})
t.Run("empty baseURL", func(t *testing.T) {
healthy := CheckServiceHealth(context.Background(), "")
assert.False(t, healthy)
})
t.Run("timeout", func(t *testing.T) {
slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
}))
defer slowServer.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
healthy := CheckServiceHealth(ctx, slowServer.URL)
assert.False(t, healthy)
})
}
func TestCheckWorkerHealth(t *testing.T) {
// Setup test data
Init(&config.Config{
Manager: config.ManagerConfig{
HealthFailureThreshold: 1,
HealthSuccessThreshold: 1,
},
})
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
t.Run("healthy worker", func(t *testing.T) {
healthy := CheckWorkerHealth(context.Background(), ts.URL)
assert.True(t, healthy)
})
t.Run("unhealthy worker", func(t *testing.T) {
unhealthyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer unhealthyServer.Close()
healthy := CheckWorkerHealth(context.Background(), unhealthyServer.URL)
assert.False(t, healthy)
})
}
func TestHealthGenerate(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: ts.URL},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: ts.URL},
}
// Test Gin handler
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Set up a valid HTTP request for the context
c.Request = httptest.NewRequest("GET", "/health", nil)
HealthGenerate(c)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Health check complete")
}
func TestMonitorInstanceHealthCore(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: ts.URL, WorkerType: "prefill"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: ts.URL, WorkerType: "decode"},
}
MonitorInstanceHealthCore(context.Background())
// Verify workers still exist (since they're healthy)
_, exists := DefaultManager.prefillWorkerMap["worker1"]
assert.True(t, exists)
_, exists = DefaultManager.decodeWorkerMap["worker2"]
assert.True(t, exists)
}
func TestReadServers(t *testing.T) {
// Setup test data
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: "http://worker1"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: "http://worker2"},
}
prefill, decode, mixed := ReadServers(context.Background())
assert.Equal(t, []string{"http://worker1"}, prefill)
assert.Equal(t, []string{"http://worker2"}, decode)
assert.Equal(t, []string{}, mixed)
}
@@ -0,0 +1,221 @@
package manager
import (
"encoding/json"
"errors"
"fmt"
"net"
"slices"
"strconv"
"strings"
)
type InstanceRole int
const (
MIXED InstanceRole = iota
PREFILL
DECODE
)
var roleNames = [...]string{"mixed", "prefill", "decode"}
func (r InstanceRole) String() string { return roleNames[r] }
func ParseInstanceRole(s string) (InstanceRole, error) {
for i, name := range roleNames {
if strings.EqualFold(strings.ToLower(s), name) {
return InstanceRole(i), nil
}
}
return -1, fmt.Errorf("invalid role: %s", s)
}
type Role struct {
EnumValue InstanceRole
CustomName string
IsCustom bool
IsSet bool
}
func (r *Role) parse(getInt func() (int, error), getStr func() (string, error)) error {
r.IsSet = true
if i, err := getInt(); err == nil {
if i >= 0 && i <= int(DECODE) {
r.EnumValue, r.IsCustom = InstanceRole(i), false
return nil
}
return fmt.Errorf("invalid role integer: %d", i)
}
s, err := getStr()
if err != nil {
return err
}
if e, err := ParseInstanceRole(s); err == nil {
r.EnumValue, r.IsCustom = e, false
} else {
r.CustomName, r.IsCustom = s, true
}
return nil
}
func (r *Role) UnmarshalJSON(data []byte) error {
return r.parse(
func() (int, error) { var i int; return i, json.Unmarshal(data, &i) },
func() (string, error) { var s string; return s, json.Unmarshal(data, &s) },
)
}
func (r *Role) UnmarshalYAML(u func(interface{}) error) error {
return r.parse(
func() (int, error) { var i int; return i, u(&i) },
func() (string, error) { var s string; return s, u(&s) },
)
}
func (r Role) MarshalJSON() ([]byte, error) {
if r.IsCustom {
return json.Marshal(r.CustomName)
}
return json.Marshal(r.EnumValue.String())
}
type Port string
func (p *Port) UnmarshalJSON(data []byte) error {
var i int
if json.Unmarshal(data, &i) == nil {
*p = Port(strconv.Itoa(i))
return nil
}
return json.Unmarshal(data, (*string)(p))
}
func (p *Port) UnmarshalYAML(u func(interface{}) error) error {
var i int
if u(&i) == nil {
*p = Port(strconv.Itoa(i))
return nil
}
return u((*string)(p))
}
type IntToStringList []string
func (sl *IntToStringList) UnmarshalJSON(data []byte) error {
return sl.unmarshal(data, json.Unmarshal)
}
func (sl *IntToStringList) UnmarshalYAML(u func(interface{}) error) error {
return sl.unmarshal(nil, func(_ []byte, v interface{}) error { return u(v) })
}
func (sl *IntToStringList) unmarshal(data []byte, u func([]byte, interface{}) error) error {
var raw []interface{}
if err := u(data, &raw); err != nil {
return err
}
res := make([]string, len(raw))
for i, v := range raw {
switch val := v.(type) {
case string:
res[i] = val
case int:
res[i] = strconv.Itoa(val)
case float64:
if val == float64(int(val)) {
res[i] = strconv.Itoa(int(val))
} else {
return fmt.Errorf("element %d: %v not integer", i, val)
}
default:
return fmt.Errorf("element %d: type %T unsupported", i, v)
}
}
*sl = res
return nil
}
type InstanceInfo struct {
Role Role `json:"role" yaml:"role"`
HostIP string `json:"host_ip" yaml:"host_ip"`
Port Port `json:"port" yaml:"port"`
ConnectorPort Port `json:"connector_port,omitempty" yaml:"connector_port,omitempty"`
EngineWorkerQueuePort Port `json:"engine_worker_queue_port,omitempty" yaml:"engine_worker_queue_port,omitempty"`
TransferProtocol []string `json:"transfer_protocol,omitempty" yaml:"transfer_protocol,omitempty"`
RDMAPorts IntToStringList `json:"rdma_ports,omitempty" yaml:"rdma_ports,omitempty"`
DeviceIDs IntToStringList `json:"device_ids,omitempty" yaml:"device_ids,omitempty"`
MetricsPort Port `json:"metrics_port,omitempty" yaml:"metrics_port,omitempty"`
}
func isValidPort(p Port) bool {
i, err := strconv.Atoi(string(p))
if err != nil {
return false
}
return i > 0 && i <= 65535
}
func isValidIP(ip string) bool {
return net.ParseIP(ip) != nil
}
func validatePortList(name string, list []string) error {
for i, p := range list {
portInt, err := strconv.Atoi(p)
if err != nil || portInt <= 0 || portInt > 65535 {
return fmt.Errorf("%s[%d] invalid port: %s", name, i, p)
}
}
return nil
}
func (info *InstanceInfo) URL() string {
url := fmt.Sprintf("%s:%s", info.HostIP, info.Port)
if !strings.HasPrefix(url, "http") {
url = "http://" + url
}
return url
}
func NewInstanceInfo(info *InstanceInfo) (*InstanceInfo, error) {
if !info.Role.IsSet {
return nil, errors.New("role is required")
}
if info.Role.IsCustom {
return nil, fmt.Errorf("invalid role: %s", info.Role.CustomName)
}
if info.HostIP == "" {
return nil, errors.New("host_ip is required")
}
if !isValidIP(info.HostIP) {
return nil, fmt.Errorf("invalid host_ip: %s", info.HostIP)
}
if info.Port == "" {
return nil, errors.New("port is required")
}
if !isValidPort(info.Port) {
return nil, fmt.Errorf("invalid port: %s", info.Port)
}
if DefaultManager.splitwise && info.ConnectorPort != "" && !isValidPort(info.ConnectorPort) {
return nil, fmt.Errorf("invalid connector_port: %s", info.ConnectorPort)
}
if DefaultManager.splitwise && info.EngineWorkerQueuePort != "" && !isValidPort(info.EngineWorkerQueuePort) {
return nil, fmt.Errorf("invalid engine_worker_queue_port: %s", info.EngineWorkerQueuePort)
}
for _, proto := range info.TransferProtocol {
if !slices.Contains([]string{"ipc", "rdma"}, proto) {
return nil, fmt.Errorf("invalid protocol: %s", proto)
}
}
if err := validatePortList("rdma_ports", info.RDMAPorts); DefaultManager.splitwise && err != nil {
return nil, err
}
if info.MetricsPort == "" {
info.MetricsPort = info.Port
} else {
if !isValidPort(info.MetricsPort) {
return nil, fmt.Errorf("invalid metrics_port: %s", info.MetricsPort)
}
}
return info, nil
}
@@ -0,0 +1,304 @@
package manager
import (
"encoding/json"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
func TestInstanceRole_String(t *testing.T) {
assert.Equal(t, "mixed", MIXED.String())
assert.Equal(t, "prefill", PREFILL.String())
assert.Equal(t, "decode", DECODE.String())
}
func TestParseInstanceRole(t *testing.T) {
tests := []struct {
name string
input string
expected InstanceRole
expectErr bool
}{
{"mixed lowercase", "mixed", MIXED, false},
{"mixed uppercase", "MIXED", MIXED, false},
{"prefill lowercase", "prefill", PREFILL, false},
{"prefill uppercase", "PREFILL", PREFILL, false},
{"decode lowercase", "decode", DECODE, false},
{"decode uppercase", "DECODE", DECODE, false},
{"invalid role", "invalid", -1, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ParseInstanceRole(tt.input)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestRole_UnmarshalJSON(t *testing.T) {
t.Run("valid role from integer", func(t *testing.T) {
var role Role
err := json.Unmarshal([]byte("0"), &role)
assert.NoError(t, err)
assert.Equal(t, MIXED, role.EnumValue)
assert.False(t, role.IsCustom)
})
t.Run("valid role from string", func(t *testing.T) {
var role Role
err := json.Unmarshal([]byte(`"prefill"`), &role)
assert.NoError(t, err)
assert.Equal(t, PREFILL, role.EnumValue)
assert.False(t, role.IsCustom)
})
t.Run("custom role", func(t *testing.T) {
var role Role
err := json.Unmarshal([]byte(`"custom-role"`), &role)
assert.NoError(t, err)
assert.Equal(t, "custom-role", role.CustomName)
assert.True(t, role.IsCustom)
})
t.Run("invalid integer", func(t *testing.T) {
var role Role
err := json.Unmarshal([]byte("99"), &role)
assert.Error(t, err)
})
}
func TestRole_UnmarshalYAML(t *testing.T) {
t.Run("valid role from integer", func(t *testing.T) {
var role Role
err := yaml.Unmarshal([]byte("1"), &role)
assert.NoError(t, err)
assert.Equal(t, PREFILL, role.EnumValue)
assert.False(t, role.IsCustom)
})
t.Run("valid role from string", func(t *testing.T) {
var role Role
err := yaml.Unmarshal([]byte("decode"), &role)
assert.NoError(t, err)
assert.Equal(t, DECODE, role.EnumValue)
assert.False(t, role.IsCustom)
})
}
func TestRole_MarshalJSON(t *testing.T) {
t.Run("standard role", func(t *testing.T) {
role := Role{EnumValue: MIXED, IsCustom: false}
data, err := json.Marshal(role)
assert.NoError(t, err)
assert.Equal(t, `"mixed"`, string(data))
})
t.Run("custom role", func(t *testing.T) {
role := Role{CustomName: "custom", IsCustom: true}
data, err := json.Marshal(role)
assert.NoError(t, err)
assert.Equal(t, `"custom"`, string(data))
})
}
func TestPort_UnmarshalJSON(t *testing.T) {
t.Run("port as integer", func(t *testing.T) {
var port Port
err := json.Unmarshal([]byte("8080"), &port)
assert.NoError(t, err)
assert.Equal(t, Port("8080"), port)
})
t.Run("port as string", func(t *testing.T) {
var port Port
err := json.Unmarshal([]byte(`"9090"`), &port)
assert.NoError(t, err)
assert.Equal(t, Port("9090"), port)
})
}
func TestPort_UnmarshalYAML(t *testing.T) {
t.Run("port as integer", func(t *testing.T) {
var port Port
err := yaml.Unmarshal([]byte("8080"), &port)
assert.NoError(t, err)
assert.Equal(t, Port("8080"), port)
})
}
func TestIntToStringList_UnmarshalJSON(t *testing.T) {
t.Run("mixed types", func(t *testing.T) {
var list IntToStringList
err := json.Unmarshal([]byte(`["1", 2, 3.0]`), &list)
assert.NoError(t, err)
assert.Equal(t, IntToStringList{"1", "2", "3"}, list)
})
t.Run("invalid float", func(t *testing.T) {
var list IntToStringList
err := json.Unmarshal([]byte(`[1.5]`), &list)
assert.Error(t, err)
})
t.Run("invalid type", func(t *testing.T) {
var list IntToStringList
err := json.Unmarshal([]byte(`[true]`), &list)
assert.Error(t, err)
})
}
func TestInstanceInfo_URL(t *testing.T) {
info := &InstanceInfo{
HostIP: "127.0.0.1",
Port: Port("8080"),
}
url := info.URL()
assert.Equal(t, "http://127.0.0.1:8080", url)
}
func TestNewInstanceInfo(t *testing.T) {
// Setup DefaultManager
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
t.Run("valid instance", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: PREFILL, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
ConnectorPort: Port("9000"),
TransferProtocol: []string{"rdma"},
RDMAPorts: IntToStringList{"5000", "5001"},
DeviceIDs: IntToStringList{"0", "1"},
}
result, err := NewInstanceInfo(info)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("missing role", func(t *testing.T) {
info := &InstanceInfo{
HostIP: "127.0.0.1",
Port: Port("8080"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "role is required")
})
t.Run("custom role", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{CustomName: "custom", IsCustom: true, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid role")
})
t.Run("missing host_ip", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
Port: Port("8080"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "host_ip is required")
})
t.Run("invalid host_ip", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
HostIP: "invalid-ip",
Port: Port("8080"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid host_ip")
})
t.Run("invalid port", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("99999"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid port")
})
t.Run("invalid connector_port", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: PREFILL, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
ConnectorPort: Port("99999"),
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid connector_port")
})
t.Run("invalid transfer protocol", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: PREFILL, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
TransferProtocol: []string{"invalid"},
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid protocol")
})
t.Run("invalid rdma port", func(t *testing.T) {
info := &InstanceInfo{
Role: Role{EnumValue: PREFILL, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
TransferProtocol: []string{"rdma"},
RDMAPorts: IntToStringList{"99999"},
}
_, err := NewInstanceInfo(info)
assert.Error(t, err)
assert.Contains(t, err.Error(), "rdma_ports[0] invalid port")
})
}
func TestIsValidPort(t *testing.T) {
assert.True(t, isValidPort(Port("8080")))
assert.True(t, isValidPort(Port("1")))
assert.True(t, isValidPort(Port("65535")))
assert.False(t, isValidPort(Port("0")))
assert.False(t, isValidPort(Port("65536")))
assert.False(t, isValidPort(Port("invalid")))
assert.False(t, isValidPort(Port("")))
}
func TestIsValidIP(t *testing.T) {
assert.True(t, isValidIP("127.0.0.1"))
assert.True(t, isValidIP("192.168.1.1"))
assert.True(t, isValidIP("::1"))
assert.False(t, isValidIP("invalid"))
assert.False(t, isValidIP(""))
}
@@ -0,0 +1,109 @@
package manager
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"strconv"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
)
// Precompile regex to avoid repeated compilation
var (
runningRequestsRegex = regexp.MustCompile(`fastdeploy:num_requests_running\s+([0-9.]+)`)
waitingRequestsRegex = regexp.MustCompile(`fastdeploy:num_requests_waiting\s+([0-9.]+)`)
availableGpuBlockNumRegex = regexp.MustCompile(`available_gpu_block_num\s+([0-9.]+)`)
)
// parseMetricsResponseOptimized parses metrics response string and extracts key metrics
func parseMetricsResponseOptimized(response string) (float64, float64, float64) {
waitingCnt := -1.0
availableGpuBlockNum := -1.0
runningCnt := -1.0
// Find fastdeploy:num_requests_running field using precompiled regex
if matches := runningRequestsRegex.FindStringSubmatch(response); len(matches) >= 2 {
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
runningCnt = score
}
}
// Find fastdeploy:num_requests_waiting field using precompiled regex
if matches := waitingRequestsRegex.FindStringSubmatch(response); len(matches) >= 2 {
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
waitingCnt = score
}
}
// Parse available_gpu_block_num field
if matches := availableGpuBlockNumRegex.FindStringSubmatch(response); len(matches) >= 2 {
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
availableGpuBlockNum = score
}
}
return runningCnt, waitingCnt, availableGpuBlockNum
}
// redrictCounter gets or creates a counter for the given URL and returns current count
func redrictCounter(ctx context.Context, rawURL string) int {
counter := scheduler_handler.GetOrCreateCounter(ctx, rawURL)
return int(counter.Get())
}
// GetMetricsByURL retrieves running metrics of a worker by specified URL
func GetMetricsByURL(ctx context.Context, rawURL string) (int, int, int, error) {
workerInfo := getWorkerInfo(ctx, rawURL)
if workerInfo == nil {
return 0, 0, 0, errors.New("worker info not found for URL")
}
u, err := url.Parse(rawURL)
if err != nil {
return 0, 0, 0, err
}
host, _, err := net.SplitHostPort(u.Host)
if err != nil {
return 0, 0, 0, err
}
u.Host = net.JoinHostPort(host, workerInfo.MetricsPort)
metricsUrl := fmt.Sprintf("%s/metrics", u.String())
client := &http.Client{Timeout: defaultCheckTimeout}
resp, err := client.Get(metricsUrl)
if err != nil {
return 0, 0, 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, 0, 0, err
}
if len(body) == 0 {
return 0, 0, 0, errors.New("metrics response is empty")
}
// Parse metrics response
runningCnt, waitingCnt, availableGpuBlockNum := parseMetricsResponseOptimized(string(body))
if runningCnt < 0 || waitingCnt < 0 || availableGpuBlockNum < 0 {
return 0, 0, 0, errors.New("failed to parse metrics response")
}
return int(runningCnt), int(waitingCnt), int(availableGpuBlockNum), nil
}
// GetMetrics retrieves running metrics of the worker for the specified URL
func (m *Manager) GetMetrics(ctx context.Context, rawURL string) (int, int, int) {
runningCnt, waitingCnt, availableGpuBlockNum, err := GetMetricsByURL(ctx, rawURL)
if err != nil {
runningNewCnt := redrictCounter(ctx, rawURL)
return runningNewCnt, 0, 0
}
return runningCnt, waitingCnt, availableGpuBlockNum
}
@@ -0,0 +1,357 @@
package manager
import (
"context"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
"github.com/stretchr/testify/assert"
)
// Helper function to get map keys
func getMapKeys(m interface{}) []string {
v := reflect.ValueOf(m)
if v.Kind() != reflect.Map {
return nil
}
keys := v.MapKeys()
result := make([]string, len(keys))
for i, key := range keys {
result[i] = key.String()
}
return result
}
func TestRedrictCounter(t *testing.T) {
// Initialize scheduler for counter tests
cfg := &config.Config{
Scheduler: config.SchedulerConfig{
Policy: "random",
PrefillPolicy: "random",
DecodePolicy: "random",
WaitingWeight: 1.0,
},
}
// Using nil for managerAPI since we're only testing counters
scheduler_handler.Init(cfg, nil)
// Setup test context
ctx := context.Background()
// Test case 1: First call with new URL should return 0
t.Run("new_url_returns_zero", func(t *testing.T) {
count := redrictCounter(ctx, "http://new-service-1")
assert.Equal(t, 0, count)
})
// Test case 2: Multiple calls with same URL should return incremented count
t.Run("same_url_increments", func(t *testing.T) {
url := "http://same-service"
// First call should return 0
count1 := redrictCounter(ctx, url)
assert.Equal(t, 0, count1)
// Simulate counter increment by calling GetOrCreateCounter and incrementing
counter := scheduler_handler.GetOrCreateCounter(ctx, url)
counter.Inc()
// Second call should return incremented count
count2 := redrictCounter(ctx, url)
assert.Equal(t, 1, count2)
})
// Test case 3: Different URLs should have independent counters
t.Run("different_urls_independent", func(t *testing.T) {
url1 := "http://service-a"
url2 := "http://service-b"
// Call first URL
count1 := redrictCounter(ctx, url1)
assert.Equal(t, 0, count1)
// Call second URL
count2 := redrictCounter(ctx, url2)
assert.Equal(t, 0, count2)
// Increment first URL's counter
counter1 := scheduler_handler.GetOrCreateCounter(ctx, url1)
counter1.Inc()
counter1.Inc()
// Verify first URL shows incremented count
assert.Equal(t, 2, redrictCounter(ctx, url1))
// Second URL should still be 0
assert.Equal(t, 0, redrictCounter(ctx, url2))
})
// Test case 4: Empty URL should work (edge case)
t.Run("empty_url", func(t *testing.T) {
count := redrictCounter(ctx, "")
assert.Equal(t, 0, count)
})
// Test case 5: Nil context should work (though not recommended)
t.Run("nil_context", func(t *testing.T) {
count := redrictCounter(ctx, "http://nil-context-test")
assert.Equal(t, 0, count)
})
}
func TestParseMetricsResponseOptimized(t *testing.T) {
tests := []struct {
name string
input string
expectedRun float64
expectedWait float64
expectedGpu float64
}{
{
name: "valid metrics response",
input: `fastdeploy:num_requests_running 10
fastdeploy:num_requests_waiting 5
available_gpu_block_num 3`,
expectedRun: 10,
expectedWait: 5,
expectedGpu: 3,
},
{
name: "partial metrics response",
input: `fastdeploy:num_requests_running 8
available_gpu_block_num 2`,
expectedRun: 8,
expectedWait: -1,
expectedGpu: 2,
},
{
name: "empty response",
input: "",
expectedRun: -1,
expectedWait: -1,
expectedGpu: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
run, wait, gpu := parseMetricsResponseOptimized(tt.input)
assert.Equal(t, tt.expectedRun, run)
assert.Equal(t, tt.expectedWait, wait)
assert.Equal(t, tt.expectedGpu, gpu)
})
}
}
func TestGetMetricsByURL_Integration(t *testing.T) {
// Initialize manager for testing
Init(&config.Config{})
// Test with mock HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/metrics" {
w.WriteHeader(http.StatusNotFound)
return
}
switch r.URL.Query().Get("scenario") {
case "valid":
w.Write([]byte(`fastdeploy:num_requests_running 5
fastdeploy:num_requests_waiting 3
available_gpu_block_num 10`))
case "partial":
w.Write([]byte(`fastdeploy:num_requests_running 2
available_gpu_block_num 5`))
case "empty":
// Empty response body
w.WriteHeader(http.StatusOK)
case "error":
w.WriteHeader(http.StatusInternalServerError)
default:
w.Write([]byte(`fastdeploy:num_requests_running 1
fastdeploy:num_requests_waiting 0
available_gpu_block_num 8`))
}
}))
defer server.Close()
// Test worker not found
t.Run("worker_not_found", func(t *testing.T) {
_, _, _, err := GetMetricsByURL(context.Background(), "http://nonexistent-worker:8080")
assert.Error(t, err)
assert.Contains(t, err.Error(), "worker info not found")
})
// Test with registered worker
t.Run("with_registered_worker", func(t *testing.T) {
// Register a test worker with localhost to avoid DNS resolution issues
workerURL := "http://localhost:8080"
DefaultManager.mu.Lock()
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
Url: workerURL,
WorkerType: "prefill",
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
}
t.Logf("Registered worker: URL=%s, MetricsPort=%s", workerURL, strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
DefaultManager.mu.Unlock()
// Debug: check what URLs are in the map
DefaultManager.mu.RLock()
t.Logf("prefillWorkerMap keys: %v", getMapKeys(DefaultManager.prefillWorkerMap))
DefaultManager.mu.RUnlock()
t.Logf("Looking up worker for URL: %s", workerURL)
// Test valid metrics response - the test server should handle scenarios based on the request
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
if err != nil {
t.Logf("Error: %v", err)
}
assert.NoError(t, err)
assert.Equal(t, 1, running) // Default scenario in test server
assert.Equal(t, 0, waiting) // Default scenario in test server
assert.Equal(t, 8, gpu) // Default scenario in test server
})
// Test invalid URL
t.Run("invalid_url", func(t *testing.T) {
_, _, _, err := GetMetricsByURL(context.Background(), "http://invalid url")
assert.Error(t, err)
})
// Test partial metrics
t.Run("partial_metrics", func(t *testing.T) {
workerURL := "http://localhost:8081" // Different port
DefaultManager.mu.Lock()
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
Url: workerURL,
WorkerType: "prefill",
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
}
t.Logf("Registered worker: URL=%s, MetricsPort=%s", workerURL, strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
DefaultManager.mu.Unlock()
t.Logf("Looking up worker for URL: %s", workerURL)
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
if err != nil {
t.Logf("Error: %v", err)
}
assert.NoError(t, err)
assert.Equal(t, 1, running) // Default scenario
assert.Equal(t, 0, waiting) // Default scenario
assert.Equal(t, 8, gpu) // Default scenario
})
}
func TestManagerGetMetrics_Integration(t *testing.T) {
// Initialize manager for testing
Init(&config.Config{})
// Initialize scheduler for counter tests
cfg := &config.Config{
Scheduler: config.SchedulerConfig{
Policy: "random",
PrefillPolicy: "random",
DecodePolicy: "random",
WaitingWeight: 1.0,
},
}
scheduler_handler.Init(cfg, nil)
m := &Manager{}
// Setup mock HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/metrics" {
w.WriteHeader(http.StatusNotFound)
return
}
w.Write([]byte(`fastdeploy:num_requests_running 2
fastdeploy:num_requests_waiting 1
available_gpu_block_num 5`))
}))
defer server.Close()
// Test normal case with registered worker
t.Run("normal_case", func(t *testing.T) {
workerURL := "http://localhost:8080" // Use localhost to avoid DNS lookup
DefaultManager.mu.Lock()
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
Url: workerURL,
WorkerType: "prefill",
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
}
t.Logf("Registered worker with MetricsPort: %s", strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
DefaultManager.mu.Unlock()
// Debug: test GetMetricsByURL directly first
t.Logf("Testing GetMetricsByURL directly...")
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
if err != nil {
t.Logf("GetMetricsByURL failed: %v", err)
} else {
t.Logf("GetMetricsByURL succeeded: running=%d, waiting=%d, gpu=%d", running, waiting, gpu)
}
// Now test Manager.GetMetrics
running, waiting, gpu = m.GetMetrics(context.Background(), workerURL)
t.Logf("Manager.GetMetrics result: running=%d, waiting=%d, gpu=%d", running, waiting, gpu)
assert.Equal(t, 2, running)
assert.Equal(t, 1, waiting)
assert.Equal(t, 5, gpu)
})
// Test error case (should fall back to counter)
t.Run("error_fallback", func(t *testing.T) {
// Use a URL that doesn't have a registered worker
workerURL := "http://unknown-worker:8080"
running, waiting, gpu := m.GetMetrics(context.Background(), workerURL)
// Should fall back to counter (which should be 0 for new URL)
assert.Equal(t, 0, running) // Should be 0 for new counter
assert.Equal(t, 0, waiting) // Should be 0 in error case
assert.Equal(t, 0, gpu) // Should be 0 in error case
})
}
// Helper function to test metrics parsing directly
func TestMetricsParsingHelper(t *testing.T) {
tests := []struct {
name string
metricsBody string
expected []int
}{
{
name: "complete_metrics",
metricsBody: `fastdeploy:num_requests_running 5
fastdeploy:num_requests_waiting 3
available_gpu_block_num 10`,
expected: []int{5, 3, 10},
},
{
name: "missing_waiting",
metricsBody: `fastdeploy:num_requests_running 2
available_gpu_block_num 5`,
expected: []int{2, -1, 5},
},
{
name: "empty_body",
metricsBody: "",
expected: []int{-1, -1, -1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
run, wait, gpu := parseMetricsResponseOptimized(tt.metricsBody)
assert.Equal(t, float64(tt.expected[0]), run)
assert.Equal(t, float64(tt.expected[1]), wait)
assert.Equal(t, float64(tt.expected[2]), gpu)
})
}
}
@@ -0,0 +1,333 @@
package manager
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"slices"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/gin-gonic/gin"
"gopkg.in/yaml.v3"
)
type RegisterConfig struct {
Instances []InstanceInfo `yaml:"instances"`
}
func GetSplitwise(ctx context.Context) bool {
if DefaultManager == nil {
return false
}
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
return DefaultManager.splitwise
}
func GetAllMapServers(ctx context.Context) map[string]*WorkerInfo {
if DefaultManager == nil {
return make(map[string]*WorkerInfo)
}
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
allServers := make(map[string]*WorkerInfo)
for id, workerInfo := range DefaultManager.prefillWorkerMap {
allServers[id] = workerInfo
}
for id, workerInfo := range DefaultManager.decodeWorkerMap {
allServers[id] = workerInfo
}
for id, workerInfo := range DefaultManager.mixedWorkerMap {
allServers[id] = workerInfo
}
return allServers
}
// getWorkerInfo gets WorkerInfo based on URL
func getWorkerInfo(ctx context.Context, url string) *WorkerInfo {
if DefaultManager == nil {
return nil
}
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
if w, ok := DefaultManager.prefillWorkerMap[url]; ok {
return w
}
if w, ok := DefaultManager.decodeWorkerMap[url]; ok {
return w
}
if w, ok := DefaultManager.mixedWorkerMap[url]; ok {
return w
}
return nil
}
// BuildDisaggregateInfo builds disaggregate_info structure
func BuildDisaggregateInfo(ctx context.Context, prefillURL, decodeURL string) (map[string]any, error) {
prefillInfo := getWorkerInfo(ctx, prefillURL)
decodeInfo := getWorkerInfo(ctx, decodeURL)
if prefillInfo == nil || decodeInfo == nil {
return nil, fmt.Errorf("worker instance not found for prefill=%s, decode=%s", prefillURL, decodeURL)
}
prefillHost := hostFromURL(prefillInfo.Url)
decodeHost := hostFromURL(decodeInfo.Url)
// Check if IPC can be used
isSameNode := prefillHost != "" && prefillHost == decodeHost
isSupportIPC := slices.Contains(prefillInfo.TransferProtocol, "ipc") &&
slices.Contains(decodeInfo.TransferProtocol, "ipc")
tpPrefill := tpSizeFromWorker(prefillInfo)
tpDecode := tpSizeFromWorker(decodeInfo)
isSameTpSize := tpPrefill == tpDecode || tpDecode == 1
useIPC := isSameNode && isSupportIPC && isSameTpSize
transferProto := "rdma"
if useIPC {
transferProto = "ipc"
}
disagg := map[string]any{
"prefill_ip": prefillHost,
"decode_ip": decodeHost,
"prefill_connector_port": portStringToInt(Port(prefillInfo.ConnectorPort)),
"decode_connector_port": portStringToInt(Port(decodeInfo.ConnectorPort)),
"decode_device_ids": []string(decodeInfo.DeviceIDs),
"decode_rdma_ports": []string(decodeInfo.RdmaPorts),
"transfer_protocol": transferProto,
"decode_tp_size": tpDecode,
}
return disagg, nil
}
// portStringToInt converts Port (string) to int
func portStringToInt(p Port) int {
s := string(p)
if s == "" {
return 0
}
i, err := strconv.Atoi(s)
if err != nil {
return 0
}
return i
}
// tpSizeFromWorker calculates tp_size (currently no explicit field, uses device_ids count, minimum 1)
func tpSizeFromWorker(w *WorkerInfo) int {
if w == nil {
return 0
}
if len(w.DeviceIDs) > 0 {
return len(w.DeviceIDs)
}
return 1
}
// hostFromURL extracts host part (without port)
func hostFromURL(raw string) string {
if raw == "" {
return ""
}
if !strings.HasPrefix(raw, "http://") && !strings.HasPrefix(raw, "https://") {
raw = "http://" + raw
}
u, err := url.Parse(raw)
if err != nil {
return ""
}
return u.Hostname()
}
func RegisterInstanceCore(ctx context.Context, rawInstance *InstanceInfo) error {
instance, err := NewInstanceInfo(rawInstance)
if err != nil {
return fmt.Errorf("invalid InstanceInfo format:%v", err)
}
splitwiseMode := GetSplitwise(ctx)
instanceRole := instance.Role.EnumValue
if splitwiseMode && instanceRole == MIXED {
return fmt.Errorf("splitwise mode only supports PREFILL/DECODE instances")
}
if !splitwiseMode && instanceRole != MIXED {
return fmt.Errorf("only MIXED instances are allowed")
}
// Check instance health status
if !CheckWorkerHealth(ctx, instance.URL()) {
return fmt.Errorf("service is not healthy")
}
allServers := GetAllMapServers(ctx)
DefaultManager.mu.Lock()
defer DefaultManager.mu.Unlock()
workerInfo := &WorkerInfo{
Url: instance.URL(),
WorkerType: instance.Role.EnumValue.String(),
ConnectorPort: string(instance.ConnectorPort),
EngineWorkerQueuePort: string(instance.EngineWorkerQueuePort),
TransferProtocol: instance.TransferProtocol,
RdmaPorts: []string(instance.RDMAPorts),
DeviceIDs: []string(instance.DeviceIDs),
MetricsPort: string(instance.MetricsPort),
}
id := instance.URL()
if w, exists := allServers[id]; exists {
wType, err := ParseInstanceRole(w.WorkerType)
if err == nil {
switch wType {
case MIXED:
delete(DefaultManager.mixedWorkerMap, id)
case PREFILL:
delete(DefaultManager.prefillWorkerMap, id)
case DECODE:
delete(DefaultManager.decodeWorkerMap, id)
}
}
}
switch instanceRole {
case MIXED:
DefaultManager.mixedWorkerMap[id] = workerInfo
case PREFILL:
DefaultManager.prefillWorkerMap[id] = workerInfo
case DECODE:
DefaultManager.decodeWorkerMap[id] = workerInfo
default:
logger.Warn("Instance %s role is unknown", id)
}
return nil
}
func RegisterInstance(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"msg": "Invalid request body",
})
return
}
var rawInstance InstanceInfo
err = json.Unmarshal(bodyBytes, &rawInstance)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"msg": fmt.Sprintf("Invalid InstanceInfo JSON format: %v", err),
})
return
}
if err := RegisterInstanceCore(c.Request.Context(), &rawInstance); err != nil {
logger.Error("Failed to register instance: %v", err)
// Return different HTTP status codes based on error type
if strings.Contains(err.Error(), "not healthy") {
c.JSON(http.StatusServiceUnavailable, gin.H{
"code": 503,
"msg": err.Error(),
})
} else {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"msg": err.Error(),
})
}
return
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"msg": "Register success",
})
}
func RegisterInstancesFromConfig(yamlPath string) {
if yamlPath == "" {
return
}
data, err := os.ReadFile(yamlPath)
if err != nil {
logger.Error("Failed to read YAML file %s: %v", yamlPath, err)
return
}
var config RegisterConfig
if err := yaml.Unmarshal(data, &config); err != nil {
logger.Error("Failed to unmarshal YAML file %s: %v", yamlPath, err)
return
}
if len(config.Instances) == 0 {
logger.Info("No instances found in config file %s", yamlPath)
return
}
for i, instanceConfig := range config.Instances {
if err := RegisterInstanceCore(context.Background(), &instanceConfig); err != nil {
logger.Error("Failed to register instance from index %d: %v", i, err)
} else {
logger.Info("Successfully registered instance from index %d", i)
}
}
}
func RegisteredNumber(c *gin.Context) {
if DefaultManager == nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"msg": "DefaultManager is nil",
})
return
}
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
c.JSON(http.StatusOK, gin.H{
"mixed": len(DefaultManager.mixedWorkerMap),
"prefill": len(DefaultManager.prefillWorkerMap),
"decode": len(DefaultManager.decodeWorkerMap),
})
}
func Registered(c *gin.Context) {
DefaultManager.mu.RLock()
defer DefaultManager.mu.RUnlock()
var prefillInstances, decodeInstances, mixedInstances []WorkerInfo
decodeInstances = make([]WorkerInfo, 0)
prefillInstances = make([]WorkerInfo, 0)
mixedInstances = make([]WorkerInfo, 0)
for _, w := range DefaultManager.prefillWorkerMap {
prefillInstances = append(prefillInstances, *w)
}
for _, w := range DefaultManager.decodeWorkerMap {
decodeInstances = append(decodeInstances, *w)
}
for _, w := range DefaultManager.mixedWorkerMap {
mixedInstances = append(mixedInstances, *w)
}
c.JSON(http.StatusOK, gin.H{
"code": http.StatusOK,
"msg": "success",
"decode": decodeInstances,
"prefill": prefillInstances,
"mixed": mixedInstances,
})
}
@@ -0,0 +1,395 @@
package manager
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestGetSplitwise(t *testing.T) {
t.Run("DefaultManager is nil", func(t *testing.T) {
originalManager := DefaultManager
DefaultManager = nil
defer func() { DefaultManager = originalManager }()
result := GetSplitwise(context.Background())
assert.False(t, result)
})
t.Run("splitwise mode enabled", func(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
result := GetSplitwise(context.Background())
assert.True(t, result)
})
t.Run("splitwise mode disabled", func(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
result := GetSplitwise(context.Background())
assert.False(t, result)
})
}
func TestGetAllMapServers(t *testing.T) {
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: "http://worker1"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: "http://worker2"},
}
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
"worker3": {Url: "http://worker3"},
}
servers := GetAllMapServers(context.Background())
assert.Len(t, servers, 3)
assert.Contains(t, servers, "worker1")
assert.Contains(t, servers, "worker2")
assert.Contains(t, servers, "worker3")
}
func TestGetAllMapServers_NilManager(t *testing.T) {
originalManager := DefaultManager
DefaultManager = nil
defer func() { DefaultManager = originalManager }()
servers := GetAllMapServers(context.Background())
assert.NotNil(t, servers)
assert.Len(t, servers, 0)
}
func TestGetWorkerInfo(t *testing.T) {
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"http://worker1": {Url: "http://worker1", WorkerType: "prefill"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"http://worker2": {Url: "http://worker2", WorkerType: "decode"},
}
t.Run("find prefill worker", func(t *testing.T) {
info := getWorkerInfo(context.Background(), "http://worker1")
assert.NotNil(t, info)
assert.Equal(t, "prefill", info.WorkerType)
})
t.Run("find decode worker", func(t *testing.T) {
info := getWorkerInfo(context.Background(), "http://worker2")
assert.NotNil(t, info)
assert.Equal(t, "decode", info.WorkerType)
})
t.Run("worker not found", func(t *testing.T) {
info := getWorkerInfo(context.Background(), "http://notfound")
assert.Nil(t, info)
})
}
func TestBuildDisaggregateInfo(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
// Setup test workers
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"http://127.0.0.1:8000": {
Url: "http://127.0.0.1:8000",
WorkerType: "prefill",
ConnectorPort: "9000",
TransferProtocol: []string{"rdma"},
DeviceIDs: []string{"0", "1"},
RdmaPorts: []string{"5000", "5001"},
},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"http://127.0.0.1:8001": {
Url: "http://127.0.0.1:8001",
WorkerType: "decode",
ConnectorPort: "9001",
TransferProtocol: []string{"rdma"},
DeviceIDs: []string{"0", "1"},
RdmaPorts: []string{"5002", "5003"},
},
}
t.Run("successful build", func(t *testing.T) {
info, err := BuildDisaggregateInfo(context.Background(),
"http://127.0.0.1:8000", "http://127.0.0.1:8001")
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, "127.0.0.1", info["prefill_ip"])
assert.Equal(t, "127.0.0.1", info["decode_ip"])
assert.Equal(t, "rdma", info["transfer_protocol"])
assert.Equal(t, 2, info["decode_tp_size"])
})
t.Run("worker not found", func(t *testing.T) {
_, err := BuildDisaggregateInfo(context.Background(),
"http://notfound", "http://notfound2")
assert.Error(t, err)
assert.Contains(t, err.Error(), "worker instance not found")
})
}
func TestPortStringToInt(t *testing.T) {
assert.Equal(t, 8080, portStringToInt(Port("8080")))
assert.Equal(t, 0, portStringToInt(Port("")))
assert.Equal(t, 0, portStringToInt(Port("invalid")))
}
func TestTpSizeFromWorker(t *testing.T) {
t.Run("worker with device IDs", func(t *testing.T) {
worker := &WorkerInfo{DeviceIDs: []string{"0", "1", "2"}}
assert.Equal(t, 3, tpSizeFromWorker(worker))
})
t.Run("worker without device IDs", func(t *testing.T) {
worker := &WorkerInfo{DeviceIDs: []string{}}
assert.Equal(t, 1, tpSizeFromWorker(worker))
})
t.Run("nil worker", func(t *testing.T) {
assert.Equal(t, 0, tpSizeFromWorker(nil))
})
}
func TestHostFromURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"standard URL", "http://127.0.0.1:8080", "127.0.0.1"},
{"HTTPS URL", "https://example.com:443", "example.com"},
{"URL without protocol", "127.0.0.1:8080", "127.0.0.1"},
{"empty URL", "", ""},
{"invalid URL", "://invalid", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hostFromURL(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRegisterInstanceCore(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
t.Run("successful registration", func(t *testing.T) {
instance := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
TransferProtocol: []string{"rdma"},
}
err := RegisterInstanceCore(context.Background(), instance)
assert.NoError(t, err)
})
t.Run("invalid instance info", func(t *testing.T) {
instance := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
// Missing required fields
}
err := RegisterInstanceCore(context.Background(), instance)
assert.Error(t, err)
})
t.Run("splitwise mode with mixed instance", func(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
instance := &InstanceInfo{
Role: Role{EnumValue: MIXED, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
TransferProtocol: []string{"rdma"},
}
err := RegisterInstanceCore(context.Background(), instance)
assert.Error(t, err)
assert.Contains(t, err.Error(), "splitwise mode only supports PREFILL/DECODE instances")
})
t.Run("non-splitwise mode with prefill instance", func(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
instance := &InstanceInfo{
Role: Role{EnumValue: PREFILL, IsSet: true},
HostIP: "127.0.0.1",
Port: Port("8080"),
TransferProtocol: []string{"rdma"},
}
err := RegisterInstanceCore(context.Background(), instance)
assert.Error(t, err)
assert.Contains(t, err.Error(), "only MIXED instances are allowed")
})
}
func TestRegisterInstance(t *testing.T) {
// Setup test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
t.Run("successful registration", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"role": "mixed", "host_ip": "127.0.0.1", "port": 8080, "transfer_protocol": ["rdma"]}`
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString(body))
RegisterInstance(c)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "Register success")
})
t.Run("invalid JSON", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString("invalid json"))
RegisterInstance(c)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("empty body", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString(""))
RegisterInstance(c)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
}
func TestRegisterInstancesFromConfig(t *testing.T) {
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
// Create a temporary YAML file
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "test_config.yaml")
yamlContent := `
instances:
- role: mixed
host_ip: 127.0.0.1
port: 8080
transfer_protocol:
- rdma
`
err := os.WriteFile(yamlPath, []byte(yamlContent), 0644)
assert.NoError(t, err)
// This should not panic
RegisterInstancesFromConfig(yamlPath)
}
func TestRegisterInstancesFromConfig_InvalidPath(t *testing.T) {
// Should not panic with invalid path
RegisterInstancesFromConfig("/nonexistent/path.yaml")
}
func TestRegisterInstancesFromConfig_InvalidYAML(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "invalid.yaml")
err := os.WriteFile(yamlPath, []byte("invalid: yaml: content:"), 0644)
assert.NoError(t, err)
// Should not panic with invalid YAML
RegisterInstancesFromConfig(yamlPath)
}
func TestRegisterInstancesFromConfig_EmptyFile(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "empty.yaml")
err := os.WriteFile(yamlPath, []byte("instances: []"), 0644)
assert.NoError(t, err)
// Should not panic with empty instances
RegisterInstancesFromConfig(yamlPath)
}
func TestRegisteredNumber(t *testing.T) {
t.Run("DefaultManager is nil", func(t *testing.T) {
originalManager := DefaultManager
DefaultManager = nil
defer func() { DefaultManager = originalManager }()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/registered/number", nil)
RegisteredNumber(c)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "DefaultManager is nil")
})
t.Run("successful query", func(t *testing.T) {
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: "http://worker1"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: "http://worker2"},
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/registered/number", nil)
RegisteredNumber(c)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), `"prefill":1`)
assert.Contains(t, w.Body.String(), `"decode":1`)
})
}
func TestRegistered(t *testing.T) {
Init(&config.Config{})
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
"worker1": {Url: "http://worker1", WorkerType: "prefill"},
}
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
"worker2": {Url: "http://worker2", WorkerType: "decode"},
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest("GET", "/registered", nil)
Registered(c)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), `"decode"`)
assert.Contains(t, w.Body.String(), `"prefill"`)
}
@@ -0,0 +1,33 @@
package middleware
import (
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/gin-gonic/gin"
)
// Logger logger middleware
func Logger() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
logger.Info("[%s] %s %s %d %s %s",
param.Method,
param.Path,
param.Request.Proto,
param.StatusCode,
param.Latency,
param.ClientIP,
)
return ""
})
}
// Recovery recovery middleware
func Recovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
logger.Error("Panic recovered: %v", recovered)
c.JSON(500, gin.H{
"code": 500,
"msg": "Internal server error",
})
c.Abort()
})
}
@@ -0,0 +1,48 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// Initialize logger to avoid nil pointer dereference in recovery middleware
logger.Init("info", "stdout")
}
func TestLoggerMiddleware(t *testing.T) {
router := gin.New()
router.Use(Logger())
router.GET("/test", func(c *gin.Context) {
c.String(200, "OK")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
func TestRecoveryMiddleware(t *testing.T) {
router := gin.New()
router.Use(Recovery())
router.GET("/panic", func(c *gin.Context) {
panic("test panic")
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/panic", nil)
router.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
// The response should contain the error message
assert.Contains(t, w.Body.String(), "Internal server error")
}
@@ -0,0 +1,30 @@
package middleware
import (
"strconv"
"time"
"github.com/PaddlePaddle/FastDeploy/router/pkg/metrics"
"github.com/gin-gonic/gin"
)
// Metrics provides middleware for collecting HTTP request metrics
func Metrics() gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
method := c.Request.Method
// Time before request processing starts
start := time.Now()
c.Next() // Process the request
// Collect response time statistics after request processing completes
duration := time.Since(start)
status := strconv.Itoa(c.Writer.Status())
// Collect metrics information
metrics.TotalRequests.WithLabelValues(method, path, status).Inc() // Increment request count
metrics.RequestDuration.WithLabelValues(method, path).Observe(duration.Seconds()) // Record request response time
}
}
@@ -0,0 +1,28 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestMetricsMiddleware(t *testing.T) {
// Setup test router
router := gin.New()
router.Use(Metrics())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "OK")
})
// Make test request
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/test", nil)
router.ServeHTTP(w, req)
// Verify response
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "OK", w.Body.String())
}
@@ -0,0 +1,36 @@
package router
import (
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/PaddlePaddle/FastDeploy/router/internal/middleware"
"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/PaddlePaddle/FastDeploy/router/internal/gateway"
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
)
func New(cfg *config.Config) *gin.Engine {
// Set Gin mode
gin.SetMode(cfg.Server.Mode)
r := gin.New()
// Global middleware
r.Use(middleware.Logger())
r.Use(middleware.Recovery())
// API route group
v1 := r.Group("/v1")
{
v1.POST("/chat/completions", gateway.ChatCompletions)
v1.POST("/completions", gateway.ChatCompletions)
}
r.POST("/register", manager.RegisterInstance)
r.GET("/registered_number", manager.RegisteredNumber)
r.GET("/registered", manager.Registered)
r.GET("/health_generate", manager.HealthGenerate)
r.GET("/metrics", gin.WrapH(promhttp.Handler()))
return r
}
@@ -0,0 +1,55 @@
package common
import (
"sync/atomic"
)
type Counter struct {
count atomic.Uint64
}
func (c *Counter) Inc() {
c.count.Add(1)
}
func (c *Counter) Dec() {
c.count.Add(^uint64(0))
}
func (c *Counter) Get() uint64 {
return c.count.Load()
}
// TokenCounter records the number of tokens currently being processed by each P instance
type TokenCounter struct {
tokens atomic.Uint64
}
func (c *TokenCounter) Add(n uint64) {
c.tokens.Add(n)
}
func (c *TokenCounter) Get() uint64 {
return c.tokens.Load()
}
func (c *TokenCounter) Sub(n uint64) {
if n == 0 {
return
}
for {
old := c.tokens.Load()
if old == 0 {
return
}
var newVal uint64
if old <= n {
newVal = 0
} else {
newVal = old - n
}
if c.tokens.CompareAndSwap(old, newVal) {
return
}
}
}
@@ -0,0 +1,5 @@
package common
import "context"
type SelectStrategyFunc func(ctx context.Context, workers []string, message string) (string, error)
@@ -0,0 +1,32 @@
package handler
import (
"context"
"math"
)
func computeScore(ctx context.Context, runningCnt int, waitingCnt int) float64 {
score := float64(runningCnt) + float64(waitingCnt)*waitingWeight
return score
}
func FDMetricsScoreSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
var (
selectedURL string = ""
minScore float64 = math.MaxFloat64
)
for _, w := range workers {
runningCnt, waitingCnt, _ := DefaultScheduler.managerAPI.GetMetrics(ctx, w)
score := computeScore(ctx, runningCnt, waitingCnt)
if score < minScore {
minScore = score
selectedURL = w
}
}
return selectedURL, nil
}
@@ -0,0 +1,282 @@
package handler
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
common "github.com/PaddlePaddle/FastDeploy/router/internal/common"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
scheduler_common "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/common"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
)
type Scheduler struct {
policy string
prefillPolicy string
decodePolicy string
IdCounterMap map[string]*scheduler_common.Counter
tokenMap map[string]*scheduler_common.TokenCounter
managerAPI common.ManagerAPI
prefillCache *prefillCacheStrategy
mu sync.RWMutex
}
type CounterPolicy struct {
counter atomic.Uint64
prefillCounter atomic.Uint64
workerType string
}
var DefaultScheduler *Scheduler
var DefaultCounterPolicy *CounterPolicy
var waitingWeight float64
// Init initializes the scheduler with the given configuration and manager API
func Init(cfg *config.Config, managerAPI common.ManagerAPI) {
prefillCfg := &schedulerConfigSnapshot{
balanceAbsThreshold: cfg.Scheduler.BalanceAbsThreshold,
balanceRelThreshold: cfg.Scheduler.BalanceRelThreshold,
hitRatioWeight: cfg.Scheduler.HitRatioWeight,
loadBalanceWeight: cfg.Scheduler.LoadBalanceWeight,
cacheBlockSize: cfg.Scheduler.CacheBlockSize,
tokenizerURL: cfg.Scheduler.TokenizerURL,
tokenizerTimeout: time.Duration(cfg.Scheduler.TokenizerTimeoutSecs * float64(time.Second)),
}
scheduler := &Scheduler{
policy: cfg.Scheduler.Policy,
prefillPolicy: cfg.Scheduler.PrefillPolicy,
decodePolicy: cfg.Scheduler.DecodePolicy,
IdCounterMap: make(map[string]*scheduler_common.Counter),
tokenMap: make(map[string]*scheduler_common.TokenCounter),
managerAPI: managerAPI,
prefillCache: newPrefillCacheStrategy(prefillCfg),
}
counterPolicy := &CounterPolicy{}
DefaultScheduler = scheduler
DefaultCounterPolicy = counterPolicy
waitingWeight = cfg.Scheduler.WaitingWeight
}
// SelectWorker selects a worker based on the specified policy and worker type
func SelectWorker(ctx context.Context, workers []string, message string, workerType string) (string, error) {
if len(workers) == 0 {
return "", fmt.Errorf("no healthy workers available")
}
var policy string
switch workerType {
case "prefill":
policy = DefaultScheduler.prefillPolicy
DefaultCounterPolicy.workerType = "prefill"
case "decode":
policy = DefaultScheduler.decodePolicy
DefaultCounterPolicy.workerType = "decode"
default:
policy = DefaultScheduler.policy
DefaultCounterPolicy.workerType = "mixed"
}
var strategyFunc scheduler_common.SelectStrategyFunc
switch policy {
case "random":
strategyFunc = RandomSelectWorker
case "round_robin":
strategyFunc = RoundRobinSelectWorker
case "power_of_two":
strategyFunc = PowerOfTwoSelectWorker
case "process_tokens":
// Prefill: prioritize the instance with the smallest number of tokens currently being processed
strategyFunc = ProcessTokensSelectWorker
case "request_num":
// Decode/mixed: prioritize the instance with the smallest number of current requests
strategyFunc = RequestNumSelectWorker
case "fd_metrics_score":
strategyFunc = FDMetricsScoreSelectWorker
case "cache_aware":
strategyFunc = CacheAwarePrefillSelectWorker
default:
strategyFunc = RandomSelectWorker
}
selectWorkerURL, err := strategyFunc(ctx, workers, message)
if err != nil {
return "", fmt.Errorf("select worker failed [policy: %s]: %w", DefaultScheduler.policy, err)
}
if !strings.HasPrefix(selectWorkerURL, "http://") && !strings.HasPrefix(selectWorkerURL, "https://") {
selectWorkerURL = "http://" + selectWorkerURL
}
// 1) All node types: request concurrency count (request_num)
counter := GetOrCreateCounter(ctx, selectWorkerURL)
counter.Inc()
count := counter.Get()
// 2) Prefill: current token processing count (process_tokens)
var tokens uint64
if workerType == "prefill" && message != "" {
tokenCounter := GetOrCreateTokenCounter(ctx, selectWorkerURL)
tokenCounter.Add(estimateTokens(message))
tokens = tokenCounter.Get()
}
if workerType == "prefill" {
logger.Info("select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
} else {
logger.Info("select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
}
return selectWorkerURL, nil
}
// Release decreases the counter for the specified worker URL
func Release(ctx context.Context, url string) {
counter := GetOrCreateCounter(ctx, url)
counter.Dec()
logger.Info("release worker: %s, count: %d", url, counter.Get())
}
// GetCounter retrieves the counter for the specified root URL
func GetCounter(ctx context.Context, rootURL string) (*scheduler_common.Counter, bool) {
DefaultScheduler.mu.RLock()
defer DefaultScheduler.mu.RUnlock()
counter, exists := DefaultScheduler.IdCounterMap[rootURL]
return counter, exists
}
// GetOrCreateCounter retrieves an existing counter or creates a new one
func GetOrCreateCounter(ctx context.Context, url string) *scheduler_common.Counter {
counter, exists := GetCounter(ctx, url)
if exists {
return counter
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
// Double check: avoid overwriting what other goroutines may have created before acquiring write lock
if counter, exists = DefaultScheduler.IdCounterMap[url]; exists {
return counter
}
newCounter := &scheduler_common.Counter{}
DefaultScheduler.IdCounterMap[url] = newCounter
return newCounter
}
// CleanupUnhealthyCounter removes counters for unhealthy worker URLs
func CleanupUnhealthyCounter(ctx context.Context, unhealthyRootURL string) {
if unhealthyRootURL == "" {
return
}
if DefaultScheduler == nil {
return
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
delete(DefaultScheduler.IdCounterMap, unhealthyRootURL)
delete(DefaultScheduler.tokenMap, unhealthyRootURL)
logger.Info("After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
}
// CleanupInvalidCounters removes counters for invalid or unreachable workers
func CleanupInvalidCounters(ctx context.Context) {
if DefaultScheduler == nil {
return
}
if DefaultScheduler.managerAPI == nil {
return
}
healthyURLs := DefaultScheduler.managerAPI.GetHealthyURLs(ctx)
if len(healthyURLs) == 0 {
return
}
healthyMap := make(map[string]bool)
for _, rootURL := range healthyURLs {
healthyMap[rootURL] = true
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
for rootURL := range DefaultScheduler.IdCounterMap {
if _, exists := healthyMap[rootURL]; !exists {
delete(DefaultScheduler.IdCounterMap, rootURL)
}
}
for rootURL := range DefaultScheduler.tokenMap {
if _, exists := healthyMap[rootURL]; !exists {
delete(DefaultScheduler.tokenMap, rootURL)
}
}
logger.Info("After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
}
// StartBackupCleanupTask starts a background task for cleaning up invalid counters
func StartBackupCleanupTask(ctx context.Context, interval float64) {
ticker := time.NewTicker(time.Duration(interval * float64(time.Second)))
defer ticker.Stop()
for {
select {
// case 1: listen for context cancellation/timeout events → graceful exit
case <-ctx.Done():
return // Exit loop, stop cleanup task
// case 2: listen for timer trigger events → perform cleanup
case <-ticker.C:
CleanupInvalidCounters(ctx)
}
}
}
// GetTokenCounter gets the TokenCounter for the specified instance
func GetTokenCounter(ctx context.Context, rootURL string) (*scheduler_common.TokenCounter, bool) {
DefaultScheduler.mu.RLock()
defer DefaultScheduler.mu.RUnlock()
counter, exists := DefaultScheduler.tokenMap[rootURL]
return counter, exists
}
// GetOrCreateTokenCounter gets or creates TokenCounter
func GetOrCreateTokenCounter(ctx context.Context, url string) *scheduler_common.TokenCounter {
counter, exists := GetTokenCounter(ctx, url)
if exists {
return counter
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
// Double check to avoid overwriting
if counter, exists = DefaultScheduler.tokenMap[url]; exists {
return counter
}
newCounter := &scheduler_common.TokenCounter{}
DefaultScheduler.tokenMap[url] = newCounter
return newCounter
}
// estimateTokens estimates token count based on character count: character count * 2
func estimateTokens(message string) uint64 {
if message == "" {
return 0
}
runeCount := utf8.RuneCountInString(message)
return uint64(runeCount * 2)
}
// ReleasePrefillTokens releases the corresponding token load when request ends
func ReleasePrefillTokens(ctx context.Context, url, message string) {
if url == "" || message == "" {
return
}
tokenCounter := GetOrCreateTokenCounter(ctx, url)
tokenCounter.Sub(estimateTokens(message))
logger.Info("release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
}
@@ -0,0 +1,213 @@
package handler
import (
"context"
"testing"
"time"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/stretchr/testify/assert"
)
type mockManagerAPI struct{}
func (m *mockManagerAPI) GetHealthyURLs(ctx context.Context) []string {
return []string{"worker1", "worker2"}
}
func (m *mockManagerAPI) GetMetrics(ctx context.Context, url string) (int, int, int) {
return 0, 0, 0 // 返回默认值用于测试
}
func TestSchedulerInit(t *testing.T) {
cfg := &config.Config{
Scheduler: config.SchedulerConfig{
Policy: "random",
PrefillPolicy: "process_tokens",
DecodePolicy: "request_num",
},
}
Init(cfg, &mockManagerAPI{})
assert.NotNil(t, DefaultScheduler)
assert.Equal(t, "random", DefaultScheduler.policy)
assert.Equal(t, "process_tokens", DefaultScheduler.prefillPolicy)
assert.Equal(t, "request_num", DefaultScheduler.decodePolicy)
}
func TestSelectWorker(t *testing.T) {
ctx := context.Background()
workers := []string{"worker1", "worker2", "worker3"}
Init(&config.Config{
Scheduler: config.SchedulerConfig{
Policy: "random",
PrefillPolicy: "process_tokens",
DecodePolicy: "request_num",
},
}, &mockManagerAPI{})
t.Run("prefill worker selection", func(t *testing.T) {
// Set up token counts
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
tc1.Add(100)
tc2 := GetOrCreateTokenCounter(ctx, "worker2")
tc2.Add(50) // Should be selected
tc3 := GetOrCreateTokenCounter(ctx, "worker3")
tc3.Add(200)
selected, err := SelectWorker(ctx, workers, "test message", "prefill")
assert.NoError(t, err)
assert.Equal(t, "http://worker2", selected)
})
t.Run("decode worker selection", func(t *testing.T) {
// Set up request counts
c1 := GetOrCreateCounter(ctx, "worker1")
c1.Inc()
c1.Inc() // count = 2
c2 := GetOrCreateCounter(ctx, "worker2") // count = 0 (should be selected)
c3 := GetOrCreateCounter(ctx, "worker3")
c3.Inc() // count = 1
// Verify counts
assert.Equal(t, uint64(2), c1.Get())
assert.Equal(t, uint64(0), c2.Get())
assert.Equal(t, uint64(1), c3.Get())
selected, err := SelectWorker(ctx, workers, "test", "decode")
assert.NoError(t, err)
assert.Equal(t, "http://worker2", selected)
})
}
func TestCounterOperations(t *testing.T) {
ctx := context.Background()
Init(&config.Config{}, nil)
t.Run("counter increment", func(t *testing.T) {
counter := GetOrCreateCounter(ctx, "test")
assert.Equal(t, uint64(0), counter.Get())
counter.Inc()
assert.Equal(t, uint64(1), counter.Get())
counter.Dec()
assert.Equal(t, uint64(0), counter.Get())
})
t.Run("token counter operations", func(t *testing.T) {
tc := GetOrCreateTokenCounter(ctx, "test")
assert.Equal(t, uint64(0), tc.Get())
tc.Add(100)
assert.Equal(t, uint64(100), tc.Get())
tc.Sub(50)
assert.Equal(t, uint64(50), tc.Get())
})
}
func TestCleanupInvalidCounters(t *testing.T) {
ctx := context.Background()
Init(&config.Config{}, &mockManagerAPI{})
// Add some counters
c1 := GetOrCreateCounter(ctx, "worker1")
c1.Inc()
GetOrCreateCounter(ctx, "invalid-worker") // Should be cleaned up
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
tc1.Add(100)
GetOrCreateTokenCounter(ctx, "invalid-worker") // Should be cleaned up
CleanupInvalidCounters(ctx)
// Verify counters
_, exists := GetCounter(ctx, "worker1")
assert.True(t, exists)
_, exists = GetCounter(ctx, "invalid-worker")
assert.False(t, exists)
// Verify token counters
_, exists = GetTokenCounter(ctx, "worker1")
assert.True(t, exists)
_, exists = GetTokenCounter(ctx, "invalid-worker")
assert.False(t, exists)
}
func TestEstimateTokens(t *testing.T) {
tests := []struct {
input string
expected uint64
}{
{"", 0},
{"hello", 10}, // 5 chars * 2
{"你好", 4}, // 2 chars * 2 (Chinese characters count as 1 char each)
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
assert.Equal(t, tt.expected, estimateTokens(tt.input))
})
}
}
func TestReleasePrefillTokens(t *testing.T) {
ctx := context.Background()
Init(&config.Config{}, nil)
t.Run("valid release", func(t *testing.T) {
tc := GetOrCreateTokenCounter(ctx, "worker1")
tc.Add(100)
ReleasePrefillTokens(ctx, "worker1", "hello") // 5 chars * 2 = 10 tokens
assert.Equal(t, uint64(90), tc.Get())
})
t.Run("empty url or message", func(t *testing.T) {
tc := GetOrCreateTokenCounter(ctx, "worker2")
tc.Add(100)
ReleasePrefillTokens(ctx, "", "hello") // no-op
ReleasePrefillTokens(ctx, "worker2", "") // no-op
assert.Equal(t, uint64(100), tc.Get())
})
}
func TestCleanupUnhealthyCounter(t *testing.T) {
ctx := context.Background()
Init(&config.Config{}, nil)
// Add counters
c := GetOrCreateCounter(ctx, "unhealthy-worker")
c.Inc()
tc := GetOrCreateTokenCounter(ctx, "unhealthy-worker")
tc.Add(100)
CleanupUnhealthyCounter(ctx, "unhealthy-worker")
// Verify cleanup
_, exists := GetCounter(ctx, "unhealthy-worker")
assert.False(t, exists)
_, exists = GetTokenCounter(ctx, "unhealthy-worker")
assert.False(t, exists)
}
func TestStartBackupCleanupTask(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
Init(&config.Config{}, &mockManagerAPI{})
// Add invalid counter
GetOrCreateCounter(ctx, "invalid-worker")
// Start cleanup task with short interval
go StartBackupCleanupTask(ctx, 0.1) // 0.1 second interval
// Wait for cleanup
time.Sleep(200 * time.Millisecond)
cancel()
// Verify cleanup
_, exists := GetCounter(ctx, "invalid-worker")
assert.False(t, exists)
}
@@ -0,0 +1,52 @@
package handler
import (
"context"
"math"
)
// ProcessTokensSelectWorker selects the instance with the smallest number of tokens currently being processed for Prefill nodes.
func ProcessTokensSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
var (
selected string
minTokens uint64 = math.MaxUint64
)
for _, w := range workers {
tc := GetOrCreateTokenCounter(ctx, w)
load := tc.Get()
if load < minTokens {
minTokens = load
selected = w
}
}
return selected, nil
}
// RequestNumSelectWorker selects the instance with the smallest number of current requests for Decode nodes.
func RequestNumSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
var (
selected string
minCount uint64 = math.MaxUint64
)
for _, w := range workers {
c := GetOrCreateCounter(ctx, w)
load := c.Get()
if load < minCount {
minCount = load
selected = w
}
}
return selected, nil
}
@@ -0,0 +1,79 @@
package handler
import (
"context"
"testing"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
"github.com/stretchr/testify/assert"
)
func TestProcessTokensSelectWorker(t *testing.T) {
ctx := context.Background()
// Setup test data
workers := []string{"worker1", "worker2", "worker3"}
// Initialize scheduler and token counters
Init(&config.Config{
Scheduler: config.SchedulerConfig{
Policy: "process_tokens",
},
}, nil)
t.Run("select worker with least tokens", func(t *testing.T) {
// Set up token counts
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
tc1.Add(100)
tc2 := GetOrCreateTokenCounter(ctx, "worker2")
tc2.Add(50) // Should be selected
tc3 := GetOrCreateTokenCounter(ctx, "worker3")
tc3.Add(200)
selected, err := ProcessTokensSelectWorker(ctx, workers, "test message")
assert.NoError(t, err)
assert.Equal(t, "worker2", selected)
})
t.Run("empty workers list", func(t *testing.T) {
selected, err := ProcessTokensSelectWorker(ctx, []string{}, "test")
assert.NoError(t, err)
assert.Equal(t, "", selected)
})
}
func TestRequestNumSelectWorker(t *testing.T) {
ctx := context.Background()
workers := []string{"worker1", "worker2", "worker3"}
Init(&config.Config{
Scheduler: config.SchedulerConfig{
Policy: "request_num",
},
}, nil)
t.Run("select worker with least requests", func(t *testing.T) {
// Set up request counts
c1 := GetOrCreateCounter(ctx, "worker1")
c1.Inc()
c1.Inc() // count = 2
c2 := GetOrCreateCounter(ctx, "worker2") // count = 0 (should be selected)
c3 := GetOrCreateCounter(ctx, "worker3")
c3.Inc() // count = 1
// Verify counts (use variables to avoid "declared and not used" error)
assert.Equal(t, uint64(2), c1.Get())
assert.Equal(t, uint64(0), c2.Get())
assert.Equal(t, uint64(1), c3.Get())
selected, err := RequestNumSelectWorker(ctx, workers, "test")
assert.NoError(t, err)
assert.Equal(t, "worker2", selected)
})
t.Run("empty workers list", func(t *testing.T) {
selected, err := RequestNumSelectWorker(ctx, []string{}, "test")
assert.NoError(t, err)
assert.Equal(t, "", selected)
})
}
@@ -0,0 +1,39 @@
package handler
import (
"context"
"math/rand"
)
func PowerOfTwoSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
if len(workers) == 1 {
return workers[0], nil
}
length := len(workers)
randomNum1 := rand.Intn(length)
randomNum2 := rand.Intn(length)
for randomNum2 == randomNum1 {
randomNum2 = rand.Intn(length)
}
worker1 := workers[randomNum1]
worker2 := workers[randomNum2]
counter1 := GetOrCreateCounter(ctx, worker1)
counter2 := GetOrCreateCounter(ctx, worker2)
load1 := counter1.Get()
load2 := counter2.Get()
var selectedURL string
if load1 <= load2 {
selectedURL = worker1
} else {
selectedURL = worker2
}
return selectedURL, nil
}
@@ -0,0 +1,545 @@
package handler
import (
"context"
"encoding/binary"
"errors"
"hash/fnv"
"math"
"math/rand"
"sync"
"time"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
)
type prefillCacheStrategy struct {
absThreshold float64
relThreshold float64
hitRatioWeight float64
loadBalanceWeight float64
cache *radixPrefixCache
tokenizer TokenizerClient
}
type schedulerConfigSnapshot struct {
balanceAbsThreshold float64
balanceRelThreshold float64
hitRatioWeight float64
loadBalanceWeight float64
cacheBlockSize int
tokenizerURL string
tokenizerTimeout time.Duration
}
// newPrefillCacheStrategy initializes cache-aware strategy config
func newPrefillCacheStrategy(cfg *schedulerConfigSnapshot) *prefillCacheStrategy {
return &prefillCacheStrategy{
absThreshold: cfg.balanceAbsThreshold,
relThreshold: cfg.balanceRelThreshold,
hitRatioWeight: cfg.hitRatioWeight,
loadBalanceWeight: cfg.loadBalanceWeight,
cache: newRadixPrefixCache(cfg.cacheBlockSize),
tokenizer: NewHTTPTokenizer(cfg.tokenizerURL, cfg.tokenizerTimeout),
}
}
// CacheAwarePrefillSelectWorker fallbacks to min tokens on extreme imbalance; otherwise scores by hit rate and load
func CacheAwarePrefillSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
if DefaultScheduler == nil || DefaultScheduler.prefillCache == nil {
return ProcessTokensSelectWorker(ctx, workers, message)
}
strategy := DefaultScheduler.prefillCache
// 1) Fetch node load; fallback to min tokens on extreme imbalance
loads := strategy.getRunningRequests(ctx, workers)
if strategy.isLoadImbalanced(loads) {
return ProcessTokensSelectWorker(ctx, workers, message)
}
// 2tokenize
tokens, err := strategy.tokenize(ctx, message)
if err != nil || len(tokens) == 0 {
if err != nil {
logger.Warn("cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
}
return ProcessTokensSelectWorker(ctx, workers, message)
}
// 3) Compute prefix tree hit rate
hitRatios := strategy.cache.Match(tokens, toWorkerSet(workers))
logger.Debug("cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
// 4) Compute weighted score from hit rate and load
selected := strategy.chooseByScore(ctx, workers, loads, hitRatios)
// 5) Record prefix
strategy.cache.Record(tokens, selected)
logger.Debug("cache-aware prefill: selected=%s", selected)
return selected, nil
}
// tokenize calls remote tokenizer service
func (p *prefillCacheStrategy) tokenize(ctx context.Context, message string) ([]int, error) {
if message == "" {
return nil, errors.New("empty prompt for tokenizer")
}
if p.tokenizer == nil {
// Fallback to character-based tokenization
return charsToTokens(message), nil
}
tokens, err := p.tokenizer.Tokenize(ctx, message)
if err != nil {
logger.Warn("cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
return charsToTokens(message), nil
}
logger.Debug("cache-aware prefill: tokenizer tokens=%v", tokens)
return tokens, nil
}
// isLoadImbalanced determines if load is imbalanced
func (p *prefillCacheStrategy) isLoadImbalanced(loads map[string]uint64) bool {
if len(loads) < 2 {
return false
}
maxLoad := uint64(0)
minLoad := uint64(math.MaxUint64)
for _, v := range loads {
if v > maxLoad {
maxLoad = v
}
if v < minLoad {
minLoad = v
}
}
if maxLoad == minLoad {
return false
}
diff := float64(maxLoad - minLoad)
relative := diff / float64(maxLoad)
return diff > p.absThreshold && relative > p.relThreshold
}
// chooseByScore selects worker by cache hit rate and load
func (p *prefillCacheStrategy) chooseByScore(ctx context.Context, workers []string, loads map[string]uint64, hitRatios map[string]int) string {
if len(workers) == 0 {
return ""
}
// TODO: reuse maxLoad from isLoadImbalanced
var maxLoad uint64
for _, w := range workers {
if v := loads[w]; v > maxLoad {
maxLoad = v
}
}
bestScore := math.MaxFloat64
selected := ""
for _, w := range workers {
hit := float64(hitRatios[w])
loadRatio := 0.0
if maxLoad > 0 {
loadRatio = float64(loads[w]) / float64(maxLoad)
}
score := (100.0-hit)/100*p.hitRatioWeight + loadRatio*p.loadBalanceWeight
logger.Debug("cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
if score < bestScore {
bestScore = score
selected = w
continue
}
// Tie-breaker: prefer lower token load if scores are equal
if score == bestScore && selected != "" {
selectedTokens := GetOrCreateTokenCounter(ctx, selected).Get()
currentTokens := GetOrCreateTokenCounter(ctx, w).Get()
if currentTokens < selectedTokens {
selected = w
}
}
}
return selected
}
// getRunningRequests retrieves running request metrics
func (p *prefillCacheStrategy) getRunningRequests(ctx context.Context, workers []string) map[string]uint64 {
result := make(map[string]uint64, len(workers))
if DefaultScheduler == nil || DefaultScheduler.managerAPI == nil {
return result
}
for _, w := range workers {
running, _, _ := DefaultScheduler.managerAPI.GetMetrics(ctx, w)
result[w] = uint64(running)
}
return result
}
// Track prefix hits using a radix tree keyed by block hash
type radixPrefixCache struct {
mu sync.RWMutex
root *radixNode
hasher *blockHasher
evictionDuration time.Duration
maxNodes int
nodeCount int
allNodes map[*radixNode]struct{}
}
type radixNode struct {
key []uint64
children map[uint64]*radixNode
parent *radixNode
workers map[string]time.Time
lastAccess time.Time
contextLen int
}
// newRadixPrefixCache initializes radix prefix cache with eviction and capacity control
func newRadixPrefixCache(blockSize int) *radixPrefixCache {
if blockSize <= 0 {
blockSize = 64
}
const defaultEvictionDuration = 5 * time.Minute
const defaultMaxNodes = 200000
root := &radixNode{
key: nil,
children: make(map[uint64]*radixNode),
contextLen: 0,
}
cache := &radixPrefixCache{
root: root,
hasher: newBlockHasher(blockSize),
evictionDuration: defaultEvictionDuration,
maxNodes: defaultMaxNodes,
nodeCount: 1, // root
allNodes: map[*radixNode]struct{}{root: {}},
}
go cache.evictionWorker(cache.evictionDuration / 2)
return cache
}
// Match returns prefix hit rate per candidate worker (0100)
func (c *radixPrefixCache) Match(tokens []int, allowed map[string]struct{}) map[string]int {
result := make(map[string]int)
hashes := c.hasher.prefixHashes(tokens)
if len(hashes) == 0 {
return result
}
c.mu.RLock()
node, matched := c.matchPrefixHelper(c.root, hashes)
length := matched
logger.Debug("radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
for n := node; n != nil; n = n.parent {
ratio := 0
if len(hashes) > 0 {
ratio = length * 100 / len(hashes)
}
for w := range n.workers {
if allowed != nil {
if _, ok := allowed[w]; !ok {
continue
}
}
if ratio > result[w] {
result[w] = ratio
}
}
if len(result) > 0 {
break
}
if n.parent != nil {
length = n.parent.contextLen
}
}
c.mu.RUnlock()
return result
}
// Record inserts block-hash prefix into radix tree and tags worker
func (c *radixPrefixCache) Record(tokens []int, worker string) {
if worker == "" {
return
}
hashes := c.hasher.prefixHashes(tokens)
if len(hashes) == 0 {
return
}
c.mu.Lock()
defer c.mu.Unlock()
node := c.insertHelper(c.root, hashes)
now := time.Now()
for n := node; n != nil; n = n.parent {
if n.workers == nil {
n.workers = make(map[string]time.Time)
}
n.workers[worker] = now
}
logger.Debug("radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
}
// evictionWorker periodically evicts inactive nodes
func (c *radixPrefixCache) evictionWorker(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
<-ticker.C
c.evictExpired()
}
}
func (c *radixPrefixCache) evictExpired() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
removed := 0
for childKey, child := range c.root.children {
removed += c.evictSubtreeIfExpired(c.root, childKey, child, now)
}
if removed > 0 {
logger.Debug("radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
}
}
// evictSubtreeIfExpired evicts expired nodes and subtrees, returns count of removed nodes
func (c *radixPrefixCache) evictSubtreeIfExpired(parent *radixNode, childKey uint64, node *radixNode, now time.Time) int {
// Process child nodes first
removed := 0
for k, child := range node.children {
removed += c.evictSubtreeIfExpired(node, k, child, now)
}
// Do not delete root node
if parent == nil {
return removed
}
if now.Sub(node.lastAccess) <= c.evictionDuration {
return removed
}
// Delete expired node and its subtree
if parent != nil {
delete(parent.children, childKey)
}
removedSubtree := c.countSubtree(node)
c.nodeCount -= removedSubtree
if c.nodeCount < 1 {
c.nodeCount = 1 // At least include root
}
c.removeSubtreeFromAll(node)
return removed + removedSubtree
}
// countSubtree counts nodes in subtree rooted at node
func (c *radixPrefixCache) countSubtree(node *radixNode) int {
count := 1
for _, child := range node.children {
count += c.countSubtree(child)
}
return count
}
// removeSubtreeFromAll removes subtree references from allNodes
func (c *radixPrefixCache) removeSubtreeFromAll(node *radixNode) {
if node == nil {
return
}
delete(c.allNodes, node)
for _, child := range node.children {
c.removeSubtreeFromAll(child)
}
// Release references for GC
node.children = nil
node.parent = nil
node.workers = nil
}
// matchPrefixHelper finds longest common prefix node in radix tree
func (c *radixPrefixCache) matchPrefixHelper(node *radixNode, hashes []uint64) (*radixNode, int) {
if len(hashes) == 0 {
return node, node.contextLen
}
if child, ok := node.children[hashes[0]]; ok {
prefixLen := matchUint64Len(child.key, hashes)
if prefixLen > 0 {
if prefixLen == len(child.key) {
if prefixLen == len(hashes) {
return child, child.contextLen
}
if deeperNode, deeperMatched := c.matchPrefixHelper(child, hashes[prefixLen:]); deeperNode != nil && deeperMatched > 0 {
return deeperNode, deeperMatched
}
return child, child.contextLen
}
return child, node.contextLen + prefixLen
}
}
return node, node.contextLen
}
// insertHelper inserts hash sequence into radix tree, splits nodes if needed
func (c *radixPrefixCache) insertHelper(node *radixNode, key []uint64) *radixNode {
node.lastAccess = time.Now()
if len(key) == 0 {
return node
}
if child, ok := node.children[key[0]]; ok {
prefixLen := matchUint64Len(child.key, key)
if prefixLen == len(child.key) {
if prefixLen == len(key) {
child.lastAccess = time.Now()
return child
}
return c.insertHelper(child, key[prefixLen:])
}
// Partial match, split required
newNode := c.splitNode(node, child, prefixLen)
if prefixLen == len(key) {
return newNode
}
return c.insertHelper(newNode, key[prefixLen:])
}
// No matching child, create new node and add to children
newNode := newRadixNode(node, key)
node.children[key[0]] = newNode
c.nodeCount++
c.allNodes[newNode] = struct{}{}
c.maybeEvictLocked()
return newNode
}
func (c *radixPrefixCache) splitNode(parent *radixNode, child *radixNode, prefixLen int) *radixNode {
commonKey := append([]uint64{}, child.key[:prefixLen]...)
newNode := newRadixNode(parent, commonKey)
parent.children[commonKey[0]] = newNode
// Adjust atomic node
child.key = append([]uint64{}, child.key[prefixLen:]...)
child.parent = newNode
child.contextLen = newNode.contextLen + len(child.key)
if len(child.key) > 0 {
newNode.children[child.key[0]] = child
}
return newNode
}
// maybeEvictLocked checks node count under write lock and evicts expired nodes if over capacity
func (c *radixPrefixCache) maybeEvictLocked() {
if c.maxNodes <= 0 || c.nodeCount <= c.maxNodes {
return
}
c.evictExpired()
// TODO: implement stronger eviction if still over capacity (e.g., evict oldest by lastAccess)
}
// newRadixNode creates radix tree node and computes context length
func newRadixNode(parent *radixNode, key []uint64) *radixNode {
n := &radixNode{
key: append([]uint64{}, key...),
children: make(map[uint64]*radixNode),
parent: parent,
lastAccess: time.Now(),
}
if parent != nil {
n.contextLen = parent.contextLen + len(key)
} else {
n.contextLen = len(key)
}
return n
}
type blockHasher struct {
blockSize int
seed uint64
}
// newBlockHasher creates and initializes a new block hasher
func newBlockHasher(blockSize int) *blockHasher {
if blockSize <= 0 {
blockSize = 64
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
return &blockHasher{
blockSize: blockSize,
seed: r.Uint64(),
}
}
// prefixHashes generates parent-chain hash sequence by block
func (h *blockHasher) prefixHashes(tokens []int) []uint64 {
if h.blockSize <= 0 || len(tokens) < h.blockSize {
return nil
}
blockCount := len(tokens) / h.blockSize
hashes := make([]uint64, 0, blockCount)
parent := h.seed
buf := make([]byte, 8)
for i := 0; i+h.blockSize <= len(tokens); i += h.blockSize {
hasher := fnv.New64a()
binary.LittleEndian.PutUint64(buf, parent)
_, _ = hasher.Write(buf)
for _, token := range tokens[i : i+h.blockSize] {
binary.LittleEndian.PutUint64(buf, uint64(token))
_, _ = hasher.Write(buf)
}
current := hasher.Sum64()
hashes = append(hashes, current)
parent = current
}
return hashes
}
func matchUint64Len(a, b []uint64) int {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
i := 0
for i < minLen && a[i] == b[i] {
i++
}
return i
}
func charsToTokens(message string) []int {
tokens := make([]int, 0, len(message))
for _, r := range message {
tokens = append(tokens, int(r))
}
return tokens
}
func toWorkerSet(workers []string) map[string]struct{} {
set := make(map[string]struct{}, len(workers))
for _, w := range workers {
set[w] = struct{}{}
}
return set
}
@@ -0,0 +1,15 @@
package handler
import (
"context"
"math/rand"
)
func RandomSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
randomNum := rand.Intn(len(workers))
return workers[randomNum], nil
}
@@ -0,0 +1,21 @@
package handler
import (
"context"
)
func RoundRobinSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
if len(workers) == 0 {
return "", nil
}
var count uint64
if DefaultCounterPolicy.workerType == "prefill" {
count = DefaultCounterPolicy.prefillCounter.Add(1) - 1
} else {
count = DefaultCounterPolicy.counter.Add(1) - 1
}
selectedNum := count % uint64(len(workers))
return workers[selectedNum], nil
}
@@ -0,0 +1,108 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
urlpkg "net/url"
"strings"
"time"
)
// Abstract remote tokenizer interface
type TokenizerClient interface {
Tokenize(ctx context.Context, prompt string) ([]int, error)
}
// Implement HTTP /tokenize call
type httpTokenizer struct {
url string
timeout time.Duration
}
// Return HTTP-based tokenizer client
func NewHTTPTokenizer(rawURL string, timeout time.Duration) TokenizerClient {
if rawURL == "" {
return nil
}
if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") {
rawURL = "http://" + rawURL
}
parsed, err := urlpkg.Parse(rawURL)
if err == nil {
if parsed.Path == "" || parsed.Path == "/" {
parsed.Path = "/tokenize"
}
rawURL = parsed.String()
}
if timeout <= 0 {
timeout = 2 * time.Second
}
return &httpTokenizer{
url: rawURL,
timeout: timeout,
}
}
type tokenizerHTTPReq struct {
Text string `json:"text,omitempty"`
Prompt string `json:"prompt,omitempty"`
Message string `json:"message,omitempty"`
}
func (c *httpTokenizer) Tokenize(ctx context.Context, prompt string) ([]int, error) {
if c == nil {
return nil, errors.New("tokenizer client is nil")
}
payload := tokenizerHTTPReq{Text: prompt, Prompt: prompt, Message: prompt}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal tokenizer request failed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("build tokenizer request failed: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: c.timeout}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("tokenizer request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read tokenizer response failed: %w", err)
}
if resp.StatusCode >= 300 {
return nil, fmt.Errorf("tokenizer response status %d: %s", resp.StatusCode, string(respBody))
}
tokens, err := parseTokensFromBody(respBody)
if err != nil {
return nil, fmt.Errorf("parse tokenizer response failed: %w", err)
}
return tokens, nil
}
// Parse tokens from body JSON {"input_ids": []}
func parseTokensFromBody(body []byte) ([]int, error) {
var input struct {
InputIDs []int `json:"input_ids"`
}
if err := json.Unmarshal(body, &input); err == nil {
if len(input.InputIDs) > 0 {
return input.InputIDs, nil
}
}
return nil, errors.New("tokenizer response missing tokens")
}
@@ -0,0 +1,650 @@
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
func TestNewHTTPTokenizer(t *testing.T) {
tests := []struct {
name string
rawURL string
timeout time.Duration
expected string
wantNil bool
}{
{
name: "empty URL should return nil",
rawURL: "",
timeout: time.Second,
wantNil: true,
},
{
name: "URL without scheme should add http://",
rawURL: "example.com",
timeout: time.Second,
expected: "http://example.com/tokenize",
},
{
name: "URL with http scheme should keep it",
rawURL: "http://example.com",
timeout: time.Second,
expected: "http://example.com/tokenize",
},
{
name: "URL with https scheme should keep it",
rawURL: "https://example.com",
timeout: time.Second,
expected: "https://example.com/tokenize",
},
{
name: "URL with path should keep it",
rawURL: "http://example.com/api",
timeout: time.Second,
expected: "http://example.com/api",
},
{
name: "URL with root path should replace with /tokenize",
rawURL: "http://example.com/",
timeout: time.Second,
expected: "http://example.com/tokenize",
},
{
name: "URL with empty path should add /tokenize",
rawURL: "http://example.com",
timeout: time.Second,
expected: "http://example.com/tokenize",
},
{
name: "invalid URL should still work with added scheme",
rawURL: "example.com:invalid:port",
timeout: time.Second,
expected: "http://example.com:invalid:port",
},
{
name: "zero timeout should use default 2s",
rawURL: "example.com",
timeout: 0,
expected: "http://example.com/tokenize",
},
{
name: "negative timeout should use default 2s",
rawURL: "example.com",
timeout: -time.Second,
expected: "http://example.com/tokenize",
},
{
name: "URL with port and path should be preserved",
rawURL: "http://example.com:8080/v1",
timeout: time.Second,
expected: "http://example.com:8080/v1",
},
{
name: "URL with query parameters should be preserved",
rawURL: "https://example.com?token=abc",
timeout: time.Second,
expected: "https://example.com/tokenize?token=abc",
},
{
name: "URL with fragment should be preserved",
rawURL: "http://example.com#section",
timeout: time.Second,
expected: "http://example.com/tokenize#section",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := NewHTTPTokenizer(tt.rawURL, tt.timeout)
if tt.wantNil {
if client != nil {
t.Errorf("Expected nil client for empty URL, got %v", client)
}
return
}
if client == nil {
t.Fatal("Expected non-nil client, got nil")
}
httpClient, ok := client.(*httpTokenizer)
if !ok {
t.Fatalf("Expected *httpTokenizer, got %T", client)
}
if httpClient.url != tt.expected {
t.Errorf("Expected URL %q, got %q", tt.expected, httpClient.url)
}
expectedTimeout := tt.timeout
if expectedTimeout <= 0 {
expectedTimeout = 2 * time.Second
}
if httpClient.timeout != expectedTimeout {
t.Errorf("Expected timeout %v, got %v", expectedTimeout, httpClient.timeout)
}
})
}
}
func TestNewHTTPTokenizer_URLParsing(t *testing.T) {
// Test that URL parsing preserves all components
rawURL := "https://user:pass@example.com:8080/path?query=value#fragment"
client := NewHTTPTokenizer(rawURL, time.Second)
if client == nil {
t.Fatal("Expected non-nil client, got nil")
}
httpClient := client.(*httpTokenizer)
parsed, err := url.Parse(httpClient.url)
if err != nil {
t.Fatalf("Failed to parse client URL: %v", err)
}
if parsed.Scheme != "https" {
t.Errorf("Expected scheme https, got %q", parsed.Scheme)
}
if parsed.Host != "example.com:8080" {
t.Errorf("Expected host example.com:8080, got %q", parsed.Host)
}
if parsed.Path != "/path" {
t.Errorf("Expected path /path, got %q", parsed.Path)
}
if parsed.RawQuery != "query=value" {
t.Errorf("Expected query query=value, got %q", parsed.RawQuery)
}
if parsed.Fragment != "fragment" {
t.Errorf("Expected fragment fragment, got %q", parsed.Fragment)
}
if parsed.User.String() != "user:pass" {
t.Errorf("Expected user user:pass, got %q", parsed.User)
}
}
func TestNewHTTPTokenizer_ImplementsInterface(t *testing.T) {
client := NewHTTPTokenizer("example.com", time.Second)
if client == nil {
t.Fatal("Expected non-nil client")
}
// Verify that the returned client implements the TokenizerClient interface
var _ TokenizerClient = client
// Test type assertion
_, ok := client.(TokenizerClient)
if !ok {
t.Error("Returned client does not implement TokenizerClient interface")
}
}
func TestHTTPTokenizer_Tokenize(t *testing.T) {
t.Run("nil tokenizer client", func(t *testing.T) {
var tokenizer *httpTokenizer = nil
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for nil tokenizer client")
}
if err.Error() != "tokenizer client is nil" {
t.Errorf("Expected 'tokenizer client is nil', got '%v'", err.Error())
}
})
t.Run("successful tokenization", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// httptest server uses root path by default
if r.URL.Path != "/" {
t.Errorf("Expected path /, got %s", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
}
// Verify request body
var req tokenizerHTTPReq
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
t.Fatalf("Failed to decode request: %v", err)
}
if req.Text != "test prompt" {
t.Errorf("Expected Text 'test prompt', got '%s'", req.Text)
}
if req.Prompt != "test prompt" {
t.Errorf("Expected Prompt 'test prompt', got '%s'", req.Prompt)
}
if req.Message != "test prompt" {
t.Errorf("Expected Message 'test prompt', got '%s'", req.Message)
}
// Send successful response
response := map[string]interface{}{
"input_ids": []int{1, 2, 3, 4, 5},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
tokens, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err != nil {
t.Fatalf("Tokenize failed: %v", err)
}
expectedTokens := []int{1, 2, 3, 4, 5}
if len(tokens) != len(expectedTokens) {
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
}
for i, token := range tokens {
if token != expectedTokens[i] {
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
}
}
})
t.Run("http request creation failure", func(t *testing.T) {
tokenizer := &httpTokenizer{
url: "://invalid-url", // Invalid URL to force http.NewRequest to fail
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for invalid URL")
}
if !strings.Contains(err.Error(), "build tokenizer request failed") {
t.Errorf("Expected error to contain 'build tokenizer request failed', got '%v'", err.Error())
}
})
t.Run("http request timeout", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond) // Simulate slow response
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"input_ids": []int{1}})
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 10 * time.Millisecond, // Very short timeout
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for timeout")
}
if !strings.Contains(err.Error(), "tokenizer request failed") {
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
}
})
t.Run("http connection failure", func(t *testing.T) {
tokenizer := &httpTokenizer{
url: "http://invalid-server:9999", // Non-existent server
timeout: 1 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for connection failure")
}
if !strings.Contains(err.Error(), "tokenizer request failed") {
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
}
})
t.Run("http response status error", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("internal server error"))
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for status code 500")
}
if !strings.Contains(err.Error(), "tokenizer response status 500") {
t.Errorf("Expected error to contain 'tokenizer response status 500', got '%v'", err.Error())
}
if !strings.Contains(err.Error(), "internal server error") {
t.Errorf("Expected error to contain response body, got '%v'", err.Error())
}
})
t.Run("invalid json response", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("invalid json content"))
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "parse tokenizer response failed") {
t.Errorf("Expected error to contain 'parse tokenizer response failed', got '%v'", err.Error())
}
})
t.Run("empty tokens in response", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"input_ids": []int{},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for empty tokens")
}
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
}
})
t.Run("missing input_ids field", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"tokens": []int{1, 2, 3}, // Wrong field name
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
if err == nil {
t.Error("Expected error for missing input_ids")
}
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
}
})
t.Run("context cancellation", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond) // Simulate slow processing
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{"input_ids": []int{1}})
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := tokenizer.Tokenize(ctx, "test prompt")
if err == nil {
t.Error("Expected error for cancelled context")
}
// Context cancellation should cause request to fail
if !strings.Contains(err.Error(), "tokenizer request failed") {
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
}
})
t.Run("empty prompt", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req tokenizerHTTPReq
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
t.Fatalf("Failed to decode request: %v", err)
}
if req.Text != "" {
t.Errorf("Expected empty Text, got '%s'", req.Text)
}
if req.Prompt != "" {
t.Errorf("Expected empty Prompt, got '%s'", req.Prompt)
}
if req.Message != "" {
t.Errorf("Expected empty Message, got '%s'", req.Message)
}
response := map[string]interface{}{
"input_ids": []int{},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
_, err := tokenizer.Tokenize(context.Background(), "")
if err == nil {
t.Error("Expected error for empty prompt")
}
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
}
})
t.Run("very long prompt", func(t *testing.T) {
longPrompt := string(make([]byte, 10000)) // 10KB prompt
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req tokenizerHTTPReq
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
t.Fatalf("Failed to decode request: %v", err)
}
if req.Text != longPrompt {
t.Error("Request text does not match long prompt")
}
response := map[string]interface{}{
"input_ids": []int{1, 2, 3},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
tokens, err := tokenizer.Tokenize(context.Background(), longPrompt)
if err != nil {
t.Fatalf("Tokenize failed for long prompt: %v", err)
}
expectedTokens := []int{1, 2, 3}
if len(tokens) != len(expectedTokens) {
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
}
for i, token := range tokens {
if token != expectedTokens[i] {
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
}
}
})
t.Run("special characters in prompt", func(t *testing.T) {
specialPrompt := "Hello, 世界! 🚀 Test with emoji and unicode: ñáéíóú"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req tokenizerHTTPReq
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
t.Fatalf("Failed to decode request: %v", err)
}
if req.Text != specialPrompt {
t.Errorf("Expected Text '%s', got '%s'", specialPrompt, req.Text)
}
response := map[string]interface{}{
"input_ids": []int{10, 20, 30, 40},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
tokenizer := &httpTokenizer{
url: server.URL,
timeout: 2 * time.Second,
}
tokens, err := tokenizer.Tokenize(context.Background(), specialPrompt)
if err != nil {
t.Fatalf("Tokenize failed for special characters: %v", err)
}
expectedTokens := []int{10, 20, 30, 40}
if len(tokens) != len(expectedTokens) {
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
}
for i, token := range tokens {
if token != expectedTokens[i] {
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
}
}
})
}
func TestParseTokensFromBody(t *testing.T) {
tests := []struct {
name string
input []byte
expected []int
err error
}{
{
name: "valid input with tokens",
input: []byte(`{"input_ids": [1, 2, 3]}`),
expected: []int{1, 2, 3},
err: nil,
},
{
name: "empty input_ids array",
input: []byte(`{"input_ids": []}`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "missing input_ids field",
input: []byte(`{"other_field": "value"}`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "invalid JSON format",
input: []byte(`invalid json`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "empty body",
input: []byte(``),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "large array of tokens",
input: []byte(`{"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}`),
expected: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
err: nil,
},
{
name: "null input_ids",
input: []byte(`{"input_ids": null}`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "non-array input_ids",
input: []byte(`{"input_ids": "not an array"}`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
{
name: "malformed array",
input: []byte(`{"input_ids": [1, "two", 3]}`),
expected: nil,
err: errors.New("tokenizer response missing tokens"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseTokensFromBody(tt.input)
// Check if error is expected
if (err != nil && tt.err == nil) || (err == nil && tt.err != nil) {
t.Errorf("parseTokensFromBody() error = %v, wantErr %v", err, tt.err)
return
}
if err != nil && tt.err != nil && err.Error() != tt.err.Error() {
t.Errorf("parseTokensFromBody() error message = %v, want %v", err.Error(), tt.err.Error())
return
}
// Compare actual and expected results
if len(got) != len(tt.expected) {
t.Errorf("parseTokensFromBody() = %v, want %v", got, tt.expected)
return
}
for i := range got {
if got[i] != tt.expected[i] {
t.Errorf("parseTokensFromBody() = %v, want %v", got, tt.expected)
return
}
}
})
}
}
@@ -0,0 +1,80 @@
package logger
import (
"log"
"os"
"sync"
)
var (
infoLogger *log.Logger
errorLogger *log.Logger
warnLogger *log.Logger
debugLogger *log.Logger
level string
once sync.Once
logFile *os.File
)
// Init initialize logger
func Init(logLevel, output string) {
once.Do(func() {
level = logLevel
flags := log.LstdFlags | log.Lshortfile
if output == "file" {
// Check if logs directory exists
if _, err := os.Stat("logs"); os.IsNotExist(err) {
if err := os.MkdirAll("logs", 0755); err != nil {
log.Fatalln("Failed to create logs directory:", err)
}
}
logFile, err := os.OpenFile("logs/router.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatalln("Failed to open log file:", err)
}
infoLogger = log.New(logFile, "[INFO] ", flags)
errorLogger = log.New(logFile, "[ERROR] ", flags)
warnLogger = log.New(logFile, "[WARN] ", flags)
debugLogger = log.New(logFile, "[DEBUG] ", flags)
} else {
infoLogger = log.New(os.Stdout, "[INFO] ", flags)
errorLogger = log.New(os.Stderr, "[ERROR] ", flags)
warnLogger = log.New(os.Stdout, "[WARN] ", flags)
debugLogger = log.New(os.Stdout, "[DEBUG] ", flags)
}
})
}
func CloseLogFile() {
if logFile != nil {
logFile.Close()
}
}
// Info logs informational messages
func Info(format string, v ...interface{}) {
if level == "debug" || level == "info" {
infoLogger.Printf(format, v...)
}
}
// Error logs error messages
func Error(format string, v ...interface{}) {
errorLogger.Printf(format, v...)
}
// Warn logs warning messages
func Warn(format string, v ...interface{}) {
if level == "debug" || level == "info" || level == "warn" {
warnLogger.Printf(format, v...)
}
}
// Debug logs debug messages
func Debug(format string, v ...interface{}) {
if level == "debug" {
debugLogger.Printf(format, v...)
}
}
@@ -0,0 +1,127 @@
package logger
import (
"bytes"
"os"
"strings"
"testing"
)
func TestLoggerInit(t *testing.T) {
t.Run("stdout output", func(t *testing.T) {
Init("debug", "stdout")
if infoLogger == nil || errorLogger == nil || warnLogger == nil || debugLogger == nil {
t.Error("Loggers should be initialized")
}
})
t.Run("file output", func(t *testing.T) {
// Clean up existing log file and directory
_ = os.RemoveAll("logs")
_ = os.MkdirAll("logs", 0755)
defer os.RemoveAll("logs")
Init("debug", "file")
if _, err := os.Stat("logs/router.log"); os.IsNotExist(err) {
t.Error("Log file should be created")
}
})
}
func TestLogLevels(t *testing.T) {
tests := []struct {
name string
level string
expected map[string]bool
}{
{"debug level", "debug", map[string]bool{
"debug": true,
"info": true,
"warn": true,
"error": true,
}},
{"info level", "info", map[string]bool{
"debug": false,
"info": true,
"warn": true,
"error": true,
}},
{"warn level", "warn", map[string]bool{
"debug": false,
"info": false,
"warn": true,
"error": true,
}},
{"error level", "error", map[string]bool{
"debug": false,
"info": false,
"warn": false,
"error": true,
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize logger with test level
Init(tt.level, "stdout")
// Capture output for each level separately
testLevel := func(logFunc func(string, ...interface{}), message string) bool {
var buf bytes.Buffer
oldOutput := infoLogger.Writer()
infoLogger.SetOutput(&buf)
errorLogger.SetOutput(&buf)
warnLogger.SetOutput(&buf)
debugLogger.SetOutput(&buf)
logFunc(message)
infoLogger.SetOutput(oldOutput)
errorLogger.SetOutput(oldOutput)
warnLogger.SetOutput(oldOutput)
debugLogger.SetOutput(oldOutput)
return strings.Contains(buf.String(), message)
}
debugPrinted := testLevel(Debug, "debug message")
infoPrinted := testLevel(Info, "info message")
warnPrinted := testLevel(Warn, "warn message")
errorPrinted := testLevel(Error, "error message")
// Check expected behavior
if tt.expected["debug"] != debugPrinted {
t.Errorf("Debug log: expected %v, got %v", tt.expected["debug"], debugPrinted)
}
if tt.expected["info"] != infoPrinted {
t.Errorf("Info log: expected %v, got %v", tt.expected["info"], infoPrinted)
}
if tt.expected["warn"] != warnPrinted {
t.Errorf("Warn log: expected %v, got %v", tt.expected["warn"], warnPrinted)
}
if tt.expected["error"] != errorPrinted {
t.Errorf("Error log: expected %v, got %v", tt.expected["error"], errorPrinted)
}
})
}
}
func TestLogFunctions(t *testing.T) {
var buf bytes.Buffer
Init("debug", "stdout")
// Redirect output
oldOutput := infoLogger.Writer()
defer func() { infoLogger.SetOutput(oldOutput) }()
infoLogger.SetOutput(&buf)
Info("test %s", "message")
if !strings.Contains(buf.String(), "test message") {
t.Error("Info log should contain the message")
}
// Similar tests for Error, Warn, Debug...
}
@@ -0,0 +1,39 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
)
func init() {
prometheus.MustRegister(TotalRequests)
prometheus.MustRegister(InferenceRequests)
prometheus.MustRegister(RequestDuration)
}
// TotalRequests tracks the total number of HTTP requests
var TotalRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "endpoint", "status_code"},
)
// InferenceRequests tracks the total number of inference requests
var InferenceRequests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "inference_requests_total",
Help: "Total number of inference requests",
},
[]string{"mixed_worker", "prefill_worker", "decode_worker", "status_code"},
)
// RequestDuration tracks the response latency of HTTP requests
var RequestDuration = prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Name: "http_request_duration_seconds",
Help: "Summary of the response latency (seconds) of HTTP requests",
Objectives: map[float64]float64{0.95: 0.01, 0.99: 0.01}, // Objectives define the required quantiles
},
[]string{"method", "endpoint"},
)
@@ -0,0 +1,45 @@
package metrics
import (
"strings"
"testing"
"github.com/prometheus/client_golang/prometheus"
)
func TestMetricsInitialization(t *testing.T) {
// Verify all metrics are registered
metrics := []prometheus.Collector{
TotalRequests,
InferenceRequests,
RequestDuration,
}
for _, metric := range metrics {
if err := prometheus.Register(metric); err == nil {
t.Errorf("Metric %T should already be registered", metric)
}
}
}
func TestMetricsHelpText(t *testing.T) {
tests := []struct {
name string
metric prometheus.Collector
expected string
}{
{"TotalRequests", TotalRequests, "Total number of HTTP requests"},
{"InferenceRequests", InferenceRequests, "Total number of inference requests"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
desc := make(chan *prometheus.Desc, 1)
tt.metric.Describe(desc)
d := <-desc
if !strings.Contains(d.String(), tt.expected) {
t.Errorf("Expected help text to contain '%s', got '%s'", tt.expected, d.String())
}
})
}
}
+11
View File
@@ -0,0 +1,11 @@
#!/bin/bash
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
echo "Killing existing fd-router process (PID: $PID)"
kill -15 $PID
fi
echo "Starting new fd-router process..."
nohup /usr/local/bin/fd-router --port 8080 --splitwise > fd-router.log 2>&1 &
echo "fd-router started with PID: $!"
+1
View File
@@ -45,3 +45,4 @@ omit =
*/fastdeploy/worker/iluvatar*.py */fastdeploy/worker/iluvatar*.py
*/fastdeploy/**/xpu/* */fastdeploy/**/xpu/*
*/fastdeploy/worker/xpu*.py */fastdeploy/worker/xpu*.py
*/fastdeploy/golang_router/*