mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2026-04-23 00:17:25 +08:00
642 lines
25 KiB
Python
642 lines
25 KiB
Python
"""
|
|
Tests for ZmqServerBase and derived helpers covering response/recv paths.
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
import unittest
|
|
from collections import defaultdict
|
|
from unittest import mock
|
|
|
|
import msgpack
|
|
import paddle
|
|
import zmq
|
|
from zmq.utils import jsonapi
|
|
|
|
if not hasattr(paddle, "enable_compat"):
|
|
paddle.enable_compat = lambda **kwargs: None
|
|
|
|
from fastdeploy import envs
|
|
from fastdeploy.inter_communicator.zmq_server import (
|
|
ZmqIpcServer,
|
|
ZmqServerBase,
|
|
ZmqTcpServer,
|
|
_msgpack_default,
|
|
)
|
|
|
|
|
|
class _DummyResponse:
|
|
def __init__(self, value, finished=False):
|
|
self.value = value
|
|
self.finished = finished
|
|
self.tensor = paddle.to_tensor([value])
|
|
|
|
def add(self, other):
|
|
self.value += other.value
|
|
self.tensor = self.tensor + other.tensor
|
|
self.finished = self.finished or other.finished
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"value": int(self.value),
|
|
"finished": bool(self.finished),
|
|
"tensor_sum": int(self.tensor.sum()),
|
|
}
|
|
|
|
def __getstate__(self):
|
|
return {"value": self.value, "finished": self.finished}
|
|
|
|
def __setstate__(self, state):
|
|
self.value = state["value"]
|
|
self.finished = state["finished"]
|
|
self.tensor = paddle.to_tensor([self.value])
|
|
|
|
|
|
class _FakeSocket:
|
|
def __init__(self):
|
|
self.closed = False
|
|
self.sent = []
|
|
self.recv_payload = None
|
|
self.recv_multipart_payload = None
|
|
self.options = {}
|
|
|
|
def setsockopt(self, option, value):
|
|
self.options[option] = value
|
|
|
|
def bind(self, address):
|
|
self.address = address
|
|
|
|
def send(self, msg, flags=0, **kwargs):
|
|
self.sent.append(("send", msg, flags, kwargs))
|
|
return msg
|
|
|
|
def send_multipart(self, parts, copy=True):
|
|
self.sent.append(("send_multipart", parts, copy))
|
|
|
|
def recv(self, flags=0):
|
|
if isinstance(self.recv_payload, Exception):
|
|
raise self.recv_payload
|
|
return self.recv_payload
|
|
|
|
def recv_multipart(self, flags=0):
|
|
if isinstance(self.recv_multipart_payload, Exception):
|
|
raise self.recv_multipart_payload
|
|
return self.recv_multipart_payload
|
|
|
|
def _deserialize(self, msg, loader):
|
|
return loader(msg)
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
|
|
class _FakeContext:
|
|
def __init__(self):
|
|
self.closed = False
|
|
self.socket_instance = _FakeSocket()
|
|
|
|
def socket(self, mode):
|
|
self.socket_instance.mode = mode
|
|
return self.socket_instance
|
|
|
|
def term(self):
|
|
self.closed = True
|
|
|
|
|
|
class _DummyServer(ZmqServerBase):
|
|
def __init__(self, socket=None):
|
|
super().__init__()
|
|
self.socket = socket
|
|
self.running = False
|
|
self.mutex = threading.Lock()
|
|
self.req_dict = {}
|
|
self.aggregate_send = False
|
|
self.address = "test-address"
|
|
self.response_token_lock = threading.Lock()
|
|
|
|
def _create_socket(self):
|
|
return self.socket
|
|
|
|
def close(self):
|
|
self.closed = True
|
|
|
|
|
|
class TestMsgpackDefault(unittest.TestCase):
|
|
"""Tests for _msgpack_default fallback serializer."""
|
|
|
|
def test_tensor_with_tolist(self):
|
|
"""Test serialization of objects with tolist method (paddle.Tensor, numpy.ndarray)."""
|
|
# Test with paddle tensor
|
|
tensor = paddle.to_tensor([1, 2, 3])
|
|
result = _msgpack_default(tensor)
|
|
self.assertEqual(result, [1, 2, 3])
|
|
|
|
# Test with numpy array
|
|
import numpy as np
|
|
|
|
arr = np.array([[1, 2], [3, 4]])
|
|
result = _msgpack_default(arr)
|
|
self.assertEqual(result, [[1, 2], [3, 4]])
|
|
|
|
def test_namedtuple(self):
|
|
"""Test serialization of NamedTuple objects."""
|
|
|
|
from collections import namedtuple
|
|
|
|
Point = namedtuple("Point", ["x", "y"])
|
|
point = Point(10, 20)
|
|
result = _msgpack_default(point)
|
|
self.assertEqual(result, [10, 20])
|
|
|
|
def test_dataclass(self):
|
|
"""Test serialization of dataclass objects."""
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class SampleData:
|
|
name: str
|
|
value: int
|
|
|
|
data = SampleData(name="test", value=42)
|
|
result = _msgpack_default(data)
|
|
self.assertEqual(result, {"name": "test", "value": 42})
|
|
|
|
def test_unsupported_type_raises(self):
|
|
"""Test that unsupported types raise TypeError."""
|
|
|
|
class UnsupportedClass:
|
|
pass
|
|
|
|
with self.assertRaises(TypeError) as ctx:
|
|
_msgpack_default(UnsupportedClass())
|
|
self.assertIn("UnsupportedClass", str(ctx.exception))
|
|
self.assertIn("not msgpack serializable", str(ctx.exception))
|
|
|
|
def test_integration_with_msgpack(self):
|
|
"""Test that _msgpack_default works correctly with msgpack.packb."""
|
|
# Test with mixed data types
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class Config:
|
|
lr: float
|
|
steps: int
|
|
|
|
data = {
|
|
"tensor": paddle.to_tensor([1.0, 2.0, 3.0]),
|
|
"config": Config(lr=0.001, steps=100),
|
|
"regular_list": [1, 2, 3],
|
|
}
|
|
packed = msgpack.packb(data, default=_msgpack_default)
|
|
unpacked = msgpack.unpackb(packed)
|
|
|
|
self.assertEqual(unpacked["tensor"], [1.0, 2.0, 3.0])
|
|
self.assertEqual(unpacked["config"], {"lr": 0.001, "steps": 100})
|
|
self.assertEqual(unpacked["regular_list"], [1, 2, 3])
|
|
|
|
|
|
class TestZmqServerBase(unittest.TestCase):
|
|
def test_send_and_recv_json_roundtrip(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
payload = {"hello": "world"}
|
|
server.send_json(payload)
|
|
sent_msg = fake_socket.sent[-1][1]
|
|
envelope = jsonapi.loads(sent_msg)
|
|
self.assertIn("__meta", envelope)
|
|
self.assertEqual(envelope["data"], payload)
|
|
|
|
recv_envelope = {"__meta": {"send_ts": time.perf_counter() - 0.01}, "data": {"ok": True}}
|
|
fake_socket.recv_payload = jsonapi.dumps(recv_envelope)
|
|
self.assertEqual(server.recv_json(), {"ok": True})
|
|
|
|
def test_ensure_socket_creates_socket(self):
|
|
server = _DummyServer(socket=None)
|
|
server._create_socket = lambda: _FakeSocket()
|
|
server._ensure_socket()
|
|
self.assertIsNotNone(server.socket)
|
|
|
|
def test_recv_json_returns_raw_payload(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
fake_socket.recv_payload = jsonapi.dumps(["plain", "list"])
|
|
self.assertEqual(server.recv_json(), ["plain", "list"])
|
|
|
|
def test_send_json_raises_on_socket_error(self):
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send(self, msg, flags=0):
|
|
raise RuntimeError("send failed")
|
|
|
|
server = _DummyServer(socket=_ErrorSocket())
|
|
with self.assertRaises(RuntimeError):
|
|
server.send_json({"boom": True})
|
|
|
|
def test_recv_pyobj_meta_envelope(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
data = {"token": 1}
|
|
envelope = {"__meta": {"send_ts": time.perf_counter() - 0.05}, "data": data}
|
|
fake_socket.recv_payload = b"payload"
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.ForkingPickler.loads", return_value=envelope):
|
|
self.assertEqual(server.recv_pyobj(), data)
|
|
|
|
def test_send_pyobj_and_recv_pyobj_fallback(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.send_pyobj({"hello": "world"})
|
|
self.assertEqual(fake_socket.sent[-1][0], "send")
|
|
|
|
envelope = {"payload": "raw"}
|
|
fake_socket.recv_payload = b"payload"
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.ForkingPickler.loads", return_value=envelope):
|
|
self.assertEqual(server.recv_pyobj(), envelope)
|
|
|
|
def test_send_pyobj_raises_on_socket_error(self):
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send(self, msg, flags=0, **kwargs):
|
|
raise RuntimeError("send failed")
|
|
|
|
server = _DummyServer(socket=_ErrorSocket())
|
|
with self.assertRaises(RuntimeError):
|
|
server.send_pyobj({"boom": True})
|
|
|
|
def test_receive_json_once_paths(self):
|
|
fake_socket = _FakeSocket()
|
|
fake_socket.closed = True
|
|
server = _DummyServer(socket=fake_socket)
|
|
error, data = server.receive_json_once()
|
|
self.assertEqual(error, "zmp socket has closed")
|
|
self.assertIsNone(data)
|
|
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server.recv_json = mock.Mock(side_effect=zmq.Again())
|
|
error, data = server.receive_json_once()
|
|
self.assertIsNone(error)
|
|
self.assertIsNone(data)
|
|
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server.recv_json = mock.Mock(side_effect=ValueError("boom"))
|
|
error, data = server.receive_json_once()
|
|
self.assertEqual(error, "boom")
|
|
self.assertIsNone(data)
|
|
self.assertTrue(server.closed)
|
|
|
|
def test_receive_pyobj_once_paths(self):
|
|
fake_socket = _FakeSocket()
|
|
fake_socket.closed = True
|
|
server = _DummyServer(socket=fake_socket)
|
|
error, data = server.receive_pyobj_once()
|
|
self.assertEqual(error, "zmp socket has closed")
|
|
self.assertIsNone(data)
|
|
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server.recv_pyobj = mock.Mock(side_effect=zmq.Again())
|
|
error, data = server.receive_pyobj_once()
|
|
self.assertIsNone(error)
|
|
self.assertIsNone(data)
|
|
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server.recv_pyobj = mock.Mock(side_effect=ValueError("boom"))
|
|
error, data = server.receive_pyobj_once()
|
|
self.assertEqual(error, "boom")
|
|
self.assertIsNone(data)
|
|
self.assertTrue(server.closed)
|
|
|
|
def test_send_response_per_step_caches_and_sends(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.response_handle_per_step = None
|
|
server.cached_results = {"data": []}
|
|
server.batch_id_per_step = 0
|
|
server._send_response_per_step(0, [[_DummyResponse(1)]])
|
|
self.assertEqual(len(server.cached_results["data"]), 1)
|
|
|
|
server.response_handle_per_step = b"client"
|
|
server._send_response_per_step(0, [[_DummyResponse(2)]])
|
|
self.assertEqual(server.batch_id_per_step, 1)
|
|
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
|
|
|
|
def test_send_response_per_step_raises_without_socket(self):
|
|
server = _DummyServer(socket=None)
|
|
with self.assertRaises(RuntimeError):
|
|
server._send_response_per_step(0, [[_DummyResponse(1)]])
|
|
|
|
def test_send_response_per_step_handles_send_error(self):
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send_multipart(self, parts, copy=True):
|
|
raise RuntimeError("send failed")
|
|
|
|
server = _DummyServer(socket=_ErrorSocket())
|
|
server.response_handle_per_step = b"client"
|
|
server.cached_results = {"data": []}
|
|
server._send_response_per_step(0, [[_DummyResponse(1)]])
|
|
self.assertEqual(server.batch_id_per_step, 0)
|
|
|
|
def test_send_response_per_query_cache_and_flush(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.cached_results = defaultdict(list)
|
|
server.req_dict = {}
|
|
server.aggregate_send = False
|
|
req_id = "req-1"
|
|
server._send_response_per_query(req_id, [_DummyResponse(3)])
|
|
self.assertIn(req_id, server.cached_results)
|
|
|
|
server.req_dict[req_id] = b"client"
|
|
server._send_response_per_query(req_id, [_DummyResponse(4, finished=True)])
|
|
self.assertNotIn(req_id, server.req_dict)
|
|
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
|
|
|
|
def test_send_response_per_query_aggregate(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.req_dict["req-agg"] = b"client"
|
|
server.aggregate_send = True
|
|
server._send_response_per_query("req-agg", [_DummyResponse(5, finished=True)])
|
|
self.assertEqual(fake_socket.sent[-1][0], "send_multipart")
|
|
|
|
def test_send_response_per_query_send_failure(self):
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send_multipart(self, parts, copy=True):
|
|
raise RuntimeError("send failed")
|
|
|
|
server = _DummyServer(socket=_ErrorSocket())
|
|
server.req_dict["req-error"] = b"client"
|
|
server.aggregate_send = False
|
|
server._send_response_per_query("req-error", [_DummyResponse(7, finished=True)])
|
|
self.assertEqual(server.req_dict, {})
|
|
|
|
def test_send_response_per_query_raises_without_socket(self):
|
|
server = _DummyServer(socket=None)
|
|
with self.assertRaises(RuntimeError):
|
|
server._send_response_per_query("req-missing", [_DummyResponse(1)])
|
|
|
|
def test_send_response_dispatches_by_env(self):
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server._send_response_per_step = mock.Mock()
|
|
server._send_response_per_query = mock.Mock()
|
|
server._send_batch_response = mock.Mock()
|
|
# Branch 1: FD_ENABLE_INTERNAL_ADAPTER=True -> _send_response_per_step
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True):
|
|
server.send_response("req", [_DummyResponse(1)])
|
|
server._send_response_per_step.assert_called_once()
|
|
# Branch 2: FD_ENABLE_INTERNAL_ADAPTER=False, ZMQ_SEND_BATCH_DATA=True -> _send_batch_response
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
|
|
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", True):
|
|
batch_data = [[_DummyResponse(1)]]
|
|
server.send_response(None, batch_data)
|
|
server._send_batch_response.assert_called_once_with(batch_data, worker_pid=None)
|
|
# Branch 3: FD_ENABLE_INTERNAL_ADAPTER=False, ZMQ_SEND_BATCH_DATA=False -> _send_response_per_query
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
|
|
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", False):
|
|
server.send_response("req", [_DummyResponse(1)])
|
|
server._send_response_per_query.assert_called_once()
|
|
|
|
def test_send_response_with_none_req_id(self):
|
|
"""Test send_response with req_id=None (batch format)"""
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server._send_batch_response = mock.Mock()
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
|
|
with mock.patch.object(envs, "ZMQ_SEND_BATCH_DATA", True):
|
|
batch_data = [[_DummyResponse(1)], [_DummyResponse(2)]]
|
|
server.send_response(None, batch_data)
|
|
server._send_batch_response.assert_called_once_with(batch_data, worker_pid=None)
|
|
|
|
def test_send_batch_response_success(self):
|
|
"""Test _send_batch_response sends data successfully"""
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.address = "test-address"
|
|
batch_data = [[_DummyResponse(1, finished=True)]]
|
|
server._send_batch_response(batch_data)
|
|
self.assertEqual(len(fake_socket.sent), 1)
|
|
self.assertEqual(fake_socket.sent[0][0], "send")
|
|
|
|
def test_send_batch_response_raises_without_socket(self):
|
|
"""Test _send_batch_response logs error and returns when socket is None"""
|
|
server = _DummyServer(socket=None)
|
|
server._create_socket = lambda: None
|
|
batch_data = [[_DummyResponse(1)]]
|
|
# Production code logs error and returns (does not raise)
|
|
server._send_batch_response(batch_data)
|
|
|
|
def test_send_batch_response_handles_send_error(self):
|
|
"""Test _send_batch_response handles socket send errors"""
|
|
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send(self, msg, flags=0, **kwargs):
|
|
raise RuntimeError("send failed")
|
|
|
|
server = _DummyServer(socket=_ErrorSocket())
|
|
server.address = "test-address"
|
|
batch_data = [[_DummyResponse(1)]]
|
|
# Should not raise, error is caught and logged
|
|
server._send_batch_response(batch_data)
|
|
|
|
def test_recv_result_handle_paths(self):
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.running = True
|
|
server.cached_results = defaultdict(list)
|
|
server.req_dict = {}
|
|
client_id = b"client"
|
|
req_id = b"req-1"
|
|
|
|
def _recv_once(*args, **kwargs):
|
|
server.running = False
|
|
return client_id, b"", req_id
|
|
|
|
fake_socket.recv_multipart = _recv_once
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", True):
|
|
server.recv_result_handle()
|
|
self.assertEqual(server.response_handle_per_step, client_id)
|
|
|
|
server.running = True
|
|
server.response_handle_per_step = None
|
|
server.cached_results = defaultdict(list)
|
|
server.cached_results["req-1"].append([_DummyResponse(1, finished=True)])
|
|
fake_socket.recv_multipart = _recv_once
|
|
server.send_response = mock.Mock()
|
|
with mock.patch.object(envs, "FD_ENABLE_INTERNAL_ADAPTER", False):
|
|
server.recv_result_handle()
|
|
server.send_response.assert_called_once_with("req-1", [])
|
|
|
|
def test_exit_calls_close(self):
|
|
server = _DummyServer(socket=_FakeSocket())
|
|
server.close = mock.Mock()
|
|
server.__exit__(None, None, None)
|
|
server.close.assert_called_once()
|
|
|
|
|
|
class TestZmqServers(unittest.TestCase):
|
|
def test_zmq_ipc_server_file_name_and_clear_ipc_error(self):
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqIpcServer("test", zmq.ROUTER)
|
|
self.assertIn("router_test.ipc", server.file_name)
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
|
server.file_name = tmp.name
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.os.remove", side_effect=OSError("fail")):
|
|
server._clear_ipc(server.file_name)
|
|
|
|
def test_zmq_ipc_server_close_cleans_ipc(self):
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqIpcServer("test", zmq.PULL)
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
|
server.file_name = tmp.name
|
|
server.close()
|
|
self.assertFalse(os.path.exists(server.file_name))
|
|
self.assertTrue(fake_context.closed)
|
|
|
|
server.close()
|
|
self.assertFalse(server.running)
|
|
|
|
def test_zmq_ipc_server_close_exception(self):
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqIpcServer("test", zmq.PULL)
|
|
|
|
class _BadSocket(_FakeSocket):
|
|
def close(self):
|
|
raise RuntimeError("close failed")
|
|
|
|
server.socket = _BadSocket()
|
|
server.context = _FakeContext()
|
|
server.close()
|
|
self.assertFalse(server.running)
|
|
|
|
def test_zmq_tcp_server_control_cmd_flow(self):
|
|
fake_context = _FakeContext()
|
|
task = {"task_id": "task-1", "payload": "ok"}
|
|
fake_context.socket_instance.recv_multipart_payload = [
|
|
b"client",
|
|
b"",
|
|
msgpack.packb(task),
|
|
]
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqTcpServer(12345, zmq.ROUTER)
|
|
received = server.recv_control_cmd()
|
|
self.assertEqual(received["task_id"], "task-1")
|
|
self.assertIn("task-1", server.req_dict)
|
|
|
|
server.response_for_control_cmd("task-1", {"status": "done"})
|
|
self.assertNotIn("task-1", server.req_dict)
|
|
self.assertEqual(fake_context.socket_instance.sent[-1][0], "send_multipart")
|
|
|
|
def test_zmq_tcp_server_control_cmd_empty(self):
|
|
fake_context = _FakeContext()
|
|
fake_context.socket_instance.recv_multipart_payload = zmq.Again()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqTcpServer(12345, zmq.ROUTER)
|
|
self.assertIsNone(server.recv_control_cmd())
|
|
|
|
def test_zmq_tcp_server_response_errors_and_close(self):
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqTcpServer(12345, zmq.ROUTER)
|
|
server.socket = None
|
|
server._create_socket = lambda: None
|
|
with self.assertRaises(RuntimeError):
|
|
server.response_for_control_cmd("task", {"status": "fail"})
|
|
|
|
class _ErrorSocket(_FakeSocket):
|
|
def send_multipart(self, parts, copy=True):
|
|
raise RuntimeError("send failed")
|
|
|
|
server.socket = _ErrorSocket()
|
|
server.req_dict["task"] = b"client"
|
|
server.response_for_control_cmd("task", {"status": "fail"})
|
|
self.assertEqual(server.req_dict, {})
|
|
|
|
server.running = False
|
|
server.close()
|
|
|
|
def test_zmq_ipc_server_unsupported_mode_raises(self):
|
|
"""Line 422: ZmqIpcServer.__init__ with unsupported ZMQ mode raises ValueError."""
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
with self.assertRaises(ValueError):
|
|
ZmqIpcServer("test", zmq.PUB)
|
|
|
|
def test_zmq_ipc_server_get_worker_push_socket_creates_and_caches(self):
|
|
"""Lines 436-447: _get_worker_push_socket creates a new PUSH socket and caches it."""
|
|
# Track all sockets created by context.socket()
|
|
created_sockets = []
|
|
|
|
class _TrackingContext(_FakeContext):
|
|
def socket(self, mode):
|
|
sock = _FakeSocket()
|
|
sock.mode = mode
|
|
sock.connected_to = None
|
|
sock.connect = lambda addr: setattr(sock, "connected_to", addr)
|
|
created_sockets.append(sock)
|
|
return sock
|
|
|
|
tracking_ctx = _TrackingContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=tracking_ctx):
|
|
server = ZmqIpcServer("myservice", zmq.PUSH)
|
|
|
|
# First call: should create a new PUSH socket and connect it
|
|
sock1 = server._get_worker_push_socket(1234)
|
|
self.assertIsNotNone(sock1)
|
|
self.assertIn(1234, server.worker_push_sockets)
|
|
self.assertIsNotNone(sock1.connected_to)
|
|
self.assertIn("1234", sock1.connected_to)
|
|
|
|
# Second call with same worker_pid: should return cached socket
|
|
sock2 = server._get_worker_push_socket(1234)
|
|
self.assertIs(sock1, sock2)
|
|
|
|
# Different worker_pid: should create a new socket
|
|
sock3 = server._get_worker_push_socket(5678)
|
|
self.assertIsNot(sock1, sock3)
|
|
self.assertIn(5678, server.worker_push_sockets)
|
|
|
|
def test_send_batch_response_with_worker_pid_none_uses_default_socket(self):
|
|
"""Line 323: _send_batch_response with worker_pid=None uses the default socket (via _ensure_socket)."""
|
|
fake_socket = _FakeSocket()
|
|
server = _DummyServer(socket=fake_socket)
|
|
server.address = "test-address"
|
|
|
|
batch_data = [[_DummyResponse(1, finished=True)]]
|
|
# worker_pid=None -> goes to the else branch that calls _ensure_socket / uses self.socket
|
|
server._send_batch_response(batch_data, worker_pid=None)
|
|
|
|
# The default socket should have been used to send the data
|
|
self.assertEqual(len(fake_socket.sent), 1)
|
|
self.assertEqual(fake_socket.sent[0][0], "send")
|
|
|
|
def test_zmq_ipc_server_close_with_worker_push_sockets(self):
|
|
"""Lines 473-475: close() iterates and closes per-worker PUSH sockets, swallowing errors."""
|
|
fake_context = _FakeContext()
|
|
with mock.patch("fastdeploy.inter_communicator.zmq_server.zmq.Context", return_value=fake_context):
|
|
server = ZmqIpcServer("test", zmq.PULL)
|
|
|
|
# Add a well-behaved push socket
|
|
good_sock = _FakeSocket()
|
|
server.worker_push_sockets[100] = good_sock
|
|
|
|
# Add a push socket whose close() raises
|
|
class _BadPushSocket(_FakeSocket):
|
|
def close(self):
|
|
raise RuntimeError("push close failed")
|
|
|
|
bad_sock = _BadPushSocket()
|
|
server.worker_push_sockets[200] = bad_sock
|
|
|
|
# close() should not raise even if a push socket close() fails
|
|
server.close()
|
|
|
|
self.assertFalse(server.running)
|
|
# worker_push_sockets should be cleared
|
|
self.assertEqual(len(server.worker_push_sockets), 0)
|
|
# The good socket should have been closed
|
|
self.assertTrue(good_sock.closed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|