Files
FastDeploy/tests/inter_communicator/test_zmq_server.py
T

642 lines
25 KiB
Python

"""
Tests for ZmqServerBase and derived helpers covering response/recv paths.
"""
import os
import tempfile
import threading
import time
import unittest
from collections import defaultdict
from unittest import mock
import msgpack
import paddle
import zmq
from zmq.utils import jsonapi
if not hasattr(paddle, "enable_compat"):
paddle.enable_compat = lambda **kwargs: None
from fastdeploy import envs
from fastdeploy.inter_communicator.zmq_server import (
ZmqIpcServer,
ZmqServerBase,
ZmqTcpServer,
_msgpack_default,
)
class _DummyResponse:
def __init__(self, value, finished=False):
self.value = value
self.finished = finished
self.tensor = paddle.to_tensor([value])
def add(self, other):
self.value += other.value
self.tensor = self.tensor + other.tensor
self.finished = self.finished or other.finished
def to_dict(self):
return {
"value": int(self.value),
"finished": bool(self.finished),
"tensor_sum": int(self.tensor.sum()),
}
def __getstate__(self):
return {"value": self.value, "finished": self.finished}
def __setstate__(self, state):
self.value = state["value"]
self.finished = state["finished"]
self.tensor = paddle.to_tensor([self.value])
class _FakeSocket:
def __init__(self):
self.closed = False
self.sent = []
self.recv_payload = None
self.recv_multipart_payload = None
self.options = {}
def setsockopt(self, option, value):
self.options[option] = value
def bind(self, address):
self.address = address
def send(self, msg, flags=0, **kwargs):
self.sent.append(("send", msg, flags, kwargs))
return msg
def send_multipart(self, parts, copy=True):
self.sent.append(("send_multipart", parts, copy))
def recv(self, flags=0):
if isinstance(self.recv_payload, Exception):
raise self.recv_payload
return self.recv_payload
def recv_multipart(self, flags=0):
if isinstance(self.recv_multipart_payload, Exception):
raise self.recv_multipart_payload
return self.recv_multipart_payload
def _deserialize(self, msg, loader):
return loader(msg)
def close(self):
self.closed = True
class _FakeContext:
def __init__(self):
self.closed = False
self.socket_instance = _FakeSocket()
def socket(self, mode):
self.socket_instance.mode = mode
return self.socket_instance
def term(self):
self.closed = True
class _DummyServer(ZmqServerBase):
def __init__(self, socket=None):
super().__init__()
self.socket = socket
self.running = False
self.mutex = threading.Lock()
self.req_dict = {}
self.aggregate_send = False
self.address = "test-address"
self.response_token_lock = threading.Lock()
def _create_socket(self):
return self.socket
def close(self):
self.closed = True
class TestMsgpackDefault(unittest.TestCase):
"""Tests for _msgpack_default fallback serializer."""
def test_tensor_with_tolist(self):
"""Test serialization of objects with tolist method (paddle.Tensor, numpy.ndarray)."""
# Test with paddle tensor
tensor = paddle.to_tensor([1, 2, 3])
result = _msgpack_default(tensor)
self.assertEqual(result, [1, 2, 3])
# Test with numpy array
import numpy as np
arr = np.array([[1, 2], [3, 4]])
result = _msgpack_default(arr)
self.assertEqual(result, [[1, 2], [3, 4]])
def test_namedtuple(self):
"""Test serialization of NamedTuple objects."""
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
point = Point(10, 20)
result = _msgpack_default(point)
self.assertEqual(result, [10, 20])
def test_dataclass(self):
"""Test serialization of dataclass objects."""
from dataclasses import dataclass
@dataclass
class SampleData:
name: str
value: int
data = SampleData(name="test", value=42)
result = _msgpack_default(data)
self.assertEqual(result, {"name": "test", "value": 42})
def test_unsupported_type_raises(self):
"""Test that unsupported types raise TypeError."""
class UnsupportedClass:
pass
with self.assertRaises(TypeError) as ctx:
_msgpack_default(UnsupportedClass())
self.assertIn("UnsupportedClass", str(ctx.exception))
self.assertIn("not msgpack serializable", str(ctx.exception))
def test_integration_with_msgpack(self):
"""Test that _msgpack_default works correctly with msgpack.packb."""
# Test with mixed data types
from dataclasses import dataclass
@dataclass
class Config:
lr: float
steps: int
data = {
"tensor": paddle.to_tensor([1.0, 2.0, 3.0]),
"config": Config(lr=0.001, steps=100),
"regular_list": [1, 2, 3],
}
packed = msgpack.packb(data, default=_msgpack_default)
unpacked = msgpack.unpackb(packed)
self.assertEqual(unpacked["tensor"], [1.0, 2.0, 3.0])
self.assertEqual(unpacked["config"], {"lr": 0.001, "steps": 100})
self.assertEqual(unpacked["regular_list"], [1, 2, 3])
class TestZmqServerBase(unittest.TestCase):
def test_send_and_recv_json_roundtrip(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
payload = {"hello": "world"}
server.send_json(payload)
sent_msg = fake_socket.sent[-1][1]
envelope = jsonapi.loads(sent_msg)
self.assertIn("__meta", envelope)
self.assertEqual(envelope["data"], payload)
recv_envelope = {"__meta": {"send_ts": time.perf_counter() - 0.01}, "data": {"ok": True}}
fake_socket.recv_payload = jsonapi.dumps(recv_envelope)
self.assertEqual(server.recv_json(), {"ok": True})
def test_ensure_socket_creates_socket(self):
server = _DummyServer(socket=None)
server._create_socket = lambda: _FakeSocket()
server._ensure_socket()
self.assertIsNotNone(server.socket)
def test_recv_json_returns_raw_payload(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
fake_socket.recv_payload = jsonapi.dumps(["plain", "list"])
self.assertEqual(server.recv_json(), ["plain", "list"])
def test_send_json_raises_on_socket_error(self):
class _ErrorSocket(_FakeSocket):
def send(self, msg, flags=0):
raise RuntimeError("send failed")
server = _DummyServer(socket=_ErrorSocket())
with self.assertRaises(RuntimeError):
server.send_json({"boom": True})
def test_recv_pyobj_meta_envelope(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
data = {"token": 1}
envelope = {"__meta": {"send_ts": time.perf_counter() - 0.05}, "data": data}
fake_socket.recv_payload = b"payload"
with mock.patch("fastdeploy.inter_communicator.zmq_server.ForkingPickler.loads", return_value=envelope):
self.assertEqual(server.recv_pyobj(), data)
def test_send_pyobj_and_recv_pyobj_fallback(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.send_pyobj({"hello": "world"})
self.assertEqual(fake_socket.sent[-1][0], "send")
envelope = {"payload": "raw"}
fake_socket.recv_payload = b"payload"
with mock.patch("fastdeploy.inter_communicator.zmq_server.ForkingPickler.loads", return_value=envelope):
self.assertEqual(server.recv_pyobj(), envelope)
def test_send_pyobj_raises_on_socket_error(self):
class _ErrorSocket(_FakeSocket):
def send(self, msg, flags=0, **kwargs):
raise RuntimeError("send failed")
server = _DummyServer(socket=_ErrorSocket())
with self.assertRaises(RuntimeError):
server.send_pyobj({"boom": True})
def test_receive_json_once_paths(self):
fake_socket = _FakeSocket()
fake_socket.closed = True
server = _DummyServer(socket=fake_socket)
error, data = server.receive_json_once()
self.assertEqual(error, "zmp socket has closed")
self.assertIsNone(data)
server = _DummyServer(socket=_FakeSocket())
server.recv_json = mock.Mock(side_effect=zmq.Again())
error, data = server.receive_json_once()
self.assertIsNone(error)
self.assertIsNone(data)
server = _DummyServer(socket=_FakeSocket())
server.recv_json = mock.Mock(side_effect=ValueError("boom"))
error, data = server.receive_json_once()
self.assertEqual(error, "boom")
self.assertIsNone(data)
self.assertTrue(server.closed)
def test_receive_pyobj_once_paths(self):
fake_socket = _FakeSocket()
fake_socket.closed = True
server = _DummyServer(socket=fake_socket)
error, data = server.receive_pyobj_once()
self.assertEqual(error, "zmp socket has closed")
self.assertIsNone(data)
server = _DummyServer(socket=_FakeSocket())
server.recv_pyobj = mock.Mock(side_effect=zmq.Again())
error, data = server.receive_pyobj_once()
self.assertIsNone(error)
self.assertIsNone(data)
server = _DummyServer(socket=_FakeSocket())
server.recv_pyobj = mock.Mock(side_effect=ValueError("boom"))
error, data = server.receive_pyobj_once()
self.assertEqual(error, "boom")
self.assertIsNone(data)
self.assertTrue(server.closed)
def test_send_response_per_step_caches_and_sends(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.response_handle_per_step = None
server.cached_results = {"data": []}
server.batch_id_per_step = 0
server._send_response_per_step(0, [[_DummyResponse(1)]])
self.assertEqual(len(server.cached_results["data"]), 1)
server.response_handle_per_step = b"client"
server._send_response_per_step(0, [[_DummyResponse(2)]])
self.assertEqual(server.batch_id_per_step, 1)
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
def test_send_response_per_step_raises_without_socket(self):
server = _DummyServer(socket=None)
with self.assertRaises(RuntimeError):
server._send_response_per_step(0, [[_DummyResponse(1)]])
def test_send_response_per_step_handles_send_error(self):
class _ErrorSocket(_FakeSocket):
def send_multipart(self, parts, copy=True):
raise RuntimeError("send failed")
server = _DummyServer(socket=_ErrorSocket())
server.response_handle_per_step = b"client"
server.cached_results = {"data": []}
server._send_response_per_step(0, [[_DummyResponse(1)]])
self.assertEqual(server.batch_id_per_step, 0)
def test_send_response_per_query_cache_and_flush(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.cached_results = defaultdict(list)
server.req_dict = {}
server.aggregate_send = False
req_id = "req-1"
server._send_response_per_query(req_id, [_DummyResponse(3)])
self.assertIn(req_id, server.cached_results)
server.req_dict[req_id] = b"client"
server._send_response_per_query(req_id, [_DummyResponse(4, finished=True)])
self.assertNotIn(req_id, server.req_dict)
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
def test_send_response_per_query_aggregate(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.req_dict["req-agg"] = b"client"
server.aggregate_send = True
server._send_response_per_query("req-agg", [_DummyResponse(5, finished=True)])
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
def test_send_response_per_query_send_failure(self):
class _ErrorSocket(_FakeSocket):
def send_multipart(self, parts, copy=True):
raise RuntimeError("send failed")
server = _DummyServer(socket=_ErrorSocket())
server.req_dict["req-error"] = b"client"
server.aggregate_send = False
server._send_response_per_query("req-error", [_DummyResponse(7, finished=True)])
self.assertEqual(server.req_dict, {})
def test_send_response_per_query_raises_without_socket(self):
server = _DummyServer(socket=None)
with self.assertRaises(RuntimeError):
server._send_response_per_query("req-missing", [_DummyResponse(1)])
def test_send_response_dispatches_by_env(self):
server = _DummyServer(socket=_FakeSocket())
server._send_response_per_step = mock.Mock()
server._send_response_per_query = mock.Mock()
server._send_batch_response = mock.Mock()
# Branch 1: FD_ENABLE_INTERNAL_ADAPTER=True -> _send_response_per_step
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True):
server.send_response("req", [_DummyResponse(1)])
server._send_response_per_step.assert_called_once()
# Branch 2: FD_ENABLE_INTERNAL_ADAPTER=False, ZMQ_SEND_BATCH_DATA=True -> _send_batch_response
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", True):
batch_data = [[_DummyResponse(1)]]
server.send_response(None, batch_data)
server._send_batch_response.assert_called_once_with(batch_data, worker_pid=None)
# Branch 3: FD_ENABLE_INTERNAL_ADAPTER=False, ZMQ_SEND_BATCH_DATA=False -> _send_response_per_query
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", False):
server.send_response("req", [_DummyResponse(1)])
server._send_response_per_query.assert_called_once()
def test_send_response_with_none_req_id(self):
"""Test send_response with req_id=None (batch format)"""
server = _DummyServer(socket=_FakeSocket())
server._send_batch_response = mock.Mock()
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", True):
batch_data = [[_DummyResponse(1)], [_DummyResponse(2)]]
server.send_response(None, batch_data)
server._send_batch_response.assert_called_once_with(batch_data, worker_pid=None)
def test_send_batch_response_success(self):
"""Test _send_batch_response sends data successfully"""
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.address = "test-address"
batch_data = [[_DummyResponse(1, finished=True)]]
server._send_batch_response(batch_data)
self.assertEqual(len(fake_socket.sent), 1)
self.assertEqual(fake_socket.sent[0][0], "send")
def test_send_batch_response_raises_without_socket(self):
"""Test _send_batch_response logs error and returns when socket is None"""
server = _DummyServer(socket=None)
server._create_socket = lambda: None
batch_data = [[_DummyResponse(1)]]
# Production code logs error and returns (does not raise)
server._send_batch_response(batch_data)
def test_send_batch_response_handles_send_error(self):
"""Test _send_batch_response handles socket send errors"""
class _ErrorSocket(_FakeSocket):
def send(self, msg, flags=0, **kwargs):
raise RuntimeError("send failed")
server = _DummyServer(socket=_ErrorSocket())
server.address = "test-address"
batch_data = [[_DummyResponse(1)]]
# Should not raise, error is caught and logged
server._send_batch_response(batch_data)
def test_recv_result_handle_paths(self):
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.running = True
server.cached_results = defaultdict(list)
server.req_dict = {}
client_id = b"client"
req_id = b"req-1"
def _recv_once(*args, **kwargs):
server.running = False
return client_id, b"", req_id
fake_socket.recv_multipart = _recv_once
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True):
server.recv_result_handle()
self.assertEqual(server.response_handle_per_step, client_id)
server.running = True
server.response_handle_per_step = None
server.cached_results = defaultdict(list)
server.cached_results["req-1"].append([_DummyResponse(1, finished=True)])
fake_socket.recv_multipart = _recv_once
server.send_response = mock.Mock()
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
server.recv_result_handle()
server.send_response.assert_called_once_with("req-1", [])
def test_exit_calls_close(self):
server = _DummyServer(socket=_FakeSocket())
server.close = mock.Mock()
server.__exit__(None, None, None)
server.close.assert_called_once()
class TestZmqServers(unittest.TestCase):
def test_zmq_ipc_server_file_name_and_clear_ipc_error(self):
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqIpcServer("test", zmq.ROUTER)
self.assertIn("router_test.ipc", server.file_name)
with tempfile.NamedTemporaryFile(delete=False) as tmp:
server.file_name = tmp.name
with mock.patch("fastdeploy.inter_communicator.zmq_server.os.remove", side_effect=OSError("fail")):
server._clear_ipc(server.file_name)
def test_zmq_ipc_server_close_cleans_ipc(self):
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqIpcServer("test", zmq.PULL)
with tempfile.NamedTemporaryFile(delete=False) as tmp:
server.file_name = tmp.name
server.close()
self.assertFalse(os.path.exists(server.file_name))
self.assertTrue(fake_context.closed)
server.close()
self.assertFalse(server.running)
def test_zmq_ipc_server_close_exception(self):
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqIpcServer("test", zmq.PULL)
class _BadSocket(_FakeSocket):
def close(self):
raise RuntimeError("close failed")
server.socket = _BadSocket()
server.context = _FakeContext()
server.close()
self.assertFalse(server.running)
def test_zmq_tcp_server_control_cmd_flow(self):
fake_context = _FakeContext()
task = {"task_id": "task-1", "payload": "ok"}
fake_context.socket_instance.recv_multipart_payload = [
b"client",
b"",
msgpack.packb(task),
]
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqTcpServer(12345, zmq.ROUTER)
received = server.recv_control_cmd()
self.assertEqual(received["task_id"], "task-1")
self.assertIn("task-1", server.req_dict)
server.response_for_control_cmd("task-1", {"status": "done"})
self.assertNotIn("task-1", server.req_dict)
self.assertEqual(fake_context.socket_instance.sent[-1][0], "send_multipart")
def test_zmq_tcp_server_control_cmd_empty(self):
fake_context = _FakeContext()
fake_context.socket_instance.recv_multipart_payload = zmq.Again()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqTcpServer(12345, zmq.ROUTER)
self.assertIsNone(server.recv_control_cmd())
def test_zmq_tcp_server_response_errors_and_close(self):
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqTcpServer(12345, zmq.ROUTER)
server.socket = None
server._create_socket = lambda: None
with self.assertRaises(RuntimeError):
server.response_for_control_cmd("task", {"status": "fail"})
class _ErrorSocket(_FakeSocket):
def send_multipart(self, parts, copy=True):
raise RuntimeError("send failed")
server.socket = _ErrorSocket()
server.req_dict["task"] = b"client"
server.response_for_control_cmd("task", {"status": "fail"})
self.assertEqual(server.req_dict, {})
server.running = False
server.close()
def test_zmq_ipc_server_unsupported_mode_raises(self):
"""Line 422: ZmqIpcServer.__init__ with unsupported ZMQ mode raises ValueError."""
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
with self.assertRaises(ValueError):
ZmqIpcServer("test", zmq.PUB)
def test_zmq_ipc_server_get_worker_push_socket_creates_and_caches(self):
"""Lines 436-447: _get_worker_push_socket creates a new PUSH socket and caches it."""
# Track all sockets created by context.socket()
created_sockets = []
class _TrackingContext(_FakeContext):
def socket(self, mode):
sock = _FakeSocket()
sock.mode = mode
sock.connected_to = None
sock.connect = lambda addr: setattr(sock, "connected_to", addr)
created_sockets.append(sock)
return sock
tracking_ctx = _TrackingContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=tracking_ctx):
server = ZmqIpcServer("myservice", zmq.PUSH)
# First call: should create a new PUSH socket and connect it
sock1 = server._get_worker_push_socket(1234)
self.assertIsNotNone(sock1)
self.assertIn(1234, server.worker_push_sockets)
self.assertIsNotNone(sock1.connected_to)
self.assertIn("1234", sock1.connected_to)
# Second call with same worker_pid: should return cached socket
sock2 = server._get_worker_push_socket(1234)
self.assertIs(sock1, sock2)
# Different worker_pid: should create a new socket
sock3 = server._get_worker_push_socket(5678)
self.assertIsNot(sock1, sock3)
self.assertIn(5678, server.worker_push_sockets)
def test_send_batch_response_with_worker_pid_none_uses_default_socket(self):
"""Line 323: _send_batch_response with worker_pid=None uses the default socket (via _ensure_socket)."""
fake_socket = _FakeSocket()
server = _DummyServer(socket=fake_socket)
server.address = "test-address"
batch_data = [[_DummyResponse(1, finished=True)]]
# worker_pid=None -> goes to the else branch that calls _ensure_socket / uses self.socket
server._send_batch_response(batch_data, worker_pid=None)
# The default socket should have been used to send the data
self.assertEqual(len(fake_socket.sent), 1)
self.assertEqual(fake_socket.sent[0][0], "send")
def test_zmq_ipc_server_close_with_worker_push_sockets(self):
"""Lines 473-475: close() iterates and closes per-worker PUSH sockets, swallowing errors."""
fake_context = _FakeContext()
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
server = ZmqIpcServer("test", zmq.PULL)
# Add a well-behaved push socket
good_sock = _FakeSocket()
server.worker_push_sockets[100] = good_sock
# Add a push socket whose close() raises
class _BadPushSocket(_FakeSocket):
def close(self):
raise RuntimeError("push close failed")
bad_sock = _BadPushSocket()
server.worker_push_sockets[200] = bad_sock
# close() should not raise even if a push socket close() fails
server.close()
self.assertFalse(server.running)
# worker_push_sockets should be cleared
self.assertEqual(len(server.worker_push_sockets), 0)
# The good socket should have been closed
self.assertTrue(good_sock.closed)
if __name__ == "__main__":
unittest.main()