mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Feature] Add Golang-based Router for Request Scheduling and Load Balancing (#5882)
* [Feature] add golang router * [Feature] add golang router * [Feature] add golang router * [Feature] add golang router * [Feature] add golang router * [Feature] Add Golang-based Router for Request Scheduling and Load Balancing * [Feature] Add Golang-based Router for Request Scheduling and Load Balancing * [Feature] Add Golang-based Router for Request Scheduling and Load Balancing * [Feature] Add Golang-based Router for Request Scheduling and Load Balancing --------- Co-authored-by: mouxin <mouxin@baidu.com>
This commit is contained in:
@@ -0,0 +1,107 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Test splitwise deployment
|
||||||
|
# There are two methods for splitwise deployment:
|
||||||
|
# v0: using splitwise_scheduler or dp_scheduler
|
||||||
|
# v1: using local_scheduler + router
|
||||||
|
# v2: using local_scheduler + golang_router
|
||||||
|
|
||||||
|
# prepare environment
|
||||||
|
export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle"
|
||||||
|
export FD_DEBUG=1
|
||||||
|
|
||||||
|
SCRIPT_PATH=$(readlink -f "$0")
|
||||||
|
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||||
|
source ${SCRIPT_DIR}/utils.sh
|
||||||
|
|
||||||
|
unset http_proxy && unset https_proxy
|
||||||
|
|
||||||
|
P_PORT=52400
|
||||||
|
D_PORT=52500
|
||||||
|
ROUTER_PORT=52700
|
||||||
|
LOG_DATE=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
FD_BIN_DIR="/usr/local/bin"
|
||||||
|
FD_ROUTER_BIN="${FD_BIN_DIR}/fd-router"
|
||||||
|
FD_ROUTER_URL="https://paddle-qa.bj.bcebos.com/FastDeploy/fd-router"
|
||||||
|
FD_ROUTER_SHA256="67640aaeebdd886826d3534930b2154cd2c1441a26bc3f38c3af5f0aadba7c2d"
|
||||||
|
|
||||||
|
ports=($P_PORT $D_PORT $ROUTER_PORT)
|
||||||
|
check_ports "${ports[@]}" || {
|
||||||
|
echo "❌ Some ports are in use. Please release them."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# check fd-router binary
|
||||||
|
if [ ! -x "${FD_ROUTER_BIN}" ]; then
|
||||||
|
echo "⚠️ fd-router not found, downloading..."
|
||||||
|
|
||||||
|
mkdir -p "${FD_BIN_DIR}"
|
||||||
|
TMP_BIN="${FD_ROUTER_BIN}.tmp"
|
||||||
|
|
||||||
|
wget -q --no-proxy "${FD_ROUTER_URL}" -O "${TMP_BIN}" || exit 1
|
||||||
|
|
||||||
|
echo "${FD_ROUTER_SHA256} ${TMP_BIN}" | sha256sum -c - || {
|
||||||
|
echo "❌ Integrity check failed"
|
||||||
|
rm -f "${TMP_BIN}"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
mv "${TMP_BIN}" "${FD_ROUTER_BIN}"
|
||||||
|
chmod +x "${FD_ROUTER_BIN}"
|
||||||
|
|
||||||
|
echo "fd-router installed and verified"
|
||||||
|
else
|
||||||
|
echo "fd-router already exists"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# start router
|
||||||
|
export FD_LOG_DIR="log/$LOG_DATE/router"
|
||||||
|
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||||
|
|
||||||
|
nohup /usr/local/bin/fd-router \
|
||||||
|
--port ${ROUTER_PORT} \
|
||||||
|
--splitwise \
|
||||||
|
2>&1 >${FD_LOG_DIR}/nohup &
|
||||||
|
|
||||||
|
# start prefill
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
export FD_LOG_DIR="log/$LOG_DATE/prefill"
|
||||||
|
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||||
|
|
||||||
|
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model ${MODEL_NAME} \
|
||||||
|
--port "${P_PORT}" \
|
||||||
|
--splitwise-role "prefill" \
|
||||||
|
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||||
|
2>&1 >${FD_LOG_DIR}/nohup &
|
||||||
|
|
||||||
|
wait_for_health ${P_PORT}
|
||||||
|
|
||||||
|
# start decode
|
||||||
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
export FD_LOG_DIR="log/$LOG_DATE/decode"
|
||||||
|
rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR}
|
||||||
|
|
||||||
|
nohup python -m fastdeploy.entrypoints.openai.api_server \
|
||||||
|
--model ${MODEL_NAME} \
|
||||||
|
--port "${D_PORT}" \
|
||||||
|
--splitwise-role "decode" \
|
||||||
|
--router "0.0.0.0:${ROUTER_PORT}" \
|
||||||
|
2>&1 >${FD_LOG_DIR}/nohup &
|
||||||
|
|
||||||
|
wait_for_health ${D_PORT}
|
||||||
|
|
||||||
|
# send request
|
||||||
|
sleep 10 # make sure server is registered to router
|
||||||
|
echo "send request..."
|
||||||
|
curl -X POST "http://0.0.0.0:${ROUTER_PORT}/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hello"}
|
||||||
|
],
|
||||||
|
"max_tokens": 100,
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
HOMEDIR := $(shell pwd)
|
||||||
|
OUTDIR := /usr/local/bin
|
||||||
|
|
||||||
|
export GOENV = $(HOMEDIR)/go.env
|
||||||
|
|
||||||
|
GOPKGS := $$(go list ./...| grep -vE "vendor")
|
||||||
|
|
||||||
|
# make, make all
|
||||||
|
all: prepare compile
|
||||||
|
|
||||||
|
prepare:
|
||||||
|
git version # 低于 2.17.1 可能不能正常工作
|
||||||
|
go env
|
||||||
|
go mod download || go mod download -x # 下载 依赖
|
||||||
|
|
||||||
|
#make compile
|
||||||
|
compile: build
|
||||||
|
build:
|
||||||
|
go build -o $(OUTDIR)/fd-router ./cmd
|
||||||
|
|
||||||
|
# make test, test your code
|
||||||
|
test: prepare
|
||||||
|
go test -race -timeout=120s -v -cover $(GOPKGS) -coverprofile=coverage.out | tee unittest.txt
|
||||||
|
|
||||||
|
# make clean
|
||||||
|
clean:
|
||||||
|
go clean
|
||||||
|
rm -rf $(OUTDIR)
|
||||||
|
|
||||||
|
# avoid filename conflict and speed up build
|
||||||
|
.PHONY: all prepare compile test package clean build
|
||||||
@@ -0,0 +1,184 @@
|
|||||||
|
# Golang-Router
|
||||||
|
## 关于
|
||||||
|
【正在开发迭代中】
|
||||||
|
Golang-Router 是一个面向大语言模型推理系统的高性能 Golang 路由框架,作为系统的**控制与调度平面**运行,负责请求接入、实例选择与流量转发,设计上适配 Prefill–Decode(PD)分离推理架构。
|
||||||
|
|
||||||
|
Golang-Router 可独立部署运行,也可通过 HTTP 接口与 FastDeploy 推理实例协同工作。框架提供基础而稳定的路由、中间件扩展与健康检查能力,适用于单点推理部署场景,并在架构层面为后续的水平扩展与调度能力演进预留空间。
|
||||||
|
|
||||||
|
### 背景与动机
|
||||||
|
在大语言模型推理系统中,路由组件已从传统的流量转发层演进为影响系统性能与资源利用效率的关键基础设施。随着 Prefill–Decode 分离推理架构的广泛采用,不同推理阶段在计算特征、显存占用与缓存行为方面呈现出明显差异,仅依赖请求级静态信息进行调度已难以满足稳定性与效率需求。
|
||||||
|
|
||||||
|
在保持请求级调度模型不变的前提下,引入更细粒度的运行时信号辅助调度决策,成为提升调度能力与系统可预测性的工程共识。Golang-Router 正是在这一背景下构建,作为独立的路由与调度组件,为推理系统提供清晰、可扩展的控制平面。
|
||||||
|
|
||||||
|
### 设计目标
|
||||||
|
Golang-Router 聚焦解决以下核心问题:
|
||||||
|
- **调度决策信息不足**
|
||||||
|
传统 Router 通常仅基于请求级元信息或粗粒度实例状态进行调度,难以利用推理过程中产生的细粒度缓存相关信号,从而限制了 cache-aware 策略的实际效果。
|
||||||
|
- **调度逻辑与推理执行强耦合**
|
||||||
|
路由与调度逻辑内嵌于推理框架内部,增加了系统复杂度,限制了调度策略的独立演进与复用能力。
|
||||||
|
- **高并发场景下的可扩展性挑战**
|
||||||
|
在高并发推理负载下,实例状态维护与实例选择逻辑对路由组件的并发模型、性能与稳定性提出更高要求。
|
||||||
|
|
||||||
|
### 核心特性
|
||||||
|
- 基于 Golang 实现的高性能路由与调度组件,适用于高并发、低延迟推理场景
|
||||||
|
- 请求级调度模型,保持接口语义清晰与系统复杂度可控
|
||||||
|
- 利用token级缓存相关运行时信息作为调度策略的辅助输入,用于提升实例选择的准确性与稳定性
|
||||||
|
- 模块化架构设计(Gateway / Scheduler / Manager),职责边界清晰,便于扩展与维护
|
||||||
|
- 面向 Prefill–Decode 分离推理架构设计,为复杂调度策略与能力演进提供结构性支持
|
||||||
|
|
||||||
|
### 与现有方案的差异
|
||||||
|
与 sglang 等推理框架内置 Router 相比,Golang-Router 以**独立 Golang 服务**的形式运行,将路由、调度与实例状态管理能力从推理执行逻辑中解耦。
|
||||||
|
|
||||||
|
Golang-Router 已支持 cache-aware 调度,在请求级调度框架内引入 token 级缓存相关运行时信号,辅助调度决策制定,以更稳定地适配 Prefill–Decode 分离推理架构下的缓存利用需求。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- 高性能 HTTP/HTTPS 服务器
|
||||||
|
- RESTful API 路由支持
|
||||||
|
- 可扩展的中间件系统
|
||||||
|
- 动态配置管理
|
||||||
|
- 内置健康检查和监控
|
||||||
|
- 负载均衡
|
||||||
|
- 日志记录和指标收集
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- Go 1.21
|
||||||
|
- 构建不依赖特定系统环境
|
||||||
|
- 可直接在 FastDeploy 官方 Docker 环境中编译与运行
|
||||||
|
|
||||||
|
### 编译
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置
|
||||||
|
|
||||||
|
1. 配置文件准备(可选)
|
||||||
|
如需修改默认配置,可复制配置模板并进行调整(示例可参考 examples/run_with_config):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config/config.example.yaml config/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 主要配置项说明:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
port: "8080" # 监听端口
|
||||||
|
host: "0.0.0.0" # 监听地址
|
||||||
|
mode: "debug" # 启动模式: debug, release, test
|
||||||
|
splitwise: true # true代表开启pd分离模式,false代表开启非pd分离模式
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num
|
||||||
|
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略
|
||||||
|
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略
|
||||||
|
eviction-interval-secs: 60 # cache-aware策略清理过期cache的间隔时间
|
||||||
|
balance-abs-threshold: 1 # cache-aware策略绝对阈值
|
||||||
|
balance-rel-threshold: 0.2 # cache-aware策略相对阈值
|
||||||
|
hit-ratio-weight: 1.0 # cache-aware策略命中率权重
|
||||||
|
load-balance-weight: 0.05 # cache-aware策略负载均衡权重
|
||||||
|
cache-block-size: 4 # cache-aware策略cache block大小
|
||||||
|
tokenizer-url: "http://0.0.0.0:8098" # tokenizer服务地址(可选)
|
||||||
|
tokenizer-timeout-secs: 2 # tokenizer服务超时时间
|
||||||
|
waiting-weight: 10 # cache-aware策略等待权重
|
||||||
|
|
||||||
|
manager:
|
||||||
|
health-failure-threshold: 3 # 健康检查失败次数,超过次数后认为节点不健康
|
||||||
|
health-success-threshold: 2 # 健康检查成功次数,超过次数后认为节点健康
|
||||||
|
health-check-timeout-secs: 5 # 健康检查超时时间
|
||||||
|
health-check-interval-secs: 5 # 健康检查间隔时间
|
||||||
|
health-check-endpoint: /health # 健康检查接口
|
||||||
|
register-path: "config/register.yaml" # 推理实例注册配置文件路径(可选)
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: "info" # 日志打印级别: debug / info / warn / error
|
||||||
|
output: "file" # 日志输出方式: stdout / file
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 启动时注册实例(可选)
|
||||||
|
支持通过配置文件在启动阶段注册推理实例(示例可参考 examples/run_with_default_workers):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp config/config.example.yaml config/config.yaml
|
||||||
|
cp config/register.example.yaml config/register.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行
|
||||||
|
本项目支持两种运行方式:直接运行源码 或 构建二进制文件后运行。
|
||||||
|
方式一:直接运行源码
|
||||||
|
在项目根目录下,使用 go run 启动服务:
|
||||||
|
```bash
|
||||||
|
go run cmd/main.go
|
||||||
|
```
|
||||||
|
该方式适用于本地开发与调试场景。
|
||||||
|
方式二:构建并运行二进制文件
|
||||||
|
1. 构建二进制文件
|
||||||
|
通过构建脚本生成可执行文件:
|
||||||
|
```bash
|
||||||
|
./build.sh
|
||||||
|
```
|
||||||
|
构建完成后,二进制文件将被安装到指定目录(默认为 /usr/local/bin,可通过修改 Makefile 中的 OUTDIR 进行调整)。
|
||||||
|
此外,也可以在项目根目录下手动构建二进制文件:
|
||||||
|
```bash
|
||||||
|
go build -o ./fd-router ./cmd
|
||||||
|
```
|
||||||
|
该方式便于本地测试或将二进制文件与配置文件一并分发。
|
||||||
|
2. 运行二进制文件
|
||||||
|
可以通过运行脚本启动服务:
|
||||||
|
```bash
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
运行脚本会自动处理常见启动参数及日志目录,适合标准化部署场景。
|
||||||
|
也可以直接运行二进制文件,在项目根目录或二进制所在目录下执行:
|
||||||
|
```bash
|
||||||
|
./fd-router \
|
||||||
|
--port 8080 \
|
||||||
|
--splitwise \
|
||||||
|
--config_path ./config/config.yaml
|
||||||
|
```
|
||||||
|
其中:
|
||||||
|
- --port 为必填参数
|
||||||
|
- 其他参数可根据实际需求配置
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── cmd/ # 主程序入口
|
||||||
|
├── config/ # 配置文件
|
||||||
|
├── internal/ # 核心实现代码
|
||||||
|
│ ├── common/ # 公共接口定义
|
||||||
|
│ ├── config/ # 配置处理
|
||||||
|
│ ├── gateway/ # API网关实现
|
||||||
|
│ ├── manager/ # 路由管理
|
||||||
|
│ ├── middleware/ # 中间件实现
|
||||||
|
│ ├── router/ # 路由核心逻辑
|
||||||
|
│ └── scheduler/ # 调度器实现
|
||||||
|
├── logs/ # 日志目录
|
||||||
|
├── output/ # 构建输出
|
||||||
|
├── pkg/ # 可复用组件
|
||||||
|
│ ├── logger/ # 日志组件
|
||||||
|
│ └── metrics/ # 监控指标
|
||||||
|
├── build.sh # 构建脚本
|
||||||
|
├── go.mod # Go模块定义
|
||||||
|
├── go.sum # 依赖校验
|
||||||
|
├── Makefile # 构建管理
|
||||||
|
├── README.md # 项目说明
|
||||||
|
└── run.sh # 启动脚本
|
||||||
|
```
|
||||||
|
|
||||||
|
### 运行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
## 贡献
|
||||||
|
|
||||||
|
欢迎提交 Issue 和 Pull Request!
|
||||||
Executable
+4
@@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
make all
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/router"
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Parse command line arguments
|
||||||
|
var configPath, port string
|
||||||
|
var splitwise bool
|
||||||
|
flag.StringVar(&configPath, "config_path", "", "path to config file")
|
||||||
|
flag.StringVar(&port, "port", "", "listen port of router")
|
||||||
|
flag.BoolVar(&splitwise, "splitwise", false, "enable splitwise mode")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
cfg, err := config.Load(configPath, port, splitwise)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize logger
|
||||||
|
logger.Init(cfg.Log.Level, cfg.Log.Output)
|
||||||
|
defer logger.CloseLogFile()
|
||||||
|
|
||||||
|
// Initialize manager
|
||||||
|
manager.Init(cfg)
|
||||||
|
scheduler_handler.Init(cfg, manager.DefaultManager)
|
||||||
|
registerYamlPath := cfg.Manager.RegisterPath
|
||||||
|
manager.RegisterInstancesFromConfig(registerYamlPath)
|
||||||
|
|
||||||
|
// Initialize router
|
||||||
|
r := router.New(cfg)
|
||||||
|
|
||||||
|
intervalSecs := cfg.Manager.HealthCheckIntervalSecs
|
||||||
|
go manager.MonitorInstanceHealth(context.Background(), intervalSecs)
|
||||||
|
intervalCleanupSecs := cfg.Scheduler.EvictionIntervalSecs
|
||||||
|
go scheduler_handler.StartBackupCleanupTask(context.Background(), intervalCleanupSecs)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
addr := ":" + cfg.Server.Port
|
||||||
|
logger.Info("Starting server on %s", addr)
|
||||||
|
if err := r.Run(addr); err != nil {
|
||||||
|
log.Fatalf("Failed to start server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
server:
|
||||||
|
port: "8080"
|
||||||
|
host: "0.0.0.0"
|
||||||
|
mode: "debug" # debug, release, test
|
||||||
|
splitwise: true # true means pd mode, false means mixed mode
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
policy: "power_of_two"
|
||||||
|
prefill-policy: "cache_aware"
|
||||||
|
decode-policy: "fd_metrics_score"
|
||||||
|
eviction-interval-secs: 60
|
||||||
|
balance-abs-threshold: 1
|
||||||
|
balance-rel-threshold: 0.2
|
||||||
|
hit-ratio-weight: 1.0
|
||||||
|
load-balance-weight: 0.05
|
||||||
|
cache-block-size: 4
|
||||||
|
tokenizer-url: "http://0.0.0.0:8098" # optional tokenizer service endpoint
|
||||||
|
tokenizer-timeout-secs: 2
|
||||||
|
waiting-weight: 10
|
||||||
|
|
||||||
|
manager:
|
||||||
|
health-failure-threshold: 3
|
||||||
|
health-success-threshold: 2
|
||||||
|
health-check-timeout-secs: 5
|
||||||
|
health-check-interval-secs: 5
|
||||||
|
health-check-endpoint: /health
|
||||||
|
register-path: "config/register.yaml"
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: "info" # debug, info, warn, error
|
||||||
|
output: "file" # stdout, file
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
|
||||||
|
if [ -n "$PID" ]; then
|
||||||
|
echo "Killing existing fd-router process (PID: $PID)"
|
||||||
|
# Try graceful shutdown first
|
||||||
|
kill -15 $PID
|
||||||
|
TIMEOUT=10
|
||||||
|
while kill -0 $PID 2>/dev/null && [ $TIMEOUT -gt 0 ]; do
|
||||||
|
echo "Waiting for fd-router (PID: $PID) to exit gracefully... ($TIMEOUT seconds remaining)"
|
||||||
|
sleep 1
|
||||||
|
TIMEOUT=$((TIMEOUT - 1))
|
||||||
|
done
|
||||||
|
# Force kill if still running after timeout
|
||||||
|
if kill -0 $PID 2>/dev/null; then
|
||||||
|
echo "fd-router (PID: $PID) did not exit gracefully; sending SIGKILL..."
|
||||||
|
kill -9 $PID
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting new fd-router process..."
|
||||||
|
nohup ./fd-router --config_path ./config/config.yaml --splitwise > fd-router.log 2>&1 &
|
||||||
|
echo "fd-router started with PID: $!"
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
server:
|
||||||
|
port: "8080"
|
||||||
|
host: "0.0.0.0"
|
||||||
|
mode: "debug" # debug, release, test
|
||||||
|
splitwise: true # true means pd mode, false means mixed mode
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
policy: "power_of_two"
|
||||||
|
prefill-policy: "cache_aware"
|
||||||
|
decode-policy: "fd_metrics_score"
|
||||||
|
eviction-interval-secs: 60
|
||||||
|
balance-abs-threshold: 1
|
||||||
|
balance-rel-threshold: 0.2
|
||||||
|
hit-ratio-weight: 1.0
|
||||||
|
load-balance-weight: 0.05
|
||||||
|
cache-block-size: 4
|
||||||
|
tokenizer-url: "http://0.0.0.0:8098" # optional tokenizer service endpoint
|
||||||
|
tokenizer-timeout-secs: 2
|
||||||
|
waiting-weight: 10
|
||||||
|
|
||||||
|
manager:
|
||||||
|
health-failure-threshold: 3
|
||||||
|
health-success-threshold: 2
|
||||||
|
health-check-timeout-secs: 5
|
||||||
|
health-check-interval-secs: 5
|
||||||
|
health-check-endpoint: /health
|
||||||
|
register-path: "config/register.yaml"
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: "info" # debug, info, warn, error
|
||||||
|
output: "file" # stdout, file
|
||||||
+21
@@ -0,0 +1,21 @@
|
|||||||
|
instances:
|
||||||
|
- role: "prefill"
|
||||||
|
host_ip: 127.0.0.1
|
||||||
|
port: 8097
|
||||||
|
connector_port: 8001
|
||||||
|
engine_worker_queue_port: 8002
|
||||||
|
transfer_protocol:
|
||||||
|
- ipc
|
||||||
|
- rdma
|
||||||
|
rdma_ports: [7100, "7101"]
|
||||||
|
device_ids: [0, "1"]
|
||||||
|
- role: "decode"
|
||||||
|
host_ip: 127.0.0.1
|
||||||
|
port: 8098
|
||||||
|
connector_port: 8001
|
||||||
|
engine_worker_queue_port: 8002
|
||||||
|
transfer_protocol:
|
||||||
|
- ipc
|
||||||
|
- rdma
|
||||||
|
rdma_ports: ["7100", "7101"]
|
||||||
|
device_ids: ["0", "1"]
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
|
||||||
|
if [ -n "$PID" ]; then
|
||||||
|
echo "Killing existing fd-router process (PID: $PID)"
|
||||||
|
# First try to terminate gracefully
|
||||||
|
kill -15 "$PID"
|
||||||
|
sleep 5
|
||||||
|
# If still running after timeout, force kill as a last resort
|
||||||
|
if ps -p "$PID" > /dev/null 2>&1; then
|
||||||
|
echo "Process $PID did not terminate gracefully; force killing..."
|
||||||
|
kill -9 "$PID"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting new fd-router process..."
|
||||||
|
nohup ./fd-router --config_path ./config/config.yaml --splitwise > fd-router.log 2>&1 &
|
||||||
|
echo "fd-router started with PID: $!"
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
module github.com/PaddlePaddle/FastDeploy/router
|
||||||
|
|
||||||
|
go 1.21
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/gin-gonic/gin v1.9.1
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/compress v1.17.11 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/prometheus/client_golang v1.21.1 // indirect
|
||||||
|
github.com/prometheus/client_model v0.6.1 // indirect
|
||||||
|
github.com/prometheus/common v0.62.0 // indirect
|
||||||
|
github.com/prometheus/procfs v0.15.1 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
|
golang.org/x/crypto v0.32.0 // indirect
|
||||||
|
golang.org/x/net v0.34.0 // indirect
|
||||||
|
golang.org/x/sys v0.29.0 // indirect
|
||||||
|
golang.org/x/text v0.21.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.1 // indirect
|
||||||
|
)
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||||
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||||
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||||
|
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||||
|
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||||
|
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||||
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||||
|
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||||
|
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||||
|
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||||
|
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||||
|
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||||
|
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||||
|
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||||
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
|
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||||
|
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||||
|
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||||
|
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||||
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||||
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
|
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||||
|
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
|
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||||
|
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type ManagerAPI interface {
|
||||||
|
GetHealthyURLs(ctx context.Context) []string
|
||||||
|
GetMetrics(ctx context.Context, url string) (int, int, int)
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig `yaml:"server"`
|
||||||
|
Log LogConfig `yaml:"log"`
|
||||||
|
Manager ManagerConfig `yaml:"manager"`
|
||||||
|
Scheduler SchedulerConfig `yaml:"scheduler"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Port string `yaml:"port"`
|
||||||
|
Host string `yaml:"host"`
|
||||||
|
Mode string `yaml:"mode"` // debug, release, test
|
||||||
|
Splitwise bool `yaml:"splitwise"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagerConfig struct {
|
||||||
|
RegisterPath string `yaml:"register-path"`
|
||||||
|
HealthFailureThreshold int `yaml:"health-failure-threshold"`
|
||||||
|
HealthSuccessThreshold int `yaml:"health-success-threshold"`
|
||||||
|
HealthCheckTimeoutSecs float64 `yaml:"health-check-timeout-secs"`
|
||||||
|
HealthCheckIntervalSecs float64 `yaml:"health-check-interval-secs"`
|
||||||
|
HealthCheckEndpoint string `yaml:"health-check-endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SchedulerConfig struct {
|
||||||
|
Policy string `yaml:"policy"`
|
||||||
|
PrefillPolicy string `yaml:"prefill-policy"`
|
||||||
|
DecodePolicy string `yaml:"decode-policy"`
|
||||||
|
EvictionIntervalSecs float64 `yaml:"eviction-interval-secs"`
|
||||||
|
CacheBlockSize int `yaml:"cache-block-size"`
|
||||||
|
TokenizerURL string `yaml:"tokenizer-url"`
|
||||||
|
TokenizerTimeoutSecs float64 `yaml:"tokenizer-timeout-secs"`
|
||||||
|
BalanceAbsThreshold float64 `yaml:"balance-abs-threshold"`
|
||||||
|
BalanceRelThreshold float64 `yaml:"balance-rel-threshold"`
|
||||||
|
HitRatioWeight float64 `yaml:"hit-ratio-weight"`
|
||||||
|
LoadBalanceWeight float64 `yaml:"load-balance-weight"`
|
||||||
|
WaitingWeight float64 `yaml:"waiting-weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogConfig struct {
|
||||||
|
Level string `yaml:"level"` // debug, info, warn, error
|
||||||
|
Output string `yaml:"output"` // stdout, file
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load(configPath, listenPort string, isSplitwise bool) (*Config, error) {
|
||||||
|
var cfg Config
|
||||||
|
if configPath != "" {
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default values
|
||||||
|
if listenPort != "" {
|
||||||
|
cfg.Server.Port = listenPort
|
||||||
|
} else if cfg.Server.Port == "" {
|
||||||
|
return nil, fmt.Errorf("failed to set router listen port")
|
||||||
|
}
|
||||||
|
if isSplitwise {
|
||||||
|
cfg.Server.Splitwise = true
|
||||||
|
}
|
||||||
|
if cfg.Server.Mode == "" {
|
||||||
|
cfg.Server.Mode = "release"
|
||||||
|
}
|
||||||
|
if cfg.Log.Level == "" {
|
||||||
|
cfg.Log.Level = "info"
|
||||||
|
}
|
||||||
|
if cfg.Manager.HealthCheckEndpoint == "" {
|
||||||
|
cfg.Manager.HealthCheckEndpoint = "/health"
|
||||||
|
}
|
||||||
|
if cfg.Manager.HealthCheckTimeoutSecs == 0 {
|
||||||
|
cfg.Manager.HealthCheckTimeoutSecs = 5
|
||||||
|
}
|
||||||
|
if cfg.Manager.HealthCheckIntervalSecs == 0 {
|
||||||
|
cfg.Manager.HealthCheckIntervalSecs = 5
|
||||||
|
}
|
||||||
|
if cfg.Manager.HealthFailureThreshold == 0 {
|
||||||
|
cfg.Manager.HealthFailureThreshold = 1
|
||||||
|
}
|
||||||
|
if cfg.Manager.HealthSuccessThreshold == 0 {
|
||||||
|
cfg.Manager.HealthSuccessThreshold = 1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.EvictionIntervalSecs == 0 {
|
||||||
|
cfg.Scheduler.EvictionIntervalSecs = 60
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.CacheBlockSize == 0 {
|
||||||
|
cfg.Scheduler.CacheBlockSize = 64
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.TokenizerTimeoutSecs == 0 {
|
||||||
|
cfg.Scheduler.TokenizerTimeoutSecs = 2
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.HitRatioWeight == 0 {
|
||||||
|
cfg.Scheduler.HitRatioWeight = 1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.LoadBalanceWeight == 0 {
|
||||||
|
cfg.Scheduler.LoadBalanceWeight = 1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.BalanceAbsThreshold == 0 {
|
||||||
|
cfg.Scheduler.BalanceAbsThreshold = 1
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.BalanceRelThreshold == 0 {
|
||||||
|
cfg.Scheduler.BalanceRelThreshold = 0.2
|
||||||
|
}
|
||||||
|
if cfg.Scheduler.WaitingWeight == 0 {
|
||||||
|
cfg.Scheduler.WaitingWeight = 1
|
||||||
|
}
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,372 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/metrics"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/valyala/bytebufferpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxCapacity = 10 * 1024 * 1024 // 10MB
|
||||||
|
|
||||||
|
// newRequestID generates UUIDv4 style request_id
|
||||||
|
func newRequestID() string {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := crand.Read(b); err == nil {
|
||||||
|
// Set version and variant bits, compliant with RFC 4122
|
||||||
|
b[6] = (b[6] & 0x0f) | 0x40
|
||||||
|
b[8] = (b[8] & 0x3f) | 0x80
|
||||||
|
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d-%d", time.Now().UnixNano(), rand.Int63())
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPromptFromChatRequest extracts text prompt from OpenAI ChatCompletions style request
|
||||||
|
func extractPromptFromChatRequest(rawReq map[string]any) string {
|
||||||
|
messagesVal, ok := rawReq["messages"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, ok := messagesVal.([]any)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
appendText := func(s string) {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if builder.Len() > 0 {
|
||||||
|
builder.WriteByte(' ')
|
||||||
|
}
|
||||||
|
builder.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
msgMap, ok := msg.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := msgMap["content"]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := content.(type) {
|
||||||
|
case string:
|
||||||
|
appendText(v)
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
itemMap, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType, _ := itemMap["type"].(string)
|
||||||
|
if itemType != "text" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if textVal, ok := itemMap["text"].(string); ok {
|
||||||
|
appendText(textVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Other structures are ignored for now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostToPD sends requests to both Prefill and Decode instances, only returns Decode node response
|
||||||
|
func PostToPD(c *gin.Context, decodeURL, prefillURL string, reqBody []byte) (*http.Response, error) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
decodeEndpoint := fmt.Sprintf("%s/v1/%s", decodeURL, "chat/completions")
|
||||||
|
prefillEndpoint := fmt.Sprintf("%s/v1/%s", prefillURL, "chat/completions")
|
||||||
|
|
||||||
|
// Construct two requests
|
||||||
|
decodeReq, err := http.NewRequestWithContext(ctx, "POST", decodeEndpoint, bytes.NewReader(reqBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
prefillReq, err := http.NewRequestWithContext(ctx, "POST", prefillEndpoint, bytes.NewReader(reqBody))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy request headers
|
||||||
|
for k, v := range c.Request.Header {
|
||||||
|
if k != "Content-Length" {
|
||||||
|
decodeReq.Header[k] = v
|
||||||
|
prefillReq.Header[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
|
||||||
|
type respResult struct {
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
prefillCh := make(chan respResult, 1)
|
||||||
|
decodeCh := make(chan respResult, 1)
|
||||||
|
|
||||||
|
// Concurrently send requests to P/D
|
||||||
|
go func() {
|
||||||
|
resp, err := client.Do(prefillReq)
|
||||||
|
prefillCh <- respResult{resp: resp, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
resp, err := client.Do(decodeReq)
|
||||||
|
decodeCh <- respResult{resp: resp, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
prefillRes := <-prefillCh
|
||||||
|
decodeRes := <-decodeCh
|
||||||
|
|
||||||
|
// Prioritize returning Decode errors
|
||||||
|
if decodeRes.err != nil {
|
||||||
|
if prefillRes.resp != nil {
|
||||||
|
prefillRes.resp.Body.Close()
|
||||||
|
}
|
||||||
|
return nil, decodeRes.err
|
||||||
|
}
|
||||||
|
if prefillRes.err != nil {
|
||||||
|
// Prefill errors are also considered failures to avoid inconsistent behavior
|
||||||
|
if decodeRes.resp != nil {
|
||||||
|
decodeRes.resp.Body.Close()
|
||||||
|
}
|
||||||
|
return nil, prefillRes.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefillRes.resp != nil {
|
||||||
|
prefillRes.resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return decodeRes.resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletions implements request forwarding to actual large model inference service
|
||||||
|
func ChatCompletions(c *gin.Context) {
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Invalid request body"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawReq map[string]any
|
||||||
|
if err := json.Unmarshal(bodyBytes, &rawReq); err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Invalid JSON format"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isSplitwise := manager.GetSplitwise(ctx)
|
||||||
|
|
||||||
|
var (
|
||||||
|
destURL string
|
||||||
|
releaseTargets []string
|
||||||
|
requestBodyData []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
if isSplitwise {
|
||||||
|
// PD mode: select instances for Prefill/Decode separately
|
||||||
|
message := extractPromptFromChatRequest(rawReq)
|
||||||
|
|
||||||
|
prefillURL, decodeURL, err := manager.SelectWorkerPair(ctx, message)
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusBadGateway)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Failed to select worker pair"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if prefillURL == "" || decodeURL == "" {
|
||||||
|
c.Writer.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
c.Writer.Write([]byte(`{"error": "No available prefill/decode workers"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct disaggregate_info to ensure selected P/D work in pairs within FastDeploy
|
||||||
|
disagg, err := manager.BuildDisaggregateInfo(ctx, prefillURL, decodeURL)
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Failed to build disaggregate_info"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawReq["disaggregate_info"] = disagg
|
||||||
|
|
||||||
|
// If user didn't provide request_id, generate one
|
||||||
|
if _, ok := rawReq["request_id"]; !ok {
|
||||||
|
rawReq["request_id"] = newRequestID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-encode request body and send to P and D
|
||||||
|
requestBodyData, err = json.Marshal(rawReq)
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusInternalServerError)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Failed to encode modified request"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
destURL = decodeURL
|
||||||
|
releaseTargets = []string{prefillURL, decodeURL}
|
||||||
|
|
||||||
|
// Expose scheduling results to caller for debugging/validating scheduling strategy
|
||||||
|
c.Writer.Header().Set("X-Router-Prefill-URL", prefillURL)
|
||||||
|
c.Writer.Header().Set("X-Router-Decode-URL", decodeURL)
|
||||||
|
|
||||||
|
// Prefill node token count was added in SelectWorker, release when request ends
|
||||||
|
defer scheduler_handler.ReleasePrefillTokens(ctx, prefillURL, message)
|
||||||
|
} else {
|
||||||
|
// Non-PD mode: use Mixed instance
|
||||||
|
dest, err := manager.SelectWorker(ctx, "")
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusBadGateway)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Failed to select worker"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
destURL = dest
|
||||||
|
releaseTargets = []string{destURL}
|
||||||
|
requestBodyData = bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maintain request_num count for related instances (Inc done in SelectWorker, Release here)
|
||||||
|
defer func() {
|
||||||
|
for _, url := range releaseTargets {
|
||||||
|
scheduler_handler.Release(ctx, url)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
var backendResp *http.Response
|
||||||
|
if isSplitwise {
|
||||||
|
backendResp, err = PostToPD(c, destURL, releaseTargets[0], requestBodyData)
|
||||||
|
} else {
|
||||||
|
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.Writer.WriteHeader(http.StatusBadGateway)
|
||||||
|
c.Writer.Write([]byte(`{"error": "Failed to connect to backend service"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer backendResp.Body.Close()
|
||||||
|
|
||||||
|
if isSplitwise {
|
||||||
|
metrics.InferenceRequests.WithLabelValues("", releaseTargets[0], destURL, strconv.Itoa(backendResp.StatusCode)).Inc()
|
||||||
|
} else {
|
||||||
|
metrics.InferenceRequests.WithLabelValues(destURL, "", "", strconv.Itoa(backendResp.StatusCode)).Inc()
|
||||||
|
}
|
||||||
|
// Copy response headers
|
||||||
|
for k, v := range backendResp.Header {
|
||||||
|
if k != "Content-Length" { // Remove Content-Length header
|
||||||
|
c.Writer.Header()[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//c.Writer.Header().Set("Transfer-Encoding", "chunked") // Set chunked transfer
|
||||||
|
if backendResp.StatusCode == http.StatusOK {
|
||||||
|
c.Writer.WriteHeader(backendResp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
isStream := false
|
||||||
|
if v, ok := rawReq["stream"]; ok {
|
||||||
|
stream, ok := v.(bool)
|
||||||
|
if ok && stream {
|
||||||
|
isStream = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
redirect(c, isStream, backendResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirect(c *gin.Context, isStream bool, backendResp *http.Response) {
|
||||||
|
// Forward response body
|
||||||
|
if isStream {
|
||||||
|
// Stream response, use buffer pool to avoid frequent buffer creation/destruction
|
||||||
|
buffer := bytebufferpool.Get()
|
||||||
|
buffer.Reset()
|
||||||
|
defer bytebufferpool.Put(buffer)
|
||||||
|
scanner := bufio.NewScanner(backendResp.Body)
|
||||||
|
scanner.Buffer(buffer.B, maxCapacity) // Key: reset buffer
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
c.Writer.Write([]byte(line + "\n"))
|
||||||
|
c.Writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
logger.Error("scanner error: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Compatible with non-stream response
|
||||||
|
io.Copy(c.Writer, backendResp.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientWithRetry adds retry
|
||||||
|
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string) (
|
||||||
|
backendResp *http.Response, err error) {
|
||||||
|
// Five retries
|
||||||
|
maxRetry := 3
|
||||||
|
for i := 0; i < maxRetry; i++ {
|
||||||
|
// If creating request fails, it's network connection error, check if selected node is elastic resource, if so, delete it
|
||||||
|
backendResp, err = GetClient(c, destUrl, "chat/completions", bodyBytes)
|
||||||
|
if err == nil { // Return latest bucketsize
|
||||||
|
return backendResp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClient(c *gin.Context, address, api string, reqBody []byte) (*http.Response, error) {
|
||||||
|
backendURL := fmt.Sprintf("%s/v1/%s", address, api)
|
||||||
|
|
||||||
|
backendReq, err := http.NewRequestWithContext(
|
||||||
|
c.Request.Context(),
|
||||||
|
"POST",
|
||||||
|
backendURL,
|
||||||
|
bytes.NewReader(reqBody),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Copy request headers
|
||||||
|
for k, v := range c.Request.Header {
|
||||||
|
if k != "Content-Length" { // Remove Content-Length header
|
||||||
|
backendReq.Header[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
backendResp, err := client.Do(backendReq)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return backendResp, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChatCompletions(t *testing.T) {
|
||||||
|
// Since the actual implementation uses package-level functions that depend on DefaultManager,
|
||||||
|
// and we don't want to set up a full manager for unit tests,
|
||||||
|
// this test will be marked as integration test and skipped for now
|
||||||
|
t.Skip("Integration test requiring manager setup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractPromptFromChatRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"simple message",
|
||||||
|
`{"messages": [{"role": "user", "content": "hello"}]}`,
|
||||||
|
"hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"multiple messages",
|
||||||
|
`{"messages": [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "hi"},
|
||||||
|
{"role": "user", "content": "how are you"}
|
||||||
|
]}`,
|
||||||
|
"hello hi how are you",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty messages",
|
||||||
|
`{"messages": []}`,
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var rawReq map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tt.input), &rawReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result := extractPromptFromChatRequest(rawReq)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostToPD(t *testing.T) {
|
||||||
|
// Setup test servers
|
||||||
|
prefillTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer prefillTS.Close()
|
||||||
|
|
||||||
|
decodeTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"response": "test"}`))
|
||||||
|
}))
|
||||||
|
defer decodeTS.Close()
|
||||||
|
|
||||||
|
// Setup test context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(`{"test": "data"}`))
|
||||||
|
|
||||||
|
resp, err := PostToPD(c, decodeTS.URL, prefillTS.URL, []byte(`{"test": "data"}`))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirect(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("test response"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Test stream response
|
||||||
|
t.Run("stream response", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
redirect(c, true, resp)
|
||||||
|
assert.Equal(t, "test response\n", w.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test non-stream response
|
||||||
|
t.Run("non-stream response", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
redirect(c, false, resp)
|
||||||
|
assert.Equal(t, "test response", w.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClient(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("test response"))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Setup test context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(`{"test": "data"}`))
|
||||||
|
|
||||||
|
resp, err := GetClient(c, ts.URL, "chat/completions", []byte(`{"test": "data"}`))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRequestID(t *testing.T) {
|
||||||
|
id1 := newRequestID()
|
||||||
|
id2 := newRequestID()
|
||||||
|
|
||||||
|
// Check that IDs are not empty
|
||||||
|
assert.NotEmpty(t, id1)
|
||||||
|
assert.NotEmpty(t, id2)
|
||||||
|
|
||||||
|
// Check that IDs are different
|
||||||
|
assert.NotEqual(t, id1, id2)
|
||||||
|
}
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
mixedWorkerMap map[string]*WorkerInfo
|
||||||
|
prefillWorkerMap map[string]*WorkerInfo
|
||||||
|
decodeWorkerMap map[string]*WorkerInfo
|
||||||
|
splitwise bool
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type WorkerInfo struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
WorkerType string `json:"worker_type"`
|
||||||
|
ConnectorPort string `json:"connector_port"`
|
||||||
|
EngineWorkerQueuePort string `json:"engine_worker_queue_port"`
|
||||||
|
TransferProtocol []string `json:"transfer_protocol"`
|
||||||
|
RdmaPorts []string `json:"rdma_ports"`
|
||||||
|
DeviceIDs []string `json:"device_ids"`
|
||||||
|
MetricsPort string `json:"metrics_port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultManager *Manager
|
||||||
|
var defaultCheckTimeout time.Duration
|
||||||
|
var healthEndpoint string
|
||||||
|
var failureThreshold int
|
||||||
|
var successThreshold int
|
||||||
|
|
||||||
|
// Manager module initialization
|
||||||
|
func Init(cfg *config.Config) {
|
||||||
|
manager := &Manager{
|
||||||
|
mixedWorkerMap: make(map[string]*WorkerInfo),
|
||||||
|
prefillWorkerMap: make(map[string]*WorkerInfo),
|
||||||
|
decodeWorkerMap: make(map[string]*WorkerInfo),
|
||||||
|
splitwise: cfg.Server.Splitwise,
|
||||||
|
}
|
||||||
|
DefaultManager = manager
|
||||||
|
// Define a default timeout duration
|
||||||
|
defaultCheckTimeout = time.Duration(cfg.Manager.HealthCheckTimeoutSecs * float64(time.Second))
|
||||||
|
healthEndpoint = cfg.Manager.HealthCheckEndpoint
|
||||||
|
failureThreshold = cfg.Manager.HealthFailureThreshold
|
||||||
|
successThreshold = cfg.Manager.HealthSuccessThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func WorkerMapToList(ctx context.Context, workerType string) []string {
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
var workerMap map[string]*WorkerInfo
|
||||||
|
switch workerType {
|
||||||
|
case "mixed":
|
||||||
|
workerMap = DefaultManager.mixedWorkerMap
|
||||||
|
case "prefill":
|
||||||
|
workerMap = DefaultManager.prefillWorkerMap
|
||||||
|
case "decode":
|
||||||
|
workerMap = DefaultManager.decodeWorkerMap
|
||||||
|
default:
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workerMap == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all keys and sort them
|
||||||
|
keys := make([]string, 0, len(workerMap))
|
||||||
|
for key := range workerMap {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
// Build worker list
|
||||||
|
workerURLs := make([]string, 0, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if workerInfo, exists := workerMap[key]; exists {
|
||||||
|
workerURLs = append(workerURLs, workerInfo.Url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return workerURLs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetHealthyURLs(ctx context.Context) []string {
|
||||||
|
if m == nil {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
totalSeversLength := len(m.prefillWorkerMap) + len(m.decodeWorkerMap) + len(m.mixedWorkerMap)
|
||||||
|
allServerURLs := make([]string, 0, totalSeversLength)
|
||||||
|
|
||||||
|
for id := range m.prefillWorkerMap {
|
||||||
|
allServerURLs = append(allServerURLs, id)
|
||||||
|
}
|
||||||
|
for id := range m.decodeWorkerMap {
|
||||||
|
allServerURLs = append(allServerURLs, id)
|
||||||
|
}
|
||||||
|
for id := range m.mixedWorkerMap {
|
||||||
|
allServerURLs = append(allServerURLs, id)
|
||||||
|
}
|
||||||
|
return allServerURLs
|
||||||
|
}
|
||||||
|
|
||||||
|
func SelectWorker(ctx context.Context, message string) (string, error) {
|
||||||
|
workers := WorkerMapToList(ctx, "mixed")
|
||||||
|
selectedWorkerURL, err := scheduler_handler.SelectWorker(ctx, workers, message, "mixed")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return selectedWorkerURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SelectWorkerPair(ctx context.Context, message string) (string, string, error) {
|
||||||
|
prefillWorkers := WorkerMapToList(ctx, "prefill")
|
||||||
|
decodeWorkers := WorkerMapToList(ctx, "decode")
|
||||||
|
if len(prefillWorkers) == 0 || len(decodeWorkers) == 0 {
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
selectedPrefillWorkerURL, err := scheduler_handler.SelectWorker(ctx, prefillWorkers, message, "prefill")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
selectedDecodeWorkerURL, err := scheduler_handler.SelectWorker(ctx, decodeWorkers, message, "decode")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
return selectedPrefillWorkerURL, selectedDecodeWorkerURL, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInit(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Splitwise: true,
|
||||||
|
},
|
||||||
|
Manager: config.ManagerConfig{
|
||||||
|
HealthCheckTimeoutSecs: 5.0,
|
||||||
|
HealthCheckEndpoint: "/health",
|
||||||
|
HealthFailureThreshold: 3,
|
||||||
|
HealthSuccessThreshold: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
Init(cfg)
|
||||||
|
|
||||||
|
assert.NotNil(t, DefaultManager)
|
||||||
|
assert.True(t, DefaultManager.splitwise)
|
||||||
|
assert.Equal(t, "/health", healthEndpoint)
|
||||||
|
assert.Equal(t, 3, failureThreshold)
|
||||||
|
assert.Equal(t, 2, successThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWorkerMapToList(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker1": {Url: "http://worker1"},
|
||||||
|
"http://worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker3": {Url: "http://worker3"},
|
||||||
|
}
|
||||||
|
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker4": {Url: "http://worker4"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prefill workers", func(t *testing.T) {
|
||||||
|
workers := WorkerMapToList(context.Background(), "prefill")
|
||||||
|
assert.Len(t, workers, 2)
|
||||||
|
assert.Contains(t, workers, "http://worker1")
|
||||||
|
assert.Contains(t, workers, "http://worker2")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode workers", func(t *testing.T) {
|
||||||
|
workers := WorkerMapToList(context.Background(), "decode")
|
||||||
|
assert.Len(t, workers, 1)
|
||||||
|
assert.Contains(t, workers, "http://worker3")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed workers", func(t *testing.T) {
|
||||||
|
workers := WorkerMapToList(context.Background(), "mixed")
|
||||||
|
assert.Len(t, workers, 1)
|
||||||
|
assert.Contains(t, workers, "http://worker4")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid worker type", func(t *testing.T) {
|
||||||
|
workers := WorkerMapToList(context.Background(), "invalid")
|
||||||
|
assert.Len(t, workers, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetHealthyURLs(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker3": {Url: "http://worker3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
urls := DefaultManager.GetHealthyURLs(context.Background())
|
||||||
|
assert.Len(t, urls, 3)
|
||||||
|
assert.Contains(t, urls, "worker1")
|
||||||
|
assert.Contains(t, urls, "worker2")
|
||||||
|
assert.Contains(t, urls, "worker3")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWorker(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// This will fail because SelectWorker depends on scheduler
|
||||||
|
// which we don't want to mock in this unit test
|
||||||
|
t.Skip("Integration test requiring scheduler setup")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWorkerPair(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// This will fail because SelectWorkerPair depends on scheduler
|
||||||
|
// which we don't want to mock in this unit test
|
||||||
|
t.Skip("Integration test requiring scheduler setup")
|
||||||
|
}
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type healthCheckResult struct {
|
||||||
|
url string
|
||||||
|
isHealthy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type healthMonitorResult struct {
|
||||||
|
id string
|
||||||
|
worker *WorkerInfo // node information
|
||||||
|
isHealthy bool // check error (nil means healthy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Duration) bool {
|
||||||
|
// Handle empty baseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
logger.Error("empty baseURL provided")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
healthPath := healthEndpoint
|
||||||
|
url := baseURL + healthPath
|
||||||
|
timeoutToUse := defaultCheckTimeout // Default timeout
|
||||||
|
|
||||||
|
// Override default value if caller provides valid timeout parameter
|
||||||
|
if len(timeout) > 0 && timeout[0] > 0 {
|
||||||
|
timeoutToUse = timeout[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to create request: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
client := &http.Client{Timeout: timeoutToUse}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to send request to %s with error: %v", url, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Read response body
|
||||||
|
_, err = io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to read response body: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response status code
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckWorkerHealth(ctx context.Context, baseURL string) bool {
|
||||||
|
allServers := GetAllMapServers(ctx)
|
||||||
|
_, exists := allServers[baseURL]
|
||||||
|
if exists {
|
||||||
|
for i := 0; i < failureThreshold; i++ {
|
||||||
|
checkOk := CheckServiceHealth(ctx, baseURL)
|
||||||
|
if checkOk {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 0; i < successThreshold; i++ {
|
||||||
|
checkOk := CheckServiceHealth(ctx, baseURL)
|
||||||
|
if !checkOk {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all URLs of Prefill and Decode servers
|
||||||
|
func GetAllServerURLs(ctx context.Context) []string {
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
totalSeversLength := len(DefaultManager.prefillWorkerMap) + len(DefaultManager.decodeWorkerMap)
|
||||||
|
allServerURLs := make([]string, 0, totalSeversLength)
|
||||||
|
|
||||||
|
for _, server := range DefaultManager.prefillWorkerMap {
|
||||||
|
allServerURLs = append(allServerURLs, server.Url)
|
||||||
|
}
|
||||||
|
for _, server := range DefaultManager.decodeWorkerMap {
|
||||||
|
allServerURLs = append(allServerURLs, server.Url)
|
||||||
|
}
|
||||||
|
return allServerURLs
|
||||||
|
}
|
||||||
|
|
||||||
|
func HealthGenerate(c *gin.Context) {
|
||||||
|
// The buffer size of this channel equals the total number of tasks, avoids goroutine blocking
|
||||||
|
results := make(chan healthCheckResult, len(DefaultManager.prefillWorkerMap)+len(DefaultManager.decodeWorkerMap))
|
||||||
|
// Use WaitGroup to wait for all goroutines to complete sending results
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
allServerURLs := GetAllServerURLs(c.Request.Context())
|
||||||
|
|
||||||
|
for _, s := range allServerURLs {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(serverURL string) {
|
||||||
|
defer wg.Done()
|
||||||
|
baseURL := serverURL
|
||||||
|
isHealthy := CheckWorkerHealth(c.Request.Context(), baseURL)
|
||||||
|
results <- healthCheckResult{
|
||||||
|
url: serverURL,
|
||||||
|
isHealthy: isHealthy,
|
||||||
|
}
|
||||||
|
}(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a goroutine to close the result channel after all check tasks complete
|
||||||
|
// Used to notify the range loop can end
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for res := range results {
|
||||||
|
// Process each result
|
||||||
|
if !res.isHealthy {
|
||||||
|
logger.Warn("Server %s is not healthy", res.url)
|
||||||
|
} else {
|
||||||
|
logger.Info("Server %s is healthy", res.url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "Health check complete",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RemoveServers(ctx context.Context, prefillToRemove []string, decodeToRemove []string, mixedToRemove []string) {
|
||||||
|
DefaultManager.mu.Lock()
|
||||||
|
defer DefaultManager.mu.Unlock()
|
||||||
|
|
||||||
|
for _, id := range prefillToRemove {
|
||||||
|
if worker, exists := DefaultManager.prefillWorkerMap[id]; exists {
|
||||||
|
delete(DefaultManager.prefillWorkerMap, id)
|
||||||
|
logger.Info("Removed unhealthy prefill instance: %s", worker.Url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, id := range decodeToRemove {
|
||||||
|
if worker, exists := DefaultManager.decodeWorkerMap[id]; exists {
|
||||||
|
delete(DefaultManager.decodeWorkerMap, id)
|
||||||
|
logger.Info("Removed unhealthy decode instance: %s", worker.Url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, id := range mixedToRemove {
|
||||||
|
if worker, exists := DefaultManager.mixedWorkerMap[id]; exists {
|
||||||
|
delete(DefaultManager.mixedWorkerMap, id)
|
||||||
|
logger.Info("Removed unhealthy mixed instance: %s", worker.Url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedInstances []string) {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
logger.Debug("Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
|
||||||
|
return []string{}, []string{}, []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
// Pre-allocate sufficient capacity to avoid multiple expansions
|
||||||
|
prefillInstances = make([]string, 0, len(DefaultManager.prefillWorkerMap))
|
||||||
|
decodeInstances = make([]string, 0, len(DefaultManager.decodeWorkerMap))
|
||||||
|
mixedInstances = make([]string, 0, len(DefaultManager.mixedWorkerMap))
|
||||||
|
|
||||||
|
// Copy data to avoid holding lock for long time
|
||||||
|
for _, w := range DefaultManager.prefillWorkerMap {
|
||||||
|
prefillInstances = append(prefillInstances, w.Url)
|
||||||
|
}
|
||||||
|
for _, w := range DefaultManager.decodeWorkerMap {
|
||||||
|
decodeInstances = append(decodeInstances, w.Url)
|
||||||
|
}
|
||||||
|
for _, w := range DefaultManager.mixedWorkerMap {
|
||||||
|
mixedInstances = append(mixedInstances, w.Url)
|
||||||
|
}
|
||||||
|
logger.Debug(
|
||||||
|
"Healthy instances: prefill=%v, decode=%v, mixed=%v",
|
||||||
|
prefillInstances,
|
||||||
|
decodeInstances,
|
||||||
|
mixedInstances,
|
||||||
|
)
|
||||||
|
return prefillInstances, decodeInstances, mixedInstances
|
||||||
|
}
|
||||||
|
|
||||||
|
func MonitorInstanceHealthCore(ctx context.Context) {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Concurrently check health status of all nodes (fix concurrent security issues)
|
||||||
|
allServers := GetAllMapServers(ctx)
|
||||||
|
length := len(allServers)
|
||||||
|
|
||||||
|
resultCh := make(chan healthMonitorResult, length)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for id, server := range allServers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id string, server *WorkerInfo) {
|
||||||
|
defer wg.Done()
|
||||||
|
// Execute health check logic
|
||||||
|
baseURL := server.Url
|
||||||
|
isHealthy := CheckWorkerHealth(ctx, baseURL)
|
||||||
|
resultCh <- healthMonitorResult{
|
||||||
|
id: id,
|
||||||
|
worker: server,
|
||||||
|
isHealthy: isHealthy,
|
||||||
|
}
|
||||||
|
}(id, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all checks to complete
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(resultCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var prefillToRemove, decodeToRemove, mixedToRemove []string
|
||||||
|
|
||||||
|
for res := range resultCh {
|
||||||
|
if !res.isHealthy {
|
||||||
|
// logger.Warn("Server %s meets error: %v", res.worker.url, res.err)
|
||||||
|
switch res.worker.WorkerType {
|
||||||
|
case "prefill":
|
||||||
|
prefillToRemove = append(prefillToRemove, res.id)
|
||||||
|
case "decode":
|
||||||
|
decodeToRemove = append(decodeToRemove, res.id)
|
||||||
|
case "mixed":
|
||||||
|
mixedToRemove = append(mixedToRemove, res.id)
|
||||||
|
}
|
||||||
|
go scheduler_handler.CleanupUnhealthyCounter(ctx, res.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove unhealthy instances
|
||||||
|
RemoveServers(ctx, prefillToRemove, decodeToRemove, mixedToRemove)
|
||||||
|
|
||||||
|
ReadServers(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MonitorInstanceHealth(ctx context.Context, intervalSecs float64) {
|
||||||
|
ticker := time.NewTicker(time.Duration(intervalSecs * float64(time.Second)))
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// Infinite loop: continuously execute health checks
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
go MonitorInstanceHealthCore(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Initialize logger for all tests
|
||||||
|
logger.Init("info", "stdout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckServiceHealth(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
t.Run("healthy service", func(t *testing.T) {
|
||||||
|
healthy := CheckServiceHealth(context.Background(), ts.URL)
|
||||||
|
assert.True(t, healthy)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unhealthy service", func(t *testing.T) {
|
||||||
|
unhealthyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer unhealthyServer.Close()
|
||||||
|
|
||||||
|
healthy := CheckServiceHealth(context.Background(), unhealthyServer.URL)
|
||||||
|
assert.False(t, healthy)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty baseURL", func(t *testing.T) {
|
||||||
|
healthy := CheckServiceHealth(context.Background(), "")
|
||||||
|
assert.False(t, healthy)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("timeout", func(t *testing.T) {
|
||||||
|
slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer slowServer.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
healthy := CheckServiceHealth(ctx, slowServer.URL)
|
||||||
|
assert.False(t, healthy)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckWorkerHealth(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{
|
||||||
|
Manager: config.ManagerConfig{
|
||||||
|
HealthFailureThreshold: 1,
|
||||||
|
HealthSuccessThreshold: 1,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
t.Run("healthy worker", func(t *testing.T) {
|
||||||
|
healthy := CheckWorkerHealth(context.Background(), ts.URL)
|
||||||
|
assert.True(t, healthy)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unhealthy worker", func(t *testing.T) {
|
||||||
|
unhealthyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer unhealthyServer.Close()
|
||||||
|
|
||||||
|
healthy := CheckWorkerHealth(context.Background(), unhealthyServer.URL)
|
||||||
|
assert.False(t, healthy)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthGenerate(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: ts.URL},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: ts.URL},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Gin handler
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
// Set up a valid HTTP request for the context
|
||||||
|
c.Request = httptest.NewRequest("GET", "/health", nil)
|
||||||
|
HealthGenerate(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Health check complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMonitorInstanceHealthCore(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: ts.URL, WorkerType: "prefill"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: ts.URL, WorkerType: "decode"},
|
||||||
|
}
|
||||||
|
|
||||||
|
MonitorInstanceHealthCore(context.Background())
|
||||||
|
|
||||||
|
// Verify workers still exist (since they're healthy)
|
||||||
|
_, exists := DefaultManager.prefillWorkerMap["worker1"]
|
||||||
|
assert.True(t, exists)
|
||||||
|
_, exists = DefaultManager.decodeWorkerMap["worker2"]
|
||||||
|
assert.True(t, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadServers(t *testing.T) {
|
||||||
|
// Setup test data
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill, decode, mixed := ReadServers(context.Background())
|
||||||
|
assert.Equal(t, []string{"http://worker1"}, prefill)
|
||||||
|
assert.Equal(t, []string{"http://worker2"}, decode)
|
||||||
|
assert.Equal(t, []string{}, mixed)
|
||||||
|
}
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type InstanceRole int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MIXED InstanceRole = iota
|
||||||
|
PREFILL
|
||||||
|
DECODE
|
||||||
|
)
|
||||||
|
|
||||||
|
var roleNames = [...]string{"mixed", "prefill", "decode"}
|
||||||
|
|
||||||
|
func (r InstanceRole) String() string { return roleNames[r] }
|
||||||
|
|
||||||
|
func ParseInstanceRole(s string) (InstanceRole, error) {
|
||||||
|
for i, name := range roleNames {
|
||||||
|
if strings.EqualFold(strings.ToLower(s), name) {
|
||||||
|
return InstanceRole(i), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1, fmt.Errorf("invalid role: %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
EnumValue InstanceRole
|
||||||
|
CustomName string
|
||||||
|
IsCustom bool
|
||||||
|
IsSet bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Role) parse(getInt func() (int, error), getStr func() (string, error)) error {
|
||||||
|
r.IsSet = true
|
||||||
|
if i, err := getInt(); err == nil {
|
||||||
|
if i >= 0 && i <= int(DECODE) {
|
||||||
|
r.EnumValue, r.IsCustom = InstanceRole(i), false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid role integer: %d", i)
|
||||||
|
}
|
||||||
|
s, err := getStr()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if e, err := ParseInstanceRole(s); err == nil {
|
||||||
|
r.EnumValue, r.IsCustom = e, false
|
||||||
|
} else {
|
||||||
|
r.CustomName, r.IsCustom = s, true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Role) UnmarshalJSON(data []byte) error {
|
||||||
|
return r.parse(
|
||||||
|
func() (int, error) { var i int; return i, json.Unmarshal(data, &i) },
|
||||||
|
func() (string, error) { var s string; return s, json.Unmarshal(data, &s) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Role) UnmarshalYAML(u func(interface{}) error) error {
|
||||||
|
return r.parse(
|
||||||
|
func() (int, error) { var i int; return i, u(&i) },
|
||||||
|
func() (string, error) { var s string; return s, u(&s) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Role) MarshalJSON() ([]byte, error) {
|
||||||
|
if r.IsCustom {
|
||||||
|
return json.Marshal(r.CustomName)
|
||||||
|
}
|
||||||
|
return json.Marshal(r.EnumValue.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Port string
|
||||||
|
|
||||||
|
func (p *Port) UnmarshalJSON(data []byte) error {
|
||||||
|
var i int
|
||||||
|
if json.Unmarshal(data, &i) == nil {
|
||||||
|
*p = Port(strconv.Itoa(i))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, (*string)(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Port) UnmarshalYAML(u func(interface{}) error) error {
|
||||||
|
var i int
|
||||||
|
if u(&i) == nil {
|
||||||
|
*p = Port(strconv.Itoa(i))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return u((*string)(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
type IntToStringList []string
|
||||||
|
|
||||||
|
func (sl *IntToStringList) UnmarshalJSON(data []byte) error {
|
||||||
|
return sl.unmarshal(data, json.Unmarshal)
|
||||||
|
}
|
||||||
|
func (sl *IntToStringList) UnmarshalYAML(u func(interface{}) error) error {
|
||||||
|
return sl.unmarshal(nil, func(_ []byte, v interface{}) error { return u(v) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *IntToStringList) unmarshal(data []byte, u func([]byte, interface{}) error) error {
|
||||||
|
var raw []interface{}
|
||||||
|
if err := u(data, &raw); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res := make([]string, len(raw))
|
||||||
|
for i, v := range raw {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
res[i] = val
|
||||||
|
case int:
|
||||||
|
res[i] = strconv.Itoa(val)
|
||||||
|
case float64:
|
||||||
|
if val == float64(int(val)) {
|
||||||
|
res[i] = strconv.Itoa(int(val))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("element %d: %v not integer", i, val)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("element %d: type %T unsupported", i, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*sl = res
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type InstanceInfo struct {
|
||||||
|
Role Role `json:"role" yaml:"role"`
|
||||||
|
HostIP string `json:"host_ip" yaml:"host_ip"`
|
||||||
|
Port Port `json:"port" yaml:"port"`
|
||||||
|
ConnectorPort Port `json:"connector_port,omitempty" yaml:"connector_port,omitempty"`
|
||||||
|
EngineWorkerQueuePort Port `json:"engine_worker_queue_port,omitempty" yaml:"engine_worker_queue_port,omitempty"`
|
||||||
|
TransferProtocol []string `json:"transfer_protocol,omitempty" yaml:"transfer_protocol,omitempty"`
|
||||||
|
RDMAPorts IntToStringList `json:"rdma_ports,omitempty" yaml:"rdma_ports,omitempty"`
|
||||||
|
DeviceIDs IntToStringList `json:"device_ids,omitempty" yaml:"device_ids,omitempty"`
|
||||||
|
MetricsPort Port `json:"metrics_port,omitempty" yaml:"metrics_port,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidPort(p Port) bool {
|
||||||
|
i, err := strconv.Atoi(string(p))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return i > 0 && i <= 65535
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidIP(ip string) bool {
|
||||||
|
return net.ParseIP(ip) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePortList(name string, list []string) error {
|
||||||
|
for i, p := range list {
|
||||||
|
portInt, err := strconv.Atoi(p)
|
||||||
|
if err != nil || portInt <= 0 || portInt > 65535 {
|
||||||
|
return fmt.Errorf("%s[%d] invalid port: %s", name, i, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (info *InstanceInfo) URL() string {
|
||||||
|
url := fmt.Sprintf("%s:%s", info.HostIP, info.Port)
|
||||||
|
if !strings.HasPrefix(url, "http") {
|
||||||
|
url = "http://" + url
|
||||||
|
}
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInstanceInfo(info *InstanceInfo) (*InstanceInfo, error) {
|
||||||
|
if !info.Role.IsSet {
|
||||||
|
return nil, errors.New("role is required")
|
||||||
|
}
|
||||||
|
if info.Role.IsCustom {
|
||||||
|
return nil, fmt.Errorf("invalid role: %s", info.Role.CustomName)
|
||||||
|
}
|
||||||
|
if info.HostIP == "" {
|
||||||
|
return nil, errors.New("host_ip is required")
|
||||||
|
}
|
||||||
|
if !isValidIP(info.HostIP) {
|
||||||
|
return nil, fmt.Errorf("invalid host_ip: %s", info.HostIP)
|
||||||
|
}
|
||||||
|
if info.Port == "" {
|
||||||
|
return nil, errors.New("port is required")
|
||||||
|
}
|
||||||
|
if !isValidPort(info.Port) {
|
||||||
|
return nil, fmt.Errorf("invalid port: %s", info.Port)
|
||||||
|
}
|
||||||
|
if DefaultManager.splitwise && info.ConnectorPort != "" && !isValidPort(info.ConnectorPort) {
|
||||||
|
return nil, fmt.Errorf("invalid connector_port: %s", info.ConnectorPort)
|
||||||
|
}
|
||||||
|
if DefaultManager.splitwise && info.EngineWorkerQueuePort != "" && !isValidPort(info.EngineWorkerQueuePort) {
|
||||||
|
return nil, fmt.Errorf("invalid engine_worker_queue_port: %s", info.EngineWorkerQueuePort)
|
||||||
|
}
|
||||||
|
for _, proto := range info.TransferProtocol {
|
||||||
|
if !slices.Contains([]string{"ipc", "rdma"}, proto) {
|
||||||
|
return nil, fmt.Errorf("invalid protocol: %s", proto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := validatePortList("rdma_ports", info.RDMAPorts); DefaultManager.splitwise && err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if info.MetricsPort == "" {
|
||||||
|
info.MetricsPort = info.Port
|
||||||
|
} else {
|
||||||
|
if !isValidPort(info.MetricsPort) {
|
||||||
|
return nil, fmt.Errorf("invalid metrics_port: %s", info.MetricsPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInstanceRole_String(t *testing.T) {
|
||||||
|
assert.Equal(t, "mixed", MIXED.String())
|
||||||
|
assert.Equal(t, "prefill", PREFILL.String())
|
||||||
|
assert.Equal(t, "decode", DECODE.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseInstanceRole(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected InstanceRole
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{"mixed lowercase", "mixed", MIXED, false},
|
||||||
|
{"mixed uppercase", "MIXED", MIXED, false},
|
||||||
|
{"prefill lowercase", "prefill", PREFILL, false},
|
||||||
|
{"prefill uppercase", "PREFILL", PREFILL, false},
|
||||||
|
{"decode lowercase", "decode", DECODE, false},
|
||||||
|
{"decode uppercase", "DECODE", DECODE, false},
|
||||||
|
{"invalid role", "invalid", -1, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := ParseInstanceRole(tt.input)
|
||||||
|
if tt.expectErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRole_UnmarshalJSON(t *testing.T) {
|
||||||
|
t.Run("valid role from integer", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := json.Unmarshal([]byte("0"), &role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, MIXED, role.EnumValue)
|
||||||
|
assert.False(t, role.IsCustom)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid role from string", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := json.Unmarshal([]byte(`"prefill"`), &role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PREFILL, role.EnumValue)
|
||||||
|
assert.False(t, role.IsCustom)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom role", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := json.Unmarshal([]byte(`"custom-role"`), &role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "custom-role", role.CustomName)
|
||||||
|
assert.True(t, role.IsCustom)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid integer", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := json.Unmarshal([]byte("99"), &role)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRole_UnmarshalYAML(t *testing.T) {
|
||||||
|
t.Run("valid role from integer", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := yaml.Unmarshal([]byte("1"), &role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, PREFILL, role.EnumValue)
|
||||||
|
assert.False(t, role.IsCustom)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid role from string", func(t *testing.T) {
|
||||||
|
var role Role
|
||||||
|
err := yaml.Unmarshal([]byte("decode"), &role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, DECODE, role.EnumValue)
|
||||||
|
assert.False(t, role.IsCustom)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRole_MarshalJSON(t *testing.T) {
|
||||||
|
t.Run("standard role", func(t *testing.T) {
|
||||||
|
role := Role{EnumValue: MIXED, IsCustom: false}
|
||||||
|
data, err := json.Marshal(role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `"mixed"`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom role", func(t *testing.T) {
|
||||||
|
role := Role{CustomName: "custom", IsCustom: true}
|
||||||
|
data, err := json.Marshal(role)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `"custom"`, string(data))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPort_UnmarshalJSON(t *testing.T) {
|
||||||
|
t.Run("port as integer", func(t *testing.T) {
|
||||||
|
var port Port
|
||||||
|
err := json.Unmarshal([]byte("8080"), &port)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, Port("8080"), port)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("port as string", func(t *testing.T) {
|
||||||
|
var port Port
|
||||||
|
err := json.Unmarshal([]byte(`"9090"`), &port)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, Port("9090"), port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPort_UnmarshalYAML(t *testing.T) {
|
||||||
|
t.Run("port as integer", func(t *testing.T) {
|
||||||
|
var port Port
|
||||||
|
err := yaml.Unmarshal([]byte("8080"), &port)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, Port("8080"), port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntToStringList_UnmarshalJSON(t *testing.T) {
|
||||||
|
t.Run("mixed types", func(t *testing.T) {
|
||||||
|
var list IntToStringList
|
||||||
|
err := json.Unmarshal([]byte(`["1", 2, 3.0]`), &list)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, IntToStringList{"1", "2", "3"}, list)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid float", func(t *testing.T) {
|
||||||
|
var list IntToStringList
|
||||||
|
err := json.Unmarshal([]byte(`[1.5]`), &list)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid type", func(t *testing.T) {
|
||||||
|
var list IntToStringList
|
||||||
|
err := json.Unmarshal([]byte(`[true]`), &list)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstanceInfo_URL(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
url := info.URL()
|
||||||
|
assert.Equal(t, "http://127.0.0.1:8080", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewInstanceInfo(t *testing.T) {
|
||||||
|
// Setup DefaultManager
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
|
||||||
|
|
||||||
|
t.Run("valid instance", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: PREFILL, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
ConnectorPort: Port("9000"),
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
RDMAPorts: IntToStringList{"5000", "5001"},
|
||||||
|
DeviceIDs: IntToStringList{"0", "1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := NewInstanceInfo(info)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing role", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "role is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("custom role", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{CustomName: "custom", IsCustom: true, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid role")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing host_ip", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
Port: Port("8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "host_ip is required")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid host_ip", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
HostIP: "invalid-ip",
|
||||||
|
Port: Port("8080"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid host_ip")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid port", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("99999"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid port")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid connector_port", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: PREFILL, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
ConnectorPort: Port("99999"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid connector_port")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid transfer protocol", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: PREFILL, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
TransferProtocol: []string{"invalid"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid protocol")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid rdma port", func(t *testing.T) {
|
||||||
|
info := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: PREFILL, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
RDMAPorts: IntToStringList{"99999"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewInstanceInfo(info)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "rdma_ports[0] invalid port")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidPort(t *testing.T) {
|
||||||
|
assert.True(t, isValidPort(Port("8080")))
|
||||||
|
assert.True(t, isValidPort(Port("1")))
|
||||||
|
assert.True(t, isValidPort(Port("65535")))
|
||||||
|
assert.False(t, isValidPort(Port("0")))
|
||||||
|
assert.False(t, isValidPort(Port("65536")))
|
||||||
|
assert.False(t, isValidPort(Port("invalid")))
|
||||||
|
assert.False(t, isValidPort(Port("")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidIP(t *testing.T) {
|
||||||
|
assert.True(t, isValidIP("127.0.0.1"))
|
||||||
|
assert.True(t, isValidIP("192.168.1.1"))
|
||||||
|
assert.True(t, isValidIP("::1"))
|
||||||
|
assert.False(t, isValidIP("invalid"))
|
||||||
|
assert.False(t, isValidIP(""))
|
||||||
|
}
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Precompile regex to avoid repeated compilation
|
||||||
|
var (
|
||||||
|
runningRequestsRegex = regexp.MustCompile(`fastdeploy:num_requests_running\s+([0-9.]+)`)
|
||||||
|
waitingRequestsRegex = regexp.MustCompile(`fastdeploy:num_requests_waiting\s+([0-9.]+)`)
|
||||||
|
availableGpuBlockNumRegex = regexp.MustCompile(`available_gpu_block_num\s+([0-9.]+)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseMetricsResponseOptimized parses metrics response string and extracts key metrics
|
||||||
|
func parseMetricsResponseOptimized(response string) (float64, float64, float64) {
|
||||||
|
waitingCnt := -1.0
|
||||||
|
availableGpuBlockNum := -1.0
|
||||||
|
runningCnt := -1.0
|
||||||
|
|
||||||
|
// Find fastdeploy:num_requests_running field using precompiled regex
|
||||||
|
if matches := runningRequestsRegex.FindStringSubmatch(response); len(matches) >= 2 {
|
||||||
|
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||||||
|
runningCnt = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find fastdeploy:num_requests_waiting field using precompiled regex
|
||||||
|
if matches := waitingRequestsRegex.FindStringSubmatch(response); len(matches) >= 2 {
|
||||||
|
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||||||
|
waitingCnt = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse available_gpu_block_num field
|
||||||
|
if matches := availableGpuBlockNumRegex.FindStringSubmatch(response); len(matches) >= 2 {
|
||||||
|
if score, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||||||
|
availableGpuBlockNum = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return runningCnt, waitingCnt, availableGpuBlockNum
|
||||||
|
}
|
||||||
|
|
||||||
|
// redrictCounter gets or creates a counter for the given URL and returns current count
|
||||||
|
func redrictCounter(ctx context.Context, rawURL string) int {
|
||||||
|
counter := scheduler_handler.GetOrCreateCounter(ctx, rawURL)
|
||||||
|
return int(counter.Get())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetricsByURL retrieves running metrics of a worker by specified URL
|
||||||
|
func GetMetricsByURL(ctx context.Context, rawURL string) (int, int, int, error) {
|
||||||
|
workerInfo := getWorkerInfo(ctx, rawURL)
|
||||||
|
if workerInfo == nil {
|
||||||
|
return 0, 0, 0, errors.New("worker info not found for URL")
|
||||||
|
}
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
u.Host = net.JoinHostPort(host, workerInfo.MetricsPort)
|
||||||
|
metricsUrl := fmt.Sprintf("%s/metrics", u.String())
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: defaultCheckTimeout}
|
||||||
|
resp, err := client.Get(metricsUrl)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0, 0, 0, errors.New("metrics response is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse metrics response
|
||||||
|
runningCnt, waitingCnt, availableGpuBlockNum := parseMetricsResponseOptimized(string(body))
|
||||||
|
if runningCnt < 0 || waitingCnt < 0 || availableGpuBlockNum < 0 {
|
||||||
|
return 0, 0, 0, errors.New("failed to parse metrics response")
|
||||||
|
}
|
||||||
|
return int(runningCnt), int(waitingCnt), int(availableGpuBlockNum), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetrics retrieves running metrics of the worker for the specified URL
|
||||||
|
func (m *Manager) GetMetrics(ctx context.Context, rawURL string) (int, int, int) {
|
||||||
|
runningCnt, waitingCnt, availableGpuBlockNum, err := GetMetricsByURL(ctx, rawURL)
|
||||||
|
if err != nil {
|
||||||
|
runningNewCnt := redrictCounter(ctx, rawURL)
|
||||||
|
return runningNewCnt, 0, 0
|
||||||
|
}
|
||||||
|
return runningCnt, waitingCnt, availableGpuBlockNum
|
||||||
|
}
|
||||||
@@ -0,0 +1,357 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
scheduler_handler "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/handler"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Helper function to get map keys
|
||||||
|
func getMapKeys(m interface{}) []string {
|
||||||
|
v := reflect.ValueOf(m)
|
||||||
|
if v.Kind() != reflect.Map {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := v.MapKeys()
|
||||||
|
result := make([]string, len(keys))
|
||||||
|
for i, key := range keys {
|
||||||
|
result[i] = key.String()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedrictCounter(t *testing.T) {
|
||||||
|
// Initialize scheduler for counter tests
|
||||||
|
cfg := &config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "random",
|
||||||
|
PrefillPolicy: "random",
|
||||||
|
DecodePolicy: "random",
|
||||||
|
WaitingWeight: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Using nil for managerAPI since we're only testing counters
|
||||||
|
scheduler_handler.Init(cfg, nil)
|
||||||
|
|
||||||
|
// Setup test context
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Test case 1: First call with new URL should return 0
|
||||||
|
t.Run("new_url_returns_zero", func(t *testing.T) {
|
||||||
|
count := redrictCounter(ctx, "http://new-service-1")
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 2: Multiple calls with same URL should return incremented count
|
||||||
|
t.Run("same_url_increments", func(t *testing.T) {
|
||||||
|
url := "http://same-service"
|
||||||
|
|
||||||
|
// First call should return 0
|
||||||
|
count1 := redrictCounter(ctx, url)
|
||||||
|
assert.Equal(t, 0, count1)
|
||||||
|
|
||||||
|
// Simulate counter increment by calling GetOrCreateCounter and incrementing
|
||||||
|
counter := scheduler_handler.GetOrCreateCounter(ctx, url)
|
||||||
|
counter.Inc()
|
||||||
|
|
||||||
|
// Second call should return incremented count
|
||||||
|
count2 := redrictCounter(ctx, url)
|
||||||
|
assert.Equal(t, 1, count2)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 3: Different URLs should have independent counters
|
||||||
|
t.Run("different_urls_independent", func(t *testing.T) {
|
||||||
|
url1 := "http://service-a"
|
||||||
|
url2 := "http://service-b"
|
||||||
|
|
||||||
|
// Call first URL
|
||||||
|
count1 := redrictCounter(ctx, url1)
|
||||||
|
assert.Equal(t, 0, count1)
|
||||||
|
|
||||||
|
// Call second URL
|
||||||
|
count2 := redrictCounter(ctx, url2)
|
||||||
|
assert.Equal(t, 0, count2)
|
||||||
|
|
||||||
|
// Increment first URL's counter
|
||||||
|
counter1 := scheduler_handler.GetOrCreateCounter(ctx, url1)
|
||||||
|
counter1.Inc()
|
||||||
|
counter1.Inc()
|
||||||
|
|
||||||
|
// Verify first URL shows incremented count
|
||||||
|
assert.Equal(t, 2, redrictCounter(ctx, url1))
|
||||||
|
// Second URL should still be 0
|
||||||
|
assert.Equal(t, 0, redrictCounter(ctx, url2))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 4: Empty URL should work (edge case)
|
||||||
|
t.Run("empty_url", func(t *testing.T) {
|
||||||
|
count := redrictCounter(ctx, "")
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test case 5: Nil context should work (though not recommended)
|
||||||
|
t.Run("nil_context", func(t *testing.T) {
|
||||||
|
count := redrictCounter(ctx, "http://nil-context-test")
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMetricsResponseOptimized(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedRun float64
|
||||||
|
expectedWait float64
|
||||||
|
expectedGpu float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid metrics response",
|
||||||
|
input: `fastdeploy:num_requests_running 10
|
||||||
|
fastdeploy:num_requests_waiting 5
|
||||||
|
available_gpu_block_num 3`,
|
||||||
|
expectedRun: 10,
|
||||||
|
expectedWait: 5,
|
||||||
|
expectedGpu: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial metrics response",
|
||||||
|
input: `fastdeploy:num_requests_running 8
|
||||||
|
available_gpu_block_num 2`,
|
||||||
|
expectedRun: 8,
|
||||||
|
expectedWait: -1,
|
||||||
|
expectedGpu: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty response",
|
||||||
|
input: "",
|
||||||
|
expectedRun: -1,
|
||||||
|
expectedWait: -1,
|
||||||
|
expectedGpu: -1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
run, wait, gpu := parseMetricsResponseOptimized(tt.input)
|
||||||
|
assert.Equal(t, tt.expectedRun, run)
|
||||||
|
assert.Equal(t, tt.expectedWait, wait)
|
||||||
|
assert.Equal(t, tt.expectedGpu, gpu)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMetricsByURL_Integration(t *testing.T) {
|
||||||
|
// Initialize manager for testing
|
||||||
|
Init(&config.Config{})
|
||||||
|
|
||||||
|
// Test with mock HTTP server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/metrics" {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.URL.Query().Get("scenario") {
|
||||||
|
case "valid":
|
||||||
|
w.Write([]byte(`fastdeploy:num_requests_running 5
|
||||||
|
fastdeploy:num_requests_waiting 3
|
||||||
|
available_gpu_block_num 10`))
|
||||||
|
case "partial":
|
||||||
|
w.Write([]byte(`fastdeploy:num_requests_running 2
|
||||||
|
available_gpu_block_num 5`))
|
||||||
|
case "empty":
|
||||||
|
// Empty response body
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
case "error":
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
default:
|
||||||
|
w.Write([]byte(`fastdeploy:num_requests_running 1
|
||||||
|
fastdeploy:num_requests_waiting 0
|
||||||
|
available_gpu_block_num 8`))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Test worker not found
|
||||||
|
t.Run("worker_not_found", func(t *testing.T) {
|
||||||
|
_, _, _, err := GetMetricsByURL(context.Background(), "http://nonexistent-worker:8080")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "worker info not found")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test with registered worker
|
||||||
|
t.Run("with_registered_worker", func(t *testing.T) {
|
||||||
|
// Register a test worker with localhost to avoid DNS resolution issues
|
||||||
|
workerURL := "http://localhost:8080"
|
||||||
|
DefaultManager.mu.Lock()
|
||||||
|
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
|
||||||
|
Url: workerURL,
|
||||||
|
WorkerType: "prefill",
|
||||||
|
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
|
||||||
|
}
|
||||||
|
t.Logf("Registered worker: URL=%s, MetricsPort=%s", workerURL, strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
|
||||||
|
DefaultManager.mu.Unlock()
|
||||||
|
|
||||||
|
// Debug: check what URLs are in the map
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
t.Logf("prefillWorkerMap keys: %v", getMapKeys(DefaultManager.prefillWorkerMap))
|
||||||
|
DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
t.Logf("Looking up worker for URL: %s", workerURL)
|
||||||
|
|
||||||
|
// Test valid metrics response - the test server should handle scenarios based on the request
|
||||||
|
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Error: %v", err)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, running) // Default scenario in test server
|
||||||
|
assert.Equal(t, 0, waiting) // Default scenario in test server
|
||||||
|
assert.Equal(t, 8, gpu) // Default scenario in test server
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test invalid URL
|
||||||
|
t.Run("invalid_url", func(t *testing.T) {
|
||||||
|
_, _, _, err := GetMetricsByURL(context.Background(), "http://invalid url")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
// Test partial metrics
|
||||||
|
t.Run("partial_metrics", func(t *testing.T) {
|
||||||
|
workerURL := "http://localhost:8081" // Different port
|
||||||
|
DefaultManager.mu.Lock()
|
||||||
|
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
|
||||||
|
Url: workerURL,
|
||||||
|
WorkerType: "prefill",
|
||||||
|
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
|
||||||
|
}
|
||||||
|
t.Logf("Registered worker: URL=%s, MetricsPort=%s", workerURL, strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
|
||||||
|
DefaultManager.mu.Unlock()
|
||||||
|
|
||||||
|
t.Logf("Looking up worker for URL: %s", workerURL)
|
||||||
|
|
||||||
|
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Error: %v", err)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, running) // Default scenario
|
||||||
|
assert.Equal(t, 0, waiting) // Default scenario
|
||||||
|
assert.Equal(t, 8, gpu) // Default scenario
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerGetMetrics_Integration(t *testing.T) {
|
||||||
|
// Initialize manager for testing
|
||||||
|
Init(&config.Config{})
|
||||||
|
// Initialize scheduler for counter tests
|
||||||
|
cfg := &config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "random",
|
||||||
|
PrefillPolicy: "random",
|
||||||
|
DecodePolicy: "random",
|
||||||
|
WaitingWeight: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
scheduler_handler.Init(cfg, nil)
|
||||||
|
|
||||||
|
m := &Manager{}
|
||||||
|
|
||||||
|
// Setup mock HTTP server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/metrics" {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte(`fastdeploy:num_requests_running 2
|
||||||
|
fastdeploy:num_requests_waiting 1
|
||||||
|
available_gpu_block_num 5`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Test normal case with registered worker
|
||||||
|
t.Run("normal_case", func(t *testing.T) {
|
||||||
|
workerURL := "http://localhost:8080" // Use localhost to avoid DNS lookup
|
||||||
|
DefaultManager.mu.Lock()
|
||||||
|
DefaultManager.prefillWorkerMap[workerURL] = &WorkerInfo{
|
||||||
|
Url: workerURL,
|
||||||
|
WorkerType: "prefill",
|
||||||
|
MetricsPort: strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port),
|
||||||
|
}
|
||||||
|
t.Logf("Registered worker with MetricsPort: %s", strconv.Itoa(server.Listener.Addr().(*net.TCPAddr).Port))
|
||||||
|
DefaultManager.mu.Unlock()
|
||||||
|
|
||||||
|
// Debug: test GetMetricsByURL directly first
|
||||||
|
t.Logf("Testing GetMetricsByURL directly...")
|
||||||
|
running, waiting, gpu, err := GetMetricsByURL(context.Background(), workerURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("GetMetricsByURL failed: %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("GetMetricsByURL succeeded: running=%d, waiting=%d, gpu=%d", running, waiting, gpu)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now test Manager.GetMetrics
|
||||||
|
running, waiting, gpu = m.GetMetrics(context.Background(), workerURL)
|
||||||
|
t.Logf("Manager.GetMetrics result: running=%d, waiting=%d, gpu=%d", running, waiting, gpu)
|
||||||
|
assert.Equal(t, 2, running)
|
||||||
|
assert.Equal(t, 1, waiting)
|
||||||
|
assert.Equal(t, 5, gpu)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test error case (should fall back to counter)
|
||||||
|
t.Run("error_fallback", func(t *testing.T) {
|
||||||
|
// Use a URL that doesn't have a registered worker
|
||||||
|
workerURL := "http://unknown-worker:8080"
|
||||||
|
running, waiting, gpu := m.GetMetrics(context.Background(), workerURL)
|
||||||
|
|
||||||
|
// Should fall back to counter (which should be 0 for new URL)
|
||||||
|
assert.Equal(t, 0, running) // Should be 0 for new counter
|
||||||
|
assert.Equal(t, 0, waiting) // Should be 0 in error case
|
||||||
|
assert.Equal(t, 0, gpu) // Should be 0 in error case
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to test metrics parsing directly
|
||||||
|
func TestMetricsParsingHelper(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
metricsBody string
|
||||||
|
expected []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "complete_metrics",
|
||||||
|
metricsBody: `fastdeploy:num_requests_running 5
|
||||||
|
fastdeploy:num_requests_waiting 3
|
||||||
|
available_gpu_block_num 10`,
|
||||||
|
expected: []int{5, 3, 10},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_waiting",
|
||||||
|
metricsBody: `fastdeploy:num_requests_running 2
|
||||||
|
available_gpu_block_num 5`,
|
||||||
|
expected: []int{2, -1, 5},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_body",
|
||||||
|
metricsBody: "",
|
||||||
|
expected: []int{-1, -1, -1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
run, wait, gpu := parseMetricsResponseOptimized(tt.metricsBody)
|
||||||
|
assert.Equal(t, float64(tt.expected[0]), run)
|
||||||
|
assert.Equal(t, float64(tt.expected[1]), wait)
|
||||||
|
assert.Equal(t, float64(tt.expected[2]), gpu)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RegisterConfig struct {
|
||||||
|
Instances []InstanceInfo `yaml:"instances"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSplitwise(ctx context.Context) bool {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
return DefaultManager.splitwise
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAllMapServers(ctx context.Context) map[string]*WorkerInfo {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
return make(map[string]*WorkerInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
allServers := make(map[string]*WorkerInfo)
|
||||||
|
for id, workerInfo := range DefaultManager.prefillWorkerMap {
|
||||||
|
allServers[id] = workerInfo
|
||||||
|
}
|
||||||
|
for id, workerInfo := range DefaultManager.decodeWorkerMap {
|
||||||
|
allServers[id] = workerInfo
|
||||||
|
}
|
||||||
|
for id, workerInfo := range DefaultManager.mixedWorkerMap {
|
||||||
|
allServers[id] = workerInfo
|
||||||
|
}
|
||||||
|
return allServers
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWorkerInfo gets WorkerInfo based on URL
|
||||||
|
func getWorkerInfo(ctx context.Context, url string) *WorkerInfo {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
if w, ok := DefaultManager.prefillWorkerMap[url]; ok {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
if w, ok := DefaultManager.decodeWorkerMap[url]; ok {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
if w, ok := DefaultManager.mixedWorkerMap[url]; ok {
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildDisaggregateInfo builds disaggregate_info structure
|
||||||
|
func BuildDisaggregateInfo(ctx context.Context, prefillURL, decodeURL string) (map[string]any, error) {
|
||||||
|
prefillInfo := getWorkerInfo(ctx, prefillURL)
|
||||||
|
decodeInfo := getWorkerInfo(ctx, decodeURL)
|
||||||
|
if prefillInfo == nil || decodeInfo == nil {
|
||||||
|
return nil, fmt.Errorf("worker instance not found for prefill=%s, decode=%s", prefillURL, decodeURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefillHost := hostFromURL(prefillInfo.Url)
|
||||||
|
decodeHost := hostFromURL(decodeInfo.Url)
|
||||||
|
|
||||||
|
// Check if IPC can be used
|
||||||
|
isSameNode := prefillHost != "" && prefillHost == decodeHost
|
||||||
|
isSupportIPC := slices.Contains(prefillInfo.TransferProtocol, "ipc") &&
|
||||||
|
slices.Contains(decodeInfo.TransferProtocol, "ipc")
|
||||||
|
tpPrefill := tpSizeFromWorker(prefillInfo)
|
||||||
|
tpDecode := tpSizeFromWorker(decodeInfo)
|
||||||
|
isSameTpSize := tpPrefill == tpDecode || tpDecode == 1
|
||||||
|
useIPC := isSameNode && isSupportIPC && isSameTpSize
|
||||||
|
|
||||||
|
transferProto := "rdma"
|
||||||
|
if useIPC {
|
||||||
|
transferProto = "ipc"
|
||||||
|
}
|
||||||
|
|
||||||
|
disagg := map[string]any{
|
||||||
|
"prefill_ip": prefillHost,
|
||||||
|
"decode_ip": decodeHost,
|
||||||
|
"prefill_connector_port": portStringToInt(Port(prefillInfo.ConnectorPort)),
|
||||||
|
"decode_connector_port": portStringToInt(Port(decodeInfo.ConnectorPort)),
|
||||||
|
"decode_device_ids": []string(decodeInfo.DeviceIDs),
|
||||||
|
"decode_rdma_ports": []string(decodeInfo.RdmaPorts),
|
||||||
|
"transfer_protocol": transferProto,
|
||||||
|
"decode_tp_size": tpDecode,
|
||||||
|
}
|
||||||
|
return disagg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// portStringToInt converts Port (string) to int
|
||||||
|
func portStringToInt(p Port) int {
|
||||||
|
s := string(p)
|
||||||
|
if s == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
i, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// tpSizeFromWorker calculates tp_size (currently no explicit field, uses device_ids count, minimum 1)
|
||||||
|
func tpSizeFromWorker(w *WorkerInfo) int {
|
||||||
|
if w == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(w.DeviceIDs) > 0 {
|
||||||
|
return len(w.DeviceIDs)
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostFromURL extracts host part (without port)
|
||||||
|
func hostFromURL(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(raw, "http://") && !strings.HasPrefix(raw, "https://") {
|
||||||
|
raw = "http://" + raw
|
||||||
|
}
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.Hostname()
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterInstanceCore(ctx context.Context, rawInstance *InstanceInfo) error {
|
||||||
|
instance, err := NewInstanceInfo(rawInstance)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid InstanceInfo format:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
splitwiseMode := GetSplitwise(ctx)
|
||||||
|
|
||||||
|
instanceRole := instance.Role.EnumValue
|
||||||
|
if splitwiseMode && instanceRole == MIXED {
|
||||||
|
return fmt.Errorf("splitwise mode only supports PREFILL/DECODE instances")
|
||||||
|
}
|
||||||
|
if !splitwiseMode && instanceRole != MIXED {
|
||||||
|
return fmt.Errorf("only MIXED instances are allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check instance health status
|
||||||
|
if !CheckWorkerHealth(ctx, instance.URL()) {
|
||||||
|
return fmt.Errorf("service is not healthy")
|
||||||
|
}
|
||||||
|
|
||||||
|
allServers := GetAllMapServers(ctx)
|
||||||
|
|
||||||
|
DefaultManager.mu.Lock()
|
||||||
|
defer DefaultManager.mu.Unlock()
|
||||||
|
|
||||||
|
workerInfo := &WorkerInfo{
|
||||||
|
Url: instance.URL(),
|
||||||
|
WorkerType: instance.Role.EnumValue.String(),
|
||||||
|
ConnectorPort: string(instance.ConnectorPort),
|
||||||
|
EngineWorkerQueuePort: string(instance.EngineWorkerQueuePort),
|
||||||
|
TransferProtocol: instance.TransferProtocol,
|
||||||
|
RdmaPorts: []string(instance.RDMAPorts),
|
||||||
|
DeviceIDs: []string(instance.DeviceIDs),
|
||||||
|
MetricsPort: string(instance.MetricsPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
id := instance.URL()
|
||||||
|
|
||||||
|
if w, exists := allServers[id]; exists {
|
||||||
|
wType, err := ParseInstanceRole(w.WorkerType)
|
||||||
|
if err == nil {
|
||||||
|
switch wType {
|
||||||
|
case MIXED:
|
||||||
|
delete(DefaultManager.mixedWorkerMap, id)
|
||||||
|
case PREFILL:
|
||||||
|
delete(DefaultManager.prefillWorkerMap, id)
|
||||||
|
case DECODE:
|
||||||
|
delete(DefaultManager.decodeWorkerMap, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch instanceRole {
|
||||||
|
case MIXED:
|
||||||
|
DefaultManager.mixedWorkerMap[id] = workerInfo
|
||||||
|
case PREFILL:
|
||||||
|
DefaultManager.prefillWorkerMap[id] = workerInfo
|
||||||
|
case DECODE:
|
||||||
|
DefaultManager.decodeWorkerMap[id] = workerInfo
|
||||||
|
default:
|
||||||
|
logger.Warn("Instance %s role is unknown", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterInstance(c *gin.Context) {
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"code": 400,
|
||||||
|
"msg": "Invalid request body",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawInstance InstanceInfo
|
||||||
|
err = json.Unmarshal(bodyBytes, &rawInstance)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"code": 400,
|
||||||
|
"msg": fmt.Sprintf("Invalid InstanceInfo JSON format: %v", err),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := RegisterInstanceCore(c.Request.Context(), &rawInstance); err != nil {
|
||||||
|
logger.Error("Failed to register instance: %v", err)
|
||||||
|
// Return different HTTP status codes based on error type
|
||||||
|
if strings.Contains(err.Error(), "not healthy") {
|
||||||
|
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||||
|
"code": 503,
|
||||||
|
"msg": err.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"code": 400,
|
||||||
|
"msg": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "Register success",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterInstancesFromConfig(yamlPath string) {
|
||||||
|
if yamlPath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(yamlPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to read YAML file %s: %v", yamlPath, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var config RegisterConfig
|
||||||
|
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||||
|
logger.Error("Failed to unmarshal YAML file %s: %v", yamlPath, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(config.Instances) == 0 {
|
||||||
|
logger.Info("No instances found in config file %s", yamlPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, instanceConfig := range config.Instances {
|
||||||
|
if err := RegisterInstanceCore(context.Background(), &instanceConfig); err != nil {
|
||||||
|
logger.Error("Failed to register instance from index %d: %v", i, err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Successfully registered instance from index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisteredNumber(c *gin.Context) {
|
||||||
|
if DefaultManager == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"code": 400,
|
||||||
|
"msg": "DefaultManager is nil",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"mixed": len(DefaultManager.mixedWorkerMap),
|
||||||
|
"prefill": len(DefaultManager.prefillWorkerMap),
|
||||||
|
"decode": len(DefaultManager.decodeWorkerMap),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Registered(c *gin.Context) {
|
||||||
|
DefaultManager.mu.RLock()
|
||||||
|
defer DefaultManager.mu.RUnlock()
|
||||||
|
|
||||||
|
var prefillInstances, decodeInstances, mixedInstances []WorkerInfo
|
||||||
|
decodeInstances = make([]WorkerInfo, 0)
|
||||||
|
prefillInstances = make([]WorkerInfo, 0)
|
||||||
|
mixedInstances = make([]WorkerInfo, 0)
|
||||||
|
for _, w := range DefaultManager.prefillWorkerMap {
|
||||||
|
prefillInstances = append(prefillInstances, *w)
|
||||||
|
}
|
||||||
|
for _, w := range DefaultManager.decodeWorkerMap {
|
||||||
|
decodeInstances = append(decodeInstances, *w)
|
||||||
|
}
|
||||||
|
for _, w := range DefaultManager.mixedWorkerMap {
|
||||||
|
mixedInstances = append(mixedInstances, *w)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"code": http.StatusOK,
|
||||||
|
"msg": "success",
|
||||||
|
"decode": decodeInstances,
|
||||||
|
"prefill": prefillInstances,
|
||||||
|
"mixed": mixedInstances,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,395 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetSplitwise(t *testing.T) {
|
||||||
|
t.Run("DefaultManager is nil", func(t *testing.T) {
|
||||||
|
originalManager := DefaultManager
|
||||||
|
DefaultManager = nil
|
||||||
|
defer func() { DefaultManager = originalManager }()
|
||||||
|
|
||||||
|
result := GetSplitwise(context.Background())
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("splitwise mode enabled", func(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
|
||||||
|
result := GetSplitwise(context.Background())
|
||||||
|
assert.True(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("splitwise mode disabled", func(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
|
||||||
|
result := GetSplitwise(context.Background())
|
||||||
|
assert.False(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAllMapServers(t *testing.T) {
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
DefaultManager.mixedWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker3": {Url: "http://worker3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := GetAllMapServers(context.Background())
|
||||||
|
assert.Len(t, servers, 3)
|
||||||
|
assert.Contains(t, servers, "worker1")
|
||||||
|
assert.Contains(t, servers, "worker2")
|
||||||
|
assert.Contains(t, servers, "worker3")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAllMapServers_NilManager(t *testing.T) {
|
||||||
|
originalManager := DefaultManager
|
||||||
|
DefaultManager = nil
|
||||||
|
defer func() { DefaultManager = originalManager }()
|
||||||
|
|
||||||
|
servers := GetAllMapServers(context.Background())
|
||||||
|
assert.NotNil(t, servers)
|
||||||
|
assert.Len(t, servers, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWorkerInfo(t *testing.T) {
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker1": {Url: "http://worker1", WorkerType: "prefill"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://worker2": {Url: "http://worker2", WorkerType: "decode"},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("find prefill worker", func(t *testing.T) {
|
||||||
|
info := getWorkerInfo(context.Background(), "http://worker1")
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
assert.Equal(t, "prefill", info.WorkerType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("find decode worker", func(t *testing.T) {
|
||||||
|
info := getWorkerInfo(context.Background(), "http://worker2")
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
assert.Equal(t, "decode", info.WorkerType)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("worker not found", func(t *testing.T) {
|
||||||
|
info := getWorkerInfo(context.Background(), "http://notfound")
|
||||||
|
assert.Nil(t, info)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildDisaggregateInfo(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
|
||||||
|
|
||||||
|
// Setup test workers
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://127.0.0.1:8000": {
|
||||||
|
Url: "http://127.0.0.1:8000",
|
||||||
|
WorkerType: "prefill",
|
||||||
|
ConnectorPort: "9000",
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
DeviceIDs: []string{"0", "1"},
|
||||||
|
RdmaPorts: []string{"5000", "5001"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"http://127.0.0.1:8001": {
|
||||||
|
Url: "http://127.0.0.1:8001",
|
||||||
|
WorkerType: "decode",
|
||||||
|
ConnectorPort: "9001",
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
DeviceIDs: []string{"0", "1"},
|
||||||
|
RdmaPorts: []string{"5002", "5003"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("successful build", func(t *testing.T) {
|
||||||
|
info, err := BuildDisaggregateInfo(context.Background(),
|
||||||
|
"http://127.0.0.1:8000", "http://127.0.0.1:8001")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, info)
|
||||||
|
assert.Equal(t, "127.0.0.1", info["prefill_ip"])
|
||||||
|
assert.Equal(t, "127.0.0.1", info["decode_ip"])
|
||||||
|
assert.Equal(t, "rdma", info["transfer_protocol"])
|
||||||
|
assert.Equal(t, 2, info["decode_tp_size"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("worker not found", func(t *testing.T) {
|
||||||
|
_, err := BuildDisaggregateInfo(context.Background(),
|
||||||
|
"http://notfound", "http://notfound2")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "worker instance not found")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortStringToInt(t *testing.T) {
|
||||||
|
assert.Equal(t, 8080, portStringToInt(Port("8080")))
|
||||||
|
assert.Equal(t, 0, portStringToInt(Port("")))
|
||||||
|
assert.Equal(t, 0, portStringToInt(Port("invalid")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTpSizeFromWorker(t *testing.T) {
|
||||||
|
t.Run("worker with device IDs", func(t *testing.T) {
|
||||||
|
worker := &WorkerInfo{DeviceIDs: []string{"0", "1", "2"}}
|
||||||
|
assert.Equal(t, 3, tpSizeFromWorker(worker))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("worker without device IDs", func(t *testing.T) {
|
||||||
|
worker := &WorkerInfo{DeviceIDs: []string{}}
|
||||||
|
assert.Equal(t, 1, tpSizeFromWorker(worker))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil worker", func(t *testing.T) {
|
||||||
|
assert.Equal(t, 0, tpSizeFromWorker(nil))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostFromURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"standard URL", "http://127.0.0.1:8080", "127.0.0.1"},
|
||||||
|
{"HTTPS URL", "https://example.com:443", "example.com"},
|
||||||
|
{"URL without protocol", "127.0.0.1:8080", "127.0.0.1"},
|
||||||
|
{"empty URL", "", ""},
|
||||||
|
{"invalid URL", "://invalid", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := hostFromURL(tt.input)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstanceCore(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
|
||||||
|
|
||||||
|
t.Run("successful registration", func(t *testing.T) {
|
||||||
|
instance := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := RegisterInstanceCore(context.Background(), instance)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid instance info", func(t *testing.T) {
|
||||||
|
instance := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
// Missing required fields
|
||||||
|
}
|
||||||
|
|
||||||
|
err := RegisterInstanceCore(context.Background(), instance)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("splitwise mode with mixed instance", func(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: true}})
|
||||||
|
|
||||||
|
instance := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: MIXED, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := RegisterInstanceCore(context.Background(), instance)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "splitwise mode only supports PREFILL/DECODE instances")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-splitwise mode with prefill instance", func(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
|
||||||
|
|
||||||
|
instance := &InstanceInfo{
|
||||||
|
Role: Role{EnumValue: PREFILL, IsSet: true},
|
||||||
|
HostIP: "127.0.0.1",
|
||||||
|
Port: Port("8080"),
|
||||||
|
TransferProtocol: []string{"rdma"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := RegisterInstanceCore(context.Background(), instance)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "only MIXED instances are allowed")
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstance(t *testing.T) {
|
||||||
|
// Setup test server
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
|
||||||
|
|
||||||
|
t.Run("successful registration", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body := `{"role": "mixed", "host_ip": "127.0.0.1", "port": 8080, "transfer_protocol": ["rdma"]}`
|
||||||
|
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString(body))
|
||||||
|
|
||||||
|
RegisterInstance(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Register success")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid JSON", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString("invalid json"))
|
||||||
|
|
||||||
|
RegisterInstance(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty body", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
c.Request = httptest.NewRequest("POST", "/register", bytes.NewBufferString(""))
|
||||||
|
|
||||||
|
RegisterInstance(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstancesFromConfig(t *testing.T) {
|
||||||
|
Init(&config.Config{Server: config.ServerConfig{Splitwise: false}})
|
||||||
|
|
||||||
|
// Create a temporary YAML file
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
yamlPath := filepath.Join(tmpDir, "test_config.yaml")
|
||||||
|
|
||||||
|
yamlContent := `
|
||||||
|
instances:
|
||||||
|
- role: mixed
|
||||||
|
host_ip: 127.0.0.1
|
||||||
|
port: 8080
|
||||||
|
transfer_protocol:
|
||||||
|
- rdma
|
||||||
|
`
|
||||||
|
err := os.WriteFile(yamlPath, []byte(yamlContent), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// This should not panic
|
||||||
|
RegisterInstancesFromConfig(yamlPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstancesFromConfig_InvalidPath(t *testing.T) {
|
||||||
|
// Should not panic with invalid path
|
||||||
|
RegisterInstancesFromConfig("/nonexistent/path.yaml")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstancesFromConfig_InvalidYAML(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
yamlPath := filepath.Join(tmpDir, "invalid.yaml")
|
||||||
|
|
||||||
|
err := os.WriteFile(yamlPath, []byte("invalid: yaml: content:"), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Should not panic with invalid YAML
|
||||||
|
RegisterInstancesFromConfig(yamlPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInstancesFromConfig_EmptyFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
yamlPath := filepath.Join(tmpDir, "empty.yaml")
|
||||||
|
|
||||||
|
err := os.WriteFile(yamlPath, []byte("instances: []"), 0644)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Should not panic with empty instances
|
||||||
|
RegisterInstancesFromConfig(yamlPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisteredNumber(t *testing.T) {
|
||||||
|
t.Run("DefaultManager is nil", func(t *testing.T) {
|
||||||
|
originalManager := DefaultManager
|
||||||
|
DefaultManager = nil
|
||||||
|
defer func() { DefaultManager = originalManager }()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/registered/number", nil)
|
||||||
|
|
||||||
|
RegisteredNumber(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "DefaultManager is nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful query", func(t *testing.T) {
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: "http://worker1"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: "http://worker2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/registered/number", nil)
|
||||||
|
|
||||||
|
RegisteredNumber(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), `"prefill":1`)
|
||||||
|
assert.Contains(t, w.Body.String(), `"decode":1`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistered(t *testing.T) {
|
||||||
|
Init(&config.Config{})
|
||||||
|
DefaultManager.prefillWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker1": {Url: "http://worker1", WorkerType: "prefill"},
|
||||||
|
}
|
||||||
|
DefaultManager.decodeWorkerMap = map[string]*WorkerInfo{
|
||||||
|
"worker2": {Url: "http://worker2", WorkerType: "decode"},
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("GET", "/registered", nil)
|
||||||
|
|
||||||
|
Registered(c)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), `"decode"`)
|
||||||
|
assert.Contains(t, w.Body.String(), `"prefill"`)
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logger logger middleware
|
||||||
|
func Logger() gin.HandlerFunc {
|
||||||
|
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
|
logger.Info("[%s] %s %s %d %s %s",
|
||||||
|
param.Method,
|
||||||
|
param.Path,
|
||||||
|
param.Request.Proto,
|
||||||
|
param.StatusCode,
|
||||||
|
param.Latency,
|
||||||
|
param.ClientIP,
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recovery recovery middleware
|
||||||
|
func Recovery() gin.HandlerFunc {
|
||||||
|
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||||
|
logger.Error("Panic recovered: %v", recovered)
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"code": 500,
|
||||||
|
"msg": "Internal server error",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Initialize logger to avoid nil pointer dereference in recovery middleware
|
||||||
|
logger.Init("info", "stdout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerMiddleware(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(Logger())
|
||||||
|
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.String(200, "OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoveryMiddleware(t *testing.T) {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(Recovery())
|
||||||
|
|
||||||
|
router.GET("/panic", func(c *gin.Context) {
|
||||||
|
panic("test panic")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, w.Code)
|
||||||
|
// The response should contain the error message
|
||||||
|
assert.Contains(t, w.Body.String(), "Internal server error")
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/metrics"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metrics provides middleware for collecting HTTP request metrics
|
||||||
|
func Metrics() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
method := c.Request.Method
|
||||||
|
|
||||||
|
// Time before request processing starts
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
c.Next() // Process the request
|
||||||
|
|
||||||
|
// Collect response time statistics after request processing completes
|
||||||
|
duration := time.Since(start)
|
||||||
|
status := strconv.Itoa(c.Writer.Status())
|
||||||
|
|
||||||
|
// Collect metrics information
|
||||||
|
metrics.TotalRequests.WithLabelValues(method, path, status).Inc() // Increment request count
|
||||||
|
metrics.RequestDuration.WithLabelValues(method, path).Observe(duration.Seconds()) // Record request response time
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetricsMiddleware(t *testing.T) {
|
||||||
|
// Setup test router
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(Metrics())
|
||||||
|
router.GET("/test", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Make test request
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("GET", "/test", nil)
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "OK", w.Body.String())
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/middleware"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/gateway"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func New(cfg *config.Config) *gin.Engine {
|
||||||
|
// Set Gin mode
|
||||||
|
gin.SetMode(cfg.Server.Mode)
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
|
||||||
|
// Global middleware
|
||||||
|
r.Use(middleware.Logger())
|
||||||
|
r.Use(middleware.Recovery())
|
||||||
|
|
||||||
|
// API route group
|
||||||
|
v1 := r.Group("/v1")
|
||||||
|
{
|
||||||
|
v1.POST("/chat/completions", gateway.ChatCompletions)
|
||||||
|
v1.POST("/completions", gateway.ChatCompletions)
|
||||||
|
}
|
||||||
|
r.POST("/register", manager.RegisterInstance)
|
||||||
|
r.GET("/registered_number", manager.RegisteredNumber)
|
||||||
|
r.GET("/registered", manager.Registered)
|
||||||
|
r.GET("/health_generate", manager.HealthGenerate)
|
||||||
|
r.GET("/metrics", gin.WrapH(promhttp.Handler()))
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Counter struct {
|
||||||
|
count atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counter) Inc() {
|
||||||
|
c.count.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counter) Dec() {
|
||||||
|
c.count.Add(^uint64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Counter) Get() uint64 {
|
||||||
|
return c.count.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenCounter records the number of tokens currently being processed by each P instance
|
||||||
|
type TokenCounter struct {
|
||||||
|
tokens atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TokenCounter) Add(n uint64) {
|
||||||
|
c.tokens.Add(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TokenCounter) Get() uint64 {
|
||||||
|
return c.tokens.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TokenCounter) Sub(n uint64) {
|
||||||
|
if n == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
old := c.tokens.Load()
|
||||||
|
if old == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var newVal uint64
|
||||||
|
if old <= n {
|
||||||
|
newVal = 0
|
||||||
|
} else {
|
||||||
|
newVal = old - n
|
||||||
|
}
|
||||||
|
if c.tokens.CompareAndSwap(old, newVal) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type SelectStrategyFunc func(ctx context.Context, workers []string, message string) (string, error)
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
func computeScore(ctx context.Context, runningCnt int, waitingCnt int) float64 {
|
||||||
|
score := float64(runningCnt) + float64(waitingCnt)*waitingWeight
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
func FDMetricsScoreSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
selectedURL string = ""
|
||||||
|
minScore float64 = math.MaxFloat64
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, w := range workers {
|
||||||
|
runningCnt, waitingCnt, _ := DefaultScheduler.managerAPI.GetMetrics(ctx, w)
|
||||||
|
score := computeScore(ctx, runningCnt, waitingCnt)
|
||||||
|
if score < minScore {
|
||||||
|
minScore = score
|
||||||
|
selectedURL = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return selectedURL, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,282 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
common "github.com/PaddlePaddle/FastDeploy/router/internal/common"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
scheduler_common "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/common"
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Scheduler struct {
|
||||||
|
policy string
|
||||||
|
prefillPolicy string
|
||||||
|
decodePolicy string
|
||||||
|
IdCounterMap map[string]*scheduler_common.Counter
|
||||||
|
tokenMap map[string]*scheduler_common.TokenCounter
|
||||||
|
managerAPI common.ManagerAPI
|
||||||
|
prefillCache *prefillCacheStrategy
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type CounterPolicy struct {
|
||||||
|
counter atomic.Uint64
|
||||||
|
prefillCounter atomic.Uint64
|
||||||
|
workerType string
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultScheduler *Scheduler
|
||||||
|
var DefaultCounterPolicy *CounterPolicy
|
||||||
|
var waitingWeight float64
|
||||||
|
|
||||||
|
// Init initializes the scheduler with the given configuration and manager API
|
||||||
|
func Init(cfg *config.Config, managerAPI common.ManagerAPI) {
|
||||||
|
prefillCfg := &schedulerConfigSnapshot{
|
||||||
|
balanceAbsThreshold: cfg.Scheduler.BalanceAbsThreshold,
|
||||||
|
balanceRelThreshold: cfg.Scheduler.BalanceRelThreshold,
|
||||||
|
hitRatioWeight: cfg.Scheduler.HitRatioWeight,
|
||||||
|
loadBalanceWeight: cfg.Scheduler.LoadBalanceWeight,
|
||||||
|
cacheBlockSize: cfg.Scheduler.CacheBlockSize,
|
||||||
|
tokenizerURL: cfg.Scheduler.TokenizerURL,
|
||||||
|
tokenizerTimeout: time.Duration(cfg.Scheduler.TokenizerTimeoutSecs * float64(time.Second)),
|
||||||
|
}
|
||||||
|
|
||||||
|
scheduler := &Scheduler{
|
||||||
|
policy: cfg.Scheduler.Policy,
|
||||||
|
prefillPolicy: cfg.Scheduler.PrefillPolicy,
|
||||||
|
decodePolicy: cfg.Scheduler.DecodePolicy,
|
||||||
|
IdCounterMap: make(map[string]*scheduler_common.Counter),
|
||||||
|
tokenMap: make(map[string]*scheduler_common.TokenCounter),
|
||||||
|
managerAPI: managerAPI,
|
||||||
|
prefillCache: newPrefillCacheStrategy(prefillCfg),
|
||||||
|
}
|
||||||
|
counterPolicy := &CounterPolicy{}
|
||||||
|
DefaultScheduler = scheduler
|
||||||
|
DefaultCounterPolicy = counterPolicy
|
||||||
|
waitingWeight = cfg.Scheduler.WaitingWeight
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectWorker selects a worker based on the specified policy and worker type
|
||||||
|
func SelectWorker(ctx context.Context, workers []string, message string, workerType string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", fmt.Errorf("no healthy workers available")
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy string
|
||||||
|
switch workerType {
|
||||||
|
case "prefill":
|
||||||
|
policy = DefaultScheduler.prefillPolicy
|
||||||
|
DefaultCounterPolicy.workerType = "prefill"
|
||||||
|
case "decode":
|
||||||
|
policy = DefaultScheduler.decodePolicy
|
||||||
|
DefaultCounterPolicy.workerType = "decode"
|
||||||
|
default:
|
||||||
|
policy = DefaultScheduler.policy
|
||||||
|
DefaultCounterPolicy.workerType = "mixed"
|
||||||
|
}
|
||||||
|
|
||||||
|
var strategyFunc scheduler_common.SelectStrategyFunc
|
||||||
|
switch policy {
|
||||||
|
case "random":
|
||||||
|
strategyFunc = RandomSelectWorker
|
||||||
|
case "round_robin":
|
||||||
|
strategyFunc = RoundRobinSelectWorker
|
||||||
|
case "power_of_two":
|
||||||
|
strategyFunc = PowerOfTwoSelectWorker
|
||||||
|
case "process_tokens":
|
||||||
|
// Prefill: prioritize the instance with the smallest number of tokens currently being processed
|
||||||
|
strategyFunc = ProcessTokensSelectWorker
|
||||||
|
case "request_num":
|
||||||
|
// Decode/mixed: prioritize the instance with the smallest number of current requests
|
||||||
|
strategyFunc = RequestNumSelectWorker
|
||||||
|
case "fd_metrics_score":
|
||||||
|
strategyFunc = FDMetricsScoreSelectWorker
|
||||||
|
case "cache_aware":
|
||||||
|
strategyFunc = CacheAwarePrefillSelectWorker
|
||||||
|
default:
|
||||||
|
strategyFunc = RandomSelectWorker
|
||||||
|
}
|
||||||
|
|
||||||
|
selectWorkerURL, err := strategyFunc(ctx, workers, message)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("select worker failed [policy: %s]: %w", DefaultScheduler.policy, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(selectWorkerURL, "http://") && !strings.HasPrefix(selectWorkerURL, "https://") {
|
||||||
|
selectWorkerURL = "http://" + selectWorkerURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) All node types: request concurrency count (request_num)
|
||||||
|
counter := GetOrCreateCounter(ctx, selectWorkerURL)
|
||||||
|
counter.Inc()
|
||||||
|
count := counter.Get()
|
||||||
|
|
||||||
|
// 2) Prefill: current token processing count (process_tokens)
|
||||||
|
var tokens uint64
|
||||||
|
if workerType == "prefill" && message != "" {
|
||||||
|
tokenCounter := GetOrCreateTokenCounter(ctx, selectWorkerURL)
|
||||||
|
tokenCounter.Add(estimateTokens(message))
|
||||||
|
tokens = tokenCounter.Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
if workerType == "prefill" {
|
||||||
|
logger.Info("select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
|
||||||
|
} else {
|
||||||
|
logger.Info("select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectWorkerURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release decreases the counter for the specified worker URL
|
||||||
|
func Release(ctx context.Context, url string) {
|
||||||
|
counter := GetOrCreateCounter(ctx, url)
|
||||||
|
counter.Dec()
|
||||||
|
logger.Info("release worker: %s, count: %d", url, counter.Get())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCounter retrieves the counter for the specified root URL
|
||||||
|
func GetCounter(ctx context.Context, rootURL string) (*scheduler_common.Counter, bool) {
|
||||||
|
DefaultScheduler.mu.RLock()
|
||||||
|
defer DefaultScheduler.mu.RUnlock()
|
||||||
|
counter, exists := DefaultScheduler.IdCounterMap[rootURL]
|
||||||
|
return counter, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateCounter retrieves an existing counter or creates a new one
|
||||||
|
func GetOrCreateCounter(ctx context.Context, url string) *scheduler_common.Counter {
|
||||||
|
counter, exists := GetCounter(ctx, url)
|
||||||
|
if exists {
|
||||||
|
return counter
|
||||||
|
}
|
||||||
|
DefaultScheduler.mu.Lock()
|
||||||
|
defer DefaultScheduler.mu.Unlock()
|
||||||
|
// Double check: avoid overwriting what other goroutines may have created before acquiring write lock
|
||||||
|
if counter, exists = DefaultScheduler.IdCounterMap[url]; exists {
|
||||||
|
return counter
|
||||||
|
}
|
||||||
|
newCounter := &scheduler_common.Counter{}
|
||||||
|
DefaultScheduler.IdCounterMap[url] = newCounter
|
||||||
|
return newCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupUnhealthyCounter removes counters for unhealthy worker URLs
|
||||||
|
func CleanupUnhealthyCounter(ctx context.Context, unhealthyRootURL string) {
|
||||||
|
if unhealthyRootURL == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if DefaultScheduler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultScheduler.mu.Lock()
|
||||||
|
defer DefaultScheduler.mu.Unlock()
|
||||||
|
|
||||||
|
delete(DefaultScheduler.IdCounterMap, unhealthyRootURL)
|
||||||
|
delete(DefaultScheduler.tokenMap, unhealthyRootURL)
|
||||||
|
logger.Info("After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupInvalidCounters removes counters for invalid or unreachable workers
|
||||||
|
func CleanupInvalidCounters(ctx context.Context) {
|
||||||
|
if DefaultScheduler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if DefaultScheduler.managerAPI == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
healthyURLs := DefaultScheduler.managerAPI.GetHealthyURLs(ctx)
|
||||||
|
if len(healthyURLs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
healthyMap := make(map[string]bool)
|
||||||
|
for _, rootURL := range healthyURLs {
|
||||||
|
healthyMap[rootURL] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
DefaultScheduler.mu.Lock()
|
||||||
|
defer DefaultScheduler.mu.Unlock()
|
||||||
|
|
||||||
|
for rootURL := range DefaultScheduler.IdCounterMap {
|
||||||
|
if _, exists := healthyMap[rootURL]; !exists {
|
||||||
|
delete(DefaultScheduler.IdCounterMap, rootURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for rootURL := range DefaultScheduler.tokenMap {
|
||||||
|
if _, exists := healthyMap[rootURL]; !exists {
|
||||||
|
delete(DefaultScheduler.tokenMap, rootURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartBackupCleanupTask starts a background task for cleaning up invalid counters
|
||||||
|
func StartBackupCleanupTask(ctx context.Context, interval float64) {
|
||||||
|
ticker := time.NewTicker(time.Duration(interval * float64(time.Second)))
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// case 1: listen for context cancellation/timeout events → graceful exit
|
||||||
|
case <-ctx.Done():
|
||||||
|
return // Exit loop, stop cleanup task
|
||||||
|
// case 2: listen for timer trigger events → perform cleanup
|
||||||
|
case <-ticker.C:
|
||||||
|
CleanupInvalidCounters(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenCounter gets the TokenCounter for the specified instance
|
||||||
|
func GetTokenCounter(ctx context.Context, rootURL string) (*scheduler_common.TokenCounter, bool) {
|
||||||
|
DefaultScheduler.mu.RLock()
|
||||||
|
defer DefaultScheduler.mu.RUnlock()
|
||||||
|
counter, exists := DefaultScheduler.tokenMap[rootURL]
|
||||||
|
return counter, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateTokenCounter gets or creates TokenCounter
|
||||||
|
func GetOrCreateTokenCounter(ctx context.Context, url string) *scheduler_common.TokenCounter {
|
||||||
|
counter, exists := GetTokenCounter(ctx, url)
|
||||||
|
if exists {
|
||||||
|
return counter
|
||||||
|
}
|
||||||
|
DefaultScheduler.mu.Lock()
|
||||||
|
defer DefaultScheduler.mu.Unlock()
|
||||||
|
// Double check to avoid overwriting
|
||||||
|
if counter, exists = DefaultScheduler.tokenMap[url]; exists {
|
||||||
|
return counter
|
||||||
|
}
|
||||||
|
newCounter := &scheduler_common.TokenCounter{}
|
||||||
|
DefaultScheduler.tokenMap[url] = newCounter
|
||||||
|
return newCounter
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateTokens estimates token count based on character count: character count * 2
|
||||||
|
func estimateTokens(message string) uint64 {
|
||||||
|
if message == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
runeCount := utf8.RuneCountInString(message)
|
||||||
|
return uint64(runeCount * 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleasePrefillTokens releases the corresponding token load when request ends
|
||||||
|
func ReleasePrefillTokens(ctx context.Context, url, message string) {
|
||||||
|
if url == "" || message == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenCounter := GetOrCreateTokenCounter(ctx, url)
|
||||||
|
tokenCounter.Sub(estimateTokens(message))
|
||||||
|
logger.Info("release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
|
||||||
|
}
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockManagerAPI struct{}
|
||||||
|
|
||||||
|
func (m *mockManagerAPI) GetHealthyURLs(ctx context.Context) []string {
|
||||||
|
return []string{"worker1", "worker2"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManagerAPI) GetMetrics(ctx context.Context, url string) (int, int, int) {
|
||||||
|
return 0, 0, 0 // 返回默认值用于测试
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerInit(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "random",
|
||||||
|
PrefillPolicy: "process_tokens",
|
||||||
|
DecodePolicy: "request_num",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
Init(cfg, &mockManagerAPI{})
|
||||||
|
|
||||||
|
assert.NotNil(t, DefaultScheduler)
|
||||||
|
assert.Equal(t, "random", DefaultScheduler.policy)
|
||||||
|
assert.Equal(t, "process_tokens", DefaultScheduler.prefillPolicy)
|
||||||
|
assert.Equal(t, "request_num", DefaultScheduler.decodePolicy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWorker(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
workers := []string{"worker1", "worker2", "worker3"}
|
||||||
|
|
||||||
|
Init(&config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "random",
|
||||||
|
PrefillPolicy: "process_tokens",
|
||||||
|
DecodePolicy: "request_num",
|
||||||
|
},
|
||||||
|
}, &mockManagerAPI{})
|
||||||
|
|
||||||
|
t.Run("prefill worker selection", func(t *testing.T) {
|
||||||
|
// Set up token counts
|
||||||
|
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
|
||||||
|
tc1.Add(100)
|
||||||
|
tc2 := GetOrCreateTokenCounter(ctx, "worker2")
|
||||||
|
tc2.Add(50) // Should be selected
|
||||||
|
tc3 := GetOrCreateTokenCounter(ctx, "worker3")
|
||||||
|
tc3.Add(200)
|
||||||
|
|
||||||
|
selected, err := SelectWorker(ctx, workers, "test message", "prefill")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "http://worker2", selected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode worker selection", func(t *testing.T) {
|
||||||
|
// Set up request counts
|
||||||
|
c1 := GetOrCreateCounter(ctx, "worker1")
|
||||||
|
c1.Inc()
|
||||||
|
c1.Inc() // count = 2
|
||||||
|
c2 := GetOrCreateCounter(ctx, "worker2") // count = 0 (should be selected)
|
||||||
|
c3 := GetOrCreateCounter(ctx, "worker3")
|
||||||
|
c3.Inc() // count = 1
|
||||||
|
|
||||||
|
// Verify counts
|
||||||
|
assert.Equal(t, uint64(2), c1.Get())
|
||||||
|
assert.Equal(t, uint64(0), c2.Get())
|
||||||
|
assert.Equal(t, uint64(1), c3.Get())
|
||||||
|
|
||||||
|
selected, err := SelectWorker(ctx, workers, "test", "decode")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "http://worker2", selected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCounterOperations(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
Init(&config.Config{}, nil)
|
||||||
|
|
||||||
|
t.Run("counter increment", func(t *testing.T) {
|
||||||
|
counter := GetOrCreateCounter(ctx, "test")
|
||||||
|
assert.Equal(t, uint64(0), counter.Get())
|
||||||
|
|
||||||
|
counter.Inc()
|
||||||
|
assert.Equal(t, uint64(1), counter.Get())
|
||||||
|
|
||||||
|
counter.Dec()
|
||||||
|
assert.Equal(t, uint64(0), counter.Get())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("token counter operations", func(t *testing.T) {
|
||||||
|
tc := GetOrCreateTokenCounter(ctx, "test")
|
||||||
|
assert.Equal(t, uint64(0), tc.Get())
|
||||||
|
|
||||||
|
tc.Add(100)
|
||||||
|
assert.Equal(t, uint64(100), tc.Get())
|
||||||
|
|
||||||
|
tc.Sub(50)
|
||||||
|
assert.Equal(t, uint64(50), tc.Get())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupInvalidCounters(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
Init(&config.Config{}, &mockManagerAPI{})
|
||||||
|
|
||||||
|
// Add some counters
|
||||||
|
c1 := GetOrCreateCounter(ctx, "worker1")
|
||||||
|
c1.Inc()
|
||||||
|
GetOrCreateCounter(ctx, "invalid-worker") // Should be cleaned up
|
||||||
|
|
||||||
|
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
|
||||||
|
tc1.Add(100)
|
||||||
|
GetOrCreateTokenCounter(ctx, "invalid-worker") // Should be cleaned up
|
||||||
|
|
||||||
|
CleanupInvalidCounters(ctx)
|
||||||
|
|
||||||
|
// Verify counters
|
||||||
|
_, exists := GetCounter(ctx, "worker1")
|
||||||
|
assert.True(t, exists)
|
||||||
|
_, exists = GetCounter(ctx, "invalid-worker")
|
||||||
|
assert.False(t, exists)
|
||||||
|
|
||||||
|
// Verify token counters
|
||||||
|
_, exists = GetTokenCounter(ctx, "worker1")
|
||||||
|
assert.True(t, exists)
|
||||||
|
_, exists = GetTokenCounter(ctx, "invalid-worker")
|
||||||
|
assert.False(t, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEstimateTokens(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected uint64
|
||||||
|
}{
|
||||||
|
{"", 0},
|
||||||
|
{"hello", 10}, // 5 chars * 2
|
||||||
|
{"你好", 4}, // 2 chars * 2 (Chinese characters count as 1 char each)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, estimateTokens(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReleasePrefillTokens(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
Init(&config.Config{}, nil)
|
||||||
|
|
||||||
|
t.Run("valid release", func(t *testing.T) {
|
||||||
|
tc := GetOrCreateTokenCounter(ctx, "worker1")
|
||||||
|
tc.Add(100)
|
||||||
|
ReleasePrefillTokens(ctx, "worker1", "hello") // 5 chars * 2 = 10 tokens
|
||||||
|
assert.Equal(t, uint64(90), tc.Get())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty url or message", func(t *testing.T) {
|
||||||
|
tc := GetOrCreateTokenCounter(ctx, "worker2")
|
||||||
|
tc.Add(100)
|
||||||
|
ReleasePrefillTokens(ctx, "", "hello") // no-op
|
||||||
|
ReleasePrefillTokens(ctx, "worker2", "") // no-op
|
||||||
|
assert.Equal(t, uint64(100), tc.Get())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupUnhealthyCounter(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
Init(&config.Config{}, nil)
|
||||||
|
|
||||||
|
// Add counters
|
||||||
|
c := GetOrCreateCounter(ctx, "unhealthy-worker")
|
||||||
|
c.Inc()
|
||||||
|
tc := GetOrCreateTokenCounter(ctx, "unhealthy-worker")
|
||||||
|
tc.Add(100)
|
||||||
|
|
||||||
|
CleanupUnhealthyCounter(ctx, "unhealthy-worker")
|
||||||
|
|
||||||
|
// Verify cleanup
|
||||||
|
_, exists := GetCounter(ctx, "unhealthy-worker")
|
||||||
|
assert.False(t, exists)
|
||||||
|
_, exists = GetTokenCounter(ctx, "unhealthy-worker")
|
||||||
|
assert.False(t, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartBackupCleanupTask(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
Init(&config.Config{}, &mockManagerAPI{})
|
||||||
|
|
||||||
|
// Add invalid counter
|
||||||
|
GetOrCreateCounter(ctx, "invalid-worker")
|
||||||
|
|
||||||
|
// Start cleanup task with short interval
|
||||||
|
go StartBackupCleanupTask(ctx, 0.1) // 0.1 second interval
|
||||||
|
|
||||||
|
// Wait for cleanup
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Verify cleanup
|
||||||
|
_, exists := GetCounter(ctx, "invalid-worker")
|
||||||
|
assert.False(t, exists)
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessTokensSelectWorker selects the instance with the smallest number of tokens currently being processed for Prefill nodes.
|
||||||
|
func ProcessTokensSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
selected string
|
||||||
|
minTokens uint64 = math.MaxUint64
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, w := range workers {
|
||||||
|
tc := GetOrCreateTokenCounter(ctx, w)
|
||||||
|
load := tc.Get()
|
||||||
|
if load < minTokens {
|
||||||
|
minTokens = load
|
||||||
|
selected = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestNumSelectWorker selects the instance with the smallest number of current requests for Decode nodes.
|
||||||
|
func RequestNumSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
selected string
|
||||||
|
minCount uint64 = math.MaxUint64
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, w := range workers {
|
||||||
|
c := GetOrCreateCounter(ctx, w)
|
||||||
|
load := c.Get()
|
||||||
|
if load < minCount {
|
||||||
|
minCount = load
|
||||||
|
selected = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProcessTokensSelectWorker(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Setup test data
|
||||||
|
workers := []string{"worker1", "worker2", "worker3"}
|
||||||
|
|
||||||
|
// Initialize scheduler and token counters
|
||||||
|
Init(&config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "process_tokens",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
t.Run("select worker with least tokens", func(t *testing.T) {
|
||||||
|
// Set up token counts
|
||||||
|
tc1 := GetOrCreateTokenCounter(ctx, "worker1")
|
||||||
|
tc1.Add(100)
|
||||||
|
tc2 := GetOrCreateTokenCounter(ctx, "worker2")
|
||||||
|
tc2.Add(50) // Should be selected
|
||||||
|
tc3 := GetOrCreateTokenCounter(ctx, "worker3")
|
||||||
|
tc3.Add(200)
|
||||||
|
|
||||||
|
selected, err := ProcessTokensSelectWorker(ctx, workers, "test message")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "worker2", selected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty workers list", func(t *testing.T) {
|
||||||
|
selected, err := ProcessTokensSelectWorker(ctx, []string{}, "test")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "", selected)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestNumSelectWorker(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
workers := []string{"worker1", "worker2", "worker3"}
|
||||||
|
|
||||||
|
Init(&config.Config{
|
||||||
|
Scheduler: config.SchedulerConfig{
|
||||||
|
Policy: "request_num",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
t.Run("select worker with least requests", func(t *testing.T) {
|
||||||
|
// Set up request counts
|
||||||
|
c1 := GetOrCreateCounter(ctx, "worker1")
|
||||||
|
c1.Inc()
|
||||||
|
c1.Inc() // count = 2
|
||||||
|
c2 := GetOrCreateCounter(ctx, "worker2") // count = 0 (should be selected)
|
||||||
|
c3 := GetOrCreateCounter(ctx, "worker3")
|
||||||
|
c3.Inc() // count = 1
|
||||||
|
|
||||||
|
// Verify counts (use variables to avoid "declared and not used" error)
|
||||||
|
assert.Equal(t, uint64(2), c1.Get())
|
||||||
|
assert.Equal(t, uint64(0), c2.Get())
|
||||||
|
assert.Equal(t, uint64(1), c3.Get())
|
||||||
|
|
||||||
|
selected, err := RequestNumSelectWorker(ctx, workers, "test")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "worker2", selected)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty workers list", func(t *testing.T) {
|
||||||
|
selected, err := RequestNumSelectWorker(ctx, []string{}, "test")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "", selected)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PowerOfTwoSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if len(workers) == 1 {
|
||||||
|
return workers[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
length := len(workers)
|
||||||
|
randomNum1 := rand.Intn(length)
|
||||||
|
randomNum2 := rand.Intn(length)
|
||||||
|
|
||||||
|
for randomNum2 == randomNum1 {
|
||||||
|
randomNum2 = rand.Intn(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
worker1 := workers[randomNum1]
|
||||||
|
worker2 := workers[randomNum2]
|
||||||
|
|
||||||
|
counter1 := GetOrCreateCounter(ctx, worker1)
|
||||||
|
counter2 := GetOrCreateCounter(ctx, worker2)
|
||||||
|
load1 := counter1.Get()
|
||||||
|
load2 := counter2.Get()
|
||||||
|
|
||||||
|
var selectedURL string
|
||||||
|
if load1 <= load2 {
|
||||||
|
selectedURL = worker1
|
||||||
|
} else {
|
||||||
|
selectedURL = worker2
|
||||||
|
}
|
||||||
|
return selectedURL, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,545 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"hash/fnv"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type prefillCacheStrategy struct {
|
||||||
|
absThreshold float64
|
||||||
|
relThreshold float64
|
||||||
|
hitRatioWeight float64
|
||||||
|
loadBalanceWeight float64
|
||||||
|
cache *radixPrefixCache
|
||||||
|
tokenizer TokenizerClient
|
||||||
|
}
|
||||||
|
|
||||||
|
type schedulerConfigSnapshot struct {
|
||||||
|
balanceAbsThreshold float64
|
||||||
|
balanceRelThreshold float64
|
||||||
|
hitRatioWeight float64
|
||||||
|
loadBalanceWeight float64
|
||||||
|
cacheBlockSize int
|
||||||
|
tokenizerURL string
|
||||||
|
tokenizerTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// newPrefillCacheStrategy initializes cache-aware strategy config
|
||||||
|
func newPrefillCacheStrategy(cfg *schedulerConfigSnapshot) *prefillCacheStrategy {
|
||||||
|
return &prefillCacheStrategy{
|
||||||
|
absThreshold: cfg.balanceAbsThreshold,
|
||||||
|
relThreshold: cfg.balanceRelThreshold,
|
||||||
|
hitRatioWeight: cfg.hitRatioWeight,
|
||||||
|
loadBalanceWeight: cfg.loadBalanceWeight,
|
||||||
|
cache: newRadixPrefixCache(cfg.cacheBlockSize),
|
||||||
|
tokenizer: NewHTTPTokenizer(cfg.tokenizerURL, cfg.tokenizerTimeout),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheAwarePrefillSelectWorker fallbacks to min tokens on extreme imbalance; otherwise scores by hit rate and load
|
||||||
|
func CacheAwarePrefillSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
if DefaultScheduler == nil || DefaultScheduler.prefillCache == nil {
|
||||||
|
return ProcessTokensSelectWorker(ctx, workers, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy := DefaultScheduler.prefillCache
|
||||||
|
|
||||||
|
// 1) Fetch node load; fallback to min tokens on extreme imbalance
|
||||||
|
loads := strategy.getRunningRequests(ctx, workers)
|
||||||
|
if strategy.isLoadImbalanced(loads) {
|
||||||
|
return ProcessTokensSelectWorker(ctx, workers, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2)tokenize
|
||||||
|
tokens, err := strategy.tokenize(ctx, message)
|
||||||
|
if err != nil || len(tokens) == 0 {
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
|
||||||
|
}
|
||||||
|
return ProcessTokensSelectWorker(ctx, workers, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) Compute prefix tree hit rate
|
||||||
|
hitRatios := strategy.cache.Match(tokens, toWorkerSet(workers))
|
||||||
|
logger.Debug("cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
|
||||||
|
|
||||||
|
// 4) Compute weighted score from hit rate and load
|
||||||
|
selected := strategy.chooseByScore(ctx, workers, loads, hitRatios)
|
||||||
|
|
||||||
|
// 5) Record prefix
|
||||||
|
strategy.cache.Record(tokens, selected)
|
||||||
|
logger.Debug("cache-aware prefill: selected=%s", selected)
|
||||||
|
return selected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenize calls remote tokenizer service
|
||||||
|
func (p *prefillCacheStrategy) tokenize(ctx context.Context, message string) ([]int, error) {
|
||||||
|
if message == "" {
|
||||||
|
return nil, errors.New("empty prompt for tokenizer")
|
||||||
|
}
|
||||||
|
if p.tokenizer == nil {
|
||||||
|
// Fallback to character-based tokenization
|
||||||
|
return charsToTokens(message), nil
|
||||||
|
}
|
||||||
|
tokens, err := p.tokenizer.Tokenize(ctx, message)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
|
||||||
|
return charsToTokens(message), nil
|
||||||
|
}
|
||||||
|
logger.Debug("cache-aware prefill: tokenizer tokens=%v", tokens)
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLoadImbalanced determines if load is imbalanced
|
||||||
|
func (p *prefillCacheStrategy) isLoadImbalanced(loads map[string]uint64) bool {
|
||||||
|
if len(loads) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
maxLoad := uint64(0)
|
||||||
|
minLoad := uint64(math.MaxUint64)
|
||||||
|
for _, v := range loads {
|
||||||
|
if v > maxLoad {
|
||||||
|
maxLoad = v
|
||||||
|
}
|
||||||
|
if v < minLoad {
|
||||||
|
minLoad = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxLoad == minLoad {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
diff := float64(maxLoad - minLoad)
|
||||||
|
relative := diff / float64(maxLoad)
|
||||||
|
|
||||||
|
return diff > p.absThreshold && relative > p.relThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// chooseByScore selects worker by cache hit rate and load
|
||||||
|
func (p *prefillCacheStrategy) chooseByScore(ctx context.Context, workers []string, loads map[string]uint64, hitRatios map[string]int) string {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: reuse maxLoad from isLoadImbalanced
|
||||||
|
var maxLoad uint64
|
||||||
|
for _, w := range workers {
|
||||||
|
if v := loads[w]; v > maxLoad {
|
||||||
|
maxLoad = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bestScore := math.MaxFloat64
|
||||||
|
selected := ""
|
||||||
|
|
||||||
|
for _, w := range workers {
|
||||||
|
hit := float64(hitRatios[w])
|
||||||
|
loadRatio := 0.0
|
||||||
|
if maxLoad > 0 {
|
||||||
|
loadRatio = float64(loads[w]) / float64(maxLoad)
|
||||||
|
}
|
||||||
|
|
||||||
|
score := (100.0-hit)/100*p.hitRatioWeight + loadRatio*p.loadBalanceWeight
|
||||||
|
logger.Debug("cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
|
||||||
|
|
||||||
|
if score < bestScore {
|
||||||
|
bestScore = score
|
||||||
|
selected = w
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tie-breaker: prefer lower token load if scores are equal
|
||||||
|
if score == bestScore && selected != "" {
|
||||||
|
selectedTokens := GetOrCreateTokenCounter(ctx, selected).Get()
|
||||||
|
currentTokens := GetOrCreateTokenCounter(ctx, w).Get()
|
||||||
|
if currentTokens < selectedTokens {
|
||||||
|
selected = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRunningRequests retrieves running request metrics
|
||||||
|
func (p *prefillCacheStrategy) getRunningRequests(ctx context.Context, workers []string) map[string]uint64 {
|
||||||
|
result := make(map[string]uint64, len(workers))
|
||||||
|
if DefaultScheduler == nil || DefaultScheduler.managerAPI == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, w := range workers {
|
||||||
|
running, _, _ := DefaultScheduler.managerAPI.GetMetrics(ctx, w)
|
||||||
|
result[w] = uint64(running)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track prefix hits using a radix tree keyed by block hash
|
||||||
|
type radixPrefixCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
root *radixNode
|
||||||
|
hasher *blockHasher
|
||||||
|
evictionDuration time.Duration
|
||||||
|
maxNodes int
|
||||||
|
nodeCount int
|
||||||
|
allNodes map[*radixNode]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type radixNode struct {
|
||||||
|
key []uint64
|
||||||
|
children map[uint64]*radixNode
|
||||||
|
parent *radixNode
|
||||||
|
workers map[string]time.Time
|
||||||
|
lastAccess time.Time
|
||||||
|
contextLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRadixPrefixCache initializes radix prefix cache with eviction and capacity control
|
||||||
|
func newRadixPrefixCache(blockSize int) *radixPrefixCache {
|
||||||
|
if blockSize <= 0 {
|
||||||
|
blockSize = 64
|
||||||
|
}
|
||||||
|
const defaultEvictionDuration = 5 * time.Minute
|
||||||
|
const defaultMaxNodes = 200000
|
||||||
|
root := &radixNode{
|
||||||
|
key: nil,
|
||||||
|
children: make(map[uint64]*radixNode),
|
||||||
|
contextLen: 0,
|
||||||
|
}
|
||||||
|
cache := &radixPrefixCache{
|
||||||
|
root: root,
|
||||||
|
hasher: newBlockHasher(blockSize),
|
||||||
|
evictionDuration: defaultEvictionDuration,
|
||||||
|
maxNodes: defaultMaxNodes,
|
||||||
|
nodeCount: 1, // root
|
||||||
|
allNodes: map[*radixNode]struct{}{root: {}},
|
||||||
|
}
|
||||||
|
go cache.evictionWorker(cache.evictionDuration / 2)
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match returns prefix hit rate per candidate worker (0–100)
|
||||||
|
func (c *radixPrefixCache) Match(tokens []int, allowed map[string]struct{}) map[string]int {
|
||||||
|
result := make(map[string]int)
|
||||||
|
hashes := c.hasher.prefixHashes(tokens)
|
||||||
|
if len(hashes) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
node, matched := c.matchPrefixHelper(c.root, hashes)
|
||||||
|
length := matched
|
||||||
|
logger.Debug("radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
|
||||||
|
for n := node; n != nil; n = n.parent {
|
||||||
|
ratio := 0
|
||||||
|
if len(hashes) > 0 {
|
||||||
|
ratio = length * 100 / len(hashes)
|
||||||
|
}
|
||||||
|
for w := range n.workers {
|
||||||
|
if allowed != nil {
|
||||||
|
if _, ok := allowed[w]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ratio > result[w] {
|
||||||
|
result[w] = ratio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if n.parent != nil {
|
||||||
|
length = n.parent.contextLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.mu.RUnlock()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record inserts block-hash prefix into radix tree and tags worker
|
||||||
|
func (c *radixPrefixCache) Record(tokens []int, worker string) {
|
||||||
|
if worker == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hashes := c.hasher.prefixHashes(tokens)
|
||||||
|
if len(hashes) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
node := c.insertHelper(c.root, hashes)
|
||||||
|
now := time.Now()
|
||||||
|
for n := node; n != nil; n = n.parent {
|
||||||
|
if n.workers == nil {
|
||||||
|
n.workers = make(map[string]time.Time)
|
||||||
|
}
|
||||||
|
n.workers[worker] = now
|
||||||
|
}
|
||||||
|
logger.Debug("radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictionWorker periodically evicts inactive nodes
|
||||||
|
func (c *radixPrefixCache) evictionWorker(interval time.Duration) {
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
<-ticker.C
|
||||||
|
c.evictExpired()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *radixPrefixCache) evictExpired() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
removed := 0
|
||||||
|
for childKey, child := range c.root.children {
|
||||||
|
removed += c.evictSubtreeIfExpired(c.root, childKey, child, now)
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Debug("radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictSubtreeIfExpired evicts expired nodes and subtrees, returns count of removed nodes
|
||||||
|
func (c *radixPrefixCache) evictSubtreeIfExpired(parent *radixNode, childKey uint64, node *radixNode, now time.Time) int {
|
||||||
|
// Process child nodes first
|
||||||
|
removed := 0
|
||||||
|
for k, child := range node.children {
|
||||||
|
removed += c.evictSubtreeIfExpired(node, k, child, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not delete root node
|
||||||
|
if parent == nil {
|
||||||
|
return removed
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.Sub(node.lastAccess) <= c.evictionDuration {
|
||||||
|
return removed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete expired node and its subtree
|
||||||
|
if parent != nil {
|
||||||
|
delete(parent.children, childKey)
|
||||||
|
}
|
||||||
|
removedSubtree := c.countSubtree(node)
|
||||||
|
c.nodeCount -= removedSubtree
|
||||||
|
if c.nodeCount < 1 {
|
||||||
|
c.nodeCount = 1 // At least include root
|
||||||
|
}
|
||||||
|
c.removeSubtreeFromAll(node)
|
||||||
|
return removed + removedSubtree
|
||||||
|
}
|
||||||
|
|
||||||
|
// countSubtree counts nodes in subtree rooted at node
|
||||||
|
func (c *radixPrefixCache) countSubtree(node *radixNode) int {
|
||||||
|
count := 1
|
||||||
|
for _, child := range node.children {
|
||||||
|
count += c.countSubtree(child)
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeSubtreeFromAll removes subtree references from allNodes
|
||||||
|
func (c *radixPrefixCache) removeSubtreeFromAll(node *radixNode) {
|
||||||
|
if node == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.allNodes, node)
|
||||||
|
for _, child := range node.children {
|
||||||
|
c.removeSubtreeFromAll(child)
|
||||||
|
}
|
||||||
|
// Release references for GC
|
||||||
|
node.children = nil
|
||||||
|
node.parent = nil
|
||||||
|
node.workers = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchPrefixHelper finds longest common prefix node in radix tree
|
||||||
|
func (c *radixPrefixCache) matchPrefixHelper(node *radixNode, hashes []uint64) (*radixNode, int) {
|
||||||
|
if len(hashes) == 0 {
|
||||||
|
return node, node.contextLen
|
||||||
|
}
|
||||||
|
|
||||||
|
if child, ok := node.children[hashes[0]]; ok {
|
||||||
|
prefixLen := matchUint64Len(child.key, hashes)
|
||||||
|
if prefixLen > 0 {
|
||||||
|
if prefixLen == len(child.key) {
|
||||||
|
if prefixLen == len(hashes) {
|
||||||
|
return child, child.contextLen
|
||||||
|
}
|
||||||
|
if deeperNode, deeperMatched := c.matchPrefixHelper(child, hashes[prefixLen:]); deeperNode != nil && deeperMatched > 0 {
|
||||||
|
return deeperNode, deeperMatched
|
||||||
|
}
|
||||||
|
return child, child.contextLen
|
||||||
|
}
|
||||||
|
return child, node.contextLen + prefixLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node, node.contextLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertHelper inserts hash sequence into radix tree, splits nodes if needed
|
||||||
|
func (c *radixPrefixCache) insertHelper(node *radixNode, key []uint64) *radixNode {
|
||||||
|
node.lastAccess = time.Now()
|
||||||
|
|
||||||
|
if len(key) == 0 {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
if child, ok := node.children[key[0]]; ok {
|
||||||
|
prefixLen := matchUint64Len(child.key, key)
|
||||||
|
|
||||||
|
if prefixLen == len(child.key) {
|
||||||
|
if prefixLen == len(key) {
|
||||||
|
child.lastAccess = time.Now()
|
||||||
|
return child
|
||||||
|
}
|
||||||
|
return c.insertHelper(child, key[prefixLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partial match, split required
|
||||||
|
newNode := c.splitNode(node, child, prefixLen)
|
||||||
|
if prefixLen == len(key) {
|
||||||
|
return newNode
|
||||||
|
}
|
||||||
|
return c.insertHelper(newNode, key[prefixLen:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// No matching child, create new node and add to children
|
||||||
|
newNode := newRadixNode(node, key)
|
||||||
|
node.children[key[0]] = newNode
|
||||||
|
c.nodeCount++
|
||||||
|
c.allNodes[newNode] = struct{}{}
|
||||||
|
c.maybeEvictLocked()
|
||||||
|
return newNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *radixPrefixCache) splitNode(parent *radixNode, child *radixNode, prefixLen int) *radixNode {
|
||||||
|
commonKey := append([]uint64{}, child.key[:prefixLen]...)
|
||||||
|
|
||||||
|
newNode := newRadixNode(parent, commonKey)
|
||||||
|
parent.children[commonKey[0]] = newNode
|
||||||
|
|
||||||
|
// Adjust atomic node
|
||||||
|
child.key = append([]uint64{}, child.key[prefixLen:]...)
|
||||||
|
child.parent = newNode
|
||||||
|
child.contextLen = newNode.contextLen + len(child.key)
|
||||||
|
|
||||||
|
if len(child.key) > 0 {
|
||||||
|
newNode.children[child.key[0]] = child
|
||||||
|
}
|
||||||
|
return newNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeEvictLocked checks node count under write lock and evicts expired nodes if over capacity
|
||||||
|
func (c *radixPrefixCache) maybeEvictLocked() {
|
||||||
|
if c.maxNodes <= 0 || c.nodeCount <= c.maxNodes {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.evictExpired()
|
||||||
|
// TODO: implement stronger eviction if still over capacity (e.g., evict oldest by lastAccess)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRadixNode creates radix tree node and computes context length
|
||||||
|
func newRadixNode(parent *radixNode, key []uint64) *radixNode {
|
||||||
|
n := &radixNode{
|
||||||
|
key: append([]uint64{}, key...),
|
||||||
|
children: make(map[uint64]*radixNode),
|
||||||
|
parent: parent,
|
||||||
|
lastAccess: time.Now(),
|
||||||
|
}
|
||||||
|
if parent != nil {
|
||||||
|
n.contextLen = parent.contextLen + len(key)
|
||||||
|
} else {
|
||||||
|
n.contextLen = len(key)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
type blockHasher struct {
|
||||||
|
blockSize int
|
||||||
|
seed uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBlockHasher creates and initializes a new block hasher
|
||||||
|
func newBlockHasher(blockSize int) *blockHasher {
|
||||||
|
if blockSize <= 0 {
|
||||||
|
blockSize = 64
|
||||||
|
}
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
return &blockHasher{
|
||||||
|
blockSize: blockSize,
|
||||||
|
seed: r.Uint64(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixHashes generates parent-chain hash sequence by block
|
||||||
|
func (h *blockHasher) prefixHashes(tokens []int) []uint64 {
|
||||||
|
if h.blockSize <= 0 || len(tokens) < h.blockSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
blockCount := len(tokens) / h.blockSize
|
||||||
|
hashes := make([]uint64, 0, blockCount)
|
||||||
|
parent := h.seed
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
|
||||||
|
for i := 0; i+h.blockSize <= len(tokens); i += h.blockSize {
|
||||||
|
hasher := fnv.New64a()
|
||||||
|
binary.LittleEndian.PutUint64(buf, parent)
|
||||||
|
_, _ = hasher.Write(buf)
|
||||||
|
|
||||||
|
for _, token := range tokens[i : i+h.blockSize] {
|
||||||
|
binary.LittleEndian.PutUint64(buf, uint64(token))
|
||||||
|
_, _ = hasher.Write(buf)
|
||||||
|
}
|
||||||
|
current := hasher.Sum64()
|
||||||
|
hashes = append(hashes, current)
|
||||||
|
parent = current
|
||||||
|
}
|
||||||
|
return hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchUint64Len(a, b []uint64) int {
|
||||||
|
minLen := len(a)
|
||||||
|
if len(b) < minLen {
|
||||||
|
minLen = len(b)
|
||||||
|
}
|
||||||
|
i := 0
|
||||||
|
for i < minLen && a[i] == b[i] {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func charsToTokens(message string) []int {
|
||||||
|
tokens := make([]int, 0, len(message))
|
||||||
|
for _, r := range message {
|
||||||
|
tokens = append(tokens, int(r))
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func toWorkerSet(workers []string) map[string]struct{} {
|
||||||
|
set := make(map[string]struct{}, len(workers))
|
||||||
|
for _, w := range workers {
|
||||||
|
set[w] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RandomSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
randomNum := rand.Intn(len(workers))
|
||||||
|
return workers[randomNum], nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RoundRobinSelectWorker(ctx context.Context, workers []string, message string) (string, error) {
|
||||||
|
if len(workers) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var count uint64
|
||||||
|
if DefaultCounterPolicy.workerType == "prefill" {
|
||||||
|
count = DefaultCounterPolicy.prefillCounter.Add(1) - 1
|
||||||
|
} else {
|
||||||
|
count = DefaultCounterPolicy.counter.Add(1) - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedNum := count % uint64(len(workers))
|
||||||
|
return workers[selectedNum], nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
urlpkg "net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Abstract remote tokenizer interface
|
||||||
|
type TokenizerClient interface {
|
||||||
|
Tokenize(ctx context.Context, prompt string) ([]int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement HTTP /tokenize call
|
||||||
|
type httpTokenizer struct {
|
||||||
|
url string
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return HTTP-based tokenizer client
|
||||||
|
func NewHTTPTokenizer(rawURL string, timeout time.Duration) TokenizerClient {
|
||||||
|
if rawURL == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") {
|
||||||
|
rawURL = "http://" + rawURL
|
||||||
|
}
|
||||||
|
parsed, err := urlpkg.Parse(rawURL)
|
||||||
|
if err == nil {
|
||||||
|
if parsed.Path == "" || parsed.Path == "/" {
|
||||||
|
parsed.Path = "/tokenize"
|
||||||
|
}
|
||||||
|
rawURL = parsed.String()
|
||||||
|
}
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 2 * time.Second
|
||||||
|
}
|
||||||
|
return &httpTokenizer{
|
||||||
|
url: rawURL,
|
||||||
|
timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenizerHTTPReq struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *httpTokenizer) Tokenize(ctx context.Context, prompt string) ([]int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errors.New("tokenizer client is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := tokenizerHTTPReq{Text: prompt, Prompt: prompt, Message: prompt}
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal tokenizer request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build tokenizer request failed: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: c.timeout}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tokenizer request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read tokenizer response failed: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("tokenizer response status %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := parseTokensFromBody(respBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tokenizer response failed: %w", err)
|
||||||
|
}
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tokens from body JSON {"input_ids": []}
|
||||||
|
func parseTokensFromBody(body []byte) ([]int, error) {
|
||||||
|
var input struct {
|
||||||
|
InputIDs []int `json:"input_ids"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &input); err == nil {
|
||||||
|
if len(input.InputIDs) > 0 {
|
||||||
|
return input.InputIDs, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("tokenizer response missing tokens")
|
||||||
|
}
|
||||||
@@ -0,0 +1,650 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHTTPTokenizer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rawURL string
|
||||||
|
timeout time.Duration
|
||||||
|
expected string
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty URL should return nil",
|
||||||
|
rawURL: "",
|
||||||
|
timeout: time.Second,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL without scheme should add http://",
|
||||||
|
rawURL: "example.com",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with http scheme should keep it",
|
||||||
|
rawURL: "http://example.com",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with https scheme should keep it",
|
||||||
|
rawURL: "https://example.com",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "https://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with path should keep it",
|
||||||
|
rawURL: "http://example.com/api",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/api",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with root path should replace with /tokenize",
|
||||||
|
rawURL: "http://example.com/",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with empty path should add /tokenize",
|
||||||
|
rawURL: "http://example.com",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid URL should still work with added scheme",
|
||||||
|
rawURL: "example.com:invalid:port",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com:invalid:port",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero timeout should use default 2s",
|
||||||
|
rawURL: "example.com",
|
||||||
|
timeout: 0,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative timeout should use default 2s",
|
||||||
|
rawURL: "example.com",
|
||||||
|
timeout: -time.Second,
|
||||||
|
expected: "http://example.com/tokenize",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with port and path should be preserved",
|
||||||
|
rawURL: "http://example.com:8080/v1",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com:8080/v1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with query parameters should be preserved",
|
||||||
|
rawURL: "https://example.com?token=abc",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "https://example.com/tokenize?token=abc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with fragment should be preserved",
|
||||||
|
rawURL: "http://example.com#section",
|
||||||
|
timeout: time.Second,
|
||||||
|
expected: "http://example.com/tokenize#section",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
client := NewHTTPTokenizer(tt.rawURL, tt.timeout)
|
||||||
|
|
||||||
|
if tt.wantNil {
|
||||||
|
if client != nil {
|
||||||
|
t.Errorf("Expected nil client for empty URL, got %v", client)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("Expected non-nil client, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient, ok := client.(*httpTokenizer)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *httpTokenizer, got %T", client)
|
||||||
|
}
|
||||||
|
|
||||||
|
if httpClient.url != tt.expected {
|
||||||
|
t.Errorf("Expected URL %q, got %q", tt.expected, httpClient.url)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedTimeout := tt.timeout
|
||||||
|
if expectedTimeout <= 0 {
|
||||||
|
expectedTimeout = 2 * time.Second
|
||||||
|
}
|
||||||
|
if httpClient.timeout != expectedTimeout {
|
||||||
|
t.Errorf("Expected timeout %v, got %v", expectedTimeout, httpClient.timeout)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPTokenizer_URLParsing(t *testing.T) {
|
||||||
|
// Test that URL parsing preserves all components
|
||||||
|
rawURL := "https://user:pass@example.com:8080/path?query=value#fragment"
|
||||||
|
client := NewHTTPTokenizer(rawURL, time.Second)
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("Expected non-nil client, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := client.(*httpTokenizer)
|
||||||
|
parsed, err := url.Parse(httpClient.url)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse client URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.Scheme != "https" {
|
||||||
|
t.Errorf("Expected scheme https, got %q", parsed.Scheme)
|
||||||
|
}
|
||||||
|
if parsed.Host != "example.com:8080" {
|
||||||
|
t.Errorf("Expected host example.com:8080, got %q", parsed.Host)
|
||||||
|
}
|
||||||
|
if parsed.Path != "/path" {
|
||||||
|
t.Errorf("Expected path /path, got %q", parsed.Path)
|
||||||
|
}
|
||||||
|
if parsed.RawQuery != "query=value" {
|
||||||
|
t.Errorf("Expected query query=value, got %q", parsed.RawQuery)
|
||||||
|
}
|
||||||
|
if parsed.Fragment != "fragment" {
|
||||||
|
t.Errorf("Expected fragment fragment, got %q", parsed.Fragment)
|
||||||
|
}
|
||||||
|
if parsed.User.String() != "user:pass" {
|
||||||
|
t.Errorf("Expected user user:pass, got %q", parsed.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPTokenizer_ImplementsInterface(t *testing.T) {
|
||||||
|
client := NewHTTPTokenizer("example.com", time.Second)
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
t.Fatal("Expected non-nil client")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the returned client implements the TokenizerClient interface
|
||||||
|
var _ TokenizerClient = client
|
||||||
|
|
||||||
|
// Test type assertion
|
||||||
|
_, ok := client.(TokenizerClient)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Returned client does not implement TokenizerClient interface")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTokenizer_Tokenize(t *testing.T) {
|
||||||
|
t.Run("nil tokenizer client", func(t *testing.T) {
|
||||||
|
var tokenizer *httpTokenizer = nil
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for nil tokenizer client")
|
||||||
|
}
|
||||||
|
if err.Error() != "tokenizer client is nil" {
|
||||||
|
t.Errorf("Expected 'tokenizer client is nil', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful tokenization", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// httptest server uses root path by default
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
t.Errorf("Expected path /, got %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.Method != "POST" {
|
||||||
|
t.Errorf("Expected POST method, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.Header.Get("Content-Type") != "application/json" {
|
||||||
|
t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify request body
|
||||||
|
var req tokenizerHTTPReq
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode request: %v", err)
|
||||||
|
}
|
||||||
|
if req.Text != "test prompt" {
|
||||||
|
t.Errorf("Expected Text 'test prompt', got '%s'", req.Text)
|
||||||
|
}
|
||||||
|
if req.Prompt != "test prompt" {
|
||||||
|
t.Errorf("Expected Prompt 'test prompt', got '%s'", req.Prompt)
|
||||||
|
}
|
||||||
|
if req.Message != "test prompt" {
|
||||||
|
t.Errorf("Expected Message 'test prompt', got '%s'", req.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send successful response
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"input_ids": []int{1, 2, 3, 4, 5},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tokenize failed: %v", err)
|
||||||
|
}
|
||||||
|
expectedTokens := []int{1, 2, 3, 4, 5}
|
||||||
|
if len(tokens) != len(expectedTokens) {
|
||||||
|
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
|
||||||
|
}
|
||||||
|
for i, token := range tokens {
|
||||||
|
if token != expectedTokens[i] {
|
||||||
|
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("http request creation failure", func(t *testing.T) {
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: "://invalid-url", // Invalid URL to force http.NewRequest to fail
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid URL")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "build tokenizer request failed") {
|
||||||
|
t.Errorf("Expected error to contain 'build tokenizer request failed', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("http request timeout", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond) // Simulate slow response
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"input_ids": []int{1}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 10 * time.Millisecond, // Very short timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for timeout")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer request failed") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("http connection failure", func(t *testing.T) {
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: "http://invalid-server:9999", // Non-existent server
|
||||||
|
timeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for connection failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer request failed") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("http response status error", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte("internal server error"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for status code 500")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer response status 500") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer response status 500', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "internal server error") {
|
||||||
|
t.Errorf("Expected error to contain response body, got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid json response", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("invalid json content"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid JSON")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "parse tokenizer response failed") {
|
||||||
|
t.Errorf("Expected error to contain 'parse tokenizer response failed', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty tokens in response", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"input_ids": []int{},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for empty tokens")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing input_ids field", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"tokens": []int{1, 2, 3}, // Wrong field name
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for missing input_ids")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context cancellation", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(100 * time.Millisecond) // Simulate slow processing
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{"input_ids": []int{1}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(ctx, "test prompt")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for cancelled context")
|
||||||
|
}
|
||||||
|
// Context cancellation should cause request to fail
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer request failed") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer request failed', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty prompt", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req tokenizerHTTPReq
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode request: %v", err)
|
||||||
|
}
|
||||||
|
if req.Text != "" {
|
||||||
|
t.Errorf("Expected empty Text, got '%s'", req.Text)
|
||||||
|
}
|
||||||
|
if req.Prompt != "" {
|
||||||
|
t.Errorf("Expected empty Prompt, got '%s'", req.Prompt)
|
||||||
|
}
|
||||||
|
if req.Message != "" {
|
||||||
|
t.Errorf("Expected empty Message, got '%s'", req.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"input_ids": []int{},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := tokenizer.Tokenize(context.Background(), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for empty prompt")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "tokenizer response missing tokens") {
|
||||||
|
t.Errorf("Expected error to contain 'tokenizer response missing tokens', got '%v'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("very long prompt", func(t *testing.T) {
|
||||||
|
longPrompt := string(make([]byte, 10000)) // 10KB prompt
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req tokenizerHTTPReq
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode request: %v", err)
|
||||||
|
}
|
||||||
|
if req.Text != longPrompt {
|
||||||
|
t.Error("Request text does not match long prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"input_ids": []int{1, 2, 3},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := tokenizer.Tokenize(context.Background(), longPrompt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tokenize failed for long prompt: %v", err)
|
||||||
|
}
|
||||||
|
expectedTokens := []int{1, 2, 3}
|
||||||
|
if len(tokens) != len(expectedTokens) {
|
||||||
|
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
|
||||||
|
}
|
||||||
|
for i, token := range tokens {
|
||||||
|
if token != expectedTokens[i] {
|
||||||
|
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("special characters in prompt", func(t *testing.T) {
|
||||||
|
specialPrompt := "Hello, 世界! 🚀 Test with emoji and unicode: ñáéíóú"
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req tokenizerHTTPReq
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode request: %v", err)
|
||||||
|
}
|
||||||
|
if req.Text != specialPrompt {
|
||||||
|
t.Errorf("Expected Text '%s', got '%s'", specialPrompt, req.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"input_ids": []int{10, 20, 30, 40},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokenizer := &httpTokenizer{
|
||||||
|
url: server.URL,
|
||||||
|
timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := tokenizer.Tokenize(context.Background(), specialPrompt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tokenize failed for special characters: %v", err)
|
||||||
|
}
|
||||||
|
expectedTokens := []int{10, 20, 30, 40}
|
||||||
|
if len(tokens) != len(expectedTokens) {
|
||||||
|
t.Errorf("Expected %d tokens, got %d", len(expectedTokens), len(tokens))
|
||||||
|
}
|
||||||
|
for i, token := range tokens {
|
||||||
|
if token != expectedTokens[i] {
|
||||||
|
t.Errorf("Token at index %d: expected %d, got %d", i, expectedTokens[i], token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokensFromBody(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected []int
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid input with tokens",
|
||||||
|
input: []byte(`{"input_ids": [1, 2, 3]}`),
|
||||||
|
expected: []int{1, 2, 3},
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input_ids array",
|
||||||
|
input: []byte(`{"input_ids": []}`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing input_ids field",
|
||||||
|
input: []byte(`{"other_field": "value"}`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid JSON format",
|
||||||
|
input: []byte(`invalid json`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body",
|
||||||
|
input: []byte(``),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large array of tokens",
|
||||||
|
input: []byte(`{"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}`),
|
||||||
|
expected: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null input_ids",
|
||||||
|
input: []byte(`{"input_ids": null}`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-array input_ids",
|
||||||
|
input: []byte(`{"input_ids": "not an array"}`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed array",
|
||||||
|
input: []byte(`{"input_ids": [1, "two", 3]}`),
|
||||||
|
expected: nil,
|
||||||
|
err: errors.New("tokenizer response missing tokens"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := parseTokensFromBody(tt.input)
|
||||||
|
|
||||||
|
// Check if error is expected
|
||||||
|
if (err != nil && tt.err == nil) || (err == nil && tt.err != nil) {
|
||||||
|
t.Errorf("parseTokensFromBody() error = %v, wantErr %v", err, tt.err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil && tt.err != nil && err.Error() != tt.err.Error() {
|
||||||
|
t.Errorf("parseTokensFromBody() error message = %v, want %v", err.Error(), tt.err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare actual and expected results
|
||||||
|
if len(got) != len(tt.expected) {
|
||||||
|
t.Errorf("parseTokensFromBody() = %v, want %v", got, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range got {
|
||||||
|
if got[i] != tt.expected[i] {
|
||||||
|
t.Errorf("parseTokensFromBody() = %v, want %v", got, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
infoLogger *log.Logger
|
||||||
|
errorLogger *log.Logger
|
||||||
|
warnLogger *log.Logger
|
||||||
|
debugLogger *log.Logger
|
||||||
|
level string
|
||||||
|
once sync.Once
|
||||||
|
logFile *os.File
|
||||||
|
)
|
||||||
|
|
||||||
|
// Init initialize logger
|
||||||
|
func Init(logLevel, output string) {
|
||||||
|
once.Do(func() {
|
||||||
|
level = logLevel
|
||||||
|
|
||||||
|
flags := log.LstdFlags | log.Lshortfile
|
||||||
|
|
||||||
|
if output == "file" {
|
||||||
|
// Check if logs directory exists
|
||||||
|
if _, err := os.Stat("logs"); os.IsNotExist(err) {
|
||||||
|
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||||
|
log.Fatalln("Failed to create logs directory:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logFile, err := os.OpenFile("logs/router.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Failed to open log file:", err)
|
||||||
|
}
|
||||||
|
infoLogger = log.New(logFile, "[INFO] ", flags)
|
||||||
|
errorLogger = log.New(logFile, "[ERROR] ", flags)
|
||||||
|
warnLogger = log.New(logFile, "[WARN] ", flags)
|
||||||
|
debugLogger = log.New(logFile, "[DEBUG] ", flags)
|
||||||
|
} else {
|
||||||
|
infoLogger = log.New(os.Stdout, "[INFO] ", flags)
|
||||||
|
errorLogger = log.New(os.Stderr, "[ERROR] ", flags)
|
||||||
|
warnLogger = log.New(os.Stdout, "[WARN] ", flags)
|
||||||
|
debugLogger = log.New(os.Stdout, "[DEBUG] ", flags)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func CloseLogFile() {
|
||||||
|
if logFile != nil {
|
||||||
|
logFile.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs informational messages
|
||||||
|
func Info(format string, v ...interface{}) {
|
||||||
|
if level == "debug" || level == "info" {
|
||||||
|
infoLogger.Printf(format, v...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs error messages
|
||||||
|
func Error(format string, v ...interface{}) {
|
||||||
|
errorLogger.Printf(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn logs warning messages
|
||||||
|
func Warn(format string, v ...interface{}) {
|
||||||
|
if level == "debug" || level == "info" || level == "warn" {
|
||||||
|
warnLogger.Printf(format, v...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug logs debug messages
|
||||||
|
func Debug(format string, v ...interface{}) {
|
||||||
|
if level == "debug" {
|
||||||
|
debugLogger.Printf(format, v...)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoggerInit(t *testing.T) {
|
||||||
|
t.Run("stdout output", func(t *testing.T) {
|
||||||
|
Init("debug", "stdout")
|
||||||
|
|
||||||
|
if infoLogger == nil || errorLogger == nil || warnLogger == nil || debugLogger == nil {
|
||||||
|
t.Error("Loggers should be initialized")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("file output", func(t *testing.T) {
|
||||||
|
// Clean up existing log file and directory
|
||||||
|
_ = os.RemoveAll("logs")
|
||||||
|
_ = os.MkdirAll("logs", 0755)
|
||||||
|
defer os.RemoveAll("logs")
|
||||||
|
|
||||||
|
Init("debug", "file")
|
||||||
|
|
||||||
|
if _, err := os.Stat("logs/router.log"); os.IsNotExist(err) {
|
||||||
|
t.Error("Log file should be created")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogLevels(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
level string
|
||||||
|
expected map[string]bool
|
||||||
|
}{
|
||||||
|
{"debug level", "debug", map[string]bool{
|
||||||
|
"debug": true,
|
||||||
|
"info": true,
|
||||||
|
"warn": true,
|
||||||
|
"error": true,
|
||||||
|
}},
|
||||||
|
{"info level", "info", map[string]bool{
|
||||||
|
"debug": false,
|
||||||
|
"info": true,
|
||||||
|
"warn": true,
|
||||||
|
"error": true,
|
||||||
|
}},
|
||||||
|
{"warn level", "warn", map[string]bool{
|
||||||
|
"debug": false,
|
||||||
|
"info": false,
|
||||||
|
"warn": true,
|
||||||
|
"error": true,
|
||||||
|
}},
|
||||||
|
{"error level", "error", map[string]bool{
|
||||||
|
"debug": false,
|
||||||
|
"info": false,
|
||||||
|
"warn": false,
|
||||||
|
"error": true,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Initialize logger with test level
|
||||||
|
Init(tt.level, "stdout")
|
||||||
|
|
||||||
|
// Capture output for each level separately
|
||||||
|
testLevel := func(logFunc func(string, ...interface{}), message string) bool {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
oldOutput := infoLogger.Writer()
|
||||||
|
|
||||||
|
infoLogger.SetOutput(&buf)
|
||||||
|
errorLogger.SetOutput(&buf)
|
||||||
|
warnLogger.SetOutput(&buf)
|
||||||
|
debugLogger.SetOutput(&buf)
|
||||||
|
|
||||||
|
logFunc(message)
|
||||||
|
|
||||||
|
infoLogger.SetOutput(oldOutput)
|
||||||
|
errorLogger.SetOutput(oldOutput)
|
||||||
|
warnLogger.SetOutput(oldOutput)
|
||||||
|
debugLogger.SetOutput(oldOutput)
|
||||||
|
|
||||||
|
return strings.Contains(buf.String(), message)
|
||||||
|
}
|
||||||
|
|
||||||
|
debugPrinted := testLevel(Debug, "debug message")
|
||||||
|
infoPrinted := testLevel(Info, "info message")
|
||||||
|
warnPrinted := testLevel(Warn, "warn message")
|
||||||
|
errorPrinted := testLevel(Error, "error message")
|
||||||
|
|
||||||
|
// Check expected behavior
|
||||||
|
if tt.expected["debug"] != debugPrinted {
|
||||||
|
t.Errorf("Debug log: expected %v, got %v", tt.expected["debug"], debugPrinted)
|
||||||
|
}
|
||||||
|
if tt.expected["info"] != infoPrinted {
|
||||||
|
t.Errorf("Info log: expected %v, got %v", tt.expected["info"], infoPrinted)
|
||||||
|
}
|
||||||
|
if tt.expected["warn"] != warnPrinted {
|
||||||
|
t.Errorf("Warn log: expected %v, got %v", tt.expected["warn"], warnPrinted)
|
||||||
|
}
|
||||||
|
if tt.expected["error"] != errorPrinted {
|
||||||
|
t.Errorf("Error log: expected %v, got %v", tt.expected["error"], errorPrinted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogFunctions(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
Init("debug", "stdout")
|
||||||
|
|
||||||
|
// Redirect output
|
||||||
|
oldOutput := infoLogger.Writer()
|
||||||
|
defer func() { infoLogger.SetOutput(oldOutput) }()
|
||||||
|
infoLogger.SetOutput(&buf)
|
||||||
|
|
||||||
|
Info("test %s", "message")
|
||||||
|
if !strings.Contains(buf.String(), "test message") {
|
||||||
|
t.Error("Info log should contain the message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar tests for Error, Warn, Debug...
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
prometheus.MustRegister(TotalRequests)
|
||||||
|
prometheus.MustRegister(InferenceRequests)
|
||||||
|
prometheus.MustRegister(RequestDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRequests tracks the total number of HTTP requests
|
||||||
|
var TotalRequests = prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "http_requests_total",
|
||||||
|
Help: "Total number of HTTP requests",
|
||||||
|
},
|
||||||
|
[]string{"method", "endpoint", "status_code"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// InferenceRequests tracks the total number of inference requests
|
||||||
|
var InferenceRequests = prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "inference_requests_total",
|
||||||
|
Help: "Total number of inference requests",
|
||||||
|
},
|
||||||
|
[]string{"mixed_worker", "prefill_worker", "decode_worker", "status_code"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestDuration tracks the response latency of HTTP requests
|
||||||
|
var RequestDuration = prometheus.NewSummaryVec(
|
||||||
|
prometheus.SummaryOpts{
|
||||||
|
Name: "http_request_duration_seconds",
|
||||||
|
Help: "Summary of the response latency (seconds) of HTTP requests",
|
||||||
|
Objectives: map[float64]float64{0.95: 0.01, 0.99: 0.01}, // Objectives define the required quantiles
|
||||||
|
},
|
||||||
|
[]string{"method", "endpoint"},
|
||||||
|
)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetricsInitialization(t *testing.T) {
|
||||||
|
// Verify all metrics are registered
|
||||||
|
metrics := []prometheus.Collector{
|
||||||
|
TotalRequests,
|
||||||
|
InferenceRequests,
|
||||||
|
RequestDuration,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, metric := range metrics {
|
||||||
|
if err := prometheus.Register(metric); err == nil {
|
||||||
|
t.Errorf("Metric %T should already be registered", metric)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsHelpText(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
metric prometheus.Collector
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"TotalRequests", TotalRequests, "Total number of HTTP requests"},
|
||||||
|
{"InferenceRequests", InferenceRequests, "Total number of inference requests"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
desc := make(chan *prometheus.Desc, 1)
|
||||||
|
tt.metric.Describe(desc)
|
||||||
|
d := <-desc
|
||||||
|
if !strings.Contains(d.String(), tt.expected) {
|
||||||
|
t.Errorf("Expected help text to contain '%s', got '%s'", tt.expected, d.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Executable
+11
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PID=$(ps -ef | grep "fd-router" | grep -v grep | awk '{print $2}')
|
||||||
|
if [ -n "$PID" ]; then
|
||||||
|
echo "Killing existing fd-router process (PID: $PID)"
|
||||||
|
kill -15 $PID
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting new fd-router process..."
|
||||||
|
nohup /usr/local/bin/fd-router --port 8080 --splitwise > fd-router.log 2>&1 &
|
||||||
|
echo "fd-router started with PID: $!"
|
||||||
@@ -45,3 +45,4 @@ omit =
|
|||||||
*/fastdeploy/worker/iluvatar*.py
|
*/fastdeploy/worker/iluvatar*.py
|
||||||
*/fastdeploy/**/xpu/*
|
*/fastdeploy/**/xpu/*
|
||||||
*/fastdeploy/worker/xpu*.py
|
*/fastdeploy/worker/xpu*.py
|
||||||
|
*/fastdeploy/golang_router/*
|
||||||
|
|||||||
Reference in New Issue
Block a user