Files
FastDeploy/fastdeploy/golang_router/internal/scheduler/handler/handler.go
T
mouxin 6cae9b1f50 [Feature] Config eviction_duration (#7125)
* [Feature] Config eviction_duration

* [Feature] Config eviction_duration

* [Feature] Config eviction_duration

* [Feature] Config eviction_duration

---------

Co-authored-by: mouxin <mouxin@baidu.com>
2026-04-01 16:46:21 +08:00

375 lines
12 KiB
Go

package handler
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
common "github.com/PaddlePaddle/FastDeploy/router/internal/common"
"github.com/PaddlePaddle/FastDeploy/router/internal/config"
scheduler_common "github.com/PaddlePaddle/FastDeploy/router/internal/scheduler/common"
"github.com/PaddlePaddle/FastDeploy/router/pkg/logger"
)
type Scheduler struct {
policy string
prefillPolicy string
decodePolicy string
IdCounterMap map[string]*scheduler_common.Counter
tokenMap map[string]*scheduler_common.TokenCounter
managerAPI common.ManagerAPI
prefillCache *prefillCacheStrategy
mu sync.RWMutex
}
type CounterPolicy struct {
counter atomic.Uint64
prefillCounter atomic.Uint64
workerType string
}
var DefaultScheduler *Scheduler
var DefaultCounterPolicy *CounterPolicy
var waitingWeight float64
// Init initializes the scheduler with the given configuration and manager API
func Init(cfg *config.Config, managerAPI common.ManagerAPI) {
prefillCfg := &schedulerConfigSnapshot{
balanceAbsThreshold: cfg.Scheduler.BalanceAbsThreshold,
balanceRelThreshold: cfg.Scheduler.BalanceRelThreshold,
hitRatioWeight: cfg.Scheduler.HitRatioWeight,
loadBalanceWeight: cfg.Scheduler.LoadBalanceWeight,
cacheBlockSize: cfg.Scheduler.CacheBlockSize,
tokenizerURL: cfg.Scheduler.TokenizerURL,
tokenizerTimeout: time.Duration(cfg.Scheduler.TokenizerTimeoutSecs * float64(time.Second)),
evictionDuration: time.Duration(cfg.Scheduler.EvictionDurationMins * float64(time.Minute)),
}
scheduler := &Scheduler{
policy: cfg.Scheduler.Policy,
prefillPolicy: cfg.Scheduler.PrefillPolicy,
decodePolicy: cfg.Scheduler.DecodePolicy,
IdCounterMap: make(map[string]*scheduler_common.Counter),
tokenMap: make(map[string]*scheduler_common.TokenCounter),
managerAPI: managerAPI,
prefillCache: newPrefillCacheStrategy(prefillCfg),
}
counterPolicy := &CounterPolicy{}
DefaultScheduler = scheduler
DefaultCounterPolicy = counterPolicy
waitingWeight = cfg.Scheduler.WaitingWeight
}
// SelectWorker selects a worker based on the specified policy and worker type
func SelectWorker(ctx context.Context, workers []string, message string, workerType string) (string, error) {
if len(workers) == 0 {
return "", fmt.Errorf("no healthy workers available")
}
var policy string
switch workerType {
case "prefill":
policy = DefaultScheduler.prefillPolicy
DefaultCounterPolicy.workerType = "prefill"
case "decode":
policy = DefaultScheduler.decodePolicy
DefaultCounterPolicy.workerType = "decode"
default:
policy = DefaultScheduler.policy
DefaultCounterPolicy.workerType = "mixed"
}
var strategyFunc scheduler_common.SelectStrategyFunc
switch policy {
case "random":
strategyFunc = RandomSelectWorker
case "round_robin":
strategyFunc = RoundRobinSelectWorker
case "power_of_two":
strategyFunc = PowerOfTwoSelectWorker
case "process_tokens":
// Prefill: prioritize the instance with the smallest number of tokens currently being processed
strategyFunc = ProcessTokensSelectWorker
case "request_num":
// Decode/mixed: prioritize the instance with the smallest number of current requests
strategyFunc = RequestNumSelectWorker
case "fd_metrics_score":
strategyFunc = FDMetricsScoreSelectWorker
case "fd_remote_metrics_score":
strategyFunc = FDRemoteMetricsScoreSelectWorker
case "cache_aware":
strategyFunc = CacheAwarePrefillSelectWorker
case "remote_cache_aware":
strategyFunc = RemoteCacheAwarePrefillSelectWorker
default:
strategyFunc = RandomSelectWorker
}
selectWorkerURL, err := strategyFunc(ctx, workers, message)
if err != nil {
return "", fmt.Errorf("select worker failed [policy: %s]: %w", DefaultScheduler.policy, err)
}
if !strings.HasPrefix(selectWorkerURL, "http://") && !strings.HasPrefix(selectWorkerURL, "https://") {
selectWorkerURL = "http://" + selectWorkerURL
}
// 1) All node types: request concurrency count (request_num)
counter := GetOrCreateCounter(ctx, selectWorkerURL)
counter.Inc()
count := counter.Get()
// 2) Prefill: current token processing count (process_tokens)
var tokens uint64
if (workerType == "prefill" || workerType == "mixed") && message != "" {
tokenCounter := GetOrCreateTokenCounter(ctx, selectWorkerURL)
tokenCounter.Add(estimateTokens(message))
tokens = tokenCounter.Get()
}
if workerType == "prefill" {
logger.Info(ctx, "select worker (prefill): %s, tokens: %d", selectWorkerURL, tokens)
} else {
logger.Info(ctx, "select worker (%s): %s, count: %d", workerType, selectWorkerURL, count)
}
return selectWorkerURL, nil
}
// Release decreases the counter for the specified worker URL.
// Uses GetCounter (not GetOrCreateCounter) to avoid creating ghost entries
// when the counter has already been cleaned up.
func Release(ctx context.Context, url string) {
if DefaultScheduler == nil {
return
}
counter, exists := GetCounter(ctx, url)
if !exists {
logger.Warn(ctx, "release worker: %s skipped, counter already cleaned up", url)
return
}
if !counter.Dec() {
logger.Warn(ctx, "release worker: %s skipped, counter already zero (possible double-release)", url)
return
}
logger.Info(ctx, "release worker: %s, count: %d", url, counter.Get())
}
// GetCounter retrieves the counter for the specified root URL
func GetCounter(ctx context.Context, rootURL string) (*scheduler_common.Counter, bool) {
DefaultScheduler.mu.RLock()
defer DefaultScheduler.mu.RUnlock()
counter, exists := DefaultScheduler.IdCounterMap[rootURL]
return counter, exists
}
// GetOrCreateCounter retrieves an existing counter or creates a new one
func GetOrCreateCounter(ctx context.Context, url string) *scheduler_common.Counter {
counter, exists := GetCounter(ctx, url)
if exists {
return counter
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
// Double check: avoid overwriting what other goroutines may have created before acquiring write lock
if counter, exists = DefaultScheduler.IdCounterMap[url]; exists {
return counter
}
newCounter := &scheduler_common.Counter{}
DefaultScheduler.IdCounterMap[url] = newCounter
return newCounter
}
// CleanupUnhealthyCounter removes counters for unhealthy worker URLs only
// when the counter has reached zero (no inflight requests). If there are
// still inflight requests, the counter is preserved so Dec() works correctly.
func CleanupUnhealthyCounter(ctx context.Context, unhealthyRootURL string) {
if unhealthyRootURL == "" {
return
}
if DefaultScheduler == nil {
return
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
if counter, exists := DefaultScheduler.IdCounterMap[unhealthyRootURL]; exists {
if counter.Get() > 0 {
logger.Info(ctx, "unhealthy worker counter preserved (inflight requests): %s, count: %d", unhealthyRootURL, counter.Get())
} else {
delete(DefaultScheduler.IdCounterMap, unhealthyRootURL)
logger.Info(ctx, "cleanup unhealthy worker counter: %s", unhealthyRootURL)
}
}
if tokenCounter, exists := DefaultScheduler.tokenMap[unhealthyRootURL]; exists {
if tokenCounter.Get() > 0 {
logger.Info(ctx, "unhealthy worker token counter preserved (inflight requests): %s, tokens: %d", unhealthyRootURL, tokenCounter.Get())
} else {
delete(DefaultScheduler.tokenMap, unhealthyRootURL)
logger.Info(ctx, "cleanup unhealthy worker token counter: %s", unhealthyRootURL)
}
}
}
// CleanupInvalidCounters removes counters for invalid or unreachable workers
// only when their counter has reached zero (no inflight requests).
func CleanupInvalidCounters(ctx context.Context) {
if DefaultScheduler == nil {
return
}
if DefaultScheduler.managerAPI == nil {
return
}
healthyURLs := DefaultScheduler.managerAPI.GetHealthyURLs(ctx)
if len(healthyURLs) == 0 {
return
}
healthyMap := make(map[string]bool)
for _, rootURL := range healthyURLs {
healthyMap[rootURL] = true
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
var removed []string
var preserved []string
for rootURL, counter := range DefaultScheduler.IdCounterMap {
if _, exists := healthyMap[rootURL]; !exists {
if counter.Get() > 0 {
preserved = append(preserved, rootURL)
} else {
delete(DefaultScheduler.IdCounterMap, rootURL)
removed = append(removed, rootURL)
}
}
}
for rootURL, tokenCounter := range DefaultScheduler.tokenMap {
if _, exists := healthyMap[rootURL]; !exists {
if tokenCounter.Get() == 0 {
delete(DefaultScheduler.tokenMap, rootURL)
}
}
}
if len(removed) > 0 {
logger.Info(ctx, "removed counters for %d unhealthy workers: %v", len(removed), removed)
}
if len(preserved) > 0 {
logger.Info(ctx, "preserved counters for %d workers with inflight requests: %v", len(preserved), preserved)
}
}
// StartBackupCleanupTask starts a background task for cleaning up invalid counters
func StartBackupCleanupTask(ctx context.Context, interval float64) {
ticker := time.NewTicker(time.Duration(interval * float64(time.Second)))
defer ticker.Stop()
for {
select {
// case 1: listen for context cancellation/timeout events → graceful exit
case <-ctx.Done():
return // Exit loop, stop cleanup task
// case 2: listen for timer trigger events → perform cleanup
case <-ticker.C:
CleanupInvalidCounters(ctx)
}
}
}
// GetTokenCounter gets the TokenCounter for the specified instance
func GetTokenCounter(ctx context.Context, rootURL string) (*scheduler_common.TokenCounter, bool) {
DefaultScheduler.mu.RLock()
defer DefaultScheduler.mu.RUnlock()
counter, exists := DefaultScheduler.tokenMap[rootURL]
return counter, exists
}
// GetOrCreateTokenCounter gets or creates TokenCounter
func GetOrCreateTokenCounter(ctx context.Context, url string) *scheduler_common.TokenCounter {
counter, exists := GetTokenCounter(ctx, url)
if exists {
return counter
}
DefaultScheduler.mu.Lock()
defer DefaultScheduler.mu.Unlock()
// Double check to avoid overwriting
if counter, exists = DefaultScheduler.tokenMap[url]; exists {
return counter
}
newCounter := &scheduler_common.TokenCounter{}
DefaultScheduler.tokenMap[url] = newCounter
return newCounter
}
// estimateTokens estimates token count based on character count: character count * 2
func estimateTokens(message string) uint64 {
if message == "" {
return 0
}
runeCount := utf8.RuneCountInString(message)
return uint64(runeCount * 2)
}
// ReleasePrefillTokens releases the corresponding token load when request ends.
// Uses GetTokenCounter (not GetOrCreateTokenCounter) to avoid creating ghost entries.
func ReleasePrefillTokens(ctx context.Context, url, message string) {
if DefaultScheduler == nil || url == "" || message == "" {
return
}
tokenCounter, exists := GetTokenCounter(ctx, url)
if !exists {
return
}
tokenCounter.Sub(estimateTokens(message))
logger.Info(ctx, "release prefill tokens: %s, tokens: %d", url, tokenCounter.Get())
}
// StartStatsReporter periodically logs all worker loads and cache hit rate
func StartStatsReporter(ctx context.Context, interval float64) {
ticker := time.NewTicker(time.Duration(interval * float64(time.Second)))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
reportStats(ctx)
}
}
}
func reportStats(ctx context.Context) {
if DefaultScheduler == nil || DefaultScheduler.managerAPI == nil {
return
}
healthyURLs := DefaultScheduler.managerAPI.GetHealthyURLs(ctx)
totalRunning := 0
var workerLoads []string
for _, url := range healthyURLs {
running, _, _ := DefaultScheduler.managerAPI.GetMetrics(ctx, url)
totalRunning += running
workerLoads = append(workerLoads, fmt.Sprintf("%s: running=%d", url, running))
}
// Cache hit stats (periodic reset)
hits, total := GetAndResetCacheHitStats()
hitRate := 0.0
if total > 0 {
hitRate = float64(hits) * 100 / float64(total)
}
logger.Info(ctx, "[stats] total_running=%d, workers: [%s], cache_hit_rate=%.2f%% (hits=%d/total=%d)",
totalRunning, strings.Join(workerLoads, ", "), hitRate, hits, total)
}