Uploaded Test files
This commit is contained in:
parent
f584ad9d97
commit
2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions
972
venv/Lib/site-packages/jupyter_client/session.py
Normal file
972
venv/Lib/site-packages/jupyter_client/session.py
Normal file
|
@ -0,0 +1,972 @@
|
|||
"""Session object for building, serializing, sending, and receiving messages.
|
||||
|
||||
The Session object supports serialization, HMAC signatures,
|
||||
and metadata on messages.
|
||||
|
||||
Also defined here are utilities for working with Sessions:
|
||||
* A SessionFactory to be used as a base class for configurables that work with
|
||||
Sessions.
|
||||
* A Message object for convenience that allows attribute-access to the msg dict.
|
||||
"""
|
||||
|
||||
# Copyright (c) Jupyter Development Team.
|
||||
# Distributed under the terms of the Modified BSD License.
|
||||
|
||||
from binascii import b2a_hex
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
import random
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
|
||||
|
||||
# We are using compare_digest to limit the surface of timing attacks
|
||||
from hmac import compare_digest
|
||||
|
||||
from datetime import timezone
|
||||
utc = timezone.utc
|
||||
|
||||
import zmq
|
||||
from zmq.utils import jsonapi
|
||||
from zmq.eventloop.ioloop import IOLoop
|
||||
from zmq.eventloop.zmqstream import ZMQStream
|
||||
|
||||
from traitlets.config.configurable import Configurable, LoggingConfigurable
|
||||
from ipython_genutils.importstring import import_item
|
||||
from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
|
||||
from ipython_genutils.py3compat import str_to_bytes, str_to_unicode
|
||||
from traitlets import (
|
||||
CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode,
|
||||
Dict, Integer, TraitError, observe
|
||||
)
|
||||
from jupyter_client import protocol_version
|
||||
from jupyter_client.adapter import adapt
|
||||
from traitlets.log import get_logger
|
||||
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# utility functions
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
def squash_unicode(obj):
|
||||
"""coerce unicode back to bytestrings."""
|
||||
if isinstance(obj,dict):
|
||||
for key in obj.keys():
|
||||
obj[key] = squash_unicode(obj[key])
|
||||
if isinstance(key, str):
|
||||
obj[squash_unicode(key)] = obj.pop(key)
|
||||
elif isinstance(obj, list):
|
||||
for i,v in enumerate(obj):
|
||||
obj[i] = squash_unicode(v)
|
||||
elif isinstance(obj, str):
|
||||
obj = obj.encode('utf8')
|
||||
return obj
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# globals and defaults
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
# default values for the thresholds:
|
||||
MAX_ITEMS = 64
|
||||
MAX_BYTES = 1024
|
||||
|
||||
# ISO8601-ify datetime objects
|
||||
# allow unicode
|
||||
# disallow nan, because it's not actually valid JSON
|
||||
json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
|
||||
ensure_ascii=False, allow_nan=False,
|
||||
)
|
||||
json_unpacker = lambda s: jsonapi.loads(s)
|
||||
|
||||
pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
|
||||
pickle_unpacker = pickle.loads
|
||||
|
||||
default_packer = json_packer
|
||||
default_unpacker = json_unpacker
|
||||
|
||||
DELIM = b"<IDS|MSG>"
|
||||
# singleton dummy tracker, which will always report as done
|
||||
DONE = zmq.MessageTracker()
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# Mixin tools for apps that use Sessions
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
def new_id():
|
||||
"""Generate a new random id.
|
||||
|
||||
Avoids problematic runtime import in stdlib uuid on Python 2.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
id string (16 random bytes as hex-encoded text, chunks separated by '-')
|
||||
"""
|
||||
buf = os.urandom(16)
|
||||
return '-'.join(b2a_hex(x).decode('ascii') for x in (
|
||||
buf[:4], buf[4:]
|
||||
))
|
||||
|
||||
def new_id_bytes():
|
||||
"""Return new_id as ascii bytes"""
|
||||
return new_id().encode('ascii')
|
||||
|
||||
session_aliases = dict(
|
||||
ident = 'Session.session',
|
||||
user = 'Session.username',
|
||||
keyfile = 'Session.keyfile',
|
||||
)
|
||||
|
||||
session_flags = {
|
||||
'secure' : ({'Session' : { 'key' : new_id_bytes(),
|
||||
'keyfile' : '' }},
|
||||
"""Use HMAC digests for authentication of messages.
|
||||
Setting this flag will generate a new UUID to use as the HMAC key.
|
||||
"""),
|
||||
'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
|
||||
"""Don't authenticate messages."""),
|
||||
}
|
||||
|
||||
def default_secure(cfg):
|
||||
"""Set the default behavior for a config environment to be secure.
|
||||
|
||||
If Session.key/keyfile have not been set, set Session.key to
|
||||
a new random UUID.
|
||||
"""
|
||||
warnings.warn("default_secure is deprecated", DeprecationWarning)
|
||||
if 'Session' in cfg:
|
||||
if 'key' in cfg.Session or 'keyfile' in cfg.Session:
|
||||
return
|
||||
# key/keyfile not specified, generate new UUID:
|
||||
cfg.Session.key = new_id_bytes()
|
||||
|
||||
def utcnow():
|
||||
"""Return timezone-aware UTC timestamp"""
|
||||
return datetime.utcnow().replace(tzinfo=utc)
|
||||
|
||||
#-----------------------------------------------------------------------------
|
||||
# Classes
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
class SessionFactory(LoggingConfigurable):
|
||||
"""The Base class for configurables that have a Session, Context, logger,
|
||||
and IOLoop.
|
||||
"""
|
||||
|
||||
logname = Unicode('')
|
||||
|
||||
@observe('logname')
|
||||
def _logname_changed(self, change):
|
||||
self.log = logging.getLogger(change['new'])
|
||||
|
||||
# not configurable:
|
||||
context = Instance('zmq.Context')
|
||||
def _context_default(self):
|
||||
return zmq.Context()
|
||||
|
||||
session = Instance('jupyter_client.session.Session',
|
||||
allow_none=True)
|
||||
|
||||
loop = Instance('tornado.ioloop.IOLoop')
|
||||
def _loop_default(self):
|
||||
return IOLoop.current()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.session is None:
|
||||
# construct the session
|
||||
self.session = Session(**kwargs)
|
||||
|
||||
|
||||
class Message(object):
|
||||
"""A simple message object that maps dict keys to attributes.
|
||||
|
||||
A Message can be created from a dict and a dict from a Message instance
|
||||
simply by calling dict(msg_obj)."""
|
||||
|
||||
def __init__(self, msg_dict):
|
||||
dct = self.__dict__
|
||||
for k, v in dict(msg_dict).items():
|
||||
if isinstance(v, dict):
|
||||
v = Message(v)
|
||||
dct[k] = v
|
||||
|
||||
# Having this iterator lets dict(msg_obj) work out of the box.
|
||||
def __iter__(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.__dict__)
|
||||
|
||||
def __str__(self):
|
||||
return pprint.pformat(self.__dict__)
|
||||
|
||||
def __contains__(self, k):
|
||||
return k in self.__dict__
|
||||
|
||||
def __getitem__(self, k):
|
||||
return self.__dict__[k]
|
||||
|
||||
|
||||
def msg_header(msg_id, msg_type, username, session):
|
||||
"""Create a new message header"""
|
||||
date = utcnow()
|
||||
version = protocol_version
|
||||
return locals()
|
||||
|
||||
def extract_header(msg_or_header):
|
||||
"""Given a message or header, return the header."""
|
||||
if not msg_or_header:
|
||||
return {}
|
||||
try:
|
||||
# See if msg_or_header is the entire message.
|
||||
h = msg_or_header['header']
|
||||
except KeyError:
|
||||
try:
|
||||
# See if msg_or_header is just the header
|
||||
h = msg_or_header['msg_id']
|
||||
except KeyError:
|
||||
raise
|
||||
else:
|
||||
h = msg_or_header
|
||||
if not isinstance(h, dict):
|
||||
h = dict(h)
|
||||
return h
|
||||
|
||||
class Session(Configurable):
|
||||
"""Object for handling serialization and sending of messages.
|
||||
|
||||
The Session object handles building messages and sending them
|
||||
with ZMQ sockets or ZMQStream objects. Objects can communicate with each
|
||||
other over the network via Session objects, and only need to work with the
|
||||
dict-based IPython message spec. The Session will handle
|
||||
serialization/deserialization, security, and metadata.
|
||||
|
||||
Sessions support configurable serialization via packer/unpacker traits,
|
||||
and signing with HMAC digests via the key/keyfile traits.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
debug : bool
|
||||
whether to trigger extra debugging statements
|
||||
packer/unpacker : str : 'json', 'pickle' or import_string
|
||||
importstrings for methods to serialize message parts. If just
|
||||
'json' or 'pickle', predefined JSON and pickle packers will be used.
|
||||
Otherwise, the entire importstring must be used.
|
||||
|
||||
The functions must accept at least valid JSON input, and output *bytes*.
|
||||
|
||||
For example, to use msgpack:
|
||||
packer = 'msgpack.packb', unpacker='msgpack.unpackb'
|
||||
pack/unpack : callables
|
||||
You can also set the pack/unpack callables for serialization directly.
|
||||
session : bytes
|
||||
the ID of this Session object. The default is to generate a new UUID.
|
||||
username : unicode
|
||||
username added to message headers. The default is to ask the OS.
|
||||
key : bytes
|
||||
The key used to initialize an HMAC signature. If unset, messages
|
||||
will not be signed or checked.
|
||||
keyfile : filepath
|
||||
The file containing a key. If this is set, `key` will be initialized
|
||||
to the contents of the file.
|
||||
|
||||
"""
|
||||
|
||||
debug = Bool(False, config=True, help="""Debug output in the Session""")
|
||||
|
||||
check_pid = Bool(True, config=True,
|
||||
help="""Whether to check PID to protect against calls after fork.
|
||||
|
||||
This check can be disabled if fork-safety is handled elsewhere.
|
||||
""")
|
||||
|
||||
packer = DottedObjectName('json',config=True,
|
||||
help="""The name of the packer for serializing messages.
|
||||
Should be one of 'json', 'pickle', or an import name
|
||||
for a custom callable serializer.""")
|
||||
|
||||
@observe('packer')
|
||||
def _packer_changed(self, change):
|
||||
new = change['new']
|
||||
if new.lower() == 'json':
|
||||
self.pack = json_packer
|
||||
self.unpack = json_unpacker
|
||||
self.unpacker = new
|
||||
elif new.lower() == 'pickle':
|
||||
self.pack = pickle_packer
|
||||
self.unpack = pickle_unpacker
|
||||
self.unpacker = new
|
||||
else:
|
||||
self.pack = import_item(str(new))
|
||||
|
||||
unpacker = DottedObjectName('json', config=True,
|
||||
help="""The name of the unpacker for unserializing messages.
|
||||
Only used with custom functions for `packer`.""")
|
||||
|
||||
@observe('unpacker')
|
||||
def _unpacker_changed(self, change):
|
||||
new = change['new']
|
||||
if new.lower() == 'json':
|
||||
self.pack = json_packer
|
||||
self.unpack = json_unpacker
|
||||
self.packer = new
|
||||
elif new.lower() == 'pickle':
|
||||
self.pack = pickle_packer
|
||||
self.unpack = pickle_unpacker
|
||||
self.packer = new
|
||||
else:
|
||||
self.unpack = import_item(str(new))
|
||||
|
||||
session = CUnicode('', config=True,
|
||||
help="""The UUID identifying this session.""")
|
||||
def _session_default(self):
|
||||
u = new_id()
|
||||
self.bsession = u.encode('ascii')
|
||||
return u
|
||||
|
||||
@observe('session')
|
||||
def _session_changed(self, change):
|
||||
self.bsession = self.session.encode('ascii')
|
||||
|
||||
# bsession is the session as bytes
|
||||
bsession = CBytes(b'')
|
||||
|
||||
username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
|
||||
help="""Username for the Session. Default is your system username.""",
|
||||
config=True)
|
||||
|
||||
metadata = Dict({}, config=True,
|
||||
help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
|
||||
|
||||
# if 0, no adapting to do.
|
||||
adapt_version = Integer(0)
|
||||
|
||||
# message signature related traits:
|
||||
|
||||
key = CBytes(config=True,
|
||||
help="""execution key, for signing messages.""")
|
||||
def _key_default(self):
|
||||
return new_id_bytes()
|
||||
|
||||
@observe('key')
|
||||
def _key_changed(self, change):
|
||||
self._new_auth()
|
||||
|
||||
signature_scheme = Unicode('hmac-sha256', config=True,
|
||||
help="""The digest scheme used to construct the message signatures.
|
||||
Must have the form 'hmac-HASH'.""")
|
||||
|
||||
@observe('signature_scheme')
|
||||
def _signature_scheme_changed(self, change):
|
||||
new = change['new']
|
||||
if not new.startswith('hmac-'):
|
||||
raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
|
||||
hash_name = new.split('-', 1)[1]
|
||||
try:
|
||||
self.digest_mod = getattr(hashlib, hash_name)
|
||||
except AttributeError as e:
|
||||
raise TraitError("hashlib has no such attribute: %s" %
|
||||
hash_name) from e
|
||||
self._new_auth()
|
||||
|
||||
digest_mod = Any()
|
||||
def _digest_mod_default(self):
|
||||
return hashlib.sha256
|
||||
|
||||
auth = Instance(hmac.HMAC, allow_none=True)
|
||||
|
||||
def _new_auth(self):
|
||||
if self.key:
|
||||
self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
|
||||
else:
|
||||
self.auth = None
|
||||
|
||||
digest_history = Set()
|
||||
digest_history_size = Integer(2**16, config=True,
|
||||
help="""The maximum number of digests to remember.
|
||||
|
||||
The digest history will be culled when it exceeds this value.
|
||||
"""
|
||||
)
|
||||
|
||||
keyfile = Unicode('', config=True,
|
||||
help="""path to file containing execution key.""")
|
||||
|
||||
@observe('keyfile')
|
||||
def _keyfile_changed(self, change):
|
||||
with open(change['new'], 'rb') as f:
|
||||
self.key = f.read().strip()
|
||||
|
||||
# for protecting against sends from forks
|
||||
pid = Integer()
|
||||
|
||||
# serialization traits:
|
||||
|
||||
pack = Any(default_packer) # the actual packer function
|
||||
|
||||
@observe('pack')
|
||||
def _pack_changed(self, change):
|
||||
new = change['new']
|
||||
if not callable(new):
|
||||
raise TypeError("packer must be callable, not %s"%type(new))
|
||||
|
||||
unpack = Any(default_unpacker) # the actual packer function
|
||||
|
||||
@observe('unpack')
|
||||
def _unpack_changed(self, change):
|
||||
# unpacker is not checked - it is assumed to be
|
||||
new = change['new']
|
||||
if not callable(new):
|
||||
raise TypeError("unpacker must be callable, not %s"%type(new))
|
||||
|
||||
# thresholds:
|
||||
copy_threshold = Integer(2**16, config=True,
|
||||
help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
|
||||
buffer_threshold = Integer(MAX_BYTES, config=True,
|
||||
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
|
||||
item_threshold = Integer(MAX_ITEMS, config=True,
|
||||
help="""The maximum number of items for a container to be introspected for custom serialization.
|
||||
Containers larger than this are pickled outright.
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""create a Session object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
debug : bool
|
||||
whether to trigger extra debugging statements
|
||||
packer/unpacker : str : 'json', 'pickle' or import_string
|
||||
importstrings for methods to serialize message parts. If just
|
||||
'json' or 'pickle', predefined JSON and pickle packers will be used.
|
||||
Otherwise, the entire importstring must be used.
|
||||
|
||||
The functions must accept at least valid JSON input, and output
|
||||
*bytes*.
|
||||
|
||||
For example, to use msgpack:
|
||||
packer = 'msgpack.packb', unpacker='msgpack.unpackb'
|
||||
pack/unpack : callables
|
||||
You can also set the pack/unpack callables for serialization
|
||||
directly.
|
||||
session : unicode (must be ascii)
|
||||
the ID of this Session object. The default is to generate a new
|
||||
UUID.
|
||||
bsession : bytes
|
||||
The session as bytes
|
||||
username : unicode
|
||||
username added to message headers. The default is to ask the OS.
|
||||
key : bytes
|
||||
The key used to initialize an HMAC signature. If unset, messages
|
||||
will not be signed or checked.
|
||||
signature_scheme : str
|
||||
The message digest scheme. Currently must be of the form 'hmac-HASH',
|
||||
where 'HASH' is a hashing function available in Python's hashlib.
|
||||
The default is 'hmac-sha256'.
|
||||
This is ignored if 'key' is empty.
|
||||
keyfile : filepath
|
||||
The file containing a key. If this is set, `key` will be
|
||||
initialized to the contents of the file.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._check_packers()
|
||||
self.none = self.pack({})
|
||||
# ensure self._session_default() if necessary, so bsession is defined:
|
||||
self.session
|
||||
self.pid = os.getpid()
|
||||
self._new_auth()
|
||||
if not self.key:
|
||||
get_logger().warning("Message signing is disabled. This is insecure and not recommended!")
|
||||
|
||||
def clone(self):
|
||||
"""Create a copy of this Session
|
||||
|
||||
Useful when connecting multiple times to a given kernel.
|
||||
This prevents a shared digest_history warning about duplicate digests
|
||||
due to multiple connections to IOPub in the same process.
|
||||
|
||||
.. versionadded:: 5.1
|
||||
"""
|
||||
# make a copy
|
||||
new_session = type(self)()
|
||||
for name in self.traits():
|
||||
setattr(new_session, name, getattr(self, name))
|
||||
# fork digest_history
|
||||
new_session.digest_history = set()
|
||||
new_session.digest_history.update(self.digest_history)
|
||||
return new_session
|
||||
|
||||
message_count = 0
|
||||
@property
|
||||
def msg_id(self):
|
||||
message_number = self.message_count
|
||||
self.message_count += 1
|
||||
return '{}_{}'.format(self.session, message_number)
|
||||
|
||||
def _check_packers(self):
|
||||
"""check packers for datetime support."""
|
||||
pack = self.pack
|
||||
unpack = self.unpack
|
||||
|
||||
# check simple serialization
|
||||
msg = dict(a=[1,'hi'])
|
||||
try:
|
||||
packed = pack(msg)
|
||||
except Exception as e:
|
||||
msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
|
||||
if self.packer == 'json':
|
||||
jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
|
||||
else:
|
||||
jsonmsg = ""
|
||||
raise ValueError(
|
||||
msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
|
||||
) from e
|
||||
|
||||
# ensure packed message is bytes
|
||||
if not isinstance(packed, bytes):
|
||||
raise ValueError("message packed to %r, but bytes are required"%type(packed))
|
||||
|
||||
# check that unpack is pack's inverse
|
||||
try:
|
||||
unpacked = unpack(packed)
|
||||
assert unpacked == msg
|
||||
except Exception as e:
|
||||
msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
|
||||
if self.packer == 'json':
|
||||
jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
|
||||
else:
|
||||
jsonmsg = ""
|
||||
raise ValueError(
|
||||
msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
|
||||
) from e
|
||||
|
||||
# check datetime support
|
||||
msg = dict(t=utcnow())
|
||||
try:
|
||||
unpacked = unpack(pack(msg))
|
||||
if isinstance(unpacked['t'], datetime):
|
||||
raise ValueError("Shouldn't deserialize to datetime")
|
||||
except Exception:
|
||||
self.pack = lambda o: pack(squash_dates(o))
|
||||
self.unpack = lambda s: unpack(s)
|
||||
|
||||
def msg_header(self, msg_type):
|
||||
return msg_header(self.msg_id, msg_type, self.username, self.session)
|
||||
|
||||
def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
|
||||
"""Return the nested message dict.
|
||||
|
||||
This format is different from what is sent over the wire. The
|
||||
serialize/deserialize methods converts this nested message dict to the wire
|
||||
format, which is a list of message parts.
|
||||
"""
|
||||
msg = {}
|
||||
header = self.msg_header(msg_type) if header is None else header
|
||||
msg['header'] = header
|
||||
msg['msg_id'] = header['msg_id']
|
||||
msg['msg_type'] = header['msg_type']
|
||||
msg['parent_header'] = {} if parent is None else extract_header(parent)
|
||||
msg['content'] = {} if content is None else content
|
||||
msg['metadata'] = self.metadata.copy()
|
||||
if metadata is not None:
|
||||
msg['metadata'].update(metadata)
|
||||
return msg
|
||||
|
||||
def sign(self, msg_list):
|
||||
"""Sign a message with HMAC digest. If no auth, return b''.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg_list : list
|
||||
The [p_header,p_parent,p_content] part of the message list.
|
||||
"""
|
||||
if self.auth is None:
|
||||
return b''
|
||||
h = self.auth.copy()
|
||||
for m in msg_list:
|
||||
h.update(m)
|
||||
return str_to_bytes(h.hexdigest())
|
||||
|
||||
def serialize(self, msg, ident=None):
|
||||
"""Serialize the message components to bytes.
|
||||
|
||||
This is roughly the inverse of deserialize. The serialize/deserialize
|
||||
methods work with full message lists, whereas pack/unpack work with
|
||||
the individual message parts in the message list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : dict or Message
|
||||
The next message dict as returned by the self.msg method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
msg_list : list
|
||||
The list of bytes objects to be sent with the format::
|
||||
|
||||
[ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
|
||||
p_metadata, p_content, buffer1, buffer2, ...]
|
||||
|
||||
In this list, the ``p_*`` entities are the packed or serialized
|
||||
versions, so if JSON is used, these are utf8 encoded JSON strings.
|
||||
"""
|
||||
content = msg.get('content', {})
|
||||
if content is None:
|
||||
content = self.none
|
||||
elif isinstance(content, dict):
|
||||
content = self.pack(content)
|
||||
elif isinstance(content, bytes):
|
||||
# content is already packed, as in a relayed message
|
||||
pass
|
||||
elif isinstance(content, str):
|
||||
# should be bytes, but JSON often spits out unicode
|
||||
content = content.encode('utf8')
|
||||
else:
|
||||
raise TypeError("Content incorrect type: %s"%type(content))
|
||||
|
||||
real_message = [self.pack(msg['header']),
|
||||
self.pack(msg['parent_header']),
|
||||
self.pack(msg['metadata']),
|
||||
content,
|
||||
]
|
||||
|
||||
to_send = []
|
||||
|
||||
if isinstance(ident, list):
|
||||
# accept list of idents
|
||||
to_send.extend(ident)
|
||||
elif ident is not None:
|
||||
to_send.append(ident)
|
||||
to_send.append(DELIM)
|
||||
|
||||
signature = self.sign(real_message)
|
||||
to_send.append(signature)
|
||||
|
||||
to_send.extend(real_message)
|
||||
|
||||
return to_send
|
||||
|
||||
def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
|
||||
buffers=None, track=False, header=None, metadata=None):
|
||||
"""Build and send a message via stream or socket.
|
||||
|
||||
The message format used by this function internally is as follows:
|
||||
|
||||
[ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
|
||||
buffer1,buffer2,...]
|
||||
|
||||
The serialize/deserialize methods convert the nested message dict into this
|
||||
format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
stream : zmq.Socket or ZMQStream
|
||||
The socket-like object used to send the data.
|
||||
msg_or_type : str or Message/dict
|
||||
Normally, msg_or_type will be a msg_type unless a message is being
|
||||
sent more than once. If a header is supplied, this can be set to
|
||||
None and the msg_type will be pulled from the header.
|
||||
|
||||
content : dict or None
|
||||
The content of the message (ignored if msg_or_type is a message).
|
||||
header : dict or None
|
||||
The header dict for the message (ignored if msg_to_type is a message).
|
||||
parent : Message or dict or None
|
||||
The parent or parent header describing the parent of this message
|
||||
(ignored if msg_or_type is a message).
|
||||
ident : bytes or list of bytes
|
||||
The zmq.IDENTITY routing path.
|
||||
metadata : dict or None
|
||||
The metadata describing the message
|
||||
buffers : list or None
|
||||
The already-serialized buffers to be appended to the message.
|
||||
track : bool
|
||||
Whether to track. Only for use with Sockets, because ZMQStream
|
||||
objects cannot track messages.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
msg : dict
|
||||
The constructed message.
|
||||
"""
|
||||
if not isinstance(stream, zmq.Socket):
|
||||
# ZMQStreams and dummy sockets do not support tracking.
|
||||
track = False
|
||||
|
||||
if isinstance(msg_or_type, (Message, dict)):
|
||||
# We got a Message or message dict, not a msg_type so don't
|
||||
# build a new Message.
|
||||
msg = msg_or_type
|
||||
buffers = buffers or msg.get('buffers', [])
|
||||
else:
|
||||
msg = self.msg(msg_or_type, content=content, parent=parent,
|
||||
header=header, metadata=metadata)
|
||||
if self.check_pid and not os.getpid() == self.pid:
|
||||
get_logger().warning("WARNING: attempted to send message from fork\n%s",
|
||||
msg
|
||||
)
|
||||
return
|
||||
buffers = [] if buffers is None else buffers
|
||||
for idx, buf in enumerate(buffers):
|
||||
if isinstance(buf, memoryview):
|
||||
view = buf
|
||||
else:
|
||||
try:
|
||||
# check to see if buf supports the buffer protocol.
|
||||
view = memoryview(buf)
|
||||
except TypeError as e:
|
||||
raise TypeError("Buffer objects must support the buffer protocol.") from e
|
||||
# memoryview.contiguous is new in 3.3,
|
||||
# just skip the check on Python 2
|
||||
if hasattr(view, 'contiguous') and not view.contiguous:
|
||||
# zmq requires memoryviews to be contiguous
|
||||
raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
|
||||
|
||||
if self.adapt_version:
|
||||
msg = adapt(msg, self.adapt_version)
|
||||
to_send = self.serialize(msg, ident)
|
||||
to_send.extend(buffers)
|
||||
longest = max([ len(s) for s in to_send ])
|
||||
copy = (longest < self.copy_threshold)
|
||||
|
||||
if buffers and track and not copy:
|
||||
# only really track when we are doing zero-copy buffers
|
||||
tracker = stream.send_multipart(to_send, copy=False, track=True)
|
||||
else:
|
||||
# use dummy tracker, which will be done immediately
|
||||
tracker = DONE
|
||||
stream.send_multipart(to_send, copy=copy)
|
||||
|
||||
if self.debug:
|
||||
pprint.pprint(msg)
|
||||
pprint.pprint(to_send)
|
||||
pprint.pprint(buffers)
|
||||
|
||||
msg['tracker'] = tracker
|
||||
|
||||
return msg
|
||||
|
||||
def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
|
||||
"""Send a raw message via ident path.
|
||||
|
||||
This method is used to send a already serialized message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : ZMQStream or Socket
|
||||
The ZMQ stream or socket to use for sending the message.
|
||||
msg_list : list
|
||||
The serialized list of messages to send. This only includes the
|
||||
[p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
|
||||
the message.
|
||||
ident : ident or list
|
||||
A single ident or a list of idents to use in sending.
|
||||
"""
|
||||
to_send = []
|
||||
if isinstance(ident, bytes):
|
||||
ident = [ident]
|
||||
if ident is not None:
|
||||
to_send.extend(ident)
|
||||
|
||||
to_send.append(DELIM)
|
||||
# Don't include buffers in signature (per spec).
|
||||
to_send.append(self.sign(msg_list[0:4]))
|
||||
to_send.extend(msg_list)
|
||||
stream.send_multipart(to_send, flags, copy=copy)
|
||||
|
||||
def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
|
||||
"""Receive and unpack a message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
socket : ZMQStream or Socket
|
||||
The socket or stream to use in receiving.
|
||||
|
||||
Returns
|
||||
-------
|
||||
[idents], msg
|
||||
[idents] is a list of idents and msg is a nested message dict of
|
||||
same format as self.msg returns.
|
||||
"""
|
||||
if isinstance(socket, ZMQStream):
|
||||
socket = socket.socket
|
||||
try:
|
||||
msg_list = socket.recv_multipart(mode, copy=copy)
|
||||
except zmq.ZMQError as e:
|
||||
if e.errno == zmq.EAGAIN:
|
||||
# We can convert EAGAIN to None as we know in this case
|
||||
# recv_multipart won't return None.
|
||||
return None,None
|
||||
else:
|
||||
raise
|
||||
# split multipart message into identity list and message dict
|
||||
# invalid large messages can cause very expensive string comparisons
|
||||
idents, msg_list = self.feed_identities(msg_list, copy)
|
||||
try:
|
||||
return idents, self.deserialize(msg_list, content=content, copy=copy)
|
||||
except Exception as e:
|
||||
# TODO: handle it
|
||||
raise e
|
||||
|
||||
def feed_identities(self, msg_list, copy=True):
|
||||
"""Split the identities from the rest of the message.
|
||||
|
||||
Feed until DELIM is reached, then return the prefix as idents and
|
||||
remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
|
||||
but that would be silly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg_list : a list of Message or bytes objects
|
||||
The message to be split.
|
||||
copy : bool
|
||||
flag determining whether the arguments are bytes or Messages
|
||||
|
||||
Returns
|
||||
-------
|
||||
(idents, msg_list) : two lists
|
||||
idents will always be a list of bytes, each of which is a ZMQ
|
||||
identity. msg_list will be a list of bytes or zmq.Messages of the
|
||||
form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
|
||||
should be unpackable/unserializable via self.deserialize at this
|
||||
point.
|
||||
"""
|
||||
if copy:
|
||||
idx = msg_list.index(DELIM)
|
||||
return msg_list[:idx], msg_list[idx+1:]
|
||||
else:
|
||||
failed = True
|
||||
for idx,m in enumerate(msg_list):
|
||||
if m.bytes == DELIM:
|
||||
failed = False
|
||||
break
|
||||
if failed:
|
||||
raise ValueError("DELIM not in msg_list")
|
||||
idents, msg_list = msg_list[:idx], msg_list[idx+1:]
|
||||
return [m.bytes for m in idents], msg_list
|
||||
|
||||
def _add_digest(self, signature):
|
||||
"""add a digest to history to protect against replay attacks"""
|
||||
if self.digest_history_size == 0:
|
||||
# no history, never add digests
|
||||
return
|
||||
|
||||
self.digest_history.add(signature)
|
||||
if len(self.digest_history) > self.digest_history_size:
|
||||
# threshold reached, cull 10%
|
||||
self._cull_digest_history()
|
||||
|
||||
def _cull_digest_history(self):
|
||||
"""cull the digest history
|
||||
|
||||
Removes a randomly selected 10% of the digest history
|
||||
"""
|
||||
current = len(self.digest_history)
|
||||
n_to_cull = max(int(current // 10), current - self.digest_history_size)
|
||||
if n_to_cull >= current:
|
||||
self.digest_history = set()
|
||||
return
|
||||
to_cull = random.sample(self.digest_history, n_to_cull)
|
||||
self.digest_history.difference_update(to_cull)
|
||||
|
||||
def deserialize(self, msg_list, content=True, copy=True):
|
||||
"""Unserialize a msg_list to a nested message dict.
|
||||
|
||||
This is roughly the inverse of serialize. The serialize/deserialize
|
||||
methods work with full message lists, whereas pack/unpack work with
|
||||
the individual message parts in the message list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg_list : list of bytes or Message objects
|
||||
The list of message parts of the form [HMAC,p_header,p_parent,
|
||||
p_metadata,p_content,buffer1,buffer2,...].
|
||||
content : bool (True)
|
||||
Whether to unpack the content dict (True), or leave it packed
|
||||
(False).
|
||||
copy : bool (True)
|
||||
Whether msg_list contains bytes (True) or the non-copying Message
|
||||
objects in each place (False).
|
||||
|
||||
Returns
|
||||
-------
|
||||
msg : dict
|
||||
The nested message dict with top-level keys [header, parent_header,
|
||||
content, buffers]. The buffers are returned as memoryviews.
|
||||
"""
|
||||
minlen = 5
|
||||
message = {}
|
||||
if not copy:
|
||||
# pyzmq didn't copy the first parts of the message, so we'll do it
|
||||
for i in range(minlen):
|
||||
msg_list[i] = msg_list[i].bytes
|
||||
if self.auth is not None:
|
||||
signature = msg_list[0]
|
||||
if not signature:
|
||||
raise ValueError("Unsigned Message")
|
||||
if signature in self.digest_history:
|
||||
raise ValueError("Duplicate Signature: %r" % signature)
|
||||
if content:
|
||||
# Only store signature if we are unpacking content, don't store if just peeking.
|
||||
self._add_digest(signature)
|
||||
check = self.sign(msg_list[1:5])
|
||||
if not compare_digest(signature, check):
|
||||
raise ValueError("Invalid Signature: %r" % signature)
|
||||
if not len(msg_list) >= minlen:
|
||||
raise TypeError("malformed message, must have at least %i elements"%minlen)
|
||||
header = self.unpack(msg_list[1])
|
||||
message['header'] = extract_dates(header)
|
||||
message['msg_id'] = header['msg_id']
|
||||
message['msg_type'] = header['msg_type']
|
||||
message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
|
||||
message['metadata'] = self.unpack(msg_list[3])
|
||||
if content:
|
||||
message['content'] = self.unpack(msg_list[4])
|
||||
else:
|
||||
message['content'] = msg_list[4]
|
||||
buffers = [memoryview(b) for b in msg_list[5:]]
|
||||
if buffers and buffers[0].shape is None:
|
||||
# force copy to workaround pyzmq #646
|
||||
buffers = [memoryview(b.bytes) for b in msg_list[5:]]
|
||||
message['buffers'] = buffers
|
||||
if self.debug:
|
||||
pprint.pprint(message)
|
||||
# adapt to the current version
|
||||
return adapt(message)
|
||||
|
||||
def unserialize(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"Session.unserialize is deprecated. Use Session.deserialize.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self.deserialize(*args, **kwargs)
|
||||
|
||||
|
||||
def test_msg2obj():
|
||||
am = dict(x=1)
|
||||
ao = Message(am)
|
||||
assert ao.x == am['x']
|
||||
|
||||
am['y'] = dict(z=1)
|
||||
ao = Message(am)
|
||||
assert ao.y.z == am['y']['z']
|
||||
|
||||
k1, k2 = 'y', 'z'
|
||||
assert ao[k1][k2] == am[k1][k2]
|
||||
|
||||
am2 = dict(ao)
|
||||
assert am['x'] == am2['x']
|
||||
assert am['y']['z'] == am2['y']['z']
|
Loading…
Add table
Add a link
Reference in a new issue