[Docs] Specify the default strategy (#6728)

* [Docs] Update the document

---------

Co-authored-by: mouxin <mouxin@baidu.com>
This commit is contained in:
mouxin
2026-03-10 13:16:31 +08:00
committed by GitHub
parent 8b8f0c5659
commit 22d308a274
12 changed files with 180 additions and 71 deletions
+19 -3
View File
@@ -190,9 +190,9 @@ server:
splitwise: true # true代表开启pd分离模式,false代表开启非pd分离模式
scheduler:
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num, cache_aware, fd_metrics_score
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num, cache_aware, fd_metrics_score; 默认: request_num
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略; 默认: process_tokens
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略; 默认: request_num
eviction-interval-secs: 60 # cache-aware策略清理过期cache的间隔时间
balance-abs-threshold: 1 # cache-aware策略绝对阈值
balance-rel-threshold: 0.2 # cache-aware策略相对阈值
@@ -253,3 +253,19 @@ instances:
* metrics_port: 推理实例的metrics端口号。
其中 `role``host_ip``port` 为必填参数,其余参数为可选。
## 调度策略说明
Router 支持以下调度策略,可通过配置文件中的 `policy`mixed 模式)、`prefill-policy``decode-policy`PD 分离模式)字段指定。
**默认策略**:不配置时,prefill 节点默认使用 `process_tokens`mixed 和 decode 节点默认使用 `request_num`
| 策略名 | 适用场景 | 实现方式 |
|--------|----------|----------|
| `random` | 通用 | 从所有可用实例中随机选择一个,无状态感知,适合轻量场景。 |
| `round_robin` | 通用 | 使用原子计数器对实例列表循环取模,按顺序均匀分发请求。 |
| `power_of_two` | 通用 | 随机选取两个实例,比较其当前并发请求数,选择负载较低的一个。 |
| `process_tokens` | **prefill(默认)** | 遍历所有实例,选择当前正在处理的 token 数最少的实例,适合 prefill 阶段的长请求负载均衡。 |
| `request_num` | **mixed / decode(默认)** | 遍历所有实例,选择当前并发请求数最少的实例,适合 decode 及 mixed 场景的请求均衡。 |
| `fd_metrics_score` | mixed / decode | 实时从各实例的 metrics 接口获取 running/waiting 请求数,按 `running + waiting × waitingWeight` 打分,选择得分最低的实例。 |
| `cache_aware` | prefill | 基于 Radix Tree 维护各实例的 KV Cache 前缀命中情况,综合命中率与负载打分选择实例;负载严重不均衡时自动回退至 `process_tokens`。 |
+1 -1
View File
@@ -60,7 +60,7 @@ func main() {
// Start server
addr := ":" + cfg.Server.Port
logger.Info("Starting server on %s", addr)
logger.Info(context.Background(), "Starting server on %s", addr)
if err := r.Run(addr); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
@@ -118,5 +118,14 @@ func Load(configPath, listenPort string, isSplitwise bool) (*Config, error) {
if cfg.Scheduler.WaitingWeight == 0 {
cfg.Scheduler.WaitingWeight = 1
}
if cfg.Scheduler.Policy == "" {
cfg.Scheduler.Policy = "request_num"
}
if cfg.Scheduler.PrefillPolicy == "" {
cfg.Scheduler.PrefillPolicy = "process_tokens"
}
if cfg.Scheduler.DecodePolicy == "" {
cfg.Scheduler.DecodePolicy = "request_num"
}
return &cfg, nil
}
@@ -232,7 +232,7 @@ func readPrefillRecv(ctx context.Context, url string, isStream bool, message str
if !released {
scheduler_handler.Release(ctx, url)
scheduler_handler.ReleasePrefillTokens(ctx, url, message)
logger.Debug("[prefill] release in defer (fallback) url=%s", url)
logger.Debug(ctx, "[prefill] release in defer (fallback) url=%s", url)
}
}()
@@ -245,21 +245,29 @@ func readPrefillRecv(ctx context.Context, url string, isStream bool, message str
scheduler_handler.ReleasePrefillTokens(ctx, url, message)
released = true
logger.Debug("[prefill] first chunk received, release scheduler url=%s", url)
logger.Info(ctx, "[prefill] first chunk received, release counter url=%s", url)
}
}
if err := scanner.Err(); err != nil {
logger.Debug("[prefill] scanner error: %v", err)
logger.Debug(ctx, "[prefill] scanner error: %v", err)
}
} else {
_, err := io.Copy(io.Discard, backendResp.Body)
if err != nil {
logger.Debug("[prefill] copy error: %v", err)
logger.Debug(ctx, "[prefill] copy error: %v", err)
}
}
}
func getRequestID(ctx context.Context, rawReq map[string]any) string {
// If user didn't provide request_id, generate one
if _, ok := rawReq["request_id"]; !ok {
rawReq["request_id"] = newRequestID()
}
return rawReq["request_id"].(string)
}
// ChatCompletions implements request forwarding to actual large model inference service
func ChatCompletions(c *gin.Context) {
completionEndpoint := "chat/completions"
@@ -300,9 +308,14 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
)
if isSplitwise {
requestID := getRequestID(ctx, rawReq)
ctx = context.WithValue(ctx, logger.RequestIDKey, requestID)
c.Request = c.Request.WithContext(ctx)
// PD mode: select instances for Prefill/Decode separately
message = extractor(rawReq)
logger.Info(ctx, "Parsing completed; starting worker selection.")
prefillURL, decodeURL, err = manager.SelectWorkerPair(ctx, message)
if err != nil {
c.Writer.WriteHeader(http.StatusBadGateway)
@@ -325,11 +338,6 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
rawReq["disaggregate_info"] = disagg
// If user didn't provide request_id, generate one
if _, ok := rawReq["request_id"]; !ok {
rawReq["request_id"] = newRequestID()
}
// Re-encode request body and send to P and D
requestBodyData, err = json.Marshal(rawReq)
if err != nil {
@@ -345,6 +353,7 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
c.Writer.Header().Set("X-Router-Prefill-URL", prefillURL)
c.Writer.Header().Set("X-Router-Decode-URL", decodeURL)
} else {
logger.Info(ctx, "Parsing completed; starting worker selection.")
// Non-PD mode: use Mixed instance
dest, err := manager.SelectWorker(ctx, "")
if err != nil {
@@ -377,12 +386,13 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
if isSplitwise {
backendResp, err = PostToPD(c, decodeURL, prefillURL, requestBodyData, isStream, message, completionEndpoint)
} else {
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL)
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL, completionEndpoint)
}
if err != nil {
c.Writer.WriteHeader(http.StatusBadGateway)
c.Writer.Write([]byte(`{"error": "Failed to connect to backend service"}`))
logger.Info(ctx, "Request completed with an error.")
return
}
defer backendResp.Body.Close()
@@ -423,25 +433,27 @@ func redirect(c *gin.Context, isStream bool, backendResp *http.Response) {
}
if err := scanner.Err(); err != nil {
logger.Error("scanner error: %v", err)
logger.Error(c.Request.Context(), "scanner error: %v", err)
}
} else {
// Compatible with non-stream response
io.Copy(c.Writer, backendResp.Body)
}
logger.Info(c.Request.Context(), "Request completed successfully.")
}
// GetClientWithRetry adds retry
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string) (
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string, completionEndpoint string) (
backendResp *http.Response, err error) {
// Five retries
maxRetry := 3
for i := 0; i < maxRetry; i++ {
// If creating request fails, it's network connection error, check if selected node is elastic resource, if so, delete it
backendResp, err = GetClient(c, destUrl, "chat/completions", bodyBytes)
backendResp, err = GetClient(c, destUrl, completionEndpoint, bodyBytes)
if err == nil { // Return latest bucketsize
return backendResp, nil
}
logger.Info(c.Request.Context(), "Request failed, retrying...")
}
return nil, err
}
@@ -265,7 +265,7 @@ func TestGetClientWithRetry(t *testing.T) {
reqBody := []byte(`{"test": "data"}`)
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
@@ -293,7 +293,7 @@ func TestGetClientWithRetry(t *testing.T) {
reqBody := []byte(`{"test": "data"}`)
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
assert.Error(t, err)
assert.Nil(t, resp)
})
@@ -314,7 +314,7 @@ func TestGetClientWithRetry(t *testing.T) {
reqBody := []byte(`{"test": "data"}`)
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, http.StatusOK, resp.StatusCode)
@@ -26,7 +26,7 @@ type healthMonitorResult struct {
func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Duration) bool {
// Handle empty baseURL
if baseURL == "" {
logger.Error("empty baseURL provided")
logger.Error(ctx, "empty baseURL provided")
return false
}
@@ -42,7 +42,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
logger.Error("failed to create request: %v", err)
logger.Error(ctx, "failed to create request: %v", err)
return false
}
@@ -51,7 +51,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
resp, err := client.Do(req)
if err != nil {
logger.Error("failed to send request to %s with error: %v", url, err)
logger.Error(ctx, "failed to send request to %s with error: %v", url, err)
return false
}
defer resp.Body.Close()
@@ -59,7 +59,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
// Read response body
_, err = io.ReadAll(resp.Body)
if err != nil {
logger.Error("failed to read response body: %v", err)
logger.Error(ctx, "failed to read response body: %v", err)
return false
}
@@ -139,9 +139,9 @@ func HealthGenerate(c *gin.Context) {
for res := range results {
// Process each result
if !res.isHealthy {
logger.Warn("Server %s is not healthy", res.url)
logger.Warn(c.Request.Context(), "Server %s is not healthy", res.url)
} else {
logger.Info("Server %s is healthy", res.url)
logger.Info(c.Request.Context(), "Server %s is healthy", res.url)
}
}
@@ -158,26 +158,26 @@ func RemoveServers(ctx context.Context, prefillToRemove []string, decodeToRemove
for _, id := range prefillToRemove {
if worker, exists := DefaultManager.prefillWorkerMap[id]; exists {
delete(DefaultManager.prefillWorkerMap, id)
logger.Info("Removed unhealthy prefill instance: %s", worker.Url)
logger.Info(ctx, "Removed unhealthy prefill instance: %s", worker.Url)
}
}
for _, id := range decodeToRemove {
if worker, exists := DefaultManager.decodeWorkerMap[id]; exists {
delete(DefaultManager.decodeWorkerMap, id)
logger.Info("Removed unhealthy decode instance: %s", worker.Url)
logger.Info(ctx, "Removed unhealthy decode instance: %s", worker.Url)
}
}
for _, id := range mixedToRemove {
if worker, exists := DefaultManager.mixedWorkerMap[id]; exists {
delete(DefaultManager.mixedWorkerMap, id)
logger.Info("Removed unhealthy mixed instance: %s", worker.Url)
logger.Info(ctx, "Removed unhealthy mixed instance: %s", worker.Url)
}
}
}
func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedInstances []string) {
if DefaultManager == nil {
logger.Debug("Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
logger.Debug(ctx, "Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
return []string{}, []string{}, []string{}
}
@@ -199,7 +199,7 @@ func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedI
for _, w := range DefaultManager.mixedWorkerMap {
mixedInstances = append(mixedInstances, w.Url)
}
logger.Debug(
logger.Debug(ctx,
"Healthy instances: prefill=%v, decode=%v, mixed=%v",
prefillInstances,
decodeInstances,
@@ -209,7 +209,7 @@ func RegisterInstanceCore(ctx context.Context, rawInstance *InstanceInfo) error
case DECODE:
DefaultManager.decodeWorkerMap[id] = workerInfo
default:
logger.Warn("Instance %s role is unknown", id)
logger.Warn(ctx, "Instance %s role is unknown", id)
}
return nil
@@ -236,7 +236,7 @@ func RegisterInstance(c *gin.Context) {
}
if err := RegisterInstanceCore(c.Request.Context(), &rawInstance); err != nil {
logger.Error("Failed to register instance: %v", err)
logger.Error(c.Request.Context(), "Failed to register instance: %v", err)
// Return different HTTP status codes based on error type
if strings.Contains(err.Error(), "not healthy") {
c.JSON(http.StatusServiceUnavailable, gin.H{
@@ -264,26 +264,26 @@ func RegisterInstancesFromConfig(yamlPath string) {
}
data, err := os.ReadFile(yamlPath)
if err != nil {
logger.Error("Failed to read YAML file %s: %v", yamlPath, err)
logger.Error(context.Background(), "Failed to read YAML file %s: %v", yamlPath, err)
return
}
var config RegisterConfig
if err := yaml.Unmarshal(data, &config); err != nil {
logger.Error("Failed to unmarshal YAML file %s: %v", yamlPath, err)
logger.Error(context.Background(), "Failed to unmarshal YAML file %s: %v", yamlPath, err)
return
}
if len(config.Instances) == 0 {
logger.Info("No instances found in config file %s", yamlPath)
logger.Info(context.Background(), "No instances found in config file %s", yamlPath)
return
}
for i, instanceConfig := range config.Instances {
if err := RegisterInstanceCore(context.Background(), &instanceConfig); err != nil {
logger.Error("Failed to register instance from index %d: %v", i, err)
logger.Error(context.Background(), "Failed to register instance from index %d: %v", i, err)
} else {
logger.Info("Successfully registered instance from index %d", i)
logger.Info(context.Background(), "Successfully registered instance from index %d", i)
}
}
}
@@ -8,7 +8,7 @@ import (
// Logger logger middleware
func Logger() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
logger.Info("[%s] %s %s %d %s %s",
logger.Info(param.Request.Context(), "[%s] %s %s %d %s %s",
param.Method,
param.Path,
param.Request.Proto,
@@ -23,7 +23,7 @@ func Logger() gin.HandlerFunc {
// Recovery recovery middleware
func Recovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
logger.Error("Panic recovered: %v", recovered)
logger.Error(c.Request.Context(), "Panic recovered: %v", recovered)
c.JSON(500, gin.H{
"code": 500,
"msg": "Internal server error",
@@ -127,9 +127,9 @@ func SelectWorker(ctx context.Context, workers []string, message string, workerT
}
if workerType == "prefill" {
logger.Info("select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
logger.Info(ctx, "select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
} else {
logger.Info("select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
logger.Info(ctx, "select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
}
return selectWorkerURL, nil
@@ -139,7 +139,7 @@ func SelectWorker(ctx context.Context, workers []string, message string, workerT
func Release(ctx context.Context, url string) {
counter := GetOrCreateCounter(ctx, url)
counter.Dec()
logger.Info("release worker: %s, count: %d", url, counter.Get())
logger.Info(ctx, "release worker: %s, count: %d", url, counter.Get())
}
// GetCounter retrieves the counter for the specified root URL
@@ -182,7 +182,7 @@ func CleanupUnhealthyCounter(ctx context.Context, unhealthyRootURL string) {
delete(DefaultScheduler.IdCounterMap, unhealthyRootURL)
delete(DefaultScheduler.tokenMap, unhealthyRootURL)
logger.Info("After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
logger.Info(ctx, "After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
}
// CleanupInvalidCounters removes counters for invalid or unreachable workers
@@ -218,7 +218,7 @@ func CleanupInvalidCounters(ctx context.Context) {
}
}
logger.Info("After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
logger.Info(ctx, "After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
}
// StartBackupCleanupTask starts a background task for cleaning up invalid counters
@@ -278,5 +278,5 @@ func ReleasePrefillTokens(ctx context.Context, url, message string) {
}
tokenCounter := GetOrCreateTokenCounter(ctx, url)
tokenCounter.Sub(estimateTokens(message))
logger.Info("release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
logger.Info(ctx, "release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
}
@@ -65,21 +65,21 @@ func CacheAwarePrefillSelectWorker(ctx context.Context, workers []string, messag
tokens, err := strategy.tokenize(ctx, message)
if err != nil || len(tokens) == 0 {
if err != nil {
logger.Warn("cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
logger.Warn(ctx, "cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
}
return ProcessTokensSelectWorker(ctx, workers, message)
}
// 3) Compute prefix tree hit rate
hitRatios := strategy.cache.Match(tokens, toWorkerSet(workers))
logger.Debug("cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
logger.Debug(ctx, "cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
// 4) Compute weighted score from hit rate and load
selected := strategy.chooseByScore(ctx, workers, loads, hitRatios)
// 5) Record prefix
strategy.cache.Record(tokens, selected)
logger.Debug("cache-aware prefill: selected=%s", selected)
logger.Debug(ctx, "cache-aware prefill: selected=%s", selected)
return selected, nil
}
@@ -94,10 +94,10 @@ func (p *prefillCacheStrategy) tokenize(ctx context.Context, message string) ([]
}
tokens, err := p.tokenizer.Tokenize(ctx, message)
if err != nil {
logger.Warn("cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
logger.Warn(ctx, "cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
return charsToTokens(message), nil
}
logger.Debug("cache-aware prefill: tokenizer tokens=%v", tokens)
logger.Debug(ctx, "cache-aware prefill: tokenizer tokens=%v", tokens)
return tokens, nil
}
@@ -153,7 +153,7 @@ func (p *prefillCacheStrategy) chooseByScore(ctx context.Context, workers []stri
}
score := (100.0-hit)/100*p.hitRatioWeight + loadRatio*p.loadBalanceWeight
logger.Debug("cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
logger.Debug(ctx, "cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
if score < bestScore {
bestScore = score
@@ -243,7 +243,7 @@ func (c *radixPrefixCache) Match(tokens []int, allowed map[string]struct{}) map[
c.mu.RLock()
node, matched := c.matchPrefixHelper(c.root, hashes)
length := matched
logger.Debug("radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
logger.Debug(context.Background(), "radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
for n := node; n != nil; n = n.parent {
ratio := 0
if len(hashes) > 0 {
@@ -291,7 +291,7 @@ func (c *radixPrefixCache) Record(tokens []int, worker string) {
}
n.workers[worker] = now
}
logger.Debug("radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
logger.Debug(context.Background(), "radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
}
// evictionWorker periodically evicts inactive nodes
@@ -313,7 +313,7 @@ func (c *radixPrefixCache) evictExpired() {
removed += c.evictSubtreeIfExpired(c.root, childKey, child, now)
}
if removed > 0 {
logger.Debug("radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
logger.Debug(context.Background(), "radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
}
}
+26 -8
View File
@@ -4,6 +4,7 @@ import (
"log"
"os"
"sync"
"context"
)
var (
@@ -16,6 +17,9 @@ var (
logFile *os.File
)
type contextKey string
const RequestIDKey contextKey = "request_id"
// Init initialize logger
func Init(logLevel, output string) {
once.Do(func() {
@@ -53,28 +57,42 @@ func CloseLogFile() {
}
}
func contextPrefix(ctx context.Context) string {
if ctx == nil {
return ""
}
if rid, ok := ctx.Value(RequestIDKey).(string); ok && rid != "" {
return "[request_id:" + rid + "] "
}
return ""
}
// Info logs informational messages
func Info(format string, v ...interface{}) {
func Info(ctx context.Context, format string, v ...interface{}) {
if level == "debug" || level == "info" {
infoLogger.Printf(format, v...)
prefix := contextPrefix(ctx)
infoLogger.Printf(prefix+format, v...)
}
}
// Error logs error messages
func Error(format string, v ...interface{}) {
errorLogger.Printf(format, v...)
func Error(ctx context.Context, format string, v ...interface{}) {
prefix := contextPrefix(ctx)
errorLogger.Printf(prefix+format, v...)
}
// Warn logs warning messages
func Warn(format string, v ...interface{}) {
func Warn(ctx context.Context, format string, v ...interface{}) {
if level == "debug" || level == "info" || level == "warn" {
warnLogger.Printf(format, v...)
prefix := contextPrefix(ctx)
warnLogger.Printf(prefix+format, v...)
}
}
// Debug logs debug messages
func Debug(format string, v ...interface{}) {
func Debug(ctx context.Context, format string, v ...interface{}) {
if level == "debug" {
debugLogger.Printf(format, v...)
prefix := contextPrefix(ctx)
debugLogger.Printf(prefix+format, v...)
}
}
@@ -2,6 +2,7 @@ package logger
import (
"bytes"
"context"
"os"
"strings"
"testing"
@@ -22,7 +23,12 @@ func TestLoggerInit(t *testing.T) {
_ = os.MkdirAll("logs", 0755)
defer os.RemoveAll("logs")
Init("debug", "file")
// sync.Once prevents re-init, so manually verify file creation logic
f, err := os.OpenFile("logs/router.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
t.Fatalf("Failed to create log file: %v", err)
}
f.Close()
if _, err := os.Stat("logs/router.log"); os.IsNotExist(err) {
t.Error("Log file should be created")
@@ -64,11 +70,11 @@ func TestLogLevels(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Initialize logger with test level
Init(tt.level, "stdout")
// Directly set package-level variable since sync.Once prevents re-init
level = tt.level
// Capture output for each level separately
testLevel := func(logFunc func(string, ...interface{}), message string) bool {
testLevel := func(logFunc func(context.Context, string, ...interface{}), message string) bool {
var buf bytes.Buffer
oldOutput := infoLogger.Writer()
@@ -77,7 +83,7 @@ func TestLogLevels(t *testing.T) {
warnLogger.SetOutput(&buf)
debugLogger.SetOutput(&buf)
logFunc(message)
logFunc(nil, message)
infoLogger.SetOutput(oldOutput)
errorLogger.SetOutput(oldOutput)
@@ -112,16 +118,64 @@ func TestLogLevels(t *testing.T) {
func TestLogFunctions(t *testing.T) {
var buf bytes.Buffer
Init("debug", "stdout")
level = "debug"
// Redirect output
oldOutput := infoLogger.Writer()
defer func() { infoLogger.SetOutput(oldOutput) }()
infoLogger.SetOutput(&buf)
Info("test %s", "message")
Info(nil, "test %s", "message")
if !strings.Contains(buf.String(), "test message") {
t.Error("Info log should contain the message")
}
// Similar tests for Error, Warn, Debug...
}
func TestContextPrefix(t *testing.T) {
Init("debug", "stdout")
level = "debug"
t.Run("nil context produces no prefix", func(t *testing.T) {
var buf bytes.Buffer
oldOutput := infoLogger.Writer()
defer func() { infoLogger.SetOutput(oldOutput) }()
infoLogger.SetOutput(&buf)
Info(nil, "no prefix here")
output := buf.String()
if strings.Contains(output, "[request_id:") {
t.Errorf("nil context should produce no request_id prefix, got: %s", output)
}
if !strings.Contains(output, "no prefix here") {
t.Errorf("message should be present, got: %s", output)
}
})
t.Run("context without request_id produces [request_id:null]", func(t *testing.T) {
var buf bytes.Buffer
oldOutput := infoLogger.Writer()
defer func() { infoLogger.SetOutput(oldOutput) }()
infoLogger.SetOutput(&buf)
ctx := context.Background()
Info(ctx, "mixed mode log")
output := buf.String()
if !strings.Contains(output, "[request_id:null]") {
t.Errorf("context without request_id should produce [request_id:null], got: %s", output)
}
})
t.Run("context with request_id produces [request_id:xxx]", func(t *testing.T) {
var buf bytes.Buffer
oldOutput := infoLogger.Writer()
defer func() { infoLogger.SetOutput(oldOutput) }()
infoLogger.SetOutput(&buf)
ctx := context.WithValue(context.Background(), RequestIDKey, "test-uuid-123")
Info(ctx, "pd mode log")
output := buf.String()
if !strings.Contains(output, "[request_id:test-uuid-123]") {
t.Errorf("context with request_id should produce [request_id:test-uuid-123], got: %s", output)
}
})
}