mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[Docs] Specify the default strategy (#6728)
* [Docs] Update the document --------- Co-authored-by: mouxin <mouxin@baidu.com>
This commit is contained in:
@@ -190,9 +190,9 @@ server:
|
||||
splitwise: true # true代表开启pd分离模式,false代表开启非pd分离模式
|
||||
|
||||
scheduler:
|
||||
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num, cache_aware, fd_metrics_score
|
||||
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略
|
||||
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略
|
||||
policy: "power_of_two" # 调度策略(可选): random, power_of_two, round_robin, process_tokens, request_num, cache_aware, fd_metrics_score; 默认: request_num
|
||||
prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略; 默认: process_tokens
|
||||
decode-policy: "fd_metrics_score" # pd分离模式下decode节点调度策略; 默认: request_num
|
||||
eviction-interval-secs: 60 # cache-aware策略清理过期cache的间隔时间
|
||||
balance-abs-threshold: 1 # cache-aware策略绝对阈值
|
||||
balance-rel-threshold: 0.2 # cache-aware策略相对阈值
|
||||
@@ -253,3 +253,19 @@ instances:
|
||||
* metrics_port: 推理实例的metrics端口号。
|
||||
|
||||
其中 `role`、`host_ip` 和 `port` 为必填参数,其余参数为可选。
|
||||
|
||||
## 调度策略说明
|
||||
|
||||
Router 支持以下调度策略,可通过配置文件中的 `policy`(mixed 模式)、`prefill-policy` 和 `decode-policy`(PD 分离模式)字段指定。
|
||||
|
||||
**默认策略**:不配置时,prefill 节点默认使用 `process_tokens`,mixed 和 decode 节点默认使用 `request_num`。
|
||||
|
||||
| 策略名 | 适用场景 | 实现方式 |
|
||||
|--------|----------|----------|
|
||||
| `random` | 通用 | 从所有可用实例中随机选择一个,无状态感知,适合轻量场景。 |
|
||||
| `round_robin` | 通用 | 使用原子计数器对实例列表循环取模,按顺序均匀分发请求。 |
|
||||
| `power_of_two` | 通用 | 随机选取两个实例,比较其当前并发请求数,选择负载较低的一个。 |
|
||||
| `process_tokens` | **prefill(默认)** | 遍历所有实例,选择当前正在处理的 token 数最少的实例,适合 prefill 阶段的长请求负载均衡。 |
|
||||
| `request_num` | **mixed / decode(默认)** | 遍历所有实例,选择当前并发请求数最少的实例,适合 decode 及 mixed 场景的请求均衡。 |
|
||||
| `fd_metrics_score` | mixed / decode | 实时从各实例的 metrics 接口获取 running/waiting 请求数,按 `running + waiting × waitingWeight` 打分,选择得分最低的实例。 |
|
||||
| `cache_aware` | prefill | 基于 Radix Tree 维护各实例的 KV Cache 前缀命中情况,综合命中率与负载打分选择实例;负载严重不均衡时自动回退至 `process_tokens`。 |
|
||||
|
||||
@@ -60,7 +60,7 @@ func main() {
|
||||
|
||||
// Start server
|
||||
addr := ":" + cfg.Server.Port
|
||||
logger.Info("Starting server on %s", addr)
|
||||
logger.Info(context.Background(), "Starting server on %s", addr)
|
||||
if err := r.Run(addr); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
@@ -118,5 +118,14 @@ func Load(configPath, listenPort string, isSplitwise bool) (*Config, error) {
|
||||
if cfg.Scheduler.WaitingWeight == 0 {
|
||||
cfg.Scheduler.WaitingWeight = 1
|
||||
}
|
||||
if cfg.Scheduler.Policy == "" {
|
||||
cfg.Scheduler.Policy = "request_num"
|
||||
}
|
||||
if cfg.Scheduler.PrefillPolicy == "" {
|
||||
cfg.Scheduler.PrefillPolicy = "process_tokens"
|
||||
}
|
||||
if cfg.Scheduler.DecodePolicy == "" {
|
||||
cfg.Scheduler.DecodePolicy = "request_num"
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@ func readPrefillRecv(ctx context.Context, url string, isStream bool, message str
|
||||
if !released {
|
||||
scheduler_handler.Release(ctx, url)
|
||||
scheduler_handler.ReleasePrefillTokens(ctx, url, message)
|
||||
logger.Debug("[prefill] release in defer (fallback) url=%s", url)
|
||||
logger.Debug(ctx, "[prefill] release in defer (fallback) url=%s", url)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -245,21 +245,29 @@ func readPrefillRecv(ctx context.Context, url string, isStream bool, message str
|
||||
scheduler_handler.ReleasePrefillTokens(ctx, url, message)
|
||||
released = true
|
||||
|
||||
logger.Debug("[prefill] first chunk received, release scheduler url=%s", url)
|
||||
logger.Info(ctx, "[prefill] first chunk received, release counter url=%s", url)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.Debug("[prefill] scanner error: %v", err)
|
||||
logger.Debug(ctx, "[prefill] scanner error: %v", err)
|
||||
}
|
||||
} else {
|
||||
_, err := io.Copy(io.Discard, backendResp.Body)
|
||||
if err != nil {
|
||||
logger.Debug("[prefill] copy error: %v", err)
|
||||
logger.Debug(ctx, "[prefill] copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getRequestID(ctx context.Context, rawReq map[string]any) string {
|
||||
// If user didn't provide request_id, generate one
|
||||
if _, ok := rawReq["request_id"]; !ok {
|
||||
rawReq["request_id"] = newRequestID()
|
||||
}
|
||||
return rawReq["request_id"].(string)
|
||||
}
|
||||
|
||||
// ChatCompletions implements request forwarding to actual large model inference service
|
||||
func ChatCompletions(c *gin.Context) {
|
||||
completionEndpoint := "chat/completions"
|
||||
@@ -300,9 +308,14 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
|
||||
)
|
||||
|
||||
if isSplitwise {
|
||||
requestID := getRequestID(ctx, rawReq)
|
||||
ctx = context.WithValue(ctx, logger.RequestIDKey, requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// PD mode: select instances for Prefill/Decode separately
|
||||
message = extractor(rawReq)
|
||||
|
||||
logger.Info(ctx, "Parsing completed; starting worker selection.")
|
||||
prefillURL, decodeURL, err = manager.SelectWorkerPair(ctx, message)
|
||||
if err != nil {
|
||||
c.Writer.WriteHeader(http.StatusBadGateway)
|
||||
@@ -325,11 +338,6 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
|
||||
|
||||
rawReq["disaggregate_info"] = disagg
|
||||
|
||||
// If user didn't provide request_id, generate one
|
||||
if _, ok := rawReq["request_id"]; !ok {
|
||||
rawReq["request_id"] = newRequestID()
|
||||
}
|
||||
|
||||
// Re-encode request body and send to P and D
|
||||
requestBodyData, err = json.Marshal(rawReq)
|
||||
if err != nil {
|
||||
@@ -345,6 +353,7 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
|
||||
c.Writer.Header().Set("X-Router-Prefill-URL", prefillURL)
|
||||
c.Writer.Header().Set("X-Router-Decode-URL", decodeURL)
|
||||
} else {
|
||||
logger.Info(ctx, "Parsing completed; starting worker selection.")
|
||||
// Non-PD mode: use Mixed instance
|
||||
dest, err := manager.SelectWorker(ctx, "")
|
||||
if err != nil {
|
||||
@@ -377,12 +386,13 @@ func CommonCompletions(c *gin.Context, extractor PromptExtractor, completionEndp
|
||||
if isSplitwise {
|
||||
backendResp, err = PostToPD(c, decodeURL, prefillURL, requestBodyData, isStream, message, completionEndpoint)
|
||||
} else {
|
||||
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL)
|
||||
backendResp, err = GetClientWithRetry(c, requestBodyData, destURL, completionEndpoint)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.Writer.WriteHeader(http.StatusBadGateway)
|
||||
c.Writer.Write([]byte(`{"error": "Failed to connect to backend service"}`))
|
||||
logger.Info(ctx, "Request completed with an error.")
|
||||
return
|
||||
}
|
||||
defer backendResp.Body.Close()
|
||||
@@ -423,25 +433,27 @@ func redirect(c *gin.Context, isStream bool, backendResp *http.Response) {
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.Error("scanner error: %v", err)
|
||||
logger.Error(c.Request.Context(), "scanner error: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Compatible with non-stream response
|
||||
io.Copy(c.Writer, backendResp.Body)
|
||||
}
|
||||
logger.Info(c.Request.Context(), "Request completed successfully.")
|
||||
}
|
||||
|
||||
// GetClientWithRetry adds retry
|
||||
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string) (
|
||||
func GetClientWithRetry(c *gin.Context, bodyBytes []byte, destUrl string, completionEndpoint string) (
|
||||
backendResp *http.Response, err error) {
|
||||
// Five retries
|
||||
maxRetry := 3
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
// If creating request fails, it's network connection error, check if selected node is elastic resource, if so, delete it
|
||||
backendResp, err = GetClient(c, destUrl, "chat/completions", bodyBytes)
|
||||
backendResp, err = GetClient(c, destUrl, completionEndpoint, bodyBytes)
|
||||
if err == nil { // Return latest bucketsize
|
||||
return backendResp, nil
|
||||
}
|
||||
logger.Info(c.Request.Context(), "Request failed, retrying...")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func TestGetClientWithRetry(t *testing.T) {
|
||||
|
||||
reqBody := []byte(`{"test": "data"}`)
|
||||
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
@@ -293,7 +293,7 @@ func TestGetClientWithRetry(t *testing.T) {
|
||||
|
||||
reqBody := []byte(`{"test": "data"}`)
|
||||
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
})
|
||||
@@ -314,7 +314,7 @@ func TestGetClientWithRetry(t *testing.T) {
|
||||
|
||||
reqBody := []byte(`{"test": "data"}`)
|
||||
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL)
|
||||
resp, err := GetClientWithRetry(c, reqBody, ts.URL, "chat/completions")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
@@ -26,7 +26,7 @@ type healthMonitorResult struct {
|
||||
func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Duration) bool {
|
||||
// Handle empty baseURL
|
||||
if baseURL == "" {
|
||||
logger.Error("empty baseURL provided")
|
||||
logger.Error(ctx, "empty baseURL provided")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
logger.Error("failed to create request: %v", err)
|
||||
logger.Error(ctx, "failed to create request: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("failed to send request to %s with error: %v", url, err)
|
||||
logger.Error(ctx, "failed to send request to %s with error: %v", url, err)
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -59,7 +59,7 @@ func CheckServiceHealth(ctx context.Context, baseURL string, timeout ...time.Dur
|
||||
// Read response body
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Error("failed to read response body: %v", err)
|
||||
logger.Error(ctx, "failed to read response body: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -139,9 +139,9 @@ func HealthGenerate(c *gin.Context) {
|
||||
for res := range results {
|
||||
// Process each result
|
||||
if !res.isHealthy {
|
||||
logger.Warn("Server %s is not healthy", res.url)
|
||||
logger.Warn(c.Request.Context(), "Server %s is not healthy", res.url)
|
||||
} else {
|
||||
logger.Info("Server %s is healthy", res.url)
|
||||
logger.Info(c.Request.Context(), "Server %s is healthy", res.url)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,26 +158,26 @@ func RemoveServers(ctx context.Context, prefillToRemove []string, decodeToRemove
|
||||
for _, id := range prefillToRemove {
|
||||
if worker, exists := DefaultManager.prefillWorkerMap[id]; exists {
|
||||
delete(DefaultManager.prefillWorkerMap, id)
|
||||
logger.Info("Removed unhealthy prefill instance: %s", worker.Url)
|
||||
logger.Info(ctx, "Removed unhealthy prefill instance: %s", worker.Url)
|
||||
}
|
||||
}
|
||||
for _, id := range decodeToRemove {
|
||||
if worker, exists := DefaultManager.decodeWorkerMap[id]; exists {
|
||||
delete(DefaultManager.decodeWorkerMap, id)
|
||||
logger.Info("Removed unhealthy decode instance: %s", worker.Url)
|
||||
logger.Info(ctx, "Removed unhealthy decode instance: %s", worker.Url)
|
||||
}
|
||||
}
|
||||
for _, id := range mixedToRemove {
|
||||
if worker, exists := DefaultManager.mixedWorkerMap[id]; exists {
|
||||
delete(DefaultManager.mixedWorkerMap, id)
|
||||
logger.Info("Removed unhealthy mixed instance: %s", worker.Url)
|
||||
logger.Info(ctx, "Removed unhealthy mixed instance: %s", worker.Url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedInstances []string) {
|
||||
if DefaultManager == nil {
|
||||
logger.Debug("Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
|
||||
logger.Debug(ctx, "Healthy instances: prefill=[], decode=[], mixed=[] (DefaultManager is nil)")
|
||||
return []string{}, []string{}, []string{}
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ func ReadServers(ctx context.Context) (prefillInstances, decodeInstances, mixedI
|
||||
for _, w := range DefaultManager.mixedWorkerMap {
|
||||
mixedInstances = append(mixedInstances, w.Url)
|
||||
}
|
||||
logger.Debug(
|
||||
logger.Debug(ctx,
|
||||
"Healthy instances: prefill=%v, decode=%v, mixed=%v",
|
||||
prefillInstances,
|
||||
decodeInstances,
|
||||
|
||||
@@ -209,7 +209,7 @@ func RegisterInstanceCore(ctx context.Context, rawInstance *InstanceInfo) error
|
||||
case DECODE:
|
||||
DefaultManager.decodeWorkerMap[id] = workerInfo
|
||||
default:
|
||||
logger.Warn("Instance %s role is unknown", id)
|
||||
logger.Warn(ctx, "Instance %s role is unknown", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -236,7 +236,7 @@ func RegisterInstance(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := RegisterInstanceCore(c.Request.Context(), &rawInstance); err != nil {
|
||||
logger.Error("Failed to register instance: %v", err)
|
||||
logger.Error(c.Request.Context(), "Failed to register instance: %v", err)
|
||||
// Return different HTTP status codes based on error type
|
||||
if strings.Contains(err.Error(), "not healthy") {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
@@ -264,26 +264,26 @@ func RegisterInstancesFromConfig(yamlPath string) {
|
||||
}
|
||||
data, err := os.ReadFile(yamlPath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to read YAML file %s: %v", yamlPath, err)
|
||||
logger.Error(context.Background(), "Failed to read YAML file %s: %v", yamlPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
var config RegisterConfig
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
logger.Error("Failed to unmarshal YAML file %s: %v", yamlPath, err)
|
||||
logger.Error(context.Background(), "Failed to unmarshal YAML file %s: %v", yamlPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(config.Instances) == 0 {
|
||||
logger.Info("No instances found in config file %s", yamlPath)
|
||||
logger.Info(context.Background(), "No instances found in config file %s", yamlPath)
|
||||
return
|
||||
}
|
||||
|
||||
for i, instanceConfig := range config.Instances {
|
||||
if err := RegisterInstanceCore(context.Background(), &instanceConfig); err != nil {
|
||||
logger.Error("Failed to register instance from index %d: %v", i, err)
|
||||
logger.Error(context.Background(), "Failed to register instance from index %d: %v", i, err)
|
||||
} else {
|
||||
logger.Info("Successfully registered instance from index %d", i)
|
||||
logger.Info(context.Background(), "Successfully registered instance from index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
// Logger logger middleware
|
||||
func Logger() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
logger.Info("[%s] %s %s %d %s %s",
|
||||
logger.Info(param.Request.Context(), "[%s] %s %s %d %s %s",
|
||||
param.Method,
|
||||
param.Path,
|
||||
param.Request.Proto,
|
||||
@@ -23,7 +23,7 @@ func Logger() gin.HandlerFunc {
|
||||
// Recovery recovery middleware
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
logger.Error("Panic recovered: %v", recovered)
|
||||
logger.Error(c.Request.Context(), "Panic recovered: %v", recovered)
|
||||
c.JSON(500, gin.H{
|
||||
"code": 500,
|
||||
"msg": "Internal server error",
|
||||
|
||||
@@ -127,9 +127,9 @@ func SelectWorker(ctx context.Context, workers []string, message string, workerT
|
||||
}
|
||||
|
||||
if workerType == "prefill" {
|
||||
logger.Info("select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
|
||||
logger.Info(ctx, "select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
|
||||
} else {
|
||||
logger.Info("select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
|
||||
logger.Info(ctx, "select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
|
||||
}
|
||||
|
||||
return selectWorkerURL, nil
|
||||
@@ -139,7 +139,7 @@ func SelectWorker(ctx context.Context, workers []string, message string, workerT
|
||||
func Release(ctx context.Context, url string) {
|
||||
counter := GetOrCreateCounter(ctx, url)
|
||||
counter.Dec()
|
||||
logger.Info("release worker: %s, count: %d", url, counter.Get())
|
||||
logger.Info(ctx, "release worker: %s, count: %d", url, counter.Get())
|
||||
}
|
||||
|
||||
// GetCounter retrieves the counter for the specified root URL
|
||||
@@ -182,7 +182,7 @@ func CleanupUnhealthyCounter(ctx context.Context, unhealthyRootURL string) {
|
||||
|
||||
delete(DefaultScheduler.IdCounterMap, unhealthyRootURL)
|
||||
delete(DefaultScheduler.tokenMap, unhealthyRootURL)
|
||||
logger.Info("After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
|
||||
logger.Info(ctx, "After cleanup unhealthy counter: %v", DefaultScheduler.IdCounterMap)
|
||||
}
|
||||
|
||||
// CleanupInvalidCounters removes counters for invalid or unreachable workers
|
||||
@@ -218,7 +218,7 @@ func CleanupInvalidCounters(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
|
||||
logger.Info(ctx, "After cleanup invalid counters: %v", DefaultScheduler.IdCounterMap)
|
||||
}
|
||||
|
||||
// StartBackupCleanupTask starts a background task for cleaning up invalid counters
|
||||
@@ -278,5 +278,5 @@ func ReleasePrefillTokens(ctx context.Context, url, message string) {
|
||||
}
|
||||
tokenCounter := GetOrCreateTokenCounter(ctx, url)
|
||||
tokenCounter.Sub(estimateTokens(message))
|
||||
logger.Info("release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
|
||||
logger.Info(ctx, "release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
|
||||
}
|
||||
|
||||
@@ -65,21 +65,21 @@ func CacheAwarePrefillSelectWorker(ctx context.Context, workers []string, messag
|
||||
tokens, err := strategy.tokenize(ctx, message)
|
||||
if err != nil || len(tokens) == 0 {
|
||||
if err != nil {
|
||||
logger.Warn("cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
|
||||
logger.Warn(ctx, "cache-aware prefill: tokenizer failed, fallback to process_tokens: %v", err)
|
||||
}
|
||||
return ProcessTokensSelectWorker(ctx, workers, message)
|
||||
}
|
||||
|
||||
// 3) Compute prefix tree hit rate
|
||||
hitRatios := strategy.cache.Match(tokens, toWorkerSet(workers))
|
||||
logger.Debug("cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
|
||||
logger.Debug(ctx, "cache-aware prefill: hashes=%d workers=%d load=%v hit=%v", len(strategy.cache.hasher.prefixHashes(tokens)), len(workers), loads, hitRatios)
|
||||
|
||||
// 4) Compute weighted score from hit rate and load
|
||||
selected := strategy.chooseByScore(ctx, workers, loads, hitRatios)
|
||||
|
||||
// 5) Record prefix
|
||||
strategy.cache.Record(tokens, selected)
|
||||
logger.Debug("cache-aware prefill: selected=%s", selected)
|
||||
logger.Debug(ctx, "cache-aware prefill: selected=%s", selected)
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
@@ -94,10 +94,10 @@ func (p *prefillCacheStrategy) tokenize(ctx context.Context, message string) ([]
|
||||
}
|
||||
tokens, err := p.tokenizer.Tokenize(ctx, message)
|
||||
if err != nil {
|
||||
logger.Warn("cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
|
||||
logger.Warn(ctx, "cache-aware prefill: tokenizer failed, fallback to char tokens: %v", err)
|
||||
return charsToTokens(message), nil
|
||||
}
|
||||
logger.Debug("cache-aware prefill: tokenizer tokens=%v", tokens)
|
||||
logger.Debug(ctx, "cache-aware prefill: tokenizer tokens=%v", tokens)
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (p *prefillCacheStrategy) chooseByScore(ctx context.Context, workers []stri
|
||||
}
|
||||
|
||||
score := (100.0-hit)/100*p.hitRatioWeight + loadRatio*p.loadBalanceWeight
|
||||
logger.Debug("cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
|
||||
logger.Debug(ctx, "cache-aware score: worker=%s hit=%.1f loadRatio=%.3f score=%.3f", w, hit, loadRatio, score)
|
||||
|
||||
if score < bestScore {
|
||||
bestScore = score
|
||||
@@ -243,7 +243,7 @@ func (c *radixPrefixCache) Match(tokens []int, allowed map[string]struct{}) map[
|
||||
c.mu.RLock()
|
||||
node, matched := c.matchPrefixHelper(c.root, hashes)
|
||||
length := matched
|
||||
logger.Debug("radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
|
||||
logger.Debug(context.Background(), "radix match: hashes=%d matched_len=%d node_children=%d", len(hashes), matched, len(node.children))
|
||||
for n := node; n != nil; n = n.parent {
|
||||
ratio := 0
|
||||
if len(hashes) > 0 {
|
||||
@@ -291,7 +291,7 @@ func (c *radixPrefixCache) Record(tokens []int, worker string) {
|
||||
}
|
||||
n.workers[worker] = now
|
||||
}
|
||||
logger.Debug("radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
|
||||
logger.Debug(context.Background(), "radix record: worker=%s hashes=%d node_depth=%d", worker, len(hashes), node.contextLen)
|
||||
}
|
||||
|
||||
// evictionWorker periodically evicts inactive nodes
|
||||
@@ -313,7 +313,7 @@ func (c *radixPrefixCache) evictExpired() {
|
||||
removed += c.evictSubtreeIfExpired(c.root, childKey, child, now)
|
||||
}
|
||||
if removed > 0 {
|
||||
logger.Debug("radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
|
||||
logger.Debug(context.Background(), "radix eviction: removed=%d nodeCount=%d", removed, c.nodeCount)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"context"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -16,6 +17,9 @@ var (
|
||||
logFile *os.File
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
const RequestIDKey contextKey = "request_id"
|
||||
|
||||
// Init initialize logger
|
||||
func Init(logLevel, output string) {
|
||||
once.Do(func() {
|
||||
@@ -53,28 +57,42 @@ func CloseLogFile() {
|
||||
}
|
||||
}
|
||||
|
||||
func contextPrefix(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if rid, ok := ctx.Value(RequestIDKey).(string); ok && rid != "" {
|
||||
return "[request_id:" + rid + "] "
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Info logs informational messages
|
||||
func Info(format string, v ...interface{}) {
|
||||
func Info(ctx context.Context, format string, v ...interface{}) {
|
||||
if level == "debug" || level == "info" {
|
||||
infoLogger.Printf(format, v...)
|
||||
prefix := contextPrefix(ctx)
|
||||
infoLogger.Printf(prefix+format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error logs error messages
|
||||
func Error(format string, v ...interface{}) {
|
||||
errorLogger.Printf(format, v...)
|
||||
func Error(ctx context.Context, format string, v ...interface{}) {
|
||||
prefix := contextPrefix(ctx)
|
||||
errorLogger.Printf(prefix+format, v...)
|
||||
}
|
||||
|
||||
// Warn logs warning messages
|
||||
func Warn(format string, v ...interface{}) {
|
||||
func Warn(ctx context.Context, format string, v ...interface{}) {
|
||||
if level == "debug" || level == "info" || level == "warn" {
|
||||
warnLogger.Printf(format, v...)
|
||||
prefix := contextPrefix(ctx)
|
||||
warnLogger.Printf(prefix+format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs debug messages
|
||||
func Debug(format string, v ...interface{}) {
|
||||
func Debug(ctx context.Context, format string, v ...interface{}) {
|
||||
if level == "debug" {
|
||||
debugLogger.Printf(format, v...)
|
||||
prefix := contextPrefix(ctx)
|
||||
debugLogger.Printf(prefix+format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -22,7 +23,12 @@ func TestLoggerInit(t *testing.T) {
|
||||
_ = os.MkdirAll("logs", 0755)
|
||||
defer os.RemoveAll("logs")
|
||||
|
||||
Init("debug", "file")
|
||||
// sync.Once prevents re-init, so manually verify file creation logic
|
||||
f, err := os.OpenFile("logs/router.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create log file: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if _, err := os.Stat("logs/router.log"); os.IsNotExist(err) {
|
||||
t.Error("Log file should be created")
|
||||
@@ -64,11 +70,11 @@ func TestLogLevels(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize logger with test level
|
||||
Init(tt.level, "stdout")
|
||||
// Directly set package-level variable since sync.Once prevents re-init
|
||||
level = tt.level
|
||||
|
||||
// Capture output for each level separately
|
||||
testLevel := func(logFunc func(string, ...interface{}), message string) bool {
|
||||
testLevel := func(logFunc func(context.Context, string, ...interface{}), message string) bool {
|
||||
var buf bytes.Buffer
|
||||
oldOutput := infoLogger.Writer()
|
||||
|
||||
@@ -77,7 +83,7 @@ func TestLogLevels(t *testing.T) {
|
||||
warnLogger.SetOutput(&buf)
|
||||
debugLogger.SetOutput(&buf)
|
||||
|
||||
logFunc(message)
|
||||
logFunc(nil, message)
|
||||
|
||||
infoLogger.SetOutput(oldOutput)
|
||||
errorLogger.SetOutput(oldOutput)
|
||||
@@ -112,16 +118,64 @@ func TestLogLevels(t *testing.T) {
|
||||
func TestLogFunctions(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
Init("debug", "stdout")
|
||||
level = "debug"
|
||||
|
||||
// Redirect output
|
||||
oldOutput := infoLogger.Writer()
|
||||
defer func() { infoLogger.SetOutput(oldOutput) }()
|
||||
infoLogger.SetOutput(&buf)
|
||||
|
||||
Info("test %s", "message")
|
||||
Info(nil, "test %s", "message")
|
||||
if !strings.Contains(buf.String(), "test message") {
|
||||
t.Error("Info log should contain the message")
|
||||
}
|
||||
|
||||
// Similar tests for Error, Warn, Debug...
|
||||
}
|
||||
|
||||
func TestContextPrefix(t *testing.T) {
|
||||
Init("debug", "stdout")
|
||||
level = "debug"
|
||||
|
||||
t.Run("nil context produces no prefix", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
oldOutput := infoLogger.Writer()
|
||||
defer func() { infoLogger.SetOutput(oldOutput) }()
|
||||
infoLogger.SetOutput(&buf)
|
||||
|
||||
Info(nil, "no prefix here")
|
||||
output := buf.String()
|
||||
if strings.Contains(output, "[request_id:") {
|
||||
t.Errorf("nil context should produce no request_id prefix, got: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "no prefix here") {
|
||||
t.Errorf("message should be present, got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context without request_id produces [request_id:null]", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
oldOutput := infoLogger.Writer()
|
||||
defer func() { infoLogger.SetOutput(oldOutput) }()
|
||||
infoLogger.SetOutput(&buf)
|
||||
|
||||
ctx := context.Background()
|
||||
Info(ctx, "mixed mode log")
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[request_id:null]") {
|
||||
t.Errorf("context without request_id should produce [request_id:null], got: %s", output)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("context with request_id produces [request_id:xxx]", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
oldOutput := infoLogger.Writer()
|
||||
defer func() { infoLogger.SetOutput(oldOutput) }()
|
||||
infoLogger.SetOutput(&buf)
|
||||
|
||||
ctx := context.WithValue(context.Background(), RequestIDKey, "test-uuid-123")
|
||||
Info(ctx, "pd mode log")
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "[request_id:test-uuid-123]") {
|
||||
t.Errorf("context with request_id should produce [request_id:test-uuid-123], got: %s", output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user