Uploaded Test files

This commit is contained in:
Batuhan Berk Başoğlu 2020-11-12 11:05:57 -05:00
parent f584ad9d97
commit 2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions

View file

@ -0,0 +1,192 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import sys
import time
from threading import Thread
from unittest import TestCase
try:
from unittest import SkipTest
except ImportError:
from unittest2 import SkipTest
from pytest import mark
import zmq
from zmq.utils import jsonapi
try:
import gevent
from zmq import green as gzmq
have_gevent = True
except ImportError:
have_gevent = False
PYPY = 'PyPy' in sys.version
#-----------------------------------------------------------------------------
# skip decorators (directly from unittest)
#-----------------------------------------------------------------------------
_id = lambda x: x
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
#-----------------------------------------------------------------------------
# Base test class
#-----------------------------------------------------------------------------
class BaseZMQTestCase(TestCase):
green = False
teardown_timeout = 10
@property
def Context(self):
if self.green:
return gzmq.Context
else:
return zmq.Context
def socket(self, socket_type):
s = self.context.socket(socket_type)
self.sockets.append(s)
return s
def setUp(self):
super(BaseZMQTestCase, self).setUp()
if self.green and not have_gevent:
raise SkipTest("requires gevent")
self.context = self.Context.instance()
self.sockets = []
def tearDown(self):
contexts = set([self.context])
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close(0)
for ctx in contexts:
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=self.teardown_timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError("context could not terminate, open sockets likely remain in test")
super(BaseZMQTestCase, self).tearDown()
def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
"""Create a bound socket pair using a random port."""
s1 = self.context.socket(type1)
s1.setsockopt(zmq.LINGER, 0)
port = s1.bind_to_random_port(interface)
s2 = self.context.socket(type2)
s2.setsockopt(zmq.LINGER, 0)
s2.connect('%s:%s' % (interface, port))
self.sockets.extend([s1,s2])
return s1, s2
def ping_pong(self, s1, s2, msg):
s1.send(msg)
msg2 = s2.recv()
s2.send(msg2)
msg3 = s1.recv()
return msg3
def ping_pong_json(self, s1, s2, o):
if jsonapi.jsonmod is None:
raise SkipTest("No json library")
s1.send_json(o)
o2 = s2.recv_json()
s2.send_json(o2)
o3 = s1.recv_json()
return o3
def ping_pong_pyobj(self, s1, s2, o):
s1.send_pyobj(o)
o2 = s2.recv_pyobj()
s2.send_pyobj(o2)
o3 = s1.recv_pyobj()
return o3
def assertRaisesErrno(self, errno, func, *args, **kwargs):
try:
func(*args, **kwargs)
except zmq.ZMQError as e:
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
else:
self.fail("Function did not raise any error")
def _select_recv(self, multipart, socket, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3,1,0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r,w,x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
assert len(r) > 0, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(**kwargs)
def recv(self, socket, **kwargs):
"""call recv in a way that raises if there is nothing to receive"""
return self._select_recv(False, socket, **kwargs)
def recv_multipart(self, socket, **kwargs):
"""call recv_multipart in a way that raises if there is nothing to receive"""
return self._select_recv(True, socket, **kwargs)
class PollZMQTestCase(BaseZMQTestCase):
pass
class GreenTest:
"""Mixin for making green versions of test classes"""
green = True
teardown_timeout = 10
def assertRaisesErrno(self, errno, func, *args, **kwargs):
if errno == zmq.EAGAIN:
raise SkipTest("Skipping because we're green.")
try:
func(*args, **kwargs)
except zmq.ZMQError:
e = sys.exc_info()[1]
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
else:
self.fail("Function did not raise any error")
def tearDown(self):
contexts = set([self.context])
while self.sockets:
sock = self.sockets.pop()
contexts.add(sock.context) # in case additional contexts are created
sock.close()
try:
gevent.joinall(
[gevent.spawn(ctx.term) for ctx in contexts],
timeout=self.teardown_timeout,
raise_error=True,
)
except gevent.Timeout:
raise RuntimeError("context could not terminate, open sockets likely remain in test")
def skip_green(self):
raise SkipTest("Skipping because we are green")
def skip_green(f):
def skipping_test(self, *args, **kwargs):
if self.green:
raise SkipTest("Skipping because we are green")
else:
return f(self, *args, **kwargs)
return skipping_test

View file

@ -0,0 +1,481 @@
"""Test asyncio support"""
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import json
from multiprocessing import Process
import os
import sys
import pytest
from pytest import mark
import zmq
from zmq.utils.strtypes import u
try:
import asyncio
import zmq.asyncio as zaio
from zmq.auth.asyncio import AsyncioAuthenticator
except ImportError:
if sys.version_info >= (3,4):
raise
asyncio = None
from concurrent.futures import CancelledError
from zmq.tests import BaseZMQTestCase, SkipTest
from zmq.tests.test_auth import TestThreadAuthentication
class ProcessForTeardownTest(Process):
def __init__(self, event_loop_policy_class):
Process.__init__(self)
self.event_loop_policy_class = event_loop_policy_class
def run(self):
"""Leave context, socket and event loop upon implicit disposal"""
asyncio.set_event_loop_policy(self.event_loop_policy_class())
actx = zaio.Context.instance()
socket = actx.socket(zmq.PAIR)
socket.bind_to_random_port('tcp://127.0.0.1')
@asyncio.coroutine
def never_ending_task(socket):
yield from socket.recv() # never ever receive anything
loop = asyncio.get_event_loop()
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
try:
loop.run_until_complete(coro)
except asyncio.TimeoutError:
pass # expected timeout
else:
assert False, "never_ending_task was completed unexpectedly"
class TestAsyncIOSocket(BaseZMQTestCase):
if asyncio is not None:
Context = zaio.Context
def setUp(self):
if asyncio is None:
raise SkipTest()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super(TestAsyncIOSocket, self).setUp()
def tearDown(self):
self.loop.close()
super().tearDown()
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, zaio.Socket)
s.close()
def test_instance_subclass_first(self):
actx = zmq.asyncio.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = zmq.asyncio.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is zmq.asyncio.Context
def test_recv_multipart(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
yield from a.send(b'hi')
recvd = yield from f
self.assertEqual(recvd, [b'hi'])
self.loop.run_until_complete(test())
def test_recv(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
yield from a.send_multipart([b'hi', b'there'])
recvd = yield from f2
assert f1.done()
self.assertEqual(f1.result(), b'hi')
self.assertEqual(recvd, b'there')
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_recv_timeout(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with self.assertRaises(zmq.Again):
yield from f1
yield from a.send_multipart([b'hi', b'there'])
recvd = yield from f2
assert f2.done()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_until_complete(test())
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_send_timeout(self):
@asyncio.coroutine
def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with self.assertRaises(zmq.Again):
yield from s.send(b'not going anywhere')
self.loop.run_until_complete(test())
def test_recv_string(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = u('πøøπ')
yield from a.send_string(msg)
recvd = yield from f
assert f.done()
self.assertEqual(f.result(), msg)
self.assertEqual(recvd, msg)
self.loop.run_until_complete(test())
def test_recv_json(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
yield from a.send_json(obj)
recvd = yield from f
assert f.done()
self.assertEqual(f.result(), obj)
self.assertEqual(recvd, obj)
self.loop.run_until_complete(test())
def test_recv_json_cancelled(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
yield from asyncio.sleep(0)
obj = dict(a=5)
yield from a.send_json(obj)
# CancelledError change in 3.8 https://bugs.python.org/issue32528
if sys.version_info < (3, 8):
with pytest.raises(CancelledError):
recvd = yield from f
else:
with pytest.raises(asyncio.exceptions.CancelledError):
recvd = yield from f
assert f.done()
# give it a chance to incorrectly consume the event
events = yield from b.poll(timeout=5)
assert events
yield from asyncio.sleep(0)
# make sure cancelled recv didn't eat up event
f = b.recv_json()
recvd = yield from asyncio.wait_for(f, timeout=5)
assert recvd == obj
self.loop.run_until_complete(test())
def test_recv_pyobj(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
yield from a.send_pyobj(obj)
recvd = yield from f
assert f.done()
self.assertEqual(f.result(), obj)
self.assertEqual(recvd, obj)
self.loop.run_until_complete(test())
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
yield from a.send_serialized(msg, serialize)
recvd = yield from b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
yield from b.send_serialized(recvd, serialize)
r2 = yield from a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
self.loop.run_until_complete(test())
def test_custom_serialize_error(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
with pytest.raises(TypeError):
yield from a.send_serialized(json, json.dumps)
yield from a.send(b'not json')
with pytest.raises(TypeError):
recvd = yield from b.recv_serialized(json.loads)
self.loop.run_until_complete(test())
def test_recv_dontwait(self):
@asyncio.coroutine
def test():
push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = pull.recv(zmq.DONTWAIT)
with self.assertRaises(zmq.Again):
yield from f
yield from push.send(b'ping')
yield from pull.poll() # ensure message will be waiting
f = pull.recv(zmq.DONTWAIT)
assert f.done()
msg = yield from f
self.assertEqual(msg, b'ping')
self.loop.run_until_complete(test())
def test_recv_cancel(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
yield from a.send_multipart([b'hi', b'there'])
recvd = yield from f2
assert f1.cancelled()
assert f2.done()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_until_complete(test())
def test_poll(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
yield from asyncio.sleep(0)
self.assertEqual(f.result(), 0)
f = b.poll(timeout=1)
assert not f.done()
evt = yield from f
self.assertEqual(evt, 0)
f = b.poll(timeout=1000)
assert not f.done()
yield from a.send_multipart([b'hi', b'there'])
evt = yield from f
self.assertEqual(evt, zmq.POLLIN)
recvd = yield from b.recv_multipart()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_until_complete(test())
def test_poll_base_socket(self):
@asyncio.coroutine
def test():
ctx = zmq.Context()
url = 'inproc://test'
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = zaio.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b'hi', b'there'])
evt = yield from f
self.assertEqual(evt, [(b, zmq.POLLIN)])
recvd = b.recv_multipart()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_until_complete(test())
def test_poll_on_closed_socket(self):
@asyncio.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=1)
b.close()
# The test might stall if we try to yield from f directly so instead just make a few
# passes through the event loop to schedule and execute all callbacks
for _ in range(5):
yield from asyncio.sleep(0)
if f.cancelled():
break
assert f.cancelled()
self.loop.run_until_complete(test())
@pytest.mark.skipif(
sys.platform.startswith('win'),
reason='Windows does not support polling on files')
def test_poll_raw(self):
@asyncio.coroutine
def test():
p = zaio.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = yield from p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b'x')
w.flush()
evts = yield from p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b'x'
r.close()
w.close()
loop = asyncio.get_event_loop()
loop.run_until_complete(test())
def test_shadow(self):
@asyncio.coroutine
def test():
ctx = zmq.Context()
s = ctx.socket(zmq.PULL)
async_s = zaio.Socket(s)
assert isinstance(async_s, self.socket_class)
def test_process_teardown(self):
event_loop_policy_class = type(asyncio.get_event_loop_policy())
proc = ProcessForTeardownTest(event_loop_policy_class)
proc.start()
try:
proc.join(10) # starting new Python process may cost a lot
self.assertEqual(proc.exitcode, 0,
"Python process died with code %d" % proc.exitcode
if proc.exitcode else "process teardown hangs")
finally:
proc.terminate()
class TestAsyncioAuthentication(TestThreadAuthentication):
"""Test authentication running in a asyncio task"""
if asyncio is not None:
Context = zaio.Context
def shortDescription(self):
"""Rewrite doc strings from TestThreadAuthentication from
'threaded' to 'asyncio'.
"""
doc = self._testMethodDoc
if doc:
doc = doc.split("\n")[0].strip()
if doc.startswith('threaded auth'):
doc = doc.replace('threaded auth', 'asyncio auth')
return doc
def setUp(self):
if asyncio is None:
raise SkipTest()
self.loop = zaio.ZMQEventLoop()
asyncio.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
super().tearDown()
self.loop.close()
def make_auth(self):
return AsyncioAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
@asyncio.coroutine
def go():
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
yield from server.send_multipart(msg)
if (yield from client.poll(1000)):
rcvd_msg = yield from client.recv_multipart()
self.assertEqual(rcvd_msg, msg)
result = True
return result
return self.loop.run_until_complete(go())
def _select_recv(self, multipart, socket, **kwargs):
recv = socket.recv_multipart if multipart else socket.recv
@asyncio.coroutine
def coro():
if not (yield from socket.poll(5000)):
raise TimeoutError("Should have received a message")
return (yield from recv(**kwargs))
return self.loop.run_until_complete(coro())

View file

@ -0,0 +1,6 @@
"""Test asyncio support"""
try:
from ._test_asyncio import TestAsyncIOSocket, TestAsyncioAuthentication
except SyntaxError:
pass

View file

@ -0,0 +1,14 @@
"""pytest configuration and fixtures"""
import sys
import pytest
@pytest.fixture(scope='session', autouse=True)
def win_py38_asyncio():
"""fix tornado compatibility on py38"""
if sys.version_info < (3, 8) or not sys.platform.startswith('win'):
return
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

View file

@ -0,0 +1,557 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import os
import shutil
import sys
import tempfile
import pytest
import zmq.auth
from zmq.auth.thread import ThreadAuthenticator
from zmq.utils.strtypes import u
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
class BaseAuthTestCase(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4,0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to have curve support")
super(BaseAuthTestCase, self).setUp()
# enable debug logging while we run tests
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
self.auth = self.make_auth()
self.auth.start()
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
def make_auth(self):
raise NotImplementedError()
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.remove_certs(self.base_dir)
super(BaseAuthTestCase, self).tearDown()
def create_certs(self):
"""Create CURVE certificates for a test"""
# Create temporary CURVE keypairs for this test run. We create all keys in a
# temp directory and then move them into the appropriate private or public
# directory.
base_dir = tempfile.mkdtemp()
keys_dir = os.path.join(base_dir, 'certificates')
public_keys_dir = os.path.join(base_dir, 'public_keys')
secret_keys_dir = os.path.join(base_dir, 'private_keys')
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(public_keys_dir, '.'))
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(secret_keys_dir, '.'))
return (base_dir, public_keys_dir, secret_keys_dir)
def remove_certs(self, base_dir):
"""Remove certificates for a test"""
shutil.rmtree(base_dir)
def load_certs(self, secret_keys_dir):
"""Return server and client certificate keys"""
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
return server_public, server_secret, client_public, client_secret
class TestThreadAuthentication(BaseAuthTestCase):
"""Test authentication running in a thread"""
def make_auth(self):
return ThreadAuthenticator(self.context)
def can_connect(self, server, client):
"""Check if client can connect to server using tcp transport"""
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = [b"Hello World"]
if server.poll(1000, zmq.POLLOUT):
server.send_multipart(msg)
if client.poll(1000):
rcvd_msg = client.recv_multipart()
self.assertEqual(rcvd_msg, msg)
result = True
return result
def test_null(self):
"""threaded auth - NULL"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
self.auth.stop()
self.auth = None
# use a new context, so ZAP isn't inherited
self.context = self.Context()
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
server = self.socket(zmq.PUSH)
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_blacklist(self):
"""threaded auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertFalse(self.can_connect(server, client))
def test_whitelist(self):
"""threaded auth - Whitelist"""
# Whitelist 127.0.0.1, connection should pass"
self.auth.allow('127.0.0.1')
server = self.socket(zmq.PUSH)
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet.
server.zap_domain = b'global'
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_plain(self):
"""threaded auth - PLAIN"""
# Try PLAIN authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.assertFalse(self.can_connect(server, client))
# Try PLAIN authentication - with server configured, connection should pass
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Password'
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
self.assertTrue(self.can_connect(server, client))
# Try PLAIN authentication - with bogus credentials, connection should fail
server = self.socket(zmq.PUSH)
server.plain_server = True
client = self.socket(zmq.PULL)
client.plain_username = b'admin'
client.plain_password = b'Bogus'
self.assertFalse(self.can_connect(server, client))
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
client.close()
server.close()
def test_curve(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
#Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
#Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertTrue(self.can_connect(server, client))
# Try CURVE authentication - with server configured, connection should pass
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# Remove authenticator and check that a normal connection works
self.auth.stop()
self.auth = None
# Try connecting using NULL and no authentication enabled, connection should pass
server = self.socket(zmq.PUSH)
client = self.socket(zmq.PULL)
self.assertTrue(self.can_connect(server, client))
def test_curve_callback(self):
"""threaded auth - CURVE with callback authentication"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
#Try CURVE authentication - without configuring server, connection should fail
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
#Try CURVE authentication - with callback authentication configured, connection should pass
class CredentialsProvider(object):
def __init__(self):
self.client = client_public
def callback(self, domain, key):
if (key == self.client):
return True
else:
return False
provider = CredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertTrue(self.can_connect(server, client))
#Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
class WrongCredentialsProvider(object):
def __init__(self):
self.client = "WrongCredentials"
def callback(self, domain, key):
if (key == self.client):
return True
else:
return False
provider = WrongCredentialsProvider()
self.auth.configure_curve_callback(credentials_provider=provider)
server = self.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
self.assertFalse(self.can_connect(server, client))
@skip_pypy
def test_curve_user_id(self):
"""threaded auth - CURVE"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
server = self.socket(zmq.PULL)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = self.socket(zmq.PUSH)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert self.can_connect(client, server)
# test default user-id map
client.send(b'test')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == u(client_public)
# test custom user-id map
self.auth.curve_user_id = lambda client_key: u'custom'
client2 = self.socket(zmq.PUSH)
client2.curve_publickey = client_public
client2.curve_secretkey = client_secret
client2.curve_serverkey = server_public
assert self.can_connect(client2, server)
client2.send(b'test2')
msg = self.recv(server, copy=False)
assert msg.bytes == b'test2'
try:
user_id = msg.get('User-Id')
except zmq.ZMQVersionError:
pass
else:
assert user_id == u'custom'
def with_ioloop(method, expect_success=True):
"""decorator for running tests with an IOLoop"""
def test_method(self):
r = method(self)
loop = self.io_loop
if expect_success:
self.pullstream.on_recv(self.on_message_succeed)
else:
self.pullstream.on_recv(self.on_message_fail)
loop.call_later(1, self.attempt_connection)
loop.call_later(1.2, self.send_msg)
if expect_success:
loop.call_later(2, self.on_test_timeout_fail)
else:
loop.call_later(2, self.on_test_timeout_succeed)
loop.start()
if self.fail_msg:
self.fail(self.fail_msg)
return r
return test_method
def should_auth(method):
return with_ioloop(method, True)
def should_not_auth(method):
return with_ioloop(method, False)
class TestIOLoopAuthentication(BaseAuthTestCase):
"""Test authentication running in ioloop"""
def setUp(self):
try:
from tornado import ioloop
except ImportError:
pytest.skip("Requires tornado")
from zmq.eventloop import zmqstream
self.fail_msg = None
self.io_loop = ioloop.IOLoop()
super(TestIOLoopAuthentication, self).setUp()
self.server = self.socket(zmq.PUSH)
self.client = self.socket(zmq.PULL)
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
def make_auth(self):
from zmq.auth.ioloop import IOLoopAuthenticator
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
def tearDown(self):
if self.auth:
self.auth.stop()
self.auth = None
self.io_loop.close(all_fds=True)
super(TestIOLoopAuthentication, self).tearDown()
def attempt_connection(self):
"""Check if client can connect to server using tcp transport"""
iface = 'tcp://127.0.0.1'
port = self.server.bind_to_random_port(iface)
self.client.connect("%s:%i" % (iface, port))
def send_msg(self):
"""Send a message from server to a client"""
msg = [b"Hello World"]
self.pushstream.send_multipart(msg)
def on_message_succeed(self, frames):
"""A message was received, as expected."""
if frames != [b"Hello World"]:
self.fail_msg = "Unexpected message received"
self.io_loop.stop()
def on_message_fail(self, frames):
"""A message was received, unexpectedly."""
self.fail_msg = 'Received messaged unexpectedly, security failed'
self.io_loop.stop()
def on_test_timeout_succeed(self):
"""Test timer expired, indicates test success"""
self.io_loop.stop()
def on_test_timeout_fail(self):
"""Test timer expired, indicates test failure"""
self.fail_msg = 'Test timed out'
self.io_loop.stop()
@should_auth
def test_none(self):
"""ioloop auth - NONE"""
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
# no auth should be running
self.auth.stop()
self.auth = None
@should_auth
def test_null(self):
"""ioloop auth - NULL"""
# By setting a domain we switch on authentication for NULL sockets,
# though no policies are configured yet. The client connection
# should still be allowed.
self.server.zap_domain = b'global'
@should_not_auth
def test_blacklist(self):
"""ioloop auth - Blacklist"""
# Blacklist 127.0.0.1, connection should fail
self.auth.deny('127.0.0.1')
self.server.zap_domain = b'global'
@should_auth
def test_whitelist(self):
"""ioloop auth - Whitelist"""
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
self.auth.allow('127.0.0.1')
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
@should_not_auth
def test_plain_unconfigured_server(self):
"""ioloop auth - PLAIN, unconfigured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - without configuring server, connection should fail
self.server.plain_server = True
@should_auth
def test_plain_configured_server(self):
"""ioloop auth - PLAIN, configured server"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Password'
# Try PLAIN authentication - with server configured, connection should pass
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_plain_bogus_credentials(self):
"""ioloop auth - PLAIN, bogus credentials"""
self.client.plain_username = b'admin'
self.client.plain_password = b'Bogus'
self.server.plain_server = True
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
@should_not_auth
def test_curve_unconfigured_server(self):
"""ioloop auth - CURVE, unconfigured server"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_allow_any(self):
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.allow('127.0.0.1')
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public
@should_auth
def test_curve_configured_server(self):
"""ioloop auth - CURVE, configured server"""
self.auth.allow('127.0.0.1')
certs = self.load_certs(self.secret_keys_dir)
server_public, server_secret, client_public, client_secret = certs
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
self.server.curve_publickey = server_public
self.server.curve_secretkey = server_secret
self.server.curve_server = True
self.client.curve_publickey = client_public
self.client.curve_secretkey = client_secret
self.client.curve_serverkey = server_public

View file

@ -0,0 +1,297 @@
# -*- coding: utf8 -*-
import sys
import time
from unittest import TestCase
from zmq.tests import BaseZMQTestCase, SkipTest
try:
from zmq.backend.cffi import (
zmq_version_info,
PUSH, PULL, IDENTITY,
REQ, REP, POLLIN, POLLOUT,
)
from zmq.backend.cffi._cffi import ffi, C
have_ffi_backend = True
except ImportError:
have_ffi_backend = False
class TestCFFIBackend(TestCase):
def setUp(self):
if not have_ffi_backend:
raise SkipTest('CFFI not available')
def test_zmq_version_info(self):
version = zmq_version_info()
assert version[0] in range(2,11)
def test_zmq_ctx_new_destroy(self):
ctx = C.zmq_ctx_new()
assert ctx != ffi.NULL
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_socket_open_close(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_setsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[3]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_getsockopt(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, PUSH)
identity = ffi.new('char[]', b'zmq')
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
assert ret == 0
option_len = ffi.new('size_t*', 3)
option = ffi.new('char[3]')
ret = C.zmq_getsockopt(socket,
IDENTITY,
ffi.cast('void*', option),
option_len)
assert ret == 0
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind(self):
ctx = C.zmq_ctx_new()
socket = C.zmq_socket(ctx, 8)
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket
assert 0 == C.zmq_close(socket)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_bind_connect(self):
ctx = C.zmq_ctx_new()
socket1 = C.zmq_socket(ctx, PUSH)
socket2 = C.zmq_socket(ctx, PULL)
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
assert ctx != ffi.NULL
assert ffi.NULL != socket1
assert ffi.NULL != socket2
assert 0 == C.zmq_close(socket1)
assert 0 == C.zmq_close(socket2)
assert 0 == C.zmq_ctx_destroy(ctx)
def test_zmq_msg_init_close(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_size(self):
zmq_msg = ffi.new('zmq_msg_t*')
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_init_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
assert 0 == C.zmq_msg_init_data(zmq_msg,
ffi.cast('void*', message),
5,
ffi.NULL,
ffi.NULL)
assert ffi.NULL != zmq_msg
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_msg_data(self):
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[]', b'Hello')
assert 0 == C.zmq_msg_init_data(zmq_msg,
ffi.cast('void*', message),
5,
ffi.NULL,
ffi.NULL)
data = C.zmq_msg_data(zmq_msg)
assert ffi.NULL != zmq_msg
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
assert 0 == C.zmq_msg_close(zmq_msg)
def test_zmq_send(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 0 == C.zmq_msg_close(zmq_msg)
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_recv(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
time.sleep(0.1)
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL)
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert 5 == C.zmq_msg_size(zmq_msg2)
assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
C.zmq_msg_size(zmq_msg2))[:]
assert C.zmq_close(sender) == 0
assert C.zmq_close(receiver) == 0
assert C.zmq_ctx_destroy(ctx) == 0
def test_zmq_poll(self):
ctx = C.zmq_ctx_new()
sender = C.zmq_socket(ctx, REQ)
receiver = C.zmq_socket(ctx, REP)
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
zmq_msg = ffi.new('zmq_msg_t*')
message = ffi.new('char[5]', b'Hello')
C.zmq_msg_init_data(zmq_msg,
ffi.cast('void*', message),
ffi.cast('size_t', 5),
ffi.NULL,
ffi.NULL)
receiver_pollitem = ffi.new('zmq_pollitem_t*')
receiver_pollitem.socket = receiver
receiver_pollitem.fd = 0
receiver_pollitem.events = POLLIN | POLLOUT
receiver_pollitem.revents = 0
ret = C.zmq_poll(ffi.NULL, 0, 0)
assert ret == 0
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 0
ret = C.zmq_msg_send(zmq_msg, sender, 0)
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
assert ret == 5
time.sleep(0.2)
ret = C.zmq_poll(receiver_pollitem, 1, 0)
assert ret == 1
assert int(receiver_pollitem.revents) & POLLIN
assert not int(receiver_pollitem.revents) & POLLOUT
zmq_msg2 = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg2)
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
assert ret_recv == 5
assert 5 == C.zmq_msg_size(zmq_msg2)
assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
C.zmq_msg_size(zmq_msg2))[:]
sender_pollitem = ffi.new('zmq_pollitem_t*')
sender_pollitem.socket = sender
sender_pollitem.fd = 0
sender_pollitem.events = POLLIN | POLLOUT
sender_pollitem.revents = 0
ret = C.zmq_poll(sender_pollitem, 1, 0)
assert ret == 0
zmq_msg_again = ffi.new('zmq_msg_t*')
message_again = ffi.new('char[11]', b'Hello Again')
C.zmq_msg_init_data(zmq_msg_again,
ffi.cast('void*', message_again),
ffi.cast('size_t', 11),
ffi.NULL,
ffi.NULL)
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
time.sleep(0.2)
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
assert int(sender_pollitem.revents) & POLLIN
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
assert 11 == C.zmq_msg_size(zmq_msg2)
assert b"Hello Again" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
int(C.zmq_msg_size(zmq_msg2)))[:]
assert 0 == C.zmq_close(sender)
assert 0 == C.zmq_close(receiver)
assert 0 == C.zmq_ctx_destroy(ctx)
assert 0 == C.zmq_msg_close(zmq_msg)
assert 0 == C.zmq_msg_close(zmq_msg2)
assert 0 == C.zmq_msg_close(zmq_msg_again)

View file

@ -0,0 +1,121 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import json
from unittest import TestCase
import pytest
import zmq
from zmq.utils import constant_names
from zmq.sugar import constants as sugar_constants
from zmq.backend import constants as backend_constants
all_set = set(constant_names.all_names)
class TestConstants(TestCase):
def _duplicate_test(self, namelist, listname):
"""test that a given list has no duplicates"""
dupes = {}
for name in set(namelist):
cnt = namelist.count(name)
if cnt > 1:
dupes[name] = cnt
if dupes:
self.fail("The following names occur more than once in %s: %s" % (listname, json.dumps(dupes, indent=2)))
def test_duplicate_all(self):
return self._duplicate_test(constant_names.all_names, "all_names")
def _change_key(self, change, version):
"""return changed-in key"""
return "%s-in %d.%d.%d" % tuple([change] + list(version))
def test_duplicate_changed(self):
all_changed = []
for change in ("new", "removed"):
d = getattr(constant_names, change + "_in")
for version, namelist in d.items():
all_changed.extend(namelist)
self._duplicate_test(namelist, self._change_key(change, version))
self._duplicate_test(all_changed, "all-changed")
def test_changed_in_all(self):
missing = {}
for change in ("new", "removed"):
d = getattr(constant_names, change + "_in")
for version, namelist in d.items():
key = self._change_key(change, version)
for name in namelist:
if name not in all_set:
if key not in missing:
missing[key] = []
missing[key].append(name)
if missing:
self.fail(
"The following names are missing in `all_names`: %s" % json.dumps(missing, indent=2)
)
def test_no_negative_constants(self):
for name in sugar_constants.__all__:
self.assertNotEqual(getattr(zmq, name), sugar_constants._UNDEFINED)
def test_undefined_constants(self):
all_aliases = []
for alias_group in sugar_constants.aliases:
all_aliases.extend(alias_group)
for name in all_set.difference(all_aliases):
raw = getattr(backend_constants, name)
if raw == sugar_constants._UNDEFINED:
self.assertRaises(AttributeError, getattr, zmq, name)
else:
self.assertEqual(getattr(zmq, name), raw)
def test_new(self):
zmq_version = zmq.zmq_version_info()
for version, new_names in constant_names.new_in.items():
should_have = zmq_version >= version
for name in new_names:
try:
value = getattr(zmq, name)
except AttributeError:
if should_have:
self.fail("AttributeError: zmq.%s" % name)
else:
if not should_have:
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))
@pytest.mark.skipif(not zmq.DRAFT_API, reason="Only test draft API if built with draft API")
def test_draft(self):
zmq_version = zmq.zmq_version_info()
for version, new_names in constant_names.draft_in.items():
should_have = zmq_version >= version
for name in new_names:
try:
value = getattr(zmq, name)
except AttributeError:
if should_have:
self.fail("AttributeError: zmq.%s" % name)
else:
if not should_have:
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))
def test_removed(self):
zmq_version = zmq.zmq_version_info()
for version, new_names in constant_names.removed_in.items():
should_have = zmq_version < version
for name in new_names:
try:
value = getattr(zmq, name)
except AttributeError:
if should_have:
self.fail("AttributeError: zmq.%s" % name)
else:
if not should_have:
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))

View file

@ -0,0 +1,392 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import gc
import os
import sys
import time
from threading import Thread, Event
try:
from queue import Queue
except ImportError:
from Queue import Queue
try:
from unittest import mock
except ImportError:
mock = None
from pytest import mark
import zmq
from zmq.tests import (
BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest,
)
class KwargTestSocket(zmq.Socket):
test_kwarg_value = None
def __init__(self, *args, **kwargs):
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
super(KwargTestSocket, self).__init__(*args, **kwargs)
class KwargTestContext(zmq.Context):
_socket_class = KwargTestSocket
class TestContext(BaseZMQTestCase):
def test_init(self):
c1 = self.Context()
self.assert_(isinstance(c1, self.Context))
del c1
c2 = self.Context()
self.assert_(isinstance(c2, self.Context))
del c2
c3 = self.Context()
self.assert_(isinstance(c3, self.Context))
del c3
def test_dir(self):
ctx = self.Context()
self.assertTrue('socket' in dir(ctx))
if zmq.zmq_version_info() > (3,):
self.assertTrue('IO_THREADS' in dir(ctx))
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
m = mock.Mock(spec=self.context)
def test_term(self):
c = self.Context()
c.term()
self.assert_(c.closed)
def test_context_manager(self):
with self.Context() as c:
pass
self.assert_(c.closed)
def test_fail_init(self):
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
def test_term_hang(self):
rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
req.setsockopt(zmq.LINGER, 0)
req.send(b'hello', copy=False)
req.close()
rep.close()
self.context.term()
def test_instance(self):
ctx = self.Context.instance()
c2 = self.Context.instance(io_threads=2)
self.assertTrue(c2 is ctx)
c2.term()
c3 = self.Context.instance()
c4 = self.Context.instance()
self.assertFalse(c3 is c2)
self.assertFalse(c3.closed)
self.assertTrue(c3 is c4)
def test_instance_subclass_first(self):
self.context.term()
class SubContext(zmq.Context):
pass
sctx = SubContext.instance()
ctx = zmq.Context.instance()
ctx.term()
sctx.term()
assert type(ctx) is zmq.Context
assert type(sctx) is SubContext
def test_instance_subclass_second(self):
self.context.term()
class SubContextInherit(zmq.Context):
pass
class SubContextNoInherit(zmq.Context):
_instance = None
pass
ctx = zmq.Context.instance()
sctx = SubContextInherit.instance()
sctx2 = SubContextNoInherit.instance()
ctx.term()
sctx.term()
sctx2.term()
assert type(ctx) is zmq.Context
assert type(sctx) is zmq.Context
assert type(sctx2) is SubContextNoInherit
def test_instance_threadsafe(self):
self.context.term() # clear default context
q = Queue()
# slow context initialization,
# to ensure that we are both trying to create one at the same time
class SlowContext(self.Context):
def __init__(self, *a, **kw):
time.sleep(1)
super(SlowContext, self).__init__(*a, **kw)
def f():
q.put(SlowContext.instance())
# call ctx.instance() in several threads at once
N = 16
threads = [ Thread(target=f) for i in range(N) ]
[ t.start() for t in threads ]
# also call it in the main thread (not first)
ctx = SlowContext.instance()
assert isinstance(ctx, SlowContext)
# check that all the threads got the same context
for i in range(N):
thread_ctx = q.get(timeout=5)
assert thread_ctx is ctx
# cleanup
ctx.term()
[ t.join(timeout=5) for t in threads ]
def test_socket_passes_kwargs(self):
test_kwarg_value = 'testing one two three'
with KwargTestContext() as ctx:
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
self.assertTrue(socket.test_kwarg_value is test_kwarg_value)
def test_many_sockets(self):
"""opening and closing many sockets shouldn't cause problems"""
ctx = self.Context()
for i in range(16):
sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
[ s.close() for s in sockets ]
# give the reaper a chance
time.sleep(1e-2)
ctx.term()
def test_sockopts(self):
"""setting socket options with ctx attributes"""
ctx = self.Context()
ctx.linger = 5
self.assertEqual(ctx.linger, 5)
s = ctx.socket(zmq.REQ)
self.assertEqual(s.linger, 5)
self.assertEqual(s.getsockopt(zmq.LINGER), 5)
s.close()
# check that subscribe doesn't get set on sockets that don't subscribe:
ctx.subscribe = b''
s = ctx.socket(zmq.REQ)
s.close()
ctx.term()
@mark.skipif(
sys.platform.startswith('win'),
reason='Segfaults on Windows')
def test_destroy(self):
"""Context.destroy should close sockets"""
ctx = self.Context()
sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
# close half of the sockets
[ s.close() for s in sockets[::2] ]
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
for s in sockets:
self.assertTrue(s.closed)
def test_destroy_linger(self):
"""Context.destroy should set linger on closing sockets"""
req,rep = self.create_bound_pair(zmq.REQ, zmq.REP)
req.send(b'hi')
time.sleep(1e-2)
self.context.destroy(linger=0)
# reaper is not instantaneous
time.sleep(1e-2)
for s in (req,rep):
self.assertTrue(s.closed)
def test_term_noclose(self):
"""Context.term won't close sockets"""
ctx = self.Context()
s = ctx.socket(zmq.REQ)
self.assertFalse(s.closed)
t = Thread(target=ctx.term)
t.start()
t.join(timeout=0.1)
self.assertTrue(t.is_alive(), "Context should be waiting")
s.close()
t.join(timeout=0.1)
self.assertFalse(t.is_alive(), "Context should have closed")
def test_gc(self):
"""test close&term by garbage collection alone"""
if PYPY:
raise SkipTest("GC doesn't work ")
# test credit @dln (GH #137):
def gcf():
def inner():
ctx = self.Context()
s = ctx.socket(zmq.PUSH)
inner()
gc.collect()
t = Thread(target=gcf)
t.start()
t.join(timeout=1)
self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context")
def test_cyclic_destroy(self):
"""ctx.destroy should succeed when cyclic ref prevents gc"""
# test credit @dln (GH #137):
class CyclicReference(object):
def __init__(self, parent=None):
self.parent = parent
def crash(self, sock):
self.sock = sock
self.child = CyclicReference(self)
def crash_zmq():
ctx = self.Context()
sock = ctx.socket(zmq.PULL)
c = CyclicReference()
c.crash(sock)
ctx.destroy()
crash_zmq()
def test_term_thread(self):
"""ctx.term should not crash active threads (#139)"""
ctx = self.Context()
evt = Event()
evt.clear()
def block():
s = ctx.socket(zmq.REP)
s.bind_to_random_port('tcp://127.0.0.1')
evt.set()
try:
s.recv()
except zmq.ZMQError as e:
self.assertEqual(e.errno, zmq.ETERM)
return
finally:
s.close()
self.fail("recv should have been interrupted with ETERM")
t = Thread(target=block)
t.start()
evt.wait(1)
self.assertTrue(evt.is_set(), "sync event never fired")
time.sleep(0.01)
ctx.term()
t.join(timeout=1)
self.assertFalse(t.is_alive(), "term should have interrupted s.recv()")
def test_destroy_no_sockets(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.bind_to_random_port('tcp://127.0.0.1')
s.close()
ctx.destroy()
assert s.closed
assert ctx.closed
def test_ctx_opts(self):
if zmq.zmq_version_info() < (3,):
raise SkipTest("context options require libzmq 3")
ctx = self.Context()
ctx.set(zmq.MAX_SOCKETS, 2)
self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2)
ctx.max_sockets = 100
self.assertEqual(ctx.max_sockets, 100)
self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100)
def test_copy(self):
c1 = self.Context()
c2 = copy.copy(c1)
c2b = copy.deepcopy(c1)
c3 = copy.deepcopy(c2)
self.assert_(c2._shadow)
self.assert_(c3._shadow)
self.assertEqual(c1.underlying, c2.underlying)
self.assertEqual(c1.underlying, c3.underlying)
self.assertEqual(c1.underlying, c2b.underlying)
s = c3.socket(zmq.PUB)
s.close()
c1.term()
def test_shadow(self):
ctx = self.Context()
ctx2 = self.Context.shadow(ctx.underlying)
self.assertEqual(ctx.underlying, ctx2.underlying)
s = ctx.socket(zmq.PUB)
s.close()
del ctx2
self.assertFalse(ctx.closed)
s = ctx.socket(zmq.PUB)
ctx2 = self.Context.shadow(ctx.underlying)
s2 = ctx2.socket(zmq.PUB)
s.close()
s2.close()
ctx.term()
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
del ctx2
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket, zstr
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
a = zsocket.new(ctx, zmq.PUSH)
zsocket.bind(a, "inproc://a")
ctx2 = self.Context.shadow_pyczmq(ctx)
b = ctx2.socket(zmq.PULL)
b.connect("inproc://a")
zstr.send(a, b'hi')
rcvd = self.recv(b)
self.assertEqual(rcvd, b'hi')
b.close()
@mark.skipif(
sys.platform.startswith('win'),
reason='No fork on Windows')
def test_fork_instance(self):
ctx = self.Context.instance()
parent_ctx_id = id(ctx)
r_fd, w_fd = os.pipe()
reader = os.fdopen(r_fd, 'r')
child_pid = os.fork()
if child_pid == 0:
ctx = self.Context.instance()
writer = os.fdopen(w_fd, 'w')
child_ctx_id = id(ctx)
ctx.term()
writer.write(str(child_ctx_id) + "\n")
writer.flush()
writer.close()
os._exit(0)
else:
os.close(w_fd)
child_id_s = reader.readline()
reader.close()
assert child_id_s
assert int(child_id_s) != parent_ctx_id
ctx.term()
if False: # disable green context tests
class TestContextGreen(GreenTest, TestContext):
"""gevent subclass of context tests"""
# skip tests that use real threads:
test_gc = GreenTest.skip_green
test_term_thread = GreenTest.skip_green
test_destroy_linger = GreenTest.skip_green

View file

@ -0,0 +1,41 @@
import os
import sys
import pytest
import zmq
@pytest.mark.skipif(
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
)
@pytest.mark.skipif(
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
)
@pytest.mark.parametrize('language_level', [3, 2])
def test_cython(language_level, request, tmpdir):
import pyximport
assert 'zmq.tests.cython_ext' not in sys.modules
importers = pyximport.install(
setup_args=dict(include_dirs=zmq.get_includes()),
language_level=language_level,
build_dir=str(tmpdir),
)
cython_ext = None
def unimport():
pyximport.uninstall(*importers)
sys.modules.pop('zmq.tests.cython_ext', None)
request.addfinalizer(unimport)
# this import tests the compilation
from . import cython_ext
assert hasattr(cython_ext, 'send_recv_test')
# call the compiled function
# this shouldn't do much
msg = b'my msg'
received = cython_ext.send_recv_test(msg)
assert received == msg

View file

@ -0,0 +1,375 @@
import threading
import zmq
from pytest import raises
from zmq.decorators import context, socket
##############################################
# Test cases for @context
##############################################
def test_ctx():
@context()
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
test()
def test_ctx_orig_args():
@context()
def f(foo, bar, ctx, baz=None):
assert isinstance(ctx, zmq.Context), ctx
assert foo == 42
assert bar is True
assert baz == 'mock'
f(42, True, baz='mock')
def test_ctx_arg_naming():
@context('myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_args():
@context('ctx', 5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_arg_kwarg():
@context('ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kw_naming():
@context(name='myctx')
def test(myctx):
assert isinstance(myctx, zmq.Context), myctx
test()
def test_ctx_kwargs():
@context(name='ctx', io_threads=5)
def test(ctx):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_kwargs_default():
@context(name='ctx', io_threads=5)
def test(ctx=None):
assert isinstance(ctx, zmq.Context), ctx
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
test()
def test_ctx_keyword_miss():
@context(name='ctx')
def test(other_name):
pass # the keyword ``ctx`` not found
with raises(TypeError):
test()
def test_ctx_multi_assign():
@context(name='ctx')
def test(ctx):
pass # explosion
with raises(TypeError):
test('mock')
def test_ctx_reinit():
result = {'foo': None, 'bar': None}
@context()
def f(key, ctx):
assert isinstance(ctx, zmq.Context), ctx
result[key] = ctx
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_multi_thread():
@context()
@context()
def f(foo, bar):
assert isinstance(foo, zmq.Context), foo
assert isinstance(bar, zmq.Context), bar
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
##############################################
# Test cases for @socket
##############################################
def test_ctx_skt():
@context()
@socket(zmq.PUB)
def test(ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
assert skt.type == zmq.PUB
test()
def test_skt_name():
@context()
@socket('myskt', zmq.PUB)
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_skt_kwarg():
@context()
@socket(zmq.PUB, name='myskt')
def test(ctx, myskt):
assert isinstance(myskt, zmq.Socket), myskt
assert isinstance(ctx, zmq.Context), ctx
assert myskt.type == zmq.PUB
test()
def test_ctx_skt_name():
@context('ctx')
@socket('skt', zmq.PUB, context_name='ctx')
def test(ctx, skt):
assert isinstance(skt, zmq.Socket), skt
assert isinstance(ctx, zmq.Context), ctx
assert skt.type == zmq.PUB
test()
def test_skt_default_ctx():
@socket(zmq.PUB)
def test(skt):
assert isinstance(skt, zmq.Socket), skt
assert skt.context is zmq.Context.instance()
assert skt.type == zmq.PUB
test()
def test_skt_reinit():
result = {'foo': None, 'bar': None}
@socket(zmq.PUB)
def f(key, skt):
assert isinstance(skt, zmq.Socket), skt
result[key] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo'] is not None, result
assert result['bar'] is not None, result
assert result['foo'] is not result['bar'], result
def test_ctx_skt_reinit():
result = {'foo': {'ctx': None, 'skt': None},
'bar': {'ctx': None, 'skt': None}}
@context()
@socket(zmq.PUB)
def f(key, ctx, skt):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(skt, zmq.Socket), skt
result[key]['ctx'] = ctx
result[key]['skt'] = skt
foo_t = threading.Thread(target=f, args=('foo',))
bar_t = threading.Thread(target=f, args=('bar',))
foo_t.start()
bar_t.start()
foo_t.join()
bar_t.join()
assert result['foo']['ctx'] is not None, result
assert result['foo']['skt'] is not None, result
assert result['bar']['ctx'] is not None, result
assert result['bar']['skt'] is not None, result
assert result['foo']['ctx'] is not result['bar']['ctx'], result
assert result['foo']['skt'] is not result['bar']['skt'], result
def test_skt_type_miss():
@context()
@socket('myskt')
def f(ctx, myskt):
pass # the socket type is missing
with raises(TypeError):
f()
def test_multi_skts():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_single_ctx():
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def test(ctx, pub, sub, push):
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is ctx
assert sub.context is ctx
assert push.context is ctx
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
test()
def test_multi_skts_with_name():
@socket('foo', zmq.PUSH)
@socket('bar', zmq.SUB)
@socket('baz', zmq.PUB)
def test(foo, bar, baz):
assert isinstance(foo, zmq.Socket), foo
assert isinstance(bar, zmq.Socket), bar
assert isinstance(baz, zmq.Socket), baz
assert foo.context is zmq.Context.instance()
assert bar.context is zmq.Context.instance()
assert baz.context is zmq.Context.instance()
assert foo.type == zmq.PUSH
assert bar.type == zmq.SUB
assert baz.type == zmq.PUB
test()
def test_func_return():
@context()
def f(ctx):
assert isinstance(ctx, zmq.Context), ctx
return 'something'
assert f() == 'something'
def test_skt_multi_thread():
@socket(zmq.PUB)
@socket(zmq.SUB)
@socket(zmq.PUSH)
def f(pub, sub, push):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert isinstance(push, zmq.Socket), push
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert push.context is zmq.Context.instance()
assert pub.type == zmq.PUB
assert sub.type == zmq.SUB
assert push.type == zmq.PUSH
assert len(set(map(id, [pub, sub, push]))) == 3
threads = [threading.Thread(target=f) for i in range(8)]
[t.start() for t in threads]
[t.join() for t in threads]
class TestMethodDecorators():
@context()
@socket(zmq.PUB)
@socket(zmq.SUB)
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
assert isinstance(self, TestMethodDecorators), self
assert isinstance(ctx, zmq.Context), ctx
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'bar'
assert pub.context is ctx
assert sub.context is ctx
assert pub.type is zmq.PUB
assert sub.type is zmq.SUB
def test_multi_skts_method(self):
self.multi_skts_method()
def multi_skts_method_other_args(self):
@socket(zmq.PUB)
@socket(zmq.SUB)
def f(foo, pub, sub, bar=None):
assert isinstance(pub, zmq.Socket), pub
assert isinstance(sub, zmq.Socket), sub
assert foo == 'mock'
assert bar == 'fake'
assert pub.context is zmq.Context.instance()
assert sub.context is zmq.Context.instance()
assert pub.type is zmq.PUB
assert sub.type is zmq.SUB
f('mock', bar='fake')
def test_multi_skts_method_other_args(self):
self.multi_skts_method_other_args()

View file

@ -0,0 +1,167 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import zmq
from zmq import devices
from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest, PYPY
from zmq.utils.strtypes import (bytes,unicode,basestring)
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestDevice(BaseZMQTestCase):
def test_device_types(self):
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
self.assertEqual(dev.device_type, devtype)
del dev
def test_device_attributes(self):
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
self.assertEqual(dev.in_type, zmq.SUB)
self.assertEqual(dev.out_type, zmq.PUB)
self.assertEqual(dev.device_type, zmq.QUEUE)
self.assertEqual(dev.daemon, True)
del dev
def test_single_socket_forwarder_connect(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_in('tcp://127.0.0.1:%i'%port)
dev.start()
time.sleep(.25)
msg = b'hello'
req.send(msg)
self.assertEqual(msg, self.recv(req))
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
req = self.context.socket(zmq.REQ)
port = req.bind_to_random_port('tcp://127.0.0.1')
dev.connect_out('tcp://127.0.0.1:%i'%port)
dev.start()
time.sleep(.25)
msg = b'hello again'
req.send(msg)
self.assertEqual(msg, self.recv(req))
del dev
req.close()
def test_single_socket_forwarder_bind(self):
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i'%port)
dev.start()
time.sleep(.25)
msg = b'hello'
req.send(msg)
self.assertEqual(msg, self.recv(req))
del dev
req.close()
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
req = self.context.socket(zmq.REQ)
req.connect('tcp://127.0.0.1:%i'%port)
dev.start()
time.sleep(.25)
msg = b'hello again'
req.send(msg)
self.assertEqual(msg, self.recv(req))
del dev
req.close()
def test_device_bind_to_random_with_args(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend([
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max)
])
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_device_bind_to_random_binderror(self):
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
iface = 'tcp://127.0.0.1'
try:
for i in range(11):
dev.bind_in_to_random_port(
iface, min_port=10000, max_port=10010
)
except zmq.ZMQBindError as e:
return
else:
self.fail('Should have failed')
def test_proxy(self):
if zmq.zmq_version_info() < (3,2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
push.send(msg)
self.sockets.extend([push, pull, mon])
self.assertEqual(msg, self.recv(pull))
self.assertEqual(msg, self.recv(mon))
def test_proxy_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (3, 2):
raise SkipTest("Proxies only in libzmq >= 3")
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend([
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max)
])
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
if have_gevent:
import gevent
import zmq.green
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
def test_green_device(self):
rep = self.context.socket(zmq.REP)
req = self.context.socket(zmq.REQ)
self.sockets.extend([req, rep])
port = rep.bind_to_random_port('tcp://127.0.0.1')
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
req.connect('tcp://127.0.0.1:%i' % port)
req.send(b'hi')
timeout = gevent.Timeout(3)
timeout.start()
receiver = gevent.spawn(req.recv)
self.assertEqual(receiver.get(2), b'hi')
timeout.cancel()
g.kill(block=True)

View file

@ -0,0 +1,52 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import platform
import time
import pytest
import zmq
from zmq.tests import (
BaseZMQTestCase, skip_pypy
)
class TestDraftSockets(BaseZMQTestCase):
def setUp(self):
if not zmq.DRAFT_API:
raise pytest.skip("draft api unavailable")
super(TestDraftSockets, self).setUp()
def test_client_server(self):
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
client.send(b'request')
msg = self.recv(server, copy=False)
assert msg.routing_id is not None
server.send(b'reply', routing_id=msg.routing_id)
reply = self.recv(client)
assert reply == b'reply'
def test_radio_dish(self):
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
dish.rcvtimeo = 250
group = 'mygroup'
dish.join(group)
received_count = 0
received = set()
sent = set()
for i in range(10):
msg = str(i).encode('ascii')
sent.add(msg)
radio.send(msg, group=group)
try:
recvd = dish.recv()
except zmq.Again:
time.sleep(0.1)
else:
received.add(recvd)
received_count += 1
# assert that we got *something*
assert len(received.intersection(sent)) >= 5

View file

@ -0,0 +1,43 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import sys
import time
from threading import Thread
import zmq
from zmq import ZMQError, strerror, Again, ContextTerminated
from zmq.tests import BaseZMQTestCase
if sys.version_info[0] >= 3:
long = int
class TestZMQError(BaseZMQTestCase):
def test_strerror(self):
"""test that strerror gets the right type."""
for i in range(10):
e = strerror(i)
self.assertTrue(isinstance(e, str))
def test_zmqerror(self):
for errno in range(10):
e = ZMQError(errno)
self.assertEqual(e.errno, errno)
self.assertEqual(str(e), strerror(errno))
def test_again(self):
s = self.context.socket(zmq.REP)
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
s.close()
def atest_ctxterm(self):
s = self.context.socket(zmq.REP)
t = Thread(target=self.context.term)
t.start()
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
s.close()
t.join()

View file

@ -0,0 +1,20 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
import sys
import zmq
from pytest import mark
@mark.skipif('zmq.zmq_version_info() < (4,1)')
def test_has():
assert not zmq.has('something weird')
has_ipc = zmq.has('ipc')
not_windows = not sys.platform.startswith('win')
assert has_ipc == not_windows
@mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
def test_has_curve():
"""bundled libzmq has curve support"""
assert zmq.has('curve')

View file

@ -0,0 +1,353 @@
# coding: utf-8
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from datetime import timedelta
import os
import json
import sys
import pytest
gen = pytest.importorskip('tornado.gen')
import zmq
from zmq.eventloop import future
from tornado.ioloop import IOLoop
from zmq.utils.strtypes import u
from zmq.tests import BaseZMQTestCase
class TestFutureSocket(BaseZMQTestCase):
Context = future.Context
def setUp(self):
self.loop = IOLoop()
self.loop.make_current()
super(TestFutureSocket, self).setUp()
def tearDown(self):
super(TestFutureSocket, self).tearDown()
if self.loop:
self.loop.close(all_fds=True)
IOLoop.clear_current()
IOLoop.clear_instance()
def test_socket_class(self):
s = self.context.socket(zmq.PUSH)
assert isinstance(s, future.Socket)
s.close()
def test_instance_subclass_first(self):
actx = self.Context.instance()
ctx = zmq.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_instance_subclass_second(self):
ctx = zmq.Context.instance()
actx = self.Context.instance()
ctx.term()
actx.term()
assert type(ctx) is zmq.Context
assert type(actx) is self.Context
def test_recv_multipart(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_multipart()
assert not f.done()
yield a.send(b'hi')
recvd = yield f
self.assertEqual(recvd, [b'hi'])
self.loop.run_sync(test)
def test_recv(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv()
assert not f1.done()
assert not f2.done()
yield a.send_multipart([b'hi', b'there'])
recvd = yield f2
assert f1.done()
self.assertEqual(f1.result(), b'hi')
self.assertEqual(recvd, b'there')
self.loop.run_sync(test)
def test_recv_cancel(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f1 = b.recv()
f2 = b.recv_multipart()
assert f1.cancel()
assert f1.done()
assert not f2.done()
yield a.send_multipart([b'hi', b'there'])
recvd = yield f2
assert f1.cancelled()
assert f2.done()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_recv_timeout(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
b.rcvtimeo = 100
f1 = b.recv()
b.rcvtimeo = 1000
f2 = b.recv_multipart()
with pytest.raises(zmq.Again):
yield f1
yield a.send_multipart([b'hi', b'there'])
recvd = yield f2
assert f2.done()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_sync(test)
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_send_timeout(self):
@gen.coroutine
def test():
s = self.socket(zmq.PUSH)
s.sndtimeo = 100
with pytest.raises(zmq.Again):
yield s.send(b'not going anywhere')
self.loop.run_sync(test)
@pytest.mark.now
def test_send_noblock(self):
@gen.coroutine
def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
yield s.send(b'not going anywhere', flags=zmq.NOBLOCK)
self.loop.run_sync(test)
@pytest.mark.now
def test_send_multipart_noblock(self):
@gen.coroutine
def test():
s = self.socket(zmq.PUSH)
with pytest.raises(zmq.Again):
yield s.send_multipart([b'not going anywhere'], flags=zmq.NOBLOCK)
self.loop.run_sync(test)
def test_recv_string(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_string()
assert not f.done()
msg = u('πøøπ')
yield a.send_string(msg)
recvd = yield f
assert f.done()
self.assertEqual(f.result(), msg)
self.assertEqual(recvd, msg)
self.loop.run_sync(test)
def test_recv_json(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
obj = dict(a=5)
yield a.send_json(obj)
recvd = yield f
assert f.done()
self.assertEqual(f.result(), obj)
self.assertEqual(recvd, obj)
self.loop.run_sync(test)
def test_recv_json_cancelled(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_json()
assert not f.done()
f.cancel()
# cycle eventloop to allow cancel events to fire
yield gen.sleep(0)
obj = dict(a=5)
yield a.send_json(obj)
with pytest.raises(future.CancelledError):
recvd = yield f
assert f.done()
# give it a chance to incorrectly consume the event
events = yield b.poll(timeout=5)
assert events
yield gen.sleep(0)
# make sure cancelled recv didn't eat up event
recvd = yield gen.with_timeout(timedelta(seconds=5), b.recv_json())
assert recvd == obj
self.loop.run_sync(test)
def test_recv_pyobj(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.recv_pyobj()
assert not f.done()
obj = dict(a=5)
yield a.send_pyobj(obj)
recvd = yield f
assert f.done()
self.assertEqual(f.result(), obj)
self.assertEqual(recvd, obj)
self.loop.run_sync(test)
def test_custom_serialize(self):
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
yield a.send_serialized(msg, serialize)
recvd = yield b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
yield b.send_serialized(recvd, serialize)
r2 = yield a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
self.loop.run_sync(test)
def test_custom_serialize_error(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
with pytest.raises(TypeError):
yield a.send_serialized(json, json.dumps)
yield a.send(b'not json')
with pytest.raises(TypeError):
recvd = yield b.recv_serialized(json.loads)
self.loop.run_sync(test)
def test_poll(self):
@gen.coroutine
def test():
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
f = b.poll(timeout=0)
assert f.done()
self.assertEqual(f.result(), 0)
f = b.poll(timeout=1)
assert not f.done()
evt = yield f
self.assertEqual(evt, 0)
f = b.poll(timeout=1000)
assert not f.done()
yield a.send_multipart([b'hi', b'there'])
evt = yield f
self.assertEqual(evt, zmq.POLLIN)
recvd = yield b.recv_multipart()
self.assertEqual(recvd, [b'hi', b'there'])
self.loop.run_sync(test)
@pytest.mark.skipif(
sys.platform.startswith('win'),
reason='Windows unsupported socket type')
def test_poll_base_socket(self):
@gen.coroutine
def test():
ctx = zmq.Context()
url = 'inproc://test'
a = ctx.socket(zmq.PUSH)
b = ctx.socket(zmq.PULL)
self.sockets.extend([a, b])
a.bind(url)
b.connect(url)
poller = future.Poller()
poller.register(b, zmq.POLLIN)
f = poller.poll(timeout=1000)
assert not f.done()
a.send_multipart([b'hi', b'there'])
evt = yield f
self.assertEqual(evt, [(b, zmq.POLLIN)])
recvd = b.recv_multipart()
self.assertEqual(recvd, [b'hi', b'there'])
a.close()
b.close()
ctx.term()
self.loop.run_sync(test)
def test_close_all_fds(self):
s = self.socket(zmq.PUB)
self.loop.close(all_fds=True)
self.loop = None # avoid second close later
assert s.closed
@pytest.mark.skipif(
sys.platform.startswith('win'),
reason='Windows does not support polling on files')
def test_poll_raw(self):
@gen.coroutine
def test():
p = future.Poller()
# make a pipe
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
# POLLOUT
p.register(r, zmq.POLLIN)
p.register(w, zmq.POLLOUT)
evts = yield p.poll(timeout=1)
evts = dict(evts)
assert r.fileno() not in evts
assert w.fileno() in evts
assert evts[w.fileno()] == zmq.POLLOUT
# POLLIN
p.unregister(w)
w.write(b'x')
w.flush()
evts = yield p.poll(timeout=1000)
evts = dict(evts)
assert r.fileno() in evts
assert evts[r.fileno()] == zmq.POLLIN
assert r.read(1) == b'x'
r.close()
w.close()
self.loop.run_sync(test)

View file

@ -0,0 +1,68 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import sys
from unittest import TestCase
import pytest
class TestImports(TestCase):
"""Test Imports - the quickest test to ensure that we haven't
introduced version-incompatible syntax errors."""
def test_toplevel(self):
"""test toplevel import"""
import zmq
def test_core(self):
"""test core imports"""
from zmq import Context
from zmq import Socket
from zmq import Poller
from zmq import Frame
from zmq import constants
from zmq import device, proxy
from zmq import (
zmq_version,
zmq_version_info,
pyzmq_version,
pyzmq_version_info,
)
def test_devices(self):
"""test device imports"""
import zmq.devices
from zmq.devices import basedevice
from zmq.devices import monitoredqueue
from zmq.devices import monitoredqueuedevice
def test_log(self):
"""test log imports"""
import zmq.log
from zmq.log import handlers
def test_eventloop(self):
"""test eventloop imports"""
try:
import tornado
except ImportError:
pytest.skip('requires tornado')
import zmq.eventloop
from zmq.eventloop import ioloop
from zmq.eventloop import zmqstream
def test_utils(self):
"""test util imports"""
import zmq.utils
from zmq.utils import strtypes
from zmq.utils import jsonapi
def test_ssh(self):
"""test ssh imports"""
from zmq.ssh import tunnel
def test_decorators(self):
"""test decorators imports"""
from zmq.decorators import context, socket

View file

@ -0,0 +1,33 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
import os
class TestIncludes(TestCase):
def test_get_includes(self):
from os.path import dirname, basename
includes = zmq.get_includes()
self.assertTrue(isinstance(includes, list))
self.assertTrue(len(includes) >= 2)
parent = includes[0]
self.assertTrue(isinstance(parent, str))
utilsdir = includes[1]
self.assertTrue(isinstance(utilsdir, str))
utils = basename(utilsdir)
self.assertEqual(utils, "utils")
def test_get_library_dirs(self):
from os.path import dirname, basename
libdirs = zmq.get_library_dirs()
self.assertTrue(isinstance(libdirs, list))
self.assertEqual(len(libdirs), 1)
parent = libdirs[0]
self.assertTrue(isinstance(parent, str))
libdir = basename(parent)
self.assertEqual(libdir, "zmq")

View file

@ -0,0 +1,141 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from __future__ import absolute_import
try:
import asyncio
except ImportError:
asyncio = None
import time
import os
import threading
import pytest
import zmq
from zmq.tests import BaseZMQTestCase, have_gevent
try:
from tornado.ioloop import IOLoop as BaseIOLoop
from zmq.eventloop import ioloop
_tornado = True
except ImportError:
_tornado = False
# tornado 5 with asyncio disables custom IOLoop implementations
t5asyncio = False
if _tornado:
import tornado
if tornado.version_info >= (5,) and asyncio:
t5asyncio = True
def printer():
os.system("say hello")
raise Exception
print (time.time())
class Delay(threading.Thread):
def __init__(self, f, delay=1):
self.f=f
self.delay=delay
self.aborted=False
self.cond=threading.Condition()
super(Delay, self).__init__()
def run(self):
self.cond.acquire()
self.cond.wait(self.delay)
self.cond.release()
if not self.aborted:
self.f()
def abort(self):
self.aborted=True
self.cond.acquire()
self.cond.notify()
self.cond.release()
class TestIOLoop(BaseZMQTestCase):
if _tornado:
IOLoop = ioloop.IOLoop
def setUp(self):
if not _tornado:
pytest.skip("tornado required")
super(TestIOLoop, self).setUp()
if asyncio:
asyncio.set_event_loop(asyncio.new_event_loop())
def tearDown(self):
super(TestIOLoop, self).tearDown()
BaseIOLoop.clear_current()
BaseIOLoop.clear_instance()
def test_simple(self):
"""simple IOLoop creation test"""
loop = self.IOLoop()
loop.make_current()
dc = ioloop.PeriodicCallback(loop.stop, 200)
pc = ioloop.PeriodicCallback(lambda : None, 10)
pc.start()
dc.start()
t = Delay(loop.stop,1)
t.start()
loop.start()
if t.is_alive():
t.abort()
else:
self.fail("IOLoop failed to exit")
def test_instance(self):
"""IOLoop.instance returns the right object"""
loop = self.IOLoop.instance()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.instance()
assert base_loop is loop
def test_current(self):
"""IOLoop.current returns the right object"""
loop = ioloop.IOLoop.current()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.current()
assert base_loop is loop
def test_close_all(self):
"""Test close(all_fds=True)"""
loop = self.IOLoop.current()
req,rep = self.create_bound_pair(zmq.REQ, zmq.REP)
loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ)
loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ)
self.assertEqual(req.closed, False)
self.assertEqual(rep.closed, False)
loop.close(all_fds=True)
self.assertEqual(req.closed, True)
self.assertEqual(rep.closed, True)
if have_gevent and _tornado:
import zmq.green.eventloop.ioloop as green_ioloop
class TestIOLoopGreen(TestIOLoop):
IOLoop = green_ioloop.IOLoop
def xtest_instance(self):
"""Green IOLoop.instance returns the right object"""
loop = self.IOLoop.instance()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.instance()
assert base_loop is loop
def xtest_current(self):
"""Green IOLoop.current returns the right object"""
loop = self.IOLoop.current()
if not t5asyncio:
assert isinstance(loop, self.IOLoop)
base_loop = BaseIOLoop.current()
assert base_loop is loop

View file

@ -0,0 +1,178 @@
# encoding: utf-8
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import logging
import time
from unittest import TestCase
import zmq
from zmq.log import handlers
from zmq.utils.strtypes import b, u
from zmq.tests import BaseZMQTestCase
class TestPubLog(BaseZMQTestCase):
iface = 'inproc://zmqlog'
topic= 'zmq'
@property
def logger(self):
# print dir(self)
logger = logging.getLogger('zmqtest')
logger.setLevel(logging.DEBUG)
return logger
def connect_handler(self, topic=None):
topic = self.topic if topic is None else topic
logger = self.logger
pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = topic
logger.addHandler(handler)
sub.setsockopt(zmq.SUBSCRIBE, b(topic))
time.sleep(0.1)
return logger, handler, sub
def test_init_iface(self):
logger = self.logger
ctx = self.context
handler = handlers.PUBHandler(self.iface)
self.assertFalse(handler.ctx is ctx)
self.sockets.append(handler.socket)
# handler.ctx.term()
handler = handlers.PUBHandler(self.iface, self.context)
self.sockets.append(handler.socket)
self.assertTrue(handler.ctx is ctx)
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
sub = ctx.socket(zmq.SUB)
self.sockets.append(sub)
sub.setsockopt(zmq.SUBSCRIBE, b(self.topic))
sub.connect(self.iface)
import time; time.sleep(0.25)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
self.assertEqual(topic, b'zmq.INFO')
self.assertEqual(msg2, b(msg1)+b'\n')
logger.removeHandler(handler)
def test_init_socket(self):
pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
logger = self.logger
handler = handlers.PUBHandler(pub)
handler.setLevel(logging.DEBUG)
handler.root_topic = self.topic
logger.addHandler(handler)
self.assertTrue(handler.socket is pub)
self.assertTrue(handler.ctx is pub.context)
self.assertTrue(handler.ctx is self.context)
sub.setsockopt(zmq.SUBSCRIBE, b(self.topic))
import time; time.sleep(0.1)
msg1 = 'message'
logger.info(msg1)
(topic, msg2) = sub.recv_multipart()
self.assertEqual(topic, b'zmq.INFO')
self.assertEqual(msg2, b(msg1)+b'\n')
logger.removeHandler(handler)
def test_root_topic(self):
logger, handler, sub = self.connect_handler()
handler.socket.bind(self.iface)
sub2 = sub.context.socket(zmq.SUB)
self.sockets.append(sub2)
sub2.connect(self.iface)
sub2.setsockopt(zmq.SUBSCRIBE, b'')
handler.root_topic = b'twoonly'
msg1 = 'ignored'
logger.info(msg1)
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
topic,msg2 = sub2.recv_multipart()
self.assertEqual(topic, b'twoonly.INFO')
self.assertEqual(msg2, b(msg1)+b'\n')
logger.removeHandler(handler)
def test_blank_root_topic(self):
logger, handler, sub_everything = self.connect_handler()
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
handler.socket.bind(self.iface)
sub_only_info = sub_everything.context.socket(zmq.SUB)
self.sockets.append(sub_only_info)
sub_only_info.connect(self.iface)
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
handler.setRootTopic(b'')
msg_debug = 'debug_message'
logger.debug(msg_debug)
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
topic, msg_debug_response = sub_everything.recv_multipart()
self.assertEqual(topic, b'DEBUG')
msg_info = 'info_message'
logger.info(msg_info)
topic, msg_info_response_everything = sub_everything.recv_multipart()
self.assertEqual(topic, b'INFO')
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
self.assertEqual(topic, b'INFO')
self.assertEqual(msg_info_response_everything, msg_info_response_onlyinfo)
logger.removeHandler(handler)
def test_unicode_message(self):
logger, handler, sub = self.connect_handler()
base_topic = b(self.topic + '.INFO')
for msg, expected in [
(u('hello'), [base_topic, b('hello\n')]),
(u('héllo'), [base_topic, b('héllo\n')]),
(u('tøpic::héllo'), [base_topic + b('.tøpic'), b('héllo\n')]),
]:
logger.info(msg)
received = sub.recv_multipart()
self.assertEqual(received, expected)
logger.removeHandler(handler)
def test_set_info_formatter_via_property(self):
logger, handler, sub = self.connect_handler()
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
logger.info('info message')
topic, msg = sub.recv_multipart()
self.assertEqual(msg, b'info message UNITTEST\n')
logger.removeHandler(handler)
def test_custom_global_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST %(message)s")
handler.setFormatter(formatter)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
logger.info('info message')
topic, msg = sub.recv_multipart()
self.assertEqual(msg, b'UNITTEST info message')
logger.debug('debug message')
topic, msg = sub.recv_multipart()
self.assertEqual(msg, b'UNITTEST debug message')
logger.removeHandler(handler)
def test_custom_debug_formatter(self):
logger, handler, sub = self.connect_handler()
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
handler.setFormatter(formatter, logging.DEBUG)
handler.socket.bind(self.iface)
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
logger.info('info message')
topic, msg = sub.recv_multipart()
self.assertEqual(msg, b'info message\n')
logger.debug('debug message')
topic, msg = sub.recv_multipart()
self.assertEqual(msg, b'UNITTEST DEBUG debug message')
logger.removeHandler(handler)

View file

@ -0,0 +1,348 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import sys
try:
from sys import getrefcount as grc
except ImportError:
grc = None
import time
from pprint import pprint
from unittest import TestCase
import zmq
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY
from zmq.utils.strtypes import unicode, bytes, b, u
# some useful constants:
x = b'x'
if grc:
rc0 = grc(x)
v = memoryview(x)
view_rc = grc(x) - rc0
def await_gc(obj, rc):
"""wait for refcount on an object to drop to an expected value
Necessary because of the zero-copy gc thread,
which can take some time to receive its DECREF message.
"""
for i in range(50):
# rc + 2 because of the refs in this function
if grc(obj) <= rc + 2:
return
time.sleep(0.05)
class TestFrame(BaseZMQTestCase):
@skip_pypy
def test_above_30(self):
"""Message above 30 bytes are never copied by 0MQ."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i)*x
self.assertEqual(grc(s), 2)
m = zmq.Frame(s, copy=False)
self.assertEqual(grc(s), 4)
del m
await_gc(s, 2)
self.assertEqual(grc(s), 2)
del s
def test_str(self):
"""Test the str representations of the Frames."""
for i in range(16):
s = (2**i)*x
m = zmq.Frame(s)
m_str = str(m)
m_str_b = b(m_str) # py3compat
self.assertEqual(s, m_str_b)
def test_bytes(self):
"""Test the Frame.bytes property."""
for i in range(1,16):
s = (2**i)*x
m = zmq.Frame(s)
b = m.bytes
self.assertEqual(s, m.bytes)
if not PYPY:
# check that it copies
self.assert_(b is not s)
# check that it copies only once
self.assert_(b is m.bytes)
def test_unicode(self):
"""Test the unicode representations of the Frames."""
s = u('asdf')
self.assertRaises(TypeError, zmq.Frame, s)
for i in range(16):
s = (2**i)*u('§')
m = zmq.Frame(s.encode('utf8'))
self.assertEqual(s, unicode(m.bytes,'utf8'))
def test_len(self):
"""Test the len of the Frames."""
for i in range(16):
s = (2**i)*x
m = zmq.Frame(s)
self.assertEqual(len(s), len(m))
@skip_pypy
def test_lifecycle1(self):
"""Run through a ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i)*x
rc = 2
self.assertEqual(grc(s), rc)
m = zmq.Frame(s, copy=False)
rc += 2
self.assertEqual(grc(s), rc)
m2 = copy.copy(m)
rc += 1
self.assertEqual(grc(s), rc)
buf = m2.buffer
rc += view_rc
self.assertEqual(grc(s), rc)
self.assertEqual(s, b(str(m)))
self.assertEqual(s, bytes(m2))
self.assertEqual(s, m.bytes)
# self.assert_(s is str(m))
# self.assert_(s is str(m2))
del m2
rc -= 1
self.assertEqual(grc(s), rc)
rc -= view_rc
del buf
self.assertEqual(grc(s), rc)
del m
rc -= 2
await_gc(s, rc)
self.assertEqual(grc(s), rc)
self.assertEqual(rc, 2)
del s
@skip_pypy
def test_lifecycle2(self):
"""Run through a different ref counting cycle with a copy."""
for i in range(5, 16): # 32, 64,..., 65536
s = (2**i)*x
rc = 2
self.assertEqual(grc(s), rc)
m = zmq.Frame(s, copy=False)
rc += 2
self.assertEqual(grc(s), rc)
m2 = copy.copy(m)
rc += 1
self.assertEqual(grc(s), rc)
buf = m.buffer
rc += view_rc
self.assertEqual(grc(s), rc)
self.assertEqual(s, b(str(m)))
self.assertEqual(s, bytes(m2))
self.assertEqual(s, m2.bytes)
self.assertEqual(s, m.bytes)
# self.assert_(s is str(m))
# self.assert_(s is str(m2))
del buf
self.assertEqual(grc(s), rc)
del m
# m.buffer is kept until m is del'd
rc -= view_rc
rc -= 1
self.assertEqual(grc(s), rc)
del m2
rc -= 2
await_gc(s, rc)
self.assertEqual(grc(s), rc)
self.assertEqual(rc, 2)
del s
@skip_pypy
def test_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
self.assertFalse(m.tracker.done)
pm = zmq.MessageTracker(m)
self.assertFalse(pm.done)
del m
for i in range(10):
if pm.done:
break
time.sleep(0.1)
self.assertTrue(pm.done)
def test_no_tracker(self):
m = zmq.Frame(b'asdf', track=False)
self.assertEqual(m.tracker, None)
m2 = copy.copy(m)
self.assertEqual(m2.tracker, None)
self.assertRaises(ValueError, zmq.MessageTracker, m)
@skip_pypy
def test_multi_tracker(self):
m = zmq.Frame(b'asdf', copy=False, track=True)
m2 = zmq.Frame(b'whoda', copy=False, track=True)
mt = zmq.MessageTracker(m,m2)
self.assertFalse(m.tracker.done)
self.assertFalse(mt.done)
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
del m
time.sleep(0.1)
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
self.assertFalse(mt.done)
del m2
self.assertTrue(mt.wait() is None)
self.assertTrue(mt.done)
def test_buffer_in(self):
"""test using a buffer as input"""
ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
m = zmq.Frame(memoryview(ins))
def test_bad_buffer_in(self):
"""test using a bad object"""
self.assertRaises(TypeError, zmq.Frame, 5)
self.assertRaises(TypeError, zmq.Frame, object())
def test_buffer_out(self):
"""receiving buffered output"""
ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
m = zmq.Frame(ins)
outb = m.buffer
self.assertTrue(isinstance(outb, memoryview))
self.assert_(outb is m.buffer)
self.assert_(m.buffer is m.buffer)
@skip_pypy
def test_memoryview_shape(self):
"""memoryview shape info"""
if sys.version_info < (3,):
raise SkipTest("only test memoryviews on Python 3")
data = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
n = len(data)
f = zmq.Frame(data)
view1 = f.buffer
self.assertEqual(view1.ndim, 1)
self.assertEqual(view1.shape, (n,))
self.assertEqual(view1.tobytes(), data)
view2 = memoryview(f)
self.assertEqual(view2.ndim, 1)
self.assertEqual(view2.shape, (n,))
self.assertEqual(view2.tobytes(), data)
def test_multisend(self):
"""ensure that a message remains intact after multiple sends"""
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
s = b"message"
m = zmq.Frame(s)
self.assertEqual(s, m.bytes)
a.send(m, copy=False)
time.sleep(0.1)
self.assertEqual(s, m.bytes)
a.send(m, copy=False)
time.sleep(0.1)
self.assertEqual(s, m.bytes)
a.send(m, copy=True)
time.sleep(0.1)
self.assertEqual(s, m.bytes)
a.send(m, copy=True)
time.sleep(0.1)
self.assertEqual(s, m.bytes)
for i in range(4):
r = b.recv()
self.assertEqual(s,r)
self.assertEqual(s, m.bytes)
def test_memoryview(self):
"""test messages from memoryview"""
major,minor = sys.version_info[:2]
if not (major >= 3 or (major == 2 and minor >= 7)):
raise SkipTest("memoryviews only in python >= 2.7")
s = b'carrotjuice'
v = memoryview(s)
m = zmq.Frame(s)
buf = m.buffer
s2 = buf.tobytes()
self.assertEqual(s2,s)
self.assertEqual(m.bytes,s)
def test_noncopying_recv(self):
"""check for clobbering message buffers"""
null = b'\0'*64
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(32):
# try a few times
sb.send(null, copy=False)
m = sa.recv(copy=False)
mb = m.bytes
# buf = memoryview(m)
buf = m.buffer
del m
for i in range(5):
ff=b'\xff'*(40 + i*10)
sb.send(ff, copy=False)
m2 = sa.recv(copy=False)
b = buf.tobytes()
self.assertEqual(b, null)
self.assertEqual(mb, null)
self.assertEqual(m2.bytes, ff)
@skip_pypy
def test_buffer_numpy(self):
"""test non-copying numpy array messages"""
try:
import numpy
from numpy.testing import assert_array_equal
except ImportError:
raise SkipTest("requires numpy")
if sys.version_info < (2,7):
raise SkipTest("requires new-style buffer interface (py >= 2.7)")
rand = numpy.random.randint
shapes = [ rand(2,5) for i in range(5) ]
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
dtypes = [int, float, '>i4', 'B']
for i in range(1,len(shapes)+1):
shape = shapes[:i]
for dt in dtypes:
A = numpy.empty(shape, dtype=dt)
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
A['a'] = 1024
A['b'] = 1e9
A['c'] = 'hello there'
a.send(A, copy=False)
msg = b.recv(copy=False)
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
assert_array_equal(A, B)
def test_frame_more(self):
"""test Frame.more attribute"""
frame = zmq.Frame(b"hello")
self.assertFalse(frame.more)
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
sa.send_multipart([b'hi', b'there'])
frame = self.recv(sb, copy=False)
self.assertTrue(frame.more)
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
self.assertTrue(frame.get(zmq.MORE))
frame = self.recv(sb, copy=False)
self.assertFalse(frame.more)
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
self.assertFalse(frame.get(zmq.MORE))

View file

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import errno
import sys
import time
import struct
from unittest import TestCase
from pytest import mark
import zmq
from zmq.tests import BaseZMQTestCase, skip_pypy, require_zmq_4
from zmq.utils.monitor import recv_monitor_message
class TestSocketMonitor(BaseZMQTestCase):
@require_zmq_4
def test_monitor(self):
"""Test monitoring interface for sockets."""
s_rep = self.context.socket(zmq.REP)
s_req = self.context.socket(zmq.REQ)
self.sockets.extend([s_rep, s_req])
s_req.bind("tcp://127.0.0.1:6666")
# try monitoring the REP socket
s_rep.monitor("inproc://monitor.rep", zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED)
# create listening socket for monitor
s_event = self.context.socket(zmq.PAIR)
self.sockets.append(s_event)
s_event.connect("inproc://monitor.rep")
s_event.linger = 0
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6666")
m = recv_monitor_message(s_event)
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666")
# test receive event for connected event
m = recv_monitor_message(s_event)
self.assertEqual(m['event'], zmq.EVENT_CONNECTED)
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666")
# test monitor can be disabled.
s_rep.disable_monitor()
m = recv_monitor_message(s_event)
self.assertEqual(m['event'], zmq.EVENT_MONITOR_STOPPED)
@require_zmq_4
def test_monitor_repeat(self):
s = self.socket(zmq.PULL)
m = s.get_monitor_socket()
self.sockets.append(m)
m2 = s.get_monitor_socket()
assert m is m2
s.disable_monitor()
evt = recv_monitor_message(m)
self.assertEqual(evt['event'], zmq.EVENT_MONITOR_STOPPED)
m.close()
s.close()
@require_zmq_4
def test_monitor_connected(self):
"""Test connected monitoring socket."""
s_rep = self.context.socket(zmq.REP)
s_req = self.context.socket(zmq.REQ)
self.sockets.extend([s_rep, s_req])
s_req.bind("tcp://127.0.0.1:6667")
# try monitoring the REP socket
# create listening socket for monitor
s_event = s_rep.get_monitor_socket()
s_event.linger = 0
self.sockets.append(s_event)
# test receive event for connect event
s_rep.connect("tcp://127.0.0.1:6667")
m = recv_monitor_message(s_event)
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667")
# test receive event for connected event
m = recv_monitor_message(s_event)
self.assertEqual(m['event'], zmq.EVENT_CONNECTED)
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667")

View file

@ -0,0 +1,221 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
from unittest import TestCase
import zmq
from zmq import devices
from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
from zmq.utils.strtypes import unicode
if PYPY or zmq.zmq_version_info() >= (4,1):
# cleanup of shared Context doesn't work on PyPy
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
devices.Device.context_factory = zmq.Context
class TestMonitoredQueue(BaseZMQTestCase):
sockets = []
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB,
in_prefix, out_prefix)
alice = self.context.socket(zmq.PAIR)
bob = self.context.socket(zmq.PAIR)
mon = self.context.socket(zmq.SUB)
aport = alice.bind_to_random_port('tcp://127.0.0.1')
bport = bob.bind_to_random_port('tcp://127.0.0.1')
mport = mon.bind_to_random_port('tcp://127.0.0.1')
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
self.device.connect_in("tcp://127.0.0.1:%i"%aport)
self.device.connect_out("tcp://127.0.0.1:%i"%bport)
self.device.connect_mon("tcp://127.0.0.1:%i"%mport)
self.device.start()
time.sleep(.2)
try:
# this is currenlty necessary to ensure no dropped monitor messages
# see LIBZMQ-248 for more info
mon.recv_multipart(zmq.NOBLOCK)
except zmq.ZMQError:
pass
self.sockets.extend([alice, bob, mon])
return alice, bob, mon
def teardown_device(self):
for socket in self.sockets:
socket.close()
del socket
del self.device
def test_reply(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
bobs = self.recv_multipart(bob)
self.assertEqual(alices, bobs)
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
self.assertEqual(alices, bobs)
self.teardown_device()
def test_queue(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
self.assertEqual(alices, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices2, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices3, bobs)
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
self.assertEqual(alices, bobs)
self.teardown_device()
def test_monitor(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
self.assertEqual(alices, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'in']+bobs, mons)
bobs = self.recv_multipart(bob)
self.assertEqual(alices2, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices3, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'in']+alices2, mons)
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
self.assertEqual(alices, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'in']+alices3, mons)
mons = self.recv_multipart(mon)
self.assertEqual([b'out']+bobs, mons)
self.teardown_device()
def test_prefix(self):
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
self.assertEqual(alices, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'foo']+bobs, mons)
bobs = self.recv_multipart(bob)
self.assertEqual(alices2, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices3, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'foo']+alices2, mons)
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
self.assertEqual(alices, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'foo']+alices3, mons)
mons = self.recv_multipart(mon)
self.assertEqual([b'bar']+bobs, mons)
self.teardown_device()
def test_monitor_subscribe(self):
alice, bob, mon = self.build_device(b"out")
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
self.assertEqual(alices, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices2, bobs)
bobs = self.recv_multipart(bob)
self.assertEqual(alices3, bobs)
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
self.assertEqual(alices, bobs)
mons = self.recv_multipart(mon)
self.assertEqual([b'out']+bobs, mons)
self.teardown_device()
def test_router_router(self):
"""test router-router MQ devices"""
dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
self.device = dev
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
a = self.context.socket(zmq.DEALER)
a.identity = b'a'
b = self.context.socket(zmq.DEALER)
b.identity = b'b'
self.sockets.extend([a, b])
a.connect('tcp://127.0.0.1:%i'%porta)
b.connect('tcp://127.0.0.1:%i'%portb)
dev.start()
time.sleep(1)
if zmq.zmq_version_info() >= (3,1,0):
# flush erroneous poll state, due to LIBZMQ-280
ping_msg = [ b'ping', b'pong' ]
for s in (a,b):
s.send_multipart(ping_msg)
try:
s.recv(zmq.NOBLOCK)
except zmq.ZMQError:
pass
msg = [ b'hello', b'there' ]
a.send_multipart([b'b']+msg)
bmsg = self.recv_multipart(b)
self.assertEqual(bmsg, [b'a']+msg)
b.send_multipart(bmsg)
amsg = self.recv_multipart(a)
self.assertEqual(amsg, [b'b']+msg)
self.teardown_device()
def test_default_mq_args(self):
self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB)
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
# this will raise if default args are wrong
dev.start()
self.teardown_device()
def test_mq_check_prefix(self):
ins = self.context.socket(zmq.ROUTER)
outs = self.context.socket(zmq.DEALER)
mons = self.context.socket(zmq.PUB)
self.sockets.extend([ins, outs, mons])
ins = unicode('in')
outs = unicode('out')
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)

View file

@ -0,0 +1,35 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest
class TestMultipart(BaseZMQTestCase):
def test_router_dealer(self):
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
msg1 = b'message1'
dealer.send(msg1)
ident = self.recv(router)
more = router.rcvmore
self.assertEqual(more, True)
msg2 = self.recv(router)
self.assertEqual(msg1, msg2)
more = router.rcvmore
self.assertEqual(more, False)
def test_basic_multipart(self):
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg = [ b'hi', b'there', b'b']
a.send_multipart(msg)
recvd = b.recv_multipart()
self.assertEqual(msg, recvd)
if have_gevent:
class TestMultipartGreen(GreenTest, TestMultipart):
pass

View file

@ -0,0 +1,53 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import zmq
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
x = b' '
class TestPair(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
msg1 = b'message1'
msg2 = self.ping_pong(s1, s2, msg1)
self.assertEqual(msg1, msg2)
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
for i in range(10):
msg = i*x
s1.send(msg)
for i in range(10):
msg = i*x
s2.send(msg)
for i in range(10):
msg = s1.recv()
self.assertEqual(msg, i*x)
for i in range(10):
msg = s2.recv()
self.assertEqual(msg, i*x)
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10,b=list(range(10)))
o2 = self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
o = dict(a=10,b=range(10))
o2 = self.ping_pong_pyobj(s1, s2, o)
if have_gevent:
class TestReqRepGreen(GreenTest, TestPair):
pass

View file

@ -0,0 +1,238 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import sys
import time
from pytest import mark
import zmq
from zmq.tests import PollZMQTestCase, have_gevent, GreenTest
def wait():
time.sleep(.25)
class TestPoll(PollZMQTestCase):
Poller = zmq.Poller
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
poller.register(s2, zmq.POLLIN|zmq.POLLOUT)
# Poll result should contain both sockets
socks = dict(poller.poll())
# Now make sure that both are send ready.
self.assertEqual(socks[s1], zmq.POLLOUT)
self.assertEqual(socks[s2], zmq.POLLOUT)
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
s1.send(b'msg1')
s2.send(b'msg2')
wait()
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLOUT|zmq.POLLIN)
self.assertEqual(socks[s2], zmq.POLLOUT|zmq.POLLIN)
# Make sure that both are in POLLOUT after recv.
s1.recv()
s2.recv()
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLOUT)
self.assertEqual(socks[s2], zmq.POLLOUT)
poller.unregister(s1)
poller.unregister(s2)
def test_reqrep(self):
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
poller.register(s2, zmq.POLLIN|zmq.POLLOUT)
# Make sure that s1 is in state 0 and s2 is in POLLOUT
socks = dict(poller.poll())
self.assertEqual(s1 in socks, 0)
self.assertEqual(socks[s2], zmq.POLLOUT)
# Make sure that s2 goes immediately into state 0 after send.
s2.send(b'msg1')
socks = dict(poller.poll())
self.assertEqual(s2 in socks, 0)
# Make sure that s1 goes into POLLIN state after a time.sleep().
time.sleep(0.5)
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLIN)
# Make sure that s1 goes into POLLOUT after recv.
s1.recv()
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLOUT)
# Make sure s1 goes into state 0 after send.
s1.send(b'msg2')
socks = dict(poller.poll())
self.assertEqual(s1 in socks, 0)
# Wait and then see that s2 is in POLLIN.
time.sleep(0.5)
socks = dict(poller.poll())
self.assertEqual(socks[s2], zmq.POLLIN)
# Make sure that s2 is in POLLOUT after recv.
s2.recv()
socks = dict(poller.poll())
self.assertEqual(socks[s2], zmq.POLLOUT)
poller.unregister(s1)
poller.unregister(s2)
def test_no_events(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
poller.register(s2, 0)
self.assertTrue(s1 in poller)
self.assertFalse(s2 in poller)
poller.register(s1, 0)
self.assertFalse(s1 in poller)
def test_pubsub(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
# Sleep to allow sockets to connect.
wait()
poller = self.Poller()
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
poller.register(s2, zmq.POLLIN)
# Now make sure that both are send ready.
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLOUT)
self.assertEqual(s2 in socks, 0)
# Make sure that s1 stays in POLLOUT after a send.
s1.send(b'msg1')
socks = dict(poller.poll())
self.assertEqual(socks[s1], zmq.POLLOUT)
# Make sure that s2 is POLLIN after waiting.
wait()
socks = dict(poller.poll())
self.assertEqual(socks[s2], zmq.POLLIN)
# Make sure that s2 goes into 0 after recv.
s2.recv()
socks = dict(poller.poll())
self.assertEqual(s2 in socks, 0)
poller.unregister(s1)
poller.unregister(s2)
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
def test_raw(self):
r, w = os.pipe()
r = os.fdopen(r, 'rb')
w = os.fdopen(w, 'wb')
p = self.Poller()
p.register(r, zmq.POLLIN)
socks = dict(p.poll(1))
assert socks == {}
w.write(b'x')
w.flush()
socks = dict(p.poll(1))
assert socks == {r.fileno(): zmq.POLLIN}
w.close()
r.close()
def test_timeout(self):
"""make sure Poller.poll timeout has the right units (milliseconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s1, zmq.POLLIN)
tic = time.time()
evt = poller.poll(.005)
toc = time.time()
self.assertTrue(toc-tic < 0.1)
tic = time.time()
evt = poller.poll(5)
toc = time.time()
self.assertTrue(toc-tic < 0.1)
self.assertTrue(toc-tic > .001)
tic = time.time()
evt = poller.poll(500)
toc = time.time()
self.assertTrue(toc-tic < 1)
self.assertTrue(toc-tic > 0.1)
class TestSelect(PollZMQTestCase):
def test_pair(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
# Sleep to allow sockets to connect.
wait()
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
self.assert_(s1 in wlist)
self.assert_(s2 in wlist)
self.assert_(s1 not in rlist)
self.assert_(s2 not in rlist)
def test_timeout(self):
"""make sure select timeout has the right units (seconds)."""
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.time()
r,w,x = zmq.select([s1,s2],[],[],.005)
toc = time.time()
self.assertTrue(toc-tic < 1)
self.assertTrue(toc-tic > 0.001)
tic = time.time()
r,w,x = zmq.select([s1,s2],[],[],.25)
toc = time.time()
self.assertTrue(toc-tic < 1)
self.assertTrue(toc-tic > 0.1)
if have_gevent:
import gevent
from zmq import green as gzmq
class TestPollGreen(GreenTest, TestPoll):
Poller = gzmq.Poller
def test_wakeup(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
poller = self.Poller()
poller.register(s2, zmq.POLLIN)
tic = time.time()
r = gevent.spawn(lambda: poller.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.time()
self.assertTrue(toc-tic < 1)
def test_socket_poll(self):
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
tic = time.time()
r = gevent.spawn(lambda: s2.poll(10000))
s = gevent.spawn(lambda: s1.send(b'msg1'))
r.join()
toc = time.time()
self.assertTrue(toc-tic < 1)

View file

@ -0,0 +1,109 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import time
import struct
import zmq
from zmq import devices
from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
if PYPY:
# cleanup of shared Context doesn't work on PyPy
devices.Device.context_factory = zmq.Context
class TestProxySteerable(BaseZMQTestCase):
def test_proxy_steerable(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(
zmq.PULL,
zmq.PUSH,
zmq.PUSH,
zmq.PAIR
)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
self.assertEqual(msg, self.recv(pull))
self.assertEqual(msg, self.recv(mon))
ctrl.send(b'TERMINATE')
dev.join()
def test_proxy_steerable_bind_to_random_with_args(self):
if zmq.zmq_version_info() < (4, 1):
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
dev = devices.ThreadProxySteerable(
zmq.PULL,
zmq.PUSH,
zmq.PUSH,
zmq.PAIR
)
iface = 'tcp://127.0.0.1'
ports = []
min, max = 5000, 5050
ports.extend([
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max)
])
for port in ports:
if port < min or port > max:
self.fail('Unexpected port number: %i' % port)
def test_proxy_steerable_statistics(self):
if zmq.zmq_version_info() < (4, 3):
raise SkipTest("STATISTICS only in libzmq >= 4.3")
dev = devices.ThreadProxySteerable(
zmq.PULL,
zmq.PUSH,
zmq.PUSH,
zmq.PAIR
)
iface = 'tcp://127.0.0.1'
port = dev.bind_in_to_random_port(iface)
port2 = dev.bind_out_to_random_port(iface)
port3 = dev.bind_mon_to_random_port(iface)
port4 = dev.bind_ctrl_to_random_port(iface)
dev.start()
time.sleep(0.25)
msg = b'hello'
push = self.context.socket(zmq.PUSH)
push.connect("%s:%i" % (iface, port))
pull = self.context.socket(zmq.PULL)
pull.connect("%s:%i" % (iface, port2))
mon = self.context.socket(zmq.PULL)
mon.connect("%s:%i" % (iface, port3))
ctrl = self.context.socket(zmq.PAIR)
ctrl.connect("%s:%i" % (iface, port4))
push.send(msg)
self.sockets.extend([push, pull, mon, ctrl])
self.assertEqual(msg, self.recv(pull))
self.assertEqual(msg, self.recv(mon))
ctrl.send(b'STATISTICS')
stats = self.recv_multipart(ctrl)
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
self.assertEqual(1, stats_int[0])
self.assertEqual(len(msg), stats_int[1])
self.assertEqual(1, stats_int[6])
self.assertEqual(len(msg), stats_int[7])
ctrl.send(b'TERMINATE')
dev.join()

View file

@ -0,0 +1,42 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from random import Random
import time
from unittest import TestCase
import zmq
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
class TestPubSub(BaseZMQTestCase):
pass
# We are disabling this test while an issue is being resolved.
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
msg2 = s2.recv() # This is blocking!
self.assertEqual(msg1, msg2)
def test_topic(self):
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
s2.setsockopt(zmq.SUBSCRIBE, b'x')
time.sleep(0.1)
msg1 = b'message'
s1.send(msg1)
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
msg1 = b'xmessage'
s1.send(msg1)
msg2 = s2.recv()
self.assertEqual(msg1, msg2)
if have_gevent:
class TestPubSubGreen(GreenTest, TestPubSub):
pass

View file

@ -0,0 +1,62 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
class TestReqRep(BaseZMQTestCase):
def test_basic(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = b'message 1'
msg2 = self.ping_pong(s1, s2, msg1)
self.assertEqual(msg1, msg2)
def test_multiple(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
for i in range(10):
msg1 = i*b' '
msg2 = self.ping_pong(s1, s2, msg1)
self.assertEqual(msg1, msg2)
def test_bad_send_recv(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
if zmq.zmq_version() != '2.1.8':
# this doesn't work on 2.1.8
for copy in (True,False):
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
# I have to have this or we die on an Abort trap.
msg1 = b'asdf'
msg2 = self.ping_pong(s1, s2, msg1)
self.assertEqual(msg1, msg2)
def test_json(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10,b=list(range(10)))
o2 = self.ping_pong_json(s1, s2, o)
def test_pyobj(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
o = dict(a=10,b=range(10))
o2 = self.ping_pong_pyobj(s1, s2, o)
def test_large_msg(self):
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
msg1 = 10000*b'X'
for i in range(10):
msg2 = self.ping_pong(s1, s2, msg1)
self.assertEqual(msg1, msg2)
if have_gevent:
class TestReqRepGreen(GreenTest, TestReqRep):
pass

View file

@ -0,0 +1,95 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import signal
import time
from threading import Thread
from pytest import mark
import zmq
from zmq.tests import (
BaseZMQTestCase, SkipTest, skip_pypy
)
from zmq.utils.strtypes import b
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
class TestEINTRSysCall(BaseZMQTestCase):
""" Base class for EINTR tests. """
# delay for initial signal delivery
signal_delay = 0.1
# timeout for tests. Must be > signal_delay
timeout = .25
timeout_ms = int(timeout * 1e3)
def alarm(self, t=None):
"""start a timer to fire only once
like signal.alarm, but with better resolution than integer seconds.
"""
if not hasattr(signal, 'setitimer'):
raise SkipTest('EINTR tests require setitimer')
if t is None:
t = self.signal_delay
self.timer_fired = False
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
# signal_period ignored, since only one timer event is allowed to fire
signal.setitimer(signal.ITIMER_REAL, t, 1000)
def stop_timer(self, *args):
self.timer_fired = True
signal.setitimer(signal.ITIMER_REAL, 0, 0)
signal.signal(signal.SIGALRM, self.orig_handler)
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_retry_recv(self):
pull = self.socket(zmq.PULL)
pull.rcvtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, pull.recv)
assert self.timer_fired
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_retry_send(self):
push = self.socket(zmq.PUSH)
push.sndtimeo = self.timeout_ms
self.alarm()
self.assertRaises(zmq.Again, push.send, b('buf'))
assert self.timer_fired
def test_retry_poll(self):
x, y = self.create_bound_pair()
poller = zmq.Poller()
poller.register(x, zmq.POLLIN)
self.alarm()
def send():
time.sleep(2 * self.signal_delay)
y.send(b('ping'))
t = Thread(target=send)
t.start()
evts = dict(poller.poll(2 * self.timeout_ms))
t.join()
assert x in evts
assert self.timer_fired
x.recv()
def test_retry_term(self):
push = self.socket(zmq.PUSH)
push.linger = self.timeout_ms
push.connect('tcp://127.0.0.1:5555')
push.send(b('ping'))
time.sleep(0.1)
self.alarm()
self.context.destroy()
assert self.timer_fired
assert self.context.closed
def test_retry_getsockopt(self):
raise SkipTest("TODO: find a way to interrupt getsockopt")
def test_retry_setsockopt(self):
raise SkipTest("TODO: find a way to interrupt setsockopt")

View file

@ -0,0 +1,236 @@
"""Test libzmq security (libzmq >= 3.3.0)"""
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import os
import contextlib
import time
from threading import Thread
import zmq
from zmq.tests import (
BaseZMQTestCase, SkipTest, PYPY
)
from zmq.utils import z85
USER = b"admin"
PASS = b"password"
class TestSecurity(BaseZMQTestCase):
def setUp(self):
if zmq.zmq_version_info() < (4,0):
raise SkipTest("security is new in libzmq 4.0")
try:
zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("security requires libzmq to be built with CURVE support")
super(TestSecurity, self).setUp()
def zap_handler(self):
socket = self.context.socket(zmq.REP)
socket.bind("inproc://zeromq.zap.01")
try:
msg = self.recv_multipart(socket)
version, sequence, domain, address, identity, mechanism = msg[:6]
if mechanism == b'PLAIN':
username, password = msg[6:]
elif mechanism == b'CURVE':
key = msg[6]
self.assertEqual(version, b"1.0")
self.assertEqual(identity, b"IDENT")
reply = [version, sequence]
if mechanism == b'CURVE' or \
(mechanism == b'PLAIN' and username == USER and password == PASS) or \
(mechanism == b'NULL'):
reply.extend([
b"200",
b"OK",
b"anonymous",
b"\5Hello\0\0\0\5World",
])
else:
reply.extend([
b"400",
b"Invalid username or password",
b"",
b"",
])
socket.send_multipart(reply)
finally:
socket.close()
@contextlib.contextmanager
def zap(self):
self.start_zap()
time.sleep(0.5) # allow time for the Thread to start
try:
yield
finally:
self.stop_zap()
def start_zap(self):
self.zap_thread = Thread(target=self.zap_handler)
self.zap_thread.start()
def stop_zap(self):
self.zap_thread.join()
def bounce(self, server, client, test_metadata=True):
msg = [os.urandom(64), os.urandom(64)]
client.send_multipart(msg)
frames = self.recv_multipart(server, copy=False)
recvd = list(map(lambda x: x.bytes, frames))
try:
if test_metadata and not PYPY:
for frame in frames:
self.assertEqual(frame.get('User-Id'), 'anonymous')
self.assertEqual(frame.get('Hello'), 'World')
self.assertEqual(frame['Socket-Type'], 'DEALER')
except zmq.ZMQVersionError:
pass
self.assertEqual(recvd, msg)
server.send_multipart(recvd)
msg2 = self.recv_multipart(client)
self.assertEqual(msg2, msg)
def test_null(self):
"""test NULL (default) security"""
server = self.socket(zmq.DEALER)
client = self.socket(zmq.DEALER)
self.assertEqual(client.MECHANISM, zmq.NULL)
self.assertEqual(server.mechanism, zmq.NULL)
self.assertEqual(client.plain_server, 0)
self.assertEqual(server.plain_server, 0)
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client, False)
def test_plain(self):
"""test PLAIN authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.assertEqual(client.plain_username, b'')
self.assertEqual(client.plain_password, b'')
client.plain_username = USER
client.plain_password = PASS
self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER)
self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS)
self.assertEqual(client.plain_server, 0)
self.assertEqual(server.plain_server, 0)
server.plain_server = True
self.assertEqual(server.mechanism, zmq.PLAIN)
self.assertEqual(client.mechanism, zmq.PLAIN)
assert not client.plain_server
assert server.plain_server
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)
def skip_plain_inauth(self):
"""test PLAIN failed authentication"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
client.plain_username = USER
client.plain_password = b'incorrect'
server.plain_server = True
self.assertEqual(server.mechanism, zmq.PLAIN)
self.assertEqual(client.mechanism, zmq.PLAIN)
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
client.send(b'ping')
server.rcvtimeo = 250
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
def test_keypair(self):
"""test curve_keypair"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
self.assertEqual(type(secret), bytes)
self.assertEqual(type(public), bytes)
self.assertEqual(len(secret), 40)
self.assertEqual(len(public), 40)
# verify that it is indeed Z85
bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ]
self.assertEqual(type(bsecret), bytes)
self.assertEqual(type(bpublic), bytes)
self.assertEqual(len(bsecret), 32)
self.assertEqual(len(bpublic), 32)
def test_curve_public(self):
"""test curve_public"""
try:
public, secret = zmq.curve_keypair()
except zmq.ZMQError:
raise SkipTest("CURVE unsupported")
if zmq.zmq_version_info() < (4,2):
raise SkipTest("curve_public is new in libzmq 4.2")
derived_public = zmq.curve_public(secret)
self.assertEqual(type(derived_public), bytes)
self.assertEqual(len(derived_public), 40)
# verify that it is indeed Z85
bpublic = z85.decode(derived_public)
self.assertEqual(type(bpublic), bytes)
self.assertEqual(len(bpublic), 32)
# verify that it is equal to the known public key
self.assertEqual(derived_public, public)
def test_curve(self):
"""test CURVE encryption"""
server = self.socket(zmq.DEALER)
server.identity = b'IDENT'
client = self.socket(zmq.DEALER)
self.sockets.extend([server, client])
try:
server.curve_server = True
except zmq.ZMQError as e:
# will raise EINVAL if no CURVE support
if e.errno == zmq.EINVAL:
raise SkipTest("CURVE unsupported")
server_public, server_secret = zmq.curve_keypair()
client_public, client_secret = zmq.curve_keypair()
server.curve_secretkey = server_secret
server.curve_publickey = server_public
client.curve_serverkey = server_public
client.curve_publickey = client_public
client.curve_secretkey = client_secret
self.assertEqual(server.mechanism, zmq.CURVE)
self.assertEqual(client.mechanism, zmq.CURVE)
self.assertEqual(server.get(zmq.CURVE_SERVER), True)
self.assertEqual(client.get(zmq.CURVE_SERVER), False)
with self.zap():
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
self.bounce(server, client)

View file

@ -0,0 +1,615 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import copy
import errno
import json
import os
import platform
import time
import warnings
import socket
import sys
try:
from unittest import mock
except ImportError:
mock = None
import pytest
from pytest import mark
import zmq
from zmq.tests import (
BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy
)
from zmq.utils.strtypes import unicode
pypy = platform.python_implementation().lower() == 'pypy'
windows = platform.platform().lower().startswith('windows')
on_travis = bool(os.environ.get('TRAVIS_PYTHON_VERSION'))
# polling on windows is slow
POLL_TIMEOUT = 1000 if windows else 100
class TestSocket(BaseZMQTestCase):
def test_create(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
# Superluminal protocol not yet implemented
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
s.close()
del ctx
def test_context_manager(self):
url = 'inproc://a'
with self.Context() as ctx:
with ctx.socket(zmq.PUSH) as a:
a.bind(url)
with ctx.socket(zmq.PULL) as b:
b.connect(url)
msg = b'hi'
a.send(msg)
rcvd = self.recv(b)
self.assertEqual(rcvd, msg)
self.assertEqual(b.closed, True)
self.assertEqual(a.closed, True)
self.assertEqual(ctx.closed, True)
def test_dir(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
self.assertTrue('send' in dir(s))
self.assertTrue('IDENTITY' in dir(s))
self.assertTrue('AFFINITY' in dir(s))
self.assertTrue('FD' in dir(s))
s.close()
ctx.term()
@mark.skipif(mock is None, reason="requires unittest.mock")
def test_mockable(self):
s = self.socket(zmq.SUB)
m = mock.Mock(spec=s)
s.close()
def test_bind_unicode(self):
s = self.socket(zmq.PUB)
p = s.bind_to_random_port(unicode("tcp://*"))
def test_connect_unicode(self):
s = self.socket(zmq.PUB)
s.connect(unicode("tcp://127.0.0.1:5555"))
def test_bind_to_random_port(self):
# Check that bind_to_random_port do not hide useful exception
ctx = self.Context()
c = ctx.socket(zmq.PUB)
# Invalid format
try:
c.bind_to_random_port('tcp:*')
except zmq.ZMQError as e:
self.assertEqual(e.errno, zmq.EINVAL)
# Invalid protocol
try:
c.bind_to_random_port('rand://*')
except zmq.ZMQError as e:
self.assertEqual(e.errno, zmq.EPROTONOSUPPORT)
def test_identity(self):
s = self.context.socket(zmq.PULL)
self.sockets.append(s)
ident = b'identity\0\0'
s.identity = ident
self.assertEqual(s.get(zmq.IDENTITY), ident)
def test_unicode_sockopts(self):
"""test setting/getting sockopts with unicode strings"""
topic = "tést"
if str is not unicode:
topic = topic.decode('utf8')
p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
self.assertEqual(s.send_unicode, s.send_unicode)
self.assertEqual(p.recv_unicode, p.recv_unicode)
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
identb = s.getsockopt(zmq.IDENTITY)
identu = identb.decode('utf16')
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
self.assertEqual(identu, identu2)
time.sleep(0.1) # wait for connection/subscription
p.send_unicode(topic,zmq.SNDMORE)
p.send_unicode(topic*2, encoding='latin-1')
self.assertEqual(topic, s.recv_unicode())
self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1'))
def test_int_sockopts(self):
"test integer sockopts"
v = zmq.zmq_version_info()
if v < (3,0):
default_hwm = 0
else:
default_hwm = 1000
p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
p.setsockopt(zmq.LINGER, 0)
self.assertEqual(p.getsockopt(zmq.LINGER), 0)
p.setsockopt(zmq.LINGER, -1)
self.assertEqual(p.getsockopt(zmq.LINGER), -1)
self.assertEqual(p.hwm, default_hwm)
p.hwm = 11
self.assertEqual(p.hwm, 11)
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT)
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1)
self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type)
self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB)
self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type)
self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB)
# check for overflow / wrong type:
errors = []
backref = {}
constants = zmq.constants
for name in constants.__all__:
value = getattr(constants, name)
if isinstance(value, int):
backref[value] = name
for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts):
sopt = backref[opt]
if sopt.startswith((
'ROUTER', 'XPUB', 'TCP', 'FAIL',
'REQ_', 'CURVE_', 'PROBE_ROUTER',
'IPC_FILTER', 'GSSAPI', 'STREAM_',
'VMCI_BUFFER_SIZE', 'VMCI_BUFFER_MIN_SIZE',
'VMCI_BUFFER_MAX_SIZE', 'VMCI_CONNECT_TIMEOUT',
)):
# some sockopts are write-only
continue
try:
n = p.getsockopt(opt)
except zmq.ZMQError as e:
errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e))
else:
if n > 2**31:
errors.append("getsockopt(zmq.%s) returned a ridiculous value."
" It is probably the wrong type."%sopt)
if errors:
self.fail('\n'.join([''] + errors))
def test_bad_sockopts(self):
"""Test that appropriate errors are raised on bad socket options"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
s.setsockopt(zmq.LINGER, 0)
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
# but only int sockopts are allowed through this way, otherwise raise a TypeError
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
# some sockopts are valid in general, but not on every socket:
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
def test_sockopt_roundtrip(self):
"test set/getsockopt roundtrip."
p = self.context.socket(zmq.PUB)
self.sockets.append(p)
p.setsockopt(zmq.LINGER, 11)
self.assertEqual(p.getsockopt(zmq.LINGER), 11)
def test_send_unicode(self):
"test sending unicode objects"
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a,b])
u = "çπ§"
if str is not unicode:
u = u.decode('utf8')
self.assertRaises(TypeError, a.send, u,copy=False)
self.assertRaises(TypeError, a.send, u,copy=True)
a.send_unicode(u)
s = b.recv()
self.assertEqual(s,u.encode('utf8'))
self.assertEqual(s.decode('utf8'),u)
a.send_unicode(u,encoding='utf16')
s = b.recv_unicode(encoding='utf16')
self.assertEqual(s,u)
def test_send_multipart_check_type(self):
"check type on all frames in send_multipart"
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
self.sockets.extend([a,b])
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
a.send_multipart([b'b'])
rcvd = self.recv_multipart(b)
self.assertEqual(rcvd, [b'b'])
@skip_pypy
def test_tracker(self):
"test the MessageTracker object for tracking when zmq is done with a buffer"
addr = 'tcp://127.0.0.1'
# get a port:
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
iface = "%s:%i" % (addr, port)
sock.close()
time.sleep(0.1)
a = self.context.socket(zmq.PUSH)
b = self.context.socket(zmq.PULL)
self.sockets.extend([a,b])
a.connect(iface)
time.sleep(0.1)
p1 = a.send(b'something', copy=False, track=True)
assert isinstance(p1, zmq.MessageTracker)
assert p1 is zmq._FINISHED_TRACKER
# small message, should start done
assert p1.done
# disable zero-copy threshold
a.copy_threshold = 0
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
assert isinstance(p2, zmq.MessageTracker)
assert not p2.done
b.bind(iface)
msg = self.recv_multipart(b)
for i in range(10):
if p1.done:
break
time.sleep(0.1)
self.assertEqual(p1.done, True)
self.assertEqual(msg, [b'something'])
msg = self.recv_multipart(b)
for i in range(10):
if p2.done:
break
time.sleep(0.1)
self.assertEqual(p2.done, True)
self.assertEqual(msg, [b'something', b'else'])
m = zmq.Frame(b"again", copy=False, track=True)
self.assertEqual(m.tracker.done, False)
p1 = a.send(m, copy=False)
p2 = a.send(m, copy=False)
self.assertEqual(m.tracker.done, False)
self.assertEqual(p1.done, False)
self.assertEqual(p2.done, False)
msg = self.recv_multipart(b)
self.assertEqual(m.tracker.done, False)
self.assertEqual(msg, [b'again'])
msg = self.recv_multipart(b)
self.assertEqual(m.tracker.done, False)
self.assertEqual(msg, [b'again'])
self.assertEqual(p1.done, False)
self.assertEqual(p2.done, False)
pm = m.tracker
del m
for i in range(10):
if p1.done:
break
time.sleep(0.1)
self.assertEqual(p1.done, True)
self.assertEqual(p2.done, True)
m = zmq.Frame(b'something', track=False)
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
def test_close(self):
ctx = self.Context()
s = ctx.socket(zmq.PUB)
s.close()
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
del ctx
def test_attr(self):
"""set setting/getting sockopts as attributes"""
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
linger = 10
s.linger = linger
self.assertEqual(linger, s.linger)
self.assertEqual(linger, s.getsockopt(zmq.LINGER))
self.assertEqual(s.fd, s.getsockopt(zmq.FD))
def test_bad_attr(self):
s = self.context.socket(zmq.DEALER)
self.sockets.append(s)
try:
s.apple='foo'
except AttributeError:
pass
else:
self.fail("bad setattr should have raised AttributeError")
try:
s.apple
except AttributeError:
pass
else:
self.fail("bad getattr should have raised AttributeError")
def test_subclass(self):
"""subclasses can assign attributes"""
class S(zmq.Socket):
a = None
def __init__(self, *a, **kw):
self.a=-1
super(S, self).__init__(*a, **kw)
s = S(self.context, zmq.REP)
self.sockets.append(s)
self.assertEqual(s.a, -1)
s.a=1
self.assertEqual(s.a, 1)
a=s.a
self.assertEqual(a, 1)
def test_recv_multipart(self):
a,b = self.create_bound_pair()
msg = b'hi'
for i in range(3):
a.send(msg)
time.sleep(0.1)
for i in range(3):
self.assertEqual(self.recv_multipart(b), [msg])
def test_close_after_destroy(self):
"""s.close() after ctx.destroy() should be fine"""
ctx = self.Context()
s = ctx.socket(zmq.REP)
ctx.destroy()
# reaper is not instantaneous
time.sleep(1e-2)
s.close()
self.assertTrue(s.closed)
def test_poll(self):
a,b = self.create_bound_pair()
tic = time.time()
evt = a.poll(POLL_TIMEOUT)
self.assertEqual(evt, 0)
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
self.assertEqual(evt, zmq.POLLOUT)
msg = b'hi'
a.send(msg)
evt = b.poll(POLL_TIMEOUT)
self.assertEqual(evt, zmq.POLLIN)
msg2 = self.recv(b)
evt = b.poll(POLL_TIMEOUT)
self.assertEqual(evt, 0)
self.assertEqual(msg2, msg)
def test_ipc_path_max_length(self):
"""IPC_PATH_MAX_LEN is a sensible value"""
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg)
self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg)
def test_ipc_path_max_length_msg(self):
if zmq.IPC_PATH_MAX_LEN == 0:
raise SkipTest("IPC_PATH_MAX_LEN undefined")
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
try:
s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
except zmq.ZMQError as e:
self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror)
@mark.skipif(windows, reason="ipc not supported on Windows.")
def test_ipc_path_no_such_file_or_directory_message(self):
"""Display the ipc path in case of an ENOENT exception"""
s = self.context.socket(zmq.PUB)
self.sockets.append(s)
invalid_path = '/foo/bar'
with pytest.raises(zmq.ZMQError) as error:
s.bind('ipc://{0}'.format(invalid_path))
assert error.value.errno == errno.ENOENT
error_message = str(error.value)
assert invalid_path in error_message
assert "no such file or directory" in error_message.lower()
def test_hwm(self):
zmq3 = zmq.zmq_version_info()[0] >= 3
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
s = self.context.socket(stype)
s.hwm = 100
self.assertEqual(s.hwm, 100)
if zmq3:
try:
self.assertEqual(s.sndhwm, 100)
except AttributeError:
pass
try:
self.assertEqual(s.rcvhwm, 100)
except AttributeError:
pass
s.close()
def test_copy(self):
s = self.socket(zmq.PUB)
scopy = copy.copy(s)
sdcopy = copy.deepcopy(s)
self.assert_(scopy._shadow)
self.assert_(sdcopy._shadow)
self.assertEqual(s.underlying, scopy.underlying)
self.assertEqual(s.underlying, sdcopy.underlying)
s.close()
def test_send_buffer(self):
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
for buffer_type in (memoryview, bytearray):
rawbytes = str(buffer_type).encode('ascii')
msg = buffer_type(rawbytes)
a.send(msg)
recvd = b.recv()
assert recvd == rawbytes
def test_shadow(self):
p = self.socket(zmq.PUSH)
p.bind("tcp://127.0.0.1:5555")
p2 = zmq.Socket.shadow(p.underlying)
self.assertEqual(p.underlying, p2.underlying)
s = self.socket(zmq.PULL)
s2 = zmq.Socket.shadow(s.underlying)
self.assertNotEqual(s.underlying, p.underlying)
self.assertEqual(s.underlying, s2.underlying)
s2.connect("tcp://127.0.0.1:5555")
sent = b'hi'
p2.send(sent)
rcvd = self.recv(s2)
self.assertEqual(rcvd, sent)
def test_shadow_pyczmq(self):
try:
from pyczmq import zctx, zsocket
except Exception:
raise SkipTest("Requires pyczmq")
ctx = zctx.new()
ca = zsocket.new(ctx, zmq.PUSH)
cb = zsocket.new(ctx, zmq.PULL)
a = zmq.Socket.shadow(ca)
b = zmq.Socket.shadow(cb)
a.bind("inproc://a")
b.connect("inproc://a")
a.send(b'hi')
rcvd = self.recv(b)
self.assertEqual(rcvd, b'hi')
def test_subscribe_method(self):
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
sub.subscribe('prefix')
sub.subscribe = 'c'
p = zmq.Poller()
p.register(sub, zmq.POLLIN)
# wait for subscription handshake
for i in range(100):
pub.send(b'canary')
events = p.poll(250)
if events:
break
self.recv(sub)
pub.send(b'prefixmessage')
msg = self.recv(sub)
self.assertEqual(msg, b'prefixmessage')
sub.unsubscribe('prefix')
pub.send(b'prefixmessage')
events = p.poll(1000)
self.assertEqual(events, [])
# Travis can't handle how much memory PyPy uses on this test
@mark.skipif(
(
pypy and on_travis
) or (
sys.maxsize < 2**32
) or (
windows
),
reason="only run on 64b and not on Travis."
)
@mark.large
def test_large_send(self):
c = os.urandom(1)
N = 2**31 + 1
try:
buf = c * N
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
a, b = self.create_bound_pair()
try:
a.send(buf, copy=False)
rcvd = b.recv(copy=False)
except MemoryError as e:
raise SkipTest("Not enough memory: %s" % e)
# sample the front and back of the received message
# without checking the whole content
# Python 2: items in memoryview are bytes
# Python 3: items im memoryview are int
byte = c if sys.version_info < (3,) else ord(c)
view = memoryview(rcvd)
assert len(view) == N
assert view[0] == byte
assert view[-1] == byte
def test_custom_serialize(self):
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
def serialize(msg):
frames = []
frames.extend(msg.get('identities', []))
content = json.dumps(msg['content']).encode('utf8')
frames.append(content)
return frames
def deserialize(frames):
identities = frames[:-1]
content = json.loads(frames[-1].decode('utf8'))
return {
'identities': identities,
'content': content,
}
msg = {
'content': {
'a': 5,
'b': 'bee',
}
}
a.send_serialized(msg, serialize)
recvd = b.recv_serialized(deserialize)
assert recvd['content'] == msg['content']
assert recvd['identities']
# bounce back, tests identities
b.send_serialized(recvd, serialize)
r2 = a.recv_serialized(deserialize)
assert r2['content'] == msg['content']
assert not r2['identities']
if have_gevent and not windows:
import gevent
class TestSocketGreen(GreenTest, TestSocket):
test_bad_attr = GreenTest.skip_green
test_close_after_destroy = GreenTest.skip_green
def test_timeout(self):
a,b = self.create_bound_pair()
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
timeout = gevent.Timeout(0.1)
timeout.start()
self.assertRaises(gevent.Timeout, b.recv)
g.kill()
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
def test_warn_set_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.rcvtimeo = 5
s.close()
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
def test_warn_get_timeo(self):
s = self.context.socket(zmq.REQ)
with warnings.catch_warnings(record=True) as w:
s.sndtimeo
s.close()
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)

View file

@ -0,0 +1,8 @@
from zmq.ssh.tunnel import select_random_ports
def test_random_ports():
for i in range(4096):
ports = select_random_ports(10)
assert len(ports) == 10
for p in ports:
assert ports.count(p) == 1

View file

@ -0,0 +1,44 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
import zmq
from zmq.sugar import version
class TestVersion(TestCase):
def test_pyzmq_version(self):
vs = zmq.pyzmq_version()
vs2 = zmq.__version__
self.assertTrue(isinstance(vs, str))
if zmq.__revision__:
self.assertEqual(vs, '@'.join(vs2, zmq.__revision__))
else:
self.assertEqual(vs, vs2)
if version.VERSION_EXTRA:
self.assertTrue(version.VERSION_EXTRA in vs)
self.assertTrue(version.VERSION_EXTRA in vs2)
def test_pyzmq_version_info(self):
info = zmq.pyzmq_version_info()
self.assertTrue(isinstance(info, tuple))
for n in info[:3]:
self.assertTrue(isinstance(n, int))
if version.VERSION_EXTRA:
self.assertEqual(len(info), 4)
self.assertEqual(info[-1], float('inf'))
else:
self.assertEqual(len(info), 3)
def test_zmq_version_info(self):
info = zmq.zmq_version_info()
self.assertTrue(isinstance(info, tuple))
for n in info[:3]:
self.assertTrue(isinstance(n, int))
def test_zmq_version(self):
v = zmq.zmq_version()
self.assertTrue(isinstance(v, str))

View file

@ -0,0 +1,63 @@
from __future__ import print_function
import os
import time
import sys
from functools import wraps
from pytest import mark
from zmq.tests import BaseZMQTestCase
from zmq.utils.win32 import allow_interrupt
def count_calls(f):
@wraps(f)
def _(*args, **kwds):
try:
return f(*args, **kwds)
finally:
_.__calls__ += 1
_.__calls__ = 0
return _
@mark.new_console
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
@mark.new_console
@mark.skipif(
not sys.platform.startswith('win'),
reason='Windows only test')
def test_handler(self):
@count_calls
def interrupt_polling():
print('Caught CTRL-C!')
from ctypes import windll
from ctypes.wintypes import BOOL, DWORD
kernel32 = windll.LoadLibrary('kernel32')
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
GenerateConsoleCtrlEvent.restype = BOOL
# Simulate CTRL-C event while handler is active.
try:
with allow_interrupt(interrupt_polling) as context:
result = GenerateConsoleCtrlEvent(0, 0)
# Sleep so that we give time to the handler to
# capture the Ctrl-C event.
time.sleep(0.5)
except KeyboardInterrupt:
pass
else:
if result == 0:
raise WindowsError()
else:
self.fail('Expecting `KeyboardInterrupt` exception!')
# Make sure our handler was called.
self.assertEqual(interrupt_polling.__calls__, 1)

View file

@ -0,0 +1,63 @@
# -*- coding: utf8 -*-
"""Test Z85 encoding
confirm values and roundtrip with test values from the reference implementation.
"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from unittest import TestCase
from zmq.utils import z85
class TestZ85(TestCase):
def test_client_public(self):
client_public = \
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B" \
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A" \
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26" \
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
encoded = z85.encode(client_public)
self.assertEqual(encoded, b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID")
decoded = z85.decode(encoded)
self.assertEqual(decoded, client_public)
def test_client_secret(self):
client_secret = \
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67" \
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89" \
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51" \
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
encoded = z85.encode(client_secret)
self.assertEqual(encoded, b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs")
decoded = z85.decode(encoded)
self.assertEqual(decoded, client_secret)
def test_server_public(self):
server_public = \
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96" \
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0" \
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5" \
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
encoded = z85.encode(server_public)
self.assertEqual(encoded, b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7")
decoded = z85.decode(encoded)
self.assertEqual(decoded, server_public)
def test_server_secret(self):
server_secret = \
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D" \
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0" \
b"\x4D\x48\x96\x3F\x79\x25\x98\x77" \
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
encoded = z85.encode(server_secret)
self.assertEqual(encoded, b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6")
decoded = z85.decode(encoded)
self.assertEqual(decoded, server_secret)

View file

@ -0,0 +1,79 @@
# -*- coding: utf8 -*-
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
from __future__ import absolute_import
try:
import asyncio
except ImportError:
asyncio = None
from unittest import TestCase
import pytest
import zmq
try:
import tornado
from tornado import gen
from zmq.eventloop import ioloop, zmqstream
except ImportError:
tornado = None
class TestZMQStream(TestCase):
def setUp(self):
if tornado is None:
pytest.skip()
if asyncio:
asyncio.set_event_loop(asyncio.new_event_loop())
self.context = zmq.Context()
self.loop = ioloop.IOLoop()
self.loop.make_current()
self.push = zmqstream.ZMQStream(self.context.socket(zmq.PUSH))
self.pull = zmqstream.ZMQStream(self.context.socket(zmq.PULL))
port = self.push.bind_to_random_port('tcp://127.0.0.1')
self.pull.connect('tcp://127.0.0.1:%i' % port)
self.stream = self.push
def tearDown(self):
self.loop.close(all_fds=True)
self.context.term()
ioloop.IOLoop.clear_current()
def run_until_timeout(self, timeout=10):
timed_out = []
@gen.coroutine
def sleep_timeout():
yield gen.sleep(timeout)
timed_out[:] = ['timed out']
self.loop.stop()
self.loop.add_callback(lambda : sleep_timeout())
self.loop.start()
assert not timed_out
def test_callable_check(self):
"""Ensure callable check works (py3k)."""
self.stream.on_send(lambda *args: None)
self.stream.on_recv(lambda *args: None)
self.assertRaises(AssertionError, self.stream.on_recv, 1)
self.assertRaises(AssertionError, self.stream.on_send, 1)
self.assertRaises(AssertionError, self.stream.on_recv, zmq)
def test_on_recv_basic(self):
sent = [b'basic']
def callback(msg):
assert msg == sent
self.loop.stop()
self.loop.add_callback(lambda : self.push.send_multipart(sent))
self.pull.on_recv(callback)
self.run_until_timeout()
def test_on_recv_wake(self):
sent = [b'wake']
def callback(msg):
assert msg == sent
self.loop.stop()
self.pull.on_recv(callback)
self.loop.call_later(1, lambda : self.push.send_multipart(sent))
self.run_until_timeout()