236 lines
8 KiB
Python
236 lines
8 KiB
Python
"""Test libzmq security (libzmq >= 3.3.0)"""
|
|
# -*- coding: utf8 -*-
|
|
|
|
# Copyright (C) PyZMQ Developers
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import os
|
|
import contextlib
|
|
import time
|
|
from threading import Thread
|
|
|
|
import zmq
|
|
from zmq.tests import (
|
|
BaseZMQTestCase, SkipTest, PYPY
|
|
)
|
|
from zmq.utils import z85
|
|
|
|
|
|
USER = b"admin"
|
|
PASS = b"password"
|
|
|
|
class TestSecurity(BaseZMQTestCase):
|
|
|
|
def setUp(self):
|
|
if zmq.zmq_version_info() < (4,0):
|
|
raise SkipTest("security is new in libzmq 4.0")
|
|
try:
|
|
zmq.curve_keypair()
|
|
except zmq.ZMQError:
|
|
raise SkipTest("security requires libzmq to be built with CURVE support")
|
|
super(TestSecurity, self).setUp()
|
|
|
|
def zap_handler(self):
|
|
socket = self.context.socket(zmq.REP)
|
|
socket.bind("inproc://zeromq.zap.01")
|
|
try:
|
|
msg = self.recv_multipart(socket)
|
|
|
|
version, sequence, domain, address, identity, mechanism = msg[:6]
|
|
if mechanism == b'PLAIN':
|
|
username, password = msg[6:]
|
|
elif mechanism == b'CURVE':
|
|
key = msg[6]
|
|
|
|
self.assertEqual(version, b"1.0")
|
|
self.assertEqual(identity, b"IDENT")
|
|
reply = [version, sequence]
|
|
if mechanism == b'CURVE' or \
|
|
(mechanism == b'PLAIN' and username == USER and password == PASS) or \
|
|
(mechanism == b'NULL'):
|
|
reply.extend([
|
|
b"200",
|
|
b"OK",
|
|
b"anonymous",
|
|
b"\5Hello\0\0\0\5World",
|
|
])
|
|
else:
|
|
reply.extend([
|
|
b"400",
|
|
b"Invalid username or password",
|
|
b"",
|
|
b"",
|
|
])
|
|
socket.send_multipart(reply)
|
|
finally:
|
|
socket.close()
|
|
|
|
@contextlib.contextmanager
|
|
def zap(self):
|
|
self.start_zap()
|
|
time.sleep(0.5) # allow time for the Thread to start
|
|
try:
|
|
yield
|
|
finally:
|
|
self.stop_zap()
|
|
|
|
def start_zap(self):
|
|
self.zap_thread = Thread(target=self.zap_handler)
|
|
self.zap_thread.start()
|
|
|
|
def stop_zap(self):
|
|
self.zap_thread.join()
|
|
|
|
def bounce(self, server, client, test_metadata=True):
|
|
msg = [os.urandom(64), os.urandom(64)]
|
|
client.send_multipart(msg)
|
|
frames = self.recv_multipart(server, copy=False)
|
|
recvd = list(map(lambda x: x.bytes, frames))
|
|
|
|
try:
|
|
if test_metadata and not PYPY:
|
|
for frame in frames:
|
|
self.assertEqual(frame.get('User-Id'), 'anonymous')
|
|
self.assertEqual(frame.get('Hello'), 'World')
|
|
self.assertEqual(frame['Socket-Type'], 'DEALER')
|
|
except zmq.ZMQVersionError:
|
|
pass
|
|
|
|
self.assertEqual(recvd, msg)
|
|
server.send_multipart(recvd)
|
|
msg2 = self.recv_multipart(client)
|
|
self.assertEqual(msg2, msg)
|
|
|
|
def test_null(self):
|
|
"""test NULL (default) security"""
|
|
server = self.socket(zmq.DEALER)
|
|
client = self.socket(zmq.DEALER)
|
|
self.assertEqual(client.MECHANISM, zmq.NULL)
|
|
self.assertEqual(server.mechanism, zmq.NULL)
|
|
self.assertEqual(client.plain_server, 0)
|
|
self.assertEqual(server.plain_server, 0)
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
self.bounce(server, client, False)
|
|
|
|
def test_plain(self):
|
|
"""test PLAIN authentication"""
|
|
server = self.socket(zmq.DEALER)
|
|
server.identity = b'IDENT'
|
|
client = self.socket(zmq.DEALER)
|
|
self.assertEqual(client.plain_username, b'')
|
|
self.assertEqual(client.plain_password, b'')
|
|
client.plain_username = USER
|
|
client.plain_password = PASS
|
|
self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER)
|
|
self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS)
|
|
self.assertEqual(client.plain_server, 0)
|
|
self.assertEqual(server.plain_server, 0)
|
|
server.plain_server = True
|
|
self.assertEqual(server.mechanism, zmq.PLAIN)
|
|
self.assertEqual(client.mechanism, zmq.PLAIN)
|
|
|
|
assert not client.plain_server
|
|
assert server.plain_server
|
|
|
|
with self.zap():
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
self.bounce(server, client)
|
|
|
|
def skip_plain_inauth(self):
|
|
"""test PLAIN failed authentication"""
|
|
server = self.socket(zmq.DEALER)
|
|
server.identity = b'IDENT'
|
|
client = self.socket(zmq.DEALER)
|
|
self.sockets.extend([server, client])
|
|
client.plain_username = USER
|
|
client.plain_password = b'incorrect'
|
|
server.plain_server = True
|
|
self.assertEqual(server.mechanism, zmq.PLAIN)
|
|
self.assertEqual(client.mechanism, zmq.PLAIN)
|
|
|
|
with self.zap():
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
client.send(b'ping')
|
|
server.rcvtimeo = 250
|
|
self.assertRaisesErrno(zmq.EAGAIN, server.recv)
|
|
|
|
def test_keypair(self):
|
|
"""test curve_keypair"""
|
|
try:
|
|
public, secret = zmq.curve_keypair()
|
|
except zmq.ZMQError:
|
|
raise SkipTest("CURVE unsupported")
|
|
|
|
self.assertEqual(type(secret), bytes)
|
|
self.assertEqual(type(public), bytes)
|
|
self.assertEqual(len(secret), 40)
|
|
self.assertEqual(len(public), 40)
|
|
|
|
# verify that it is indeed Z85
|
|
bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ]
|
|
self.assertEqual(type(bsecret), bytes)
|
|
self.assertEqual(type(bpublic), bytes)
|
|
self.assertEqual(len(bsecret), 32)
|
|
self.assertEqual(len(bpublic), 32)
|
|
|
|
def test_curve_public(self):
|
|
"""test curve_public"""
|
|
try:
|
|
public, secret = zmq.curve_keypair()
|
|
except zmq.ZMQError:
|
|
raise SkipTest("CURVE unsupported")
|
|
if zmq.zmq_version_info() < (4,2):
|
|
raise SkipTest("curve_public is new in libzmq 4.2")
|
|
|
|
derived_public = zmq.curve_public(secret)
|
|
|
|
self.assertEqual(type(derived_public), bytes)
|
|
self.assertEqual(len(derived_public), 40)
|
|
|
|
# verify that it is indeed Z85
|
|
bpublic = z85.decode(derived_public)
|
|
self.assertEqual(type(bpublic), bytes)
|
|
self.assertEqual(len(bpublic), 32)
|
|
|
|
# verify that it is equal to the known public key
|
|
self.assertEqual(derived_public, public)
|
|
|
|
def test_curve(self):
|
|
"""test CURVE encryption"""
|
|
server = self.socket(zmq.DEALER)
|
|
server.identity = b'IDENT'
|
|
client = self.socket(zmq.DEALER)
|
|
self.sockets.extend([server, client])
|
|
try:
|
|
server.curve_server = True
|
|
except zmq.ZMQError as e:
|
|
# will raise EINVAL if no CURVE support
|
|
if e.errno == zmq.EINVAL:
|
|
raise SkipTest("CURVE unsupported")
|
|
|
|
server_public, server_secret = zmq.curve_keypair()
|
|
client_public, client_secret = zmq.curve_keypair()
|
|
|
|
server.curve_secretkey = server_secret
|
|
server.curve_publickey = server_public
|
|
client.curve_serverkey = server_public
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
|
|
self.assertEqual(server.mechanism, zmq.CURVE)
|
|
self.assertEqual(client.mechanism, zmq.CURVE)
|
|
|
|
self.assertEqual(server.get(zmq.CURVE_SERVER), True)
|
|
self.assertEqual(client.get(zmq.CURVE_SERVER), False)
|
|
|
|
with self.zap():
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
self.bounce(server, client)
|