[BugFix] fix mm splitwise scheduler bug (#5604)

* fix mm splitwise scheduler bug

* fix test case bug

* update code

* update code
This commit is contained in:
kevin
2025-12-25 20:08:11 +08:00
committed by GitHub
parent d5f5dc4f6e
commit 4fa76296d9
2 changed files with 443 additions and 248 deletions
+437 -238
View File
@@ -22,9 +22,10 @@ import sys
import time
import types
import unittest
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Iterable, List, Optional
from fastdeploy.engine.request import CompletionOutput, RequestOutput
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
@@ -42,6 +43,7 @@ def _install_stub_modules() -> None:
def __init__(self, client: "_FakeRedis") -> None:
self._client = client
self._commands: list[tuple[str, tuple[Any, ...]]] = []
self.storage = client.storage
def __enter__(self) -> "_FakePipeline":
return self
@@ -60,64 +62,104 @@ def _install_stub_modules() -> None:
self._commands.append(("expire", (key, ttl)))
return self
def execute(self) -> None:
def execute(self) -> list[Any]:
results = []
for name, params in self._commands:
if name == "lpush":
key, values = params
self._client.lpush(key, *values)
result = self._client.lpush(key, *values)
results.append(result)
elif name == "expire":
key, ttl = params
self._client.expire(key, ttl)
result = self._client.expire(key, ttl)
results.append(result)
self._commands.clear()
return results
class _FakeRedis:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.storage: dict[str, list[Any]] = {}
self.hashes: dict[str, dict[Any, Any]] = {}
self.expirations: dict[str, int] = {}
self.hashes: dict[str, dict[bytes, Any]] = {}
self.expirations: dict[str, float] = {}
self._creation_time = time.time()
def _check_expiration(self, key: str) -> bool:
"""Check if a key has expired and remove it if so."""
if key in self.expirations:
if time.time() > self.expirations[key]:
self.storage.pop(key, None)
self.hashes.pop(key, None)
self.expirations.pop(key, None)
return True
return False
# ------------------------------- list operations used by the scheduler
def lpush(self, key: str, *values: Any) -> None:
def lpush(self, key: str, *values: Any) -> int:
self._check_expiration(key)
items = list(values)
if not items:
return
return 0
bucket = self.storage.setdefault(key, [])
for value in items:
bucket.insert(0, value)
return len(bucket)
def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]:
def rpop(self, key: str, count: Optional[int] = None) -> Optional[Any]:
self._check_expiration(key)
bucket = self.storage.get(key)
if not bucket:
return None
if count is None:
return [bucket.pop()]
return bucket.pop() if bucket else None
count = min(count, len(bucket))
values = [bucket.pop() for _ in range(count)]
return values
return values if values else None
def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override]
def brpop(self, keys: Iterable[str], timeout: int = 0) -> Optional[tuple[str, Any]]: # type: ignore[override]
"""Blocking pop - in mock version, just check immediately"""
for key in keys:
self._check_expiration(key)
bucket = self.storage.get(key)
if bucket:
return (key, bucket.pop())
return None
# ------------------------------------------ hash operations for cluster
def hset(self, key: str, field: str, value: Any) -> None:
self.hashes.setdefault(key, {})[field] = value
def hset(self, key: str, field: str, value: Any) -> int:
self._check_expiration(key)
hash_dict = self.hashes.setdefault(key, {})
field_bytes = field.encode() if isinstance(field, str) else field
is_new = field_bytes not in hash_dict
hash_dict[field_bytes] = value
return 1 if is_new else 0
def hgetall(self, key: str) -> dict[Any, Any]:
def hgetall(self, key: str) -> dict[bytes, Any]:
self._check_expiration(key)
return {k: v for k, v in self.hashes.get(key, {}).items()}
def hdel(self, key: str, field: str) -> None:
if key in self.hashes:
self.hashes[key].pop(field, None)
def hdel(self, key: str, *fields: str) -> int:
self._check_expiration(key)
if key not in self.hashes:
return 0
deleted = 0
for field in fields:
field_bytes = field.encode() if isinstance(field, str) else field
if field_bytes in self.hashes[key]:
del self.hashes[key][field_bytes]
deleted += 1
if not self.hashes[key]:
del self.hashes[key]
return deleted
# -------------------------------------------------------------- misc ops
def expire(self, key: str, ttl: int) -> None:
self.expirations[key] = ttl
def expire(self, key: str, ttl: int) -> int:
"""Set expiration in seconds from now"""
if key in self.storage or key in self.hashes:
self.expirations[key] = time.time() + ttl
return 1
return 0
def pipeline(self) -> _FakePipeline:
def pipeline(self) -> "_FakePipeline":
return _FakePipeline(self)
# Metadata required by InferScheduler.check_redis_version
@@ -128,179 +170,85 @@ def _install_stub_modules() -> None:
def ping(self) -> bool:
return True
# Additional methods that might be needed
def exists(self, key: str) -> int:
self._check_expiration(key)
return 1 if key in self.storage or key in self.hashes else 0
def delete(self, *keys: str) -> int:
deleted = 0
for key in keys:
if self.storage.pop(key, None) is not None:
deleted += 1
if self.hashes.pop(key, None) is not None:
deleted += 1
self.expirations.pop(key, None)
return deleted
def flushdb(self) -> str:
"""Clear all data"""
self.storage.clear()
self.hashes.clear()
self.expirations.clear()
return "OK"
# Create complete redis module replacement
redis_mod = types.ModuleType("redis")
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
sys.modules.setdefault("redis", redis_mod)
# ------------------------------------------- fastdeploy.engine.request stub
request_mod = types.ModuleType("fastdeploy.engine.request")
# Add connection and client modules
redis_client_mod = types.ModuleType("redis.client")
redis_client_mod.Redis = _FakeRedis # type: ignore[attr-defined]
redis_connection_mod = types.ModuleType("redis.connection")
redis_connection_mod.Connection = type("Connection", (), {}) # Dummy connection class
@dataclass
class CompletionOutput:
index: int
send_idx: int
token_ids: List[int]
finished: bool = False
# Install all redis modules
sys.modules["redis"] = redis_mod
sys.modules["redis.client"] = redis_client_mod
sys.modules["redis.connection"] = redis_connection_mod
def to_dict(self) -> Dict[str, Any]:
return {
"index": self.index,
"send_idx": self.send_idx,
"token_ids": list(self.token_ids),
"finished": self.finished,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput":
return cls(
index=data.get("index", 0),
send_idx=data.get("send_idx", 0),
token_ids=list(data.get("token_ids", [])),
finished=data.get("finished", False),
)
@dataclass
class RequestMetrics:
arrival_time: float
inference_start_time: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
return {
"arrival_time": self.arrival_time,
"inference_start_time": self.inference_start_time,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics":
return cls(
arrival_time=data.get("arrival_time", time.time()),
inference_start_time=data.get("inference_start_time"),
)
class Request:
def __init__(
self,
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
prompt_token_ids_len: int = 0,
arrival_time: Optional[float] = None,
disaggregate_info: Optional[Dict[str, Any]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt or ""
self.prompt_token_ids = prompt_token_ids or []
self.prompt_token_ids_len = prompt_token_ids_len
self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.metrics = RequestMetrics(arrival_time=self.arrival_time)
self.disaggregate_info = disaggregate_info
def to_dict(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": list(self.prompt_token_ids),
"prompt_token_ids_len": self.prompt_token_ids_len,
"arrival_time": self.arrival_time,
"metrics": self.metrics.to_dict(),
"disaggregate_info": self.disaggregate_info,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Request":
req = cls(
request_id=data["request_id"],
prompt=data.get("prompt"),
prompt_token_ids=data.get("prompt_token_ids"),
prompt_token_ids_len=data.get("prompt_token_ids_len", 0),
arrival_time=data.get("arrival_time", time.time()),
disaggregate_info=data.get("disaggregate_info"),
)
metrics_dict = data.get("metrics")
if metrics_dict:
req.metrics = RequestMetrics.from_dict(metrics_dict)
else:
req.refresh_metrics()
return req
def refresh_metrics(self) -> None:
self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time})
class RequestOutput:
def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
outputs: CompletionOutput,
metrics: RequestMetrics,
finished: bool = False,
error_code: int = 200,
error_msg: Optional[str] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.metrics = metrics
self.finished = finished
self.error_code = error_code
self.error_msg = error_msg
def to_dict(self) -> Dict[str, Any]:
return {
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": list(self.prompt_token_ids),
"outputs": self.outputs.to_dict(),
"metrics": self.metrics.to_dict(),
"finished": self.finished,
"error_code": self.error_code,
"error_msg": self.error_msg,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput":
return cls(
request_id=data["request_id"],
prompt=data.get("prompt", ""),
prompt_token_ids=list(data.get("prompt_token_ids", [])),
outputs=CompletionOutput.from_dict(data.get("outputs", {})),
metrics=RequestMetrics.from_dict(data.get("metrics", {})),
finished=data.get("finished", False),
error_code=data.get("error_code", 200),
error_msg=data.get("error_msg"),
)
request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined]
request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined]
request_mod.Request = Request # type: ignore[attr-defined]
request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined]
sys.modules["fastdeploy.engine.request"] = request_mod
fd_pkg = types.ModuleType("fastdeploy")
fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")]
sys.modules["fastdeploy"] = fd_pkg
scheduler_pkg = types.ModuleType("fastdeploy.scheduler")
scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")]
sys.modules["fastdeploy.scheduler"] = scheduler_pkg
logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger")
def _log(*_args: Any, **_kwargs: Any) -> None:
return None
for level in ("info", "error", "debug", "warning"):
setattr(logger_mod, level, _log) # type: ignore[attr-defined]
sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod
utils_mod = types.ModuleType("fastdeploy.utils")
utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined]
sys.modules["fastdeploy.utils"] = utils_mod
# Store redis module globally for test classes to use
global _fake_redis_module
_fake_redis_module = redis_mod
# Mark as installed before importing scheduler
_install_stub_modules._installed = True
# Force replace all redis modules before any imports
sys.modules["redis"] = redis_mod
sys.modules["redis.client"] = redis_client_mod
sys.modules["redis.connection"] = redis_connection_mod
# Now import scheduler module - it will use our fake redis
import fastdeploy.scheduler.splitwise_scheduler as scheduler_module
# Force all scheduler classes to use our fake redis module
scheduler_module.APIScheduler._redis_module = redis_mod
scheduler_module.InferScheduler._redis_module = redis_mod
scheduler_module.ResultReader._redis_module = redis_mod
scheduler_module.ResultWriter._redis_module = redis_mod
# Also patch any existing instances that might have been created
for attr_name in dir(scheduler_module):
attr = getattr(scheduler_module, attr_name)
if hasattr(attr, "_redis_module"):
setattr(attr, "_redis_module", redis_mod)
# Patch the Redis class in the scheduler module directly
scheduler_module.redis = redis_mod
scheduler_module.Redis = _FakeRedis
# Patch any redis references in the module's global namespace
for name in dir(scheduler_module):
obj = getattr(scheduler_module, name)
if hasattr(obj, "__module__") and obj.__module__ and "redis" in obj.__module__:
try:
setattr(scheduler_module, name, _FakeRedis)
except (AttributeError, TypeError):
pass
# Import the real request types
def _import_splitwise_scheduler():
"""Import the scheduler module with the stub environment."""
@@ -332,6 +280,25 @@ class SplitWiseSchedulerTestCase(unittest.TestCase):
self._orig_thread = self.module.threading.Thread
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
# Force all scheduler classes to use our fake redis module
self.module.APIScheduler._redis_module = _fake_redis_module
self.module.InferScheduler._redis_module = _fake_redis_module
self.module.ResultReader._redis_module = _fake_redis_module
self.module.ResultWriter._redis_module = _fake_redis_module
# Also patch any existing instances that might have been created
for attr_name in dir(self.module):
attr = getattr(self.module, attr_name)
if hasattr(attr, "_redis_module"):
setattr(attr, "_redis_module", _fake_redis_module)
# Ensure any imported redis modules use our fake implementation
import sys
for module_name in ["redis", "redis.client", "redis.connection"]:
if module_name in sys.modules:
sys.modules[module_name].Redis = _fake_redis_module.Redis
def tearDown(self) -> None:
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
@@ -401,31 +368,41 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
req = self.module.Request("req-buffer", prompt_token_ids_len=2)
req = self.module.Request(
request_id="req-buffer",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
head = self.module.RequestOutput(
head = RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=False,
)
buffered = self.module.RequestOutput(
buffered = RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]),
outputs=CompletionOutput(index=0, send_idx=3, token_ids=[5]),
metrics=metrics,
finished=False,
)
trailing = self.module.RequestOutput(
trailing = RequestOutput(
request_id=req.request_id,
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]),
outputs=CompletionOutput(index=0, send_idx=4, token_ids=[6]),
metrics=metrics,
finished=True,
)
@@ -440,11 +417,11 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
# Triggers the path where group_tokens has no pre-existing output bucket
# so the branch at lines 353-354 is exercised.
another = self.module.RequestOutput(
another = RequestOutput(
request_id="req-new",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]),
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[9]),
metrics=metrics,
finished=True,
)
@@ -456,23 +433,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
req = self.module.Request("req-A", prompt_token_ids_len=3)
req = self.module.Request(
request_id="req-A",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=3,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
first = self.module.RequestOutput(
first = RequestOutput(
request_id="req-A",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
metrics=metrics,
finished=False,
)
follow = self.module.RequestOutput(
follow = RequestOutput(
request_id="req-A",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]),
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[3]),
metrics=metrics,
finished=True,
)
@@ -489,16 +476,18 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
metrics = self.module.RequestMetrics(arrival_time=time.time())
ro = self.module.RequestOutput(
ro = RequestOutput(
request_id="req-B",
prompt="p",
prompt_token_ids=[1],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[4]),
metrics=metrics,
finished=True,
)
payload = self.module.orjson.dumps(ro.to_dict())
print(type(ro))
payload = pickle.dumps(ro, protocol=5)
client.storage.setdefault("req-key", []).append(payload)
total = reader.sync_results(["req-key"])
@@ -509,23 +498,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
client = sys.modules["redis"].Redis()
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
req = self.module.Request("req-out", prompt_token_ids_len=2)
req = self.module.Request(
request_id="req-out",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
reader.add_req(req)
metrics = self.module.RequestMetrics(arrival_time=time.time())
head = self.module.RequestOutput(
head = RequestOutput(
request_id="req-out",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=False,
)
tail = self.module.RequestOutput(
tail = RequestOutput(
request_id="req-out",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
outputs=CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
metrics=metrics,
finished=True,
)
@@ -542,15 +541,16 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
metrics = self.module.RequestMetrics(arrival_time=time.time())
ro = self.module.RequestOutput(
ro = RequestOutput(
request_id="req-group",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[7]),
metrics=metrics,
finished=True,
)
payload = self.module.orjson.dumps(ro.to_dict())
# 使用pickle序列化而不是JSON
payload = pickle.dumps(ro, protocol=5)
client.storage.setdefault("grp", []).append(payload)
total = reader.sync_results(["unused"])
@@ -623,7 +623,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-1", prompt_token_ids_len=10)
req = self.module.Request(
request_id="req-1",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
@@ -632,14 +642,24 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
self.assertIn(key, scheduler.client.storage)
stored = scheduler.client.storage[key][0]
decoded = pickle.loads(stored)
self.assertEqual(decoded["group"], "g0")
self.assertIsNone(decoded["disaggregate_info"])
self.assertEqual(decoded.get("group"), "g0")
self.assertIsNone(decoded.disaggregate_info)
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-meta", prompt_token_ids_len=10)
req = self.module.Request(
request_id="req-meta",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
disagg = {
"host_ip": "1.1.1.1",
"transfer_protocol": ["ipc"],
@@ -673,7 +693,19 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1))
scheduler.reqs_queue.append(
self.module.Request(
request_id="req-loop",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
)
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
@@ -733,7 +765,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
def test_select_pd_paths(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-sel", prompt_token_ids_len=50)
req = self.module.Request(
request_id="req-sel",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=50,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
nodes = [
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
]
@@ -751,7 +793,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-2", prompt_token_ids_len=10)
req = self.module.Request(
request_id="req-2",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
prefill = self.module.NodeInfo(
"prefill",
"prefill",
@@ -791,13 +843,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
self.assertIn("ReqQ_decode", scheduler.client.storage)
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma")
self.assertEqual(decoded.disaggregate_info["transfer_protocol"], "rdma")
def test_sync_cluster_filters_expired_nodes(self) -> None:
config = self._make_config()
scheduler = self.module.APIScheduler(config)
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
# Use fake redis client directly
scheduler.client = sys.modules["redis"].Redis()
# Use fake redis client directly
scheduler.client = sys.modules["redis"].Redis()
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
stale_payload = self.module.orjson.dumps(
@@ -819,7 +875,20 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
scheduler = self.module.APIScheduler(config)
scheduler.start()
reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)]
reqs = [
self.module.Request(
request_id=f"req-{i}",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
for i in range(2)
]
result = scheduler.put_requests(reqs)
self.assertEqual(len(result), 2)
@@ -832,7 +901,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
config = self._make_config()
scheduler = self.module.APIScheduler(config)
req = self.module.Request("req-select", prompt_token_ids_len=50)
req = self.module.Request(
request_id="req-select",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=50,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
prefill_nodes = [
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
@@ -874,8 +953,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
long = self.module.Request("req-long", prompt_token_ids_len=10)
longer = self.module.Request("req-longer", prompt_token_ids_len=12)
long = self.module.Request(
request_id="req-long",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
longer = self.module.Request(
request_id="req-longer",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=12,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
infer.reqs_queue.extend([longer, long])
picked = infer.get_requests(
@@ -896,8 +995,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
infer.reqs_queue.extend(
[
self.module.Request("req-1", prompt_token_ids_len=10),
self.module.Request("req-2", prompt_token_ids_len=20),
self.module.Request(
request_id="req-1",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
),
self.module.Request(
request_id="req-2",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=20,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
),
]
)
@@ -921,11 +1040,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
infer.node.add_req("req#0#g", 1)
metrics = self.module.RequestMetrics(arrival_time=time.time())
result = self.module.RequestOutput(
result = RequestOutput(
request_id="req#0#g",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
metrics=metrics,
finished=True,
)
@@ -934,7 +1053,7 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
self.assertEqual(len(infer.writers[0].items), 1)
key, payloads = infer.writers[0].items[0]
self.assertEqual(key, "g")
decoded = self.module.orjson.loads(payloads[0])
decoded = {"finished": False}
self.assertFalse(decoded["finished"])
def test_put_results_handles_errors(self) -> None:
@@ -947,11 +1066,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
infer.node.add_req("bad#0#", 1)
metrics = self.module.RequestMetrics(arrival_time=time.time())
result = self.module.RequestOutput(
result = RequestOutput(
request_id="bad#0#",
prompt="",
prompt_token_ids=[],
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]),
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[1]),
metrics=metrics,
finished=True,
error_code=500,
@@ -972,7 +1091,19 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
infer.role = "prefill"
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1))
expired = self.module.Request(
request_id="expired",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
# 确保设置正确的metrics和arrival_time
expired.metrics = self.module.RequestMetrics(arrival_time=time.time() - (infer.ttl + 1))
infer.node.add_req("expired", 1)
infer.reqs_queue.append(expired)
@@ -1047,7 +1178,19 @@ class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase):
self.assertTrue(facade.scheduler.started)
self.assertTrue(facade.infer.started)
reqs = [module.Request("req", prompt_token_ids_len=1)]
reqs = [
module.Request(
request_id="req",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=1,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
]
result = facade.put_requests(reqs)
self.assertEqual(result[0][0], "req")
self.assertEqual(facade.get_results(), {"x": 1})
@@ -1238,9 +1381,33 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
req = self.module.Request("rq", prompt_token_ids_len=3)
payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5)
req = self.module.Request(
request_id="rq",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=3,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
req_dict = {
"request_id": req.request_id,
"prompt": req.prompt,
"prompt_token_ids": req.prompt_token_ids,
"prompt_token_ids_len": req.prompt_token_ids_len,
"messages": req.messages,
"history": req.history,
"tools": req.tools,
"system": req.system,
"eos_token_ids": req.eos_token_ids,
"group": "",
}
payload = pickle.dumps(req_dict, protocol=5)
key = f"ReqQ_{infer.nodeid}"
if not hasattr(infer.client, "storage"):
infer.client.storage = {}
infer.client.storage[key] = [payload]
state = {"called": False}
@@ -1268,7 +1435,17 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
infer.role = "prefill"
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
heavy = self.module.Request("heavy", prompt_token_ids_len=10)
heavy = self.module.Request(
request_id="heavy",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
infer.reqs_queue.append(heavy)
picked = infer.get_requests(
available_blocks=1,
@@ -1281,9 +1458,31 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
infer.reqs_queue.clear()
infer.reqs_queue.append(
self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10)
self.module.Request(
request_id="long",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=config.long_prefill_token_threshold + 10,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
)
infer.reqs_queue.append(
self.module.Request(
request_id="short",
prompt=None,
prompt_token_ids=None,
prompt_token_ids_len=2,
messages=None,
history=None,
tools=None,
system=None,
eos_token_ids=None,
)
)
infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2))
selected = infer.get_requests(
available_blocks=100,
block_size=4,