192 lines
6.2 KiB
Python
192 lines
6.2 KiB
Python
# Copyright (c) PyZMQ Developers.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import sys
|
|
import time
|
|
from threading import Thread
|
|
|
|
from unittest import TestCase
|
|
try:
|
|
from unittest import SkipTest
|
|
except ImportError:
|
|
from unittest2 import SkipTest
|
|
|
|
from pytest import mark
|
|
import zmq
|
|
from zmq.utils import jsonapi
|
|
|
|
try:
|
|
import gevent
|
|
from zmq import green as gzmq
|
|
have_gevent = True
|
|
except ImportError:
|
|
have_gevent = False
|
|
|
|
|
|
PYPY = 'PyPy' in sys.version
|
|
|
|
#-----------------------------------------------------------------------------
|
|
# skip decorators (directly from unittest)
|
|
#-----------------------------------------------------------------------------
|
|
|
|
_id = lambda x: x
|
|
|
|
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
|
|
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
|
|
|
|
#-----------------------------------------------------------------------------
|
|
# Base test class
|
|
#-----------------------------------------------------------------------------
|
|
|
|
class BaseZMQTestCase(TestCase):
|
|
green = False
|
|
teardown_timeout = 10
|
|
|
|
@property
|
|
def Context(self):
|
|
if self.green:
|
|
return gzmq.Context
|
|
else:
|
|
return zmq.Context
|
|
|
|
def socket(self, socket_type):
|
|
s = self.context.socket(socket_type)
|
|
self.sockets.append(s)
|
|
return s
|
|
|
|
def setUp(self):
|
|
super(BaseZMQTestCase, self).setUp()
|
|
if self.green and not have_gevent:
|
|
raise SkipTest("requires gevent")
|
|
self.context = self.Context.instance()
|
|
self.sockets = []
|
|
|
|
def tearDown(self):
|
|
contexts = set([self.context])
|
|
while self.sockets:
|
|
sock = self.sockets.pop()
|
|
contexts.add(sock.context) # in case additional contexts are created
|
|
sock.close(0)
|
|
for ctx in contexts:
|
|
t = Thread(target=ctx.term)
|
|
t.daemon = True
|
|
t.start()
|
|
t.join(timeout=self.teardown_timeout)
|
|
if t.is_alive():
|
|
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
|
zmq.sugar.context.Context._instance = None
|
|
raise RuntimeError("context could not terminate, open sockets likely remain in test")
|
|
super(BaseZMQTestCase, self).tearDown()
|
|
|
|
def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
|
|
"""Create a bound socket pair using a random port."""
|
|
s1 = self.context.socket(type1)
|
|
s1.setsockopt(zmq.LINGER, 0)
|
|
port = s1.bind_to_random_port(interface)
|
|
s2 = self.context.socket(type2)
|
|
s2.setsockopt(zmq.LINGER, 0)
|
|
s2.connect('%s:%s' % (interface, port))
|
|
self.sockets.extend([s1,s2])
|
|
return s1, s2
|
|
|
|
def ping_pong(self, s1, s2, msg):
|
|
s1.send(msg)
|
|
msg2 = s2.recv()
|
|
s2.send(msg2)
|
|
msg3 = s1.recv()
|
|
return msg3
|
|
|
|
def ping_pong_json(self, s1, s2, o):
|
|
if jsonapi.jsonmod is None:
|
|
raise SkipTest("No json library")
|
|
s1.send_json(o)
|
|
o2 = s2.recv_json()
|
|
s2.send_json(o2)
|
|
o3 = s1.recv_json()
|
|
return o3
|
|
|
|
def ping_pong_pyobj(self, s1, s2, o):
|
|
s1.send_pyobj(o)
|
|
o2 = s2.recv_pyobj()
|
|
s2.send_pyobj(o2)
|
|
o3 = s1.recv_pyobj()
|
|
return o3
|
|
|
|
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
|
try:
|
|
func(*args, **kwargs)
|
|
except zmq.ZMQError as e:
|
|
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
|
|
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
|
|
else:
|
|
self.fail("Function did not raise any error")
|
|
|
|
def _select_recv(self, multipart, socket, **kwargs):
|
|
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
|
if zmq.zmq_version_info() >= (3,1,0):
|
|
# zmq 3.1 has a bug, where poll can return false positives,
|
|
# so we wait a little bit just in case
|
|
# See LIBZMQ-280 on JIRA
|
|
time.sleep(0.1)
|
|
|
|
r,w,x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
|
|
assert len(r) > 0, "Should have received a message"
|
|
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
|
|
|
recv = socket.recv_multipart if multipart else socket.recv
|
|
return recv(**kwargs)
|
|
|
|
def recv(self, socket, **kwargs):
|
|
"""call recv in a way that raises if there is nothing to receive"""
|
|
return self._select_recv(False, socket, **kwargs)
|
|
|
|
def recv_multipart(self, socket, **kwargs):
|
|
"""call recv_multipart in a way that raises if there is nothing to receive"""
|
|
return self._select_recv(True, socket, **kwargs)
|
|
|
|
|
|
class PollZMQTestCase(BaseZMQTestCase):
|
|
pass
|
|
|
|
class GreenTest:
|
|
"""Mixin for making green versions of test classes"""
|
|
green = True
|
|
teardown_timeout = 10
|
|
|
|
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
|
if errno == zmq.EAGAIN:
|
|
raise SkipTest("Skipping because we're green.")
|
|
try:
|
|
func(*args, **kwargs)
|
|
except zmq.ZMQError:
|
|
e = sys.exc_info()[1]
|
|
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
|
|
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
|
|
else:
|
|
self.fail("Function did not raise any error")
|
|
|
|
def tearDown(self):
|
|
contexts = set([self.context])
|
|
while self.sockets:
|
|
sock = self.sockets.pop()
|
|
contexts.add(sock.context) # in case additional contexts are created
|
|
sock.close()
|
|
try:
|
|
gevent.joinall(
|
|
[gevent.spawn(ctx.term) for ctx in contexts],
|
|
timeout=self.teardown_timeout,
|
|
raise_error=True,
|
|
)
|
|
except gevent.Timeout:
|
|
raise RuntimeError("context could not terminate, open sockets likely remain in test")
|
|
|
|
def skip_green(self):
|
|
raise SkipTest("Skipping because we are green")
|
|
|
|
def skip_green(f):
|
|
def skipping_test(self, *args, **kwargs):
|
|
if self.green:
|
|
raise SkipTest("Skipping because we are green")
|
|
else:
|
|
return f(self, *args, **kwargs)
|
|
return skipping_test
|