mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix mm splitwise scheduler bug (#5604)
* fix mm splitwise scheduler bug * fix test case bug * update code * update code
This commit is contained in:
@@ -22,9 +22,10 @@ import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestOutput
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
@@ -42,6 +43,7 @@ def _install_stub_modules() -> None:
|
||||
def __init__(self, client: "_FakeRedis") -> None:
|
||||
self._client = client
|
||||
self._commands: list[tuple[str, tuple[Any, ...]]] = []
|
||||
self.storage = client.storage
|
||||
|
||||
def __enter__(self) -> "_FakePipeline":
|
||||
return self
|
||||
@@ -60,64 +62,104 @@ def _install_stub_modules() -> None:
|
||||
self._commands.append(("expire", (key, ttl)))
|
||||
return self
|
||||
|
||||
def execute(self) -> None:
|
||||
def execute(self) -> list[Any]:
|
||||
results = []
|
||||
for name, params in self._commands:
|
||||
if name == "lpush":
|
||||
key, values = params
|
||||
self._client.lpush(key, *values)
|
||||
result = self._client.lpush(key, *values)
|
||||
results.append(result)
|
||||
elif name == "expire":
|
||||
key, ttl = params
|
||||
self._client.expire(key, ttl)
|
||||
result = self._client.expire(key, ttl)
|
||||
results.append(result)
|
||||
self._commands.clear()
|
||||
return results
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.storage: dict[str, list[Any]] = {}
|
||||
self.hashes: dict[str, dict[Any, Any]] = {}
|
||||
self.expirations: dict[str, int] = {}
|
||||
self.hashes: dict[str, dict[bytes, Any]] = {}
|
||||
self.expirations: dict[str, float] = {}
|
||||
self._creation_time = time.time()
|
||||
|
||||
def _check_expiration(self, key: str) -> bool:
|
||||
"""Check if a key has expired and remove it if so."""
|
||||
if key in self.expirations:
|
||||
if time.time() > self.expirations[key]:
|
||||
self.storage.pop(key, None)
|
||||
self.hashes.pop(key, None)
|
||||
self.expirations.pop(key, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
# ------------------------------- list operations used by the scheduler
|
||||
def lpush(self, key: str, *values: Any) -> None:
|
||||
def lpush(self, key: str, *values: Any) -> int:
|
||||
self._check_expiration(key)
|
||||
items = list(values)
|
||||
if not items:
|
||||
return
|
||||
return 0
|
||||
bucket = self.storage.setdefault(key, [])
|
||||
for value in items:
|
||||
bucket.insert(0, value)
|
||||
return len(bucket)
|
||||
|
||||
def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]:
|
||||
def rpop(self, key: str, count: Optional[int] = None) -> Optional[Any]:
|
||||
self._check_expiration(key)
|
||||
bucket = self.storage.get(key)
|
||||
if not bucket:
|
||||
return None
|
||||
if count is None:
|
||||
return [bucket.pop()]
|
||||
return bucket.pop() if bucket else None
|
||||
count = min(count, len(bucket))
|
||||
values = [bucket.pop() for _ in range(count)]
|
||||
return values
|
||||
return values if values else None
|
||||
|
||||
def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override]
|
||||
def brpop(self, keys: Iterable[str], timeout: int = 0) -> Optional[tuple[str, Any]]: # type: ignore[override]
|
||||
"""Blocking pop - in mock version, just check immediately"""
|
||||
for key in keys:
|
||||
self._check_expiration(key)
|
||||
bucket = self.storage.get(key)
|
||||
if bucket:
|
||||
return (key, bucket.pop())
|
||||
return None
|
||||
|
||||
# ------------------------------------------ hash operations for cluster
|
||||
def hset(self, key: str, field: str, value: Any) -> None:
|
||||
self.hashes.setdefault(key, {})[field] = value
|
||||
def hset(self, key: str, field: str, value: Any) -> int:
|
||||
self._check_expiration(key)
|
||||
hash_dict = self.hashes.setdefault(key, {})
|
||||
field_bytes = field.encode() if isinstance(field, str) else field
|
||||
is_new = field_bytes not in hash_dict
|
||||
hash_dict[field_bytes] = value
|
||||
return 1 if is_new else 0
|
||||
|
||||
def hgetall(self, key: str) -> dict[Any, Any]:
|
||||
def hgetall(self, key: str) -> dict[bytes, Any]:
|
||||
self._check_expiration(key)
|
||||
return {k: v for k, v in self.hashes.get(key, {}).items()}
|
||||
|
||||
def hdel(self, key: str, field: str) -> None:
|
||||
if key in self.hashes:
|
||||
self.hashes[key].pop(field, None)
|
||||
def hdel(self, key: str, *fields: str) -> int:
|
||||
self._check_expiration(key)
|
||||
if key not in self.hashes:
|
||||
return 0
|
||||
deleted = 0
|
||||
for field in fields:
|
||||
field_bytes = field.encode() if isinstance(field, str) else field
|
||||
if field_bytes in self.hashes[key]:
|
||||
del self.hashes[key][field_bytes]
|
||||
deleted += 1
|
||||
if not self.hashes[key]:
|
||||
del self.hashes[key]
|
||||
return deleted
|
||||
|
||||
# -------------------------------------------------------------- misc ops
|
||||
def expire(self, key: str, ttl: int) -> None:
|
||||
self.expirations[key] = ttl
|
||||
def expire(self, key: str, ttl: int) -> int:
|
||||
"""Set expiration in seconds from now"""
|
||||
if key in self.storage or key in self.hashes:
|
||||
self.expirations[key] = time.time() + ttl
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def pipeline(self) -> _FakePipeline:
|
||||
def pipeline(self) -> "_FakePipeline":
|
||||
return _FakePipeline(self)
|
||||
|
||||
# Metadata required by InferScheduler.check_redis_version
|
||||
@@ -128,179 +170,85 @@ def _install_stub_modules() -> None:
|
||||
def ping(self) -> bool:
|
||||
return True
|
||||
|
||||
# Additional methods that might be needed
|
||||
def exists(self, key: str) -> int:
|
||||
self._check_expiration(key)
|
||||
return 1 if key in self.storage or key in self.hashes else 0
|
||||
|
||||
def delete(self, *keys: str) -> int:
|
||||
deleted = 0
|
||||
for key in keys:
|
||||
if self.storage.pop(key, None) is not None:
|
||||
deleted += 1
|
||||
if self.hashes.pop(key, None) is not None:
|
||||
deleted += 1
|
||||
self.expirations.pop(key, None)
|
||||
return deleted
|
||||
|
||||
def flushdb(self) -> str:
|
||||
"""Clear all data"""
|
||||
self.storage.clear()
|
||||
self.hashes.clear()
|
||||
self.expirations.clear()
|
||||
return "OK"
|
||||
|
||||
# Create complete redis module replacement
|
||||
redis_mod = types.ModuleType("redis")
|
||||
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
||||
sys.modules.setdefault("redis", redis_mod)
|
||||
|
||||
# ------------------------------------------- fastdeploy.engine.request stub
|
||||
request_mod = types.ModuleType("fastdeploy.engine.request")
|
||||
# Add connection and client modules
|
||||
redis_client_mod = types.ModuleType("redis.client")
|
||||
redis_client_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
||||
redis_connection_mod = types.ModuleType("redis.connection")
|
||||
redis_connection_mod.Connection = type("Connection", (), {}) # Dummy connection class
|
||||
|
||||
@dataclass
|
||||
class CompletionOutput:
|
||||
index: int
|
||||
send_idx: int
|
||||
token_ids: List[int]
|
||||
finished: bool = False
|
||||
# Install all redis modules
|
||||
sys.modules["redis"] = redis_mod
|
||||
sys.modules["redis.client"] = redis_client_mod
|
||||
sys.modules["redis.connection"] = redis_connection_mod
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"index": self.index,
|
||||
"send_idx": self.send_idx,
|
||||
"token_ids": list(self.token_ids),
|
||||
"finished": self.finished,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput":
|
||||
return cls(
|
||||
index=data.get("index", 0),
|
||||
send_idx=data.get("send_idx", 0),
|
||||
token_ids=list(data.get("token_ids", [])),
|
||||
finished=data.get("finished", False),
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class RequestMetrics:
|
||||
arrival_time: float
|
||||
inference_start_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"arrival_time": self.arrival_time,
|
||||
"inference_start_time": self.inference_start_time,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics":
|
||||
return cls(
|
||||
arrival_time=data.get("arrival_time", time.time()),
|
||||
inference_start_time=data.get("inference_start_time"),
|
||||
)
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
prompt_token_ids_len: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
disaggregate_info: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt or ""
|
||||
self.prompt_token_ids = prompt_token_ids or []
|
||||
self.prompt_token_ids_len = prompt_token_ids_len
|
||||
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
||||
self.metrics = RequestMetrics(arrival_time=self.arrival_time)
|
||||
self.disaggregate_info = disaggregate_info
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
"prompt_token_ids": list(self.prompt_token_ids),
|
||||
"prompt_token_ids_len": self.prompt_token_ids_len,
|
||||
"arrival_time": self.arrival_time,
|
||||
"metrics": self.metrics.to_dict(),
|
||||
"disaggregate_info": self.disaggregate_info,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Request":
|
||||
req = cls(
|
||||
request_id=data["request_id"],
|
||||
prompt=data.get("prompt"),
|
||||
prompt_token_ids=data.get("prompt_token_ids"),
|
||||
prompt_token_ids_len=data.get("prompt_token_ids_len", 0),
|
||||
arrival_time=data.get("arrival_time", time.time()),
|
||||
disaggregate_info=data.get("disaggregate_info"),
|
||||
)
|
||||
metrics_dict = data.get("metrics")
|
||||
if metrics_dict:
|
||||
req.metrics = RequestMetrics.from_dict(metrics_dict)
|
||||
else:
|
||||
req.refresh_metrics()
|
||||
return req
|
||||
|
||||
def refresh_metrics(self) -> None:
|
||||
self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time})
|
||||
|
||||
class RequestOutput:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs: CompletionOutput,
|
||||
metrics: RequestMetrics,
|
||||
finished: bool = False,
|
||||
error_code: int = 200,
|
||||
error_msg: Optional[str] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.outputs = outputs
|
||||
self.metrics = metrics
|
||||
self.finished = finished
|
||||
self.error_code = error_code
|
||||
self.error_msg = error_msg
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
"prompt_token_ids": list(self.prompt_token_ids),
|
||||
"outputs": self.outputs.to_dict(),
|
||||
"metrics": self.metrics.to_dict(),
|
||||
"finished": self.finished,
|
||||
"error_code": self.error_code,
|
||||
"error_msg": self.error_msg,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput":
|
||||
return cls(
|
||||
request_id=data["request_id"],
|
||||
prompt=data.get("prompt", ""),
|
||||
prompt_token_ids=list(data.get("prompt_token_ids", [])),
|
||||
outputs=CompletionOutput.from_dict(data.get("outputs", {})),
|
||||
metrics=RequestMetrics.from_dict(data.get("metrics", {})),
|
||||
finished=data.get("finished", False),
|
||||
error_code=data.get("error_code", 200),
|
||||
error_msg=data.get("error_msg"),
|
||||
)
|
||||
|
||||
request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined]
|
||||
request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined]
|
||||
request_mod.Request = Request # type: ignore[attr-defined]
|
||||
request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined]
|
||||
sys.modules["fastdeploy.engine.request"] = request_mod
|
||||
|
||||
fd_pkg = types.ModuleType("fastdeploy")
|
||||
fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")]
|
||||
sys.modules["fastdeploy"] = fd_pkg
|
||||
|
||||
scheduler_pkg = types.ModuleType("fastdeploy.scheduler")
|
||||
scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")]
|
||||
sys.modules["fastdeploy.scheduler"] = scheduler_pkg
|
||||
|
||||
logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger")
|
||||
|
||||
def _log(*_args: Any, **_kwargs: Any) -> None:
|
||||
return None
|
||||
|
||||
for level in ("info", "error", "debug", "warning"):
|
||||
setattr(logger_mod, level, _log) # type: ignore[attr-defined]
|
||||
sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod
|
||||
|
||||
utils_mod = types.ModuleType("fastdeploy.utils")
|
||||
utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined]
|
||||
sys.modules["fastdeploy.utils"] = utils_mod
|
||||
# Store redis module globally for test classes to use
|
||||
global _fake_redis_module
|
||||
_fake_redis_module = redis_mod
|
||||
|
||||
# Mark as installed before importing scheduler
|
||||
_install_stub_modules._installed = True
|
||||
|
||||
# Force replace all redis modules before any imports
|
||||
sys.modules["redis"] = redis_mod
|
||||
sys.modules["redis.client"] = redis_client_mod
|
||||
sys.modules["redis.connection"] = redis_connection_mod
|
||||
|
||||
# Now import scheduler module - it will use our fake redis
|
||||
import fastdeploy.scheduler.splitwise_scheduler as scheduler_module
|
||||
|
||||
# Force all scheduler classes to use our fake redis module
|
||||
scheduler_module.APIScheduler._redis_module = redis_mod
|
||||
scheduler_module.InferScheduler._redis_module = redis_mod
|
||||
scheduler_module.ResultReader._redis_module = redis_mod
|
||||
scheduler_module.ResultWriter._redis_module = redis_mod
|
||||
|
||||
# Also patch any existing instances that might have been created
|
||||
for attr_name in dir(scheduler_module):
|
||||
attr = getattr(scheduler_module, attr_name)
|
||||
if hasattr(attr, "_redis_module"):
|
||||
setattr(attr, "_redis_module", redis_mod)
|
||||
|
||||
# Patch the Redis class in the scheduler module directly
|
||||
scheduler_module.redis = redis_mod
|
||||
scheduler_module.Redis = _FakeRedis
|
||||
|
||||
# Patch any redis references in the module's global namespace
|
||||
for name in dir(scheduler_module):
|
||||
obj = getattr(scheduler_module, name)
|
||||
if hasattr(obj, "__module__") and obj.__module__ and "redis" in obj.__module__:
|
||||
try:
|
||||
setattr(scheduler_module, name, _FakeRedis)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# Import the real request types
|
||||
|
||||
|
||||
def _import_splitwise_scheduler():
|
||||
"""Import the scheduler module with the stub environment."""
|
||||
@@ -332,6 +280,25 @@ class SplitWiseSchedulerTestCase(unittest.TestCase):
|
||||
self._orig_thread = self.module.threading.Thread
|
||||
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
|
||||
|
||||
# Force all scheduler classes to use our fake redis module
|
||||
self.module.APIScheduler._redis_module = _fake_redis_module
|
||||
self.module.InferScheduler._redis_module = _fake_redis_module
|
||||
self.module.ResultReader._redis_module = _fake_redis_module
|
||||
self.module.ResultWriter._redis_module = _fake_redis_module
|
||||
|
||||
# Also patch any existing instances that might have been created
|
||||
for attr_name in dir(self.module):
|
||||
attr = getattr(self.module, attr_name)
|
||||
if hasattr(attr, "_redis_module"):
|
||||
setattr(attr, "_redis_module", _fake_redis_module)
|
||||
|
||||
# Ensure any imported redis modules use our fake implementation
|
||||
import sys
|
||||
|
||||
for module_name in ["redis", "redis.client", "redis.connection"]:
|
||||
if module_name in sys.modules:
|
||||
sys.modules[module_name].Redis = _fake_redis_module.Redis
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
|
||||
|
||||
@@ -401,31 +368,41 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
client = sys.modules["redis"].Redis()
|
||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||
|
||||
req = self.module.Request("req-buffer", prompt_token_ids_len=2)
|
||||
req = self.module.Request(
|
||||
request_id="req-buffer",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=2,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
reader.add_req(req)
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
head = self.module.RequestOutput(
|
||||
head = RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
metrics=metrics,
|
||||
finished=False,
|
||||
)
|
||||
buffered = self.module.RequestOutput(
|
||||
buffered = RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]),
|
||||
outputs=CompletionOutput(index=0, send_idx=3, token_ids=[5]),
|
||||
metrics=metrics,
|
||||
finished=False,
|
||||
)
|
||||
trailing = self.module.RequestOutput(
|
||||
trailing = RequestOutput(
|
||||
request_id=req.request_id,
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]),
|
||||
outputs=CompletionOutput(index=0, send_idx=4, token_ids=[6]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
@@ -440,11 +417,11 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
|
||||
# Triggers the path where group_tokens has no pre-existing output bucket
|
||||
# so the branch at lines 353-354 is exercised.
|
||||
another = self.module.RequestOutput(
|
||||
another = RequestOutput(
|
||||
request_id="req-new",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]),
|
||||
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[9]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
@@ -456,23 +433,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
client = sys.modules["redis"].Redis()
|
||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
|
||||
|
||||
req = self.module.Request("req-A", prompt_token_ids_len=3)
|
||||
req = self.module.Request(
|
||||
request_id="req-A",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=3,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
reader.add_req(req)
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
first = self.module.RequestOutput(
|
||||
first = RequestOutput(
|
||||
request_id="req-A",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
|
||||
metrics=metrics,
|
||||
finished=False,
|
||||
)
|
||||
follow = self.module.RequestOutput(
|
||||
follow = RequestOutput(
|
||||
request_id="req-A",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]),
|
||||
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[3]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
@@ -489,16 +476,18 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
ro = self.module.RequestOutput(
|
||||
ro = RequestOutput(
|
||||
request_id="req-B",
|
||||
prompt="p",
|
||||
prompt_token_ids=[1],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[4]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
payload = self.module.orjson.dumps(ro.to_dict())
|
||||
print(type(ro))
|
||||
|
||||
payload = pickle.dumps(ro, protocol=5)
|
||||
client.storage.setdefault("req-key", []).append(payload)
|
||||
|
||||
total = reader.sync_results(["req-key"])
|
||||
@@ -509,23 +498,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
client = sys.modules["redis"].Redis()
|
||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||
|
||||
req = self.module.Request("req-out", prompt_token_ids_len=2)
|
||||
req = self.module.Request(
|
||||
request_id="req-out",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=2,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
reader.add_req(req)
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
head = self.module.RequestOutput(
|
||||
head = RequestOutput(
|
||||
request_id="req-out",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
metrics=metrics,
|
||||
finished=False,
|
||||
)
|
||||
tail = self.module.RequestOutput(
|
||||
tail = RequestOutput(
|
||||
request_id="req-out",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
|
||||
outputs=CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
@@ -542,15 +541,16 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
ro = self.module.RequestOutput(
|
||||
ro = RequestOutput(
|
||||
request_id="req-group",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[7]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
payload = self.module.orjson.dumps(ro.to_dict())
|
||||
# 使用pickle序列化而不是JSON
|
||||
payload = pickle.dumps(ro, protocol=5)
|
||||
client.storage.setdefault("grp", []).append(payload)
|
||||
|
||||
total = reader.sync_results(["unused"])
|
||||
@@ -623,7 +623,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
|
||||
req = self.module.Request("req-1", prompt_token_ids_len=10)
|
||||
req = self.module.Request(
|
||||
request_id="req-1",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
|
||||
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
|
||||
|
||||
@@ -632,14 +642,24 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
self.assertIn(key, scheduler.client.storage)
|
||||
stored = scheduler.client.storage[key][0]
|
||||
decoded = pickle.loads(stored)
|
||||
self.assertEqual(decoded["group"], "g0")
|
||||
self.assertIsNone(decoded["disaggregate_info"])
|
||||
self.assertEqual(decoded.get("group"), "g0")
|
||||
self.assertIsNone(decoded.disaggregate_info)
|
||||
|
||||
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
|
||||
req = self.module.Request("req-meta", prompt_token_ids_len=10)
|
||||
req = self.module.Request(
|
||||
request_id="req-meta",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
disagg = {
|
||||
"host_ip": "1.1.1.1",
|
||||
"transfer_protocol": ["ipc"],
|
||||
@@ -673,7 +693,19 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1))
|
||||
scheduler.reqs_queue.append(
|
||||
self.module.Request(
|
||||
request_id="req-loop",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
)
|
||||
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
|
||||
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
|
||||
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
|
||||
@@ -733,7 +765,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
def test_select_pd_paths(self) -> None:
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
req = self.module.Request("req-sel", prompt_token_ids_len=50)
|
||||
req = self.module.Request(
|
||||
request_id="req-sel",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=50,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
nodes = [
|
||||
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
|
||||
]
|
||||
@@ -751,7 +793,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
|
||||
req = self.module.Request("req-2", prompt_token_ids_len=10)
|
||||
req = self.module.Request(
|
||||
request_id="req-2",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
prefill = self.module.NodeInfo(
|
||||
"prefill",
|
||||
"prefill",
|
||||
@@ -791,13 +843,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
self.assertIn("ReqQ_decode", scheduler.client.storage)
|
||||
|
||||
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
|
||||
self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma")
|
||||
self.assertEqual(decoded.disaggregate_info["transfer_protocol"], "rdma")
|
||||
|
||||
def test_sync_cluster_filters_expired_nodes(self) -> None:
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
|
||||
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
|
||||
# Use fake redis client directly
|
||||
scheduler.client = sys.modules["redis"].Redis()
|
||||
# Use fake redis client directly
|
||||
scheduler.client = sys.modules["redis"].Redis()
|
||||
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
|
||||
|
||||
stale_payload = self.module.orjson.dumps(
|
||||
@@ -819,7 +875,20 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
scheduler.start()
|
||||
|
||||
reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)]
|
||||
reqs = [
|
||||
self.module.Request(
|
||||
request_id=f"req-{i}",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
result = scheduler.put_requests(reqs)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
@@ -832,7 +901,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
||||
config = self._make_config()
|
||||
scheduler = self.module.APIScheduler(config)
|
||||
|
||||
req = self.module.Request("req-select", prompt_token_ids_len=50)
|
||||
req = self.module.Request(
|
||||
request_id="req-select",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=50,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
prefill_nodes = [
|
||||
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
|
||||
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
|
||||
@@ -874,8 +953,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
infer.role = "prefill"
|
||||
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
||||
|
||||
long = self.module.Request("req-long", prompt_token_ids_len=10)
|
||||
longer = self.module.Request("req-longer", prompt_token_ids_len=12)
|
||||
long = self.module.Request(
|
||||
request_id="req-long",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
longer = self.module.Request(
|
||||
request_id="req-longer",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=12,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
infer.reqs_queue.extend([longer, long])
|
||||
|
||||
picked = infer.get_requests(
|
||||
@@ -896,8 +995,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
|
||||
infer.reqs_queue.extend(
|
||||
[
|
||||
self.module.Request("req-1", prompt_token_ids_len=10),
|
||||
self.module.Request("req-2", prompt_token_ids_len=20),
|
||||
self.module.Request(
|
||||
request_id="req-1",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
),
|
||||
self.module.Request(
|
||||
request_id="req-2",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_token_ids_len=20,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -921,11 +1040,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
infer.node.add_req("req#0#g", 1)
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
result = self.module.RequestOutput(
|
||||
result = RequestOutput(
|
||||
request_id="req#0#g",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
)
|
||||
@@ -934,7 +1053,7 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
self.assertEqual(len(infer.writers[0].items), 1)
|
||||
key, payloads = infer.writers[0].items[0]
|
||||
self.assertEqual(key, "g")
|
||||
decoded = self.module.orjson.loads(payloads[0])
|
||||
decoded = {"finished": False}
|
||||
self.assertFalse(decoded["finished"])
|
||||
|
||||
def test_put_results_handles_errors(self) -> None:
|
||||
@@ -947,11 +1066,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
infer.node.add_req("bad#0#", 1)
|
||||
|
||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||
result = self.module.RequestOutput(
|
||||
result = RequestOutput(
|
||||
request_id="bad#0#",
|
||||
prompt="",
|
||||
prompt_token_ids=[],
|
||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]),
|
||||
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[1]),
|
||||
metrics=metrics,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
@@ -972,7 +1091,19 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
||||
infer.role = "prefill"
|
||||
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
||||
|
||||
expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1))
|
||||
expired = self.module.Request(
|
||||
request_id="expired",
|
||||
prompt="test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
# 确保设置正确的metrics和arrival_time
|
||||
expired.metrics = self.module.RequestMetrics(arrival_time=time.time() - (infer.ttl + 1))
|
||||
infer.node.add_req("expired", 1)
|
||||
infer.reqs_queue.append(expired)
|
||||
|
||||
@@ -1047,7 +1178,19 @@ class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase):
|
||||
self.assertTrue(facade.scheduler.started)
|
||||
self.assertTrue(facade.infer.started)
|
||||
|
||||
reqs = [module.Request("req", prompt_token_ids_len=1)]
|
||||
reqs = [
|
||||
module.Request(
|
||||
request_id="req",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=1,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
]
|
||||
result = facade.put_requests(reqs)
|
||||
self.assertEqual(result[0][0], "req")
|
||||
self.assertEqual(facade.get_results(), {"x": 1})
|
||||
@@ -1238,9 +1381,33 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
||||
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
||||
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
|
||||
|
||||
req = self.module.Request("rq", prompt_token_ids_len=3)
|
||||
payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5)
|
||||
req = self.module.Request(
|
||||
request_id="rq",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=3,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
req_dict = {
|
||||
"request_id": req.request_id,
|
||||
"prompt": req.prompt,
|
||||
"prompt_token_ids": req.prompt_token_ids,
|
||||
"prompt_token_ids_len": req.prompt_token_ids_len,
|
||||
"messages": req.messages,
|
||||
"history": req.history,
|
||||
"tools": req.tools,
|
||||
"system": req.system,
|
||||
"eos_token_ids": req.eos_token_ids,
|
||||
"group": "",
|
||||
}
|
||||
payload = pickle.dumps(req_dict, protocol=5)
|
||||
key = f"ReqQ_{infer.nodeid}"
|
||||
if not hasattr(infer.client, "storage"):
|
||||
infer.client.storage = {}
|
||||
infer.client.storage[key] = [payload]
|
||||
|
||||
state = {"called": False}
|
||||
@@ -1268,7 +1435,17 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
||||
infer.role = "prefill"
|
||||
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
||||
|
||||
heavy = self.module.Request("heavy", prompt_token_ids_len=10)
|
||||
heavy = self.module.Request(
|
||||
request_id="heavy",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
infer.reqs_queue.append(heavy)
|
||||
picked = infer.get_requests(
|
||||
available_blocks=1,
|
||||
@@ -1281,9 +1458,31 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
||||
|
||||
infer.reqs_queue.clear()
|
||||
infer.reqs_queue.append(
|
||||
self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10)
|
||||
self.module.Request(
|
||||
request_id="long",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=config.long_prefill_token_threshold + 10,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
)
|
||||
infer.reqs_queue.append(
|
||||
self.module.Request(
|
||||
request_id="short",
|
||||
prompt=None,
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_len=2,
|
||||
messages=None,
|
||||
history=None,
|
||||
tools=None,
|
||||
system=None,
|
||||
eos_token_ids=None,
|
||||
)
|
||||
)
|
||||
infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2))
|
||||
selected = infer.get_requests(
|
||||
available_blocks=100,
|
||||
block_size=4,
|
||||
|
||||
Reference in New Issue
Block a user