Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
192
venv/Lib/site-packages/zmq/tests/__init__.py
Normal file
192
venv/Lib/site-packages/zmq/tests/__init__.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from unittest import TestCase
|
||||
try:
|
||||
from unittest import SkipTest
|
||||
except ImportError:
|
||||
from unittest2 import SkipTest
|
||||
|
||||
from pytest import mark
|
||||
import zmq
|
||||
from zmq.utils import jsonapi
|
||||
|
||||
try:
|
||||
import gevent
|
||||
from zmq import green as gzmq
|
||||
have_gevent = True
|
||||
except ImportError:
|
||||
have_gevent = False
|
||||
|
||||
|
||||
PYPY = 'PyPy' in sys.version
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# skip decorators (directly from unittest)
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
_id = lambda x: x
|
||||
|
||||
skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
|
||||
require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# Base test class
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
class BaseZMQTestCase(TestCase):
|
||||
green = False
|
||||
teardown_timeout = 10
|
||||
|
||||
@property
|
||||
def Context(self):
|
||||
if self.green:
|
||||
return gzmq.Context
|
||||
else:
|
||||
return zmq.Context
|
||||
|
||||
def socket(self, socket_type):
|
||||
s = self.context.socket(socket_type)
|
||||
self.sockets.append(s)
|
||||
return s
|
||||
|
||||
def setUp(self):
|
||||
super(BaseZMQTestCase, self).setUp()
|
||||
if self.green and not have_gevent:
|
||||
raise SkipTest("requires gevent")
|
||||
self.context = self.Context.instance()
|
||||
self.sockets = []
|
||||
|
||||
def tearDown(self):
|
||||
contexts = set([self.context])
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close(0)
|
||||
for ctx in contexts:
|
||||
t = Thread(target=ctx.term)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
t.join(timeout=self.teardown_timeout)
|
||||
if t.is_alive():
|
||||
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
||||
zmq.sugar.context.Context._instance = None
|
||||
raise RuntimeError("context could not terminate, open sockets likely remain in test")
|
||||
super(BaseZMQTestCase, self).tearDown()
|
||||
|
||||
def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
|
||||
"""Create a bound socket pair using a random port."""
|
||||
s1 = self.context.socket(type1)
|
||||
s1.setsockopt(zmq.LINGER, 0)
|
||||
port = s1.bind_to_random_port(interface)
|
||||
s2 = self.context.socket(type2)
|
||||
s2.setsockopt(zmq.LINGER, 0)
|
||||
s2.connect('%s:%s' % (interface, port))
|
||||
self.sockets.extend([s1,s2])
|
||||
return s1, s2
|
||||
|
||||
def ping_pong(self, s1, s2, msg):
|
||||
s1.send(msg)
|
||||
msg2 = s2.recv()
|
||||
s2.send(msg2)
|
||||
msg3 = s1.recv()
|
||||
return msg3
|
||||
|
||||
def ping_pong_json(self, s1, s2, o):
|
||||
if jsonapi.jsonmod is None:
|
||||
raise SkipTest("No json library")
|
||||
s1.send_json(o)
|
||||
o2 = s2.recv_json()
|
||||
s2.send_json(o2)
|
||||
o3 = s1.recv_json()
|
||||
return o3
|
||||
|
||||
def ping_pong_pyobj(self, s1, s2, o):
|
||||
s1.send_pyobj(o)
|
||||
o2 = s2.recv_pyobj()
|
||||
s2.send_pyobj(o2)
|
||||
o3 = s1.recv_pyobj()
|
||||
return o3
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
|
||||
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def _select_recv(self, multipart, socket, **kwargs):
|
||||
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
||||
if zmq.zmq_version_info() >= (3,1,0):
|
||||
# zmq 3.1 has a bug, where poll can return false positives,
|
||||
# so we wait a little bit just in case
|
||||
# See LIBZMQ-280 on JIRA
|
||||
time.sleep(0.1)
|
||||
|
||||
r,w,x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
|
||||
assert len(r) > 0, "Should have received a message"
|
||||
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
||||
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
return recv(**kwargs)
|
||||
|
||||
def recv(self, socket, **kwargs):
|
||||
"""call recv in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(False, socket, **kwargs)
|
||||
|
||||
def recv_multipart(self, socket, **kwargs):
|
||||
"""call recv_multipart in a way that raises if there is nothing to receive"""
|
||||
return self._select_recv(True, socket, **kwargs)
|
||||
|
||||
|
||||
class PollZMQTestCase(BaseZMQTestCase):
|
||||
pass
|
||||
|
||||
class GreenTest:
|
||||
"""Mixin for making green versions of test classes"""
|
||||
green = True
|
||||
teardown_timeout = 10
|
||||
|
||||
def assertRaisesErrno(self, errno, func, *args, **kwargs):
|
||||
if errno == zmq.EAGAIN:
|
||||
raise SkipTest("Skipping because we're green.")
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except zmq.ZMQError:
|
||||
e = sys.exc_info()[1]
|
||||
self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
|
||||
got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
|
||||
else:
|
||||
self.fail("Function did not raise any error")
|
||||
|
||||
def tearDown(self):
|
||||
contexts = set([self.context])
|
||||
while self.sockets:
|
||||
sock = self.sockets.pop()
|
||||
contexts.add(sock.context) # in case additional contexts are created
|
||||
sock.close()
|
||||
try:
|
||||
gevent.joinall(
|
||||
[gevent.spawn(ctx.term) for ctx in contexts],
|
||||
timeout=self.teardown_timeout,
|
||||
raise_error=True,
|
||||
)
|
||||
except gevent.Timeout:
|
||||
raise RuntimeError("context could not terminate, open sockets likely remain in test")
|
||||
|
||||
def skip_green(self):
|
||||
raise SkipTest("Skipping because we are green")
|
||||
|
||||
def skip_green(f):
|
||||
def skipping_test(self, *args, **kwargs):
|
||||
if self.green:
|
||||
raise SkipTest("Skipping because we are green")
|
||||
else:
|
||||
return f(self, *args, **kwargs)
|
||||
return skipping_test
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
0
venv/Lib/site-packages/zmq/tests/asyncio/__init__.py
Normal file
0
venv/Lib/site-packages/zmq/tests/asyncio/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
481
venv/Lib/site-packages/zmq/tests/asyncio/_test_asyncio.py
Normal file
481
venv/Lib/site-packages/zmq/tests/asyncio/_test_asyncio.py
Normal file
|
@ -0,0 +1,481 @@
|
|||
"""Test asyncio support"""
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.utils.strtypes import u
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import zmq.asyncio as zaio
|
||||
from zmq.auth.asyncio import AsyncioAuthenticator
|
||||
except ImportError:
|
||||
if sys.version_info >= (3,4):
|
||||
raise
|
||||
asyncio = None
|
||||
|
||||
from concurrent.futures import CancelledError
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest
|
||||
from zmq.tests.test_auth import TestThreadAuthentication
|
||||
|
||||
|
||||
class ProcessForTeardownTest(Process):
|
||||
def __init__(self, event_loop_policy_class):
|
||||
Process.__init__(self)
|
||||
self.event_loop_policy_class = event_loop_policy_class
|
||||
|
||||
def run(self):
|
||||
"""Leave context, socket and event loop upon implicit disposal"""
|
||||
asyncio.set_event_loop_policy(self.event_loop_policy_class())
|
||||
|
||||
actx = zaio.Context.instance()
|
||||
socket = actx.socket(zmq.PAIR)
|
||||
socket.bind_to_random_port('tcp://127.0.0.1')
|
||||
|
||||
@asyncio.coroutine
|
||||
def never_ending_task(socket):
|
||||
yield from socket.recv() # never ever receive anything
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
|
||||
try:
|
||||
loop.run_until_complete(coro)
|
||||
except asyncio.TimeoutError:
|
||||
pass # expected timeout
|
||||
else:
|
||||
assert False, "never_ending_task was completed unexpectedly"
|
||||
|
||||
|
||||
class TestAsyncIOSocket(BaseZMQTestCase):
|
||||
if asyncio is not None:
|
||||
Context = zaio.Context
|
||||
|
||||
def setUp(self):
|
||||
if asyncio is None:
|
||||
raise SkipTest()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
super(TestAsyncIOSocket, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_socket_class(self):
|
||||
s = self.context.socket(zmq.PUSH)
|
||||
assert isinstance(s, zaio.Socket)
|
||||
s.close()
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
actx = zmq.asyncio.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
ctx = zmq.Context.instance()
|
||||
actx = zmq.asyncio.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is zmq.asyncio.Context
|
||||
|
||||
def test_recv_multipart(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
yield from a.send(b'hi')
|
||||
recvd = yield from f
|
||||
self.assertEqual(recvd, [b'hi'])
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
yield from a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield from f2
|
||||
assert f1.done()
|
||||
self.assertEqual(f1.result(), b'hi')
|
||||
self.assertEqual(recvd, b'there')
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_recv_timeout(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with self.assertRaises(zmq.Again):
|
||||
yield from f1
|
||||
yield from a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield from f2
|
||||
assert f2.done()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_send_timeout(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with self.assertRaises(zmq.Again):
|
||||
yield from s.send(b'not going anywhere')
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_string(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = u('πøøπ')
|
||||
yield from a.send_string(msg)
|
||||
recvd = yield from f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), msg)
|
||||
self.assertEqual(recvd, msg)
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_json(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
yield from a.send_json(obj)
|
||||
recvd = yield from f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), obj)
|
||||
self.assertEqual(recvd, obj)
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_json_cancelled(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
yield from asyncio.sleep(0)
|
||||
obj = dict(a=5)
|
||||
yield from a.send_json(obj)
|
||||
# CancelledError change in 3.8 https://bugs.python.org/issue32528
|
||||
if sys.version_info < (3, 8):
|
||||
with pytest.raises(CancelledError):
|
||||
recvd = yield from f
|
||||
else:
|
||||
with pytest.raises(asyncio.exceptions.CancelledError):
|
||||
recvd = yield from f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = yield from b.poll(timeout=5)
|
||||
assert events
|
||||
yield from asyncio.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
f = b.recv_json()
|
||||
recvd = yield from asyncio.wait_for(f, timeout=5)
|
||||
assert recvd == obj
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_pyobj(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
yield from a.send_pyobj(obj)
|
||||
recvd = yield from f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), obj)
|
||||
self.assertEqual(recvd, obj)
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
|
||||
def test_custom_serialize(self):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
yield from a.send_serialized(msg, serialize)
|
||||
recvd = yield from b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
yield from b.send_serialized(recvd, serialize)
|
||||
r2 = yield from a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_custom_serialize_error(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
yield from a.send_serialized(json, json.dumps)
|
||||
|
||||
yield from a.send(b'not json')
|
||||
with pytest.raises(TypeError):
|
||||
recvd = yield from b.recv_serialized(json.loads)
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_dontwait(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
with self.assertRaises(zmq.Again):
|
||||
yield from f
|
||||
yield from push.send(b'ping')
|
||||
yield from pull.poll() # ensure message will be waiting
|
||||
f = pull.recv(zmq.DONTWAIT)
|
||||
assert f.done()
|
||||
msg = yield from f
|
||||
self.assertEqual(msg, b'ping')
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_recv_cancel(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
yield from a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield from f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.poll(timeout=0)
|
||||
yield from asyncio.sleep(0)
|
||||
self.assertEqual(f.result(), 0)
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = yield from f
|
||||
|
||||
self.assertEqual(evt, 0)
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
yield from a.send_multipart([b'hi', b'there'])
|
||||
evt = yield from f
|
||||
self.assertEqual(evt, zmq.POLLIN)
|
||||
recvd = yield from b.recv_multipart()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll_base_socket(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
ctx = zmq.Context()
|
||||
url = 'inproc://test'
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = zaio.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b'hi', b'there'])
|
||||
evt = yield from f
|
||||
self.assertEqual(evt, [(b, zmq.POLLIN)])
|
||||
recvd = b.recv_multipart()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
def test_poll_on_closed_socket(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
b.close()
|
||||
|
||||
# The test might stall if we try to yield from f directly so instead just make a few
|
||||
# passes through the event loop to schedule and execute all callbacks
|
||||
for _ in range(5):
|
||||
yield from asyncio.sleep(0)
|
||||
if f.cancelled():
|
||||
break
|
||||
assert f.cancelled()
|
||||
|
||||
self.loop.run_until_complete(test())
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Windows does not support polling on files')
|
||||
def test_poll_raw(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
p = zaio.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = yield from p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
evts = yield from p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b'x'
|
||||
r.close()
|
||||
w.close()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(test())
|
||||
|
||||
def test_shadow(self):
|
||||
@asyncio.coroutine
|
||||
def test():
|
||||
ctx = zmq.Context()
|
||||
s = ctx.socket(zmq.PULL)
|
||||
async_s = zaio.Socket(s)
|
||||
assert isinstance(async_s, self.socket_class)
|
||||
|
||||
def test_process_teardown(self):
|
||||
event_loop_policy_class = type(asyncio.get_event_loop_policy())
|
||||
proc = ProcessForTeardownTest(event_loop_policy_class)
|
||||
proc.start()
|
||||
try:
|
||||
proc.join(10) # starting new Python process may cost a lot
|
||||
self.assertEqual(proc.exitcode, 0,
|
||||
"Python process died with code %d" % proc.exitcode
|
||||
if proc.exitcode else "process teardown hangs")
|
||||
finally:
|
||||
proc.terminate()
|
||||
|
||||
|
||||
class TestAsyncioAuthentication(TestThreadAuthentication):
|
||||
"""Test authentication running in a asyncio task"""
|
||||
|
||||
if asyncio is not None:
|
||||
Context = zaio.Context
|
||||
|
||||
def shortDescription(self):
|
||||
"""Rewrite doc strings from TestThreadAuthentication from
|
||||
'threaded' to 'asyncio'.
|
||||
"""
|
||||
doc = self._testMethodDoc
|
||||
if doc:
|
||||
doc = doc.split("\n")[0].strip()
|
||||
if doc.startswith('threaded auth'):
|
||||
doc = doc.replace('threaded auth', 'asyncio auth')
|
||||
return doc
|
||||
|
||||
def setUp(self):
|
||||
if asyncio is None:
|
||||
raise SkipTest()
|
||||
self.loop = zaio.ZMQEventLoop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.loop.close()
|
||||
|
||||
def make_auth(self):
|
||||
return AsyncioAuthenticator(self.context)
|
||||
|
||||
def can_connect(self, server, client):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
@asyncio.coroutine
|
||||
def go():
|
||||
result = False
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
yield from server.send_multipart(msg)
|
||||
if (yield from client.poll(1000)):
|
||||
rcvd_msg = yield from client.recv_multipart()
|
||||
self.assertEqual(rcvd_msg, msg)
|
||||
result = True
|
||||
return result
|
||||
return self.loop.run_until_complete(go())
|
||||
|
||||
def _select_recv(self, multipart, socket, **kwargs):
|
||||
recv = socket.recv_multipart if multipart else socket.recv
|
||||
@asyncio.coroutine
|
||||
def coro():
|
||||
if not (yield from socket.poll(5000)):
|
||||
raise TimeoutError("Should have received a message")
|
||||
return (yield from recv(**kwargs))
|
||||
return self.loop.run_until_complete(coro())
|
6
venv/Lib/site-packages/zmq/tests/asyncio/test_asyncio.py
Normal file
6
venv/Lib/site-packages/zmq/tests/asyncio/test_asyncio.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
"""Test asyncio support"""
|
||||
|
||||
try:
|
||||
from ._test_asyncio import TestAsyncIOSocket, TestAsyncioAuthentication
|
||||
except SyntaxError:
|
||||
pass
|
14
venv/Lib/site-packages/zmq/tests/conftest.py
Normal file
14
venv/Lib/site-packages/zmq/tests/conftest.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
"""pytest configuration and fixtures"""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def win_py38_asyncio():
|
||||
"""fix tornado compatibility on py38"""
|
||||
if sys.version_info < (3, 8) or not sys.platform.startswith('win'):
|
||||
return
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
557
venv/Lib/site-packages/zmq/tests/test_auth.py
Normal file
557
venv/Lib/site-packages/zmq/tests/test_auth.py
Normal file
|
@ -0,0 +1,557 @@
|
|||
# -*- coding: utf8 -*-
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq.auth
|
||||
from zmq.auth.thread import ThreadAuthenticator
|
||||
|
||||
from zmq.utils.strtypes import u
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
|
||||
|
||||
|
||||
class BaseAuthTestCase(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if zmq.zmq_version_info() < (4,0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to have curve support")
|
||||
super(BaseAuthTestCase, self).setUp()
|
||||
# enable debug logging while we run tests
|
||||
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
||||
self.auth = self.make_auth()
|
||||
self.auth.start()
|
||||
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
|
||||
|
||||
def make_auth(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def tearDown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.remove_certs(self.base_dir)
|
||||
super(BaseAuthTestCase, self).tearDown()
|
||||
|
||||
def create_certs(self):
|
||||
"""Create CURVE certificates for a test"""
|
||||
|
||||
# Create temporary CURVE keypairs for this test run. We create all keys in a
|
||||
# temp directory and then move them into the appropriate private or public
|
||||
# directory.
|
||||
|
||||
base_dir = tempfile.mkdtemp()
|
||||
keys_dir = os.path.join(base_dir, 'certificates')
|
||||
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
||||
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
||||
|
||||
os.mkdir(keys_dir)
|
||||
os.mkdir(public_keys_dir)
|
||||
os.mkdir(secret_keys_dir)
|
||||
|
||||
server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
|
||||
client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key"):
|
||||
shutil.move(os.path.join(keys_dir, key_file),
|
||||
os.path.join(public_keys_dir, '.'))
|
||||
|
||||
for key_file in os.listdir(keys_dir):
|
||||
if key_file.endswith(".key_secret"):
|
||||
shutil.move(os.path.join(keys_dir, key_file),
|
||||
os.path.join(secret_keys_dir, '.'))
|
||||
|
||||
return (base_dir, public_keys_dir, secret_keys_dir)
|
||||
|
||||
def remove_certs(self, base_dir):
|
||||
"""Remove certificates for a test"""
|
||||
shutil.rmtree(base_dir)
|
||||
|
||||
def load_certs(self, secret_keys_dir):
|
||||
"""Return server and client certificate keys"""
|
||||
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
||||
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
||||
|
||||
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
||||
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
||||
|
||||
return server_public, server_secret, client_public, client_secret
|
||||
|
||||
|
||||
class TestThreadAuthentication(BaseAuthTestCase):
|
||||
"""Test authentication running in a thread"""
|
||||
|
||||
def make_auth(self):
|
||||
return ThreadAuthenticator(self.context)
|
||||
|
||||
def can_connect(self, server, client):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
result = False
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
msg = [b"Hello World"]
|
||||
if server.poll(1000, zmq.POLLOUT):
|
||||
server.send_multipart(msg)
|
||||
if client.poll(1000):
|
||||
rcvd_msg = client.recv_multipart()
|
||||
self.assertEqual(rcvd_msg, msg)
|
||||
result = True
|
||||
return result
|
||||
|
||||
def test_null(self):
|
||||
"""threaded auth - NULL"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
# use a new context, so ZAP isn't inherited
|
||||
self.context = self.Context()
|
||||
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
def test_blacklist(self):
|
||||
"""threaded auth - Blacklist"""
|
||||
# Blacklist 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
server = self.socket(zmq.PUSH)
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
def test_whitelist(self):
|
||||
"""threaded auth - Whitelist"""
|
||||
# Whitelist 127.0.0.1, connection should pass"
|
||||
self.auth.allow('127.0.0.1')
|
||||
server = self.socket(zmq.PUSH)
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet.
|
||||
server.zap_domain = b'global'
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
def test_plain(self):
|
||||
"""threaded auth - PLAIN"""
|
||||
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Password'
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
# Try PLAIN authentication - with server configured, connection should pass
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Password'
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
# Try PLAIN authentication - with bogus credentials, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.plain_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.plain_username = b'admin'
|
||||
client.plain_password = b'Bogus'
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
def test_curve(self):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
#Try CURVE authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
#Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
|
||||
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
# Try CURVE authentication - with server configured, connection should pass
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
server = self.socket(zmq.PULL)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PUSH)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(client, server)
|
||||
|
||||
# Remove authenticator and check that a normal connection works
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
# Try connecting using NULL and no authentication enabled, connection should pass
|
||||
server = self.socket(zmq.PUSH)
|
||||
client = self.socket(zmq.PULL)
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
def test_curve_callback(self):
|
||||
"""threaded auth - CURVE with callback authentication"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
#Try CURVE authentication - without configuring server, connection should fail
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
#Try CURVE authentication - with callback authentication configured, connection should pass
|
||||
|
||||
class CredentialsProvider(object):
|
||||
def __init__(self):
|
||||
self.client = client_public
|
||||
|
||||
def callback(self, domain, key):
|
||||
if (key == self.client):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
provider = CredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
self.assertTrue(self.can_connect(server, client))
|
||||
|
||||
#Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
|
||||
|
||||
class WrongCredentialsProvider(object):
|
||||
def __init__(self):
|
||||
self.client = "WrongCredentials"
|
||||
|
||||
def callback(self, domain, key):
|
||||
if (key == self.client):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
provider = WrongCredentialsProvider()
|
||||
self.auth.configure_curve_callback(credentials_provider=provider)
|
||||
server = self.socket(zmq.PUSH)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PULL)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
self.assertFalse(self.can_connect(server, client))
|
||||
|
||||
|
||||
|
||||
@skip_pypy
|
||||
def test_curve_user_id(self):
|
||||
"""threaded auth - CURVE"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
server = self.socket(zmq.PULL)
|
||||
server.curve_publickey = server_public
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_server = True
|
||||
client = self.socket(zmq.PUSH)
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
client.curve_serverkey = server_public
|
||||
assert self.can_connect(client, server)
|
||||
|
||||
# test default user-id map
|
||||
client.send(b'test')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.bytes == b'test'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == u(client_public)
|
||||
|
||||
# test custom user-id map
|
||||
self.auth.curve_user_id = lambda client_key: u'custom'
|
||||
|
||||
client2 = self.socket(zmq.PUSH)
|
||||
client2.curve_publickey = client_public
|
||||
client2.curve_secretkey = client_secret
|
||||
client2.curve_serverkey = server_public
|
||||
assert self.can_connect(client2, server)
|
||||
|
||||
client2.send(b'test2')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.bytes == b'test2'
|
||||
try:
|
||||
user_id = msg.get('User-Id')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
else:
|
||||
assert user_id == u'custom'
|
||||
|
||||
|
||||
def with_ioloop(method, expect_success=True):
|
||||
"""decorator for running tests with an IOLoop"""
|
||||
def test_method(self):
|
||||
r = method(self)
|
||||
|
||||
loop = self.io_loop
|
||||
if expect_success:
|
||||
self.pullstream.on_recv(self.on_message_succeed)
|
||||
else:
|
||||
self.pullstream.on_recv(self.on_message_fail)
|
||||
|
||||
loop.call_later(1, self.attempt_connection)
|
||||
loop.call_later(1.2, self.send_msg)
|
||||
|
||||
if expect_success:
|
||||
loop.call_later(2, self.on_test_timeout_fail)
|
||||
else:
|
||||
loop.call_later(2, self.on_test_timeout_succeed)
|
||||
|
||||
loop.start()
|
||||
if self.fail_msg:
|
||||
self.fail(self.fail_msg)
|
||||
|
||||
return r
|
||||
return test_method
|
||||
|
||||
def should_auth(method):
|
||||
return with_ioloop(method, True)
|
||||
|
||||
def should_not_auth(method):
|
||||
return with_ioloop(method, False)
|
||||
|
||||
class TestIOLoopAuthentication(BaseAuthTestCase):
|
||||
"""Test authentication running in ioloop"""
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
from tornado import ioloop
|
||||
except ImportError:
|
||||
pytest.skip("Requires tornado")
|
||||
from zmq.eventloop import zmqstream
|
||||
self.fail_msg = None
|
||||
self.io_loop = ioloop.IOLoop()
|
||||
super(TestIOLoopAuthentication, self).setUp()
|
||||
self.server = self.socket(zmq.PUSH)
|
||||
self.client = self.socket(zmq.PULL)
|
||||
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
|
||||
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
|
||||
|
||||
def make_auth(self):
|
||||
from zmq.auth.ioloop import IOLoopAuthenticator
|
||||
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
|
||||
|
||||
def tearDown(self):
|
||||
if self.auth:
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
self.io_loop.close(all_fds=True)
|
||||
super(TestIOLoopAuthentication, self).tearDown()
|
||||
|
||||
def attempt_connection(self):
|
||||
"""Check if client can connect to server using tcp transport"""
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = self.server.bind_to_random_port(iface)
|
||||
self.client.connect("%s:%i" % (iface, port))
|
||||
|
||||
def send_msg(self):
|
||||
"""Send a message from server to a client"""
|
||||
msg = [b"Hello World"]
|
||||
self.pushstream.send_multipart(msg)
|
||||
|
||||
def on_message_succeed(self, frames):
|
||||
"""A message was received, as expected."""
|
||||
if frames != [b"Hello World"]:
|
||||
self.fail_msg = "Unexpected message received"
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_message_fail(self, frames):
|
||||
"""A message was received, unexpectedly."""
|
||||
self.fail_msg = 'Received messaged unexpectedly, security failed'
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_test_timeout_succeed(self):
|
||||
"""Test timer expired, indicates test success"""
|
||||
self.io_loop.stop()
|
||||
|
||||
def on_test_timeout_fail(self):
|
||||
"""Test timer expired, indicates test failure"""
|
||||
self.fail_msg = 'Test timed out'
|
||||
self.io_loop.stop()
|
||||
|
||||
@should_auth
|
||||
def test_none(self):
|
||||
"""ioloop auth - NONE"""
|
||||
# A default NULL connection should always succeed, and not
|
||||
# go through our authentication infrastructure at all.
|
||||
# no auth should be running
|
||||
self.auth.stop()
|
||||
self.auth = None
|
||||
|
||||
@should_auth
|
||||
def test_null(self):
|
||||
"""ioloop auth - NULL"""
|
||||
# By setting a domain we switch on authentication for NULL sockets,
|
||||
# though no policies are configured yet. The client connection
|
||||
# should still be allowed.
|
||||
self.server.zap_domain = b'global'
|
||||
|
||||
@should_not_auth
|
||||
def test_blacklist(self):
|
||||
"""ioloop auth - Blacklist"""
|
||||
# Blacklist 127.0.0.1, connection should fail
|
||||
self.auth.deny('127.0.0.1')
|
||||
self.server.zap_domain = b'global'
|
||||
|
||||
@should_auth
|
||||
def test_whitelist(self):
|
||||
"""ioloop auth - Whitelist"""
|
||||
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
|
||||
|
||||
@should_not_auth
|
||||
def test_plain_unconfigured_server(self):
|
||||
"""ioloop auth - PLAIN, unconfigured server"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Password'
|
||||
# Try PLAIN authentication - without configuring server, connection should fail
|
||||
self.server.plain_server = True
|
||||
|
||||
@should_auth
|
||||
def test_plain_configured_server(self):
|
||||
"""ioloop auth - PLAIN, configured server"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Password'
|
||||
# Try PLAIN authentication - with server configured, connection should pass
|
||||
self.server.plain_server = True
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
|
||||
@should_not_auth
|
||||
def test_plain_bogus_credentials(self):
|
||||
"""ioloop auth - PLAIN, bogus credentials"""
|
||||
self.client.plain_username = b'admin'
|
||||
self.client.plain_password = b'Bogus'
|
||||
self.server.plain_server = True
|
||||
|
||||
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
||||
|
||||
@should_not_auth
|
||||
def test_curve_unconfigured_server(self):
|
||||
"""ioloop auth - CURVE, unconfigured server"""
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.allow('127.0.0.1')
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
||||
|
||||
@should_auth
|
||||
def test_curve_allow_any(self):
|
||||
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.allow('127.0.0.1')
|
||||
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
||||
|
||||
@should_auth
|
||||
def test_curve_configured_server(self):
|
||||
"""ioloop auth - CURVE, configured server"""
|
||||
self.auth.allow('127.0.0.1')
|
||||
certs = self.load_certs(self.secret_keys_dir)
|
||||
server_public, server_secret, client_public, client_secret = certs
|
||||
|
||||
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
||||
|
||||
self.server.curve_publickey = server_public
|
||||
self.server.curve_secretkey = server_secret
|
||||
self.server.curve_server = True
|
||||
|
||||
self.client.curve_publickey = client_public
|
||||
self.client.curve_secretkey = client_secret
|
||||
self.client.curve_serverkey = server_public
|
297
venv/Lib/site-packages/zmq/tests/test_cffi_backend.py
Normal file
297
venv/Lib/site-packages/zmq/tests/test_cffi_backend.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
# -*- coding: utf8 -*-
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest
|
||||
|
||||
try:
|
||||
from zmq.backend.cffi import (
|
||||
zmq_version_info,
|
||||
PUSH, PULL, IDENTITY,
|
||||
REQ, REP, POLLIN, POLLOUT,
|
||||
)
|
||||
from zmq.backend.cffi._cffi import ffi, C
|
||||
have_ffi_backend = True
|
||||
except ImportError:
|
||||
have_ffi_backend = False
|
||||
|
||||
|
||||
class TestCFFIBackend(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not have_ffi_backend:
|
||||
raise SkipTest('CFFI not available')
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
version = zmq_version_info()
|
||||
|
||||
assert version[0] in range(2,11)
|
||||
|
||||
def test_zmq_ctx_new_destroy(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_socket_open_close(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_setsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[3]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
|
||||
assert ret == 0
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_getsockopt(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, PUSH)
|
||||
|
||||
identity = ffi.new('char[]', b'zmq')
|
||||
ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3)
|
||||
assert ret == 0
|
||||
|
||||
option_len = ffi.new('size_t*', 3)
|
||||
option = ffi.new('char[3]')
|
||||
ret = C.zmq_getsockopt(socket,
|
||||
IDENTITY,
|
||||
ffi.cast('void*', option),
|
||||
option_len)
|
||||
|
||||
assert ret == 0
|
||||
assert ffi.string(ffi.cast('char*', option))[0:1] == b"z"
|
||||
assert ffi.string(ffi.cast('char*', option))[1:2] == b"m"
|
||||
assert ffi.string(ffi.cast('char*', option))[2:3] == b"q"
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
socket = C.zmq_socket(ctx, 8)
|
||||
|
||||
assert 0 == C.zmq_bind(socket, b'tcp://*:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket
|
||||
assert 0 == C.zmq_close(socket)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_bind_connect(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
socket1 = C.zmq_socket(ctx, PUSH)
|
||||
socket2 = C.zmq_socket(ctx, PULL)
|
||||
|
||||
assert 0 == C.zmq_bind(socket1, b'tcp://*:4444')
|
||||
assert 0 == C.zmq_connect(socket2, b'tcp://127.0.0.1:4444')
|
||||
assert ctx != ffi.NULL
|
||||
assert ffi.NULL != socket1
|
||||
assert ffi.NULL != socket2
|
||||
assert 0 == C.zmq_close(socket1)
|
||||
assert 0 == C.zmq_close(socket2)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
|
||||
def test_zmq_msg_init_close(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_size(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_init_size(zmq_msg, 10)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_init_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
assert 0 == C.zmq_msg_init_data(zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
5,
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
def test_zmq_msg_data(self):
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[]', b'Hello')
|
||||
assert 0 == C.zmq_msg_init_data(zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
5,
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
data = C.zmq_msg_data(zmq_msg)
|
||||
|
||||
assert ffi.NULL != zmq_msg
|
||||
assert ffi.string(ffi.cast("char*", data)) == b'Hello'
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
|
||||
|
||||
def test_zmq_send(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:7777')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:7777')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_recv(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
assert 0 == C.zmq_bind(receiver, b'tcp://*:2222')
|
||||
assert 0 == C.zmq_connect(sender, b'tcp://127.0.0.1:2222')
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
assert 5 == C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
|
||||
C.zmq_msg_size(zmq_msg2))[:]
|
||||
assert C.zmq_close(sender) == 0
|
||||
assert C.zmq_close(receiver) == 0
|
||||
assert C.zmq_ctx_destroy(ctx) == 0
|
||||
|
||||
def test_zmq_poll(self):
|
||||
ctx = C.zmq_ctx_new()
|
||||
|
||||
sender = C.zmq_socket(ctx, REQ)
|
||||
receiver = C.zmq_socket(ctx, REP)
|
||||
|
||||
r1 = C.zmq_bind(receiver, b'tcp://*:3333')
|
||||
r2 = C.zmq_connect(sender, b'tcp://127.0.0.1:3333')
|
||||
|
||||
zmq_msg = ffi.new('zmq_msg_t*')
|
||||
message = ffi.new('char[5]', b'Hello')
|
||||
|
||||
C.zmq_msg_init_data(zmq_msg,
|
||||
ffi.cast('void*', message),
|
||||
ffi.cast('size_t', 5),
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
receiver_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
receiver_pollitem.socket = receiver
|
||||
receiver_pollitem.fd = 0
|
||||
receiver_pollitem.events = POLLIN | POLLOUT
|
||||
receiver_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(ffi.NULL, 0, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
ret = C.zmq_msg_send(zmq_msg, sender, 0)
|
||||
print(ffi.string(C.zmq_strerror(C.zmq_errno())))
|
||||
assert ret == 5
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
ret = C.zmq_poll(receiver_pollitem, 1, 0)
|
||||
assert ret == 1
|
||||
|
||||
assert int(receiver_pollitem.revents) & POLLIN
|
||||
assert not int(receiver_pollitem.revents) & POLLOUT
|
||||
|
||||
zmq_msg2 = ffi.new('zmq_msg_t*')
|
||||
C.zmq_msg_init(zmq_msg2)
|
||||
|
||||
ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0)
|
||||
assert ret_recv == 5
|
||||
|
||||
assert 5 == C.zmq_msg_size(zmq_msg2)
|
||||
assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
|
||||
C.zmq_msg_size(zmq_msg2))[:]
|
||||
|
||||
sender_pollitem = ffi.new('zmq_pollitem_t*')
|
||||
sender_pollitem.socket = sender
|
||||
sender_pollitem.fd = 0
|
||||
sender_pollitem.events = POLLIN | POLLOUT
|
||||
sender_pollitem.revents = 0
|
||||
|
||||
ret = C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert ret == 0
|
||||
|
||||
zmq_msg_again = ffi.new('zmq_msg_t*')
|
||||
message_again = ffi.new('char[11]', b'Hello Again')
|
||||
|
||||
C.zmq_msg_init_data(zmq_msg_again,
|
||||
ffi.cast('void*', message_again),
|
||||
ffi.cast('size_t', 11),
|
||||
ffi.NULL,
|
||||
ffi.NULL)
|
||||
|
||||
assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0)
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
assert 0 <= C.zmq_poll(sender_pollitem, 1, 0)
|
||||
assert int(sender_pollitem.revents) & POLLIN
|
||||
assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0)
|
||||
assert 11 == C.zmq_msg_size(zmq_msg2)
|
||||
assert b"Hello Again" == ffi.buffer(C.zmq_msg_data(zmq_msg2),
|
||||
int(C.zmq_msg_size(zmq_msg2)))[:]
|
||||
assert 0 == C.zmq_close(sender)
|
||||
assert 0 == C.zmq_close(receiver)
|
||||
assert 0 == C.zmq_ctx_destroy(ctx)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg2)
|
||||
assert 0 == C.zmq_msg_close(zmq_msg_again)
|
||||
|
121
venv/Lib/site-packages/zmq/tests/test_constants.py
Normal file
121
venv/Lib/site-packages/zmq/tests/test_constants.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import json
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
|
||||
from zmq.utils import constant_names
|
||||
from zmq.sugar import constants as sugar_constants
|
||||
from zmq.backend import constants as backend_constants
|
||||
|
||||
all_set = set(constant_names.all_names)
|
||||
|
||||
class TestConstants(TestCase):
|
||||
|
||||
def _duplicate_test(self, namelist, listname):
|
||||
"""test that a given list has no duplicates"""
|
||||
dupes = {}
|
||||
for name in set(namelist):
|
||||
cnt = namelist.count(name)
|
||||
if cnt > 1:
|
||||
dupes[name] = cnt
|
||||
if dupes:
|
||||
self.fail("The following names occur more than once in %s: %s" % (listname, json.dumps(dupes, indent=2)))
|
||||
|
||||
def test_duplicate_all(self):
|
||||
return self._duplicate_test(constant_names.all_names, "all_names")
|
||||
|
||||
def _change_key(self, change, version):
|
||||
"""return changed-in key"""
|
||||
return "%s-in %d.%d.%d" % tuple([change] + list(version))
|
||||
|
||||
def test_duplicate_changed(self):
|
||||
all_changed = []
|
||||
for change in ("new", "removed"):
|
||||
d = getattr(constant_names, change + "_in")
|
||||
for version, namelist in d.items():
|
||||
all_changed.extend(namelist)
|
||||
self._duplicate_test(namelist, self._change_key(change, version))
|
||||
|
||||
self._duplicate_test(all_changed, "all-changed")
|
||||
|
||||
def test_changed_in_all(self):
|
||||
missing = {}
|
||||
for change in ("new", "removed"):
|
||||
d = getattr(constant_names, change + "_in")
|
||||
for version, namelist in d.items():
|
||||
key = self._change_key(change, version)
|
||||
for name in namelist:
|
||||
if name not in all_set:
|
||||
if key not in missing:
|
||||
missing[key] = []
|
||||
missing[key].append(name)
|
||||
|
||||
if missing:
|
||||
self.fail(
|
||||
"The following names are missing in `all_names`: %s" % json.dumps(missing, indent=2)
|
||||
)
|
||||
|
||||
def test_no_negative_constants(self):
|
||||
for name in sugar_constants.__all__:
|
||||
self.assertNotEqual(getattr(zmq, name), sugar_constants._UNDEFINED)
|
||||
|
||||
def test_undefined_constants(self):
|
||||
all_aliases = []
|
||||
for alias_group in sugar_constants.aliases:
|
||||
all_aliases.extend(alias_group)
|
||||
|
||||
for name in all_set.difference(all_aliases):
|
||||
raw = getattr(backend_constants, name)
|
||||
if raw == sugar_constants._UNDEFINED:
|
||||
self.assertRaises(AttributeError, getattr, zmq, name)
|
||||
else:
|
||||
self.assertEqual(getattr(zmq, name), raw)
|
||||
|
||||
def test_new(self):
|
||||
zmq_version = zmq.zmq_version_info()
|
||||
for version, new_names in constant_names.new_in.items():
|
||||
should_have = zmq_version >= version
|
||||
for name in new_names:
|
||||
try:
|
||||
value = getattr(zmq, name)
|
||||
except AttributeError:
|
||||
if should_have:
|
||||
self.fail("AttributeError: zmq.%s" % name)
|
||||
else:
|
||||
if not should_have:
|
||||
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))
|
||||
|
||||
@pytest.mark.skipif(not zmq.DRAFT_API, reason="Only test draft API if built with draft API")
|
||||
def test_draft(self):
|
||||
zmq_version = zmq.zmq_version_info()
|
||||
for version, new_names in constant_names.draft_in.items():
|
||||
should_have = zmq_version >= version
|
||||
for name in new_names:
|
||||
try:
|
||||
value = getattr(zmq, name)
|
||||
except AttributeError:
|
||||
if should_have:
|
||||
self.fail("AttributeError: zmq.%s" % name)
|
||||
else:
|
||||
if not should_have:
|
||||
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))
|
||||
|
||||
def test_removed(self):
|
||||
zmq_version = zmq.zmq_version_info()
|
||||
for version, new_names in constant_names.removed_in.items():
|
||||
should_have = zmq_version < version
|
||||
for name in new_names:
|
||||
try:
|
||||
value = getattr(zmq, name)
|
||||
except AttributeError:
|
||||
if should_have:
|
||||
self.fail("AttributeError: zmq.%s" % name)
|
||||
else:
|
||||
if not should_have:
|
||||
self.fail("Shouldn't have: zmq.%s=%s" % (name, value))
|
||||
|
392
venv/Lib/site-packages/zmq/tests/test_context.py
Normal file
392
venv/Lib/site-packages/zmq/tests/test_context.py
Normal file
|
@ -0,0 +1,392 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread, Event
|
||||
try:
|
||||
from queue import Queue
|
||||
except ImportError:
|
||||
from Queue import Queue
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import (
|
||||
BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest,
|
||||
)
|
||||
|
||||
|
||||
class KwargTestSocket(zmq.Socket):
|
||||
test_kwarg_value = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.test_kwarg_value = kwargs.pop('test_kwarg', None)
|
||||
super(KwargTestSocket, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class KwargTestContext(zmq.Context):
|
||||
_socket_class = KwargTestSocket
|
||||
|
||||
|
||||
class TestContext(BaseZMQTestCase):
|
||||
|
||||
def test_init(self):
|
||||
c1 = self.Context()
|
||||
self.assert_(isinstance(c1, self.Context))
|
||||
del c1
|
||||
c2 = self.Context()
|
||||
self.assert_(isinstance(c2, self.Context))
|
||||
del c2
|
||||
c3 = self.Context()
|
||||
self.assert_(isinstance(c3, self.Context))
|
||||
del c3
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
self.assertTrue('socket' in dir(ctx))
|
||||
if zmq.zmq_version_info() > (3,):
|
||||
self.assertTrue('IO_THREADS' in dir(ctx))
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
m = mock.Mock(spec=self.context)
|
||||
|
||||
|
||||
def test_term(self):
|
||||
c = self.Context()
|
||||
c.term()
|
||||
self.assert_(c.closed)
|
||||
|
||||
def test_context_manager(self):
|
||||
with self.Context() as c:
|
||||
pass
|
||||
self.assert_(c.closed)
|
||||
|
||||
def test_fail_init(self):
|
||||
self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
|
||||
|
||||
def test_term_hang(self):
|
||||
rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
req.setsockopt(zmq.LINGER, 0)
|
||||
req.send(b'hello', copy=False)
|
||||
req.close()
|
||||
rep.close()
|
||||
self.context.term()
|
||||
|
||||
def test_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
c2 = self.Context.instance(io_threads=2)
|
||||
self.assertTrue(c2 is ctx)
|
||||
c2.term()
|
||||
c3 = self.Context.instance()
|
||||
c4 = self.Context.instance()
|
||||
self.assertFalse(c3 is c2)
|
||||
self.assertFalse(c3.closed)
|
||||
self.assertTrue(c3 is c4)
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
self.context.term()
|
||||
class SubContext(zmq.Context):
|
||||
pass
|
||||
sctx = SubContext.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is SubContext
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
self.context.term()
|
||||
class SubContextInherit(zmq.Context):
|
||||
pass
|
||||
class SubContextNoInherit(zmq.Context):
|
||||
_instance = None
|
||||
pass
|
||||
ctx = zmq.Context.instance()
|
||||
sctx = SubContextInherit.instance()
|
||||
sctx2 = SubContextNoInherit.instance()
|
||||
ctx.term()
|
||||
sctx.term()
|
||||
sctx2.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(sctx) is zmq.Context
|
||||
assert type(sctx2) is SubContextNoInherit
|
||||
|
||||
def test_instance_threadsafe(self):
|
||||
self.context.term() # clear default context
|
||||
|
||||
q = Queue()
|
||||
# slow context initialization,
|
||||
# to ensure that we are both trying to create one at the same time
|
||||
class SlowContext(self.Context):
|
||||
def __init__(self, *a, **kw):
|
||||
time.sleep(1)
|
||||
super(SlowContext, self).__init__(*a, **kw)
|
||||
|
||||
def f():
|
||||
q.put(SlowContext.instance())
|
||||
|
||||
# call ctx.instance() in several threads at once
|
||||
N = 16
|
||||
threads = [ Thread(target=f) for i in range(N) ]
|
||||
[ t.start() for t in threads ]
|
||||
# also call it in the main thread (not first)
|
||||
ctx = SlowContext.instance()
|
||||
assert isinstance(ctx, SlowContext)
|
||||
# check that all the threads got the same context
|
||||
for i in range(N):
|
||||
thread_ctx = q.get(timeout=5)
|
||||
assert thread_ctx is ctx
|
||||
# cleanup
|
||||
ctx.term()
|
||||
[ t.join(timeout=5) for t in threads ]
|
||||
|
||||
def test_socket_passes_kwargs(self):
|
||||
test_kwarg_value = 'testing one two three'
|
||||
with KwargTestContext() as ctx:
|
||||
with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
|
||||
self.assertTrue(socket.test_kwarg_value is test_kwarg_value)
|
||||
|
||||
def test_many_sockets(self):
|
||||
"""opening and closing many sockets shouldn't cause problems"""
|
||||
ctx = self.Context()
|
||||
for i in range(16):
|
||||
sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
|
||||
[ s.close() for s in sockets ]
|
||||
# give the reaper a chance
|
||||
time.sleep(1e-2)
|
||||
ctx.term()
|
||||
|
||||
def test_sockopts(self):
|
||||
"""setting socket options with ctx attributes"""
|
||||
ctx = self.Context()
|
||||
ctx.linger = 5
|
||||
self.assertEqual(ctx.linger, 5)
|
||||
s = ctx.socket(zmq.REQ)
|
||||
self.assertEqual(s.linger, 5)
|
||||
self.assertEqual(s.getsockopt(zmq.LINGER), 5)
|
||||
s.close()
|
||||
# check that subscribe doesn't get set on sockets that don't subscribe:
|
||||
ctx.subscribe = b''
|
||||
s = ctx.socket(zmq.REQ)
|
||||
s.close()
|
||||
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Segfaults on Windows')
|
||||
def test_destroy(self):
|
||||
"""Context.destroy should close sockets"""
|
||||
ctx = self.Context()
|
||||
sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
|
||||
|
||||
# close half of the sockets
|
||||
[ s.close() for s in sockets[::2] ]
|
||||
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in sockets:
|
||||
self.assertTrue(s.closed)
|
||||
|
||||
def test_destroy_linger(self):
|
||||
"""Context.destroy should set linger on closing sockets"""
|
||||
req,rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
req.send(b'hi')
|
||||
time.sleep(1e-2)
|
||||
self.context.destroy(linger=0)
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
for s in (req,rep):
|
||||
self.assertTrue(s.closed)
|
||||
|
||||
def test_term_noclose(self):
|
||||
"""Context.term won't close sockets"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REQ)
|
||||
self.assertFalse(s.closed)
|
||||
t = Thread(target=ctx.term)
|
||||
t.start()
|
||||
t.join(timeout=0.1)
|
||||
self.assertTrue(t.is_alive(), "Context should be waiting")
|
||||
s.close()
|
||||
t.join(timeout=0.1)
|
||||
self.assertFalse(t.is_alive(), "Context should have closed")
|
||||
|
||||
def test_gc(self):
|
||||
"""test close&term by garbage collection alone"""
|
||||
if PYPY:
|
||||
raise SkipTest("GC doesn't work ")
|
||||
|
||||
# test credit @dln (GH #137):
|
||||
def gcf():
|
||||
def inner():
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUSH)
|
||||
inner()
|
||||
gc.collect()
|
||||
t = Thread(target=gcf)
|
||||
t.start()
|
||||
t.join(timeout=1)
|
||||
self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context")
|
||||
|
||||
def test_cyclic_destroy(self):
|
||||
"""ctx.destroy should succeed when cyclic ref prevents gc"""
|
||||
# test credit @dln (GH #137):
|
||||
class CyclicReference(object):
|
||||
def __init__(self, parent=None):
|
||||
self.parent = parent
|
||||
|
||||
def crash(self, sock):
|
||||
self.sock = sock
|
||||
self.child = CyclicReference(self)
|
||||
|
||||
def crash_zmq():
|
||||
ctx = self.Context()
|
||||
sock = ctx.socket(zmq.PULL)
|
||||
c = CyclicReference()
|
||||
c.crash(sock)
|
||||
ctx.destroy()
|
||||
|
||||
crash_zmq()
|
||||
|
||||
def test_term_thread(self):
|
||||
"""ctx.term should not crash active threads (#139)"""
|
||||
ctx = self.Context()
|
||||
evt = Event()
|
||||
evt.clear()
|
||||
|
||||
def block():
|
||||
s = ctx.socket(zmq.REP)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
evt.set()
|
||||
try:
|
||||
s.recv()
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(e.errno, zmq.ETERM)
|
||||
return
|
||||
finally:
|
||||
s.close()
|
||||
self.fail("recv should have been interrupted with ETERM")
|
||||
t = Thread(target=block)
|
||||
t.start()
|
||||
|
||||
evt.wait(1)
|
||||
self.assertTrue(evt.is_set(), "sync event never fired")
|
||||
time.sleep(0.01)
|
||||
ctx.term()
|
||||
t.join(timeout=1)
|
||||
self.assertFalse(t.is_alive(), "term should have interrupted s.recv()")
|
||||
|
||||
def test_destroy_no_sockets(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.bind_to_random_port('tcp://127.0.0.1')
|
||||
s.close()
|
||||
ctx.destroy()
|
||||
assert s.closed
|
||||
assert ctx.closed
|
||||
|
||||
def test_ctx_opts(self):
|
||||
if zmq.zmq_version_info() < (3,):
|
||||
raise SkipTest("context options require libzmq 3")
|
||||
ctx = self.Context()
|
||||
ctx.set(zmq.MAX_SOCKETS, 2)
|
||||
self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2)
|
||||
ctx.max_sockets = 100
|
||||
self.assertEqual(ctx.max_sockets, 100)
|
||||
self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100)
|
||||
|
||||
def test_copy(self):
|
||||
c1 = self.Context()
|
||||
c2 = copy.copy(c1)
|
||||
c2b = copy.deepcopy(c1)
|
||||
c3 = copy.deepcopy(c2)
|
||||
self.assert_(c2._shadow)
|
||||
self.assert_(c3._shadow)
|
||||
self.assertEqual(c1.underlying, c2.underlying)
|
||||
self.assertEqual(c1.underlying, c3.underlying)
|
||||
self.assertEqual(c1.underlying, c2b.underlying)
|
||||
s = c3.socket(zmq.PUB)
|
||||
s.close()
|
||||
c1.term()
|
||||
|
||||
def test_shadow(self):
|
||||
ctx = self.Context()
|
||||
ctx2 = self.Context.shadow(ctx.underlying)
|
||||
self.assertEqual(ctx.underlying, ctx2.underlying)
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
del ctx2
|
||||
self.assertFalse(ctx.closed)
|
||||
s = ctx.socket(zmq.PUB)
|
||||
ctx2 = self.Context.shadow(ctx.underlying)
|
||||
s2 = ctx2.socket(zmq.PUB)
|
||||
s.close()
|
||||
s2.close()
|
||||
ctx.term()
|
||||
self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
|
||||
del ctx2
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket, zstr
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
a = zsocket.new(ctx, zmq.PUSH)
|
||||
zsocket.bind(a, "inproc://a")
|
||||
ctx2 = self.Context.shadow_pyczmq(ctx)
|
||||
b = ctx2.socket(zmq.PULL)
|
||||
b.connect("inproc://a")
|
||||
zstr.send(a, b'hi')
|
||||
rcvd = self.recv(b)
|
||||
self.assertEqual(rcvd, b'hi')
|
||||
b.close()
|
||||
|
||||
@mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='No fork on Windows')
|
||||
def test_fork_instance(self):
|
||||
ctx = self.Context.instance()
|
||||
parent_ctx_id = id(ctx)
|
||||
r_fd, w_fd = os.pipe()
|
||||
reader = os.fdopen(r_fd, 'r')
|
||||
child_pid = os.fork()
|
||||
if child_pid == 0:
|
||||
ctx = self.Context.instance()
|
||||
writer = os.fdopen(w_fd, 'w')
|
||||
child_ctx_id = id(ctx)
|
||||
ctx.term()
|
||||
writer.write(str(child_ctx_id) + "\n")
|
||||
writer.flush()
|
||||
writer.close()
|
||||
os._exit(0)
|
||||
else:
|
||||
os.close(w_fd)
|
||||
|
||||
child_id_s = reader.readline()
|
||||
reader.close()
|
||||
assert child_id_s
|
||||
assert int(child_id_s) != parent_ctx_id
|
||||
ctx.term()
|
||||
|
||||
|
||||
if False: # disable green context tests
|
||||
class TestContextGreen(GreenTest, TestContext):
|
||||
"""gevent subclass of context tests"""
|
||||
# skip tests that use real threads:
|
||||
test_gc = GreenTest.skip_green
|
||||
test_term_thread = GreenTest.skip_green
|
||||
test_destroy_linger = GreenTest.skip_green
|
41
venv/Lib/site-packages/zmq/tests/test_cython.py
Normal file
41
venv/Lib/site-packages/zmq/tests/test_cython.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
@pytest.mark.skipif(
|
||||
'zmq.backend.cython' not in sys.modules, reason="Requires cython backend"
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'), reason="Don't try runtime Cython on Windows"
|
||||
)
|
||||
@pytest.mark.parametrize('language_level', [3, 2])
|
||||
def test_cython(language_level, request, tmpdir):
|
||||
import pyximport
|
||||
|
||||
assert 'zmq.tests.cython_ext' not in sys.modules
|
||||
|
||||
importers = pyximport.install(
|
||||
setup_args=dict(include_dirs=zmq.get_includes()),
|
||||
language_level=language_level,
|
||||
build_dir=str(tmpdir),
|
||||
)
|
||||
|
||||
cython_ext = None
|
||||
|
||||
def unimport():
|
||||
pyximport.uninstall(*importers)
|
||||
sys.modules.pop('zmq.tests.cython_ext', None)
|
||||
|
||||
request.addfinalizer(unimport)
|
||||
|
||||
# this import tests the compilation
|
||||
from . import cython_ext
|
||||
assert hasattr(cython_ext, 'send_recv_test')
|
||||
|
||||
# call the compiled function
|
||||
# this shouldn't do much
|
||||
msg = b'my msg'
|
||||
received = cython_ext.send_recv_test(msg)
|
||||
assert received == msg
|
375
venv/Lib/site-packages/zmq/tests/test_decorators.py
Normal file
375
venv/Lib/site-packages/zmq/tests/test_decorators.py
Normal file
|
@ -0,0 +1,375 @@
|
|||
import threading
|
||||
import zmq
|
||||
|
||||
from pytest import raises
|
||||
from zmq.decorators import context, socket
|
||||
|
||||
|
||||
##############################################
|
||||
# Test cases for @context
|
||||
##############################################
|
||||
|
||||
|
||||
def test_ctx():
|
||||
@context()
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_orig_args():
|
||||
@context()
|
||||
def f(foo, bar, ctx, baz=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert foo == 42
|
||||
assert bar is True
|
||||
assert baz == 'mock'
|
||||
|
||||
f(42, True, baz='mock')
|
||||
|
||||
|
||||
def test_ctx_arg_naming():
|
||||
@context('myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_args():
|
||||
@context('ctx', 5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_arg_kwarg():
|
||||
@context('ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kw_naming():
|
||||
@context(name='myctx')
|
||||
def test(myctx):
|
||||
assert isinstance(myctx, zmq.Context), myctx
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_kwargs_default():
|
||||
@context(name='ctx', io_threads=5)
|
||||
def test(ctx=None):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert ctx.IO_THREADS == 5, ctx.IO_THREADS
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_keyword_miss():
|
||||
@context(name='ctx')
|
||||
def test(other_name):
|
||||
pass # the keyword ``ctx`` not found
|
||||
with raises(TypeError):
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_multi_assign():
|
||||
@context(name='ctx')
|
||||
def test(ctx):
|
||||
pass # explosion
|
||||
with raises(TypeError):
|
||||
test('mock')
|
||||
|
||||
|
||||
def test_ctx_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@context()
|
||||
def f(key, ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
result[key] = ctx
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_multi_thread():
|
||||
@context()
|
||||
@context()
|
||||
def f(foo, bar):
|
||||
assert isinstance(foo, zmq.Context), foo
|
||||
assert isinstance(bar, zmq.Context), bar
|
||||
|
||||
assert len(set(map(id, [foo, bar]))) == 2, set(map(id, [foo, bar]))
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
##############################################
|
||||
# Test cases for @socket
|
||||
##############################################
|
||||
|
||||
|
||||
def test_ctx_skt():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def test(ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.type == zmq.PUB
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_name():
|
||||
@context()
|
||||
@socket('myskt', zmq.PUB)
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_kwarg():
|
||||
@context()
|
||||
@socket(zmq.PUB, name='myskt')
|
||||
def test(ctx, myskt):
|
||||
assert isinstance(myskt, zmq.Socket), myskt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert myskt.type == zmq.PUB
|
||||
test()
|
||||
|
||||
|
||||
def test_ctx_skt_name():
|
||||
@context('ctx')
|
||||
@socket('skt', zmq.PUB, context_name='ctx')
|
||||
def test(ctx, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert skt.type == zmq.PUB
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_default_ctx():
|
||||
@socket(zmq.PUB)
|
||||
def test(skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
assert skt.context is zmq.Context.instance()
|
||||
assert skt.type == zmq.PUB
|
||||
test()
|
||||
|
||||
|
||||
def test_skt_reinit():
|
||||
result = {'foo': None, 'bar': None}
|
||||
|
||||
@socket(zmq.PUB)
|
||||
def f(key, skt):
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo'] is not None, result
|
||||
assert result['bar'] is not None, result
|
||||
assert result['foo'] is not result['bar'], result
|
||||
|
||||
|
||||
def test_ctx_skt_reinit():
|
||||
result = {'foo': {'ctx': None, 'skt': None},
|
||||
'bar': {'ctx': None, 'skt': None}}
|
||||
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
def f(key, ctx, skt):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(skt, zmq.Socket), skt
|
||||
|
||||
result[key]['ctx'] = ctx
|
||||
result[key]['skt'] = skt
|
||||
|
||||
foo_t = threading.Thread(target=f, args=('foo',))
|
||||
bar_t = threading.Thread(target=f, args=('bar',))
|
||||
|
||||
foo_t.start()
|
||||
bar_t.start()
|
||||
|
||||
foo_t.join()
|
||||
bar_t.join()
|
||||
|
||||
assert result['foo']['ctx'] is not None, result
|
||||
assert result['foo']['skt'] is not None, result
|
||||
assert result['bar']['ctx'] is not None, result
|
||||
assert result['bar']['skt'] is not None, result
|
||||
assert result['foo']['ctx'] is not result['bar']['ctx'], result
|
||||
assert result['foo']['skt'] is not result['bar']['skt'], result
|
||||
|
||||
|
||||
def test_skt_type_miss():
|
||||
@context()
|
||||
@socket('myskt')
|
||||
def f(ctx, myskt):
|
||||
pass # the socket type is missing
|
||||
with raises(TypeError):
|
||||
f()
|
||||
|
||||
|
||||
def test_multi_skts():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_single_ctx():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def test(ctx, pub, sub, push):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
assert push.context is ctx
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
test()
|
||||
|
||||
|
||||
def test_multi_skts_with_name():
|
||||
@socket('foo', zmq.PUSH)
|
||||
@socket('bar', zmq.SUB)
|
||||
@socket('baz', zmq.PUB)
|
||||
def test(foo, bar, baz):
|
||||
assert isinstance(foo, zmq.Socket), foo
|
||||
assert isinstance(bar, zmq.Socket), bar
|
||||
assert isinstance(baz, zmq.Socket), baz
|
||||
|
||||
assert foo.context is zmq.Context.instance()
|
||||
assert bar.context is zmq.Context.instance()
|
||||
assert baz.context is zmq.Context.instance()
|
||||
|
||||
assert foo.type == zmq.PUSH
|
||||
assert bar.type == zmq.SUB
|
||||
assert baz.type == zmq.PUB
|
||||
test()
|
||||
|
||||
def test_func_return():
|
||||
@context()
|
||||
def f(ctx):
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
return 'something'
|
||||
|
||||
assert f() == 'something'
|
||||
|
||||
|
||||
def test_skt_multi_thread():
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
@socket(zmq.PUSH)
|
||||
def f(pub, sub, push):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert isinstance(push, zmq.Socket), push
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
assert push.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type == zmq.PUB
|
||||
assert sub.type == zmq.SUB
|
||||
assert push.type == zmq.PUSH
|
||||
|
||||
assert len(set(map(id, [pub, sub, push]))) == 3
|
||||
|
||||
threads = [threading.Thread(target=f) for i in range(8)]
|
||||
[t.start() for t in threads]
|
||||
[t.join() for t in threads]
|
||||
|
||||
|
||||
class TestMethodDecorators():
|
||||
@context()
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def multi_skts_method(self, ctx, pub, sub, foo='bar'):
|
||||
assert isinstance(self, TestMethodDecorators), self
|
||||
assert isinstance(ctx, zmq.Context), ctx
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
assert foo == 'bar'
|
||||
|
||||
assert pub.context is ctx
|
||||
assert sub.context is ctx
|
||||
|
||||
assert pub.type is zmq.PUB
|
||||
assert sub.type is zmq.SUB
|
||||
|
||||
def test_multi_skts_method(self):
|
||||
self.multi_skts_method()
|
||||
|
||||
def multi_skts_method_other_args(self):
|
||||
@socket(zmq.PUB)
|
||||
@socket(zmq.SUB)
|
||||
def f(foo, pub, sub, bar=None):
|
||||
assert isinstance(pub, zmq.Socket), pub
|
||||
assert isinstance(sub, zmq.Socket), sub
|
||||
|
||||
assert foo == 'mock'
|
||||
assert bar == 'fake'
|
||||
|
||||
assert pub.context is zmq.Context.instance()
|
||||
assert sub.context is zmq.Context.instance()
|
||||
|
||||
assert pub.type is zmq.PUB
|
||||
assert sub.type is zmq.SUB
|
||||
|
||||
f('mock', bar='fake')
|
||||
|
||||
def test_multi_skts_method_other_args(self):
|
||||
self.multi_skts_method_other_args()
|
167
venv/Lib/site-packages/zmq/tests/test_device.py
Normal file
167
venv/Lib/site-packages/zmq/tests/test_device.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest, PYPY
|
||||
from zmq.utils.strtypes import (bytes,unicode,basestring)
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
class TestDevice(BaseZMQTestCase):
|
||||
|
||||
def test_device_types(self):
|
||||
for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE):
|
||||
dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR)
|
||||
self.assertEqual(dev.device_type, devtype)
|
||||
del dev
|
||||
|
||||
def test_device_attributes(self):
|
||||
dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB)
|
||||
self.assertEqual(dev.in_type, zmq.SUB)
|
||||
self.assertEqual(dev.out_type, zmq.PUB)
|
||||
self.assertEqual(dev.device_type, zmq.QUEUE)
|
||||
self.assertEqual(dev.daemon, True)
|
||||
del dev
|
||||
|
||||
def test_single_socket_forwarder_connect(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_in('tcp://127.0.0.1:%i'%port)
|
||||
dev.start()
|
||||
time.sleep(.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
self.assertEqual(msg, self.recv(req))
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
port = req.bind_to_random_port('tcp://127.0.0.1')
|
||||
dev.connect_out('tcp://127.0.0.1:%i'%port)
|
||||
dev.start()
|
||||
time.sleep(.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
self.assertEqual(msg, self.recv(req))
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_single_socket_forwarder_bind(self):
|
||||
if zmq.zmq_version() in ('4.1.1', '4.0.6'):
|
||||
raise SkipTest("libzmq-%s broke single-socket devices" % zmq.zmq_version())
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i'%port)
|
||||
dev.start()
|
||||
time.sleep(.25)
|
||||
msg = b'hello'
|
||||
req.send(msg)
|
||||
self.assertEqual(msg, self.recv(req))
|
||||
del dev
|
||||
req.close()
|
||||
dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1)
|
||||
port = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
req = self.context.socket(zmq.REQ)
|
||||
req.connect('tcp://127.0.0.1:%i'%port)
|
||||
dev.start()
|
||||
time.sleep(.25)
|
||||
msg = b'hello again'
|
||||
req.send(msg)
|
||||
self.assertEqual(msg, self.recv(req))
|
||||
del dev
|
||||
req.close()
|
||||
|
||||
def test_device_bind_to_random_with_args(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend([
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max)
|
||||
])
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_device_bind_to_random_binderror(self):
|
||||
dev = devices.ThreadDevice(zmq.PULL, zmq.PUSH, -1)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
try:
|
||||
for i in range(11):
|
||||
dev.bind_in_to_random_port(
|
||||
iface, min_port=10000, max_port=10010
|
||||
)
|
||||
except zmq.ZMQBindError as e:
|
||||
return
|
||||
else:
|
||||
self.fail('Should have failed')
|
||||
|
||||
def test_proxy(self):
|
||||
if zmq.zmq_version_info() < (3,2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon])
|
||||
self.assertEqual(msg, self.recv(pull))
|
||||
self.assertEqual(msg, self.recv(mon))
|
||||
|
||||
def test_proxy_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (3, 2):
|
||||
raise SkipTest("Proxies only in libzmq >= 3")
|
||||
dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend([
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max)
|
||||
])
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
import zmq.green
|
||||
|
||||
class TestDeviceGreen(GreenTest, BaseZMQTestCase):
|
||||
|
||||
def test_green_device(self):
|
||||
rep = self.context.socket(zmq.REP)
|
||||
req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([req, rep])
|
||||
port = rep.bind_to_random_port('tcp://127.0.0.1')
|
||||
g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep)
|
||||
req.connect('tcp://127.0.0.1:%i' % port)
|
||||
req.send(b'hi')
|
||||
timeout = gevent.Timeout(3)
|
||||
timeout.start()
|
||||
receiver = gevent.spawn(req.recv)
|
||||
self.assertEqual(receiver.get(2), b'hi')
|
||||
timeout.cancel()
|
||||
g.kill(block=True)
|
||||
|
52
venv/Lib/site-packages/zmq/tests/test_draft.py
Normal file
52
venv/Lib/site-packages/zmq/tests/test_draft.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
from zmq.tests import (
|
||||
BaseZMQTestCase, skip_pypy
|
||||
)
|
||||
|
||||
|
||||
class TestDraftSockets(BaseZMQTestCase):
|
||||
def setUp(self):
|
||||
if not zmq.DRAFT_API:
|
||||
raise pytest.skip("draft api unavailable")
|
||||
super(TestDraftSockets, self).setUp()
|
||||
|
||||
|
||||
def test_client_server(self):
|
||||
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
|
||||
client.send(b'request')
|
||||
msg = self.recv(server, copy=False)
|
||||
assert msg.routing_id is not None
|
||||
server.send(b'reply', routing_id=msg.routing_id)
|
||||
reply = self.recv(client)
|
||||
assert reply == b'reply'
|
||||
|
||||
def test_radio_dish(self):
|
||||
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
|
||||
dish.rcvtimeo = 250
|
||||
group = 'mygroup'
|
||||
dish.join(group)
|
||||
received_count = 0
|
||||
received = set()
|
||||
sent = set()
|
||||
for i in range(10):
|
||||
msg = str(i).encode('ascii')
|
||||
sent.add(msg)
|
||||
radio.send(msg, group=group)
|
||||
try:
|
||||
recvd = dish.recv()
|
||||
except zmq.Again:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
received.add(recvd)
|
||||
received_count += 1
|
||||
# assert that we got *something*
|
||||
assert len(received.intersection(sent)) >= 5
|
43
venv/Lib/site-packages/zmq/tests/test_error.py
Normal file
43
venv/Lib/site-packages/zmq/tests/test_error.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq import ZMQError, strerror, Again, ContextTerminated
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
if sys.version_info[0] >= 3:
|
||||
long = int
|
||||
|
||||
class TestZMQError(BaseZMQTestCase):
|
||||
|
||||
def test_strerror(self):
|
||||
"""test that strerror gets the right type."""
|
||||
for i in range(10):
|
||||
e = strerror(i)
|
||||
self.assertTrue(isinstance(e, str))
|
||||
|
||||
def test_zmqerror(self):
|
||||
for errno in range(10):
|
||||
e = ZMQError(errno)
|
||||
self.assertEqual(e.errno, errno)
|
||||
self.assertEqual(str(e), strerror(errno))
|
||||
|
||||
def test_again(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
self.assertRaises(Again, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
|
||||
def atest_ctxterm(self):
|
||||
s = self.context.socket(zmq.REP)
|
||||
t = Thread(target=self.context.term)
|
||||
t.start()
|
||||
self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK)
|
||||
self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK)
|
||||
s.close()
|
||||
t.join()
|
20
venv/Lib/site-packages/zmq/tests/test_etc.py
Normal file
20
venv/Lib/site-packages/zmq/tests/test_etc.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) PyZMQ Developers.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import sys
|
||||
|
||||
import zmq
|
||||
|
||||
from pytest import mark
|
||||
|
||||
@mark.skipif('zmq.zmq_version_info() < (4,1)')
|
||||
def test_has():
|
||||
assert not zmq.has('something weird')
|
||||
has_ipc = zmq.has('ipc')
|
||||
not_windows = not sys.platform.startswith('win')
|
||||
assert has_ipc == not_windows
|
||||
|
||||
@mark.skipif(not hasattr(zmq, '_libzmq'), reason="bundled libzmq")
|
||||
def test_has_curve():
|
||||
"""bundled libzmq has curve support"""
|
||||
assert zmq.has('curve')
|
353
venv/Lib/site-packages/zmq/tests/test_future.py
Normal file
353
venv/Lib/site-packages/zmq/tests/test_future.py
Normal file
|
@ -0,0 +1,353 @@
|
|||
# coding: utf-8
|
||||
# Copyright (c) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from datetime import timedelta
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
gen = pytest.importorskip('tornado.gen')
|
||||
|
||||
import zmq
|
||||
from zmq.eventloop import future
|
||||
from tornado.ioloop import IOLoop
|
||||
from zmq.utils.strtypes import u
|
||||
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
class TestFutureSocket(BaseZMQTestCase):
|
||||
Context = future.Context
|
||||
|
||||
def setUp(self):
|
||||
self.loop = IOLoop()
|
||||
self.loop.make_current()
|
||||
super(TestFutureSocket, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestFutureSocket, self).tearDown()
|
||||
if self.loop:
|
||||
self.loop.close(all_fds=True)
|
||||
IOLoop.clear_current()
|
||||
IOLoop.clear_instance()
|
||||
|
||||
def test_socket_class(self):
|
||||
s = self.context.socket(zmq.PUSH)
|
||||
assert isinstance(s, future.Socket)
|
||||
s.close()
|
||||
|
||||
def test_instance_subclass_first(self):
|
||||
actx = self.Context.instance()
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_instance_subclass_second(self):
|
||||
ctx = zmq.Context.instance()
|
||||
actx = self.Context.instance()
|
||||
ctx.term()
|
||||
actx.term()
|
||||
assert type(ctx) is zmq.Context
|
||||
assert type(actx) is self.Context
|
||||
|
||||
def test_recv_multipart(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_multipart()
|
||||
assert not f.done()
|
||||
yield a.send(b'hi')
|
||||
recvd = yield f
|
||||
self.assertEqual(recvd, [b'hi'])
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv()
|
||||
assert not f1.done()
|
||||
assert not f2.done()
|
||||
yield a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield f2
|
||||
assert f1.done()
|
||||
self.assertEqual(f1.result(), b'hi')
|
||||
self.assertEqual(recvd, b'there')
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_cancel(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f1 = b.recv()
|
||||
f2 = b.recv_multipart()
|
||||
assert f1.cancel()
|
||||
assert f1.done()
|
||||
assert not f2.done()
|
||||
yield a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield f2
|
||||
assert f1.cancelled()
|
||||
assert f2.done()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_recv_timeout(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
b.rcvtimeo = 100
|
||||
f1 = b.recv()
|
||||
b.rcvtimeo = 1000
|
||||
f2 = b.recv_multipart()
|
||||
with pytest.raises(zmq.Again):
|
||||
yield f1
|
||||
yield a.send_multipart([b'hi', b'there'])
|
||||
recvd = yield f2
|
||||
assert f2.done()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_send_timeout(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
s.sndtimeo = 100
|
||||
with pytest.raises(zmq.Again):
|
||||
yield s.send(b'not going anywhere')
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.now
|
||||
def test_send_noblock(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
yield s.send(b'not going anywhere', flags=zmq.NOBLOCK)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.now
|
||||
def test_send_multipart_noblock(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
s = self.socket(zmq.PUSH)
|
||||
with pytest.raises(zmq.Again):
|
||||
yield s.send_multipart([b'not going anywhere'], flags=zmq.NOBLOCK)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_string(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_string()
|
||||
assert not f.done()
|
||||
msg = u('πøøπ')
|
||||
yield a.send_string(msg)
|
||||
recvd = yield f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), msg)
|
||||
self.assertEqual(recvd, msg)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
yield a.send_json(obj)
|
||||
recvd = yield f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), obj)
|
||||
self.assertEqual(recvd, obj)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_json_cancelled(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_json()
|
||||
assert not f.done()
|
||||
f.cancel()
|
||||
# cycle eventloop to allow cancel events to fire
|
||||
yield gen.sleep(0)
|
||||
obj = dict(a=5)
|
||||
yield a.send_json(obj)
|
||||
with pytest.raises(future.CancelledError):
|
||||
recvd = yield f
|
||||
assert f.done()
|
||||
# give it a chance to incorrectly consume the event
|
||||
events = yield b.poll(timeout=5)
|
||||
assert events
|
||||
yield gen.sleep(0)
|
||||
# make sure cancelled recv didn't eat up event
|
||||
recvd = yield gen.with_timeout(timedelta(seconds=5), b.recv_json())
|
||||
assert recvd == obj
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_recv_pyobj(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.recv_pyobj()
|
||||
assert not f.done()
|
||||
obj = dict(a=5)
|
||||
yield a.send_pyobj(obj)
|
||||
recvd = yield f
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), obj)
|
||||
self.assertEqual(recvd, obj)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize(self):
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
yield a.send_serialized(msg, serialize)
|
||||
recvd = yield b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
yield b.send_serialized(recvd, serialize)
|
||||
r2 = yield a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_custom_serialize_error(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError):
|
||||
yield a.send_serialized(json, json.dumps)
|
||||
|
||||
yield a.send(b'not json')
|
||||
with pytest.raises(TypeError):
|
||||
recvd = yield b.recv_serialized(json.loads)
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_poll(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
f = b.poll(timeout=0)
|
||||
assert f.done()
|
||||
self.assertEqual(f.result(), 0)
|
||||
|
||||
f = b.poll(timeout=1)
|
||||
assert not f.done()
|
||||
evt = yield f
|
||||
self.assertEqual(evt, 0)
|
||||
|
||||
f = b.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
yield a.send_multipart([b'hi', b'there'])
|
||||
evt = yield f
|
||||
self.assertEqual(evt, zmq.POLLIN)
|
||||
recvd = yield b.recv_multipart()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
self.loop.run_sync(test)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Windows unsupported socket type')
|
||||
def test_poll_base_socket(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
ctx = zmq.Context()
|
||||
url = 'inproc://test'
|
||||
a = ctx.socket(zmq.PUSH)
|
||||
b = ctx.socket(zmq.PULL)
|
||||
self.sockets.extend([a, b])
|
||||
a.bind(url)
|
||||
b.connect(url)
|
||||
|
||||
poller = future.Poller()
|
||||
poller.register(b, zmq.POLLIN)
|
||||
|
||||
f = poller.poll(timeout=1000)
|
||||
assert not f.done()
|
||||
a.send_multipart([b'hi', b'there'])
|
||||
evt = yield f
|
||||
self.assertEqual(evt, [(b, zmq.POLLIN)])
|
||||
recvd = b.recv_multipart()
|
||||
self.assertEqual(recvd, [b'hi', b'there'])
|
||||
a.close()
|
||||
b.close()
|
||||
ctx.term()
|
||||
self.loop.run_sync(test)
|
||||
|
||||
def test_close_all_fds(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
self.loop.close(all_fds=True)
|
||||
self.loop = None # avoid second close later
|
||||
assert s.closed
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith('win'),
|
||||
reason='Windows does not support polling on files')
|
||||
def test_poll_raw(self):
|
||||
@gen.coroutine
|
||||
def test():
|
||||
p = future.Poller()
|
||||
# make a pipe
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
|
||||
# POLLOUT
|
||||
p.register(r, zmq.POLLIN)
|
||||
p.register(w, zmq.POLLOUT)
|
||||
evts = yield p.poll(timeout=1)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() not in evts
|
||||
assert w.fileno() in evts
|
||||
assert evts[w.fileno()] == zmq.POLLOUT
|
||||
|
||||
# POLLIN
|
||||
p.unregister(w)
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
evts = yield p.poll(timeout=1000)
|
||||
evts = dict(evts)
|
||||
assert r.fileno() in evts
|
||||
assert evts[r.fileno()] == zmq.POLLIN
|
||||
assert r.read(1) == b'x'
|
||||
r.close()
|
||||
w.close()
|
||||
self.loop.run_sync(test)
|
68
venv/Lib/site-packages/zmq/tests/test_imports.py
Normal file
68
venv/Lib/site-packages/zmq/tests/test_imports.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import sys
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
|
||||
class TestImports(TestCase):
|
||||
"""Test Imports - the quickest test to ensure that we haven't
|
||||
introduced version-incompatible syntax errors."""
|
||||
|
||||
def test_toplevel(self):
|
||||
"""test toplevel import"""
|
||||
import zmq
|
||||
|
||||
def test_core(self):
|
||||
"""test core imports"""
|
||||
from zmq import Context
|
||||
from zmq import Socket
|
||||
from zmq import Poller
|
||||
from zmq import Frame
|
||||
from zmq import constants
|
||||
from zmq import device, proxy
|
||||
from zmq import (
|
||||
zmq_version,
|
||||
zmq_version_info,
|
||||
pyzmq_version,
|
||||
pyzmq_version_info,
|
||||
)
|
||||
|
||||
def test_devices(self):
|
||||
"""test device imports"""
|
||||
import zmq.devices
|
||||
from zmq.devices import basedevice
|
||||
from zmq.devices import monitoredqueue
|
||||
from zmq.devices import monitoredqueuedevice
|
||||
|
||||
def test_log(self):
|
||||
"""test log imports"""
|
||||
import zmq.log
|
||||
from zmq.log import handlers
|
||||
|
||||
def test_eventloop(self):
|
||||
"""test eventloop imports"""
|
||||
try:
|
||||
import tornado
|
||||
except ImportError:
|
||||
pytest.skip('requires tornado')
|
||||
import zmq.eventloop
|
||||
from zmq.eventloop import ioloop
|
||||
from zmq.eventloop import zmqstream
|
||||
|
||||
def test_utils(self):
|
||||
"""test util imports"""
|
||||
import zmq.utils
|
||||
from zmq.utils import strtypes
|
||||
from zmq.utils import jsonapi
|
||||
|
||||
def test_ssh(self):
|
||||
"""test ssh imports"""
|
||||
from zmq.ssh import tunnel
|
||||
|
||||
def test_decorators(self):
|
||||
"""test decorators imports"""
|
||||
from zmq.decorators import context, socket
|
||||
|
||||
|
33
venv/Lib/site-packages/zmq/tests/test_includes.py
Normal file
33
venv/Lib/site-packages/zmq/tests/test_includes.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
import zmq
|
||||
import os
|
||||
|
||||
class TestIncludes(TestCase):
|
||||
|
||||
def test_get_includes(self):
|
||||
from os.path import dirname, basename
|
||||
includes = zmq.get_includes()
|
||||
self.assertTrue(isinstance(includes, list))
|
||||
self.assertTrue(len(includes) >= 2)
|
||||
parent = includes[0]
|
||||
self.assertTrue(isinstance(parent, str))
|
||||
utilsdir = includes[1]
|
||||
self.assertTrue(isinstance(utilsdir, str))
|
||||
utils = basename(utilsdir)
|
||||
self.assertEqual(utils, "utils")
|
||||
|
||||
def test_get_library_dirs(self):
|
||||
from os.path import dirname, basename
|
||||
libdirs = zmq.get_library_dirs()
|
||||
self.assertTrue(isinstance(libdirs, list))
|
||||
self.assertEqual(len(libdirs), 1)
|
||||
parent = libdirs[0]
|
||||
self.assertTrue(isinstance(parent, str))
|
||||
libdir = basename(parent)
|
||||
self.assertEqual(libdir, "zmq")
|
||||
|
||||
|
141
venv/Lib/site-packages/zmq/tests/test_ioloop.py
Normal file
141
venv/Lib/site-packages/zmq/tests/test_ioloop.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, have_gevent
|
||||
try:
|
||||
from tornado.ioloop import IOLoop as BaseIOLoop
|
||||
from zmq.eventloop import ioloop
|
||||
_tornado = True
|
||||
except ImportError:
|
||||
_tornado = False
|
||||
|
||||
|
||||
# tornado 5 with asyncio disables custom IOLoop implementations
|
||||
t5asyncio = False
|
||||
if _tornado:
|
||||
import tornado
|
||||
if tornado.version_info >= (5,) and asyncio:
|
||||
t5asyncio = True
|
||||
|
||||
def printer():
|
||||
os.system("say hello")
|
||||
raise Exception
|
||||
print (time.time())
|
||||
|
||||
|
||||
class Delay(threading.Thread):
|
||||
def __init__(self, f, delay=1):
|
||||
self.f=f
|
||||
self.delay=delay
|
||||
self.aborted=False
|
||||
self.cond=threading.Condition()
|
||||
super(Delay, self).__init__()
|
||||
|
||||
def run(self):
|
||||
self.cond.acquire()
|
||||
self.cond.wait(self.delay)
|
||||
self.cond.release()
|
||||
if not self.aborted:
|
||||
self.f()
|
||||
|
||||
def abort(self):
|
||||
self.aborted=True
|
||||
self.cond.acquire()
|
||||
self.cond.notify()
|
||||
self.cond.release()
|
||||
|
||||
|
||||
class TestIOLoop(BaseZMQTestCase):
|
||||
if _tornado:
|
||||
IOLoop = ioloop.IOLoop
|
||||
|
||||
def setUp(self):
|
||||
if not _tornado:
|
||||
pytest.skip("tornado required")
|
||||
super(TestIOLoop, self).setUp()
|
||||
if asyncio:
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
|
||||
def tearDown(self):
|
||||
super(TestIOLoop, self).tearDown()
|
||||
BaseIOLoop.clear_current()
|
||||
BaseIOLoop.clear_instance()
|
||||
|
||||
def test_simple(self):
|
||||
"""simple IOLoop creation test"""
|
||||
loop = self.IOLoop()
|
||||
loop.make_current()
|
||||
dc = ioloop.PeriodicCallback(loop.stop, 200)
|
||||
pc = ioloop.PeriodicCallback(lambda : None, 10)
|
||||
pc.start()
|
||||
dc.start()
|
||||
t = Delay(loop.stop,1)
|
||||
t.start()
|
||||
loop.start()
|
||||
if t.is_alive():
|
||||
t.abort()
|
||||
else:
|
||||
self.fail("IOLoop failed to exit")
|
||||
|
||||
def test_instance(self):
|
||||
"""IOLoop.instance returns the right object"""
|
||||
loop = self.IOLoop.instance()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.instance()
|
||||
assert base_loop is loop
|
||||
|
||||
def test_current(self):
|
||||
"""IOLoop.current returns the right object"""
|
||||
loop = ioloop.IOLoop.current()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.current()
|
||||
assert base_loop is loop
|
||||
|
||||
def test_close_all(self):
|
||||
"""Test close(all_fds=True)"""
|
||||
loop = self.IOLoop.current()
|
||||
req,rep = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ)
|
||||
loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ)
|
||||
self.assertEqual(req.closed, False)
|
||||
self.assertEqual(rep.closed, False)
|
||||
loop.close(all_fds=True)
|
||||
self.assertEqual(req.closed, True)
|
||||
self.assertEqual(rep.closed, True)
|
||||
|
||||
|
||||
if have_gevent and _tornado:
|
||||
import zmq.green.eventloop.ioloop as green_ioloop
|
||||
|
||||
class TestIOLoopGreen(TestIOLoop):
|
||||
IOLoop = green_ioloop.IOLoop
|
||||
def xtest_instance(self):
|
||||
"""Green IOLoop.instance returns the right object"""
|
||||
loop = self.IOLoop.instance()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.instance()
|
||||
assert base_loop is loop
|
||||
|
||||
def xtest_current(self):
|
||||
"""Green IOLoop.current returns the right object"""
|
||||
loop = self.IOLoop.current()
|
||||
if not t5asyncio:
|
||||
assert isinstance(loop, self.IOLoop)
|
||||
base_loop = BaseIOLoop.current()
|
||||
assert base_loop is loop
|
||||
|
178
venv/Lib/site-packages/zmq/tests/test_log.py
Normal file
178
venv/Lib/site-packages/zmq/tests/test_log.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
# encoding: utf-8
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import logging
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq.log import handlers
|
||||
from zmq.utils.strtypes import b, u
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
|
||||
|
||||
class TestPubLog(BaseZMQTestCase):
|
||||
|
||||
iface = 'inproc://zmqlog'
|
||||
topic= 'zmq'
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
# print dir(self)
|
||||
logger = logging.getLogger('zmqtest')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return logger
|
||||
|
||||
def connect_handler(self, topic=None):
|
||||
topic = self.topic if topic is None else topic
|
||||
logger = self.logger
|
||||
pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = topic
|
||||
logger.addHandler(handler)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(topic))
|
||||
time.sleep(0.1)
|
||||
return logger, handler, sub
|
||||
|
||||
def test_init_iface(self):
|
||||
logger = self.logger
|
||||
ctx = self.context
|
||||
handler = handlers.PUBHandler(self.iface)
|
||||
self.assertFalse(handler.ctx is ctx)
|
||||
self.sockets.append(handler.socket)
|
||||
# handler.ctx.term()
|
||||
handler = handlers.PUBHandler(self.iface, self.context)
|
||||
self.sockets.append(handler.socket)
|
||||
self.assertTrue(handler.ctx is ctx)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
sub = ctx.socket(zmq.SUB)
|
||||
self.sockets.append(sub)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(self.topic))
|
||||
sub.connect(self.iface)
|
||||
import time; time.sleep(0.25)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
self.assertEqual(topic, b'zmq.INFO')
|
||||
self.assertEqual(msg2, b(msg1)+b'\n')
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_init_socket(self):
|
||||
pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
logger = self.logger
|
||||
handler = handlers.PUBHandler(pub)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.root_topic = self.topic
|
||||
logger.addHandler(handler)
|
||||
|
||||
self.assertTrue(handler.socket is pub)
|
||||
self.assertTrue(handler.ctx is pub.context)
|
||||
self.assertTrue(handler.ctx is self.context)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(self.topic))
|
||||
import time; time.sleep(0.1)
|
||||
msg1 = 'message'
|
||||
logger.info(msg1)
|
||||
|
||||
(topic, msg2) = sub.recv_multipart()
|
||||
self.assertEqual(topic, b'zmq.INFO')
|
||||
self.assertEqual(msg2, b(msg1)+b'\n')
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_root_topic(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.socket.bind(self.iface)
|
||||
sub2 = sub.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub2)
|
||||
sub2.connect(self.iface)
|
||||
sub2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.root_topic = b'twoonly'
|
||||
msg1 = 'ignored'
|
||||
logger.info(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK)
|
||||
topic,msg2 = sub2.recv_multipart()
|
||||
self.assertEqual(topic, b'twoonly.INFO')
|
||||
self.assertEqual(msg2, b(msg1)+b'\n')
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_blank_root_topic(self):
|
||||
logger, handler, sub_everything = self.connect_handler()
|
||||
sub_everything.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
handler.socket.bind(self.iface)
|
||||
sub_only_info = sub_everything.context.socket(zmq.SUB)
|
||||
self.sockets.append(sub_only_info)
|
||||
sub_only_info.connect(self.iface)
|
||||
sub_only_info.setsockopt(zmq.SUBSCRIBE, b'INFO')
|
||||
handler.setRootTopic(b'')
|
||||
msg_debug = 'debug_message'
|
||||
logger.debug(msg_debug)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, sub_only_info.recv, zmq.NOBLOCK)
|
||||
topic, msg_debug_response = sub_everything.recv_multipart()
|
||||
self.assertEqual(topic, b'DEBUG')
|
||||
msg_info = 'info_message'
|
||||
logger.info(msg_info)
|
||||
topic, msg_info_response_everything = sub_everything.recv_multipart()
|
||||
self.assertEqual(topic, b'INFO')
|
||||
topic, msg_info_response_onlyinfo = sub_only_info.recv_multipart()
|
||||
self.assertEqual(topic, b'INFO')
|
||||
self.assertEqual(msg_info_response_everything, msg_info_response_onlyinfo)
|
||||
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_unicode_message(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
base_topic = b(self.topic + '.INFO')
|
||||
for msg, expected in [
|
||||
(u('hello'), [base_topic, b('hello\n')]),
|
||||
(u('héllo'), [base_topic, b('héllo\n')]),
|
||||
(u('tøpic::héllo'), [base_topic + b('.tøpic'), b('héllo\n')]),
|
||||
]:
|
||||
logger.info(msg)
|
||||
received = sub.recv_multipart()
|
||||
self.assertEqual(received, expected)
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_set_info_formatter_via_property(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
handler.formatters[logging.INFO] = logging.Formatter("%(message)s UNITTEST\n")
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
self.assertEqual(msg, b'info message UNITTEST\n')
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_global_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
self.assertEqual(msg, b'UNITTEST info message')
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
self.assertEqual(msg, b'UNITTEST debug message')
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_custom_debug_formatter(self):
|
||||
logger, handler, sub = self.connect_handler()
|
||||
formatter = logging.Formatter("UNITTEST DEBUG %(message)s")
|
||||
handler.setFormatter(formatter, logging.DEBUG)
|
||||
handler.socket.bind(self.iface)
|
||||
sub.setsockopt(zmq.SUBSCRIBE, b(handler.root_topic))
|
||||
logger.info('info message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
self.assertEqual(msg, b'info message\n')
|
||||
logger.debug('debug message')
|
||||
topic, msg = sub.recv_multipart()
|
||||
self.assertEqual(msg, b'UNITTEST DEBUG debug message')
|
||||
logger.removeHandler(handler)
|
348
venv/Lib/site-packages/zmq/tests/test_message.py
Normal file
348
venv/Lib/site-packages/zmq/tests/test_message.py
Normal file
|
@ -0,0 +1,348 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import copy
|
||||
import sys
|
||||
try:
|
||||
from sys import getrefcount as grc
|
||||
except ImportError:
|
||||
grc = None
|
||||
|
||||
import time
|
||||
from pprint import pprint
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY
|
||||
from zmq.utils.strtypes import unicode, bytes, b, u
|
||||
|
||||
|
||||
# some useful constants:
|
||||
|
||||
x = b'x'
|
||||
|
||||
if grc:
|
||||
rc0 = grc(x)
|
||||
v = memoryview(x)
|
||||
view_rc = grc(x) - rc0
|
||||
|
||||
def await_gc(obj, rc):
|
||||
"""wait for refcount on an object to drop to an expected value
|
||||
|
||||
Necessary because of the zero-copy gc thread,
|
||||
which can take some time to receive its DECREF message.
|
||||
"""
|
||||
for i in range(50):
|
||||
# rc + 2 because of the refs in this function
|
||||
if grc(obj) <= rc + 2:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
|
||||
class TestFrame(BaseZMQTestCase):
|
||||
|
||||
@skip_pypy
|
||||
def test_above_30(self):
|
||||
"""Message above 30 bytes are never copied by 0MQ."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i)*x
|
||||
self.assertEqual(grc(s), 2)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
self.assertEqual(grc(s), 4)
|
||||
del m
|
||||
await_gc(s, 2)
|
||||
self.assertEqual(grc(s), 2)
|
||||
del s
|
||||
|
||||
def test_str(self):
|
||||
"""Test the str representations of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i)*x
|
||||
m = zmq.Frame(s)
|
||||
m_str = str(m)
|
||||
m_str_b = b(m_str) # py3compat
|
||||
self.assertEqual(s, m_str_b)
|
||||
|
||||
def test_bytes(self):
|
||||
"""Test the Frame.bytes property."""
|
||||
for i in range(1,16):
|
||||
s = (2**i)*x
|
||||
m = zmq.Frame(s)
|
||||
b = m.bytes
|
||||
self.assertEqual(s, m.bytes)
|
||||
if not PYPY:
|
||||
# check that it copies
|
||||
self.assert_(b is not s)
|
||||
# check that it copies only once
|
||||
self.assert_(b is m.bytes)
|
||||
|
||||
def test_unicode(self):
|
||||
"""Test the unicode representations of the Frames."""
|
||||
s = u('asdf')
|
||||
self.assertRaises(TypeError, zmq.Frame, s)
|
||||
for i in range(16):
|
||||
s = (2**i)*u('§')
|
||||
m = zmq.Frame(s.encode('utf8'))
|
||||
self.assertEqual(s, unicode(m.bytes,'utf8'))
|
||||
|
||||
def test_len(self):
|
||||
"""Test the len of the Frames."""
|
||||
for i in range(16):
|
||||
s = (2**i)*x
|
||||
m = zmq.Frame(s)
|
||||
self.assertEqual(len(s), len(m))
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle1(self):
|
||||
"""Run through a ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i)*x
|
||||
rc = 2
|
||||
self.assertEqual(grc(s), rc)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
self.assertEqual(grc(s), rc)
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
self.assertEqual(grc(s), rc)
|
||||
buf = m2.buffer
|
||||
|
||||
rc += view_rc
|
||||
self.assertEqual(grc(s), rc)
|
||||
|
||||
self.assertEqual(s, b(str(m)))
|
||||
self.assertEqual(s, bytes(m2))
|
||||
self.assertEqual(s, m.bytes)
|
||||
# self.assert_(s is str(m))
|
||||
# self.assert_(s is str(m2))
|
||||
del m2
|
||||
rc -= 1
|
||||
self.assertEqual(grc(s), rc)
|
||||
rc -= view_rc
|
||||
del buf
|
||||
self.assertEqual(grc(s), rc)
|
||||
del m
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
self.assertEqual(grc(s), rc)
|
||||
self.assertEqual(rc, 2)
|
||||
del s
|
||||
|
||||
@skip_pypy
|
||||
def test_lifecycle2(self):
|
||||
"""Run through a different ref counting cycle with a copy."""
|
||||
for i in range(5, 16): # 32, 64,..., 65536
|
||||
s = (2**i)*x
|
||||
rc = 2
|
||||
self.assertEqual(grc(s), rc)
|
||||
m = zmq.Frame(s, copy=False)
|
||||
rc += 2
|
||||
self.assertEqual(grc(s), rc)
|
||||
m2 = copy.copy(m)
|
||||
rc += 1
|
||||
self.assertEqual(grc(s), rc)
|
||||
buf = m.buffer
|
||||
rc += view_rc
|
||||
self.assertEqual(grc(s), rc)
|
||||
self.assertEqual(s, b(str(m)))
|
||||
self.assertEqual(s, bytes(m2))
|
||||
self.assertEqual(s, m2.bytes)
|
||||
self.assertEqual(s, m.bytes)
|
||||
# self.assert_(s is str(m))
|
||||
# self.assert_(s is str(m2))
|
||||
del buf
|
||||
self.assertEqual(grc(s), rc)
|
||||
del m
|
||||
# m.buffer is kept until m is del'd
|
||||
rc -= view_rc
|
||||
rc -= 1
|
||||
self.assertEqual(grc(s), rc)
|
||||
del m2
|
||||
rc -= 2
|
||||
await_gc(s, rc)
|
||||
self.assertEqual(grc(s), rc)
|
||||
self.assertEqual(rc, 2)
|
||||
del s
|
||||
|
||||
@skip_pypy
|
||||
def test_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
self.assertFalse(m.tracker.done)
|
||||
pm = zmq.MessageTracker(m)
|
||||
self.assertFalse(pm.done)
|
||||
del m
|
||||
for i in range(10):
|
||||
if pm.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(pm.done)
|
||||
|
||||
def test_no_tracker(self):
|
||||
m = zmq.Frame(b'asdf', track=False)
|
||||
self.assertEqual(m.tracker, None)
|
||||
m2 = copy.copy(m)
|
||||
self.assertEqual(m2.tracker, None)
|
||||
self.assertRaises(ValueError, zmq.MessageTracker, m)
|
||||
|
||||
@skip_pypy
|
||||
def test_multi_tracker(self):
|
||||
m = zmq.Frame(b'asdf', copy=False, track=True)
|
||||
m2 = zmq.Frame(b'whoda', copy=False, track=True)
|
||||
mt = zmq.MessageTracker(m,m2)
|
||||
self.assertFalse(m.tracker.done)
|
||||
self.assertFalse(mt.done)
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
del m
|
||||
time.sleep(0.1)
|
||||
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
||||
self.assertFalse(mt.done)
|
||||
del m2
|
||||
self.assertTrue(mt.wait() is None)
|
||||
self.assertTrue(mt.done)
|
||||
|
||||
|
||||
def test_buffer_in(self):
|
||||
"""test using a buffer as input"""
|
||||
ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
|
||||
m = zmq.Frame(memoryview(ins))
|
||||
|
||||
def test_bad_buffer_in(self):
|
||||
"""test using a bad object"""
|
||||
self.assertRaises(TypeError, zmq.Frame, 5)
|
||||
self.assertRaises(TypeError, zmq.Frame, object())
|
||||
|
||||
def test_buffer_out(self):
|
||||
"""receiving buffered output"""
|
||||
ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
|
||||
m = zmq.Frame(ins)
|
||||
outb = m.buffer
|
||||
self.assertTrue(isinstance(outb, memoryview))
|
||||
self.assert_(outb is m.buffer)
|
||||
self.assert_(m.buffer is m.buffer)
|
||||
|
||||
@skip_pypy
|
||||
def test_memoryview_shape(self):
|
||||
"""memoryview shape info"""
|
||||
if sys.version_info < (3,):
|
||||
raise SkipTest("only test memoryviews on Python 3")
|
||||
data = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
|
||||
n = len(data)
|
||||
f = zmq.Frame(data)
|
||||
view1 = f.buffer
|
||||
self.assertEqual(view1.ndim, 1)
|
||||
self.assertEqual(view1.shape, (n,))
|
||||
self.assertEqual(view1.tobytes(), data)
|
||||
view2 = memoryview(f)
|
||||
self.assertEqual(view2.ndim, 1)
|
||||
self.assertEqual(view2.shape, (n,))
|
||||
self.assertEqual(view2.tobytes(), data)
|
||||
|
||||
def test_multisend(self):
|
||||
"""ensure that a message remains intact after multiple sends"""
|
||||
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
s = b"message"
|
||||
m = zmq.Frame(s)
|
||||
self.assertEqual(s, m.bytes)
|
||||
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(s, m.bytes)
|
||||
a.send(m, copy=False)
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(s, m.bytes)
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(s, m.bytes)
|
||||
a.send(m, copy=True)
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(s, m.bytes)
|
||||
for i in range(4):
|
||||
r = b.recv()
|
||||
self.assertEqual(s,r)
|
||||
self.assertEqual(s, m.bytes)
|
||||
|
||||
def test_memoryview(self):
|
||||
"""test messages from memoryview"""
|
||||
major,minor = sys.version_info[:2]
|
||||
if not (major >= 3 or (major == 2 and minor >= 7)):
|
||||
raise SkipTest("memoryviews only in python >= 2.7")
|
||||
|
||||
s = b'carrotjuice'
|
||||
v = memoryview(s)
|
||||
m = zmq.Frame(s)
|
||||
buf = m.buffer
|
||||
s2 = buf.tobytes()
|
||||
self.assertEqual(s2,s)
|
||||
self.assertEqual(m.bytes,s)
|
||||
|
||||
def test_noncopying_recv(self):
|
||||
"""check for clobbering message buffers"""
|
||||
null = b'\0'*64
|
||||
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
for i in range(32):
|
||||
# try a few times
|
||||
sb.send(null, copy=False)
|
||||
m = sa.recv(copy=False)
|
||||
mb = m.bytes
|
||||
# buf = memoryview(m)
|
||||
buf = m.buffer
|
||||
del m
|
||||
for i in range(5):
|
||||
ff=b'\xff'*(40 + i*10)
|
||||
sb.send(ff, copy=False)
|
||||
m2 = sa.recv(copy=False)
|
||||
b = buf.tobytes()
|
||||
self.assertEqual(b, null)
|
||||
self.assertEqual(mb, null)
|
||||
self.assertEqual(m2.bytes, ff)
|
||||
|
||||
@skip_pypy
|
||||
def test_buffer_numpy(self):
|
||||
"""test non-copying numpy array messages"""
|
||||
try:
|
||||
import numpy
|
||||
from numpy.testing import assert_array_equal
|
||||
except ImportError:
|
||||
raise SkipTest("requires numpy")
|
||||
if sys.version_info < (2,7):
|
||||
raise SkipTest("requires new-style buffer interface (py >= 2.7)")
|
||||
rand = numpy.random.randint
|
||||
shapes = [ rand(2,5) for i in range(5) ]
|
||||
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
dtypes = [int, float, '>i4', 'B']
|
||||
for i in range(1,len(shapes)+1):
|
||||
shape = shapes[:i]
|
||||
for dt in dtypes:
|
||||
A = numpy.empty(shape, dtype=dt)
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
|
||||
A['a'] = 1024
|
||||
A['b'] = 1e9
|
||||
A['c'] = 'hello there'
|
||||
a.send(A, copy=False)
|
||||
msg = b.recv(copy=False)
|
||||
|
||||
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
||||
assert_array_equal(A, B)
|
||||
|
||||
def test_frame_more(self):
|
||||
"""test Frame.more attribute"""
|
||||
frame = zmq.Frame(b"hello")
|
||||
self.assertFalse(frame.more)
|
||||
sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
sa.send_multipart([b'hi', b'there'])
|
||||
frame = self.recv(sb, copy=False)
|
||||
self.assertTrue(frame.more)
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
self.assertTrue(frame.get(zmq.MORE))
|
||||
frame = self.recv(sb, copy=False)
|
||||
self.assertFalse(frame.more)
|
||||
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
||||
self.assertFalse(frame.get(zmq.MORE))
|
||||
|
83
venv/Lib/site-packages/zmq/tests/test_monitor.py
Normal file
83
venv/Lib/site-packages/zmq/tests/test_monitor.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import errno
|
||||
import sys
|
||||
import time
|
||||
import struct
|
||||
|
||||
from unittest import TestCase
|
||||
from pytest import mark
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, skip_pypy, require_zmq_4
|
||||
from zmq.utils.monitor import recv_monitor_message
|
||||
|
||||
|
||||
class TestSocketMonitor(BaseZMQTestCase):
|
||||
|
||||
@require_zmq_4
|
||||
def test_monitor(self):
|
||||
"""Test monitoring interface for sockets."""
|
||||
s_rep = self.context.socket(zmq.REP)
|
||||
s_req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([s_rep, s_req])
|
||||
s_req.bind("tcp://127.0.0.1:6666")
|
||||
# try monitoring the REP socket
|
||||
|
||||
s_rep.monitor("inproc://monitor.rep", zmq.EVENT_CONNECT_DELAYED | zmq.EVENT_CONNECTED | zmq.EVENT_MONITOR_STOPPED)
|
||||
# create listening socket for monitor
|
||||
s_event = self.context.socket(zmq.PAIR)
|
||||
self.sockets.append(s_event)
|
||||
s_event.connect("inproc://monitor.rep")
|
||||
s_event.linger = 0
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6666")
|
||||
m = recv_monitor_message(s_event)
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666")
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
self.assertEqual(m['event'], zmq.EVENT_CONNECTED)
|
||||
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666")
|
||||
|
||||
# test monitor can be disabled.
|
||||
s_rep.disable_monitor()
|
||||
m = recv_monitor_message(s_event)
|
||||
self.assertEqual(m['event'], zmq.EVENT_MONITOR_STOPPED)
|
||||
|
||||
@require_zmq_4
|
||||
def test_monitor_repeat(self):
|
||||
s = self.socket(zmq.PULL)
|
||||
m = s.get_monitor_socket()
|
||||
self.sockets.append(m)
|
||||
m2 = s.get_monitor_socket()
|
||||
assert m is m2
|
||||
s.disable_monitor()
|
||||
evt = recv_monitor_message(m)
|
||||
self.assertEqual(evt['event'], zmq.EVENT_MONITOR_STOPPED)
|
||||
m.close()
|
||||
s.close()
|
||||
|
||||
@require_zmq_4
|
||||
def test_monitor_connected(self):
|
||||
"""Test connected monitoring socket."""
|
||||
s_rep = self.context.socket(zmq.REP)
|
||||
s_req = self.context.socket(zmq.REQ)
|
||||
self.sockets.extend([s_rep, s_req])
|
||||
s_req.bind("tcp://127.0.0.1:6667")
|
||||
# try monitoring the REP socket
|
||||
# create listening socket for monitor
|
||||
s_event = s_rep.get_monitor_socket()
|
||||
s_event.linger = 0
|
||||
self.sockets.append(s_event)
|
||||
# test receive event for connect event
|
||||
s_rep.connect("tcp://127.0.0.1:6667")
|
||||
m = recv_monitor_message(s_event)
|
||||
if m['event'] == zmq.EVENT_CONNECT_DELAYED:
|
||||
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667")
|
||||
# test receive event for connected event
|
||||
m = recv_monitor_message(s_event)
|
||||
self.assertEqual(m['event'], zmq.EVENT_CONNECTED)
|
||||
self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667")
|
221
venv/Lib/site-packages/zmq/tests/test_monqueue.py
Normal file
221
venv/Lib/site-packages/zmq/tests/test_monqueue.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
|
||||
from zmq.utils.strtypes import unicode
|
||||
|
||||
|
||||
if PYPY or zmq.zmq_version_info() >= (4,1):
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestMonitoredQueue(BaseZMQTestCase):
|
||||
|
||||
sockets = []
|
||||
|
||||
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
|
||||
self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB,
|
||||
in_prefix, out_prefix)
|
||||
alice = self.context.socket(zmq.PAIR)
|
||||
bob = self.context.socket(zmq.PAIR)
|
||||
mon = self.context.socket(zmq.SUB)
|
||||
|
||||
aport = alice.bind_to_random_port('tcp://127.0.0.1')
|
||||
bport = bob.bind_to_random_port('tcp://127.0.0.1')
|
||||
mport = mon.bind_to_random_port('tcp://127.0.0.1')
|
||||
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
|
||||
|
||||
self.device.connect_in("tcp://127.0.0.1:%i"%aport)
|
||||
self.device.connect_out("tcp://127.0.0.1:%i"%bport)
|
||||
self.device.connect_mon("tcp://127.0.0.1:%i"%mport)
|
||||
self.device.start()
|
||||
time.sleep(.2)
|
||||
try:
|
||||
# this is currenlty necessary to ensure no dropped monitor messages
|
||||
# see LIBZMQ-248 for more info
|
||||
mon.recv_multipart(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
self.sockets.extend([alice, bob, mon])
|
||||
return alice, bob, mon
|
||||
|
||||
def teardown_device(self):
|
||||
for socket in self.sockets:
|
||||
socket.close()
|
||||
del socket
|
||||
del self.device
|
||||
|
||||
def test_reply(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices, bobs)
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
self.assertEqual(alices, bobs)
|
||||
self.teardown_device()
|
||||
|
||||
def test_queue(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices2, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices3, bobs)
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
self.assertEqual(alices, bobs)
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor(self):
|
||||
alice, bob, mon = self.build_device()
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'in']+bobs, mons)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices2, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices3, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'in']+alices2, mons)
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
self.assertEqual(alices, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'in']+alices3, mons)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'out']+bobs, mons)
|
||||
self.teardown_device()
|
||||
|
||||
def test_prefix(self):
|
||||
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'foo']+bobs, mons)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices2, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices3, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'foo']+alices2, mons)
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
self.assertEqual(alices, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'foo']+alices3, mons)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'bar']+bobs, mons)
|
||||
self.teardown_device()
|
||||
|
||||
def test_monitor_subscribe(self):
|
||||
alice, bob, mon = self.build_device(b"out")
|
||||
alices = b"hello bob".split()
|
||||
alice.send_multipart(alices)
|
||||
alices2 = b"hello again".split()
|
||||
alice.send_multipart(alices2)
|
||||
alices3 = b"hello again and again".split()
|
||||
alice.send_multipart(alices3)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices2, bobs)
|
||||
bobs = self.recv_multipart(bob)
|
||||
self.assertEqual(alices3, bobs)
|
||||
bobs = b"hello alice".split()
|
||||
bob.send_multipart(bobs)
|
||||
alices = self.recv_multipart(alice)
|
||||
self.assertEqual(alices, bobs)
|
||||
mons = self.recv_multipart(mon)
|
||||
self.assertEqual([b'out']+bobs, mons)
|
||||
self.teardown_device()
|
||||
|
||||
def test_router_router(self):
|
||||
"""test router-router MQ devices"""
|
||||
dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
|
||||
self.device = dev
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
|
||||
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
|
||||
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
|
||||
a = self.context.socket(zmq.DEALER)
|
||||
a.identity = b'a'
|
||||
b = self.context.socket(zmq.DEALER)
|
||||
b.identity = b'b'
|
||||
self.sockets.extend([a, b])
|
||||
|
||||
a.connect('tcp://127.0.0.1:%i'%porta)
|
||||
b.connect('tcp://127.0.0.1:%i'%portb)
|
||||
dev.start()
|
||||
time.sleep(1)
|
||||
if zmq.zmq_version_info() >= (3,1,0):
|
||||
# flush erroneous poll state, due to LIBZMQ-280
|
||||
ping_msg = [ b'ping', b'pong' ]
|
||||
for s in (a,b):
|
||||
s.send_multipart(ping_msg)
|
||||
try:
|
||||
s.recv(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
pass
|
||||
msg = [ b'hello', b'there' ]
|
||||
a.send_multipart([b'b']+msg)
|
||||
bmsg = self.recv_multipart(b)
|
||||
self.assertEqual(bmsg, [b'a']+msg)
|
||||
b.send_multipart(bmsg)
|
||||
amsg = self.recv_multipart(a)
|
||||
self.assertEqual(amsg, [b'b']+msg)
|
||||
self.teardown_device()
|
||||
|
||||
def test_default_mq_args(self):
|
||||
self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB)
|
||||
dev.setsockopt_in(zmq.LINGER, 0)
|
||||
dev.setsockopt_out(zmq.LINGER, 0)
|
||||
dev.setsockopt_mon(zmq.LINGER, 0)
|
||||
# this will raise if default args are wrong
|
||||
dev.start()
|
||||
self.teardown_device()
|
||||
|
||||
def test_mq_check_prefix(self):
|
||||
ins = self.context.socket(zmq.ROUTER)
|
||||
outs = self.context.socket(zmq.DEALER)
|
||||
mons = self.context.socket(zmq.PUB)
|
||||
self.sockets.extend([ins, outs, mons])
|
||||
|
||||
ins = unicode('in')
|
||||
outs = unicode('out')
|
||||
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)
|
35
venv/Lib/site-packages/zmq/tests/test_multipart.py
Normal file
35
venv/Lib/site-packages/zmq/tests/test_multipart.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest
|
||||
|
||||
|
||||
class TestMultipart(BaseZMQTestCase):
|
||||
|
||||
def test_router_dealer(self):
|
||||
router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
|
||||
|
||||
msg1 = b'message1'
|
||||
dealer.send(msg1)
|
||||
ident = self.recv(router)
|
||||
more = router.rcvmore
|
||||
self.assertEqual(more, True)
|
||||
msg2 = self.recv(router)
|
||||
self.assertEqual(msg1, msg2)
|
||||
more = router.rcvmore
|
||||
self.assertEqual(more, False)
|
||||
|
||||
def test_basic_multipart(self):
|
||||
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
msg = [ b'hi', b'there', b'b']
|
||||
a.send_multipart(msg)
|
||||
recvd = b.recv_multipart()
|
||||
self.assertEqual(msg, recvd)
|
||||
|
||||
if have_gevent:
|
||||
class TestMultipartGreen(GreenTest, TestMultipart):
|
||||
pass
|
53
venv/Lib/site-packages/zmq/tests/test_pair.py
Normal file
53
venv/Lib/site-packages/zmq/tests/test_pair.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
import zmq
|
||||
|
||||
|
||||
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
|
||||
|
||||
|
||||
x = b' '
|
||||
class TestPair(BaseZMQTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
msg1 = b'message1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
for i in range(10):
|
||||
msg = i*x
|
||||
s1.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = i*x
|
||||
s2.send(msg)
|
||||
|
||||
for i in range(10):
|
||||
msg = s1.recv()
|
||||
self.assertEqual(msg, i*x)
|
||||
|
||||
for i in range(10):
|
||||
msg = s2.recv()
|
||||
self.assertEqual(msg, i*x)
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10,b=list(range(10)))
|
||||
o2 = self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
o = dict(a=10,b=range(10))
|
||||
o2 = self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
if have_gevent:
|
||||
class TestReqRepGreen(GreenTest, TestPair):
|
||||
pass
|
||||
|
238
venv/Lib/site-packages/zmq/tests/test_poll.py
Normal file
238
venv/Lib/site-packages/zmq/tests/test_poll.py
Normal file
|
@ -0,0 +1,238 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
|
||||
from zmq.tests import PollZMQTestCase, have_gevent, GreenTest
|
||||
|
||||
def wait():
|
||||
time.sleep(.25)
|
||||
|
||||
|
||||
class TestPoll(PollZMQTestCase):
|
||||
|
||||
Poller = zmq.Poller
|
||||
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN|zmq.POLLOUT)
|
||||
# Poll result should contain both sockets
|
||||
socks = dict(poller.poll())
|
||||
# Now make sure that both are send ready.
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT)
|
||||
self.assertEqual(socks[s2], zmq.POLLOUT)
|
||||
# Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN
|
||||
s1.send(b'msg1')
|
||||
s2.send(b'msg2')
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT|zmq.POLLIN)
|
||||
self.assertEqual(socks[s2], zmq.POLLOUT|zmq.POLLIN)
|
||||
# Make sure that both are in POLLOUT after recv.
|
||||
s1.recv()
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT)
|
||||
self.assertEqual(socks[s2], zmq.POLLOUT)
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
|
||||
def test_reqrep(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN|zmq.POLLOUT)
|
||||
|
||||
# Make sure that s1 is in state 0 and s2 is in POLLOUT
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(s1 in socks, 0)
|
||||
self.assertEqual(socks[s2], zmq.POLLOUT)
|
||||
|
||||
# Make sure that s2 goes immediately into state 0 after send.
|
||||
s2.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(s2 in socks, 0)
|
||||
|
||||
# Make sure that s1 goes into POLLIN state after a time.sleep().
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLIN)
|
||||
|
||||
# Make sure that s1 goes into POLLOUT after recv.
|
||||
s1.recv()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT)
|
||||
|
||||
# Make sure s1 goes into state 0 after send.
|
||||
s1.send(b'msg2')
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(s1 in socks, 0)
|
||||
|
||||
# Wait and then see that s2 is in POLLIN.
|
||||
time.sleep(0.5)
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s2], zmq.POLLIN)
|
||||
|
||||
# Make sure that s2 is in POLLOUT after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s2], zmq.POLLOUT)
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
def test_no_events(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
|
||||
poller.register(s2, 0)
|
||||
self.assertTrue(s1 in poller)
|
||||
self.assertFalse(s2 in poller)
|
||||
poller.register(s1, 0)
|
||||
self.assertFalse(s1 in poller)
|
||||
|
||||
def test_pubsub(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN|zmq.POLLOUT)
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
# Now make sure that both are send ready.
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT)
|
||||
self.assertEqual(s2 in socks, 0)
|
||||
# Make sure that s1 stays in POLLOUT after a send.
|
||||
s1.send(b'msg1')
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s1], zmq.POLLOUT)
|
||||
|
||||
# Make sure that s2 is POLLIN after waiting.
|
||||
wait()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(socks[s2], zmq.POLLIN)
|
||||
|
||||
# Make sure that s2 goes into 0 after recv.
|
||||
s2.recv()
|
||||
socks = dict(poller.poll())
|
||||
self.assertEqual(s2 in socks, 0)
|
||||
|
||||
poller.unregister(s1)
|
||||
poller.unregister(s2)
|
||||
|
||||
@mark.skipif(sys.platform.startswith('win'), reason='Windows')
|
||||
def test_raw(self):
|
||||
r, w = os.pipe()
|
||||
r = os.fdopen(r, 'rb')
|
||||
w = os.fdopen(w, 'wb')
|
||||
p = self.Poller()
|
||||
p.register(r, zmq.POLLIN)
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {}
|
||||
w.write(b'x')
|
||||
w.flush()
|
||||
socks = dict(p.poll(1))
|
||||
assert socks == {r.fileno(): zmq.POLLIN}
|
||||
w.close()
|
||||
r.close()
|
||||
|
||||
def test_timeout(self):
|
||||
"""make sure Poller.poll timeout has the right units (milliseconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s1, zmq.POLLIN)
|
||||
tic = time.time()
|
||||
evt = poller.poll(.005)
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 0.1)
|
||||
tic = time.time()
|
||||
evt = poller.poll(5)
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 0.1)
|
||||
self.assertTrue(toc-tic > .001)
|
||||
tic = time.time()
|
||||
evt = poller.poll(500)
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 1)
|
||||
self.assertTrue(toc-tic > 0.1)
|
||||
|
||||
class TestSelect(PollZMQTestCase):
|
||||
|
||||
def test_pair(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
# Sleep to allow sockets to connect.
|
||||
wait()
|
||||
|
||||
rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2])
|
||||
self.assert_(s1 in wlist)
|
||||
self.assert_(s2 in wlist)
|
||||
self.assert_(s1 not in rlist)
|
||||
self.assert_(s2 not in rlist)
|
||||
|
||||
def test_timeout(self):
|
||||
"""make sure select timeout has the right units (seconds)."""
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
tic = time.time()
|
||||
r,w,x = zmq.select([s1,s2],[],[],.005)
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 1)
|
||||
self.assertTrue(toc-tic > 0.001)
|
||||
tic = time.time()
|
||||
r,w,x = zmq.select([s1,s2],[],[],.25)
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 1)
|
||||
self.assertTrue(toc-tic > 0.1)
|
||||
|
||||
|
||||
if have_gevent:
|
||||
import gevent
|
||||
from zmq import green as gzmq
|
||||
|
||||
class TestPollGreen(GreenTest, TestPoll):
|
||||
Poller = gzmq.Poller
|
||||
|
||||
def test_wakeup(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
poller = self.Poller()
|
||||
poller.register(s2, zmq.POLLIN)
|
||||
|
||||
tic = time.time()
|
||||
r = gevent.spawn(lambda: poller.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 1)
|
||||
|
||||
def test_socket_poll(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
|
||||
tic = time.time()
|
||||
r = gevent.spawn(lambda: s2.poll(10000))
|
||||
s = gevent.spawn(lambda: s1.send(b'msg1'))
|
||||
r.join()
|
||||
toc = time.time()
|
||||
self.assertTrue(toc-tic < 1)
|
||||
|
109
venv/Lib/site-packages/zmq/tests/test_proxy_steerable.py
Normal file
109
venv/Lib/site-packages/zmq/tests/test_proxy_steerable.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import time
|
||||
import struct
|
||||
|
||||
import zmq
|
||||
from zmq import devices
|
||||
from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
|
||||
|
||||
if PYPY:
|
||||
# cleanup of shared Context doesn't work on PyPy
|
||||
devices.Device.context_factory = zmq.Context
|
||||
|
||||
|
||||
class TestProxySteerable(BaseZMQTestCase):
|
||||
|
||||
def test_proxy_steerable(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(
|
||||
zmq.PULL,
|
||||
zmq.PUSH,
|
||||
zmq.PUSH,
|
||||
zmq.PAIR
|
||||
)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
self.assertEqual(msg, self.recv(pull))
|
||||
self.assertEqual(msg, self.recv(mon))
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
||||
|
||||
def test_proxy_steerable_bind_to_random_with_args(self):
|
||||
if zmq.zmq_version_info() < (4, 1):
|
||||
raise SkipTest("Steerable Proxies only in libzmq >= 4.1")
|
||||
dev = devices.ThreadProxySteerable(
|
||||
zmq.PULL,
|
||||
zmq.PUSH,
|
||||
zmq.PUSH,
|
||||
zmq.PAIR
|
||||
)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
ports = []
|
||||
min, max = 5000, 5050
|
||||
ports.extend([
|
||||
dev.bind_in_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_out_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_mon_to_random_port(iface, min_port=min, max_port=max),
|
||||
dev.bind_ctrl_to_random_port(iface, min_port=min, max_port=max)
|
||||
])
|
||||
for port in ports:
|
||||
if port < min or port > max:
|
||||
self.fail('Unexpected port number: %i' % port)
|
||||
|
||||
def test_proxy_steerable_statistics(self):
|
||||
if zmq.zmq_version_info() < (4, 3):
|
||||
raise SkipTest("STATISTICS only in libzmq >= 4.3")
|
||||
dev = devices.ThreadProxySteerable(
|
||||
zmq.PULL,
|
||||
zmq.PUSH,
|
||||
zmq.PUSH,
|
||||
zmq.PAIR
|
||||
)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = dev.bind_in_to_random_port(iface)
|
||||
port2 = dev.bind_out_to_random_port(iface)
|
||||
port3 = dev.bind_mon_to_random_port(iface)
|
||||
port4 = dev.bind_ctrl_to_random_port(iface)
|
||||
dev.start()
|
||||
time.sleep(0.25)
|
||||
msg = b'hello'
|
||||
push = self.context.socket(zmq.PUSH)
|
||||
push.connect("%s:%i" % (iface, port))
|
||||
pull = self.context.socket(zmq.PULL)
|
||||
pull.connect("%s:%i" % (iface, port2))
|
||||
mon = self.context.socket(zmq.PULL)
|
||||
mon.connect("%s:%i" % (iface, port3))
|
||||
ctrl = self.context.socket(zmq.PAIR)
|
||||
ctrl.connect("%s:%i" % (iface, port4))
|
||||
push.send(msg)
|
||||
self.sockets.extend([push, pull, mon, ctrl])
|
||||
self.assertEqual(msg, self.recv(pull))
|
||||
self.assertEqual(msg, self.recv(mon))
|
||||
ctrl.send(b'STATISTICS')
|
||||
stats = self.recv_multipart(ctrl)
|
||||
stats_int = [struct.unpack("=Q", x)[0] for x in stats]
|
||||
self.assertEqual(1, stats_int[0])
|
||||
self.assertEqual(len(msg), stats_int[1])
|
||||
self.assertEqual(1, stats_int[6])
|
||||
self.assertEqual(len(msg), stats_int[7])
|
||||
ctrl.send(b'TERMINATE')
|
||||
dev.join()
|
42
venv/Lib/site-packages/zmq/tests/test_pubsub.py
Normal file
42
venv/Lib/site-packages/zmq/tests/test_pubsub.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from random import Random
|
||||
import time
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
|
||||
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
|
||||
|
||||
|
||||
class TestPubSub(BaseZMQTestCase):
|
||||
|
||||
pass
|
||||
|
||||
# We are disabling this test while an issue is being resolved.
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv() # This is blocking!
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
def test_topic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
s2.setsockopt(zmq.SUBSCRIBE, b'x')
|
||||
time.sleep(0.1)
|
||||
msg1 = b'message'
|
||||
s1.send(msg1)
|
||||
self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK)
|
||||
msg1 = b'xmessage'
|
||||
s1.send(msg1)
|
||||
msg2 = s2.recv()
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
if have_gevent:
|
||||
class TestPubSubGreen(GreenTest, TestPubSub):
|
||||
pass
|
62
venv/Lib/site-packages/zmq/tests/test_reqrep.py
Normal file
62
venv/Lib/site-packages/zmq/tests/test_reqrep.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import zmq
|
||||
from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest
|
||||
|
||||
|
||||
class TestReqRep(BaseZMQTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
msg1 = b'message 1'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
def test_multiple(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
for i in range(10):
|
||||
msg1 = i*b' '
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
def test_bad_send_recv(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
|
||||
if zmq.zmq_version() != '2.1.8':
|
||||
# this doesn't work on 2.1.8
|
||||
for copy in (True,False):
|
||||
self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy)
|
||||
self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy)
|
||||
|
||||
# I have to have this or we die on an Abort trap.
|
||||
msg1 = b'asdf'
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
def test_json(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10,b=list(range(10)))
|
||||
o2 = self.ping_pong_json(s1, s2, o)
|
||||
|
||||
def test_pyobj(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
o = dict(a=10,b=range(10))
|
||||
o2 = self.ping_pong_pyobj(s1, s2, o)
|
||||
|
||||
def test_large_msg(self):
|
||||
s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP)
|
||||
msg1 = 10000*b'X'
|
||||
|
||||
for i in range(10):
|
||||
msg2 = self.ping_pong(s1, s2, msg1)
|
||||
self.assertEqual(msg1, msg2)
|
||||
|
||||
if have_gevent:
|
||||
class TestReqRepGreen(GreenTest, TestReqRep):
|
||||
pass
|
95
venv/Lib/site-packages/zmq/tests/test_retry_eintr.py
Normal file
95
venv/Lib/site-packages/zmq/tests/test_retry_eintr.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import signal
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import (
|
||||
BaseZMQTestCase, SkipTest, skip_pypy
|
||||
)
|
||||
from zmq.utils.strtypes import b
|
||||
|
||||
|
||||
# Partially based on EINTRBaseTest from CPython 3.5 eintr_tester
|
||||
|
||||
class TestEINTRSysCall(BaseZMQTestCase):
|
||||
""" Base class for EINTR tests. """
|
||||
|
||||
# delay for initial signal delivery
|
||||
signal_delay = 0.1
|
||||
# timeout for tests. Must be > signal_delay
|
||||
timeout = .25
|
||||
timeout_ms = int(timeout * 1e3)
|
||||
|
||||
def alarm(self, t=None):
|
||||
"""start a timer to fire only once
|
||||
|
||||
like signal.alarm, but with better resolution than integer seconds.
|
||||
"""
|
||||
if not hasattr(signal, 'setitimer'):
|
||||
raise SkipTest('EINTR tests require setitimer')
|
||||
if t is None:
|
||||
t = self.signal_delay
|
||||
self.timer_fired = False
|
||||
self.orig_handler = signal.signal(signal.SIGALRM, self.stop_timer)
|
||||
# signal_period ignored, since only one timer event is allowed to fire
|
||||
signal.setitimer(signal.ITIMER_REAL, t, 1000)
|
||||
|
||||
def stop_timer(self, *args):
|
||||
self.timer_fired = True
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
signal.signal(signal.SIGALRM, self.orig_handler)
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_retry_recv(self):
|
||||
pull = self.socket(zmq.PULL)
|
||||
pull.rcvtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, pull.recv)
|
||||
assert self.timer_fired
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_retry_send(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.sndtimeo = self.timeout_ms
|
||||
self.alarm()
|
||||
self.assertRaises(zmq.Again, push.send, b('buf'))
|
||||
assert self.timer_fired
|
||||
|
||||
def test_retry_poll(self):
|
||||
x, y = self.create_bound_pair()
|
||||
poller = zmq.Poller()
|
||||
poller.register(x, zmq.POLLIN)
|
||||
self.alarm()
|
||||
def send():
|
||||
time.sleep(2 * self.signal_delay)
|
||||
y.send(b('ping'))
|
||||
t = Thread(target=send)
|
||||
t.start()
|
||||
evts = dict(poller.poll(2 * self.timeout_ms))
|
||||
t.join()
|
||||
assert x in evts
|
||||
assert self.timer_fired
|
||||
x.recv()
|
||||
|
||||
def test_retry_term(self):
|
||||
push = self.socket(zmq.PUSH)
|
||||
push.linger = self.timeout_ms
|
||||
push.connect('tcp://127.0.0.1:5555')
|
||||
push.send(b('ping'))
|
||||
time.sleep(0.1)
|
||||
self.alarm()
|
||||
self.context.destroy()
|
||||
assert self.timer_fired
|
||||
assert self.context.closed
|
||||
|
||||
def test_retry_getsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt getsockopt")
|
||||
|
||||
def test_retry_setsockopt(self):
|
||||
raise SkipTest("TODO: find a way to interrupt setsockopt")
|
236
venv/Lib/site-packages/zmq/tests/test_security.py
Normal file
236
venv/Lib/site-packages/zmq/tests/test_security.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
"""Test libzmq security (libzmq >= 3.3.0)"""
|
||||
# -*- coding: utf8 -*-
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import os
|
||||
import contextlib
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import zmq
|
||||
from zmq.tests import (
|
||||
BaseZMQTestCase, SkipTest, PYPY
|
||||
)
|
||||
from zmq.utils import z85
|
||||
|
||||
|
||||
USER = b"admin"
|
||||
PASS = b"password"
|
||||
|
||||
class TestSecurity(BaseZMQTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if zmq.zmq_version_info() < (4,0):
|
||||
raise SkipTest("security is new in libzmq 4.0")
|
||||
try:
|
||||
zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("security requires libzmq to be built with CURVE support")
|
||||
super(TestSecurity, self).setUp()
|
||||
|
||||
def zap_handler(self):
|
||||
socket = self.context.socket(zmq.REP)
|
||||
socket.bind("inproc://zeromq.zap.01")
|
||||
try:
|
||||
msg = self.recv_multipart(socket)
|
||||
|
||||
version, sequence, domain, address, identity, mechanism = msg[:6]
|
||||
if mechanism == b'PLAIN':
|
||||
username, password = msg[6:]
|
||||
elif mechanism == b'CURVE':
|
||||
key = msg[6]
|
||||
|
||||
self.assertEqual(version, b"1.0")
|
||||
self.assertEqual(identity, b"IDENT")
|
||||
reply = [version, sequence]
|
||||
if mechanism == b'CURVE' or \
|
||||
(mechanism == b'PLAIN' and username == USER and password == PASS) or \
|
||||
(mechanism == b'NULL'):
|
||||
reply.extend([
|
||||
b"200",
|
||||
b"OK",
|
||||
b"anonymous",
|
||||
b"\5Hello\0\0\0\5World",
|
||||
])
|
||||
else:
|
||||
reply.extend([
|
||||
b"400",
|
||||
b"Invalid username or password",
|
||||
b"",
|
||||
b"",
|
||||
])
|
||||
socket.send_multipart(reply)
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zap(self):
|
||||
self.start_zap()
|
||||
time.sleep(0.5) # allow time for the Thread to start
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.stop_zap()
|
||||
|
||||
def start_zap(self):
|
||||
self.zap_thread = Thread(target=self.zap_handler)
|
||||
self.zap_thread.start()
|
||||
|
||||
def stop_zap(self):
|
||||
self.zap_thread.join()
|
||||
|
||||
def bounce(self, server, client, test_metadata=True):
|
||||
msg = [os.urandom(64), os.urandom(64)]
|
||||
client.send_multipart(msg)
|
||||
frames = self.recv_multipart(server, copy=False)
|
||||
recvd = list(map(lambda x: x.bytes, frames))
|
||||
|
||||
try:
|
||||
if test_metadata and not PYPY:
|
||||
for frame in frames:
|
||||
self.assertEqual(frame.get('User-Id'), 'anonymous')
|
||||
self.assertEqual(frame.get('Hello'), 'World')
|
||||
self.assertEqual(frame['Socket-Type'], 'DEALER')
|
||||
except zmq.ZMQVersionError:
|
||||
pass
|
||||
|
||||
self.assertEqual(recvd, msg)
|
||||
server.send_multipart(recvd)
|
||||
msg2 = self.recv_multipart(client)
|
||||
self.assertEqual(msg2, msg)
|
||||
|
||||
def test_null(self):
|
||||
"""test NULL (default) security"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.assertEqual(client.MECHANISM, zmq.NULL)
|
||||
self.assertEqual(server.mechanism, zmq.NULL)
|
||||
self.assertEqual(client.plain_server, 0)
|
||||
self.assertEqual(server.plain_server, 0)
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client, False)
|
||||
|
||||
def test_plain(self):
|
||||
"""test PLAIN authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.assertEqual(client.plain_username, b'')
|
||||
self.assertEqual(client.plain_password, b'')
|
||||
client.plain_username = USER
|
||||
client.plain_password = PASS
|
||||
self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER)
|
||||
self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS)
|
||||
self.assertEqual(client.plain_server, 0)
|
||||
self.assertEqual(server.plain_server, 0)
|
||||
server.plain_server = True
|
||||
self.assertEqual(server.mechanism, zmq.PLAIN)
|
||||
self.assertEqual(client.mechanism, zmq.PLAIN)
|
||||
|
||||
assert not client.plain_server
|
||||
assert server.plain_server
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
||||
|
||||
def skip_plain_inauth(self):
|
||||
"""test PLAIN failed authentication"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
client.plain_username = USER
|
||||
client.plain_password = b'incorrect'
|
||||
server.plain_server = True
|
||||
self.assertEqual(server.mechanism, zmq.PLAIN)
|
||||
self.assertEqual(client.mechanism, zmq.PLAIN)
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
client.send(b'ping')
|
||||
server.rcvtimeo = 250
|
||||
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
|
||||
|
||||
def test_keypair(self):
|
||||
"""test curve_keypair"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
self.assertEqual(type(secret), bytes)
|
||||
self.assertEqual(type(public), bytes)
|
||||
self.assertEqual(len(secret), 40)
|
||||
self.assertEqual(len(public), 40)
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ]
|
||||
self.assertEqual(type(bsecret), bytes)
|
||||
self.assertEqual(type(bpublic), bytes)
|
||||
self.assertEqual(len(bsecret), 32)
|
||||
self.assertEqual(len(bpublic), 32)
|
||||
|
||||
def test_curve_public(self):
|
||||
"""test curve_public"""
|
||||
try:
|
||||
public, secret = zmq.curve_keypair()
|
||||
except zmq.ZMQError:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
if zmq.zmq_version_info() < (4,2):
|
||||
raise SkipTest("curve_public is new in libzmq 4.2")
|
||||
|
||||
derived_public = zmq.curve_public(secret)
|
||||
|
||||
self.assertEqual(type(derived_public), bytes)
|
||||
self.assertEqual(len(derived_public), 40)
|
||||
|
||||
# verify that it is indeed Z85
|
||||
bpublic = z85.decode(derived_public)
|
||||
self.assertEqual(type(bpublic), bytes)
|
||||
self.assertEqual(len(bpublic), 32)
|
||||
|
||||
# verify that it is equal to the known public key
|
||||
self.assertEqual(derived_public, public)
|
||||
|
||||
def test_curve(self):
|
||||
"""test CURVE encryption"""
|
||||
server = self.socket(zmq.DEALER)
|
||||
server.identity = b'IDENT'
|
||||
client = self.socket(zmq.DEALER)
|
||||
self.sockets.extend([server, client])
|
||||
try:
|
||||
server.curve_server = True
|
||||
except zmq.ZMQError as e:
|
||||
# will raise EINVAL if no CURVE support
|
||||
if e.errno == zmq.EINVAL:
|
||||
raise SkipTest("CURVE unsupported")
|
||||
|
||||
server_public, server_secret = zmq.curve_keypair()
|
||||
client_public, client_secret = zmq.curve_keypair()
|
||||
|
||||
server.curve_secretkey = server_secret
|
||||
server.curve_publickey = server_public
|
||||
client.curve_serverkey = server_public
|
||||
client.curve_publickey = client_public
|
||||
client.curve_secretkey = client_secret
|
||||
|
||||
self.assertEqual(server.mechanism, zmq.CURVE)
|
||||
self.assertEqual(client.mechanism, zmq.CURVE)
|
||||
|
||||
self.assertEqual(server.get(zmq.CURVE_SERVER), True)
|
||||
self.assertEqual(client.get(zmq.CURVE_SERVER), False)
|
||||
|
||||
with self.zap():
|
||||
iface = 'tcp://127.0.0.1'
|
||||
port = server.bind_to_random_port(iface)
|
||||
client.connect("%s:%i" % (iface, port))
|
||||
self.bounce(server, client)
|
615
venv/Lib/site-packages/zmq/tests/test_socket.py
Normal file
615
venv/Lib/site-packages/zmq/tests/test_socket.py
Normal file
|
@ -0,0 +1,615 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
import copy
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import warnings
|
||||
import socket
|
||||
import sys
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
import pytest
|
||||
from pytest import mark
|
||||
|
||||
import zmq
|
||||
from zmq.tests import (
|
||||
BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy
|
||||
)
|
||||
from zmq.utils.strtypes import unicode
|
||||
|
||||
pypy = platform.python_implementation().lower() == 'pypy'
|
||||
windows = platform.platform().lower().startswith('windows')
|
||||
on_travis = bool(os.environ.get('TRAVIS_PYTHON_VERSION'))
|
||||
|
||||
# polling on windows is slow
|
||||
POLL_TIMEOUT = 1000 if windows else 100
|
||||
|
||||
class TestSocket(BaseZMQTestCase):
|
||||
|
||||
def test_create(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
# Superluminal protocol not yet implemented
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
|
||||
s.close()
|
||||
del ctx
|
||||
|
||||
def test_context_manager(self):
|
||||
url = 'inproc://a'
|
||||
with self.Context() as ctx:
|
||||
with ctx.socket(zmq.PUSH) as a:
|
||||
a.bind(url)
|
||||
with ctx.socket(zmq.PULL) as b:
|
||||
b.connect(url)
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
rcvd = self.recv(b)
|
||||
self.assertEqual(rcvd, msg)
|
||||
self.assertEqual(b.closed, True)
|
||||
self.assertEqual(a.closed, True)
|
||||
self.assertEqual(ctx.closed, True)
|
||||
|
||||
def test_dir(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
self.assertTrue('send' in dir(s))
|
||||
self.assertTrue('IDENTITY' in dir(s))
|
||||
self.assertTrue('AFFINITY' in dir(s))
|
||||
self.assertTrue('FD' in dir(s))
|
||||
s.close()
|
||||
ctx.term()
|
||||
|
||||
@mark.skipif(mock is None, reason="requires unittest.mock")
|
||||
def test_mockable(self):
|
||||
s = self.socket(zmq.SUB)
|
||||
m = mock.Mock(spec=s)
|
||||
s.close()
|
||||
|
||||
def test_bind_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
p = s.bind_to_random_port(unicode("tcp://*"))
|
||||
|
||||
def test_connect_unicode(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
s.connect(unicode("tcp://127.0.0.1:5555"))
|
||||
|
||||
def test_bind_to_random_port(self):
|
||||
# Check that bind_to_random_port do not hide useful exception
|
||||
ctx = self.Context()
|
||||
c = ctx.socket(zmq.PUB)
|
||||
# Invalid format
|
||||
try:
|
||||
c.bind_to_random_port('tcp:*')
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(e.errno, zmq.EINVAL)
|
||||
# Invalid protocol
|
||||
try:
|
||||
c.bind_to_random_port('rand://*')
|
||||
except zmq.ZMQError as e:
|
||||
self.assertEqual(e.errno, zmq.EPROTONOSUPPORT)
|
||||
|
||||
def test_identity(self):
|
||||
s = self.context.socket(zmq.PULL)
|
||||
self.sockets.append(s)
|
||||
ident = b'identity\0\0'
|
||||
s.identity = ident
|
||||
self.assertEqual(s.get(zmq.IDENTITY), ident)
|
||||
|
||||
def test_unicode_sockopts(self):
|
||||
"""test setting/getting sockopts with unicode strings"""
|
||||
topic = "tést"
|
||||
if str is not unicode:
|
||||
topic = topic.decode('utf8')
|
||||
p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
self.assertEqual(s.send_unicode, s.send_unicode)
|
||||
self.assertEqual(p.recv_unicode, p.recv_unicode)
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
|
||||
s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
|
||||
self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
|
||||
s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
|
||||
self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
|
||||
|
||||
identb = s.getsockopt(zmq.IDENTITY)
|
||||
identu = identb.decode('utf16')
|
||||
identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
|
||||
self.assertEqual(identu, identu2)
|
||||
time.sleep(0.1) # wait for connection/subscription
|
||||
p.send_unicode(topic,zmq.SNDMORE)
|
||||
p.send_unicode(topic*2, encoding='latin-1')
|
||||
self.assertEqual(topic, s.recv_unicode())
|
||||
self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1'))
|
||||
|
||||
def test_int_sockopts(self):
|
||||
"test integer sockopts"
|
||||
v = zmq.zmq_version_info()
|
||||
if v < (3,0):
|
||||
default_hwm = 0
|
||||
else:
|
||||
default_hwm = 1000
|
||||
p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
p.setsockopt(zmq.LINGER, 0)
|
||||
self.assertEqual(p.getsockopt(zmq.LINGER), 0)
|
||||
p.setsockopt(zmq.LINGER, -1)
|
||||
self.assertEqual(p.getsockopt(zmq.LINGER), -1)
|
||||
self.assertEqual(p.hwm, default_hwm)
|
||||
p.hwm = 11
|
||||
self.assertEqual(p.hwm, 11)
|
||||
# p.setsockopt(zmq.EVENTS, zmq.POLLIN)
|
||||
self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT)
|
||||
self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1)
|
||||
self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type)
|
||||
self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB)
|
||||
self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type)
|
||||
self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB)
|
||||
|
||||
# check for overflow / wrong type:
|
||||
errors = []
|
||||
backref = {}
|
||||
constants = zmq.constants
|
||||
for name in constants.__all__:
|
||||
value = getattr(constants, name)
|
||||
if isinstance(value, int):
|
||||
backref[value] = name
|
||||
for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts):
|
||||
sopt = backref[opt]
|
||||
if sopt.startswith((
|
||||
'ROUTER', 'XPUB', 'TCP', 'FAIL',
|
||||
'REQ_', 'CURVE_', 'PROBE_ROUTER',
|
||||
'IPC_FILTER', 'GSSAPI', 'STREAM_',
|
||||
'VMCI_BUFFER_SIZE', 'VMCI_BUFFER_MIN_SIZE',
|
||||
'VMCI_BUFFER_MAX_SIZE', 'VMCI_CONNECT_TIMEOUT',
|
||||
)):
|
||||
# some sockopts are write-only
|
||||
continue
|
||||
try:
|
||||
n = p.getsockopt(opt)
|
||||
except zmq.ZMQError as e:
|
||||
errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e))
|
||||
else:
|
||||
if n > 2**31:
|
||||
errors.append("getsockopt(zmq.%s) returned a ridiculous value."
|
||||
" It is probably the wrong type."%sopt)
|
||||
if errors:
|
||||
self.fail('\n'.join([''] + errors))
|
||||
|
||||
def test_bad_sockopts(self):
|
||||
"""Test that appropriate errors are raised on bad socket options"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
s.setsockopt(zmq.LINGER, 0)
|
||||
# unrecognized int sockopts pass through to libzmq, and should raise EINVAL
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
|
||||
# but only int sockopts are allowed through this way, otherwise raise a TypeError
|
||||
self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
|
||||
# some sockopts are valid in general, but not on every socket:
|
||||
self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
|
||||
|
||||
def test_sockopt_roundtrip(self):
|
||||
"test set/getsockopt roundtrip."
|
||||
p = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(p)
|
||||
p.setsockopt(zmq.LINGER, 11)
|
||||
self.assertEqual(p.getsockopt(zmq.LINGER), 11)
|
||||
|
||||
def test_send_unicode(self):
|
||||
"test sending unicode objects"
|
||||
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a,b])
|
||||
u = "çπ§"
|
||||
if str is not unicode:
|
||||
u = u.decode('utf8')
|
||||
self.assertRaises(TypeError, a.send, u,copy=False)
|
||||
self.assertRaises(TypeError, a.send, u,copy=True)
|
||||
a.send_unicode(u)
|
||||
s = b.recv()
|
||||
self.assertEqual(s,u.encode('utf8'))
|
||||
self.assertEqual(s.decode('utf8'),u)
|
||||
a.send_unicode(u,encoding='utf16')
|
||||
s = b.recv_unicode(encoding='utf16')
|
||||
self.assertEqual(s,u)
|
||||
|
||||
def test_send_multipart_check_type(self):
|
||||
"check type on all frames in send_multipart"
|
||||
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
||||
self.sockets.extend([a,b])
|
||||
self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
|
||||
a.send_multipart([b'b'])
|
||||
rcvd = self.recv_multipart(b)
|
||||
self.assertEqual(rcvd, [b'b'])
|
||||
|
||||
@skip_pypy
|
||||
def test_tracker(self):
|
||||
"test the MessageTracker object for tracking when zmq is done with a buffer"
|
||||
addr = 'tcp://127.0.0.1'
|
||||
# get a port:
|
||||
sock = socket.socket()
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
port = sock.getsockname()[1]
|
||||
iface = "%s:%i" % (addr, port)
|
||||
sock.close()
|
||||
time.sleep(0.1)
|
||||
|
||||
a = self.context.socket(zmq.PUSH)
|
||||
b = self.context.socket(zmq.PULL)
|
||||
self.sockets.extend([a,b])
|
||||
a.connect(iface)
|
||||
time.sleep(0.1)
|
||||
p1 = a.send(b'something', copy=False, track=True)
|
||||
assert isinstance(p1, zmq.MessageTracker)
|
||||
assert p1 is zmq._FINISHED_TRACKER
|
||||
# small message, should start done
|
||||
assert p1.done
|
||||
|
||||
# disable zero-copy threshold
|
||||
a.copy_threshold = 0
|
||||
|
||||
p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
|
||||
assert isinstance(p2, zmq.MessageTracker)
|
||||
assert not p2.done
|
||||
|
||||
b.bind(iface)
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(p1.done, True)
|
||||
self.assertEqual(msg, [b'something'])
|
||||
msg = self.recv_multipart(b)
|
||||
for i in range(10):
|
||||
if p2.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(p2.done, True)
|
||||
self.assertEqual(msg, [b'something', b'else'])
|
||||
m = zmq.Frame(b"again", copy=False, track=True)
|
||||
self.assertEqual(m.tracker.done, False)
|
||||
p1 = a.send(m, copy=False)
|
||||
p2 = a.send(m, copy=False)
|
||||
self.assertEqual(m.tracker.done, False)
|
||||
self.assertEqual(p1.done, False)
|
||||
self.assertEqual(p2.done, False)
|
||||
msg = self.recv_multipart(b)
|
||||
self.assertEqual(m.tracker.done, False)
|
||||
self.assertEqual(msg, [b'again'])
|
||||
msg = self.recv_multipart(b)
|
||||
self.assertEqual(m.tracker.done, False)
|
||||
self.assertEqual(msg, [b'again'])
|
||||
self.assertEqual(p1.done, False)
|
||||
self.assertEqual(p2.done, False)
|
||||
pm = m.tracker
|
||||
del m
|
||||
for i in range(10):
|
||||
if p1.done:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
self.assertEqual(p1.done, True)
|
||||
self.assertEqual(p2.done, True)
|
||||
m = zmq.Frame(b'something', track=False)
|
||||
self.assertRaises(ValueError, a.send, m, copy=False, track=True)
|
||||
|
||||
def test_close(self):
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.PUB)
|
||||
s.close()
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
|
||||
self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
|
||||
del ctx
|
||||
|
||||
def test_attr(self):
|
||||
"""set setting/getting sockopts as attributes"""
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
linger = 10
|
||||
s.linger = linger
|
||||
self.assertEqual(linger, s.linger)
|
||||
self.assertEqual(linger, s.getsockopt(zmq.LINGER))
|
||||
self.assertEqual(s.fd, s.getsockopt(zmq.FD))
|
||||
|
||||
def test_bad_attr(self):
|
||||
s = self.context.socket(zmq.DEALER)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.apple='foo'
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad setattr should have raised AttributeError")
|
||||
try:
|
||||
s.apple
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.fail("bad getattr should have raised AttributeError")
|
||||
|
||||
def test_subclass(self):
|
||||
"""subclasses can assign attributes"""
|
||||
class S(zmq.Socket):
|
||||
a = None
|
||||
def __init__(self, *a, **kw):
|
||||
self.a=-1
|
||||
super(S, self).__init__(*a, **kw)
|
||||
|
||||
s = S(self.context, zmq.REP)
|
||||
self.sockets.append(s)
|
||||
self.assertEqual(s.a, -1)
|
||||
s.a=1
|
||||
self.assertEqual(s.a, 1)
|
||||
a=s.a
|
||||
self.assertEqual(a, 1)
|
||||
|
||||
def test_recv_multipart(self):
|
||||
a,b = self.create_bound_pair()
|
||||
msg = b'hi'
|
||||
for i in range(3):
|
||||
a.send(msg)
|
||||
time.sleep(0.1)
|
||||
for i in range(3):
|
||||
self.assertEqual(self.recv_multipart(b), [msg])
|
||||
|
||||
def test_close_after_destroy(self):
|
||||
"""s.close() after ctx.destroy() should be fine"""
|
||||
ctx = self.Context()
|
||||
s = ctx.socket(zmq.REP)
|
||||
ctx.destroy()
|
||||
# reaper is not instantaneous
|
||||
time.sleep(1e-2)
|
||||
s.close()
|
||||
self.assertTrue(s.closed)
|
||||
|
||||
def test_poll(self):
|
||||
a,b = self.create_bound_pair()
|
||||
tic = time.time()
|
||||
evt = a.poll(POLL_TIMEOUT)
|
||||
self.assertEqual(evt, 0)
|
||||
evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
|
||||
self.assertEqual(evt, zmq.POLLOUT)
|
||||
msg = b'hi'
|
||||
a.send(msg)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
self.assertEqual(evt, zmq.POLLIN)
|
||||
msg2 = self.recv(b)
|
||||
evt = b.poll(POLL_TIMEOUT)
|
||||
self.assertEqual(evt, 0)
|
||||
self.assertEqual(msg2, msg)
|
||||
|
||||
def test_ipc_path_max_length(self):
|
||||
"""IPC_PATH_MAX_LEN is a sensible value"""
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
|
||||
self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg)
|
||||
self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg)
|
||||
|
||||
def test_ipc_path_max_length_msg(self):
|
||||
if zmq.IPC_PATH_MAX_LEN == 0:
|
||||
raise SkipTest("IPC_PATH_MAX_LEN undefined")
|
||||
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
try:
|
||||
s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
|
||||
except zmq.ZMQError as e:
|
||||
self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror)
|
||||
|
||||
@mark.skipif(windows, reason="ipc not supported on Windows.")
|
||||
def test_ipc_path_no_such_file_or_directory_message(self):
|
||||
"""Display the ipc path in case of an ENOENT exception"""
|
||||
s = self.context.socket(zmq.PUB)
|
||||
self.sockets.append(s)
|
||||
invalid_path = '/foo/bar'
|
||||
with pytest.raises(zmq.ZMQError) as error:
|
||||
s.bind('ipc://{0}'.format(invalid_path))
|
||||
assert error.value.errno == errno.ENOENT
|
||||
error_message = str(error.value)
|
||||
assert invalid_path in error_message
|
||||
assert "no such file or directory" in error_message.lower()
|
||||
|
||||
def test_hwm(self):
|
||||
zmq3 = zmq.zmq_version_info()[0] >= 3
|
||||
for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
|
||||
s = self.context.socket(stype)
|
||||
s.hwm = 100
|
||||
self.assertEqual(s.hwm, 100)
|
||||
if zmq3:
|
||||
try:
|
||||
self.assertEqual(s.sndhwm, 100)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.assertEqual(s.rcvhwm, 100)
|
||||
except AttributeError:
|
||||
pass
|
||||
s.close()
|
||||
|
||||
def test_copy(self):
|
||||
s = self.socket(zmq.PUB)
|
||||
scopy = copy.copy(s)
|
||||
sdcopy = copy.deepcopy(s)
|
||||
self.assert_(scopy._shadow)
|
||||
self.assert_(sdcopy._shadow)
|
||||
self.assertEqual(s.underlying, scopy.underlying)
|
||||
self.assertEqual(s.underlying, sdcopy.underlying)
|
||||
s.close()
|
||||
|
||||
def test_send_buffer(self):
|
||||
a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
|
||||
for buffer_type in (memoryview, bytearray):
|
||||
rawbytes = str(buffer_type).encode('ascii')
|
||||
msg = buffer_type(rawbytes)
|
||||
a.send(msg)
|
||||
recvd = b.recv()
|
||||
assert recvd == rawbytes
|
||||
|
||||
def test_shadow(self):
|
||||
p = self.socket(zmq.PUSH)
|
||||
p.bind("tcp://127.0.0.1:5555")
|
||||
p2 = zmq.Socket.shadow(p.underlying)
|
||||
self.assertEqual(p.underlying, p2.underlying)
|
||||
s = self.socket(zmq.PULL)
|
||||
s2 = zmq.Socket.shadow(s.underlying)
|
||||
self.assertNotEqual(s.underlying, p.underlying)
|
||||
self.assertEqual(s.underlying, s2.underlying)
|
||||
s2.connect("tcp://127.0.0.1:5555")
|
||||
sent = b'hi'
|
||||
p2.send(sent)
|
||||
rcvd = self.recv(s2)
|
||||
self.assertEqual(rcvd, sent)
|
||||
|
||||
def test_shadow_pyczmq(self):
|
||||
try:
|
||||
from pyczmq import zctx, zsocket
|
||||
except Exception:
|
||||
raise SkipTest("Requires pyczmq")
|
||||
|
||||
ctx = zctx.new()
|
||||
ca = zsocket.new(ctx, zmq.PUSH)
|
||||
cb = zsocket.new(ctx, zmq.PULL)
|
||||
a = zmq.Socket.shadow(ca)
|
||||
b = zmq.Socket.shadow(cb)
|
||||
a.bind("inproc://a")
|
||||
b.connect("inproc://a")
|
||||
a.send(b'hi')
|
||||
rcvd = self.recv(b)
|
||||
self.assertEqual(rcvd, b'hi')
|
||||
|
||||
def test_subscribe_method(self):
|
||||
pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
|
||||
sub.subscribe('prefix')
|
||||
sub.subscribe = 'c'
|
||||
p = zmq.Poller()
|
||||
p.register(sub, zmq.POLLIN)
|
||||
# wait for subscription handshake
|
||||
for i in range(100):
|
||||
pub.send(b'canary')
|
||||
events = p.poll(250)
|
||||
if events:
|
||||
break
|
||||
self.recv(sub)
|
||||
pub.send(b'prefixmessage')
|
||||
msg = self.recv(sub)
|
||||
self.assertEqual(msg, b'prefixmessage')
|
||||
sub.unsubscribe('prefix')
|
||||
pub.send(b'prefixmessage')
|
||||
events = p.poll(1000)
|
||||
self.assertEqual(events, [])
|
||||
|
||||
# Travis can't handle how much memory PyPy uses on this test
|
||||
@mark.skipif(
|
||||
(
|
||||
pypy and on_travis
|
||||
) or (
|
||||
sys.maxsize < 2**32
|
||||
) or (
|
||||
windows
|
||||
),
|
||||
reason="only run on 64b and not on Travis."
|
||||
)
|
||||
@mark.large
|
||||
def test_large_send(self):
|
||||
c = os.urandom(1)
|
||||
N = 2**31 + 1
|
||||
try:
|
||||
buf = c * N
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
a, b = self.create_bound_pair()
|
||||
try:
|
||||
a.send(buf, copy=False)
|
||||
rcvd = b.recv(copy=False)
|
||||
except MemoryError as e:
|
||||
raise SkipTest("Not enough memory: %s" % e)
|
||||
# sample the front and back of the received message
|
||||
# without checking the whole content
|
||||
# Python 2: items in memoryview are bytes
|
||||
# Python 3: items im memoryview are int
|
||||
byte = c if sys.version_info < (3,) else ord(c)
|
||||
view = memoryview(rcvd)
|
||||
assert len(view) == N
|
||||
assert view[0] == byte
|
||||
assert view[-1] == byte
|
||||
|
||||
def test_custom_serialize(self):
|
||||
a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
|
||||
def serialize(msg):
|
||||
frames = []
|
||||
frames.extend(msg.get('identities', []))
|
||||
content = json.dumps(msg['content']).encode('utf8')
|
||||
frames.append(content)
|
||||
return frames
|
||||
|
||||
def deserialize(frames):
|
||||
identities = frames[:-1]
|
||||
content = json.loads(frames[-1].decode('utf8'))
|
||||
return {
|
||||
'identities': identities,
|
||||
'content': content,
|
||||
}
|
||||
|
||||
msg = {
|
||||
'content': {
|
||||
'a': 5,
|
||||
'b': 'bee',
|
||||
}
|
||||
}
|
||||
a.send_serialized(msg, serialize)
|
||||
recvd = b.recv_serialized(deserialize)
|
||||
assert recvd['content'] == msg['content']
|
||||
assert recvd['identities']
|
||||
# bounce back, tests identities
|
||||
b.send_serialized(recvd, serialize)
|
||||
r2 = a.recv_serialized(deserialize)
|
||||
assert r2['content'] == msg['content']
|
||||
assert not r2['identities']
|
||||
|
||||
|
||||
if have_gevent and not windows:
|
||||
import gevent
|
||||
|
||||
class TestSocketGreen(GreenTest, TestSocket):
|
||||
test_bad_attr = GreenTest.skip_green
|
||||
test_close_after_destroy = GreenTest.skip_green
|
||||
|
||||
def test_timeout(self):
|
||||
a,b = self.create_bound_pair()
|
||||
g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
|
||||
timeout = gevent.Timeout(0.1)
|
||||
timeout.start()
|
||||
self.assertRaises(gevent.Timeout, b.recv)
|
||||
g.kill()
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
|
||||
def test_warn_set_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.rcvtimeo = 5
|
||||
s.close()
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, UserWarning)
|
||||
|
||||
|
||||
@mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
|
||||
def test_warn_get_timeo(self):
|
||||
s = self.context.socket(zmq.REQ)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
s.sndtimeo
|
||||
s.close()
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, UserWarning)
|
8
venv/Lib/site-packages/zmq/tests/test_ssh.py
Normal file
8
venv/Lib/site-packages/zmq/tests/test_ssh.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from zmq.ssh.tunnel import select_random_ports
|
||||
|
||||
def test_random_ports():
|
||||
for i in range(4096):
|
||||
ports = select_random_ports(10)
|
||||
assert len(ports) == 10
|
||||
for p in ports:
|
||||
assert ports.count(p) == 1
|
44
venv/Lib/site-packages/zmq/tests/test_version.py
Normal file
44
venv/Lib/site-packages/zmq/tests/test_version.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
|
||||
from unittest import TestCase
|
||||
import zmq
|
||||
from zmq.sugar import version
|
||||
|
||||
|
||||
class TestVersion(TestCase):
|
||||
|
||||
def test_pyzmq_version(self):
|
||||
vs = zmq.pyzmq_version()
|
||||
vs2 = zmq.__version__
|
||||
self.assertTrue(isinstance(vs, str))
|
||||
if zmq.__revision__:
|
||||
self.assertEqual(vs, '@'.join(vs2, zmq.__revision__))
|
||||
else:
|
||||
self.assertEqual(vs, vs2)
|
||||
if version.VERSION_EXTRA:
|
||||
self.assertTrue(version.VERSION_EXTRA in vs)
|
||||
self.assertTrue(version.VERSION_EXTRA in vs2)
|
||||
|
||||
def test_pyzmq_version_info(self):
|
||||
info = zmq.pyzmq_version_info()
|
||||
self.assertTrue(isinstance(info, tuple))
|
||||
for n in info[:3]:
|
||||
self.assertTrue(isinstance(n, int))
|
||||
if version.VERSION_EXTRA:
|
||||
self.assertEqual(len(info), 4)
|
||||
self.assertEqual(info[-1], float('inf'))
|
||||
else:
|
||||
self.assertEqual(len(info), 3)
|
||||
|
||||
def test_zmq_version_info(self):
|
||||
info = zmq.zmq_version_info()
|
||||
self.assertTrue(isinstance(info, tuple))
|
||||
for n in info[:3]:
|
||||
self.assertTrue(isinstance(n, int))
|
||||
|
||||
def test_zmq_version(self):
|
||||
v = zmq.zmq_version()
|
||||
self.assertTrue(isinstance(v, str))
|
||||
|
63
venv/Lib/site-packages/zmq/tests/test_win32_shim.py
Normal file
63
venv/Lib/site-packages/zmq/tests/test_win32_shim.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from zmq.tests import BaseZMQTestCase
|
||||
from zmq.utils.win32 import allow_interrupt
|
||||
|
||||
|
||||
def count_calls(f):
|
||||
@wraps(f)
|
||||
def _(*args, **kwds):
|
||||
try:
|
||||
return f(*args, **kwds)
|
||||
finally:
|
||||
_.__calls__ += 1
|
||||
_.__calls__ = 0
|
||||
return _
|
||||
|
||||
|
||||
@mark.new_console
|
||||
class TestWindowsConsoleControlHandler(BaseZMQTestCase):
|
||||
|
||||
@mark.new_console
|
||||
@mark.skipif(
|
||||
not sys.platform.startswith('win'),
|
||||
reason='Windows only test')
|
||||
def test_handler(self):
|
||||
@count_calls
|
||||
def interrupt_polling():
|
||||
print('Caught CTRL-C!')
|
||||
|
||||
from ctypes import windll
|
||||
from ctypes.wintypes import BOOL, DWORD
|
||||
|
||||
kernel32 = windll.LoadLibrary('kernel32')
|
||||
|
||||
# <http://msdn.microsoft.com/en-us/library/ms683155.aspx>
|
||||
GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent
|
||||
GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD)
|
||||
GenerateConsoleCtrlEvent.restype = BOOL
|
||||
|
||||
# Simulate CTRL-C event while handler is active.
|
||||
try:
|
||||
with allow_interrupt(interrupt_polling) as context:
|
||||
result = GenerateConsoleCtrlEvent(0, 0)
|
||||
# Sleep so that we give time to the handler to
|
||||
# capture the Ctrl-C event.
|
||||
time.sleep(0.5)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
else:
|
||||
if result == 0:
|
||||
raise WindowsError()
|
||||
else:
|
||||
self.fail('Expecting `KeyboardInterrupt` exception!')
|
||||
|
||||
# Make sure our handler was called.
|
||||
self.assertEqual(interrupt_polling.__calls__, 1)
|
63
venv/Lib/site-packages/zmq/tests/test_z85.py
Normal file
63
venv/Lib/site-packages/zmq/tests/test_z85.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# -*- coding: utf8 -*-
|
||||
"""Test Z85 encoding
|
||||
|
||||
confirm values and roundtrip with test values from the reference implementation.
|
||||
"""
|
||||
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from unittest import TestCase
|
||||
from zmq.utils import z85
|
||||
|
||||
|
||||
class TestZ85(TestCase):
|
||||
|
||||
def test_client_public(self):
|
||||
client_public = \
|
||||
b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B" \
|
||||
b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A" \
|
||||
b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26" \
|
||||
b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18"
|
||||
encoded = z85.encode(client_public)
|
||||
|
||||
self.assertEqual(encoded, b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID")
|
||||
decoded = z85.decode(encoded)
|
||||
self.assertEqual(decoded, client_public)
|
||||
|
||||
def test_client_secret(self):
|
||||
client_secret = \
|
||||
b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67" \
|
||||
b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89" \
|
||||
b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51" \
|
||||
b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88"
|
||||
encoded = z85.encode(client_secret)
|
||||
|
||||
self.assertEqual(encoded, b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs")
|
||||
decoded = z85.decode(encoded)
|
||||
self.assertEqual(decoded, client_secret)
|
||||
|
||||
def test_server_public(self):
|
||||
server_public = \
|
||||
b"\x54\xFC\xBA\x24\xE9\x32\x49\x96" \
|
||||
b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0" \
|
||||
b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5" \
|
||||
b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52"
|
||||
encoded = z85.encode(server_public)
|
||||
|
||||
self.assertEqual(encoded, b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7")
|
||||
decoded = z85.decode(encoded)
|
||||
self.assertEqual(decoded, server_public)
|
||||
|
||||
def test_server_secret(self):
|
||||
server_secret = \
|
||||
b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D" \
|
||||
b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0" \
|
||||
b"\x4D\x48\x96\x3F\x79\x25\x98\x77" \
|
||||
b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7"
|
||||
encoded = z85.encode(server_secret)
|
||||
|
||||
self.assertEqual(encoded, b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6")
|
||||
decoded = z85.decode(encoded)
|
||||
self.assertEqual(decoded, server_secret)
|
||||
|
79
venv/Lib/site-packages/zmq/tests/test_zmqstream.py
Normal file
79
venv/Lib/site-packages/zmq/tests/test_zmqstream.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf8 -*-
|
||||
# Copyright (C) PyZMQ Developers
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError:
|
||||
asyncio = None
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
try:
|
||||
import tornado
|
||||
from tornado import gen
|
||||
from zmq.eventloop import ioloop, zmqstream
|
||||
except ImportError:
|
||||
tornado = None
|
||||
|
||||
class TestZMQStream(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if tornado is None:
|
||||
pytest.skip()
|
||||
if asyncio:
|
||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
||||
self.context = zmq.Context()
|
||||
self.loop = ioloop.IOLoop()
|
||||
self.loop.make_current()
|
||||
self.push = zmqstream.ZMQStream(self.context.socket(zmq.PUSH))
|
||||
self.pull = zmqstream.ZMQStream(self.context.socket(zmq.PULL))
|
||||
port = self.push.bind_to_random_port('tcp://127.0.0.1')
|
||||
self.pull.connect('tcp://127.0.0.1:%i' % port)
|
||||
self.stream = self.push
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close(all_fds=True)
|
||||
self.context.term()
|
||||
ioloop.IOLoop.clear_current()
|
||||
|
||||
def run_until_timeout(self, timeout=10):
|
||||
timed_out = []
|
||||
@gen.coroutine
|
||||
def sleep_timeout():
|
||||
yield gen.sleep(timeout)
|
||||
timed_out[:] = ['timed out']
|
||||
self.loop.stop()
|
||||
self.loop.add_callback(lambda : sleep_timeout())
|
||||
self.loop.start()
|
||||
assert not timed_out
|
||||
|
||||
def test_callable_check(self):
|
||||
"""Ensure callable check works (py3k)."""
|
||||
|
||||
self.stream.on_send(lambda *args: None)
|
||||
self.stream.on_recv(lambda *args: None)
|
||||
self.assertRaises(AssertionError, self.stream.on_recv, 1)
|
||||
self.assertRaises(AssertionError, self.stream.on_send, 1)
|
||||
self.assertRaises(AssertionError, self.stream.on_recv, zmq)
|
||||
|
||||
def test_on_recv_basic(self):
|
||||
sent = [b'basic']
|
||||
def callback(msg):
|
||||
assert msg == sent
|
||||
self.loop.stop()
|
||||
self.loop.add_callback(lambda : self.push.send_multipart(sent))
|
||||
self.pull.on_recv(callback)
|
||||
self.run_until_timeout()
|
||||
|
||||
def test_on_recv_wake(self):
|
||||
sent = [b'wake']
|
||||
def callback(msg):
|
||||
assert msg == sent
|
||||
self.loop.stop()
|
||||
self.pull.on_recv(callback)
|
||||
self.loop.call_later(1, lambda : self.push.send_multipart(sent))
|
||||
self.run_until_timeout()
|
Loading…
Add table
Add a link
Reference in a new issue