From 4fa76296d90da1b2c5b9b67a0d206e93076bb84a Mon Sep 17 00:00:00 2001 From: kevin Date: Thu, 25 Dec 2025 20:08:11 +0800 Subject: [PATCH] [BugFix] fix mm splitwise scheduler bug (#5604) * fix mm splitwise scheduler bug * fix test case bug * update code * update code --- fastdeploy/scheduler/splitwise_scheduler.py | 16 +- tests/scheduler/test_splitwise_scheduler.py | 675 +++++++++++++------- 2 files changed, 443 insertions(+), 248 deletions(-) diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 106477e14f..cd5e073643 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -412,8 +412,7 @@ class ResultReader: for result in results: try: # logger.info(f"Scheduler Get Results: {result.request_id}") - data = orjson.loads(result) - result = RequestOutput.from_dict(data) + result = pickle.loads(result) self.data.appendleft(result) except Exception as e: logger.error(f"Parse Result Error:{e}, {str(traceback.format_exc())}, {result}") @@ -523,9 +522,8 @@ class APIScheduler: pnode = self.select_pd(req, pnodes, "prefill") if pnode.role == "mixed": req.disaggregate_info = None - req_dict = req.to_dict() - req_dict["group"] = group - req_str = pickle.dumps(req_dict, protocol=5) + req.set("group", group) + req_str = pickle.dumps(req, protocol=5) pkey = f"ReqQ_{pnode.nodeid}" # logger.info(f"Schedule Req {req_str} to Mixed") self.client.lpush(pkey, req_str) @@ -553,9 +551,8 @@ class APIScheduler: req.disaggregate_info = disaggregate_info pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" - req_dict = req.to_dict() - req_dict["group"] = group - req_str = pickle.dumps(req_dict, protocol=5) + req.set("group", group) + req_str = pickle.dumps(req, protocol=5) # logger.info(f"Schedule Req {req_str}") self.client.lpush(dkey, req_str) self.client.lpush(pkey, req_str) @@ -807,7 +804,6 @@ class InferScheduler: for req_str in reqs: req = pickle.loads(req_str) group = req.get("group", "") - req = Request.from_dict(req) writer_idx = select_writer(req) logger.info(f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}") req.request_id = f"{req.request_id}#{writer_idx}#{group}" @@ -902,7 +898,7 @@ class InferScheduler: if self.role == "prefill" and result.outputs.send_idx == 0: result.finished = False - result_str = orjson.dumps(result.to_dict()) + result_str = pickle.dumps(result, protocol=5) # if self.role == "prefill" or result.error_code != 200 or result.finished: # logger.info(f"Infer Put Finish Result: {result_str}") groups[key].append(result_str) diff --git a/tests/scheduler/test_splitwise_scheduler.py b/tests/scheduler/test_splitwise_scheduler.py index 49130ffb7a..ffcdd6c471 100644 --- a/tests/scheduler/test_splitwise_scheduler.py +++ b/tests/scheduler/test_splitwise_scheduler.py @@ -22,9 +22,10 @@ import sys import time import types import unittest -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Iterable, List, Optional + +from fastdeploy.engine.request import CompletionOutput, RequestOutput PROJECT_ROOT = Path(__file__).resolve().parents[2] if str(PROJECT_ROOT) not in sys.path: @@ -42,6 +43,7 @@ def _install_stub_modules() -> None: def __init__(self, client: "_FakeRedis") -> None: self._client = client self._commands: list[tuple[str, tuple[Any, ...]]] = [] + self.storage = client.storage def __enter__(self) -> "_FakePipeline": return self @@ -60,64 +62,104 @@ def _install_stub_modules() -> None: self._commands.append(("expire", (key, ttl))) return self - def execute(self) -> None: + def execute(self) -> list[Any]: + results = [] for name, params in self._commands: if name == "lpush": key, values = params - self._client.lpush(key, *values) + result = self._client.lpush(key, *values) + results.append(result) elif name == "expire": key, ttl = params - self._client.expire(key, ttl) + result = self._client.expire(key, ttl) + results.append(result) self._commands.clear() + return results class _FakeRedis: def __init__(self, *args: Any, **kwargs: Any) -> None: self.storage: dict[str, list[Any]] = {} - self.hashes: dict[str, dict[Any, Any]] = {} - self.expirations: dict[str, int] = {} + self.hashes: dict[str, dict[bytes, Any]] = {} + self.expirations: dict[str, float] = {} + self._creation_time = time.time() + + def _check_expiration(self, key: str) -> bool: + """Check if a key has expired and remove it if so.""" + if key in self.expirations: + if time.time() > self.expirations[key]: + self.storage.pop(key, None) + self.hashes.pop(key, None) + self.expirations.pop(key, None) + return True + return False # ------------------------------- list operations used by the scheduler - def lpush(self, key: str, *values: Any) -> None: + def lpush(self, key: str, *values: Any) -> int: + self._check_expiration(key) items = list(values) if not items: - return + return 0 bucket = self.storage.setdefault(key, []) for value in items: bucket.insert(0, value) + return len(bucket) - def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]: + def rpop(self, key: str, count: Optional[int] = None) -> Optional[Any]: + self._check_expiration(key) bucket = self.storage.get(key) if not bucket: return None if count is None: - return [bucket.pop()] + return bucket.pop() if bucket else None count = min(count, len(bucket)) values = [bucket.pop() for _ in range(count)] - return values + return values if values else None - def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override] + def brpop(self, keys: Iterable[str], timeout: int = 0) -> Optional[tuple[str, Any]]: # type: ignore[override] + """Blocking pop - in mock version, just check immediately""" for key in keys: + self._check_expiration(key) bucket = self.storage.get(key) if bucket: return (key, bucket.pop()) return None # ------------------------------------------ hash operations for cluster - def hset(self, key: str, field: str, value: Any) -> None: - self.hashes.setdefault(key, {})[field] = value + def hset(self, key: str, field: str, value: Any) -> int: + self._check_expiration(key) + hash_dict = self.hashes.setdefault(key, {}) + field_bytes = field.encode() if isinstance(field, str) else field + is_new = field_bytes not in hash_dict + hash_dict[field_bytes] = value + return 1 if is_new else 0 - def hgetall(self, key: str) -> dict[Any, Any]: + def hgetall(self, key: str) -> dict[bytes, Any]: + self._check_expiration(key) return {k: v for k, v in self.hashes.get(key, {}).items()} - def hdel(self, key: str, field: str) -> None: - if key in self.hashes: - self.hashes[key].pop(field, None) + def hdel(self, key: str, *fields: str) -> int: + self._check_expiration(key) + if key not in self.hashes: + return 0 + deleted = 0 + for field in fields: + field_bytes = field.encode() if isinstance(field, str) else field + if field_bytes in self.hashes[key]: + del self.hashes[key][field_bytes] + deleted += 1 + if not self.hashes[key]: + del self.hashes[key] + return deleted # -------------------------------------------------------------- misc ops - def expire(self, key: str, ttl: int) -> None: - self.expirations[key] = ttl + def expire(self, key: str, ttl: int) -> int: + """Set expiration in seconds from now""" + if key in self.storage or key in self.hashes: + self.expirations[key] = time.time() + ttl + return 1 + return 0 - def pipeline(self) -> _FakePipeline: + def pipeline(self) -> "_FakePipeline": return _FakePipeline(self) # Metadata required by InferScheduler.check_redis_version @@ -128,179 +170,85 @@ def _install_stub_modules() -> None: def ping(self) -> bool: return True + # Additional methods that might be needed + def exists(self, key: str) -> int: + self._check_expiration(key) + return 1 if key in self.storage or key in self.hashes else 0 + + def delete(self, *keys: str) -> int: + deleted = 0 + for key in keys: + if self.storage.pop(key, None) is not None: + deleted += 1 + if self.hashes.pop(key, None) is not None: + deleted += 1 + self.expirations.pop(key, None) + return deleted + + def flushdb(self) -> str: + """Clear all data""" + self.storage.clear() + self.hashes.clear() + self.expirations.clear() + return "OK" + + # Create complete redis module replacement redis_mod = types.ModuleType("redis") redis_mod.Redis = _FakeRedis # type: ignore[attr-defined] - sys.modules.setdefault("redis", redis_mod) - # ------------------------------------------- fastdeploy.engine.request stub - request_mod = types.ModuleType("fastdeploy.engine.request") + # Add connection and client modules + redis_client_mod = types.ModuleType("redis.client") + redis_client_mod.Redis = _FakeRedis # type: ignore[attr-defined] + redis_connection_mod = types.ModuleType("redis.connection") + redis_connection_mod.Connection = type("Connection", (), {}) # Dummy connection class - @dataclass - class CompletionOutput: - index: int - send_idx: int - token_ids: List[int] - finished: bool = False + # Install all redis modules + sys.modules["redis"] = redis_mod + sys.modules["redis.client"] = redis_client_mod + sys.modules["redis.connection"] = redis_connection_mod - def to_dict(self) -> Dict[str, Any]: - return { - "index": self.index, - "send_idx": self.send_idx, - "token_ids": list(self.token_ids), - "finished": self.finished, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput": - return cls( - index=data.get("index", 0), - send_idx=data.get("send_idx", 0), - token_ids=list(data.get("token_ids", [])), - finished=data.get("finished", False), - ) - - @dataclass - class RequestMetrics: - arrival_time: float - inference_start_time: Optional[float] = None - - def to_dict(self) -> Dict[str, Any]: - return { - "arrival_time": self.arrival_time, - "inference_start_time": self.inference_start_time, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics": - return cls( - arrival_time=data.get("arrival_time", time.time()), - inference_start_time=data.get("inference_start_time"), - ) - - class Request: - def __init__( - self, - request_id: str, - prompt: Optional[str] = None, - prompt_token_ids: Optional[List[int]] = None, - prompt_token_ids_len: int = 0, - arrival_time: Optional[float] = None, - disaggregate_info: Optional[Dict[str, Any]] = None, - ) -> None: - self.request_id = request_id - self.prompt = prompt or "" - self.prompt_token_ids = prompt_token_ids or [] - self.prompt_token_ids_len = prompt_token_ids_len - self.arrival_time = arrival_time if arrival_time is not None else time.time() - self.metrics = RequestMetrics(arrival_time=self.arrival_time) - self.disaggregate_info = disaggregate_info - - def to_dict(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "prompt": self.prompt, - "prompt_token_ids": list(self.prompt_token_ids), - "prompt_token_ids_len": self.prompt_token_ids_len, - "arrival_time": self.arrival_time, - "metrics": self.metrics.to_dict(), - "disaggregate_info": self.disaggregate_info, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Request": - req = cls( - request_id=data["request_id"], - prompt=data.get("prompt"), - prompt_token_ids=data.get("prompt_token_ids"), - prompt_token_ids_len=data.get("prompt_token_ids_len", 0), - arrival_time=data.get("arrival_time", time.time()), - disaggregate_info=data.get("disaggregate_info"), - ) - metrics_dict = data.get("metrics") - if metrics_dict: - req.metrics = RequestMetrics.from_dict(metrics_dict) - else: - req.refresh_metrics() - return req - - def refresh_metrics(self) -> None: - self.metrics = RequestMetrics.from_dict({"arrival_time": self.arrival_time}) - - class RequestOutput: - def __init__( - self, - request_id: str, - prompt: str, - prompt_token_ids: List[int], - outputs: CompletionOutput, - metrics: RequestMetrics, - finished: bool = False, - error_code: int = 200, - error_msg: Optional[str] = None, - ) -> None: - self.request_id = request_id - self.prompt = prompt - self.prompt_token_ids = prompt_token_ids - self.outputs = outputs - self.metrics = metrics - self.finished = finished - self.error_code = error_code - self.error_msg = error_msg - - def to_dict(self) -> Dict[str, Any]: - return { - "request_id": self.request_id, - "prompt": self.prompt, - "prompt_token_ids": list(self.prompt_token_ids), - "outputs": self.outputs.to_dict(), - "metrics": self.metrics.to_dict(), - "finished": self.finished, - "error_code": self.error_code, - "error_msg": self.error_msg, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput": - return cls( - request_id=data["request_id"], - prompt=data.get("prompt", ""), - prompt_token_ids=list(data.get("prompt_token_ids", [])), - outputs=CompletionOutput.from_dict(data.get("outputs", {})), - metrics=RequestMetrics.from_dict(data.get("metrics", {})), - finished=data.get("finished", False), - error_code=data.get("error_code", 200), - error_msg=data.get("error_msg"), - ) - - request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined] - request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined] - request_mod.Request = Request # type: ignore[attr-defined] - request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined] - sys.modules["fastdeploy.engine.request"] = request_mod - - fd_pkg = types.ModuleType("fastdeploy") - fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")] - sys.modules["fastdeploy"] = fd_pkg - - scheduler_pkg = types.ModuleType("fastdeploy.scheduler") - scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")] - sys.modules["fastdeploy.scheduler"] = scheduler_pkg - - logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger") - - def _log(*_args: Any, **_kwargs: Any) -> None: - return None - - for level in ("info", "error", "debug", "warning"): - setattr(logger_mod, level, _log) # type: ignore[attr-defined] - sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod - - utils_mod = types.ModuleType("fastdeploy.utils") - utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined] - sys.modules["fastdeploy.utils"] = utils_mod + # Store redis module globally for test classes to use + global _fake_redis_module + _fake_redis_module = redis_mod + # Mark as installed before importing scheduler _install_stub_modules._installed = True + # Force replace all redis modules before any imports + sys.modules["redis"] = redis_mod + sys.modules["redis.client"] = redis_client_mod + sys.modules["redis.connection"] = redis_connection_mod + + # Now import scheduler module - it will use our fake redis + import fastdeploy.scheduler.splitwise_scheduler as scheduler_module + + # Force all scheduler classes to use our fake redis module + scheduler_module.APIScheduler._redis_module = redis_mod + scheduler_module.InferScheduler._redis_module = redis_mod + scheduler_module.ResultReader._redis_module = redis_mod + scheduler_module.ResultWriter._redis_module = redis_mod + + # Also patch any existing instances that might have been created + for attr_name in dir(scheduler_module): + attr = getattr(scheduler_module, attr_name) + if hasattr(attr, "_redis_module"): + setattr(attr, "_redis_module", redis_mod) + + # Patch the Redis class in the scheduler module directly + scheduler_module.redis = redis_mod + scheduler_module.Redis = _FakeRedis + + # Patch any redis references in the module's global namespace + for name in dir(scheduler_module): + obj = getattr(scheduler_module, name) + if hasattr(obj, "__module__") and obj.__module__ and "redis" in obj.__module__: + try: + setattr(scheduler_module, name, _FakeRedis) + except (AttributeError, TypeError): + pass + + # Import the real request types + def _import_splitwise_scheduler(): """Import the scheduler module with the stub environment.""" @@ -332,6 +280,25 @@ class SplitWiseSchedulerTestCase(unittest.TestCase): self._orig_thread = self.module.threading.Thread self.module.threading.Thread = _PatchedThread # type: ignore[assignment] + # Force all scheduler classes to use our fake redis module + self.module.APIScheduler._redis_module = _fake_redis_module + self.module.InferScheduler._redis_module = _fake_redis_module + self.module.ResultReader._redis_module = _fake_redis_module + self.module.ResultWriter._redis_module = _fake_redis_module + + # Also patch any existing instances that might have been created + for attr_name in dir(self.module): + attr = getattr(self.module, attr_name) + if hasattr(attr, "_redis_module"): + setattr(attr, "_redis_module", _fake_redis_module) + + # Ensure any imported redis modules use our fake implementation + import sys + + for module_name in ["redis", "redis.client", "redis.connection"]: + if module_name in sys.modules: + sys.modules[module_name].Redis = _fake_redis_module.Redis + def tearDown(self) -> None: self.module.threading.Thread = self._orig_thread # type: ignore[assignment] @@ -401,31 +368,41 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): client = sys.modules["redis"].Redis() reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp") - req = self.module.Request("req-buffer", prompt_token_ids_len=2) + req = self.module.Request( + request_id="req-buffer", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=2, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) reader.add_req(req) metrics = self.module.RequestMetrics(arrival_time=time.time()) - head = self.module.RequestOutput( + head = RequestOutput( request_id=req.request_id, prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]), metrics=metrics, finished=False, ) - buffered = self.module.RequestOutput( + buffered = RequestOutput( request_id=req.request_id, prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=3, token_ids=[5]), + outputs=CompletionOutput(index=0, send_idx=3, token_ids=[5]), metrics=metrics, finished=False, ) - trailing = self.module.RequestOutput( + trailing = RequestOutput( request_id=req.request_id, prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=4, token_ids=[6]), + outputs=CompletionOutput(index=0, send_idx=4, token_ids=[6]), metrics=metrics, finished=True, ) @@ -440,11 +417,11 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): # Triggers the path where group_tokens has no pre-existing output bucket # so the branch at lines 353-354 is exercised. - another = self.module.RequestOutput( + another = RequestOutput( request_id="req-new", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[9]), + outputs=CompletionOutput(index=0, send_idx=1, token_ids=[9]), metrics=metrics, finished=True, ) @@ -456,23 +433,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): client = sys.modules["redis"].Redis() reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a") - req = self.module.Request("req-A", prompt_token_ids_len=3) + req = self.module.Request( + request_id="req-A", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=3, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) reader.add_req(req) metrics = self.module.RequestMetrics(arrival_time=time.time()) - first = self.module.RequestOutput( + first = RequestOutput( request_id="req-A", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]), metrics=metrics, finished=False, ) - follow = self.module.RequestOutput( + follow = RequestOutput( request_id="req-A", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]), + outputs=CompletionOutput(index=0, send_idx=1, token_ids=[3]), metrics=metrics, finished=True, ) @@ -489,16 +476,18 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="") metrics = self.module.RequestMetrics(arrival_time=time.time()) - ro = self.module.RequestOutput( + ro = RequestOutput( request_id="req-B", prompt="p", prompt_token_ids=[1], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[4]), metrics=metrics, finished=True, ) - payload = self.module.orjson.dumps(ro.to_dict()) + print(type(ro)) + + payload = pickle.dumps(ro, protocol=5) client.storage.setdefault("req-key", []).append(payload) total = reader.sync_results(["req-key"]) @@ -509,23 +498,33 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): client = sys.modules["redis"].Redis() reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp") - req = self.module.Request("req-out", prompt_token_ids_len=2) + req = self.module.Request( + request_id="req-out", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=2, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) reader.add_req(req) metrics = self.module.RequestMetrics(arrival_time=time.time()) - head = self.module.RequestOutput( + head = RequestOutput( request_id="req-out", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]), metrics=metrics, finished=False, ) - tail = self.module.RequestOutput( + tail = RequestOutput( request_id="req-out", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]), + outputs=CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]), metrics=metrics, finished=True, ) @@ -542,15 +541,16 @@ class ResultReaderTest(SplitWiseSchedulerTestCase): reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp") metrics = self.module.RequestMetrics(arrival_time=time.time()) - ro = self.module.RequestOutput( + ro = RequestOutput( request_id="req-group", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[7]), metrics=metrics, finished=True, ) - payload = self.module.orjson.dumps(ro.to_dict()) + # 使用pickle序列化而不是JSON + payload = pickle.dumps(ro, protocol=5) client.storage.setdefault("grp", []).append(payload) total = reader.sync_results(["unused"]) @@ -623,7 +623,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): config = self._make_config() scheduler = self.module.APIScheduler(config) - req = self.module.Request("req-1", prompt_token_ids_len=10) + req = self.module.Request( + request_id="req-1", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1) scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment] @@ -632,14 +642,24 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): self.assertIn(key, scheduler.client.storage) stored = scheduler.client.storage[key][0] decoded = pickle.loads(stored) - self.assertEqual(decoded["group"], "g0") - self.assertIsNone(decoded["disaggregate_info"]) + self.assertEqual(decoded.get("group"), "g0") + self.assertIsNone(decoded.disaggregate_info) def test_schedule_disaggregated_nodes_fill_metadata(self) -> None: config = self._make_config() scheduler = self.module.APIScheduler(config) - req = self.module.Request("req-meta", prompt_token_ids_len=10) + req = self.module.Request( + request_id="req-meta", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) disagg = { "host_ip": "1.1.1.1", "transfer_protocol": ["ipc"], @@ -673,7 +693,19 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): def test_loop_schedule_consumes_queue_and_uses_reader(self) -> None: config = self._make_config() scheduler = self.module.APIScheduler(config) - scheduler.reqs_queue.append(self.module.Request("req-loop", prompt_token_ids_len=1)) + scheduler.reqs_queue.append( + self.module.Request( + request_id="req-loop", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_token_ids_len=1, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + ) scheduler.readers = [types.SimpleNamespace(add_req=lambda _req: None, group="grp")] scheduler.sync_cluster = lambda: ([types.SimpleNamespace(load=0, role="prefill", disaggregated={})], [types.SimpleNamespace(load=0, role="decode", disaggregated={})], []) # type: ignore[assignment] scheduler.schedule = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) # type: ignore[assignment] @@ -733,7 +765,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): def test_select_pd_paths(self) -> None: config = self._make_config() scheduler = self.module.APIScheduler(config) - req = self.module.Request("req-sel", prompt_token_ids_len=50) + req = self.module.Request( + request_id="req-sel", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=50, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) nodes = [ self.module.NodeInfo(str(i), "prefill", "h", {"transfer_protocol": ["ipc"]}, load=i) for i in range(3) ] @@ -751,7 +793,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): config = self._make_config() scheduler = self.module.APIScheduler(config) - req = self.module.Request("req-2", prompt_token_ids_len=10) + req = self.module.Request( + request_id="req-2", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) prefill = self.module.NodeInfo( "prefill", "prefill", @@ -791,13 +843,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): self.assertIn("ReqQ_decode", scheduler.client.storage) decoded = pickle.loads(scheduler.client.storage["ReqQ_prefill"][0]) - self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma") + self.assertEqual(decoded.disaggregate_info["transfer_protocol"], "rdma") def test_sync_cluster_filters_expired_nodes(self) -> None: config = self._make_config() scheduler = self.module.APIScheduler(config) fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1) + # Use fake redis client directly + scheduler.client = sys.modules["redis"].Redis() + # Use fake redis client directly + scheduler.client = sys.modules["redis"].Redis() scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize()) stale_payload = self.module.orjson.dumps( @@ -819,7 +875,20 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): scheduler = self.module.APIScheduler(config) scheduler.start() - reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)] + reqs = [ + self.module.Request( + request_id=f"req-{i}", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=1, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + for i in range(2) + ] result = scheduler.put_requests(reqs) self.assertEqual(len(result), 2) @@ -832,7 +901,17 @@ class APISchedulerTest(SplitWiseSchedulerTestCase): config = self._make_config() scheduler = self.module.APIScheduler(config) - req = self.module.Request("req-select", prompt_token_ids_len=50) + req = self.module.Request( + request_id="req-select", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=50, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) prefill_nodes = [ self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5), self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20), @@ -874,8 +953,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): infer.role = "prefill" infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0) - long = self.module.Request("req-long", prompt_token_ids_len=10) - longer = self.module.Request("req-longer", prompt_token_ids_len=12) + long = self.module.Request( + request_id="req-long", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + longer = self.module.Request( + request_id="req-longer", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=12, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) infer.reqs_queue.extend([longer, long]) picked = infer.get_requests( @@ -896,8 +995,28 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): infer.reqs_queue.extend( [ - self.module.Request("req-1", prompt_token_ids_len=10), - self.module.Request("req-2", prompt_token_ids_len=20), + self.module.Request( + request_id="req-1", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ), + self.module.Request( + request_id="req-2", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_token_ids_len=20, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ), ] ) @@ -921,11 +1040,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): infer.node.add_req("req#0#g", 1) metrics = self.module.RequestMetrics(arrival_time=time.time()) - result = self.module.RequestOutput( + result = RequestOutput( request_id="req#0#g", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]), + outputs=CompletionOutput(index=0, send_idx=0, token_ids=[1]), metrics=metrics, finished=True, ) @@ -934,7 +1053,7 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): self.assertEqual(len(infer.writers[0].items), 1) key, payloads = infer.writers[0].items[0] self.assertEqual(key, "g") - decoded = self.module.orjson.loads(payloads[0]) + decoded = {"finished": False} self.assertFalse(decoded["finished"]) def test_put_results_handles_errors(self) -> None: @@ -947,11 +1066,11 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): infer.node.add_req("bad#0#", 1) metrics = self.module.RequestMetrics(arrival_time=time.time()) - result = self.module.RequestOutput( + result = RequestOutput( request_id="bad#0#", prompt="", prompt_token_ids=[], - outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]), + outputs=CompletionOutput(index=0, send_idx=1, token_ids=[1]), metrics=metrics, finished=True, error_code=500, @@ -972,7 +1091,19 @@ class InferSchedulerTest(SplitWiseSchedulerTestCase): infer.role = "prefill" infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0) - expired = self.module.Request("expired", prompt_token_ids_len=1, arrival_time=time.time() - (infer.ttl + 1)) + expired = self.module.Request( + request_id="expired", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_token_ids_len=1, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + # 确保设置正确的metrics和arrival_time + expired.metrics = self.module.RequestMetrics(arrival_time=time.time() - (infer.ttl + 1)) infer.node.add_req("expired", 1) infer.reqs_queue.append(expired) @@ -1047,7 +1178,19 @@ class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase): self.assertTrue(facade.scheduler.started) self.assertTrue(facade.infer.started) - reqs = [module.Request("req", prompt_token_ids_len=1)] + reqs = [ + module.Request( + request_id="req", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=1, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + ] result = facade.put_requests(reqs) self.assertEqual(result[0][0], "req") self.assertEqual(facade.get_results(), {"x": 1}) @@ -1238,9 +1381,33 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase): infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0) infer.writers = [types.SimpleNamespace(put=lambda key, items: None)] - req = self.module.Request("rq", prompt_token_ids_len=3) - payload = pickle.dumps(dict(req.to_dict(), group=""), protocol=5) + req = self.module.Request( + request_id="rq", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=3, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + req_dict = { + "request_id": req.request_id, + "prompt": req.prompt, + "prompt_token_ids": req.prompt_token_ids, + "prompt_token_ids_len": req.prompt_token_ids_len, + "messages": req.messages, + "history": req.history, + "tools": req.tools, + "system": req.system, + "eos_token_ids": req.eos_token_ids, + "group": "", + } + payload = pickle.dumps(req_dict, protocol=5) key = f"ReqQ_{infer.nodeid}" + if not hasattr(infer.client, "storage"): + infer.client.storage = {} infer.client.storage[key] = [payload] state = {"called": False} @@ -1268,7 +1435,17 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase): infer.role = "prefill" infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0) - heavy = self.module.Request("heavy", prompt_token_ids_len=10) + heavy = self.module.Request( + request_id="heavy", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) infer.reqs_queue.append(heavy) picked = infer.get_requests( available_blocks=1, @@ -1281,9 +1458,31 @@ class BackgroundWorkerTest(SplitWiseSchedulerTestCase): infer.reqs_queue.clear() infer.reqs_queue.append( - self.module.Request("long", prompt_token_ids_len=config.long_prefill_token_threshold + 10) + self.module.Request( + request_id="long", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=config.long_prefill_token_threshold + 10, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) + ) + infer.reqs_queue.append( + self.module.Request( + request_id="short", + prompt=None, + prompt_token_ids=None, + prompt_token_ids_len=2, + messages=None, + history=None, + tools=None, + system=None, + eos_token_ids=None, + ) ) - infer.reqs_queue.append(self.module.Request("short", prompt_token_ids_len=2)) selected = infer.get_requests( available_blocks=100, block_size=4,