mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
[BugFix] fix mm splitwise scheduler bug (#5604)
* fix mm splitwise scheduler bug * fix test case bug * update code * update code
This commit is contained in:
@@ -412,8 +412,7 @@ class ResultReader:
|
|||||||
for result in results:
|
for result in results:
|
||||||
try:
|
try:
|
||||||
# logger.info(f"Scheduler Get Results: {result.request_id}")
|
# logger.info(f"Scheduler Get Results: {result.request_id}")
|
||||||
data = orjson.loads(result)
|
result = pickle.loads(result)
|
||||||
result = RequestOutput.from_dict(data)
|
|
||||||
self.data.appendleft(result)
|
self.data.appendleft(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Parse Result Error:{e}, {str(traceback.format_exc())}, {result}")
|
logger.error(f"Parse Result Error:{e}, {str(traceback.format_exc())}, {result}")
|
||||||
@@ -523,9 +522,8 @@ class APIScheduler:
|
|||||||
pnode = self.select_pd(req, pnodes, "prefill")
|
pnode = self.select_pd(req, pnodes, "prefill")
|
||||||
if pnode.role == "mixed":
|
if pnode.role == "mixed":
|
||||||
req.disaggregate_info = None
|
req.disaggregate_info = None
|
||||||
req_dict = req.to_dict()
|
req.set("group", group)
|
||||||
req_dict["group"] = group
|
req_str = pickle.dumps(req, protocol=5)
|
||||||
req_str = pickle.dumps(req_dict, protocol=5)
|
|
||||||
pkey = f"ReqQ_{pnode.nodeid}"
|
pkey = f"ReqQ_{pnode.nodeid}"
|
||||||
# logger.info(f"Schedule Req {req_str} to Mixed")
|
# logger.info(f"Schedule Req {req_str} to Mixed")
|
||||||
self.client.lpush(pkey, req_str)
|
self.client.lpush(pkey, req_str)
|
||||||
@@ -553,9 +551,8 @@ class APIScheduler:
|
|||||||
|
|
||||||
req.disaggregate_info = disaggregate_info
|
req.disaggregate_info = disaggregate_info
|
||||||
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
|
||||||
req_dict = req.to_dict()
|
req.set("group", group)
|
||||||
req_dict["group"] = group
|
req_str = pickle.dumps(req, protocol=5)
|
||||||
req_str = pickle.dumps(req_dict, protocol=5)
|
|
||||||
# logger.info(f"Schedule Req {req_str}")
|
# logger.info(f"Schedule Req {req_str}")
|
||||||
self.client.lpush(dkey, req_str)
|
self.client.lpush(dkey, req_str)
|
||||||
self.client.lpush(pkey, req_str)
|
self.client.lpush(pkey, req_str)
|
||||||
@@ -807,7 +804,6 @@ class InferScheduler:
|
|||||||
for req_str in reqs:
|
for req_str in reqs:
|
||||||
req = pickle.loads(req_str)
|
req = pickle.loads(req_str)
|
||||||
group = req.get("group", "")
|
group = req.get("group", "")
|
||||||
req = Request.from_dict(req)
|
|
||||||
writer_idx = select_writer(req)
|
writer_idx = select_writer(req)
|
||||||
logger.info(f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}")
|
logger.info(f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}")
|
||||||
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
|
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
|
||||||
@@ -902,7 +898,7 @@ class InferScheduler:
|
|||||||
if self.role == "prefill" and result.outputs.send_idx == 0:
|
if self.role == "prefill" and result.outputs.send_idx == 0:
|
||||||
result.finished = False
|
result.finished = False
|
||||||
|
|
||||||
result_str = orjson.dumps(result.to_dict())
|
result_str = pickle.dumps(result, protocol=5)
|
||||||
# if self.role == "prefill" or result.error_code != 200 or result.finished:
|
# if self.role == "prefill" or result.error_code != 200 or result.finished:
|
||||||
# logger.info(f"Infer Put Finish Result: {result_str}")
|
# logger.info(f"Infer Put Finish Result: {result_str}")
|
||||||
groups[key].append(result_str)
|
groups[key].append(result_str)
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
from fastdeploy.engine.request import CompletionOutput, RequestOutput
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
@@ -42,6 +43,7 @@ def _install_stub_modules() -> None:
|
|||||||
def __init__(self, client: "_FakeRedis") -> None:
|
def __init__(self, client: "_FakeRedis") -> None:
|
||||||
self._client = client
|
self._client = client
|
||||||
self._commands: list[tuple[str, tuple[Any, ...]]] = []
|
self._commands: list[tuple[str, tuple[Any, ...]]] = []
|
||||||
|
self.storage = client.storage
|
||||||
|
|
||||||
def __enter__(self) -> "_FakePipeline":
|
def __enter__(self) -> "_FakePipeline":
|
||||||
return self
|
return self
|
||||||
@@ -60,64 +62,104 @@ def _install_stub_modules() -> None:
|
|||||||
self._commands.append(("expire", (key, ttl)))
|
self._commands.append(("expire", (key, ttl)))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def execute(self) -> None:
|
def execute(self) -> list[Any]:
|
||||||
|
results = []
|
||||||
for name, params in self._commands:
|
for name, params in self._commands:
|
||||||
if name == "lpush":
|
if name == "lpush":
|
||||||
key, values = params
|
key, values = params
|
||||||
self._client.lpush(key, *values)
|
result = self._client.lpush(key, *values)
|
||||||
|
results.append(result)
|
||||||
elif name == "expire":
|
elif name == "expire":
|
||||||
key, ttl = params
|
key, ttl = params
|
||||||
self._client.expire(key, ttl)
|
result = self._client.expire(key, ttl)
|
||||||
|
results.append(result)
|
||||||
self._commands.clear()
|
self._commands.clear()
|
||||||
|
return results
|
||||||
|
|
||||||
class _FakeRedis:
|
class _FakeRedis:
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
self.storage: dict[str, list[Any]] = {}
|
self.storage: dict[str, list[Any]] = {}
|
||||||
self.hashes: dict[str, dict[Any, Any]] = {}
|
self.hashes: dict[str, dict[bytes, Any]] = {}
|
||||||
self.expirations: dict[str, int] = {}
|
self.expirations: dict[str, float] = {}
|
||||||
|
self._creation_time = time.time()
|
||||||
|
|
||||||
|
def _check_expiration(self, key: str) -> bool:
|
||||||
|
"""Check if a key has expired and remove it if so."""
|
||||||
|
if key in self.expirations:
|
||||||
|
if time.time() > self.expirations[key]:
|
||||||
|
self.storage.pop(key, None)
|
||||||
|
self.hashes.pop(key, None)
|
||||||
|
self.expirations.pop(key, None)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# ------------------------------- list operations used by the scheduler
|
# ------------------------------- list operations used by the scheduler
|
||||||
def lpush(self, key: str, *values: Any) -> None:
|
def lpush(self, key: str, *values: Any) -> int:
|
||||||
|
self._check_expiration(key)
|
||||||
items = list(values)
|
items = list(values)
|
||||||
if not items:
|
if not items:
|
||||||
return
|
return 0
|
||||||
bucket = self.storage.setdefault(key, [])
|
bucket = self.storage.setdefault(key, [])
|
||||||
for value in items:
|
for value in items:
|
||||||
bucket.insert(0, value)
|
bucket.insert(0, value)
|
||||||
|
return len(bucket)
|
||||||
|
|
||||||
def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]:
|
def rpop(self, key: str, count: Optional[int] = None) -> Optional[Any]:
|
||||||
|
self._check_expiration(key)
|
||||||
bucket = self.storage.get(key)
|
bucket = self.storage.get(key)
|
||||||
if not bucket:
|
if not bucket:
|
||||||
return None
|
return None
|
||||||
if count is None:
|
if count is None:
|
||||||
return [bucket.pop()]
|
return bucket.pop() if bucket else None
|
||||||
count = min(count, len(bucket))
|
count = min(count, len(bucket))
|
||||||
values = [bucket.pop() for _ in range(count)]
|
values = [bucket.pop() for _ in range(count)]
|
||||||
return values
|
return values if values else None
|
||||||
|
|
||||||
def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override]
|
def brpop(self, keys: Iterable[str], timeout: int = 0) -> Optional[tuple[str, Any]]: # type: ignore[override]
|
||||||
|
"""Blocking pop - in mock version, just check immediately"""
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
self._check_expiration(key)
|
||||||
bucket = self.storage.get(key)
|
bucket = self.storage.get(key)
|
||||||
if bucket:
|
if bucket:
|
||||||
return (key, bucket.pop())
|
return (key, bucket.pop())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# ------------------------------------------ hash operations for cluster
|
# ------------------------------------------ hash operations for cluster
|
||||||
def hset(self, key: str, field: str, value: Any) -> None:
|
def hset(self, key: str, field: str, value: Any) -> int:
|
||||||
self.hashes.setdefault(key, {})[field] = value
|
self._check_expiration(key)
|
||||||
|
hash_dict = self.hashes.setdefault(key, {})
|
||||||
|
field_bytes = field.encode() if isinstance(field, str) else field
|
||||||
|
is_new = field_bytes not in hash_dict
|
||||||
|
hash_dict[field_bytes] = value
|
||||||
|
return 1 if is_new else 0
|
||||||
|
|
||||||
def hgetall(self, key: str) -> dict[Any, Any]:
|
def hgetall(self, key: str) -> dict[bytes, Any]:
|
||||||
|
self._check_expiration(key)
|
||||||
return {k: v for k, v in self.hashes.get(key, {}).items()}
|
return {k: v for k, v in self.hashes.get(key, {}).items()}
|
||||||
|
|
||||||
def hdel(self, key: str, field: str) -> None:
|
def hdel(self, key: str, *fields: str) -> int:
|
||||||
if key in self.hashes:
|
self._check_expiration(key)
|
||||||
self.hashes[key].pop(field, None)
|
if key not in self.hashes:
|
||||||
|
return 0
|
||||||
|
deleted = 0
|
||||||
|
for field in fields:
|
||||||
|
field_bytes = field.encode() if isinstance(field, str) else field
|
||||||
|
if field_bytes in self.hashes[key]:
|
||||||
|
del self.hashes[key][field_bytes]
|
||||||
|
deleted += 1
|
||||||
|
if not self.hashes[key]:
|
||||||
|
del self.hashes[key]
|
||||||
|
return deleted
|
||||||
|
|
||||||
# -------------------------------------------------------------- misc ops
|
# -------------------------------------------------------------- misc ops
|
||||||
def expire(self, key: str, ttl: int) -> None:
|
def expire(self, key: str, ttl: int) -> int:
|
||||||
self.expirations[key] = ttl
|
"""Set expiration in seconds from now"""
|
||||||
|
if key in self.storage or key in self.hashes:
|
||||||
|
self.expirations[key] = time.time() + ttl
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
def pipeline(self) -> _FakePipeline:
|
def pipeline(self) -> "_FakePipeline":
|
||||||
return _FakePipeline(self)
|
return _FakePipeline(self)
|
||||||
|
|
||||||
# Metadata required by InferScheduler.check_redis_version
|
# Metadata required by InferScheduler.check_redis_version
|
||||||
@@ -128,179 +170,85 @@ def _install_stub_modules() -> None:
|
|||||||
def ping(self) -> bool:
|
def ping(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Additional methods that might be needed
|
||||||
|
def exists(self, key: str) -> int:
|
||||||
|
self._check_expiration(key)
|
||||||
|
return 1 if key in self.storage or key in self.hashes else 0
|
||||||
|
|
||||||
|
def delete(self, *keys: str) -> int:
|
||||||
|
deleted = 0
|
||||||
|
for key in keys:
|
||||||
|
if self.storage.pop(key, None) is not None:
|
||||||
|
deleted += 1
|
||||||
|
if self.hashes.pop(key, None) is not None:
|
||||||
|
deleted += 1
|
||||||
|
self.expirations.pop(key, None)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
def flushdb(self) -> str:
|
||||||
|
"""Clear all data"""
|
||||||
|
self.storage.clear()
|
||||||
|
self.hashes.clear()
|
||||||
|
self.expirations.clear()
|
||||||
|
return "OK"
|
||||||
|
|
||||||
|
# Create complete redis module replacement
|
||||||
redis_mod = types.ModuleType("redis")
|
redis_mod = types.ModuleType("redis")
|
||||||
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
redis_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
||||||
sys.modules.setdefault("redis", redis_mod)
|
|
||||||
|
|
||||||
# ------------------------------------------- fastdeploy.engine.request stub
|
# Add connection and client modules
|
||||||
request_mod = types.ModuleType("fastdeploy.engine.request")
|
redis_client_mod = types.ModuleType("redis.client")
|
||||||
|
redis_client_mod.Redis = _FakeRedis # type: ignore[attr-defined]
|
||||||
|
redis_connection_mod = types.ModuleType("redis.connection")
|
||||||
|
redis_connection_mod.Connection = type("Connection", (), {}) # Dummy connection class
|
||||||
|
|
||||||
@dataclass
|
# Install all redis modules
|
||||||
class CompletionOutput:
|
sys.modules["redis"] = redis_mod
|
||||||
index: int
|
sys.modules["redis.client"] = redis_client_mod
|
||||||
send_idx: int
|
sys.modules["redis.connection"] = redis_connection_mod
|
||||||
token_ids: List[int]
|
|
||||||
finished: bool = False
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
# Store redis module globally for test classes to use
|
||||||
return {
|
global _fake_redis_module
|
||||||
"index": self.index,
|
_fake_redis_module = redis_mod
|
||||||
"send_idx": self.send_idx,
|
|
||||||
"token_ids": list(self.token_ids),
|
|
||||||
"finished": self.finished,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput":
|
|
||||||
return cls(
|
|
||||||
index=data.get("index", 0),
|
|
||||||
send_idx=data.get("send_idx", 0),
|
|
||||||
token_ids=list(data.get("token_ids", [])),
|
|
||||||
finished=data.get("finished", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RequestMetrics:
|
|
||||||
arrival_time: float
|
|
||||||
inference_start_time: Optional[float] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"arrival_time": self.arrival_time,
|
|
||||||
"inference_start_time": self.inference_start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics":
|
|
||||||
return cls(
|
|
||||||
arrival_time=data.get("arrival_time", time.time()),
|
|
||||||
inference_start_time=data.get("inference_start_time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
class Request:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
request_id: str,
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
|
||||||
prompt_token_ids_len: int = 0,
|
|
||||||
arrival_time: Optional[float] = None,
|
|
||||||
disaggregate_info: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self.prompt = prompt or ""
|
|
||||||
self.prompt_token_ids = prompt_token_ids or []
|
|
||||||
self.prompt_token_ids_len = prompt_token_ids_len
|
|
||||||
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
|
||||||
self.metrics = RequestMetrics(arrival_time=self.arrival_time)
|
|
||||||
self.disaggregate_info = disaggregate_info
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"prompt_token_ids": list(self.prompt_token_ids),
|
|
||||||
"prompt_token_ids_len": self.prompt_token_ids_len,
|
|
||||||
"arrival_time": self.arrival_time,
|
|
||||||
"metrics": self.metrics.to_dict(),
|
|
||||||
"disaggregate_info": self.disaggregate_info,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "Request":
|
|
||||||
req = cls(
|
|
||||||
request_id=data["request_id"],
|
|
||||||
prompt=data.get("prompt"),
|
|
||||||
prompt_token_ids=data.get("prompt_token_ids"),
|
|
||||||
prompt_token_ids_len=data.get("prompt_token_ids_len", 0),
|
|
||||||
arrival_time=data.get("arrival_time", time.time()),
|
|
||||||
disaggregate_info=data.get("disaggregate_info"),
|
|
||||||
)
|
|
||||||
metrics_dict = data.get("metrics")
|
|
||||||
if metrics_dict:
|
|
||||||
req.metrics = RequestMetrics.from_dict(metrics_dict)
|
|
||||||
else:
|
|
||||||
req.refresh_metrics()
|
|
||||||
return req
|
|
||||||
|
|
||||||
def refresh_metrics(self) -> None:
|
|
||||||
self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time})
|
|
||||||
|
|
||||||
class RequestOutput:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
request_id: str,
|
|
||||||
prompt: str,
|
|
||||||
prompt_token_ids: List[int],
|
|
||||||
outputs: CompletionOutput,
|
|
||||||
metrics: RequestMetrics,
|
|
||||||
finished: bool = False,
|
|
||||||
error_code: int = 200,
|
|
||||||
error_msg: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self.prompt = prompt
|
|
||||||
self.prompt_token_ids = prompt_token_ids
|
|
||||||
self.outputs = outputs
|
|
||||||
self.metrics = metrics
|
|
||||||
self.finished = finished
|
|
||||||
self.error_code = error_code
|
|
||||||
self.error_msg = error_msg
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"prompt_token_ids": list(self.prompt_token_ids),
|
|
||||||
"outputs": self.outputs.to_dict(),
|
|
||||||
"metrics": self.metrics.to_dict(),
|
|
||||||
"finished": self.finished,
|
|
||||||
"error_code": self.error_code,
|
|
||||||
"error_msg": self.error_msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput":
|
|
||||||
return cls(
|
|
||||||
request_id=data["request_id"],
|
|
||||||
prompt=data.get("prompt", ""),
|
|
||||||
prompt_token_ids=list(data.get("prompt_token_ids", [])),
|
|
||||||
outputs=CompletionOutput.from_dict(data.get("outputs", {})),
|
|
||||||
metrics=RequestMetrics.from_dict(data.get("metrics", {})),
|
|
||||||
finished=data.get("finished", False),
|
|
||||||
error_code=data.get("error_code", 200),
|
|
||||||
error_msg=data.get("error_msg"),
|
|
||||||
)
|
|
||||||
|
|
||||||
request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined]
|
|
||||||
request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined]
|
|
||||||
request_mod.Request = Request # type: ignore[attr-defined]
|
|
||||||
request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined]
|
|
||||||
sys.modules["fastdeploy.engine.request"] = request_mod
|
|
||||||
|
|
||||||
fd_pkg = types.ModuleType("fastdeploy")
|
|
||||||
fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")]
|
|
||||||
sys.modules["fastdeploy"] = fd_pkg
|
|
||||||
|
|
||||||
scheduler_pkg = types.ModuleType("fastdeploy.scheduler")
|
|
||||||
scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")]
|
|
||||||
sys.modules["fastdeploy.scheduler"] = scheduler_pkg
|
|
||||||
|
|
||||||
logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger")
|
|
||||||
|
|
||||||
def _log(*_args: Any, **_kwargs: Any) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for level in ("info", "error", "debug", "warning"):
|
|
||||||
setattr(logger_mod, level, _log) # type: ignore[attr-defined]
|
|
||||||
sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod
|
|
||||||
|
|
||||||
utils_mod = types.ModuleType("fastdeploy.utils")
|
|
||||||
utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined]
|
|
||||||
sys.modules["fastdeploy.utils"] = utils_mod
|
|
||||||
|
|
||||||
|
# Mark as installed before importing scheduler
|
||||||
_install_stub_modules._installed = True
|
_install_stub_modules._installed = True
|
||||||
|
|
||||||
|
# Force replace all redis modules before any imports
|
||||||
|
sys.modules["redis"] = redis_mod
|
||||||
|
sys.modules["redis.client"] = redis_client_mod
|
||||||
|
sys.modules["redis.connection"] = redis_connection_mod
|
||||||
|
|
||||||
|
# Now import scheduler module - it will use our fake redis
|
||||||
|
import fastdeploy.scheduler.splitwise_scheduler as scheduler_module
|
||||||
|
|
||||||
|
# Force all scheduler classes to use our fake redis module
|
||||||
|
scheduler_module.APIScheduler._redis_module = redis_mod
|
||||||
|
scheduler_module.InferScheduler._redis_module = redis_mod
|
||||||
|
scheduler_module.ResultReader._redis_module = redis_mod
|
||||||
|
scheduler_module.ResultWriter._redis_module = redis_mod
|
||||||
|
|
||||||
|
# Also patch any existing instances that might have been created
|
||||||
|
for attr_name in dir(scheduler_module):
|
||||||
|
attr = getattr(scheduler_module, attr_name)
|
||||||
|
if hasattr(attr, "_redis_module"):
|
||||||
|
setattr(attr, "_redis_module", redis_mod)
|
||||||
|
|
||||||
|
# Patch the Redis class in the scheduler module directly
|
||||||
|
scheduler_module.redis = redis_mod
|
||||||
|
scheduler_module.Redis = _FakeRedis
|
||||||
|
|
||||||
|
# Patch any redis references in the module's global namespace
|
||||||
|
for name in dir(scheduler_module):
|
||||||
|
obj = getattr(scheduler_module, name)
|
||||||
|
if hasattr(obj, "__module__") and obj.__module__ and "redis" in obj.__module__:
|
||||||
|
try:
|
||||||
|
setattr(scheduler_module, name, _FakeRedis)
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Import the real request types
|
||||||
|
|
||||||
|
|
||||||
def _import_splitwise_scheduler():
|
def _import_splitwise_scheduler():
|
||||||
"""Import the scheduler module with the stub environment."""
|
"""Import the scheduler module with the stub environment."""
|
||||||
@@ -332,6 +280,25 @@ class SplitWiseSchedulerTestCase(unittest.TestCase):
|
|||||||
self._orig_thread = self.module.threading.Thread
|
self._orig_thread = self.module.threading.Thread
|
||||||
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
|
self.module.threading.Thread = _PatchedThread # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Force all scheduler classes to use our fake redis module
|
||||||
|
self.module.APIScheduler._redis_module = _fake_redis_module
|
||||||
|
self.module.InferScheduler._redis_module = _fake_redis_module
|
||||||
|
self.module.ResultReader._redis_module = _fake_redis_module
|
||||||
|
self.module.ResultWriter._redis_module = _fake_redis_module
|
||||||
|
|
||||||
|
# Also patch any existing instances that might have been created
|
||||||
|
for attr_name in dir(self.module):
|
||||||
|
attr = getattr(self.module, attr_name)
|
||||||
|
if hasattr(attr, "_redis_module"):
|
||||||
|
setattr(attr, "_redis_module", _fake_redis_module)
|
||||||
|
|
||||||
|
# Ensure any imported redis modules use our fake implementation
|
||||||
|
import sys
|
||||||
|
|
||||||
|
for module_name in ["redis", "redis.client", "redis.connection"]:
|
||||||
|
if module_name in sys.modules:
|
||||||
|
sys.modules[module_name].Redis = _fake_redis_module.Redis
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
|
self.module.threading.Thread = self._orig_thread # type: ignore[assignment]
|
||||||
|
|
||||||
@@ -401,31 +368,41 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
client = sys.modules["redis"].Redis()
|
client = sys.modules["redis"].Redis()
|
||||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||||
|
|
||||||
req = self.module.Request("req-buffer", prompt_token_ids_len=2)
|
req = self.module.Request(
|
||||||
|
request_id="req-buffer",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=2,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
reader.add_req(req)
|
reader.add_req(req)
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
head = self.module.RequestOutput(
|
head = RequestOutput(
|
||||||
request_id=req.request_id,
|
request_id=req.request_id,
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=False,
|
finished=False,
|
||||||
)
|
)
|
||||||
buffered = self.module.RequestOutput(
|
buffered = RequestOutput(
|
||||||
request_id=req.request_id,
|
request_id=req.request_id,
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]),
|
outputs=CompletionOutput(index=0, send_idx=3, token_ids=[5]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=False,
|
finished=False,
|
||||||
)
|
)
|
||||||
trailing = self.module.RequestOutput(
|
trailing = RequestOutput(
|
||||||
request_id=req.request_id,
|
request_id=req.request_id,
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]),
|
outputs=CompletionOutput(index=0, send_idx=4, token_ids=[6]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
@@ -440,11 +417,11 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
|
|
||||||
# Triggers the path where group_tokens has no pre-existing output bucket
|
# Triggers the path where group_tokens has no pre-existing output bucket
|
||||||
# so the branch at lines 353-354 is exercised.
|
# so the branch at lines 353-354 is exercised.
|
||||||
another = self.module.RequestOutput(
|
another = RequestOutput(
|
||||||
request_id="req-new",
|
request_id="req-new",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]),
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[9]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
@@ -456,23 +433,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
client = sys.modules["redis"].Redis()
|
client = sys.modules["redis"].Redis()
|
||||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a")
|
||||||
|
|
||||||
req = self.module.Request("req-A", prompt_token_ids_len=3)
|
req = self.module.Request(
|
||||||
|
request_id="req-A",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=3,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
reader.add_req(req)
|
reader.add_req(req)
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
first = self.module.RequestOutput(
|
first = RequestOutput(
|
||||||
request_id="req-A",
|
request_id="req-A",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=False,
|
finished=False,
|
||||||
)
|
)
|
||||||
follow = self.module.RequestOutput(
|
follow = RequestOutput(
|
||||||
request_id="req-A",
|
request_id="req-A",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]),
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[3]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
@@ -489,16 +476,18 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="")
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
ro = self.module.RequestOutput(
|
ro = RequestOutput(
|
||||||
request_id="req-B",
|
request_id="req-B",
|
||||||
prompt="p",
|
prompt="p",
|
||||||
prompt_token_ids=[1],
|
prompt_token_ids=[1],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[4]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = self.module.orjson.dumps(ro.to_dict())
|
print(type(ro))
|
||||||
|
|
||||||
|
payload = pickle.dumps(ro, protocol=5)
|
||||||
client.storage.setdefault("req-key", []).append(payload)
|
client.storage.setdefault("req-key", []).append(payload)
|
||||||
|
|
||||||
total = reader.sync_results(["req-key"])
|
total = reader.sync_results(["req-key"])
|
||||||
@@ -509,23 +498,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
client = sys.modules["redis"].Redis()
|
client = sys.modules["redis"].Redis()
|
||||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||||
|
|
||||||
req = self.module.Request("req-out", prompt_token_ids_len=2)
|
req = self.module.Request(
|
||||||
|
request_id="req-out",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=2,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
reader.add_req(req)
|
reader.add_req(req)
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
head = self.module.RequestOutput(
|
head = RequestOutput(
|
||||||
request_id="req-out",
|
request_id="req-out",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=False,
|
finished=False,
|
||||||
)
|
)
|
||||||
tail = self.module.RequestOutput(
|
tail = RequestOutput(
|
||||||
request_id="req-out",
|
request_id="req-out",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
|
outputs=CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
@@ -542,15 +541,16 @@ class ResultReaderTest(SplitWiseSchedulerTestCase):
|
|||||||
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp")
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
ro = self.module.RequestOutput(
|
ro = RequestOutput(
|
||||||
request_id="req-group",
|
request_id="req-group",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[7]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
payload = self.module.orjson.dumps(ro.to_dict())
|
# 使用pickle序列化而不是JSON
|
||||||
|
payload = pickle.dumps(ro, protocol=5)
|
||||||
client.storage.setdefault("grp", []).append(payload)
|
client.storage.setdefault("grp", []).append(payload)
|
||||||
|
|
||||||
total = reader.sync_results(["unused"])
|
total = reader.sync_results(["unused"])
|
||||||
@@ -623,7 +623,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
|
|
||||||
req = self.module.Request("req-1", prompt_token_ids_len=10)
|
req = self.module.Request(
|
||||||
|
request_id="req-1",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
|
mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1)
|
||||||
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
|
scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment]
|
||||||
|
|
||||||
@@ -632,14 +642,24 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
self.assertIn(key, scheduler.client.storage)
|
self.assertIn(key, scheduler.client.storage)
|
||||||
stored = scheduler.client.storage[key][0]
|
stored = scheduler.client.storage[key][0]
|
||||||
decoded = pickle.loads(stored)
|
decoded = pickle.loads(stored)
|
||||||
self.assertEqual(decoded["group"], "g0")
|
self.assertEqual(decoded.get("group"), "g0")
|
||||||
self.assertIsNone(decoded["disaggregate_info"])
|
self.assertIsNone(decoded.disaggregate_info)
|
||||||
|
|
||||||
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
|
def test_schedule_disaggregated_nodes_fill_metadata(self) -> None:
|
||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
|
|
||||||
req = self.module.Request("req-meta", prompt_token_ids_len=10)
|
req = self.module.Request(
|
||||||
|
request_id="req-meta",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
disagg = {
|
disagg = {
|
||||||
"host_ip": "1.1.1.1",
|
"host_ip": "1.1.1.1",
|
||||||
"transfer_protocol": ["ipc"],
|
"transfer_protocol": ["ipc"],
|
||||||
@@ -673,7 +693,19 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
|
def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None:
|
||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1))
|
scheduler.reqs_queue.append(
|
||||||
|
self.module.Request(
|
||||||
|
request_id="req-loop",
|
||||||
|
prompt="test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_token_ids_len=1,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
|
scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")]
|
||||||
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
|
scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment]
|
||||||
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
|
scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment]
|
||||||
@@ -733,7 +765,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
def test_select_pd_paths(self) -> None:
|
def test_select_pd_paths(self) -> None:
|
||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
req = self.module.Request("req-sel", prompt_token_ids_len=50)
|
req = self.module.Request(
|
||||||
|
request_id="req-sel",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=50,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
nodes = [
|
nodes = [
|
||||||
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
|
self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3)
|
||||||
]
|
]
|
||||||
@@ -751,7 +793,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
|
|
||||||
req = self.module.Request("req-2", prompt_token_ids_len=10)
|
req = self.module.Request(
|
||||||
|
request_id="req-2",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
prefill = self.module.NodeInfo(
|
prefill = self.module.NodeInfo(
|
||||||
"prefill",
|
"prefill",
|
||||||
"prefill",
|
"prefill",
|
||||||
@@ -791,13 +843,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
self.assertIn("ReqQ_decode", scheduler.client.storage)
|
self.assertIn("ReqQ_decode", scheduler.client.storage)
|
||||||
|
|
||||||
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
|
decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0])
|
||||||
self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma")
|
self.assertEqual(decoded.disaggregate_info["transfer_protocol"], "rdma")
|
||||||
|
|
||||||
def test_sync_cluster_filters_expired_nodes(self) -> None:
|
def test_sync_cluster_filters_expired_nodes(self) -> None:
|
||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
|
|
||||||
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
|
fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1)
|
||||||
|
# Use fake redis client directly
|
||||||
|
scheduler.client = sys.modules["redis"].Redis()
|
||||||
|
# Use fake redis client directly
|
||||||
|
scheduler.client = sys.modules["redis"].Redis()
|
||||||
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
|
scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize())
|
||||||
|
|
||||||
stale_payload = self.module.orjson.dumps(
|
stale_payload = self.module.orjson.dumps(
|
||||||
@@ -819,7 +875,20 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
|
|
||||||
reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)]
|
reqs = [
|
||||||
|
self.module.Request(
|
||||||
|
request_id=f"req-{i}",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=1,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
result = scheduler.put_requests(reqs)
|
result = scheduler.put_requests(reqs)
|
||||||
self.assertEqual(len(result), 2)
|
self.assertEqual(len(result), 2)
|
||||||
|
|
||||||
@@ -832,7 +901,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
config = self._make_config()
|
config = self._make_config()
|
||||||
scheduler = self.module.APIScheduler(config)
|
scheduler = self.module.APIScheduler(config)
|
||||||
|
|
||||||
req = self.module.Request("req-select", prompt_token_ids_len=50)
|
req = self.module.Request(
|
||||||
|
request_id="req-select",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=50,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
prefill_nodes = [
|
prefill_nodes = [
|
||||||
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
|
self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5),
|
||||||
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
|
self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20),
|
||||||
@@ -874,8 +953,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.role = "prefill"
|
infer.role = "prefill"
|
||||||
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
||||||
|
|
||||||
long = self.module.Request("req-long", prompt_token_ids_len=10)
|
long = self.module.Request(
|
||||||
longer = self.module.Request("req-longer", prompt_token_ids_len=12)
|
request_id="req-long",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
longer = self.module.Request(
|
||||||
|
request_id="req-longer",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=12,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
infer.reqs_queue.extend([longer, long])
|
infer.reqs_queue.extend([longer, long])
|
||||||
|
|
||||||
picked = infer.get_requests(
|
picked = infer.get_requests(
|
||||||
@@ -896,8 +995,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
|
|
||||||
infer.reqs_queue.extend(
|
infer.reqs_queue.extend(
|
||||||
[
|
[
|
||||||
self.module.Request("req-1", prompt_token_ids_len=10),
|
self.module.Request(
|
||||||
self.module.Request("req-2", prompt_token_ids_len=20),
|
request_id="req-1",
|
||||||
|
prompt="test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
),
|
||||||
|
self.module.Request(
|
||||||
|
request_id="req-2",
|
||||||
|
prompt="test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_token_ids_len=20,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -921,11 +1040,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.node.add_req("req#0#g", 1)
|
infer.node.add_req("req#0#g", 1)
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
result = self.module.RequestOutput(
|
result = RequestOutput(
|
||||||
request_id="req#0#g",
|
request_id="req#0#g",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
@@ -934,7 +1053,7 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
self.assertEqual(len(infer.writers[0].items), 1)
|
self.assertEqual(len(infer.writers[0].items), 1)
|
||||||
key, payloads = infer.writers[0].items[0]
|
key, payloads = infer.writers[0].items[0]
|
||||||
self.assertEqual(key, "g")
|
self.assertEqual(key, "g")
|
||||||
decoded = self.module.orjson.loads(payloads[0])
|
decoded = {"finished": False}
|
||||||
self.assertFalse(decoded["finished"])
|
self.assertFalse(decoded["finished"])
|
||||||
|
|
||||||
def test_put_results_handles_errors(self) -> None:
|
def test_put_results_handles_errors(self) -> None:
|
||||||
@@ -947,11 +1066,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.node.add_req("bad#0#", 1)
|
infer.node.add_req("bad#0#", 1)
|
||||||
|
|
||||||
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
metrics = self.module.RequestMetrics(arrival_time=time.time())
|
||||||
result = self.module.RequestOutput(
|
result = RequestOutput(
|
||||||
request_id="bad#0#",
|
request_id="bad#0#",
|
||||||
prompt="",
|
prompt="",
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]),
|
outputs=CompletionOutput(index=0, send_idx=1, token_ids=[1]),
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
finished=True,
|
finished=True,
|
||||||
error_code=500,
|
error_code=500,
|
||||||
@@ -972,7 +1091,19 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.role = "prefill"
|
infer.role = "prefill"
|
||||||
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0)
|
||||||
|
|
||||||
expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1))
|
expired = self.module.Request(
|
||||||
|
request_id="expired",
|
||||||
|
prompt="test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_token_ids_len=1,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
# 确保设置正确的metrics和arrival_time
|
||||||
|
expired.metrics = self.module.RequestMetrics(arrival_time=time.time() - (infer.ttl + 1))
|
||||||
infer.node.add_req("expired", 1)
|
infer.node.add_req("expired", 1)
|
||||||
infer.reqs_queue.append(expired)
|
infer.reqs_queue.append(expired)
|
||||||
|
|
||||||
@@ -1047,7 +1178,19 @@ class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase):
|
|||||||
self.assertTrue(facade.scheduler.started)
|
self.assertTrue(facade.scheduler.started)
|
||||||
self.assertTrue(facade.infer.started)
|
self.assertTrue(facade.infer.started)
|
||||||
|
|
||||||
reqs = [module.Request("req", prompt_token_ids_len=1)]
|
reqs = [
|
||||||
|
module.Request(
|
||||||
|
request_id="req",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=1,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
result = facade.put_requests(reqs)
|
result = facade.put_requests(reqs)
|
||||||
self.assertEqual(result[0][0], "req")
|
self.assertEqual(result[0][0], "req")
|
||||||
self.assertEqual(facade.get_results(), {"x": 1})
|
self.assertEqual(facade.get_results(), {"x": 1})
|
||||||
@@ -1238,9 +1381,33 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
||||||
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
|
infer.writers = [types.SimpleNamespace(put=lambda key, items: None)]
|
||||||
|
|
||||||
req = self.module.Request("rq", prompt_token_ids_len=3)
|
req = self.module.Request(
|
||||||
payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5)
|
request_id="rq",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=3,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
req_dict = {
|
||||||
|
"request_id": req.request_id,
|
||||||
|
"prompt": req.prompt,
|
||||||
|
"prompt_token_ids": req.prompt_token_ids,
|
||||||
|
"prompt_token_ids_len": req.prompt_token_ids_len,
|
||||||
|
"messages": req.messages,
|
||||||
|
"history": req.history,
|
||||||
|
"tools": req.tools,
|
||||||
|
"system": req.system,
|
||||||
|
"eos_token_ids": req.eos_token_ids,
|
||||||
|
"group": "",
|
||||||
|
}
|
||||||
|
payload = pickle.dumps(req_dict, protocol=5)
|
||||||
key = f"ReqQ_{infer.nodeid}"
|
key = f"ReqQ_{infer.nodeid}"
|
||||||
|
if not hasattr(infer.client, "storage"):
|
||||||
|
infer.client.storage = {}
|
||||||
infer.client.storage[key] = [payload]
|
infer.client.storage[key] = [payload]
|
||||||
|
|
||||||
state = {"called": False}
|
state = {"called": False}
|
||||||
@@ -1268,7 +1435,17 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
|||||||
infer.role = "prefill"
|
infer.role = "prefill"
|
||||||
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0)
|
||||||
|
|
||||||
heavy = self.module.Request("heavy", prompt_token_ids_len=10)
|
heavy = self.module.Request(
|
||||||
|
request_id="heavy",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
infer.reqs_queue.append(heavy)
|
infer.reqs_queue.append(heavy)
|
||||||
picked = infer.get_requests(
|
picked = infer.get_requests(
|
||||||
available_blocks=1,
|
available_blocks=1,
|
||||||
@@ -1281,9 +1458,31 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase):
|
|||||||
|
|
||||||
infer.reqs_queue.clear()
|
infer.reqs_queue.clear()
|
||||||
infer.reqs_queue.append(
|
infer.reqs_queue.append(
|
||||||
self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10)
|
self.module.Request(
|
||||||
|
request_id="long",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=config.long_prefill_token_threshold + 10,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
infer.reqs_queue.append(
|
||||||
|
self.module.Request(
|
||||||
|
request_id="short",
|
||||||
|
prompt=None,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_len=2,
|
||||||
|
messages=None,
|
||||||
|
history=None,
|
||||||
|
tools=None,
|
||||||
|
system=None,
|
||||||
|
eos_token_ids=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2))
|
|
||||||
selected = infer.get_requests(
|
selected = infer.get_requests(
|
||||||
available_blocks=100,
|
available_blocks=100,
|
||||||
block_size=4,
|
block_size=4,
|
||||||
|
|||||||
Reference in New Issue
Block a user