557 lines
20 KiB
Python
557 lines
20 KiB
Python
# -*- coding: utf8 -*-
|
|
|
|
# Copyright (C) PyZMQ Developers
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
import zmq.auth
|
|
from zmq.auth.thread import ThreadAuthenticator
|
|
|
|
from zmq.utils.strtypes import u
|
|
from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
|
|
|
|
|
|
class BaseAuthTestCase(BaseZMQTestCase):
|
|
def setUp(self):
|
|
if zmq.zmq_version_info() < (4,0):
|
|
raise SkipTest("security is new in libzmq 4.0")
|
|
try:
|
|
zmq.curve_keypair()
|
|
except zmq.ZMQError:
|
|
raise SkipTest("security requires libzmq to have curve support")
|
|
super(BaseAuthTestCase, self).setUp()
|
|
# enable debug logging while we run tests
|
|
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
|
self.auth = self.make_auth()
|
|
self.auth.start()
|
|
self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
|
|
|
|
def make_auth(self):
|
|
raise NotImplementedError()
|
|
|
|
def tearDown(self):
|
|
if self.auth:
|
|
self.auth.stop()
|
|
self.auth = None
|
|
self.remove_certs(self.base_dir)
|
|
super(BaseAuthTestCase, self).tearDown()
|
|
|
|
def create_certs(self):
|
|
"""Create CURVE certificates for a test"""
|
|
|
|
# Create temporary CURVE keypairs for this test run. We create all keys in a
|
|
# temp directory and then move them into the appropriate private or public
|
|
# directory.
|
|
|
|
base_dir = tempfile.mkdtemp()
|
|
keys_dir = os.path.join(base_dir, 'certificates')
|
|
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
|
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
|
|
|
os.mkdir(keys_dir)
|
|
os.mkdir(public_keys_dir)
|
|
os.mkdir(secret_keys_dir)
|
|
|
|
server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
|
|
client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
|
|
|
|
for key_file in os.listdir(keys_dir):
|
|
if key_file.endswith(".key"):
|
|
shutil.move(os.path.join(keys_dir, key_file),
|
|
os.path.join(public_keys_dir, '.'))
|
|
|
|
for key_file in os.listdir(keys_dir):
|
|
if key_file.endswith(".key_secret"):
|
|
shutil.move(os.path.join(keys_dir, key_file),
|
|
os.path.join(secret_keys_dir, '.'))
|
|
|
|
return (base_dir, public_keys_dir, secret_keys_dir)
|
|
|
|
def remove_certs(self, base_dir):
|
|
"""Remove certificates for a test"""
|
|
shutil.rmtree(base_dir)
|
|
|
|
def load_certs(self, secret_keys_dir):
|
|
"""Return server and client certificate keys"""
|
|
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
|
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
|
|
|
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
|
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
|
|
|
return server_public, server_secret, client_public, client_secret
|
|
|
|
|
|
class TestThreadAuthentication(BaseAuthTestCase):
|
|
"""Test authentication running in a thread"""
|
|
|
|
def make_auth(self):
|
|
return ThreadAuthenticator(self.context)
|
|
|
|
def can_connect(self, server, client):
|
|
"""Check if client can connect to server using tcp transport"""
|
|
result = False
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
msg = [b"Hello World"]
|
|
if server.poll(1000, zmq.POLLOUT):
|
|
server.send_multipart(msg)
|
|
if client.poll(1000):
|
|
rcvd_msg = client.recv_multipart()
|
|
self.assertEqual(rcvd_msg, msg)
|
|
result = True
|
|
return result
|
|
|
|
def test_null(self):
|
|
"""threaded auth - NULL"""
|
|
# A default NULL connection should always succeed, and not
|
|
# go through our authentication infrastructure at all.
|
|
self.auth.stop()
|
|
self.auth = None
|
|
# use a new context, so ZAP isn't inherited
|
|
self.context = self.Context()
|
|
|
|
server = self.socket(zmq.PUSH)
|
|
client = self.socket(zmq.PULL)
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet. The client connection
|
|
# should still be allowed.
|
|
server = self.socket(zmq.PUSH)
|
|
server.zap_domain = b'global'
|
|
client = self.socket(zmq.PULL)
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
def test_blacklist(self):
|
|
"""threaded auth - Blacklist"""
|
|
# Blacklist 127.0.0.1, connection should fail
|
|
self.auth.deny('127.0.0.1')
|
|
server = self.socket(zmq.PUSH)
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet.
|
|
server.zap_domain = b'global'
|
|
client = self.socket(zmq.PULL)
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
def test_whitelist(self):
|
|
"""threaded auth - Whitelist"""
|
|
# Whitelist 127.0.0.1, connection should pass"
|
|
self.auth.allow('127.0.0.1')
|
|
server = self.socket(zmq.PUSH)
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet.
|
|
server.zap_domain = b'global'
|
|
client = self.socket(zmq.PULL)
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
def test_plain(self):
|
|
"""threaded auth - PLAIN"""
|
|
|
|
# Try PLAIN authentication - without configuring server, connection should fail
|
|
server = self.socket(zmq.PUSH)
|
|
server.plain_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.plain_username = b'admin'
|
|
client.plain_password = b'Password'
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
# Try PLAIN authentication - with server configured, connection should pass
|
|
server = self.socket(zmq.PUSH)
|
|
server.plain_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.plain_username = b'admin'
|
|
client.plain_password = b'Password'
|
|
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
# Try PLAIN authentication - with bogus credentials, connection should fail
|
|
server = self.socket(zmq.PUSH)
|
|
server.plain_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.plain_username = b'admin'
|
|
client.plain_password = b'Bogus'
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
# Remove authenticator and check that a normal connection works
|
|
self.auth.stop()
|
|
self.auth = None
|
|
|
|
server = self.socket(zmq.PUSH)
|
|
client = self.socket(zmq.PULL)
|
|
self.assertTrue(self.can_connect(server, client))
|
|
client.close()
|
|
server.close()
|
|
|
|
def test_curve(self):
|
|
"""threaded auth - CURVE"""
|
|
self.auth.allow('127.0.0.1')
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
#Try CURVE authentication - without configuring server, connection should fail
|
|
server = self.socket(zmq.PUSH)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
#Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
|
|
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
|
server = self.socket(zmq.PUSH)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
# Try CURVE authentication - with server configured, connection should pass
|
|
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
|
server = self.socket(zmq.PULL)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PUSH)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
assert self.can_connect(client, server)
|
|
|
|
# Remove authenticator and check that a normal connection works
|
|
self.auth.stop()
|
|
self.auth = None
|
|
|
|
# Try connecting using NULL and no authentication enabled, connection should pass
|
|
server = self.socket(zmq.PUSH)
|
|
client = self.socket(zmq.PULL)
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
def test_curve_callback(self):
|
|
"""threaded auth - CURVE with callback authentication"""
|
|
self.auth.allow('127.0.0.1')
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
#Try CURVE authentication - without configuring server, connection should fail
|
|
server = self.socket(zmq.PUSH)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
#Try CURVE authentication - with callback authentication configured, connection should pass
|
|
|
|
class CredentialsProvider(object):
|
|
def __init__(self):
|
|
self.client = client_public
|
|
|
|
def callback(self, domain, key):
|
|
if (key == self.client):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
provider = CredentialsProvider()
|
|
self.auth.configure_curve_callback(credentials_provider=provider)
|
|
server = self.socket(zmq.PUSH)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
self.assertTrue(self.can_connect(server, client))
|
|
|
|
#Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
|
|
|
|
class WrongCredentialsProvider(object):
|
|
def __init__(self):
|
|
self.client = "WrongCredentials"
|
|
|
|
def callback(self, domain, key):
|
|
if (key == self.client):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
provider = WrongCredentialsProvider()
|
|
self.auth.configure_curve_callback(credentials_provider=provider)
|
|
server = self.socket(zmq.PUSH)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PULL)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
self.assertFalse(self.can_connect(server, client))
|
|
|
|
|
|
|
|
@skip_pypy
|
|
def test_curve_user_id(self):
|
|
"""threaded auth - CURVE"""
|
|
self.auth.allow('127.0.0.1')
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
|
server = self.socket(zmq.PULL)
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
client = self.socket(zmq.PUSH)
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
assert self.can_connect(client, server)
|
|
|
|
# test default user-id map
|
|
client.send(b'test')
|
|
msg = self.recv(server, copy=False)
|
|
assert msg.bytes == b'test'
|
|
try:
|
|
user_id = msg.get('User-Id')
|
|
except zmq.ZMQVersionError:
|
|
pass
|
|
else:
|
|
assert user_id == u(client_public)
|
|
|
|
# test custom user-id map
|
|
self.auth.curve_user_id = lambda client_key: u'custom'
|
|
|
|
client2 = self.socket(zmq.PUSH)
|
|
client2.curve_publickey = client_public
|
|
client2.curve_secretkey = client_secret
|
|
client2.curve_serverkey = server_public
|
|
assert self.can_connect(client2, server)
|
|
|
|
client2.send(b'test2')
|
|
msg = self.recv(server, copy=False)
|
|
assert msg.bytes == b'test2'
|
|
try:
|
|
user_id = msg.get('User-Id')
|
|
except zmq.ZMQVersionError:
|
|
pass
|
|
else:
|
|
assert user_id == u'custom'
|
|
|
|
|
|
def with_ioloop(method, expect_success=True):
|
|
"""decorator for running tests with an IOLoop"""
|
|
def test_method(self):
|
|
r = method(self)
|
|
|
|
loop = self.io_loop
|
|
if expect_success:
|
|
self.pullstream.on_recv(self.on_message_succeed)
|
|
else:
|
|
self.pullstream.on_recv(self.on_message_fail)
|
|
|
|
loop.call_later(1, self.attempt_connection)
|
|
loop.call_later(1.2, self.send_msg)
|
|
|
|
if expect_success:
|
|
loop.call_later(2, self.on_test_timeout_fail)
|
|
else:
|
|
loop.call_later(2, self.on_test_timeout_succeed)
|
|
|
|
loop.start()
|
|
if self.fail_msg:
|
|
self.fail(self.fail_msg)
|
|
|
|
return r
|
|
return test_method
|
|
|
|
def should_auth(method):
|
|
return with_ioloop(method, True)
|
|
|
|
def should_not_auth(method):
|
|
return with_ioloop(method, False)
|
|
|
|
class TestIOLoopAuthentication(BaseAuthTestCase):
|
|
"""Test authentication running in ioloop"""
|
|
|
|
def setUp(self):
|
|
try:
|
|
from tornado import ioloop
|
|
except ImportError:
|
|
pytest.skip("Requires tornado")
|
|
from zmq.eventloop import zmqstream
|
|
self.fail_msg = None
|
|
self.io_loop = ioloop.IOLoop()
|
|
super(TestIOLoopAuthentication, self).setUp()
|
|
self.server = self.socket(zmq.PUSH)
|
|
self.client = self.socket(zmq.PULL)
|
|
self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
|
|
self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
|
|
|
|
def make_auth(self):
|
|
from zmq.auth.ioloop import IOLoopAuthenticator
|
|
return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
|
|
|
|
def tearDown(self):
|
|
if self.auth:
|
|
self.auth.stop()
|
|
self.auth = None
|
|
self.io_loop.close(all_fds=True)
|
|
super(TestIOLoopAuthentication, self).tearDown()
|
|
|
|
def attempt_connection(self):
|
|
"""Check if client can connect to server using tcp transport"""
|
|
iface = 'tcp://127.0.0.1'
|
|
port = self.server.bind_to_random_port(iface)
|
|
self.client.connect("%s:%i" % (iface, port))
|
|
|
|
def send_msg(self):
|
|
"""Send a message from server to a client"""
|
|
msg = [b"Hello World"]
|
|
self.pushstream.send_multipart(msg)
|
|
|
|
def on_message_succeed(self, frames):
|
|
"""A message was received, as expected."""
|
|
if frames != [b"Hello World"]:
|
|
self.fail_msg = "Unexpected message received"
|
|
self.io_loop.stop()
|
|
|
|
def on_message_fail(self, frames):
|
|
"""A message was received, unexpectedly."""
|
|
self.fail_msg = 'Received messaged unexpectedly, security failed'
|
|
self.io_loop.stop()
|
|
|
|
def on_test_timeout_succeed(self):
|
|
"""Test timer expired, indicates test success"""
|
|
self.io_loop.stop()
|
|
|
|
def on_test_timeout_fail(self):
|
|
"""Test timer expired, indicates test failure"""
|
|
self.fail_msg = 'Test timed out'
|
|
self.io_loop.stop()
|
|
|
|
@should_auth
|
|
def test_none(self):
|
|
"""ioloop auth - NONE"""
|
|
# A default NULL connection should always succeed, and not
|
|
# go through our authentication infrastructure at all.
|
|
# no auth should be running
|
|
self.auth.stop()
|
|
self.auth = None
|
|
|
|
@should_auth
|
|
def test_null(self):
|
|
"""ioloop auth - NULL"""
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet. The client connection
|
|
# should still be allowed.
|
|
self.server.zap_domain = b'global'
|
|
|
|
@should_not_auth
|
|
def test_blacklist(self):
|
|
"""ioloop auth - Blacklist"""
|
|
# Blacklist 127.0.0.1, connection should fail
|
|
self.auth.deny('127.0.0.1')
|
|
self.server.zap_domain = b'global'
|
|
|
|
@should_auth
|
|
def test_whitelist(self):
|
|
"""ioloop auth - Whitelist"""
|
|
# Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
|
|
self.auth.allow('127.0.0.1')
|
|
|
|
self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
|
|
|
|
@should_not_auth
|
|
def test_plain_unconfigured_server(self):
|
|
"""ioloop auth - PLAIN, unconfigured server"""
|
|
self.client.plain_username = b'admin'
|
|
self.client.plain_password = b'Password'
|
|
# Try PLAIN authentication - without configuring server, connection should fail
|
|
self.server.plain_server = True
|
|
|
|
@should_auth
|
|
def test_plain_configured_server(self):
|
|
"""ioloop auth - PLAIN, configured server"""
|
|
self.client.plain_username = b'admin'
|
|
self.client.plain_password = b'Password'
|
|
# Try PLAIN authentication - with server configured, connection should pass
|
|
self.server.plain_server = True
|
|
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
|
|
|
@should_not_auth
|
|
def test_plain_bogus_credentials(self):
|
|
"""ioloop auth - PLAIN, bogus credentials"""
|
|
self.client.plain_username = b'admin'
|
|
self.client.plain_password = b'Bogus'
|
|
self.server.plain_server = True
|
|
|
|
self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
|
|
|
|
@should_not_auth
|
|
def test_curve_unconfigured_server(self):
|
|
"""ioloop auth - CURVE, unconfigured server"""
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
self.auth.allow('127.0.0.1')
|
|
|
|
self.server.curve_publickey = server_public
|
|
self.server.curve_secretkey = server_secret
|
|
self.server.curve_server = True
|
|
|
|
self.client.curve_publickey = client_public
|
|
self.client.curve_secretkey = client_secret
|
|
self.client.curve_serverkey = server_public
|
|
|
|
@should_auth
|
|
def test_curve_allow_any(self):
|
|
"""ioloop auth - CURVE, CURVE_ALLOW_ANY"""
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
self.auth.allow('127.0.0.1')
|
|
self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
|
|
|
|
self.server.curve_publickey = server_public
|
|
self.server.curve_secretkey = server_secret
|
|
self.server.curve_server = True
|
|
|
|
self.client.curve_publickey = client_public
|
|
self.client.curve_secretkey = client_secret
|
|
self.client.curve_serverkey = server_public
|
|
|
|
@should_auth
|
|
def test_curve_configured_server(self):
|
|
"""ioloop auth - CURVE, configured server"""
|
|
self.auth.allow('127.0.0.1')
|
|
certs = self.load_certs(self.secret_keys_dir)
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
self.auth.configure_curve(domain='*', location=self.public_keys_dir)
|
|
|
|
self.server.curve_publickey = server_public
|
|
self.server.curve_secretkey = server_secret
|
|
self.server.curve_server = True
|
|
|
|
self.client.curve_publickey = client_public
|
|
self.client.curve_secretkey = client_secret
|
|
self.client.curve_serverkey = server_public
|