mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
644 lines
22 KiB
Python
644 lines
22 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
import types
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import pytest
|
|
|
|
if not hasattr(paddle, "enable_compat"):
|
|
paddle.enable_compat = lambda *args, **kwargs: None
|
|
|
|
from fastdeploy.cache_manager import cache_messager
|
|
|
|
|
|
class _DummyBarrier:
|
|
def __init__(self):
|
|
self.calls = 0
|
|
|
|
def wait(self):
|
|
self.calls += 1
|
|
|
|
|
|
class _DummyEngineWorkerQueue:
|
|
def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwargs):
|
|
self.cache_info_sequence = list(cache_info_sequence or [])
|
|
self.connect_task_sequence = list(connect_task_sequence or [])
|
|
self.cache_info_calls = 0
|
|
self.connect_task_calls = 0
|
|
self.cache_info_barrier = _DummyBarrier()
|
|
self.finish_add_cache_task_barrier = _DummyBarrier()
|
|
self.finish_send_cache_barrier = _DummyBarrier()
|
|
self.connect_task_barrier = _DummyBarrier()
|
|
self.connect_task_response_barrier = _DummyBarrier()
|
|
self.begin_send_cache_barrier = _DummyBarrier()
|
|
self.finished_add_cache_task_req_ids = []
|
|
self.finished_req_payloads = []
|
|
self.connect_task_responses = []
|
|
|
|
def get_cache_info(self):
|
|
if self.cache_info_calls >= len(self.cache_info_sequence):
|
|
raise SystemExit
|
|
info = self.cache_info_sequence[self.cache_info_calls]
|
|
self.cache_info_calls += 1
|
|
return info
|
|
|
|
def put_finished_add_cache_task_req(self, req_ids):
|
|
self.finished_add_cache_task_req_ids.append(req_ids)
|
|
|
|
def put_finished_req(self, payload):
|
|
self.finished_req_payloads.append(payload)
|
|
|
|
def get_connect_rdma_task(self):
|
|
if self.connect_task_calls >= len(self.connect_task_sequence):
|
|
raise SystemExit
|
|
task = self.connect_task_sequence[self.connect_task_calls]
|
|
self.connect_task_calls += 1
|
|
return task, None
|
|
|
|
def put_connect_rdma_task_response(self, response):
|
|
self.connect_task_responses.append(response)
|
|
|
|
|
|
class _DummyRDMACommManager:
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.connect_calls = []
|
|
|
|
def connect(self, *args):
|
|
self.connect_calls.append(args)
|
|
return True
|
|
|
|
|
|
class _DummyIPCCommManager:
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.write_cache_calls = []
|
|
self.sync_calls = []
|
|
|
|
def write_cache(self, *args):
|
|
self.write_cache_calls.append(args)
|
|
return 0
|
|
|
|
def write_block_by_sync(self, decode_idx):
|
|
self.sync_calls.append(decode_idx)
|
|
|
|
|
|
class _DummyLogger:
|
|
def __init__(self):
|
|
self.messages = []
|
|
|
|
def info(self, msg):
|
|
self.messages.append(("info", msg))
|
|
|
|
def debug(self, msg):
|
|
self.messages.append(("debug", msg))
|
|
|
|
def error(self, msg):
|
|
self.messages.append(("error", msg))
|
|
|
|
|
|
class _DummySignalValue:
|
|
def __init__(self, sequence):
|
|
self.sequence = list(sequence)
|
|
self.set_calls = []
|
|
|
|
def __getitem__(self, index):
|
|
if not self.sequence:
|
|
return 0
|
|
value = self.sequence.pop(0)
|
|
return value
|
|
|
|
def __setitem__(self, index, value):
|
|
self.set_calls.append((index, value))
|
|
|
|
|
|
class _DummyIPCSignal:
|
|
instances = []
|
|
|
|
def __init__(self, name, array, **kwargs):
|
|
self.name = name
|
|
self.dtype = kwargs.get("dtype", np.array(array).dtype)
|
|
self.value = _DummySignalValue(array)
|
|
_DummyIPCSignal.instances.append(self)
|
|
|
|
|
|
class _DummyTensor:
|
|
def __init__(self, shape, dtype, place):
|
|
self.shape = shape
|
|
self.dtype = dtype
|
|
self.place = place
|
|
|
|
def data_ptr(self):
|
|
return 123
|
|
|
|
|
|
class _DummyPlace:
|
|
def __str__(self):
|
|
return "Place(gpu:0)"
|
|
|
|
|
|
def _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1):
|
|
gpu_cache_kvs = {}
|
|
for layer_idx in range(num_layers):
|
|
key_cache = paddle.zeros([2, 3], dtype=dtype)
|
|
gpu_cache_kvs[f"key_caches_{layer_idx}_rank0_device0"] = key_cache
|
|
if include_value_cache:
|
|
gpu_cache_kvs[f"value_caches_{layer_idx}_rank0_device0"] = paddle.zeros([2, 3], dtype=dtype)
|
|
return gpu_cache_kvs
|
|
|
|
|
|
def _build_dummy_cache_kvs(include_value_cache=True, num_layers=1):
|
|
gpu_cache_kvs = {}
|
|
for layer_idx in range(num_layers):
|
|
gpu_cache_kvs[f"key_caches_{layer_idx}_rank0_device0"] = _DummyTensor(
|
|
shape=[2, 3], dtype=paddle.float16, place=_DummyPlace()
|
|
)
|
|
if include_value_cache:
|
|
gpu_cache_kvs[f"value_caches_{layer_idx}_rank0_device0"] = _DummyTensor(
|
|
shape=[2, 3], dtype=paddle.float16, place=_DummyPlace()
|
|
)
|
|
return gpu_cache_kvs
|
|
|
|
|
|
def test_parse_args_and_get_decode_ip_idx(monkeypatch):
|
|
args = [
|
|
"prog",
|
|
"--splitwise_role",
|
|
"prefill",
|
|
"--rank",
|
|
"1",
|
|
"--device_id",
|
|
"2",
|
|
"--num_layers",
|
|
"3",
|
|
"--key_cache_shape",
|
|
"2,3,4,5",
|
|
"--value_cache_shape",
|
|
"2,3,4,5",
|
|
"--rdma_port",
|
|
"1234",
|
|
"--mp_num",
|
|
"2",
|
|
"--ipc_suffix",
|
|
"suffix",
|
|
"--protocol",
|
|
"rdma",
|
|
"--pod_ip",
|
|
"127.0.0.1",
|
|
"--cache_queue_port",
|
|
"9911",
|
|
"--engine_worker_queue_port",
|
|
"9912",
|
|
"--cache_dtype",
|
|
"uint8",
|
|
"--speculative_config",
|
|
"{}",
|
|
"--local_data_parallel_id",
|
|
"1",
|
|
]
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
monkeypatch.setattr(sys, "argv", args, raising=False)
|
|
parsed = cache_messager.parse_args()
|
|
assert parsed.splitwise_role == "prefill"
|
|
assert parsed.rank == 1
|
|
assert parsed.device_id == 2
|
|
assert parsed.cache_dtype == "uint8"
|
|
|
|
decode_ip, decode_ports = cache_messager.get_decode_ip_idx({"ip": "1.1.1.1", "rdma_ports": [3]})
|
|
assert decode_ip == "1.1.1.1"
|
|
assert decode_ports == [3]
|
|
decode_ip, decode_ports = cache_messager.get_decode_ip_idx({"decode_ip": "2.2.2.2", "decode_rdma_ports": [5]})
|
|
assert decode_ip == "2.2.2.2"
|
|
assert decode_ports == [5]
|
|
|
|
|
|
def test_cache_messager_init_rdma_block_bytes(monkeypatch):
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
|
|
messager = cache_messager.CacheMessager(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=1,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=2,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
rdma_port="1111",
|
|
)
|
|
assert messager.block_bytes == 6
|
|
assert messager.rank_id == 2
|
|
assert "rdma" in messager.messager
|
|
|
|
|
|
def test_cache_messager_init_ipc_uses_local_device_id(monkeypatch):
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
dummy_key = _DummyTensor(shape=[2, 3], dtype=paddle.float16, place=_DummyPlace())
|
|
dummy_val = _DummyTensor(shape=[2, 3], dtype=paddle.float16, place=_DummyPlace())
|
|
gpu_cache_kvs = {
|
|
"key_caches_0_rank0_device0": dummy_key,
|
|
"value_caches_0_rank0_device0": dummy_val,
|
|
}
|
|
messager = cache_messager.CacheMessager(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="ipc",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
rdma_port=None,
|
|
)
|
|
assert messager.block_bytes == 6
|
|
assert "ipc" in messager.messager
|
|
|
|
|
|
def test_cache_messager_prefill_layerwise_send_cache_thread(monkeypatch):
|
|
class _PrefillIPCSignal:
|
|
def __init__(self, name, array, **kwargs):
|
|
if "step" in name:
|
|
sequence = [0, 1, 1]
|
|
else:
|
|
sequence = [0, 0, 0]
|
|
self.name = name
|
|
self.dtype = kwargs.get("dtype", np.array(array).dtype)
|
|
self.value = _DummySignalValue(sequence)
|
|
|
|
dummy_queue = _DummyEngineWorkerQueue(
|
|
cache_info_sequence=[
|
|
[
|
|
{
|
|
"request_id": "req-1",
|
|
"src_block_ids": [0],
|
|
"dest_block_ids": [1],
|
|
"status": "init",
|
|
"current_id": 0,
|
|
"transfer_protocol": "ipc",
|
|
"device_ids": [0],
|
|
}
|
|
],
|
|
None,
|
|
None,
|
|
]
|
|
)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
|
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
|
monkeypatch.setattr(cache_messager, "shared_memory_exists", lambda name: False)
|
|
monkeypatch.setattr(cache_messager, "IPCSignal", _PrefillIPCSignal)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=2)
|
|
messager = cache_messager.CacheMessager(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="ipc",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=2,
|
|
gpu_id=0,
|
|
rdma_port=None,
|
|
)
|
|
_DummyIPCSignal.instances.clear()
|
|
with pytest.raises(SystemExit):
|
|
messager.prefill_layerwise_send_cache_thread()
|
|
assert dummy_queue.finished_req_payloads
|
|
assert dummy_queue.finished_req_payloads[0][0][0] == "req-1"
|
|
|
|
|
|
def test_cache_messager_v1_add_cache_task_thread(monkeypatch):
|
|
dummy_queue = _DummyEngineWorkerQueue(
|
|
cache_info_sequence=[
|
|
[
|
|
{
|
|
"request_id": "req-2",
|
|
"src_block_ids": [0, 1, 2],
|
|
"dest_block_ids": [3],
|
|
"current_id": 7,
|
|
"need_prefill_tokens": 128,
|
|
"transfer_protocol": "rdma",
|
|
},
|
|
{"request_id": "req-new"},
|
|
]
|
|
]
|
|
)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
|
|
messager = cache_messager.CacheMessagerV1(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port="2222",
|
|
)
|
|
messager.cache_info["req-2"] = {
|
|
"request_id": "req-2",
|
|
"src_block_ids": [0, 1, 2],
|
|
"dest_block_ids": [3],
|
|
"current_id": 7,
|
|
"need_prefill_tokens": 128,
|
|
"transfer_protocol": "rdma",
|
|
}
|
|
with pytest.raises(SystemExit):
|
|
messager._add_cache_task_thread()
|
|
assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]]
|
|
assert messager.cache_info["req-2"]["status"] == "init"
|
|
|
|
|
|
def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch):
|
|
class _OneShotQueue:
|
|
def __init__(self):
|
|
self.called = False
|
|
|
|
def get(self):
|
|
if self.called:
|
|
raise SystemExit
|
|
self.called = True
|
|
return [(0, 64)]
|
|
|
|
dummy_queue = _DummyEngineWorkerQueue()
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
|
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=2)
|
|
messager = cache_messager.CacheMessagerV1(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="ipc",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=2,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port=None,
|
|
)
|
|
messager.cache_prefilled_engine_ids_queue = _OneShotQueue()
|
|
messager.idx_cache_task_dict[0] = {
|
|
"request_id": "req-3",
|
|
"src_block_ids": [0],
|
|
"dest_block_ids": [1],
|
|
"transfer_protocol": "ipc",
|
|
"device_ids": [0],
|
|
"need_prefill_tokens": 64,
|
|
"sended_layer_id": -1,
|
|
"sended_block_num": 0,
|
|
"status": "init",
|
|
"current_id": 0,
|
|
}
|
|
messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64}
|
|
messager.cache_info["req-3"] = messager.idx_cache_task_dict[0]
|
|
with pytest.raises(SystemExit):
|
|
messager.prefill_layerwise_send_cache_thread()
|
|
assert dummy_queue.finished_req_payloads
|
|
assert dummy_queue.finished_req_payloads[0][0][0] == "req-3"
|
|
|
|
|
|
def test_cache_messager_v1_handle_connect_task(monkeypatch):
|
|
dummy_queue = _DummyEngineWorkerQueue(
|
|
connect_task_sequence=[
|
|
{"task_id": 1, "decode_ip": "1.1.1.1", "decode_rdma_ports": [1234, 5678]},
|
|
{"task_id": 2, "decode_ip": "2.2.2.2", "decode_rdma_ports": [4321]},
|
|
]
|
|
)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
|
messager = cache_messager.CacheMessagerV1(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=2,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port="2222",
|
|
)
|
|
with pytest.raises(SystemExit):
|
|
messager._handle_connect_task()
|
|
assert dummy_queue.connect_task_responses[0]["success"] is True
|
|
assert dummy_queue.connect_task_responses[1]["success"] is True
|
|
|
|
|
|
def test_cache_messager_init_shm_and_xpu_paths(monkeypatch):
|
|
monkeypatch.setattr(cache_messager.envs, "FD_ENGINE_TASK_QUEUE_WITH_SHM", True)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "get_peer_mem_addr", lambda ptr: ptr + 1)
|
|
monkeypatch.setattr(cache_messager.paddle, "is_compiled_with_xpu", lambda: True)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=1)
|
|
messager = cache_messager.CacheMessager(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
rdma_port="1111",
|
|
)
|
|
assert messager.engine_worker_queue is not None
|
|
|
|
|
|
def test_cache_messager_handle_connect_task_error(monkeypatch):
|
|
dummy_queue = _DummyEngineWorkerQueue(
|
|
connect_task_sequence=[
|
|
None,
|
|
{"task_id": 1, "decode_ip": "1.1.1.1", "decode_rdma_ports": [1234]},
|
|
]
|
|
)
|
|
|
|
class _FailingRDMACommManager(_DummyRDMACommManager):
|
|
def connect(self, *args):
|
|
self.connect_calls.append(args)
|
|
return False
|
|
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _FailingRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
|
messager = cache_messager.CacheMessager(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
rdma_port="1111",
|
|
)
|
|
with pytest.raises(SystemExit):
|
|
messager._handle_connect_task()
|
|
assert dummy_queue.connect_task_responses[0]["success"] is False
|
|
|
|
|
|
def test_cache_messager_v1_shm_xpu_and_bfloat16(monkeypatch):
|
|
monkeypatch.setattr(cache_messager.envs, "FD_ENGINE_TASK_QUEUE_WITH_SHM", True)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "get_peer_mem_addr", lambda ptr: ptr + 1)
|
|
monkeypatch.setattr(cache_messager.paddle, "is_compiled_with_xpu", lambda: True)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=1)
|
|
gpu_cache_kvs["key_caches_0_rank0_device0"].dtype = paddle.bfloat16
|
|
messager = cache_messager.CacheMessagerV1(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port="2222",
|
|
)
|
|
assert messager.block_bytes == 6
|
|
|
|
|
|
def test_cache_messager_v1_consume_signals(monkeypatch):
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
|
|
class _QueueRecorder:
|
|
def __init__(self):
|
|
self.items = []
|
|
|
|
def put(self, item):
|
|
self.items.append(item)
|
|
|
|
counter = {"calls": 0}
|
|
|
|
def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
|
|
if counter["calls"] > 0:
|
|
raise SystemExit
|
|
counter["calls"] += 1
|
|
data = np.full(kv_signal_data.shape, -1, dtype="int32")
|
|
data[0] = 1
|
|
data[1] = 0
|
|
data[2] = 2
|
|
data[3] = 4
|
|
data[4] = 5
|
|
kv_signal_data.set_value(data)
|
|
|
|
monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal)
|
|
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
|
messager = cache_messager.CacheMessagerV1(
|
|
splitwise_role="mixed",
|
|
transfer_protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
local_data_parallel_id=0,
|
|
gpu_cache_kvs=gpu_cache_kvs,
|
|
rank=0,
|
|
nranks=1,
|
|
num_layers=1,
|
|
gpu_id=0,
|
|
block_size=64,
|
|
rdma_port="2222",
|
|
)
|
|
messager.cache_info["req-4"] = {"request_id": "req-4"}
|
|
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
|
|
with pytest.raises(SystemExit):
|
|
messager.consume_signals()
|
|
assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]]
|
|
|
|
|
|
def test_main_initializes_cache_and_exits(monkeypatch):
|
|
monkeypatch.setattr(cache_messager, "set_device", lambda device: None)
|
|
monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None)
|
|
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
|
monkeypatch.setattr(cache_messager, "IPCSignal", _DummyIPCSignal)
|
|
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
|
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
|
monkeypatch.setattr(cache_messager.envs, "ENABLE_V1_KVCACHE_SCHEDULER", True)
|
|
|
|
class _DummySpeculativeConfig:
|
|
def __init__(self, *args, **kwargs):
|
|
self.num_extra_cache_layer = 0
|
|
self.num_gpu_block_expand_ratio = 0
|
|
|
|
monkeypatch.setattr(cache_messager, "SpeculativeConfig", _DummySpeculativeConfig)
|
|
monkeypatch.setattr(cache_messager.CacheMessagerV1, "_handle_connect_task", lambda self: None)
|
|
|
|
args = types.SimpleNamespace(
|
|
device_id=0,
|
|
rank=0,
|
|
default_dtype="float16",
|
|
cache_dtype="float16",
|
|
key_cache_shape="1,1,1,1",
|
|
value_cache_shape="",
|
|
mp_num=1,
|
|
num_layers=1,
|
|
splitwise_role="decode",
|
|
protocol="rdma",
|
|
pod_ip="0.0.0.0",
|
|
engine_worker_queue_port=9000,
|
|
cache_queue_port=9001,
|
|
ipc_suffix=None,
|
|
rdma_port="2222",
|
|
speculative_config={},
|
|
local_data_parallel_id=0,
|
|
)
|
|
monkeypatch.setattr(cache_messager, "args", args, raising=False)
|
|
monkeypatch.setattr(
|
|
cache_messager.CacheMessagerV1,
|
|
"prefill_layerwise_send_cache_thread",
|
|
lambda self: (_ for _ in ()).throw(SystemExit),
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cache_messager.main()
|