diff --git a/docs/online_serving/router.md b/docs/online_serving/router.md index fc973de9a8..82940e5680 100644 --- a/docs/online_serving/router.md +++ b/docs/online_serving/router.md @@ -195,6 +195,7 @@ scheduler: prefill-policy: "cache_aware" # Prefill scheduling policy in PD mode decode-policy: "request_num" # Decode scheduling policy in PD mode eviction-interval-secs: 60 # Cache eviction interval for CacheAware scheduling + eviction-duration-mins: 30 # Eviction duration for cache-aware radix tree nodes (minutes); default: 30 balance-abs-threshold: 1 # Absolute threshold for CacheAware balancing balance-rel-threshold: 0.2 # Relative threshold for CacheAware balancing hit-ratio-weight: 1.0 # Cache hit ratio weight diff --git a/docs/zh/online_serving/router.md b/docs/zh/online_serving/router.md index c5748daa7b..0ace28c2da 100644 --- a/docs/zh/online_serving/router.md +++ b/docs/zh/online_serving/router.md @@ -195,6 +195,7 @@ scheduler: prefill-policy: "cache_aware" # pd分离模式下prefill节点调度策略; 默认: process_tokens decode-policy: "request_num" # pd分离模式下decode节点调度策略; 默认: request_num eviction-interval-secs: 60 # cache-aware策略清理过期cache的间隔时间 + eviction-duration-mins: 30 # cache-aware策略radix tree节点驱逐时间(分钟); 默认: 30 balance-abs-threshold: 1 # cache-aware策略绝对阈值 balance-rel-threshold: 0.2 # cache-aware策略相对阈值 hit-ratio-weight: 1.0 # cache-aware策略命中率权重 diff --git a/fastdeploy/golang_router/examples/run_with_config/config/config.example.yaml b/fastdeploy/golang_router/examples/run_with_config/config/config.example.yaml index ba1a51acc3..be4b11227d 100644 --- a/fastdeploy/golang_router/examples/run_with_config/config/config.example.yaml +++ b/fastdeploy/golang_router/examples/run_with_config/config/config.example.yaml @@ -9,6 +9,7 @@ scheduler: prefill-policy: "cache_aware" decode-policy: "request_num" eviction-interval-secs: 60 + eviction-duration-mins: 30 # eviction duration for cache-aware radix tree nodes (minutes); default: 30 balance-abs-threshold: 1 balance-rel-threshold: 0.2 hit-ratio-weight: 1.0 diff --git a/fastdeploy/golang_router/examples/run_with_default_workers/config/config.example.yaml b/fastdeploy/golang_router/examples/run_with_default_workers/config/config.example.yaml index ba1a51acc3..be4b11227d 100644 --- a/fastdeploy/golang_router/examples/run_with_default_workers/config/config.example.yaml +++ b/fastdeploy/golang_router/examples/run_with_default_workers/config/config.example.yaml @@ -9,6 +9,7 @@ scheduler: prefill-policy: "cache_aware" decode-policy: "request_num" eviction-interval-secs: 60 + eviction-duration-mins: 30 # eviction duration for cache-aware radix tree nodes (minutes); default: 30 balance-abs-threshold: 1 balance-rel-threshold: 0.2 hit-ratio-weight: 1.0 diff --git a/fastdeploy/golang_router/internal/config/config.go b/fastdeploy/golang_router/internal/config/config.go index 9f1c1c0e34..2cb8226961 100644 --- a/fastdeploy/golang_router/internal/config/config.go +++ b/fastdeploy/golang_router/internal/config/config.go @@ -36,6 +36,7 @@ type SchedulerConfig struct { PrefillPolicy string `yaml:"prefill-policy"` DecodePolicy string `yaml:"decode-policy"` EvictionIntervalSecs float64 `yaml:"eviction-interval-secs"` + EvictionDurationMins float64 `yaml:"eviction-duration-mins"` CacheBlockSize int `yaml:"cache-block-size"` TokenizerURL string `yaml:"tokenizer-url"` TokenizerTimeoutSecs float64 `yaml:"tokenizer-timeout-secs"` @@ -98,6 +99,9 @@ func Load(configPath, listenPort string, isSplitwise bool) (*Config, error) { if cfg.Scheduler.EvictionIntervalSecs == 0 { cfg.Scheduler.EvictionIntervalSecs = 60 } + if cfg.Scheduler.EvictionDurationMins == 0 { + cfg.Scheduler.EvictionDurationMins = 30 + } if cfg.Scheduler.CacheBlockSize == 0 { cfg.Scheduler.CacheBlockSize = 64 } diff --git a/fastdeploy/golang_router/internal/gateway/completions_test.go b/fastdeploy/golang_router/internal/gateway/completions_test.go index b188532aba..825544ff5e 100644 --- a/fastdeploy/golang_router/internal/gateway/completions_test.go +++ b/fastdeploy/golang_router/internal/gateway/completions_test.go @@ -4,16 +4,27 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" + "os" "strings" + "sync/atomic" "testing" + "time" + "github.com/PaddlePaddle/FastDeploy/router/pkg/logger" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) +func TestMain(m *testing.M) { + logger.Init("info", "stdout") + gin.SetMode(gin.TestMode) + os.Exit(m.Run()) +} + func TestChatCompletions(t *testing.T) { // Since the actual implementation uses package-level functions that depend on DefaultManager, // and we don't want to set up a full manager for unit tests, @@ -570,3 +581,405 @@ func TestCommonCompletions(t *testing.T) { w.Body.String() != "") }) } + +// ============================================================ +// Test helpers for timeout / hang simulation +// ============================================================ + +// newHangingServer creates an httptest.Server whose handler blocks until +// the returned cleanup function is called. Always call cleanup in defer. +func newHangingServer() (server *httptest.Server, cleanup func()) { + done := make(chan struct{}) + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-done + })) + cleanup = func() { + close(done) + server.Close() + } + return +} + +// newSlowServer creates an httptest.Server that waits for the given duration +// before responding. cleanup unblocks any in-flight handlers. +func newSlowServer(delay time.Duration, statusCode int, body string) (server *httptest.Server, cleanup func()) { + done := make(chan struct{}) + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-time.After(delay): + w.WriteHeader(statusCode) + w.Write([]byte(body)) + case <-done: + // test cleanup + } + })) + cleanup = func() { + close(done) + server.Close() + } + return +} + +// newGinContextWithTimeout creates a gin.Context with a request whose context +// has the given timeout. Returns cancel for cleanup. +func newGinContextWithTimeout(timeout time.Duration) (*gin.Context, *httptest.ResponseRecorder, context.CancelFunc) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + req := httptest.NewRequest("POST", "/v1/chat/completions", + bytes.NewBufferString(`{"test":"data"}`)) + c.Request = req.WithContext(ctx) + return c, w, cancel +} + +// hangingReader is a custom io.ReadCloser that returns initial data, +// then blocks until its context is cancelled (simulating mid-stream hang). +type hangingReader struct { + data []byte + offset int + hangAt int // byte offset at which to start hanging + ctx context.Context +} + +func (r *hangingReader) Read(p []byte) (int, error) { + if r.offset >= r.hangAt { + <-r.ctx.Done() + return 0, r.ctx.Err() + } + end := r.offset + len(p) + if end > r.hangAt { + end = r.hangAt + } + n := copy(p, r.data[r.offset:end]) + r.offset += n + return n, nil +} + +func (r *hangingReader) Close() error { return nil } + +// ============================================================ +// PostToPD timeout tests +// ============================================================ + +func TestPostToPD_PrefillHangs(t *testing.T) { + hangServer, hangCleanup := newHangingServer() + defer hangCleanup() + + decodeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("decode ok")) + })) + defer decodeServer.Close() + + c, _, cancel := newGinContextWithTimeout(200 * time.Millisecond) + defer cancel() + + start := time.Now() + resp, err := PostToPD(c, decodeServer.URL, hangServer.URL, []byte(`{"test":"data"}`), false, "msg", "chat/completions") + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected context deadline exceeded or canceled, got: %v", err) + assert.Nil(t, resp) + assert.Less(t, elapsed, 5*time.Second, "should not hang indefinitely") +} + +func TestPostToPD_DecodeHangs(t *testing.T) { + prefillServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("prefill ok")) + })) + defer prefillServer.Close() + + hangServer, hangCleanup := newHangingServer() + defer hangCleanup() + + c, _, cancel := newGinContextWithTimeout(200 * time.Millisecond) + defer cancel() + + start := time.Now() + resp, err := PostToPD(c, hangServer.URL, prefillServer.URL, []byte(`{"test":"data"}`), false, "msg", "chat/completions") + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected context deadline exceeded or canceled, got: %v", err) + assert.Nil(t, resp) + assert.Less(t, elapsed, 5*time.Second) +} + +func TestPostToPD_BothHang(t *testing.T) { + hangP, cleanupP := newHangingServer() + defer cleanupP() + + hangD, cleanupD := newHangingServer() + defer cleanupD() + + c, _, cancel := newGinContextWithTimeout(200 * time.Millisecond) + defer cancel() + + start := time.Now() + resp, err := PostToPD(c, hangD.URL, hangP.URL, []byte(`{"test":"data"}`), false, "msg", "chat/completions") + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected context deadline exceeded or canceled, got: %v", err) + assert.Nil(t, resp) + assert.Less(t, elapsed, 5*time.Second) +} + +func TestPostToPD_ContextCancellation(t *testing.T) { + hangP, cleanupP := newHangingServer() + defer cleanupP() + + hangD, cleanupD := newHangingServer() + defer cleanupD() + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("POST", "/v1/chat/completions", + bytes.NewBufferString(`{"test":"data"}`)) + c.Request = req.WithContext(ctx) + + type result struct { + resp *http.Response + err error + } + ch := make(chan result, 1) + go func() { + resp, err := PostToPD(c, hangD.URL, hangP.URL, []byte(`{"test":"data"}`), false, "msg", "chat/completions") + ch <- result{resp, err} + }() + + // Cancel after a short delay + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case res := <-ch: + assert.Error(t, res.err) + assert.True(t, errors.Is(res.err, context.Canceled), + "expected context.Canceled, got: %v", res.err) + assert.Nil(t, res.resp) + case <-time.After(5 * time.Second): + t.Fatal("PostToPD did not return after context cancellation") + } +} + +func TestPostToPD_PrefillSlowButCompletes(t *testing.T) { + slowPrefill, cleanupP := newSlowServer(50*time.Millisecond, http.StatusOK, "prefill done") + defer cleanupP() + + decodeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("decode done")) + })) + defer decodeServer.Close() + + c, _, cancel := newGinContextWithTimeout(2 * time.Second) + defer cancel() + + resp, err := PostToPD(c, decodeServer.URL, slowPrefill.URL, []byte(`{"test":"data"}`), false, "msg", "chat/completions") + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Equal(t, "decode done", string(body)) +} + +// ============================================================ +// GetClient / GetClientWithRetry timeout tests +// ============================================================ + +func TestGetClient_Timeout(t *testing.T) { + hangServer, hangCleanup := newHangingServer() + defer hangCleanup() + + c, _, cancel := newGinContextWithTimeout(200 * time.Millisecond) + defer cancel() + + start := time.Now() + resp, err := GetClient(c, hangServer.URL, "chat/completions", []byte(`{"test":"data"}`)) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected context deadline exceeded or canceled, got: %v", err) + assert.Nil(t, resp) + assert.Less(t, elapsed, 5*time.Second) +} + +func TestGetClientWithRetry_TimeoutAcrossRetries(t *testing.T) { + var hitCount atomic.Int32 + done := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hitCount.Add(1) + <-done + })) + defer func() { + close(done) + server.Close() + }() + + c, _, cancel := newGinContextWithTimeout(200 * time.Millisecond) + defer cancel() + + start := time.Now() + resp, err := GetClientWithRetry(c, []byte(`{"test":"data"}`), server.URL, "chat/completions") + elapsed := time.Since(start) + + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "expected context deadline exceeded or canceled, got: %v", err) + assert.Nil(t, resp) + // Should not have completed all 3 retries; the shared context expires + assert.Less(t, elapsed, 5*time.Second) + // At most 3 attempts, but with a 200ms timeout the context should expire during/after the first attempt + assert.LessOrEqual(t, hitCount.Load(), int32(3)) +} + +func TestGetClientWithRetry_ContextCancelled(t *testing.T) { + done := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-done + })) + defer func() { + close(done) + server.Close() + }() + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("POST", "/v1/chat/completions", + bytes.NewBufferString(`{"test":"data"}`)) + c.Request = req.WithContext(ctx) + + type result struct { + resp *http.Response + err error + } + ch := make(chan result, 1) + go func() { + resp, err := GetClientWithRetry(c, []byte(`{"test":"data"}`), server.URL, "chat/completions") + ch <- result{resp, err} + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case res := <-ch: + assert.Error(t, res.err) + assert.True(t, errors.Is(res.err, context.Canceled), + "expected context.Canceled, got: %v", res.err) + assert.Nil(t, res.resp) + case <-time.After(5 * time.Second): + t.Fatal("GetClientWithRetry did not return after context cancellation") + } +} + +// ============================================================ +// Streaming hang / mid-stream interruption tests +// ============================================================ + +func TestRedirect_StreamingHangMidStream(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + initialData := "data: {\"choices\":[{\"text\":\"chunk1\"}]}\n" + reader := &hangingReader{ + data: []byte(initialData), + hangAt: len(initialData), + ctx: ctx, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: reader, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/", nil).WithContext(ctx) + + done := make(chan struct{}) + go func() { + redirect(c, true, resp) + close(done) + }() + + select { + case <-done: + // redirect returned, check partial output was written + assert.Contains(t, w.Body.String(), "data: {\"choices\":[{\"text\":\"chunk1\"}]}") + case <-time.After(5 * time.Second): + t.Fatal("redirect did not return after context timeout") + } +} + +func TestReadPrefillRecv_StreamHang(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + // First chunk followed by a hang + initialData := "data: first-chunk\n" + reader := &hangingReader{ + data: []byte(initialData), + hangAt: len(initialData), + ctx: ctx, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: reader, + } + + done := make(chan struct{}) + go func() { + readPrefillRecv(ctx, "test-url", true, "test message", resp) + close(done) + }() + + select { + case <-done: + // completed without panic + case <-time.After(5 * time.Second): + t.Fatal("readPrefillRecv did not return after context timeout") + } +} + +func TestReadPrefillRecv_NonStreamHang(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + // Reader that immediately hangs (no data before hang) + reader := &hangingReader{ + data: []byte{}, + hangAt: 0, + ctx: ctx, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: reader, + } + + done := make(chan struct{}) + go func() { + readPrefillRecv(ctx, "test-url", false, "test message", resp) + close(done) + }() + + select { + case <-done: + // completed without panic + case <-time.After(5 * time.Second): + t.Fatal("readPrefillRecv did not return after context timeout") + } +} diff --git a/fastdeploy/golang_router/internal/scheduler/handler/handler.go b/fastdeploy/golang_router/internal/scheduler/handler/handler.go index 0d62346de5..070ee2109e 100644 --- a/fastdeploy/golang_router/internal/scheduler/handler/handler.go +++ b/fastdeploy/golang_router/internal/scheduler/handler/handler.go @@ -46,6 +46,7 @@ func Init(cfg *config.Config, managerAPI common.ManagerAPI) { cacheBlockSize: cfg.Scheduler.CacheBlockSize, tokenizerURL: cfg.Scheduler.TokenizerURL, tokenizerTimeout: time.Duration(cfg.Scheduler.TokenizerTimeoutSecs * float64(time.Second)), + evictionDuration: time.Duration(cfg.Scheduler.EvictionDurationMins * float64(time.Minute)), } scheduler := &Scheduler{ @@ -124,7 +125,7 @@ func SelectWorker(ctx context.Context, workers []string, message string, workerT // 2) Prefill: current token processing count (process_tokens) var tokens uint64 - if workerType == "prefill" && message != "" { + if (workerType == "prefill" || workerType == "mixed") && message != "" { tokenCounter := GetOrCreateTokenCounter(ctx, selectWorkerURL) tokenCounter.Add(estimateTokens(message)) tokens = tokenCounter.Get() diff --git a/fastdeploy/golang_router/internal/scheduler/handler/prefill_cache_aware.go b/fastdeploy/golang_router/internal/scheduler/handler/prefill_cache_aware.go index 6f5cb80eec..48737c03c7 100644 --- a/fastdeploy/golang_router/internal/scheduler/handler/prefill_cache_aware.go +++ b/fastdeploy/golang_router/internal/scheduler/handler/prefill_cache_aware.go @@ -38,6 +38,7 @@ type schedulerConfigSnapshot struct { cacheBlockSize int tokenizerURL string tokenizerTimeout time.Duration + evictionDuration time.Duration } // newPrefillCacheStrategy initializes cache-aware strategy config @@ -47,7 +48,7 @@ func newPrefillCacheStrategy(cfg *schedulerConfigSnapshot) *prefillCacheStrategy relThreshold: cfg.balanceRelThreshold, hitRatioWeight: cfg.hitRatioWeight, loadBalanceWeight: cfg.loadBalanceWeight, - cache: newRadixPrefixCache(cfg.cacheBlockSize), + cache: newRadixPrefixCache(cfg.cacheBlockSize, cfg.evictionDuration), tokenizer: NewHTTPTokenizer(cfg.tokenizerURL, cfg.tokenizerTimeout), sessionWorkerMap: make(map[string]string), } @@ -297,11 +298,10 @@ type radixNode struct { } // newRadixPrefixCache initializes radix prefix cache with eviction and capacity control -func newRadixPrefixCache(blockSize int) *radixPrefixCache { +func newRadixPrefixCache(blockSize int, evictionDuration time.Duration) *radixPrefixCache { if blockSize <= 0 { blockSize = 64 } - const defaultEvictionDuration = 5 * time.Minute const defaultMaxNodes = 200000 root := &radixNode{ key: nil, @@ -311,7 +311,7 @@ func newRadixPrefixCache(blockSize int) *radixPrefixCache { cache := &radixPrefixCache{ root: root, hasher: newBlockHasher(blockSize), - evictionDuration: defaultEvictionDuration, + evictionDuration: evictionDuration, maxNodes: defaultMaxNodes, nodeCount: 1, // root allNodes: map[*radixNode]struct{}{root: {}},