mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 17:11:21 +08:00
[CI] 【Hackathon 9th Sprint No.33】NO.33 功能模块单测补充 -new (#5726)
* Add cache messager coverage tests * Add default_dtype parameter to test cache manager --------- Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
This commit is contained in:
@@ -0,0 +1,643 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pytest
|
||||
|
||||
if not hasattr(paddle, "compat"):
|
||||
paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None)
|
||||
|
||||
from fastdeploy.cache_manager import cache_messager
|
||||
|
||||
|
||||
class _DummyBarrier:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
def wait(self):
|
||||
self.calls += 1
|
||||
|
||||
|
||||
class _DummyEngineWorkerQueue:
|
||||
def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwargs):
|
||||
self.cache_info_sequence = list(cache_info_sequence or [])
|
||||
self.connect_task_sequence = list(connect_task_sequence or [])
|
||||
self.cache_info_calls = 0
|
||||
self.connect_task_calls = 0
|
||||
self.cache_info_barrier = _DummyBarrier()
|
||||
self.finish_add_cache_task_barrier = _DummyBarrier()
|
||||
self.finish_send_cache_barrier = _DummyBarrier()
|
||||
self.connect_task_barrier = _DummyBarrier()
|
||||
self.connect_task_response_barrier = _DummyBarrier()
|
||||
self.begin_send_cache_barrier = _DummyBarrier()
|
||||
self.finished_add_cache_task_req_ids = []
|
||||
self.finished_req_payloads = []
|
||||
self.connect_task_responses = []
|
||||
|
||||
def get_cache_info(self):
|
||||
if self.cache_info_calls >= len(self.cache_info_sequence):
|
||||
raise SystemExit
|
||||
info = self.cache_info_sequence[self.cache_info_calls]
|
||||
self.cache_info_calls += 1
|
||||
return info
|
||||
|
||||
def put_finished_add_cache_task_req(self, req_ids):
|
||||
self.finished_add_cache_task_req_ids.append(req_ids)
|
||||
|
||||
def put_finished_req(self, payload):
|
||||
self.finished_req_payloads.append(payload)
|
||||
|
||||
def get_connect_rdma_task(self):
|
||||
if self.connect_task_calls >= len(self.connect_task_sequence):
|
||||
raise SystemExit
|
||||
task = self.connect_task_sequence[self.connect_task_calls]
|
||||
self.connect_task_calls += 1
|
||||
return task, None
|
||||
|
||||
def put_connect_rdma_task_response(self, response):
|
||||
self.connect_task_responses.append(response)
|
||||
|
||||
|
||||
class _DummyRDMACommManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.connect_calls = []
|
||||
|
||||
def connect(self, *args):
|
||||
self.connect_calls.append(args)
|
||||
return True
|
||||
|
||||
|
||||
class _DummyIPCCommManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.write_cache_calls = []
|
||||
self.sync_calls = []
|
||||
|
||||
def write_cache(self, *args):
|
||||
self.write_cache_calls.append(args)
|
||||
return 0
|
||||
|
||||
def write_block_by_sync(self, decode_idx):
|
||||
self.sync_calls.append(decode_idx)
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
def info(self, msg):
|
||||
self.messages.append(("info", msg))
|
||||
|
||||
def debug(self, msg):
|
||||
self.messages.append(("debug", msg))
|
||||
|
||||
def error(self, msg):
|
||||
self.messages.append(("error", msg))
|
||||
|
||||
|
||||
class _DummySignalValue:
|
||||
def __init__(self, sequence):
|
||||
self.sequence = list(sequence)
|
||||
self.set_calls = []
|
||||
|
||||
def __getitem__(self, index):
|
||||
if not self.sequence:
|
||||
return 0
|
||||
value = self.sequence.pop(0)
|
||||
return value
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
self.set_calls.append((index, value))
|
||||
|
||||
|
||||
class _DummyIPCSignal:
|
||||
instances = []
|
||||
|
||||
def __init__(self, name, array, **kwargs):
|
||||
self.name = name
|
||||
self.dtype = kwargs.get("dtype", np.array(array).dtype)
|
||||
self.value = _DummySignalValue(array)
|
||||
_DummyIPCSignal.instances.append(self)
|
||||
|
||||
|
||||
class _DummyTensor:
|
||||
def __init__(self, shape, dtype, place):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.place = place
|
||||
|
||||
def data_ptr(self):
|
||||
return 123
|
||||
|
||||
|
||||
class _DummyPlace:
|
||||
def __str__(self):
|
||||
return "Place(gpu:0)"
|
||||
|
||||
|
||||
def _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1):
|
||||
gpu_cache_kvs = {}
|
||||
for layer_idx in range(num_layers):
|
||||
key_cache = paddle.zeros([2, 3], dtype=dtype)
|
||||
gpu_cache_kvs[f"key_caches_{layer_idx}_rank0_device0"] = key_cache
|
||||
if include_value_cache:
|
||||
gpu_cache_kvs[f"value_caches_{layer_idx}_rank0_device0"] = paddle.zeros([2, 3], dtype=dtype)
|
||||
return gpu_cache_kvs
|
||||
|
||||
|
||||
def _build_dummy_cache_kvs(include_value_cache=True, num_layers=1):
|
||||
gpu_cache_kvs = {}
|
||||
for layer_idx in range(num_layers):
|
||||
gpu_cache_kvs[f"key_caches_{layer_idx}_rank0_device0"] = _DummyTensor(
|
||||
shape=[2, 3], dtype=paddle.float16, place=_DummyPlace()
|
||||
)
|
||||
if include_value_cache:
|
||||
gpu_cache_kvs[f"value_caches_{layer_idx}_rank0_device0"] = _DummyTensor(
|
||||
shape=[2, 3], dtype=paddle.float16, place=_DummyPlace()
|
||||
)
|
||||
return gpu_cache_kvs
|
||||
|
||||
|
||||
def test_parse_args_and_get_decode_ip_idx(monkeypatch):
|
||||
args = [
|
||||
"prog",
|
||||
"--splitwise_role",
|
||||
"prefill",
|
||||
"--rank",
|
||||
"1",
|
||||
"--device_id",
|
||||
"2",
|
||||
"--num_layers",
|
||||
"3",
|
||||
"--key_cache_shape",
|
||||
"2,3,4,5",
|
||||
"--value_cache_shape",
|
||||
"2,3,4,5",
|
||||
"--rdma_port",
|
||||
"1234",
|
||||
"--mp_num",
|
||||
"2",
|
||||
"--ipc_suffix",
|
||||
"suffix",
|
||||
"--protocol",
|
||||
"rdma",
|
||||
"--pod_ip",
|
||||
"127.0.0.1",
|
||||
"--cache_queue_port",
|
||||
"9911",
|
||||
"--engine_worker_queue_port",
|
||||
"9912",
|
||||
"--cache_dtype",
|
||||
"uint8",
|
||||
"--speculative_config",
|
||||
"{}",
|
||||
"--local_data_parallel_id",
|
||||
"1",
|
||||
]
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
monkeypatch.setattr(sys, "argv", args, raising=False)
|
||||
parsed = cache_messager.parse_args()
|
||||
assert parsed.splitwise_role == "prefill"
|
||||
assert parsed.rank == 1
|
||||
assert parsed.device_id == 2
|
||||
assert parsed.cache_dtype == "uint8"
|
||||
|
||||
decode_ip, decode_ports = cache_messager.get_decode_ip_idx({"ip": "1.1.1.1", "rdma_ports": [3]})
|
||||
assert decode_ip == "1.1.1.1"
|
||||
assert decode_ports == [3]
|
||||
decode_ip, decode_ports = cache_messager.get_decode_ip_idx({"decode_ip": "2.2.2.2", "decode_rdma_ports": [5]})
|
||||
assert decode_ip == "2.2.2.2"
|
||||
assert decode_ports == [5]
|
||||
|
||||
|
||||
def test_cache_messager_init_rdma_block_bytes(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
|
||||
messager = cache_messager.CacheMessager(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=1,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=2,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
rdma_port="1111",
|
||||
)
|
||||
assert messager.block_bytes == 6
|
||||
assert messager.rank_id == 2
|
||||
assert "rdma" in messager.messager
|
||||
|
||||
|
||||
def test_cache_messager_init_ipc_uses_local_device_id(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
dummy_key = _DummyTensor(shape=[2, 3], dtype=paddle.float16, place=_DummyPlace())
|
||||
dummy_val = _DummyTensor(shape=[2, 3], dtype=paddle.float16, place=_DummyPlace())
|
||||
gpu_cache_kvs = {
|
||||
"key_caches_0_rank0_device0": dummy_key,
|
||||
"value_caches_0_rank0_device0": dummy_val,
|
||||
}
|
||||
messager = cache_messager.CacheMessager(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="ipc",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
)
|
||||
assert messager.block_bytes == 6
|
||||
assert "ipc" in messager.messager
|
||||
|
||||
|
||||
def test_cache_messager_prefill_layerwise_send_cache_thread(monkeypatch):
|
||||
class _PrefillIPCSignal:
|
||||
def __init__(self, name, array, **kwargs):
|
||||
if "step" in name:
|
||||
sequence = [0, 1, 1]
|
||||
else:
|
||||
sequence = [0, 0, 0]
|
||||
self.name = name
|
||||
self.dtype = kwargs.get("dtype", np.array(array).dtype)
|
||||
self.value = _DummySignalValue(sequence)
|
||||
|
||||
dummy_queue = _DummyEngineWorkerQueue(
|
||||
cache_info_sequence=[
|
||||
[
|
||||
{
|
||||
"request_id": "req-1",
|
||||
"src_block_ids": [0],
|
||||
"dest_block_ids": [1],
|
||||
"status": "init",
|
||||
"current_id": 0,
|
||||
"transfer_protocol": "ipc",
|
||||
"device_ids": [0],
|
||||
}
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
||||
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
||||
monkeypatch.setattr(cache_messager, "shared_memory_exists", lambda name: False)
|
||||
monkeypatch.setattr(cache_messager, "IPCSignal", _PrefillIPCSignal)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=2)
|
||||
messager = cache_messager.CacheMessager(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="ipc",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=2,
|
||||
gpu_id=0,
|
||||
rdma_port=None,
|
||||
)
|
||||
_DummyIPCSignal.instances.clear()
|
||||
with pytest.raises(SystemExit):
|
||||
messager.prefill_layerwise_send_cache_thread()
|
||||
assert dummy_queue.finished_req_payloads
|
||||
assert dummy_queue.finished_req_payloads[0][0][0] == "req-1"
|
||||
|
||||
|
||||
def test_cache_messager_v1_add_cache_task_thread(monkeypatch):
|
||||
dummy_queue = _DummyEngineWorkerQueue(
|
||||
cache_info_sequence=[
|
||||
[
|
||||
{
|
||||
"request_id": "req-2",
|
||||
"src_block_ids": [0, 1, 2],
|
||||
"dest_block_ids": [3],
|
||||
"current_id": 7,
|
||||
"need_prefill_tokens": 128,
|
||||
"transfer_protocol": "rdma",
|
||||
},
|
||||
{"request_id": "req-new"},
|
||||
]
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1)
|
||||
messager = cache_messager.CacheMessagerV1(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port="2222",
|
||||
)
|
||||
messager.cache_info["req-2"] = {
|
||||
"request_id": "req-2",
|
||||
"src_block_ids": [0, 1, 2],
|
||||
"dest_block_ids": [3],
|
||||
"current_id": 7,
|
||||
"need_prefill_tokens": 128,
|
||||
"transfer_protocol": "rdma",
|
||||
}
|
||||
with pytest.raises(SystemExit):
|
||||
messager._add_cache_task_thread()
|
||||
assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]]
|
||||
assert messager.cache_info["req-2"]["status"] == "init"
|
||||
|
||||
|
||||
def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch):
|
||||
class _OneShotQueue:
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
|
||||
def get(self):
|
||||
if self.called:
|
||||
raise SystemExit
|
||||
self.called = True
|
||||
return [(0, 64)]
|
||||
|
||||
dummy_queue = _DummyEngineWorkerQueue()
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
||||
monkeypatch.setattr(cache_messager, "IPCCommManager", _DummyIPCCommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=2)
|
||||
messager = cache_messager.CacheMessagerV1(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="ipc",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=2,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port=None,
|
||||
)
|
||||
messager.cache_prefilled_engine_ids_queue = _OneShotQueue()
|
||||
messager.idx_cache_task_dict[0] = {
|
||||
"request_id": "req-3",
|
||||
"src_block_ids": [0],
|
||||
"dest_block_ids": [1],
|
||||
"transfer_protocol": "ipc",
|
||||
"device_ids": [0],
|
||||
"need_prefill_tokens": 64,
|
||||
"sended_layer_id": -1,
|
||||
"sended_block_num": 0,
|
||||
"status": "init",
|
||||
"current_id": 0,
|
||||
}
|
||||
messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64}
|
||||
messager.cache_info["req-3"] = messager.idx_cache_task_dict[0]
|
||||
with pytest.raises(SystemExit):
|
||||
messager.prefill_layerwise_send_cache_thread()
|
||||
assert dummy_queue.finished_req_payloads
|
||||
assert dummy_queue.finished_req_payloads[0][0][0] == "req-3"
|
||||
|
||||
|
||||
def test_cache_messager_v1_handle_connect_task(monkeypatch):
|
||||
dummy_queue = _DummyEngineWorkerQueue(
|
||||
connect_task_sequence=[
|
||||
{"task_id": 1, "decode_ip": "1.1.1.1", "decode_rdma_ports": [1234, 5678]},
|
||||
{"task_id": 2, "decode_ip": "2.2.2.2", "decode_rdma_ports": [4321]},
|
||||
]
|
||||
)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
||||
messager = cache_messager.CacheMessagerV1(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=2,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port="2222",
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
messager._handle_connect_task()
|
||||
assert dummy_queue.connect_task_responses[0]["success"] is True
|
||||
assert dummy_queue.connect_task_responses[1]["success"] is True
|
||||
|
||||
|
||||
def test_cache_messager_init_shm_and_xpu_paths(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager.envs, "FD_ENGINE_TASK_QUEUE_WITH_SHM", True)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "get_peer_mem_addr", lambda ptr: ptr + 1)
|
||||
monkeypatch.setattr(cache_messager.paddle, "is_compiled_with_xpu", lambda: True)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=1)
|
||||
messager = cache_messager.CacheMessager(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
rdma_port="1111",
|
||||
)
|
||||
assert messager.engine_worker_queue is not None
|
||||
|
||||
|
||||
def test_cache_messager_handle_connect_task_error(monkeypatch):
|
||||
dummy_queue = _DummyEngineWorkerQueue(
|
||||
connect_task_sequence=[
|
||||
None,
|
||||
{"task_id": 1, "decode_ip": "1.1.1.1", "decode_rdma_ports": [1234]},
|
||||
]
|
||||
)
|
||||
|
||||
class _FailingRDMACommManager(_DummyRDMACommManager):
|
||||
def connect(self, *args):
|
||||
self.connect_calls.append(args)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _FailingRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
||||
messager = cache_messager.CacheMessager(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
rdma_port="1111",
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
messager._handle_connect_task()
|
||||
assert dummy_queue.connect_task_responses[0]["success"] is False
|
||||
|
||||
|
||||
def test_cache_messager_v1_shm_xpu_and_bfloat16(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager.envs, "FD_ENGINE_TASK_QUEUE_WITH_SHM", True)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "get_peer_mem_addr", lambda ptr: ptr + 1)
|
||||
monkeypatch.setattr(cache_messager.paddle, "is_compiled_with_xpu", lambda: True)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
gpu_cache_kvs = _build_dummy_cache_kvs(include_value_cache=True, num_layers=1)
|
||||
gpu_cache_kvs["key_caches_0_rank0_device0"].dtype = paddle.bfloat16
|
||||
messager = cache_messager.CacheMessagerV1(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port="2222",
|
||||
)
|
||||
assert messager.block_bytes == 6
|
||||
|
||||
|
||||
def test_cache_messager_v1_consume_signals(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
|
||||
class _QueueRecorder:
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
|
||||
def put(self, item):
|
||||
self.items.append(item)
|
||||
|
||||
counter = {"calls": 0}
|
||||
|
||||
def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag):
|
||||
if counter["calls"] > 0:
|
||||
raise SystemExit
|
||||
counter["calls"] += 1
|
||||
data = np.full(kv_signal_data.shape, -1, dtype="int32")
|
||||
data[0] = 1
|
||||
data[1] = 0
|
||||
data[2] = 2
|
||||
data[3] = 4
|
||||
data[4] = 5
|
||||
kv_signal_data.set_value(data)
|
||||
|
||||
monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal)
|
||||
gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1)
|
||||
messager = cache_messager.CacheMessagerV1(
|
||||
splitwise_role="mixed",
|
||||
transfer_protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
local_data_parallel_id=0,
|
||||
gpu_cache_kvs=gpu_cache_kvs,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
num_layers=1,
|
||||
gpu_id=0,
|
||||
block_size=64,
|
||||
rdma_port="2222",
|
||||
)
|
||||
messager.cache_info["req-4"] = {"request_id": "req-4"}
|
||||
messager.cache_prefilled_engine_ids_queue = _QueueRecorder()
|
||||
with pytest.raises(SystemExit):
|
||||
messager.consume_signals()
|
||||
assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]]
|
||||
|
||||
|
||||
def test_main_initializes_cache_and_exits(monkeypatch):
|
||||
monkeypatch.setattr(cache_messager, "set_device", lambda device: None)
|
||||
monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None)
|
||||
monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue)
|
||||
monkeypatch.setattr(cache_messager, "IPCSignal", _DummyIPCSignal)
|
||||
monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager)
|
||||
monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False)
|
||||
monkeypatch.setattr(cache_messager.envs, "ENABLE_V1_KVCACHE_SCHEDULER", True)
|
||||
|
||||
class _DummySpeculativeConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.num_extra_cache_layer = 0
|
||||
self.num_gpu_block_expand_ratio = 0
|
||||
|
||||
monkeypatch.setattr(cache_messager, "SpeculativeConfig", _DummySpeculativeConfig)
|
||||
monkeypatch.setattr(cache_messager.CacheMessagerV1, "_handle_connect_task", lambda self: None)
|
||||
|
||||
args = types.SimpleNamespace(
|
||||
device_id=0,
|
||||
rank=0,
|
||||
default_dtype="float16",
|
||||
cache_dtype="float16",
|
||||
key_cache_shape="1,1,1,1",
|
||||
value_cache_shape="",
|
||||
mp_num=1,
|
||||
num_layers=1,
|
||||
splitwise_role="decode",
|
||||
protocol="rdma",
|
||||
pod_ip="0.0.0.0",
|
||||
engine_worker_queue_port=9000,
|
||||
cache_queue_port=9001,
|
||||
ipc_suffix=None,
|
||||
rdma_port="2222",
|
||||
speculative_config={},
|
||||
local_data_parallel_id=0,
|
||||
)
|
||||
monkeypatch.setattr(cache_messager, "args", args, raising=False)
|
||||
monkeypatch.setattr(
|
||||
cache_messager.CacheMessagerV1,
|
||||
"prefill_layerwise_send_cache_thread",
|
||||
lambda self: (_ for _ in ()).throw(SystemExit),
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cache_messager.main()
|
||||
Reference in New Issue
Block a user