Uploaded Test files

This commit is contained in:
Batuhan Berk Başoğlu 2020-11-12 11:05:57 -05:00
parent f584ad9d97
commit 2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions

View file

@ -0,0 +1,10 @@
"""Client-side implementations of the Jupyter protocol"""
from ._version import version_info, __version__, protocol_version_info, protocol_version
from .connect import *
from .launcher import *
from .client import KernelClient
from .manager import KernelManager, AsyncKernelManager, run_kernel
from .blocking import BlockingKernelClient
from .asynchronous import AsyncKernelClient
from .multikernelmanager import MultiKernelManager, AsyncMultiKernelManager

View file

@ -0,0 +1,5 @@
version_info = (6, 1, 7)
__version__ = '.'.join(map(str, version_info))
protocol_version_info = (5, 3)
protocol_version = "%i.%i" % protocol_version_info

View file

@ -0,0 +1,405 @@
"""Adapters for Jupyter msg spec versions."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import re
import json
from jupyter_client import protocol_version_info
def code_to_line(code, cursor_pos):
"""Turn a multiline code block and cursor position into a single line
and new cursor position.
For adapting ``complete_`` and ``object_info_request``.
"""
if not code:
return "", 0
for line in code.splitlines(True):
n = len(line)
if cursor_pos > n:
cursor_pos -= n
else:
break
return line, cursor_pos
_match_bracket = re.compile(r'\([^\(\)]+\)', re.UNICODE)
_end_bracket = re.compile(r'\([^\(]*$', re.UNICODE)
_identifier = re.compile(r'[a-z_][0-9a-z._]*', re.I|re.UNICODE)
def extract_oname_v4(code, cursor_pos):
"""Reimplement token-finding logic from IPython 2.x javascript
for adapting object_info_request from v5 to v4
"""
line, _ = code_to_line(code, cursor_pos)
oldline = line
line = _match_bracket.sub('', line)
while oldline != line:
oldline = line
line = _match_bracket.sub('', line)
# remove everything after last open bracket
line = _end_bracket.sub('', line)
matches = _identifier.findall(line)
if matches:
return matches[-1]
else:
return ''
class Adapter(object):
"""Base class for adapting messages
Override message_type(msg) methods to create adapters.
"""
msg_type_map = {}
def update_header(self, msg):
return msg
def update_metadata(self, msg):
return msg
def update_msg_type(self, msg):
header = msg['header']
msg_type = header['msg_type']
if msg_type in self.msg_type_map:
msg['msg_type'] = header['msg_type'] = self.msg_type_map[msg_type]
return msg
def handle_reply_status_error(self, msg):
"""This will be called *instead of* the regular handler
on any reply with status != ok
"""
return msg
def __call__(self, msg):
msg = self.update_header(msg)
msg = self.update_metadata(msg)
msg = self.update_msg_type(msg)
header = msg['header']
handler = getattr(self, header['msg_type'], None)
if handler is None:
return msg
# handle status=error replies separately (no change, at present)
if msg['content'].get('status', None) in {'error', 'aborted'}:
return self.handle_reply_status_error(msg)
return handler(msg)
def _version_str_to_list(version):
"""convert a version string to a list of ints
non-int segments are excluded
"""
v = []
for part in version.split('.'):
try:
v.append(int(part))
except ValueError:
pass
return v
class V5toV4(Adapter):
"""Adapt msg protocol v5 to v4"""
version = '4.1'
msg_type_map = {
'execute_result' : 'pyout',
'execute_input' : 'pyin',
'error' : 'pyerr',
'inspect_request' : 'object_info_request',
'inspect_reply' : 'object_info_reply',
}
def update_header(self, msg):
msg['header'].pop('version', None)
msg['parent_header'].pop('version', None)
return msg
# shell channel
def kernel_info_reply(self, msg):
v4c = {}
content = msg['content']
for key in ('language_version', 'protocol_version'):
if key in content:
v4c[key] = _version_str_to_list(content[key])
if content.get('implementation', '') == 'ipython' \
and 'implementation_version' in content:
v4c['ipython_version'] = _version_str_to_list(content['implementation_version'])
language_info = content.get('language_info', {})
language = language_info.get('name', '')
v4c.setdefault('language', language)
if 'version' in language_info:
v4c.setdefault('language_version', _version_str_to_list(language_info['version']))
msg['content'] = v4c
return msg
def execute_request(self, msg):
content = msg['content']
content.setdefault('user_variables', [])
return msg
def execute_reply(self, msg):
content = msg['content']
content.setdefault('user_variables', {})
# TODO: handle payloads
return msg
def complete_request(self, msg):
content = msg['content']
code = content['code']
cursor_pos = content['cursor_pos']
line, cursor_pos = code_to_line(code, cursor_pos)
new_content = msg['content'] = {}
new_content['text'] = ''
new_content['line'] = line
new_content['block'] = None
new_content['cursor_pos'] = cursor_pos
return msg
def complete_reply(self, msg):
content = msg['content']
cursor_start = content.pop('cursor_start')
cursor_end = content.pop('cursor_end')
match_len = cursor_end - cursor_start
content['matched_text'] = content['matches'][0][:match_len]
content.pop('metadata', None)
return msg
def object_info_request(self, msg):
content = msg['content']
code = content['code']
cursor_pos = content['cursor_pos']
line, _ = code_to_line(code, cursor_pos)
new_content = msg['content'] = {}
new_content['oname'] = extract_oname_v4(code, cursor_pos)
new_content['detail_level'] = content['detail_level']
return msg
def object_info_reply(self, msg):
"""inspect_reply can't be easily backward compatible"""
msg['content'] = {'found' : False, 'oname' : 'unknown'}
return msg
# iopub channel
def stream(self, msg):
content = msg['content']
content['data'] = content.pop('text')
return msg
def display_data(self, msg):
content = msg['content']
content.setdefault("source", "display")
data = content['data']
if 'application/json' in data:
try:
data['application/json'] = json.dumps(data['application/json'])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg):
msg['content'].pop('password', None)
return msg
class V4toV5(Adapter):
"""Convert msg spec V4 to V5"""
version = '5.0'
# invert message renames above
msg_type_map = {v:k for k,v in V5toV4.msg_type_map.items()}
def update_header(self, msg):
msg['header']['version'] = self.version
if msg['parent_header']:
msg['parent_header']['version'] = self.version
return msg
# shell channel
def kernel_info_reply(self, msg):
content = msg['content']
for key in ('protocol_version', 'ipython_version'):
if key in content:
content[key] = '.'.join(map(str, content[key]))
content.setdefault('protocol_version', '4.1')
if content['language'].startswith('python') and 'ipython_version' in content:
content['implementation'] = 'ipython'
content['implementation_version'] = content.pop('ipython_version')
language = content.pop('language')
language_info = content.setdefault('language_info', {})
language_info.setdefault('name', language)
if 'language_version' in content:
language_version = '.'.join(map(str, content.pop('language_version')))
language_info.setdefault('version', language_version)
content['banner'] = ''
return msg
def execute_request(self, msg):
content = msg['content']
user_variables = content.pop('user_variables', [])
user_expressions = content.setdefault('user_expressions', {})
for v in user_variables:
user_expressions[v] = v
return msg
def execute_reply(self, msg):
content = msg['content']
user_expressions = content.setdefault('user_expressions', {})
user_variables = content.pop('user_variables', {})
if user_variables:
user_expressions.update(user_variables)
# Pager payloads became a mime bundle
for payload in content.get('payload', []):
if payload.get('source', None) == 'page' and ('text' in payload):
if 'data' not in payload:
payload['data'] = {}
payload['data']['text/plain'] = payload.pop('text')
return msg
def complete_request(self, msg):
old_content = msg['content']
new_content = msg['content'] = {}
new_content['code'] = old_content['line']
new_content['cursor_pos'] = old_content['cursor_pos']
return msg
def complete_reply(self, msg):
# complete_reply needs more context than we have to get cursor_start and end.
# use special end=null to indicate current cursor position and negative offset
# for start relative to the cursor.
# start=None indicates that start == end (accounts for no -0).
content = msg['content']
new_content = msg['content'] = {'status' : 'ok'}
new_content['matches'] = content['matches']
if content['matched_text']:
new_content['cursor_start'] = -len(content['matched_text'])
else:
# no -0, use None to indicate that start == end
new_content['cursor_start'] = None
new_content['cursor_end'] = None
new_content['metadata'] = {}
return msg
def inspect_request(self, msg):
content = msg['content']
name = content['oname']
new_content = msg['content'] = {}
new_content['code'] = name
new_content['cursor_pos'] = len(name)
new_content['detail_level'] = content['detail_level']
return msg
def inspect_reply(self, msg):
"""inspect_reply can't be easily backward compatible"""
content = msg['content']
new_content = msg['content'] = {'status' : 'ok'}
found = new_content['found'] = content['found']
new_content['data'] = data = {}
new_content['metadata'] = {}
if found:
lines = []
for key in ('call_def', 'init_definition', 'definition'):
if content.get(key, False):
lines.append(content[key])
break
for key in ('call_docstring', 'init_docstring', 'docstring'):
if content.get(key, False):
lines.append(content[key])
break
if not lines:
lines.append("<empty docstring>")
data['text/plain'] = '\n'.join(lines)
return msg
# iopub channel
def stream(self, msg):
content = msg['content']
content['text'] = content.pop('data')
return msg
def display_data(self, msg):
content = msg['content']
content.pop("source", None)
data = content['data']
if 'application/json' in data:
try:
data['application/json'] = json.loads(data['application/json'])
except Exception:
# warn?
pass
return msg
# stdin channel
def input_request(self, msg):
msg['content'].setdefault('password', False)
return msg
def adapt(msg, to_version=protocol_version_info[0]):
"""Adapt a single message to a target version
Parameters
----------
msg : dict
A Jupyter message.
to_version : int, optional
The target major version.
If unspecified, adapt to the current version.
Returns
-------
msg : dict
A Jupyter message appropriate in the new version.
"""
from .session import utcnow
header = msg['header']
if 'date' not in header:
header['date'] = utcnow()
if 'version' in header:
from_version = int(header['version'].split('.')[0])
else:
# assume last version before adding the key to the header
from_version = 4
adapter = adapters.get((from_version, to_version), None)
if adapter is None:
return msg
return adapter(msg)
# one adapter per major version from,to
adapters = {
(5,4) : V5toV4(),
(4,5) : V4toV5(),
}

View file

@ -0,0 +1 @@
from .client import AsyncKernelClient

View file

@ -0,0 +1,82 @@
"""Async channels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from queue import Queue, Empty
class ZMQSocketChannel(object):
"""A ZMQ socket in an async API"""
session = None
socket = None
stream = None
_exiting = False
proxy_methods = []
def __init__(self, socket, session, loop=None):
"""Create a channel.
Parameters
----------
socket : :class:`zmq.asyncio.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket = socket
self.session = session
async def _recv(self, **kwargs):
msg = await self.socket.recv_multipart(**kwargs)
ident,smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
async def get_msg(self, timeout=None):
""" Gets a message if there is one that is ready. """
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = await self.socket.poll(timeout)
if ready:
return await self._recv()
else:
raise Empty
async def get_msgs(self):
""" Get all messages that are currently ready. """
msgs = []
while True:
try:
msgs.append(await self.get_msg())
except Empty:
break
return msgs
async def msg_ready(self):
""" Is there a message that has been received? """
return bool(await self.socket.poll(timeout=0))
def close(self):
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self):
return (self.socket is not None)
def send(self, msg):
"""Pass a message to the ZMQ socket to send
"""
self.session.send(self.socket, msg)
def start(self):
pass

View file

@ -0,0 +1,388 @@
"""Implements an async kernel client"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from functools import partial
from getpass import getpass
from queue import Empty
import sys
import time
import zmq
import zmq.asyncio
import asyncio
from traitlets import (Type, Instance)
from jupyter_client.channels import HBChannel
from jupyter_client.client import KernelClient
from .channels import ZMQSocketChannel
def reqrep(meth, channel='shell'):
def wrapped(self, *args, **kwargs):
reply = kwargs.pop('reply', False)
timeout = kwargs.pop('timeout', None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return self._recv_reply(msg_id, timeout=timeout, channel=channel)
if not meth.__doc__:
# python -OO removes docstrings,
# so don't bother building the wrapped docstring
return wrapped
basedoc, _ = meth.__doc__.split('Returns\n', 1)
parts = [basedoc.strip()]
if 'Parameters' not in basedoc:
parts.append("""
Parameters
----------
""")
parts.append("""
reply: bool (default: False)
Whether to wait for and return reply
timeout: float or None (default: None)
Timeout to use when waiting for a reply
Returns
-------
msg_id: str
The msg_id of the request sent, if reply=False (default)
reply: dict
The reply message for this request, if reply=True
""")
wrapped.__doc__ = '\n'.join(parts)
return wrapped
class AsyncKernelClient(KernelClient):
"""A KernelClient with async APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
# The PyZMQ Context to use for communication with the kernel.
context = Instance(zmq.asyncio.Context)
def _context_default(self):
return zmq.asyncio.Context()
#--------------------------------------------------------------------------
# Channel proxy methods
#--------------------------------------------------------------------------
async def get_shell_msg(self, *args, **kwargs):
"""Get a message from the shell channel"""
return await self.shell_channel.get_msg(*args, **kwargs)
async def get_iopub_msg(self, *args, **kwargs):
"""Get a message from the iopub channel"""
return await self.iopub_channel.get_msg(*args, **kwargs)
async def get_stdin_msg(self, *args, **kwargs):
"""Get a message from the stdin channel"""
return await self.stdin_channel.get_msg(*args, **kwargs)
async def get_control_msg(self, *args, **kwargs):
"""Get a message from the control channel"""
return await self.control_channel.get_msg(*args, **kwargs)
@property
def hb_channel(self):
"""Get the hb channel object for this kernel."""
if self._hb_channel is None:
url = self._make_url('hb')
self.log.debug("connecting heartbeat channel to %s", url)
loop = asyncio.new_event_loop()
self._hb_channel = self.hb_channel_class(
self.context, self.session, url, loop
)
return self._hb_channel
async def wait_for_ready(self, timeout=None):
"""Waits for a response when a client is blocked
- Sets future time for timeout
- Blocks on shell channel until a message is received
- Exit if the kernel has died
- If client times out before receiving a message from the kernel, send RuntimeError
- Flush the IOPub channel
"""
if timeout is None:
abs_timeout = float('inf')
else:
abs_timeout = time.time() + timeout
from ..manager import KernelManager
if not isinstance(self.parent, KernelManager):
# This Client was not created by a KernelManager,
# so wait for kernel to become responsive to heartbeats
# before checking for kernel_info reply
while not self.is_alive():
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout)
await asyncio.sleep(0.2)
# Wait for kernel info reply on shell channel
while True:
try:
msg = await self.shell_channel.get_msg(timeout=1)
except Empty:
pass
else:
if msg['msg_type'] == 'kernel_info_reply':
self._handle_kernel_info_reply(msg)
break
if not await self.is_alive():
raise RuntimeError('Kernel died before replying to kernel_info')
# Check if current time is ready check time plus timeout
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond in %d seconds" % timeout)
# Flush IOPub channel
while True:
try:
msg = await self.iopub_channel.get_msg(timeout=0.2)
except Empty:
break
# The classes to use for the various channels
shell_channel_class = Type(ZMQSocketChannel)
iopub_channel_class = Type(ZMQSocketChannel)
stdin_channel_class = Type(ZMQSocketChannel)
hb_channel_class = Type(HBChannel)
control_channel_class = Type(ZMQSocketChannel)
async def _recv_reply(self, msg_id, timeout=None, channel='shell'):
"""Receive and return the reply for a given request"""
if timeout is not None:
deadline = time.monotonic() + timeout
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
try:
if channel == 'control':
reply = await self.get_control_msg(timeout=timeout)
else:
reply = await self.get_shell_msg(timeout=timeout)
except Empty as e:
raise TimeoutError("Timeout waiting for reply") from e
if reply['parent_header'].get('msg_id') != msg_id:
# not my reply, someone may have forgotten to retrieve theirs
continue
return reply
# replies come on the shell channel
execute = reqrep(KernelClient.execute)
history = reqrep(KernelClient.history)
complete = reqrep(KernelClient.complete)
inspect = reqrep(KernelClient.inspect)
kernel_info = reqrep(KernelClient.kernel_info)
comm_info = reqrep(KernelClient.comm_info)
# replies come on the control channel
shutdown = reqrep(KernelClient.shutdown, channel='control')
def _stdin_hook_default(self, msg):
"""Handle an input request"""
content = msg['content']
if content.get('password', False):
prompt = getpass
else:
prompt = input
try:
raw_data = prompt(content["prompt"])
except EOFError:
# turn EOFError into EOF character
raw_data = '\x04'
except KeyboardInterrupt:
sys.stdout.write('\n')
return
# only send stdin reply if there *was not* another request
# or execution finished while we were reading.
if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()):
self.input(raw_data)
def _output_hook_default(self, msg):
"""Default hook for redisplaying plain-text output"""
msg_type = msg['header']['msg_type']
content = msg['content']
if msg_type == 'stream':
stream = getattr(sys, content['name'])
stream.write(content['text'])
elif msg_type in ('display_data', 'execute_result'):
sys.stdout.write(content['data'].get('text/plain', ''))
elif msg_type == 'error':
print('\n'.join(content['traceback']), file=sys.stderr)
def _output_hook_kernel(self, session, socket, parent_header, msg):
"""Output hook when running inside an IPython kernel
adds rich output support.
"""
msg_type = msg['header']['msg_type']
if msg_type in ('display_data', 'execute_result', 'error'):
session.send(socket, msg_type, msg['content'], parent=parent_header)
else:
self._output_hook_default(msg)
async def is_alive(self):
"""Is the kernel process still running?"""
from ..manager import KernelManager, AsyncKernelManager
if isinstance(self.parent, KernelManager):
# This KernelClient was created by a KernelManager,
# we can ask the parent KernelManager:
if isinstance(self.parent, AsyncKernelManager):
return await self.parent.is_alive()
return self.parent.is_alive()
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
else:
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True
async def execute_interactive(self, code, silent=False, store_history=True,
user_expressions=None, allow_stdin=None, stop_on_error=True,
timeout=None, output_hook=None, stdin_hook=None,
):
"""Execute code in the kernel interactively
Output will be redisplayed, and stdin prompts will be relayed as well.
If an IPython kernel is detected, rich output will be displayed.
You can pass a custom output_hook callable that will be called
with every IOPub message that is produced instead of the default redisplay.
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
timeout: float or None (default: None)
Timeout to use when waiting for a reply
output_hook: callable(msg)
Function to be called with output messages.
If not specified, output will be redisplayed.
stdin_hook: callable(msg)
Function to be called with stdin_request messages.
If not specified, input/getpass will be called.
Returns
-------
reply: dict
The reply message for this request
"""
if not self.iopub_channel.is_alive():
raise RuntimeError("IOPub channel must be running to receive output")
if allow_stdin is None:
allow_stdin = self.allow_stdin
if allow_stdin and not self.stdin_channel.is_alive():
raise RuntimeError("stdin channel must be running to allow input")
msg_id = await self.execute(code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)
if stdin_hook is None:
stdin_hook = self._stdin_hook_default
if output_hook is None:
# detect IPython kernel
if 'IPython' in sys.modules:
from IPython import get_ipython
ip = get_ipython()
in_kernel = getattr(ip, 'kernel', False)
if in_kernel:
output_hook = partial(
self._output_hook_kernel,
ip.display_pub.session,
ip.display_pub.pub_socket,
ip.display_pub.parent_header,
)
if output_hook is None:
# default: redisplay plain-text outputs
output_hook = self._output_hook_default
# set deadline based on timeout
if timeout is not None:
deadline = time.monotonic() + timeout
else:
timeout_ms = None
poller = zmq.Poller()
iopub_socket = self.iopub_channel.socket
poller.register(iopub_socket, zmq.POLLIN)
if allow_stdin:
stdin_socket = self.stdin_channel.socket
poller.register(stdin_socket, zmq.POLLIN)
else:
stdin_socket = None
# wait for output and redisplay it
while True:
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
timeout_ms = 1e3 * timeout
events = dict(poller.poll(timeout_ms))
if not events:
raise TimeoutError("Timeout waiting for output")
if stdin_socket in events:
req = await self.stdin_channel.get_msg(timeout=0)
stdin_hook(req)
continue
if iopub_socket not in events:
continue
msg = await self.iopub_channel.get_msg(timeout=0)
if msg['parent_header'].get('msg_id') != msg_id:
# not from my request
continue
output_hook(msg)
# stop on idle
if msg['header']['msg_type'] == 'status' and \
msg['content']['execution_state'] == 'idle':
break
# output is done, get the reply
if timeout is not None:
timeout = max(0, deadline - time.monotonic())
return await self._recv_reply(msg_id, timeout=timeout)

View file

@ -0,0 +1 @@
from .client import BlockingKernelClient

View file

@ -0,0 +1,88 @@
"""Blocking channels
Useful for test suites and blocking terminal interfaces.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from queue import Queue, Empty
class ZMQSocketChannel(object):
"""A ZMQ socket in a simple blocking API"""
session = None
socket = None
stream = None
_exiting = False
proxy_methods = []
def __init__(self, socket, session, loop=None):
"""Create a channel.
Parameters
----------
socket : :class:`zmq.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket = socket
self.session = session
def _recv(self, **kwargs):
msg = self.socket.recv_multipart(**kwargs)
ident,smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
def get_msg(self, block=True, timeout=None):
""" Gets a message if there is one that is ready. """
if block:
if timeout is not None:
timeout *= 1000 # seconds to ms
ready = self.socket.poll(timeout)
else:
ready = self.socket.poll(timeout=0)
if ready:
return self._recv()
else:
raise Empty
def get_msgs(self):
""" Get all messages that are currently ready. """
msgs = []
while True:
try:
msgs.append(self.get_msg(block=False))
except Empty:
break
return msgs
def msg_ready(self):
""" Is there a message that has been received? """
return bool(self.socket.poll(timeout=0))
def close(self):
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self):
return (self.socket is not None)
def send(self, msg):
"""Pass a message to the ZMQ socket to send
"""
self.session.send(self.socket, msg)
def start(self):
pass

View file

@ -0,0 +1,337 @@
"""Implements a fully blocking kernel client.
Useful for test suites and blocking terminal interfaces.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from functools import partial
from getpass import getpass
from queue import Empty
import sys
import time
import zmq
from time import monotonic
from traitlets import Type
from jupyter_client.channels import HBChannel
from jupyter_client.client import KernelClient
from .channels import ZMQSocketChannel
def reqrep(meth, channel='shell'):
def wrapped(self, *args, **kwargs):
reply = kwargs.pop('reply', False)
timeout = kwargs.pop('timeout', None)
msg_id = meth(self, *args, **kwargs)
if not reply:
return msg_id
return self._recv_reply(msg_id, timeout=timeout, channel=channel)
if not meth.__doc__:
# python -OO removes docstrings,
# so don't bother building the wrapped docstring
return wrapped
basedoc, _ = meth.__doc__.split('Returns\n', 1)
parts = [basedoc.strip()]
if 'Parameters' not in basedoc:
parts.append("""
Parameters
----------
""")
parts.append("""
reply: bool (default: False)
Whether to wait for and return reply
timeout: float or None (default: None)
Timeout to use when waiting for a reply
Returns
-------
msg_id: str
The msg_id of the request sent, if reply=False (default)
reply: dict
The reply message for this request, if reply=True
""")
wrapped.__doc__ = '\n'.join(parts)
return wrapped
class BlockingKernelClient(KernelClient):
"""A KernelClient with blocking APIs
``get_[channel]_msg()`` methods wait for and return messages on channels,
raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds.
"""
def wait_for_ready(self, timeout=None):
"""Waits for a response when a client is blocked
- Sets future time for timeout
- Blocks on shell channel until a message is received
- Exit if the kernel has died
- If client times out before receiving a message from the kernel, send RuntimeError
- Flush the IOPub channel
"""
if timeout is None:
abs_timeout = float('inf')
else:
abs_timeout = time.time() + timeout
from ..manager import KernelManager
if not isinstance(self.parent, KernelManager):
# This Client was not created by a KernelManager,
# so wait for kernel to become responsive to heartbeats
# before checking for kernel_info reply
while not self.is_alive():
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout)
time.sleep(0.2)
# Wait for kernel info reply on shell channel
while True:
try:
msg = self.shell_channel.get_msg(block=True, timeout=1)
except Empty:
pass
else:
if msg['msg_type'] == 'kernel_info_reply':
self._handle_kernel_info_reply(msg)
break
if not self.is_alive():
raise RuntimeError('Kernel died before replying to kernel_info')
# Check if current time is ready check time plus timeout
if time.time() > abs_timeout:
raise RuntimeError("Kernel didn't respond in %d seconds" % timeout)
# Flush IOPub channel
while True:
try:
msg = self.iopub_channel.get_msg(block=True, timeout=0.2)
except Empty:
break
# The classes to use for the various channels
shell_channel_class = Type(ZMQSocketChannel)
iopub_channel_class = Type(ZMQSocketChannel)
stdin_channel_class = Type(ZMQSocketChannel)
hb_channel_class = Type(HBChannel)
control_channel_class = Type(ZMQSocketChannel)
def _recv_reply(self, msg_id, timeout=None, channel='shell'):
"""Receive and return the reply for a given request"""
if timeout is not None:
deadline = monotonic() + timeout
while True:
if timeout is not None:
timeout = max(0, deadline - monotonic())
try:
if channel == 'control':
reply = self.get_control_msg(timeout=timeout)
else:
reply = self.get_shell_msg(timeout=timeout)
except Empty as e:
raise TimeoutError("Timeout waiting for reply") from e
if reply['parent_header'].get('msg_id') != msg_id:
# not my reply, someone may have forgotten to retrieve theirs
continue
return reply
# replies come on the shell channel
execute = reqrep(KernelClient.execute)
history = reqrep(KernelClient.history)
complete = reqrep(KernelClient.complete)
inspect = reqrep(KernelClient.inspect)
kernel_info = reqrep(KernelClient.kernel_info)
comm_info = reqrep(KernelClient.comm_info)
# replies come on the control channel
shutdown = reqrep(KernelClient.shutdown, channel='control')
def _stdin_hook_default(self, msg):
"""Handle an input request"""
content = msg['content']
if content.get('password', False):
prompt = getpass
else:
prompt = input
try:
raw_data = prompt(content["prompt"])
except EOFError:
# turn EOFError into EOF character
raw_data = '\x04'
except KeyboardInterrupt:
sys.stdout.write('\n')
return
# only send stdin reply if there *was not* another request
# or execution finished while we were reading.
if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()):
self.input(raw_data)
def _output_hook_default(self, msg):
"""Default hook for redisplaying plain-text output"""
msg_type = msg['header']['msg_type']
content = msg['content']
if msg_type == 'stream':
stream = getattr(sys, content['name'])
stream.write(content['text'])
elif msg_type in ('display_data', 'execute_result'):
sys.stdout.write(content['data'].get('text/plain', ''))
elif msg_type == 'error':
print('\n'.join(content['traceback']), file=sys.stderr)
def _output_hook_kernel(self, session, socket, parent_header, msg):
"""Output hook when running inside an IPython kernel
adds rich output support.
"""
msg_type = msg['header']['msg_type']
if msg_type in ('display_data', 'execute_result', 'error'):
session.send(socket, msg_type, msg['content'], parent=parent_header)
else:
self._output_hook_default(msg)
def execute_interactive(self, code, silent=False, store_history=True,
user_expressions=None, allow_stdin=None, stop_on_error=True,
timeout=None, output_hook=None, stdin_hook=None,
):
"""Execute code in the kernel interactively
Output will be redisplayed, and stdin prompts will be relayed as well.
If an IPython kernel is detected, rich output will be displayed.
You can pass a custom output_hook callable that will be called
with every IOPub message that is produced instead of the default redisplay.
.. versionadded:: 5.0
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
timeout: float or None (default: None)
Timeout to use when waiting for a reply
output_hook: callable(msg)
Function to be called with output messages.
If not specified, output will be redisplayed.
stdin_hook: callable(msg)
Function to be called with stdin_request messages.
If not specified, input/getpass will be called.
Returns
-------
reply: dict
The reply message for this request
"""
if not self.iopub_channel.is_alive():
raise RuntimeError("IOPub channel must be running to receive output")
if allow_stdin is None:
allow_stdin = self.allow_stdin
if allow_stdin and not self.stdin_channel.is_alive():
raise RuntimeError("stdin channel must be running to allow input")
msg_id = self.execute(code,
silent=silent,
store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin,
stop_on_error=stop_on_error,
)
if stdin_hook is None:
stdin_hook = self._stdin_hook_default
if output_hook is None:
# detect IPython kernel
if 'IPython' in sys.modules:
from IPython import get_ipython
ip = get_ipython()
in_kernel = getattr(ip, 'kernel', False)
if in_kernel:
output_hook = partial(
self._output_hook_kernel,
ip.display_pub.session,
ip.display_pub.pub_socket,
ip.display_pub.parent_header,
)
if output_hook is None:
# default: redisplay plain-text outputs
output_hook = self._output_hook_default
# set deadline based on timeout
if timeout is not None:
deadline = monotonic() + timeout
else:
timeout_ms = None
poller = zmq.Poller()
iopub_socket = self.iopub_channel.socket
poller.register(iopub_socket, zmq.POLLIN)
if allow_stdin:
stdin_socket = self.stdin_channel.socket
poller.register(stdin_socket, zmq.POLLIN)
else:
stdin_socket = None
# wait for output and redisplay it
while True:
if timeout is not None:
timeout = max(0, deadline - monotonic())
timeout_ms = 1e3 * timeout
events = dict(poller.poll(timeout_ms))
if not events:
raise TimeoutError("Timeout waiting for output")
if stdin_socket in events:
req = self.stdin_channel.get_msg(timeout=0)
stdin_hook(req)
continue
if iopub_socket not in events:
continue
msg = self.iopub_channel.get_msg(timeout=0)
if msg['parent_header'].get('msg_id') != msg_id:
# not from my request
continue
output_hook(msg)
# stop on idle
if msg['header']['msg_type'] == 'status' and \
msg['content']['execution_state'] == 'idle':
break
# output is done, get the reply
if timeout is not None:
timeout = max(0, deadline - monotonic())
return self._recv_reply(msg_id, timeout=timeout)

View file

@ -0,0 +1,213 @@
"""Base classes to manage a Client's interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import atexit
import errno
from threading import Thread, Event
import time
import asyncio
import zmq
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit:
from zmq import ZMQError
from jupyter_client import protocol_version_info
from .channelsabc import HBChannelABC
#-----------------------------------------------------------------------------
# Constants and exceptions
#-----------------------------------------------------------------------------
major_protocol_version = protocol_version_info[0]
class InvalidPortNumber(Exception):
pass
class HBChannel(Thread):
"""The heartbeat channel which monitors the kernel heartbeat.
Note that the heartbeat channel is paused by default. As long as you start
this channel, the kernel manager will ensure that it is paused and un-paused
as appropriate.
"""
context = None
session = None
socket = None
address = None
_exiting = False
time_to_dead = 1.
poller = None
_running = None
_pause = None
_beating = None
def __init__(self, context=None, session=None, address=None, loop=None):
"""Create the heartbeat monitor thread.
Parameters
----------
context : :class:`zmq.Context`
The ZMQ context to use.
session : :class:`session.Session`
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
"""
super().__init__()
self.daemon = True
self.loop = loop
self.context = context
self.session = session
if isinstance(address, tuple):
if address[1] == 0:
message = 'The port number for a channel cannot be 0.'
raise InvalidPortNumber(message)
address = "tcp://%s:%i" % address
self.address = address
# running is False until `.start()` is called
self._running = False
self._exit = Event()
# don't start paused
self._pause = False
self.poller = zmq.Poller()
@staticmethod
@atexit.register
def _notice_exit():
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if HBChannel is not None:
HBChannel._exiting = True
def _create_socket(self):
if self.socket is not None:
# close previous socket, before opening a new one
self.poller.unregister(self.socket)
self.socket.close()
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
self.socket.connect(self.address)
self.poller.register(self.socket, zmq.POLLIN)
def _poll(self, start_time):
"""poll for heartbeat replies until we reach self.time_to_dead.
Ignores interrupts, and returns the result of poll(), which
will be an empty list if no messages arrived before the timeout,
or the event tuple if there is a message to receive.
"""
until_dead = self.time_to_dead - (time.time() - start_time)
# ensure poll at least once
until_dead = max(until_dead, 1e-3)
events = []
while True:
try:
events = self.poller.poll(1000 * until_dead)
except ZMQError as e:
if e.errno == errno.EINTR:
# ignore interrupts during heartbeat
# this may never actually happen
until_dead = self.time_to_dead - (time.time() - start_time)
until_dead = max(until_dead, 1e-3)
pass
else:
raise
except Exception:
if self._exiting:
break
else:
raise
else:
break
return events
def run(self):
"""The thread's main activity. Call start() instead."""
if self.loop is not None:
asyncio.set_event_loop(self.loop)
self._create_socket()
self._running = True
self._beating = True
while self._running:
if self._pause:
# just sleep, and skip the rest of the loop
self._exit.wait(self.time_to_dead)
continue
since_last_heartbeat = 0.0
# no need to catch EFSM here, because the previous event was
# either a recv or connect, which cannot be followed by EFSM
self.socket.send(b'ping')
request_time = time.time()
ready = self._poll(request_time)
if ready:
self._beating = True
# the poll above guarantees we have something to recv
self.socket.recv()
# sleep the remainder of the cycle
remainder = self.time_to_dead - (time.time() - request_time)
if remainder > 0:
self._exit.wait(remainder)
continue
else:
# nothing was received within the time limit, signal heart failure
self._beating = False
since_last_heartbeat = time.time() - request_time
self.call_handlers(since_last_heartbeat)
# and close/reopen the socket, because the REQ/REP cycle has been broken
self._create_socket()
continue
def pause(self):
"""Pause the heartbeat."""
self._pause = True
def unpause(self):
"""Unpause the heartbeat."""
self._pause = False
def is_beating(self):
"""Is the heartbeat running and responsive (and not paused)."""
if self.is_alive() and not self._pause and self._beating:
return True
else:
return False
def stop(self):
"""Stop the channel's event loop and join its thread."""
self._running = False
self._exit.set()
self.join()
self.close()
def close(self):
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def call_handlers(self, since_last_heartbeat):
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
HBChannelABC.register(HBChannel)

View file

@ -0,0 +1,47 @@
"""Abstract base classes for kernel client channels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
class ChannelABC(object, metaclass=abc.ABCMeta):
"""A base class for all channel ABCs."""
@abc.abstractmethod
def start(self):
pass
@abc.abstractmethod
def stop(self):
pass
@abc.abstractmethod
def is_alive(self):
pass
class HBChannelABC(ChannelABC):
"""HBChannel ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.channels.HBChannel`
"""
@abc.abstractproperty
def time_to_dead(self):
pass
@abc.abstractmethod
def pause(self):
pass
@abc.abstractmethod
def unpause(self):
pass
@abc.abstractmethod
def is_beating(self):
pass

View file

@ -0,0 +1,445 @@
"""Base class to manage the interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyter_client.channels import major_protocol_version
import zmq
from traitlets import (
Any, Instance, Type,
)
from .channelsabc import (ChannelABC, HBChannelABC)
from .clientabc import KernelClientABC
from .connect import ConnectionFileMixin
# some utilities to validate message structure, these might get moved elsewhere
# if they prove to have more generic utility
def validate_string_dict(dct):
"""Validate that the input is a dict with string keys and values.
Raises ValueError if not."""
for k,v in dct.items():
if not isinstance(k, str):
raise ValueError('key %r in dict must be a string' % k)
if not isinstance(v, str):
raise ValueError('value %r in dict must be a string' % v)
class KernelClient(ConnectionFileMixin):
"""Communicates with a single kernel on any host via zmq channels.
There are five channels associated with each kernel:
* shell: for request/reply calls to the kernel.
* iopub: for the kernel to publish results to frontends.
* hb: for monitoring the kernel's heartbeat.
* stdin: for frontends to reply to raw_input calls in the kernel.
* control: for kernel management calls to the kernel.
The messages that can be sent on these channels are exposed as methods of the
client (KernelClient.execute, complete, history, etc.). These methods only
send the message, they don't wait for a reply. To get results, use e.g.
:meth:`get_shell_msg` to fetch messages from the shell channel.
"""
# The PyZMQ Context to use for communication with the kernel.
context = Instance(zmq.Context)
def _context_default(self):
return zmq.Context()
# The classes to use for the various channels
shell_channel_class = Type(ChannelABC)
iopub_channel_class = Type(ChannelABC)
stdin_channel_class = Type(ChannelABC)
hb_channel_class = Type(HBChannelABC)
control_channel_class = Type(ChannelABC)
# Protected traits
_shell_channel = Any()
_iopub_channel = Any()
_stdin_channel = Any()
_hb_channel = Any()
_control_channel = Any()
# flag for whether execute requests should be allowed to call raw_input:
allow_stdin = True
#--------------------------------------------------------------------------
# Channel proxy methods
#--------------------------------------------------------------------------
def get_shell_msg(self, *args, **kwargs):
"""Get a message from the shell channel"""
return self.shell_channel.get_msg(*args, **kwargs)
def get_iopub_msg(self, *args, **kwargs):
"""Get a message from the iopub channel"""
return self.iopub_channel.get_msg(*args, **kwargs)
def get_stdin_msg(self, *args, **kwargs):
"""Get a message from the stdin channel"""
return self.stdin_channel.get_msg(*args, **kwargs)
def get_control_msg(self, *args, **kwargs):
"""Get a message from the control channel"""
return self.control_channel.get_msg(*args, **kwargs)
#--------------------------------------------------------------------------
# Channel management methods
#--------------------------------------------------------------------------
def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True):
"""Starts the channels for this kernel.
This will create the channels if they do not exist and then start
them (their activity runs in a thread). If port numbers of 0 are
being used (random ports) then you must first call
:meth:`start_kernel`. If the channels have been stopped and you
call this, :class:`RuntimeError` will be raised.
"""
if shell:
self.shell_channel.start()
self.kernel_info()
if iopub:
self.iopub_channel.start()
if stdin:
self.stdin_channel.start()
self.allow_stdin = True
else:
self.allow_stdin = False
if hb:
self.hb_channel.start()
if control:
self.control_channel.start()
def stop_channels(self):
"""Stops all the running channels for this kernel.
This stops their event loops and joins their threads.
"""
if self.shell_channel.is_alive():
self.shell_channel.stop()
if self.iopub_channel.is_alive():
self.iopub_channel.stop()
if self.stdin_channel.is_alive():
self.stdin_channel.stop()
if self.hb_channel.is_alive():
self.hb_channel.stop()
if self.control_channel.is_alive():
self.control_channel.stop()
@property
def channels_running(self):
"""Are any of the channels created and running?"""
return (self.shell_channel.is_alive() or self.iopub_channel.is_alive() or
self.stdin_channel.is_alive() or self.hb_channel.is_alive() or
self.control_channel.is_alive())
ioloop = None # Overridden in subclasses that use pyzmq event loop
@property
def shell_channel(self):
"""Get the shell channel object for this kernel."""
if self._shell_channel is None:
url = self._make_url('shell')
self.log.debug("connecting shell channel to %s", url)
socket = self.connect_shell(identity=self.session.bsession)
self._shell_channel = self.shell_channel_class(
socket, self.session, self.ioloop
)
return self._shell_channel
@property
def iopub_channel(self):
"""Get the iopub channel object for this kernel."""
if self._iopub_channel is None:
url = self._make_url('iopub')
self.log.debug("connecting iopub channel to %s", url)
socket = self.connect_iopub()
self._iopub_channel = self.iopub_channel_class(
socket, self.session, self.ioloop
)
return self._iopub_channel
@property
def stdin_channel(self):
"""Get the stdin channel object for this kernel."""
if self._stdin_channel is None:
url = self._make_url('stdin')
self.log.debug("connecting stdin channel to %s", url)
socket = self.connect_stdin(identity=self.session.bsession)
self._stdin_channel = self.stdin_channel_class(
socket, self.session, self.ioloop
)
return self._stdin_channel
@property
def hb_channel(self):
"""Get the hb channel object for this kernel."""
if self._hb_channel is None:
url = self._make_url('hb')
self.log.debug("connecting heartbeat channel to %s", url)
self._hb_channel = self.hb_channel_class(
self.context, self.session, url
)
return self._hb_channel
@property
def control_channel(self):
"""Get the control channel object for this kernel."""
if self._control_channel is None:
url = self._make_url('control')
self.log.debug("connecting control channel to %s", url)
socket = self.connect_control(identity=self.session.bsession)
self._control_channel = self.control_channel_class(
socket, self.session, self.ioloop
)
return self._control_channel
def is_alive(self):
"""Is the kernel process still running?"""
from .manager import KernelManager
if isinstance(self.parent, KernelManager):
# This KernelClient was created by a KernelManager,
# we can ask the parent KernelManager:
return self.parent.is_alive()
if self._hb_channel is not None:
# We don't have access to the KernelManager,
# so we use the heartbeat.
return self._hb_channel.is_beating()
else:
# no heartbeat and not local, we can't tell if it's running,
# so naively return True
return True
# Methods to send specific messages on channels
def execute(self, code, silent=False, store_history=True,
user_expressions=None, allow_stdin=None, stop_on_error=True):
"""Execute code in the kernel.
Parameters
----------
code : str
A string of code in the kernel's language.
silent : bool, optional (default False)
If set, the kernel will execute the code as quietly possible, and
will force store_history to be False.
store_history : bool, optional (default True)
If set, the kernel will store command history. This is forced
to be False if silent is True.
user_expressions : dict, optional
A dict mapping names to expressions to be evaluated in the user's
dict. The expression values are returned as strings formatted using
:func:`repr`.
allow_stdin : bool, optional (default self.allow_stdin)
Flag for whether the kernel can send stdin requests to frontends.
Some frontends (e.g. the Notebook) do not support stdin requests.
If raw_input is called from code executed from such a frontend, a
StdinNotImplementedError will be raised.
stop_on_error: bool, optional (default True)
Flag whether to abort the execution queue, if an exception is encountered.
Returns
-------
The msg_id of the message sent.
"""
if user_expressions is None:
user_expressions = {}
if allow_stdin is None:
allow_stdin = self.allow_stdin
# Don't waste network traffic if inputs are invalid
if not isinstance(code, str):
raise ValueError('code %r must be a string' % code)
validate_string_dict(user_expressions)
# Create class for content/msg creation. Related to, but possibly
# not in Session.
content = dict(code=code, silent=silent, store_history=store_history,
user_expressions=user_expressions,
allow_stdin=allow_stdin, stop_on_error=stop_on_error
)
msg = self.session.msg('execute_request', content)
self.shell_channel.send(msg)
return msg['header']['msg_id']
def complete(self, code, cursor_pos=None):
"""Tab complete text in the kernel's namespace.
Parameters
----------
code : str
The context in which completion is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the completion was requested.
Default: ``len(code)``
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = dict(code=code, cursor_pos=cursor_pos)
msg = self.session.msg('complete_request', content)
self.shell_channel.send(msg)
return msg['header']['msg_id']
def inspect(self, code, cursor_pos=None, detail_level=0):
"""Get metadata information about an object in the kernel's namespace.
It is up to the kernel to determine the appropriate object to inspect.
Parameters
----------
code : str
The context in which info is requested.
Can be anything between a variable name and an entire cell.
cursor_pos : int, optional
The position of the cursor in the block of code where the info was requested.
Default: ``len(code)``
detail_level : int, optional
The level of detail for the introspection (0-2)
Returns
-------
The msg_id of the message sent.
"""
if cursor_pos is None:
cursor_pos = len(code)
content = dict(code=code, cursor_pos=cursor_pos,
detail_level=detail_level,
)
msg = self.session.msg('inspect_request', content)
self.shell_channel.send(msg)
return msg['header']['msg_id']
def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
"""Get entries from the kernel's history list.
Parameters
----------
raw : bool
If True, return the raw input.
output : bool
If True, then return the output as well.
hist_access_type : str
'range' (fill in session, start and stop params), 'tail' (fill in n)
or 'search' (fill in pattern param).
session : int
For a range request, the session from which to get lines. Session
numbers are positive integers; negative ones count back from the
current session.
start : int
The first line number of a history range.
stop : int
The final (excluded) line number of a history range.
n : int
The number of lines of history to get for a tail request.
pattern : str
The glob-syntax pattern for a search request.
Returns
-------
The ID of the message sent.
"""
if hist_access_type == 'range':
kwargs.setdefault('session', 0)
kwargs.setdefault('start', 0)
content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
**kwargs)
msg = self.session.msg('history_request', content)
self.shell_channel.send(msg)
return msg['header']['msg_id']
def kernel_info(self):
"""Request kernel info
Returns
-------
The msg_id of the message sent
"""
msg = self.session.msg('kernel_info_request')
self.shell_channel.send(msg)
return msg['header']['msg_id']
def comm_info(self, target_name=None):
"""Request comm info
Returns
-------
The msg_id of the message sent
"""
if target_name is None:
content = {}
else:
content = dict(target_name=target_name)
msg = self.session.msg('comm_info_request', content)
self.shell_channel.send(msg)
return msg['header']['msg_id']
def _handle_kernel_info_reply(self, msg):
"""handle kernel info reply
sets protocol adaptation version. This might
be run from a separate thread.
"""
adapt_version = int(msg['content']['protocol_version'].split('.')[0])
if adapt_version != major_protocol_version:
self.session.adapt_version = adapt_version
def is_complete(self, code):
"""Ask the kernel whether some code is complete and ready to execute."""
msg = self.session.msg('is_complete_request', {'code': code})
self.shell_channel.send(msg)
return msg['header']['msg_id']
def input(self, string):
"""Send a string of raw input to the kernel.
This should only be called in response to the kernel sending an
``input_request`` message on the stdin channel.
"""
content = dict(value=string)
msg = self.session.msg('input_reply', content)
self.stdin_channel.send(msg)
def shutdown(self, restart=False):
"""Request an immediate kernel shutdown on the control channel.
Upon receipt of the (empty) reply, client code can safely assume that
the kernel has shut down and it's safe to forcefully terminate it if
it's still alive.
The kernel will send the reply via a function registered with Python's
atexit module, ensuring it's truly done as the kernel is done with all
normal operation.
Returns
-------
The msg_id of the message sent
"""
# Send quit message to kernel. Once we implement kernel-side setattr,
# this should probably be done that way, but for now this will do.
msg = self.session.msg('shutdown_request', {'restart':restart})
self.control_channel.send(msg)
return msg['header']['msg_id']
KernelClientABC.register(KernelClient)

View file

@ -0,0 +1,87 @@
"""Abstract base class for kernel clients"""
#-----------------------------------------------------------------------------
# Copyright (c) The Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
# Imports
#-----------------------------------------------------------------------------
import abc
#-----------------------------------------------------------------------------
# Main kernel client class
#-----------------------------------------------------------------------------
class KernelClientABC(object, metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.client.KernelClient`
"""
@abc.abstractproperty
def kernel(self):
pass
@abc.abstractproperty
def shell_channel_class(self):
pass
@abc.abstractproperty
def iopub_channel_class(self):
pass
@abc.abstractproperty
def hb_channel_class(self):
pass
@abc.abstractproperty
def stdin_channel_class(self):
pass
@abc.abstractproperty
def control_channel_class(self):
pass
#--------------------------------------------------------------------------
# Channel management methods
#--------------------------------------------------------------------------
@abc.abstractmethod
def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True):
pass
@abc.abstractmethod
def stop_channels(self):
pass
@abc.abstractproperty
def channels_running(self):
pass
@abc.abstractproperty
def shell_channel(self):
pass
@abc.abstractproperty
def iopub_channel(self):
pass
@abc.abstractproperty
def stdin_channel(self):
pass
@abc.abstractproperty
def hb_channel(self):
pass
@abc.abstractproperty
def control_channel(self):
pass

View file

@ -0,0 +1,580 @@
"""Utilities for connecting to jupyter kernels
The :class:`ConnectionFileMixin` class in this module encapsulates the logic
related to writing and reading connections files.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import glob
import json
import os
import socket
import stat
import tempfile
import warnings
from getpass import getpass
from contextlib import contextmanager
import zmq
from traitlets.config import LoggingConfigurable
from .localinterfaces import localhost
from ipython_genutils.path import filefind
from ipython_genutils.py3compat import (
bytes_to_str, cast_bytes,
)
from traitlets import (
Bool, Integer, Unicode, CaselessStrEnum, Instance, Type, observe
)
from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write
def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
control_port=0, ip='', key=b'', transport='tcp',
signature_scheme='hmac-sha256', kernel_name=''
):
"""Generates a JSON config file, including the selection of random ports.
Parameters
----------
fname : unicode
The path to the file to write
shell_port : int, optional
The port to use for ROUTER (shell) channel.
iopub_port : int, optional
The port to use for the SUB channel.
stdin_port : int, optional
The port to use for the ROUTER (raw input) channel.
control_port : int, optional
The port to use for the ROUTER (control) channel.
hb_port : int, optional
The port to use for the heartbeat REP channel.
ip : str, optional
The ip address the kernel will bind to.
key : str, optional
The Session key used for message authentication.
signature_scheme : str, optional
The scheme used for message authentication.
This has the form 'digest-hash', where 'digest'
is the scheme used for digests, and 'hash' is the name of the hash function
used by the digest scheme.
Currently, 'hmac' is the only supported digest scheme,
and 'sha256' is the default hash function.
kernel_name : str, optional
The name of the kernel currently connected to.
"""
if not ip:
ip = localhost()
# default to temporary connector file
if not fname:
fd, fname = tempfile.mkstemp('.json')
os.close(fd)
# Find open ports as necessary.
ports = []
ports_needed = int(shell_port <= 0) + \
int(iopub_port <= 0) + \
int(stdin_port <= 0) + \
int(control_port <= 0) + \
int(hb_port <= 0)
if transport == 'tcp':
for i in range(ports_needed):
sock = socket.socket()
# struct.pack('ii', (0,0)) is 8 null bytes
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8)
sock.bind((ip, 0))
ports.append(sock)
for i, sock in enumerate(ports):
port = sock.getsockname()[1]
sock.close()
ports[i] = port
else:
N = 1
for i in range(ports_needed):
while os.path.exists("%s-%s" % (ip, str(N))):
N += 1
ports.append(N)
N += 1
if shell_port <= 0:
shell_port = ports.pop(0)
if iopub_port <= 0:
iopub_port = ports.pop(0)
if stdin_port <= 0:
stdin_port = ports.pop(0)
if control_port <= 0:
control_port = ports.pop(0)
if hb_port <= 0:
hb_port = ports.pop(0)
cfg = dict( shell_port=shell_port,
iopub_port=iopub_port,
stdin_port=stdin_port,
control_port=control_port,
hb_port=hb_port,
)
cfg['ip'] = ip
cfg['key'] = bytes_to_str(key)
cfg['transport'] = transport
cfg['signature_scheme'] = signature_scheme
cfg['kernel_name'] = kernel_name
# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
# which would let others execute arbitrarily code as you
with secure_write(fname) as f:
f.write(json.dumps(cfg, indent=2))
if hasattr(stat, 'S_ISVTX'):
# set the sticky bit on the file and its parent directory
# to avoid periodic cleanup
paths = [fname]
runtime_dir = os.path.dirname(fname)
if runtime_dir:
paths.append(runtime_dir)
for path in paths:
permissions = os.stat(path).st_mode
new_permissions = permissions | stat.S_ISVTX
if new_permissions != permissions:
try:
os.chmod(path, new_permissions)
except OSError as e:
if e.errno == errno.EPERM and path == runtime_dir:
# suppress permission errors setting sticky bit on runtime_dir,
# which we may not own.
pass
else:
# failed to set sticky bit, probably not a big deal
warnings.warn(
"Failed to set sticky bit on %r: %s"
"\nProbably not a big deal, but runtime files may be cleaned up periodically." % (path, e),
RuntimeWarning,
)
return fname, cfg
def find_connection_file(filename='kernel-*.json', path=None, profile=None):
"""find a connection file, and return its absolute path.
The current working directory and optional search path
will be searched for the file if it is not given by absolute path.
If the argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the profile's security dir with
the latest access time will be used.
Parameters
----------
filename : str
The connection file or fileglob to search for.
path : str or list of strs[optional]
Paths in which to search for connection files.
Returns
-------
str : The absolute path of the connection file.
"""
if profile is not None:
warnings.warn("Jupyter has no profiles. profile=%s has been ignored." % profile)
if path is None:
path = ['.', jupyter_runtime_dir()]
if isinstance(path, str):
path = [path]
try:
# first, try explicit name
return filefind(filename, path)
except IOError:
pass
# not found by full name
if '*' in filename:
# given as a glob already
pat = filename
else:
# accept any substring match
pat = '*%s*' % filename
matches = []
for p in path:
matches.extend(glob.glob(os.path.join(p, pat)))
matches = [ os.path.abspath(m) for m in matches ]
if not matches:
raise IOError("Could not find %r in %r" % (filename, path))
elif len(matches) == 1:
return matches[0]
else:
# get most recent match, by access time:
return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
"""tunnel connections to a kernel via ssh
This will open five SSH tunnels from localhost on this machine to the
ports associated with the kernel. They can be either direct
localhost-localhost tunnels, or if an intermediate server is necessary,
the kernel must be listening on a public IP.
Parameters
----------
connection_info : dict or str (path)
Either a connection dict, or the path to a JSON connection file
sshserver : str
The ssh sever to use to tunnel to the kernel. Can be a full
`user@server:port` string. ssh config aliases are respected.
sshkey : str [optional]
Path to file containing ssh key to use for authentication.
Only necessary if your ssh config does not already associate
a keyfile with the host.
Returns
-------
(shell, iopub, stdin, hb, control) : ints
The five ports on localhost that have been forwarded to the kernel.
"""
from .ssh import tunnel
if isinstance(connection_info, str):
# it's a path, unpack it
with open(connection_info) as f:
connection_info = json.loads(f.read())
cf = connection_info
lports = tunnel.select_random_ports(5)
rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port'], cf['control_port']
remote_ip = cf['ip']
if tunnel.try_passwordless_ssh(sshserver, sshkey):
password=False
else:
password = getpass("SSH Password for %s: " % sshserver)
for lp,rp in zip(lports, rports):
tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
return tuple(lports)
#-----------------------------------------------------------------------------
# Mixin for classes that work with connection files
#-----------------------------------------------------------------------------
channel_socket_types = {
'hb' : zmq.REQ,
'shell' : zmq.DEALER,
'iopub' : zmq.SUB,
'stdin' : zmq.DEALER,
'control': zmq.DEALER,
}
port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
class ConnectionFileMixin(LoggingConfigurable):
"""Mixin for configurable classes that work with connection files"""
data_dir = Unicode()
def _data_dir_default(self):
return jupyter_data_dir()
# The addresses for the communication channels
connection_file = Unicode('', config=True,
help="""JSON file in which to store connection info [default: kernel-<pid>.json]
This file will contain the IP, ports, and authentication key needed to connect
clients to this kernel. By default, this file will be created in the security dir
of the current profile, but can be specified by absolute path.
""")
_connection_file_written = Bool(False)
transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
kernel_name = Unicode()
ip = Unicode(config=True,
help="""Set the kernel\'s IP address [default localhost].
If the IP address is something other than localhost, then
Consoles on other machines will be able to connect
to the Kernel, so be careful!"""
)
def _ip_default(self):
if self.transport == 'ipc':
if self.connection_file:
return os.path.splitext(self.connection_file)[0] + '-ipc'
else:
return 'kernel-ipc'
else:
return localhost()
@observe('ip')
def _ip_changed(self, change):
if change['new'] == '*':
self.ip = '0.0.0.0'
# protected traits
hb_port = Integer(0, config=True,
help="set the heartbeat port [default: random]")
shell_port = Integer(0, config=True,
help="set the shell (ROUTER) port [default: random]")
iopub_port = Integer(0, config=True,
help="set the iopub (PUB) port [default: random]")
stdin_port = Integer(0, config=True,
help="set the stdin (ROUTER) port [default: random]")
control_port = Integer(0, config=True,
help="set the control (ROUTER) port [default: random]")
# names of the ports with random assignment
_random_port_names = None
@property
def ports(self):
return [ getattr(self, name) for name in port_names ]
# The Session to use for communication with the kernel.
session = Instance('jupyter_client.session.Session')
def _session_default(self):
from jupyter_client.session import Session
return Session(parent=self)
#--------------------------------------------------------------------------
# Connection and ipc file management
#--------------------------------------------------------------------------
def get_connection_info(self, session=False):
"""Return the connection info as a dict
Parameters
----------
session : bool [default: False]
If True, return our session object will be included in the connection info.
If False (default), the configuration parameters of our session object will be included,
rather than the session object itself.
Returns
-------
connect_info : dict
dictionary of connection information.
"""
info = dict(
transport=self.transport,
ip=self.ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
)
if session:
# add *clone* of my session,
# so that state such as digest_history is not shared.
info['session'] = self.session.clone()
else:
# add session info
info.update(dict(
signature_scheme=self.session.signature_scheme,
key=self.session.key,
))
return info
# factory for blocking clients
blocking_class = Type(klass=object, default_value='jupyter_client.BlockingKernelClient')
def blocking_client(self):
"""Make a blocking client connected to my kernel"""
info = self.get_connection_info()
info['parent'] = self
bc = self.blocking_class(**info)
bc.session.key = self.session.key
return bc
def cleanup_connection_file(self):
"""Cleanup connection file *if we wrote it*
Will not raise if the connection file was already removed somehow.
"""
if self._connection_file_written:
# cleanup connection files on full shutdown of kernel we started
self._connection_file_written = False
try:
os.remove(self.connection_file)
except (IOError, OSError, AttributeError):
pass
def cleanup_ipc_files(self):
"""Cleanup ipc files if we wrote them."""
if self.transport != 'ipc':
return
for port in self.ports:
ipcfile = "%s-%i" % (self.ip, port)
try:
os.remove(ipcfile)
except (IOError, OSError):
pass
def _record_random_port_names(self):
"""Records which of the ports are randomly assigned.
Records on first invocation, if the transport is tcp.
Does nothing on later invocations."""
if self.transport != 'tcp':
return
if self._random_port_names is not None:
return
self._random_port_names = []
for name in port_names:
if getattr(self, name) <= 0:
self._random_port_names.append(name)
def cleanup_random_ports(self):
"""Forgets randomly assigned port numbers and cleans up the connection file.
Does nothing if no port numbers have been randomly assigned.
In particular, does nothing unless the transport is tcp.
"""
if not self._random_port_names:
return
for name in self._random_port_names:
setattr(self, name, 0)
self.cleanup_connection_file()
def write_connection_file(self):
"""Write connection info to JSON dict in self.connection_file."""
if self._connection_file_written and os.path.exists(self.connection_file):
return
self.connection_file, cfg = write_connection_file(self.connection_file,
transport=self.transport, ip=self.ip, key=self.session.key,
stdin_port=self.stdin_port, iopub_port=self.iopub_port,
shell_port=self.shell_port, hb_port=self.hb_port,
control_port=self.control_port,
signature_scheme=self.session.signature_scheme,
kernel_name=self.kernel_name
)
# write_connection_file also sets default ports:
self._record_random_port_names()
for name in port_names:
setattr(self, name, cfg[name])
self._connection_file_written = True
def load_connection_file(self, connection_file=None):
"""Load connection info from JSON dict in self.connection_file.
Parameters
----------
connection_file: unicode, optional
Path to connection file to load.
If unspecified, use self.connection_file
"""
if connection_file is None:
connection_file = self.connection_file
self.log.debug("Loading connection file %s", connection_file)
with open(connection_file) as f:
info = json.load(f)
self.load_connection_info(info)
def load_connection_info(self, info):
"""Load connection info from a dict containing connection info.
Typically this data comes from a connection file
and is called by load_connection_file.
Parameters
----------
info: dict
Dictionary containing connection_info.
See the connection_file spec for details.
"""
self.transport = info.get('transport', self.transport)
self.ip = info.get('ip', self._ip_default())
self._record_random_port_names()
for name in port_names:
if getattr(self, name) == 0 and name in info:
# not overridden by config or cl_args
setattr(self, name, info[name])
if 'key' in info:
self.session.key = cast_bytes(info['key'])
if 'signature_scheme' in info:
self.session.signature_scheme = info['signature_scheme']
#--------------------------------------------------------------------------
# Creating connected sockets
#--------------------------------------------------------------------------
def _make_url(self, channel):
"""Make a ZeroMQ URL for a given channel."""
transport = self.transport
ip = self.ip
port = getattr(self, '%s_port' % channel)
if transport == 'tcp':
return "tcp://%s:%i" % (ip, port)
else:
return "%s://%s-%s" % (transport, ip, port)
def _create_connected_socket(self, channel, identity=None):
"""Create a zmq Socket and connect it to the kernel."""
url = self._make_url(channel)
socket_type = channel_socket_types[channel]
self.log.debug("Connecting to: %s" % url)
sock = self.context.socket(socket_type)
# set linger to 1s to prevent hangs at exit
sock.linger = 1000
if identity:
sock.identity = identity
sock.connect(url)
return sock
def connect_iopub(self, identity=None):
"""return zmq Socket connected to the IOPub channel"""
sock = self._create_connected_socket('iopub', identity=identity)
sock.setsockopt(zmq.SUBSCRIBE, b'')
return sock
def connect_shell(self, identity=None):
"""return zmq Socket connected to the Shell channel"""
return self._create_connected_socket('shell', identity=identity)
def connect_stdin(self, identity=None):
"""return zmq Socket connected to the StdIn channel"""
return self._create_connected_socket('stdin', identity=identity)
def connect_hb(self, identity=None):
"""return zmq Socket connected to the Heartbeat channel"""
return self._create_connected_socket('hb', identity=identity)
def connect_control(self, identity=None):
"""return zmq Socket connected to the Control channel"""
return self._create_connected_socket('control', identity=identity)
__all__ = [
'write_connection_file',
'find_connection_file',
'tunnel_to_kernel',
]

View file

@ -0,0 +1,347 @@
""" A minimal application base mixin for all ZMQ based IPython frontends.
This is not a complete console app, as subprocess will not be able to receive
input, there is no real readline support, among other limitations. This is a
refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import atexit
import os
import signal
import sys
import uuid
import warnings
from traitlets.config.application import boolean_flag
from ipython_genutils.path import filefind
from traitlets import (
Dict, List, Unicode, CUnicode, CBool, Any, Type
)
from jupyter_core.application import base_flags, base_aliases
from .blocking import BlockingKernelClient
from .restarter import KernelRestarter
from . import KernelManager, tunnel_to_kernel, find_connection_file, connect
from .kernelspec import NoSuchKernel
from .session import Session
ConnectionFileMixin = connect.ConnectionFileMixin
from .localinterfaces import localhost
#-----------------------------------------------------------------------------
# Aliases and Flags
#-----------------------------------------------------------------------------
flags = {}
flags.update(base_flags)
# the flags that are specific to the frontend
# these must be scrubbed before being passed to the kernel,
# or it will raise an error on unrecognized flags
app_flags = {
'existing' : ({'JupyterConsoleApp' : {'existing' : 'kernel*.json'}},
"Connect to an existing kernel. If no argument specified, guess most recent"),
}
app_flags.update(boolean_flag(
'confirm-exit', 'JupyterConsoleApp.confirm_exit',
"""Set to display confirmation dialog on exit. You can always use 'exit' or
'quit', to force a direct exit without any confirmation. This can also
be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
""",
"""Don't prompt the user when exiting. This will terminate the kernel
if it is owned by the frontend, and leave it alive if it is external.
This can also be set in the config file by setting
`c.JupyterConsoleApp.confirm_exit`.
"""
))
flags.update(app_flags)
aliases = {}
aliases.update(base_aliases)
# also scrub aliases from the frontend
app_aliases = dict(
ip = 'JupyterConsoleApp.ip',
transport = 'JupyterConsoleApp.transport',
hb = 'JupyterConsoleApp.hb_port',
shell = 'JupyterConsoleApp.shell_port',
iopub = 'JupyterConsoleApp.iopub_port',
stdin = 'JupyterConsoleApp.stdin_port',
control = 'JupyterConsoleApp.control_port',
existing = 'JupyterConsoleApp.existing',
f = 'JupyterConsoleApp.connection_file',
kernel = 'JupyterConsoleApp.kernel_name',
ssh = 'JupyterConsoleApp.sshserver',
)
aliases.update(app_aliases)
#-----------------------------------------------------------------------------
# Classes
#-----------------------------------------------------------------------------
classes = [KernelManager, KernelRestarter, Session]
class JupyterConsoleApp(ConnectionFileMixin):
name = 'jupyter-console-mixin'
description = """
The Jupyter Console Mixin.
This class contains the common portions of console client (QtConsole,
ZMQ-based terminal console, etc). It is not a full console, in that
launched terminal subprocesses will not be able to accept input.
The Console using this mixing supports various extra features beyond
the single-process Terminal IPython shell, such as connecting to
existing kernel, via:
jupyter console <appname> --existing
as well as tunnel via SSH
"""
classes = classes
flags = Dict(flags)
aliases = Dict(aliases)
kernel_manager_class = Type(
default_value=KernelManager,
config=True,
help='The kernel manager class to use.'
)
kernel_client_class = BlockingKernelClient
kernel_argv = List(Unicode())
# connection info:
sshserver = Unicode('', config=True,
help="""The SSH server to use to connect to the kernel.""")
sshkey = Unicode('', config=True,
help="""Path to the ssh key to use for logging in to the ssh server.""")
def _connection_file_default(self):
return 'kernel-%i.json' % os.getpid()
existing = CUnicode('', config=True,
help="""Connect to an already running kernel""")
kernel_name = Unicode('python', config=True,
help="""The name of the default kernel to start.""")
confirm_exit = CBool(True, config=True,
help="""
Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
to force a direct exit without any confirmation.""",
)
def build_kernel_argv(self, argv=None):
"""build argv to be passed to kernel subprocess
Override in subclasses if any args should be passed to the kernel
"""
self.kernel_argv = self.extra_args
def init_connection_file(self):
"""find the connection file, and load the info if found.
The current working directory and the current profile's security
directory will be searched for the file if it is not given by
absolute path.
When attempting to connect to an existing kernel and the `--existing`
argument does not match an existing file, it will be interpreted as a
fileglob, and the matching file in the current profile's security dir
with the latest access time will be used.
After this method is called, self.connection_file contains the *full path*
to the connection file, never just its name.
"""
if self.existing:
try:
cf = find_connection_file(self.existing, ['.', self.runtime_dir])
except Exception:
self.log.critical("Could not find existing kernel connection file %s", self.existing)
self.exit(1)
self.log.debug("Connecting to existing kernel: %s" % cf)
self.connection_file = cf
else:
# not existing, check if we are going to write the file
# and ensure that self.connection_file is a full path, not just the shortname
try:
cf = find_connection_file(self.connection_file, [self.runtime_dir])
except Exception:
# file might not exist
if self.connection_file == os.path.basename(self.connection_file):
# just shortname, put it in security dir
cf = os.path.join(self.runtime_dir, self.connection_file)
else:
cf = self.connection_file
self.connection_file = cf
try:
self.connection_file = filefind(self.connection_file, ['.', self.runtime_dir])
except IOError:
self.log.debug("Connection File not found: %s", self.connection_file)
return
# should load_connection_file only be used for existing?
# as it is now, this allows reusing ports if an existing
# file is requested
try:
self.load_connection_file()
except Exception:
self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
self.exit(1)
def init_ssh(self):
"""set up ssh tunnels, if needed."""
if not self.existing or (not self.sshserver and not self.sshkey):
return
self.load_connection_file()
transport = self.transport
ip = self.ip
if transport != 'tcp':
self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
sys.exit(-1)
if self.sshkey and not self.sshserver:
# specifying just the key implies that we are connecting directly
self.sshserver = ip
ip = localhost()
# build connection dict for tunnels:
info = dict(ip=ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port
)
self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver))
# tunnels return a new set of ports, which will be on localhost:
self.ip = localhost()
try:
newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
except:
# even catch KeyboardInterrupt
self.log.error("Could not setup tunnels", exc_info=True)
self.exit(1)
self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port = newports
cf = self.connection_file
root, ext = os.path.splitext(cf)
self.connection_file = root + '-ssh' + ext
self.write_connection_file() # write the new connection file
self.log.info("To connect another client via this tunnel, use:")
self.log.info("--existing %s" % os.path.basename(self.connection_file))
def _new_connection_file(self):
cf = ''
while not cf:
# we don't need a 128b id to distinguish kernels, use more readable
# 48b node segment (12 hex chars). Users running more than 32k simultaneous
# kernels can subclass.
ident = str(uuid.uuid4()).split('-')[-1]
cf = os.path.join(self.runtime_dir, 'kernel-%s.json' % ident)
# only keep if it's actually new. Protect against unlikely collision
# in 48b random search space
cf = cf if not os.path.exists(cf) else ''
return cf
def init_kernel_manager(self):
# Don't let Qt or ZMQ swallow KeyboardInterupts.
if self.existing:
self.kernel_manager = None
return
signal.signal(signal.SIGINT, signal.SIG_DFL)
# Create a KernelManager and start a kernel.
try:
self.kernel_manager = self.kernel_manager_class(
ip=self.ip,
session=self.session,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
kernel_name=self.kernel_name,
parent=self,
data_dir=self.data_dir,
)
except NoSuchKernel:
self.log.critical("Could not find kernel %s", self.kernel_name)
self.exit(1)
self.kernel_manager.client_factory = self.kernel_client_class
kwargs = {}
kwargs['extra_arguments'] = self.kernel_argv
self.kernel_manager.start_kernel(**kwargs)
atexit.register(self.kernel_manager.cleanup_ipc_files)
if self.sshserver:
# ssh, write new connection file
self.kernel_manager.write_connection_file()
# in case KM defaults / ssh writing changes things:
km = self.kernel_manager
self.shell_port=km.shell_port
self.iopub_port=km.iopub_port
self.stdin_port=km.stdin_port
self.hb_port=km.hb_port
self.control_port=km.control_port
self.connection_file = km.connection_file
atexit.register(self.kernel_manager.cleanup_connection_file)
def init_kernel_client(self):
if self.kernel_manager is not None:
self.kernel_client = self.kernel_manager.client()
else:
self.kernel_client = self.kernel_client_class(
session=self.session,
ip=self.ip,
transport=self.transport,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
stdin_port=self.stdin_port,
hb_port=self.hb_port,
control_port=self.control_port,
connection_file=self.connection_file,
parent=self,
)
self.kernel_client.start_channels()
def initialize(self, argv=None):
"""
Classes which mix this class in should call:
JupyterConsoleApp.initialize(self,argv)
"""
if self._dispatching:
return
self.init_connection_file()
self.init_ssh()
self.init_kernel_manager()
self.init_kernel_client()
class IPythonConsoleApp(JupyterConsoleApp):
def __init__(self, *args, **kwargs):
warnings.warn("IPythonConsoleApp is deprecated. Use JupyterConsoleApp")
super().__init__(*args, **kwargs)

View file

@ -0,0 +1,2 @@
from .manager import IOLoopKernelManager, AsyncIOLoopKernelManager
from .restarter import IOLoopKernelRestarter, AsyncIOLoopKernelRestarter

View file

@ -0,0 +1,102 @@
"""A kernel manager with a tornado IOLoop"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from tornado import ioloop
from zmq.eventloop.zmqstream import ZMQStream
from traitlets import (
Instance,
Type,
)
from jupyter_client.manager import KernelManager, AsyncKernelManager
from .restarter import IOLoopKernelRestarter, AsyncIOLoopKernelRestarter
def as_zmqstream(f):
def wrapped(self, *args, **kwargs):
socket = f(self, *args, **kwargs)
return ZMQStream(socket, self.loop)
return wrapped
class IOLoopKernelManager(KernelManager):
loop = Instance('tornado.ioloop.IOLoop')
def _loop_default(self):
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=IOLoopKernelRestarter,
klass=IOLoopKernelRestarter,
help=(
'Type of KernelRestarter to use. '
'Must be a subclass of IOLoopKernelRestarter.\n'
'Override this to customize how kernel restarts are managed.'
),
config=True,
)
_restarter = Instance('jupyter_client.ioloop.IOLoopKernelRestarter', allow_none=True)
def start_restarter(self):
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop,
parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self):
if self.autorestart:
if self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(KernelManager.connect_shell)
connect_control = as_zmqstream(KernelManager.connect_control)
connect_iopub = as_zmqstream(KernelManager.connect_iopub)
connect_stdin = as_zmqstream(KernelManager.connect_stdin)
connect_hb = as_zmqstream(KernelManager.connect_hb)
class AsyncIOLoopKernelManager(AsyncKernelManager):
loop = Instance('tornado.ioloop.IOLoop')
def _loop_default(self):
return ioloop.IOLoop.current()
restarter_class = Type(
default_value=AsyncIOLoopKernelRestarter,
klass=AsyncIOLoopKernelRestarter,
help=(
'Type of KernelRestarter to use. '
'Must be a subclass of AsyncIOLoopKernelManager.\n'
'Override this to customize how kernel restarts are managed.'
),
config=True,
)
_restarter = Instance('jupyter_client.ioloop.AsyncIOLoopKernelRestarter', allow_none=True)
def start_restarter(self):
if self.autorestart and self.has_kernel:
if self._restarter is None:
self._restarter = self.restarter_class(
kernel_manager=self, loop=self.loop,
parent=self, log=self.log
)
self._restarter.start()
def stop_restarter(self):
if self.autorestart:
if self._restarter is not None:
self._restarter.stop()
connect_shell = as_zmqstream(AsyncKernelManager.connect_shell)
connect_control = as_zmqstream(AsyncKernelManager.connect_control)
connect_iopub = as_zmqstream(AsyncKernelManager.connect_iopub)
connect_stdin = as_zmqstream(AsyncKernelManager.connect_stdin)
connect_hb = as_zmqstream(AsyncKernelManager.connect_hb)

View file

@ -0,0 +1,81 @@
"""A basic in process kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import warnings
from zmq.eventloop import ioloop
from jupyter_client.restarter import KernelRestarter
from traitlets import (
Instance,
)
class IOLoopKernelRestarter(KernelRestarter):
"""Monitor and autorestart a kernel."""
loop = Instance('tornado.ioloop.IOLoop')
def _loop_default(self):
warnings.warn("IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2",
DeprecationWarning, stacklevel=4,
)
return ioloop.IOLoop.current()
_pcallback = None
def start(self):
"""Start the polling of the kernel."""
if self._pcallback is None:
self._pcallback = ioloop.PeriodicCallback(
self.poll, 1000*self.time_to_dead,
)
self._pcallback.start()
def stop(self):
"""Stop the kernel polling."""
if self._pcallback is not None:
self._pcallback.stop()
self._pcallback = None
class AsyncIOLoopKernelRestarter(IOLoopKernelRestarter):
async def poll(self):
if self.debug:
self.log.debug('Polling kernel...')
is_alive = await self.kernel_manager.is_alive()
if not is_alive:
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count >= self.restart_limit:
self.log.warning("AsyncIOLoopKernelRestarter: restart failed")
self._fire_callbacks('dead')
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info('AsyncIOLoopKernelRestarter: restarting kernel (%i/%i), %s random ports',
self._restart_count,
self.restart_limit,
'new' if newports else 'keep'
)
self._fire_callbacks('restart')
await self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
if self._initial_startup:
self._initial_startup = False
if self._restarting:
self.log.debug("AsyncIOLoopKernelRestarter: restart apparently succeeded")
self._restarting = False

View file

@ -0,0 +1,91 @@
"""Utilities to manipulate JSON objects."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from datetime import datetime
import re
import warnings
from dateutil.parser import parse as _dateutil_parse
from dateutil.tz import tzlocal
next_attr_name = '__next__' # Not sure what downstream library uses this, but left it to be safe
#-----------------------------------------------------------------------------
# Globals and constants
#-----------------------------------------------------------------------------
# timestamp formats
ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
ISO8601_PAT = re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$")
# holy crap, strptime is not threadsafe.
# Calling it once at import seems to help.
datetime.strptime("1", "%d")
#-----------------------------------------------------------------------------
# Classes and functions
#-----------------------------------------------------------------------------
def _ensure_tzinfo(dt):
"""Ensure a datetime object has tzinfo
If no tzinfo is present, add tzlocal
"""
if not dt.tzinfo:
# No more naïve datetime objects!
warnings.warn("Interpreting naive datetime as local %s. Please add timezone info to timestamps." % dt,
DeprecationWarning,
stacklevel=4)
dt = dt.replace(tzinfo=tzlocal())
return dt
def parse_date(s):
"""parse an ISO8601 date string
If it is None or not a valid ISO8601 timestamp,
it will be returned unmodified.
Otherwise, it will return a datetime object.
"""
if s is None:
return s
m = ISO8601_PAT.match(s)
if m:
dt = _dateutil_parse(s)
return _ensure_tzinfo(dt)
return s
def extract_dates(obj):
"""extract ISO8601 dates from unpacked JSON"""
if isinstance(obj, dict):
new_obj = {} # don't clobber
for k,v in obj.items():
new_obj[k] = extract_dates(v)
obj = new_obj
elif isinstance(obj, (list, tuple)):
obj = [ extract_dates(o) for o in obj ]
elif isinstance(obj, str):
obj = parse_date(obj)
return obj
def squash_dates(obj):
"""squash datetime objects into ISO8601 strings"""
if isinstance(obj, dict):
obj = dict(obj) # don't clobber
for k,v in obj.items():
obj[k] = squash_dates(v)
elif isinstance(obj, (list, tuple)):
obj = [ squash_dates(o) for o in obj ]
elif isinstance(obj, datetime):
obj = obj.isoformat()
return obj
def date_default(obj):
"""default function for packing datetime objects in JSON."""
if isinstance(obj, datetime):
obj = _ensure_tzinfo(obj)
return obj.isoformat().replace('+00:00', 'Z')
else:
raise TypeError("%r is not JSON serializable" % obj)

View file

@ -0,0 +1,83 @@
import os
import signal
import uuid
from jupyter_core.application import JupyterApp, base_flags
from tornado.ioloop import IOLoop
from traitlets import Unicode
from . import __version__
from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME
from .manager import KernelManager
class KernelApp(JupyterApp):
"""Launch a kernel by name in a local subprocess.
"""
version = __version__
description = "Run a kernel locally in a subprocess"
classes = [KernelManager, KernelSpecManager]
aliases = {
'kernel': 'KernelApp.kernel_name',
'ip': 'KernelManager.ip',
}
flags = {'debug': base_flags['debug']}
kernel_name = Unicode(NATIVE_KERNEL_NAME,
help = 'The name of a kernel type to start'
).tag(config=True)
def initialize(self, argv=None):
super().initialize(argv)
cf_basename = 'kernel-%s.json' % uuid.uuid4()
self.config.setdefault('KernelManager', {}).setdefault('connection_file', os.path.join(self.runtime_dir, cf_basename))
self.km = KernelManager(kernel_name=self.kernel_name,
config=self.config)
self.loop = IOLoop.current()
self.loop.add_callback(self._record_started)
def setup_signals(self):
"""Shutdown on SIGTERM or SIGINT (Ctrl-C)"""
if os.name == 'nt':
return
def shutdown_handler(signo, frame):
self.loop.add_callback_from_signal(self.shutdown, signo)
for sig in [signal.SIGTERM, signal.SIGINT]:
signal.signal(sig, shutdown_handler)
def shutdown(self, signo):
self.log.info('Shutting down on signal %d' % signo)
self.km.shutdown_kernel()
self.loop.stop()
def log_connection_info(self):
cf = self.km.connection_file
self.log.info('Connection file: %s', cf)
self.log.info("To connect a client: --existing %s", os.path.basename(cf))
def _record_started(self):
"""For tests, create a file to indicate that we've started
Do not rely on this except in our own tests!
"""
fn = os.environ.get('JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE')
if fn is not None:
with open(fn, 'wb'):
pass
def start(self):
self.log.info('Starting kernel %r', self.kernel_name)
try:
self.km.start_kernel()
self.log_connection_info()
self.setup_signals()
self.loop.start()
finally:
self.km.cleanup_resources()
main = KernelApp.launch_instance

View file

@ -0,0 +1,378 @@
"""Tools for managing kernel specs"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import io
import json
import os
import re
import shutil
import warnings
pjoin = os.path.join
from traitlets import (
HasTraits, List, Unicode, Dict, Set, Bool, Type, CaselessStrEnum
)
from traitlets.config import LoggingConfigurable
from jupyter_core.paths import jupyter_data_dir, jupyter_path, SYSTEM_JUPYTER_PATH
NATIVE_KERNEL_NAME = 'python3'
class KernelSpec(HasTraits):
argv = List()
display_name = Unicode()
language = Unicode()
env = Dict()
resource_dir = Unicode()
interrupt_mode = CaselessStrEnum(
['message', 'signal'], default_value='signal'
)
metadata = Dict()
@classmethod
def from_resource_dir(cls, resource_dir):
"""Create a KernelSpec object by reading kernel.json
Pass the path to the *directory* containing kernel.json.
"""
kernel_file = pjoin(resource_dir, 'kernel.json')
with io.open(kernel_file, 'r', encoding='utf-8') as f:
kernel_dict = json.load(f)
return cls(resource_dir=resource_dir, **kernel_dict)
def to_dict(self):
d = dict(argv=self.argv,
env=self.env,
display_name=self.display_name,
language=self.language,
interrupt_mode=self.interrupt_mode,
metadata=self.metadata,
)
return d
def to_json(self):
"""Serialise this kernelspec to a JSON object.
Returns a string.
"""
return json.dumps(self.to_dict())
_kernel_name_pat = re.compile(r'^[a-z0-9._\-]+$', re.IGNORECASE)
def _is_valid_kernel_name(name):
"""Check that a kernel name is valid."""
# quote is not unicode-safe on Python 2
return _kernel_name_pat.match(name)
_kernel_name_description = "Kernel names can only contain ASCII letters and numbers and these separators:" \
" - . _ (hyphen, period, and underscore)."
def _is_kernel_dir(path):
"""Is ``path`` a kernel directory?"""
return os.path.isdir(path) and os.path.isfile(pjoin(path, 'kernel.json'))
def _list_kernels_in(dir):
"""Return a mapping of kernel names to resource directories from dir.
If dir is None or does not exist, returns an empty dict.
"""
if dir is None or not os.path.isdir(dir):
return {}
kernels = {}
for f in os.listdir(dir):
path = pjoin(dir, f)
if not _is_kernel_dir(path):
continue
key = f.lower()
if not _is_valid_kernel_name(key):
warnings.warn("Invalid kernelspec directory name (%s): %s"
% (_kernel_name_description, path), stacklevel=3,
)
kernels[key] = path
return kernels
class NoSuchKernel(KeyError):
def __init__(self, name):
self.name = name
def __str__(self):
return "No such kernel named {}".format(self.name)
class KernelSpecManager(LoggingConfigurable):
kernel_spec_class = Type(KernelSpec, config=True,
help="""The kernel spec class. This is configurable to allow
subclassing of the KernelSpecManager for customized behavior.
"""
)
ensure_native_kernel = Bool(True, config=True,
help="""If there is no Python kernelspec registered and the IPython
kernel is available, ensure it is added to the spec list.
"""
)
data_dir = Unicode()
def _data_dir_default(self):
return jupyter_data_dir()
user_kernel_dir = Unicode()
def _user_kernel_dir_default(self):
return pjoin(self.data_dir, 'kernels')
whitelist = Set(config=True,
help="""Whitelist of allowed kernel names.
By default, all installed kernels are allowed.
"""
)
kernel_dirs = List(
help="List of kernel directories to search. Later ones take priority over earlier."
)
def _kernel_dirs_default(self):
dirs = jupyter_path('kernels')
# At some point, we should stop adding .ipython/kernels to the path,
# but the cost to keeping it is very small.
try:
from IPython.paths import get_ipython_dir
except ImportError:
try:
from IPython.utils.path import get_ipython_dir
except ImportError:
# no IPython, no ipython dir
get_ipython_dir = None
if get_ipython_dir is not None:
dirs.append(os.path.join(get_ipython_dir(), 'kernels'))
return dirs
def find_kernel_specs(self):
"""Returns a dict mapping kernel names to resource directories."""
d = {}
for kernel_dir in self.kernel_dirs:
kernels = _list_kernels_in(kernel_dir)
for kname, spec in kernels.items():
if kname not in d:
self.log.debug("Found kernel %s in %s", kname, kernel_dir)
d[kname] = spec
if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d:
try:
from ipykernel.kernelspec import RESOURCES
self.log.debug("Native kernel (%s) available from %s",
NATIVE_KERNEL_NAME, RESOURCES)
d[NATIVE_KERNEL_NAME] = RESOURCES
except ImportError:
self.log.warning("Native kernel (%s) is not available", NATIVE_KERNEL_NAME)
if self.whitelist:
# filter if there's a whitelist
d = {name:spec for name,spec in d.items() if name in self.whitelist}
return d
# TODO: Caching?
def _get_kernel_spec_by_name(self, kernel_name, resource_dir):
""" Returns a :class:`KernelSpec` instance for a given kernel_name
and resource_dir.
"""
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES, get_kernel_dict
except ImportError:
# It should be impossible to reach this, but let's play it safe
pass
else:
if resource_dir == RESOURCES:
return self.kernel_spec_class(resource_dir=resource_dir, **get_kernel_dict())
return self.kernel_spec_class.from_resource_dir(resource_dir)
def _find_spec_directory(self, kernel_name):
"""Find the resource directory of a named kernel spec"""
for kernel_dir in self.kernel_dirs:
try:
files = os.listdir(kernel_dir)
except OSError as e:
if e.errno in (errno.ENOTDIR, errno.ENOENT):
continue
raise
for f in files:
path = pjoin(kernel_dir, f)
if f.lower() == kernel_name and _is_kernel_dir(path):
return path
if kernel_name == NATIVE_KERNEL_NAME:
try:
from ipykernel.kernelspec import RESOURCES
except ImportError:
pass
else:
return RESOURCES
def get_kernel_spec(self, kernel_name):
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises :exc:`NoSuchKernel` if the given kernel name is not found.
"""
if not _is_valid_kernel_name(kernel_name):
self.log.warning("Kernelspec name %r is invalid: %s", kernel_name,
_kernel_name_description)
resource_dir = self._find_spec_directory(kernel_name.lower())
if resource_dir is None:
raise NoSuchKernel(kernel_name)
return self._get_kernel_spec_by_name(kernel_name, resource_dir)
def get_all_specs(self):
"""Returns a dict mapping kernel names to kernelspecs.
Returns a dict of the form::
{
'kernel_name': {
'resource_dir': '/path/to/kernel_name',
'spec': {"the spec itself": ...}
},
...
}
"""
d = self.find_kernel_specs()
res = {}
for kname, resource_dir in d.items():
try:
if self.__class__ is KernelSpecManager:
spec = self._get_kernel_spec_by_name(kname, resource_dir)
else:
# avoid calling private methods in subclasses,
# which may have overridden find_kernel_specs
# and get_kernel_spec, but not the newer get_all_specs
spec = self.get_kernel_spec(kname)
res[kname] = {
"resource_dir": resource_dir,
"spec": spec.to_dict()
}
except Exception:
self.log.warning("Error loading kernelspec %r", kname, exc_info=True)
return res
def remove_kernel_spec(self, name):
"""Remove a kernel spec directory by name.
Returns the path that was deleted.
"""
save_native = self.ensure_native_kernel
try:
self.ensure_native_kernel = False
specs = self.find_kernel_specs()
finally:
self.ensure_native_kernel = save_native
spec_dir = specs[name]
self.log.debug("Removing %s", spec_dir)
if os.path.islink(spec_dir):
os.remove(spec_dir)
else:
shutil.rmtree(spec_dir)
return spec_dir
def _get_destination_dir(self, kernel_name, user=False, prefix=None):
if user:
return os.path.join(self.user_kernel_dir, kernel_name)
elif prefix:
return os.path.join(os.path.abspath(prefix), 'share', 'jupyter', 'kernels', kernel_name)
else:
return os.path.join(SYSTEM_JUPYTER_PATH[0], 'kernels', kernel_name)
def install_kernel_spec(self, source_dir, kernel_name=None, user=False,
replace=None, prefix=None):
"""Install a kernel spec by copying its directory.
If ``kernel_name`` is not given, the basename of ``source_dir`` will
be used.
If ``user`` is False, it will attempt to install into the systemwide
kernel registry. If the process does not have appropriate permissions,
an :exc:`OSError` will be raised.
If ``prefix`` is given, the kernelspec will be installed to
PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix
for installation inside virtual or conda envs.
"""
source_dir = source_dir.rstrip('/\\')
if not kernel_name:
kernel_name = os.path.basename(source_dir)
kernel_name = kernel_name.lower()
if not _is_valid_kernel_name(kernel_name):
raise ValueError("Invalid kernel name %r. %s" % (kernel_name, _kernel_name_description))
if user and prefix:
raise ValueError("Can't specify both user and prefix. Please choose one or the other.")
if replace is not None:
warnings.warn(
"replace is ignored. Installing a kernelspec always replaces an existing installation",
DeprecationWarning,
stacklevel=2,
)
destination = self._get_destination_dir(kernel_name, user=user, prefix=prefix)
self.log.debug('Installing kernelspec in %s', destination)
kernel_dir = os.path.dirname(destination)
if kernel_dir not in self.kernel_dirs:
self.log.warning("Installing to %s, which is not in %s. The kernelspec may not be found.",
kernel_dir, self.kernel_dirs,
)
if os.path.isdir(destination):
self.log.info('Removing existing kernelspec in %s', destination)
shutil.rmtree(destination)
shutil.copytree(source_dir, destination)
self.log.info('Installed kernelspec %s in %s', kernel_name, destination)
return destination
def install_native_kernel_spec(self, user=False):
"""DEPRECATED: Use ipykernel.kenelspec.install"""
warnings.warn("install_native_kernel_spec is deprecated."
" Use ipykernel.kernelspec import install.", stacklevel=2)
from ipykernel.kernelspec import install
install(self, user=user)
def find_kernel_specs():
"""Returns a dict mapping kernel names to resource directories."""
return KernelSpecManager().find_kernel_specs()
def get_kernel_spec(kernel_name):
"""Returns a :class:`KernelSpec` instance for the given kernel_name.
Raises KeyError if the given kernel name is not found.
"""
return KernelSpecManager().get_kernel_spec(kernel_name)
def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False,
prefix=None):
return KernelSpecManager().install_kernel_spec(source_dir, kernel_name,
user, replace, prefix)
install_kernel_spec.__doc__ = KernelSpecManager.install_kernel_spec.__doc__
def install_native_kernel_spec(user=False):
return KernelSpecManager().install_native_kernel_spec(user=user)
install_native_kernel_spec.__doc__ = KernelSpecManager.install_native_kernel_spec.__doc__

View file

@ -0,0 +1,270 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import os.path
import sys
import json
from traitlets.config.application import Application
from jupyter_core.application import (
JupyterApp, base_flags, base_aliases
)
from traitlets import Instance, Dict, Unicode, Bool, List
from . import __version__
from .kernelspec import KernelSpecManager
class ListKernelSpecs(JupyterApp):
version = __version__
description = """List installed kernel specifications."""
kernel_spec_manager = Instance(KernelSpecManager)
json_output = Bool(False, help='output spec name and location as machine-readable json.',
config=True)
flags = {'json': ({'ListKernelSpecs': {'json_output': True}},
"output spec name and location as machine-readable json."),
'debug': base_flags['debug'],
}
def _kernel_spec_manager_default(self):
return KernelSpecManager(parent=self, data_dir=self.data_dir)
def start(self):
paths = self.kernel_spec_manager.find_kernel_specs()
specs = self.kernel_spec_manager.get_all_specs()
if not self.json_output:
if not specs:
print("No kernels available")
return
# pad to width of longest kernel name
name_len = len(sorted(paths, key=lambda name: len(name))[-1])
def path_key(item):
"""sort key function for Jupyter path priority"""
path = item[1]
for idx, prefix in enumerate(self.jupyter_path):
if path.startswith(prefix):
return (idx, path)
# not in jupyter path, artificially added to the front
return (-1, path)
print("Available kernels:")
for kernelname, path in sorted(paths.items(), key=path_key):
print(" %s %s" % (kernelname.ljust(name_len), path))
else:
print(json.dumps({
'kernelspecs': specs
}, indent=2))
class InstallKernelSpec(JupyterApp):
version = __version__
description = """Install a kernel specification directory.
Given a SOURCE DIRECTORY containing a kernel spec,
jupyter will copy that directory into one of the Jupyter kernel directories.
The default is to install kernelspecs for all users.
`--user` can be specified to install a kernel only for the current user.
"""
examples = """
jupyter kernelspec install /path/to/my_kernel --user
"""
usage = "jupyter kernelspec install SOURCE_DIR [--options]"
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir)
sourcedir = Unicode()
kernel_name = Unicode("", config=True,
help="Install the kernel spec with this name"
)
def _kernel_name_default(self):
return os.path.basename(self.sourcedir)
user = Bool(False, config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
"""
)
prefix = Unicode('', config=True,
help="""Specify a prefix to install to, e.g. an env.
The kernelspec will be installed in PREFIX/share/jupyter/kernels/
"""
)
replace = Bool(False, config=True,
help="Replace any existing kernel spec with this name."
)
aliases = {
'name': 'InstallKernelSpec.kernel_name',
'prefix': 'InstallKernelSpec.prefix',
}
aliases.update(base_aliases)
flags = {'user': ({'InstallKernelSpec': {'user': True}},
"Install to the per-user kernel registry"),
'replace': ({'InstallKernelSpec': {'replace': True}},
"Replace any existing kernel spec with this name."),
'sys-prefix': ({'InstallKernelSpec': {'prefix': sys.prefix}},
"Install to Python's sys.prefix. Useful in conda/virtual environments."),
'debug': base_flags['debug'],
}
def parse_command_line(self, argv):
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.sourcedir = self.extra_args[0]
else:
print("No source directory specified.")
self.exit(1)
def start(self):
if self.user and self.prefix:
self.exit("Can't specify both user and prefix. Please choose one or the other.")
try:
self.kernel_spec_manager.install_kernel_spec(self.sourcedir,
kernel_name=self.kernel_name,
user=self.user,
prefix=self.prefix,
replace=self.replace,
)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print("Perhaps you want to install with `sudo` or `--user`?", file=sys.stderr)
self.exit(1)
elif e.errno == errno.EEXIST:
print("A kernel spec is already present at %s" % e.filename, file=sys.stderr)
self.exit(1)
raise
class RemoveKernelSpec(JupyterApp):
version = __version__
description = """Remove one or more Jupyter kernelspecs by name."""
examples = """jupyter kernelspec remove python2 [my_kernel ...]"""
force = Bool(False, config=True,
help="""Force removal, don't prompt for confirmation."""
)
spec_names = List(Unicode())
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir, parent=self)
flags = {
'f': ({'RemoveKernelSpec': {'force': True}}, force.get_metadata('help')),
}
flags.update(JupyterApp.flags)
def parse_command_line(self, argv):
super().parse_command_line(argv)
# accept positional arg as profile name
if self.extra_args:
self.spec_names = sorted(set(self.extra_args)) # remove duplicates
else:
self.exit("No kernelspec specified.")
def start(self):
self.kernel_spec_manager.ensure_native_kernel = False
spec_paths = self.kernel_spec_manager.find_kernel_specs()
missing = set(self.spec_names).difference(set(spec_paths))
if missing:
self.exit("Couldn't find kernel spec(s): %s" % ', '.join(missing))
if not self.force:
print("Kernel specs to remove:")
for name in self.spec_names:
print(" %s\t%s" % (name.ljust(20), spec_paths[name]))
answer = input("Remove %i kernel specs [y/N]: " % len(self.spec_names))
if not answer.lower().startswith('y'):
return
for kernel_name in self.spec_names:
try:
path = self.kernel_spec_manager.remove_kernel_spec(kernel_name)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
print("Perhaps you want sudo?", file=sys.stderr)
self.exit(1)
else:
raise
self.log.info("Removed %s", path)
class InstallNativeKernelSpec(JupyterApp):
version = __version__
description = """[DEPRECATED] Install the IPython kernel spec directory for this Python."""
kernel_spec_manager = Instance(KernelSpecManager)
def _kernel_spec_manager_default(self):
return KernelSpecManager(data_dir=self.data_dir)
user = Bool(False, config=True,
help="""
Try to install the kernel spec to the per-user directory instead of
the system or environment directory.
"""
)
flags = {'user': ({'InstallNativeKernelSpec': {'user': True}},
"Install to the per-user kernel registry"),
'debug': base_flags['debug'],
}
def start(self):
self.log.warning("`jupyter kernelspec install-self` is DEPRECATED as of 4.0."
" You probably want `ipython kernel install` to install the IPython kernelspec.")
try:
from ipykernel import kernelspec
except ImportError:
print("ipykernel not available, can't install its spec.", file=sys.stderr)
self.exit(1)
try:
kernelspec.install(self.kernel_spec_manager, user=self.user)
except OSError as e:
if e.errno == errno.EACCES:
print(e, file=sys.stderr)
if not self.user:
print("Perhaps you want to install with `sudo` or `--user`?", file=sys.stderr)
self.exit(1)
self.exit(e)
class KernelSpecApp(Application):
version = __version__
name = "jupyter kernelspec"
description = """Manage Jupyter kernel specifications."""
subcommands = Dict({
'list': (ListKernelSpecs, ListKernelSpecs.description.splitlines()[0]),
'install': (InstallKernelSpec, InstallKernelSpec.description.splitlines()[0]),
'uninstall': (RemoveKernelSpec, "Alias for remove"),
'remove': (RemoveKernelSpec, RemoveKernelSpec.description.splitlines()[0]),
'install-self': (InstallNativeKernelSpec, InstallNativeKernelSpec.description.splitlines()[0]),
})
aliases = {}
flags = {}
def start(self):
if self.subapp is None:
print("No subcommand specified. Must specify one of: %s"% list(self.subcommands))
print()
self.print_description()
self.print_subcommands()
self.exit(1)
else:
return self.subapp.start()
if __name__ == '__main__':
KernelSpecApp.launch_instance()

View file

@ -0,0 +1,158 @@
"""Utilities for launching kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import sys
from subprocess import Popen, PIPE
from ipython_genutils.encoding import getdefaultencoding
from traitlets.log import get_logger
def launch_kernel(cmd, stdin=None, stdout=None, stderr=None, env=None,
independent=False, cwd=None, **kw):
""" Launches a localhost kernel, binding to the specified ports.
Parameters
----------
cmd : Popen list,
A string of Python code that imports and executes a kernel entry point.
stdin, stdout, stderr : optional (default None)
Standards streams, as defined in subprocess.Popen.
env: dict, optional
Environment variables passed to the kernel
independent : bool, optional (default False)
If set, the kernel process is guaranteed to survive if this process
dies. If not set, an effort is made to ensure that the kernel is killed
when this process dies. Note that in this case it is still good practice
to kill kernels manually before exiting.
cwd : path, optional
The working dir of the kernel process (default: cwd of this process).
**kw: optional
Additional arguments for Popen
Returns
-------
Popen instance for the kernel subprocess
"""
# Popen will fail (sometimes with a deadlock) if stdin, stdout, and stderr
# are invalid. Unfortunately, there is in general no way to detect whether
# they are valid. The following two blocks redirect them to (temporary)
# pipes in certain important cases.
# If this process has been backgrounded, our stdin is invalid. Since there
# is no compelling reason for the kernel to inherit our stdin anyway, we'll
# place this one safe and always redirect.
redirect_in = True
_stdin = PIPE if stdin is None else stdin
# If this process in running on pythonw, we know that stdin, stdout, and
# stderr are all invalid.
redirect_out = sys.executable.endswith('pythonw.exe')
if redirect_out:
blackhole = open(os.devnull, 'w')
_stdout = blackhole if stdout is None else stdout
_stderr = blackhole if stderr is None else stderr
else:
_stdout, _stderr = stdout, stderr
env = env if (env is not None) else os.environ.copy()
encoding = getdefaultencoding(prefer_stream=False)
kwargs = kw.copy()
main_args = dict(
stdin=_stdin,
stdout=_stdout,
stderr=_stderr,
cwd=cwd,
env=env,
)
kwargs.update(main_args)
# Spawn a kernel.
if sys.platform == 'win32':
if cwd:
kwargs['cwd'] = cwd
from .win_interrupt import create_interrupt_event
# Create a Win32 event for interrupting the kernel
# and store it in an environment variable.
interrupt_event = create_interrupt_event()
env["JPY_INTERRUPT_EVENT"] = str(interrupt_event)
# deprecated old env name:
env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"]
try:
from _winapi import DuplicateHandle, GetCurrentProcess, \
DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP
except:
from _subprocess import DuplicateHandle, GetCurrentProcess, \
DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP
# create a handle on the parent to be inherited
if independent:
kwargs['creationflags'] = CREATE_NEW_PROCESS_GROUP
else:
pid = GetCurrentProcess()
handle = DuplicateHandle(pid, pid, pid, 0,
True, # Inheritable by new processes.
DUPLICATE_SAME_ACCESS)
env['JPY_PARENT_PID'] = str(int(handle))
# Prevent creating new console window on pythonw
if redirect_out:
kwargs['creationflags'] = kwargs.setdefault('creationflags', 0) | 0x08000000 # CREATE_NO_WINDOW
# Avoid closing the above parent and interrupt handles.
# close_fds is True by default on Python >=3.7
# or when no stream is captured on Python <3.7
# (we always capture stdin, so this is already False by default on <3.7)
kwargs['close_fds'] = False
else:
# Create a new session.
# This makes it easier to interrupt the kernel,
# because we want to interrupt the whole process group.
# We don't use setpgrp, which is known to cause problems for kernels starting
# certain interactive subprocesses, such as bash -i.
kwargs['start_new_session'] = True
if not independent:
env['JPY_PARENT_PID'] = str(os.getpid())
try:
proc = Popen(cmd, **kwargs)
except Exception as exc:
msg = (
"Failed to run command:\n{}\n"
" PATH={!r}\n"
" with kwargs:\n{!r}\n"
)
# exclude environment variables,
# which may contain access tokens and the like.
without_env = {key:value for key, value in kwargs.items() if key != 'env'}
msg = msg.format(cmd, env.get('PATH', os.defpath), without_env)
get_logger().error(msg)
raise
if sys.platform == 'win32':
# Attach the interrupt event to the Popen objet so it can be used later.
proc.win32_interrupt_event = interrupt_event
# Clean up pipes created to work around Popen bug.
if redirect_in:
if stdin is None:
proc.stdin.close()
return proc
__all__ = [
'launch_kernel',
]

View file

@ -0,0 +1,274 @@
"""Utilities for identifying local IP addresses."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import re
import socket
import subprocess
from subprocess import Popen, PIPE
from warnings import warn
LOCAL_IPS = []
PUBLIC_IPS = []
LOCALHOST = ''
def _uniq_stable(elems):
"""uniq_stable(elems) -> list
Return from an iterable, a list of all the unique elements in the input,
maintaining the order in which they first appear.
From ipython_genutils.data
"""
seen = set()
return [x for x in elems if x not in seen and not seen.add(x)]
def _get_output(cmd):
"""Get output of a command, raising IOError if it fails"""
startupinfo = None
if os.name == 'nt':
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
p = Popen(cmd, stdout=PIPE, stderr=PIPE, startupinfo=startupinfo)
stdout, stderr = p.communicate()
if p.returncode:
raise IOError("Failed to run %s: %s" % (cmd, stderr.decode('utf8', 'replace')))
return stdout.decode('utf8', 'replace')
def _only_once(f):
"""decorator to only run a function once"""
f.called = False
def wrapped(**kwargs):
if f.called:
return
ret = f(**kwargs)
f.called = True
return ret
return wrapped
def _requires_ips(f):
"""decorator to ensure load_ips has been run before f"""
def ips_loaded(*args, **kwargs):
_load_ips()
return f(*args, **kwargs)
return ips_loaded
# subprocess-parsing ip finders
class NoIPAddresses(Exception):
pass
def _populate_from_list(addrs):
"""populate local and public IPs from flat list of all IPs"""
if not addrs:
raise NoIPAddresses
global LOCALHOST
public_ips = []
local_ips = []
for ip in addrs:
local_ips.append(ip)
if not ip.startswith('127.'):
public_ips.append(ip)
elif not LOCALHOST:
LOCALHOST = ip
if not LOCALHOST:
LOCALHOST = '127.0.0.1'
local_ips.insert(0, LOCALHOST)
local_ips.extend(['0.0.0.0', ''])
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
_ifconfig_ipv4_pat = re.compile(r'inet\b.*?(\d+\.\d+\.\d+\.\d+)', re.IGNORECASE)
def _load_ips_ifconfig():
"""load ip addresses from `ifconfig` output (posix)"""
try:
out = _get_output('ifconfig')
except (IOError, OSError):
# no ifconfig, it's usually in /sbin and /sbin is not on everyone's PATH
out = _get_output('/sbin/ifconfig')
lines = out.splitlines()
addrs = []
for line in lines:
m = _ifconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_ip():
"""load ip addresses from `ip addr` output (Linux)"""
out = _get_output(['ip', '-f', 'inet', 'addr'])
lines = out.splitlines()
addrs = []
for line in lines:
blocks = line.lower().split()
if (len(blocks) >= 2) and (blocks[0] == 'inet'):
addrs.append(blocks[1].split('/')[0])
_populate_from_list(addrs)
_ipconfig_ipv4_pat = re.compile(r'ipv4.*?(\d+\.\d+\.\d+\.\d+)$', re.IGNORECASE)
def _load_ips_ipconfig():
"""load ip addresses from `ipconfig` output (Windows)"""
out = _get_output('ipconfig')
lines = out.splitlines()
addrs = []
for line in lines:
m = _ipconfig_ipv4_pat.match(line.strip())
if m:
addrs.append(m.group(1))
_populate_from_list(addrs)
def _load_ips_netifaces():
"""load ip addresses with netifaces"""
import netifaces
global LOCALHOST
local_ips = []
public_ips = []
# list of iface names, 'lo0', 'eth0', etc.
for iface in netifaces.interfaces():
# list of ipv4 addrinfo dicts
ipv4s = netifaces.ifaddresses(iface).get(netifaces.AF_INET, [])
for entry in ipv4s:
addr = entry.get('addr')
if not addr:
continue
if not (iface.startswith('lo') or addr.startswith('127.')):
public_ips.append(addr)
elif not LOCALHOST:
LOCALHOST = addr
local_ips.append(addr)
if not LOCALHOST:
# we never found a loopback interface (can this ever happen?), assume common default
LOCALHOST = '127.0.0.1'
local_ips.insert(0, LOCALHOST)
local_ips.extend(['0.0.0.0', ''])
LOCAL_IPS[:] = _uniq_stable(local_ips)
PUBLIC_IPS[:] = _uniq_stable(public_ips)
def _load_ips_gethostbyname():
"""load ip addresses with socket.gethostbyname_ex
This can be slow.
"""
global LOCALHOST
try:
LOCAL_IPS[:] = socket.gethostbyname_ex('localhost')[2]
except socket.error:
# assume common default
LOCAL_IPS[:] = ['127.0.0.1']
try:
hostname = socket.gethostname()
PUBLIC_IPS[:] = socket.gethostbyname_ex(hostname)[2]
# try hostname.local, in case hostname has been short-circuited to loopback
if not hostname.endswith('.local') and all(ip.startswith('127') for ip in PUBLIC_IPS):
PUBLIC_IPS[:] = socket.gethostbyname_ex(socket.gethostname() + '.local')[2]
except socket.error:
pass
finally:
PUBLIC_IPS[:] = _uniq_stable(PUBLIC_IPS)
LOCAL_IPS.extend(PUBLIC_IPS)
# include all-interface aliases: 0.0.0.0 and ''
LOCAL_IPS.extend(['0.0.0.0', ''])
LOCAL_IPS[:] = _uniq_stable(LOCAL_IPS)
LOCALHOST = LOCAL_IPS[0]
def _load_ips_dumb():
"""Fallback in case of unexpected failure"""
global LOCALHOST
LOCALHOST = '127.0.0.1'
LOCAL_IPS[:] = [LOCALHOST, '0.0.0.0', '']
PUBLIC_IPS[:] = []
@_only_once
def _load_ips(suppress_exceptions=True):
"""load the IPs that point to this machine
This function will only ever be called once.
It will use netifaces to do it quickly if available.
Then it will fallback on parsing the output of ifconfig / ip addr / ipconfig, as appropriate.
Finally, it will fallback on socket.gethostbyname_ex, which can be slow.
"""
try:
# first priority, use netifaces
try:
return _load_ips_netifaces()
except ImportError:
pass
# second priority, parse subprocess output (how reliable is this?)
if os.name == 'nt':
try:
return _load_ips_ipconfig()
except (IOError, NoIPAddresses):
pass
else:
try:
return _load_ips_ip()
except (IOError, OSError, NoIPAddresses):
pass
try:
return _load_ips_ifconfig()
except (IOError, OSError, NoIPAddresses):
pass
# lowest priority, use gethostbyname
return _load_ips_gethostbyname()
except Exception as e:
if not suppress_exceptions:
raise
# unexpected error shouldn't crash, load dumb default values instead.
warn("Unexpected error discovering local network interfaces: %s" % e)
_load_ips_dumb()
@_requires_ips
def local_ips():
"""return the IP addresses that point to this machine"""
return LOCAL_IPS
@_requires_ips
def public_ips():
"""return the IP addresses for this machine that are visible to other machines"""
return PUBLIC_IPS
@_requires_ips
def localhost():
"""return ip for localhost (almost always 127.0.0.1)"""
return LOCALHOST
@_requires_ips
def is_local_ip(ip):
"""does `ip` point to this machine?"""
return ip in LOCAL_IPS
@_requires_ips
def is_public_ip(ip):
"""is `ip` a publicly visible address?"""
return ip in PUBLIC_IPS

View file

@ -0,0 +1,827 @@
"""Base class to manage a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from contextlib import contextmanager
import asyncio
import os
import re
import signal
import sys
import time
import warnings
import zmq
from ipython_genutils.importstring import import_item
from .localinterfaces import is_local_ip, local_ips
from traitlets import (
Any, Float, Instance, Unicode, List, Bool, Type, DottedObjectName,
default, observe
)
from jupyter_client import (
launch_kernel,
kernelspec,
)
from .connect import ConnectionFileMixin
from .managerabc import (
KernelManagerABC
)
class KernelManager(ConnectionFileMixin):
"""Manages a single kernel in a subprocess on this host.
This version starts kernels with Popen.
"""
_created_context = Bool(False)
# The PyZMQ Context to use for communication with the kernel.
context = Instance(zmq.Context)
def _context_default(self):
self._created_context = True
return zmq.Context()
# the class to create with our `client` method
client_class = DottedObjectName('jupyter_client.blocking.BlockingKernelClient')
client_factory = Type(klass='jupyter_client.KernelClient')
def _client_factory_default(self):
return import_item(self.client_class)
@observe('client_class')
def _client_class_changed(self, change):
self.client_factory = import_item(str(change['new']))
# The kernel process with which the KernelManager is communicating.
# generally a Popen instance
kernel = Any()
kernel_spec_manager = Instance(kernelspec.KernelSpecManager)
def _kernel_spec_manager_default(self):
return kernelspec.KernelSpecManager(data_dir=self.data_dir)
def _kernel_spec_manager_changed(self):
self._kernel_spec = None
shutdown_wait_time = Float(
5.0, config=True,
help="Time to wait for a kernel to terminate before killing it, "
"in seconds.")
kernel_name = Unicode(kernelspec.NATIVE_KERNEL_NAME)
@observe('kernel_name')
def _kernel_name_changed(self, change):
self._kernel_spec = None
if change['new'] == 'python':
self.kernel_name = kernelspec.NATIVE_KERNEL_NAME
_kernel_spec = None
@property
def kernel_spec(self):
if self._kernel_spec is None and self.kernel_name != '':
self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name)
return self._kernel_spec
kernel_cmd = List(Unicode(), config=True,
help="""DEPRECATED: Use kernel_name instead.
The Popen Command to launch the kernel.
Override this if you have a custom kernel.
If kernel_cmd is specified in a configuration file,
Jupyter does not pass any arguments to the kernel,
because it cannot make any assumptions about the
arguments that the kernel understands. In particular,
this means that the kernel does not receive the
option --debug if it given on the Jupyter command line.
"""
)
def _kernel_cmd_changed(self, name, old, new):
warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to "
"start different kernels.")
cache_ports = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance')
@default('cache_ports')
def _default_cache_ports(self):
return self.transport == 'tcp'
@property
def ipykernel(self):
return self.kernel_name in {'python', 'python2', 'python3'}
# Protected traits
_launch_args = Any()
_control_socket = Any()
_restarter = Any()
autorestart = Bool(True, config=True,
help="""Should we autorestart the kernel if it dies."""
)
def __del__(self):
self._close_control_socket()
self.cleanup_connection_file()
#--------------------------------------------------------------------------
# Kernel restarter
#--------------------------------------------------------------------------
def start_restarter(self):
pass
def stop_restarter(self):
pass
def add_restart_callback(self, callback, event='restart'):
"""register a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.add_callback(callback, event)
def remove_restart_callback(self, callback, event='restart'):
"""unregister a callback to be called when a kernel is restarted"""
if self._restarter is None:
return
self._restarter.remove_callback(callback, event)
#--------------------------------------------------------------------------
# create a Client connected to our Kernel
#--------------------------------------------------------------------------
def client(self, **kwargs):
"""Create a client configured to connect to our kernel"""
kw = {}
kw.update(self.get_connection_info(session=True))
kw.update(dict(
connection_file=self.connection_file,
parent=self,
))
# add kwargs last, for manual overrides
kw.update(kwargs)
return self.client_factory(**kw)
#--------------------------------------------------------------------------
# Kernel management
#--------------------------------------------------------------------------
def format_kernel_cmd(self, extra_arguments=None):
"""replace templated args (e.g. {connection_file})"""
extra_arguments = extra_arguments or []
if self.kernel_cmd:
cmd = self.kernel_cmd + extra_arguments
else:
cmd = self.kernel_spec.argv + extra_arguments
if cmd and cmd[0] in {'python',
'python%i' % sys.version_info[0],
'python%i.%i' % sys.version_info[:2]}:
# executable is 'python' or 'python3', use sys.executable.
# These will typically be the same,
# but if the current process is in an env
# and has been launched by abspath without
# activating the env, python on PATH may not be sys.executable,
# but it should be.
cmd[0] = sys.executable
# Make sure to use the realpath for the connection_file
# On windows, when running with the store python, the connection_file path
# is not usable by non python kernels because the path is being rerouted when
# inside of a store app.
# See this bug here: https://bugs.python.org/issue41196
ns = dict(connection_file=os.path.realpath(self.connection_file),
prefix=sys.prefix,
)
if self.kernel_spec:
ns["resource_dir"] = self.kernel_spec.resource_dir
ns.update(self._launch_args)
pat = re.compile(r'\{([A-Za-z0-9_]+)\}')
def from_ns(match):
"""Get the key out of ns if it's there, otherwise no change."""
return ns.get(match.group(1), match.group())
return [ pat.sub(from_ns, arg) for arg in cmd ]
def _launch_kernel(self, kernel_cmd, **kw):
"""actually launch the kernel
override in a subclass to launch kernel subprocesses differently
"""
return launch_kernel(kernel_cmd, **kw)
# Control socket used for polite kernel shutdown
def _connect_control_socket(self):
if self._control_socket is None:
self._control_socket = self._create_connected_socket('control')
self._control_socket.linger = 100
def _close_control_socket(self):
if self._control_socket is None:
return
self._control_socket.close()
self._control_socket = None
def pre_start_kernel(self, **kw):
"""Prepares a kernel for startup in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
if self.transport == 'tcp' and not is_local_ip(self.ip):
raise RuntimeError("Can only launch a kernel on a local interface. "
"This one is not: %s."
"Make sure that the '*_address' attributes are "
"configured properly. "
"Currently valid addresses are: %s" % (self.ip, local_ips())
)
# write connection file / get default ports
self.write_connection_file()
# save kwargs for use in restart
self._launch_args = kw.copy()
# build the Popen cmd
extra_arguments = kw.pop('extra_arguments', [])
kernel_cmd = self.format_kernel_cmd(extra_arguments=extra_arguments)
env = kw.pop('env', os.environ).copy()
# Don't allow PYTHONEXECUTABLE to be passed to kernel process.
# If set, it can bork all the things.
env.pop('PYTHONEXECUTABLE', None)
if not self.kernel_cmd:
# If kernel_cmd has been set manually, don't refer to a kernel spec.
# Environment variables from kernel spec are added to os.environ.
env.update(self._get_env_substitutions(self.kernel_spec.env, env))
elif self.extra_env:
env.update(self._get_env_substitutions(self.extra_env, env))
kw['env'] = env
return kernel_cmd, kw
def _get_env_substitutions(self, templated_env, substitution_values):
""" Walks env entries in templated_env and applies possible substitutions from current env
(represented by substitution_values).
Returns the substituted list of env entries.
"""
substituted_env = {}
if templated_env:
from string import Template
# For each templated env entry, fill any templated references
# matching names of env variables with those values and build
# new dict with substitutions.
for k, v in templated_env.items():
substituted_env.update({k: Template(v).safe_substitute(substitution_values)})
return substituted_env
def post_start_kernel(self, **kw):
self.start_restarter()
self._connect_control_socket()
def start_kernel(self, **kw):
"""Starts a kernel on this host in a separate process.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
kernel_cmd, kw = self.pre_start_kernel(**kw)
# launch the kernel subprocess
self.log.debug("Starting kernel: %s", kernel_cmd)
self.kernel = self._launch_kernel(kernel_cmd, **kw)
self.post_start_kernel(**kw)
def request_shutdown(self, restart=False):
"""Send a shutdown request via control channel
"""
content = dict(restart=restart)
msg = self.session.msg("shutdown_request", content=content)
# ensure control socket is connected
self._connect_control_socket()
self.session.send(self._control_socket, msg)
def finish_shutdown(self, waittime=None, pollinterval=0.1):
"""Wait for kernel shutdown, then kill process if it doesn't shutdown.
This does not send shutdown requests - use :meth:`request_shutdown`
first.
"""
if waittime is None:
waittime = max(self.shutdown_wait_time, 0)
for i in range(int(waittime/pollinterval)):
if self.is_alive():
time.sleep(pollinterval)
else:
# If there's still a proc, wait and clear
if self.has_kernel:
self.kernel.wait()
self.kernel = None
break
else:
# OK, we've waited long enough.
if self.has_kernel:
self.log.debug("Kernel is taking too long to finish, killing")
self._kill_kernel()
def cleanup_resources(self, restart=False):
"""Clean up resources when the kernel is shut down"""
if not restart:
self.cleanup_connection_file()
self.cleanup_ipc_files()
self._close_control_socket()
self.session.parent = None
if self._created_context and not restart:
self.context.destroy(linger=100)
def cleanup(self, connection_file=True):
"""Clean up resources when the kernel is shut down"""
warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).",
FutureWarning)
self.cleanup_resources(restart=not connection_file)
def shutdown_kernel(self, now=False, restart=False):
"""Attempts to stop the kernel process cleanly.
This attempts to shutdown the kernels cleanly by:
1. Sending it a shutdown message over the control channel.
2. If that fails, the kernel is shutdown forcibly by sending it
a signal.
Parameters
----------
now : bool
Should the kernel be forcible killed *now*. This skips the
first, nice shutdown attempt.
restart: bool
Will this kernel be restarted after it is shutdown. When this
is True, connection files will not be cleaned up.
"""
# Stop monitoring for restarting while we shutdown.
self.stop_restarter()
if now:
self._kill_kernel()
else:
self.request_shutdown(restart=restart)
# Don't send any additional kernel kill messages immediately, to give
# the kernel a chance to properly execute shutdown actions. Wait for at
# most 1s, checking every 0.1s.
self.finish_shutdown()
# In 6.1.5, a new method, cleanup_resources(), was introduced to address
# a leak issue (https://github.com/jupyter/jupyter_client/pull/548) and
# replaced the existing cleanup() method. However, that method introduction
# breaks subclass implementations that override cleanup() since it would
# circumvent cleanup() functionality implemented in subclasses.
# By detecting if the current instance overrides cleanup(), we can determine
# if the deprecated path of calling cleanup() should be performed - which avoids
# unnecessary deprecation warnings in a majority of configurations in which
# subclassed KernelManager instances are not in use.
# Note: because subclasses may have already implemented cleanup_resources()
# but need to support older jupyter_clients, we should only take the deprecated
# path if cleanup() is overridden but cleanup_resources() is not.
overrides_cleanup = type(self).cleanup is not KernelManager.cleanup
overrides_cleanup_resources = type(self).cleanup_resources is not KernelManager.cleanup_resources
if overrides_cleanup and not overrides_cleanup_resources:
self.cleanup(connection_file=not restart)
else:
self.cleanup_resources(restart=restart)
def restart_kernel(self, now=False, newports=False, **kw):
"""Restarts a kernel with the arguments that were used to launch it.
Parameters
----------
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
newports : bool, optional
If the old kernel was launched with random ports, this flag decides
whether the same ports and connection file will be used again.
If False, the same ports and connection file are used. This is
the default. If True, new random port numbers are chosen and a
new connection file is written. It is still possible that the newly
chosen random port numbers happen to be the same as the old ones.
`**kw` : optional
Any options specified here will overwrite those used to launch the
kernel.
"""
if self._launch_args is None:
raise RuntimeError("Cannot restart the kernel. "
"No previous call to 'start_kernel'.")
else:
# Stop currently running kernel.
self.shutdown_kernel(now=now, restart=True)
if newports:
self.cleanup_random_ports()
# Start new kernel.
self._launch_args.update(kw)
self.start_kernel(**self._launch_args)
@property
def has_kernel(self):
"""Has a kernel been started that we are managing."""
return self.kernel is not None
def _kill_kernel(self):
"""Kill the running kernel.
This is a private method, callers should use shutdown_kernel(now=True).
"""
if self.has_kernel:
# Signal the kernel to terminate (sends SIGKILL on Unix and calls
# TerminateProcess() on Win32).
try:
if hasattr(signal, 'SIGKILL'):
self.signal_kernel(signal.SIGKILL)
else:
self.kernel.kill()
except OSError as e:
# In Windows, we will get an Access Denied error if the process
# has already terminated. Ignore it.
if sys.platform == 'win32':
if e.winerror != 5:
raise
# On Unix, we may get an ESRCH error if the process has already
# terminated. Ignore it.
else:
from errno import ESRCH
if e.errno != ESRCH:
raise
# Block until the kernel terminates.
self.kernel.wait()
self.kernel = None
def interrupt_kernel(self):
"""Interrupts the kernel by sending it a signal.
Unlike ``signal_kernel``, this operation is well supported on all
platforms.
"""
if self.has_kernel:
interrupt_mode = self.kernel_spec.interrupt_mode
if interrupt_mode == 'signal':
if sys.platform == 'win32':
from .win_interrupt import send_interrupt
send_interrupt(self.kernel.win32_interrupt_event)
else:
self.signal_kernel(signal.SIGINT)
elif interrupt_mode == 'message':
msg = self.session.msg("interrupt_request", content={})
self._connect_control_socket()
self.session.send(self._control_socket, msg)
else:
raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
def signal_kernel(self, signum):
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, this function is
only useful on Unix systems.
"""
if self.has_kernel:
if hasattr(os, "getpgid") and hasattr(os, "killpg"):
try:
pgid = os.getpgid(self.kernel.pid)
os.killpg(pgid, signum)
return
except OSError:
pass
self.kernel.send_signal(signum)
else:
raise RuntimeError("Cannot signal kernel. No kernel is running!")
def is_alive(self):
"""Is the kernel process still running?"""
if self.has_kernel:
if self.kernel.poll() is None:
return True
else:
return False
else:
# we don't have a kernel
return False
class AsyncKernelManager(KernelManager):
"""Manages kernels in an asynchronous manner """
client_class = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient')
client_factory = Type(klass='jupyter_client.asynchronous.AsyncKernelClient')
async def _launch_kernel(self, kernel_cmd, **kw):
"""actually launch the kernel
override in a subclass to launch kernel subprocesses differently
"""
res = launch_kernel(kernel_cmd, **kw)
return res
async def start_kernel(self, **kw):
"""Starts a kernel in a separate process in an asynchronous manner.
If random ports (port=0) are being used, this method must be called
before the channels are created.
Parameters
----------
`**kw` : optional
keyword arguments that are passed down to build the kernel_cmd
and launching the kernel (e.g. Popen kwargs).
"""
kernel_cmd, kw = self.pre_start_kernel(**kw)
# launch the kernel subprocess
self.log.debug("Starting kernel (async): %s", kernel_cmd)
self.kernel = await self._launch_kernel(kernel_cmd, **kw)
self.post_start_kernel(**kw)
async def finish_shutdown(self, waittime=None, pollinterval=0.1):
"""Wait for kernel shutdown, then kill process if it doesn't shutdown.
This does not send shutdown requests - use :meth:`request_shutdown`
first.
"""
if waittime is None:
waittime = max(self.shutdown_wait_time, 0)
try:
await asyncio.wait_for(self._async_wait(pollinterval=pollinterval), timeout=waittime)
except asyncio.TimeoutError:
self.log.debug("Kernel is taking too long to finish, killing")
await self._kill_kernel()
else:
# Process is no longer alive, wait and clear
if self.kernel is not None:
self.kernel.wait()
self.kernel = None
async def shutdown_kernel(self, now=False, restart=False):
"""Attempts to stop the kernel process cleanly.
This attempts to shutdown the kernels cleanly by:
1. Sending it a shutdown message over the shell channel.
2. If that fails, the kernel is shutdown forcibly by sending it
a signal.
Parameters
----------
now : bool
Should the kernel be forcible killed *now*. This skips the
first, nice shutdown attempt.
restart: bool
Will this kernel be restarted after it is shutdown. When this
is True, connection files will not be cleaned up.
"""
# Stop monitoring for restarting while we shutdown.
self.stop_restarter()
if now:
await self._kill_kernel()
else:
self.request_shutdown(restart=restart)
# Don't send any additional kernel kill messages immediately, to give
# the kernel a chance to properly execute shutdown actions. Wait for at
# most 1s, checking every 0.1s.
await self.finish_shutdown()
# See comment in KernelManager.shutdown_kernel().
overrides_cleanup = type(self).cleanup is not AsyncKernelManager.cleanup
overrides_cleanup_resources = type(self).cleanup_resources is not AsyncKernelManager.cleanup_resources
if overrides_cleanup and not overrides_cleanup_resources:
self.cleanup(connection_file=not restart)
else:
self.cleanup_resources(restart=restart)
async def restart_kernel(self, now=False, newports=False, **kw):
"""Restarts a kernel with the arguments that were used to launch it.
Parameters
----------
now : bool, optional
If True, the kernel is forcefully restarted *immediately*, without
having a chance to do any cleanup action. Otherwise the kernel is
given 1s to clean up before a forceful restart is issued.
In all cases the kernel is restarted, the only difference is whether
it is given a chance to perform a clean shutdown or not.
newports : bool, optional
If the old kernel was launched with random ports, this flag decides
whether the same ports and connection file will be used again.
If False, the same ports and connection file are used. This is
the default. If True, new random port numbers are chosen and a
new connection file is written. It is still possible that the newly
chosen random port numbers happen to be the same as the old ones.
`**kw` : optional
Any options specified here will overwrite those used to launch the
kernel.
"""
if self._launch_args is None:
raise RuntimeError("Cannot restart the kernel. "
"No previous call to 'start_kernel'.")
else:
# Stop currently running kernel.
await self.shutdown_kernel(now=now, restart=True)
if newports:
self.cleanup_random_ports()
# Start new kernel.
self._launch_args.update(kw)
await self.start_kernel(**self._launch_args)
return None
async def _kill_kernel(self):
"""Kill the running kernel.
This is a private method, callers should use shutdown_kernel(now=True).
"""
if self.has_kernel:
# Signal the kernel to terminate (sends SIGKILL on Unix and calls
# TerminateProcess() on Win32).
try:
if hasattr(signal, 'SIGKILL'):
await self.signal_kernel(signal.SIGKILL)
else:
self.kernel.kill()
except OSError as e:
# In Windows, we will get an Access Denied error if the process
# has already terminated. Ignore it.
if sys.platform == 'win32':
if e.winerror != 5:
raise
# On Unix, we may get an ESRCH error if the process has already
# terminated. Ignore it.
else:
from errno import ESRCH
if e.errno != ESRCH:
raise
# Wait until the kernel terminates.
try:
await asyncio.wait_for(self._async_wait(), timeout=5.0)
except asyncio.TimeoutError:
# Wait timed out, just log warning but continue - not much more we can do.
self.log.warning("Wait for final termination of kernel timed out - continuing...")
pass
else:
# Process is no longer alive, wait and clear
if self.kernel is not None:
self.kernel.wait()
self.kernel = None
async def interrupt_kernel(self):
"""Interrupts the kernel by sending it a signal.
Unlike ``signal_kernel``, this operation is well supported on all
platforms.
"""
if self.has_kernel:
interrupt_mode = self.kernel_spec.interrupt_mode
if interrupt_mode == 'signal':
if sys.platform == 'win32':
from .win_interrupt import send_interrupt
send_interrupt(self.kernel.win32_interrupt_event)
else:
await self.signal_kernel(signal.SIGINT)
elif interrupt_mode == 'message':
msg = self.session.msg("interrupt_request", content={})
self._connect_control_socket()
self.session.send(self._control_socket, msg)
else:
raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
async def signal_kernel(self, signum):
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, this function is
only useful on Unix systems.
"""
if self.has_kernel:
if hasattr(os, "getpgid") and hasattr(os, "killpg"):
try:
pgid = os.getpgid(self.kernel.pid)
os.killpg(pgid, signum)
return
except OSError:
pass
self.kernel.send_signal(signum)
else:
raise RuntimeError("Cannot signal kernel. No kernel is running!")
async def is_alive(self):
"""Is the kernel process still running?"""
if self.has_kernel:
if self.kernel.poll() is None:
return True
else:
return False
else:
# we don't have a kernel
return False
async def _async_wait(self, pollinterval=0.1):
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see _kill_kernel()).
while await self.is_alive():
await asyncio.sleep(pollinterval)
KernelManagerABC.register(KernelManager)
def start_new_kernel(startup_timeout=60, kernel_name='python', **kwargs):
"""Start a new kernel, and return its Manager and Client"""
km = KernelManager(kernel_name=kernel_name)
km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
raise
return km, kc
async def start_new_async_kernel(startup_timeout=60, kernel_name='python', **kwargs):
"""Start a new kernel, and return its Manager and Client"""
km = AsyncKernelManager(kernel_name=kernel_name)
await km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
await kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
await km.shutdown_kernel()
raise
return (km, kc)
@contextmanager
def run_kernel(**kwargs):
"""Context manager to create a kernel in a subprocess.
The kernel is shut down when the context exits.
Returns
-------
kernel_client: connected KernelClient instance
"""
km, kc = start_new_kernel(**kwargs)
try:
yield kc
finally:
kc.stop_channels()
km.shutdown_kernel(now=True)

View file

@ -0,0 +1,51 @@
"""Abstract base class for kernel managers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import abc
class KernelManagerABC(object, metaclass=abc.ABCMeta):
"""KernelManager ABC.
The docstrings for this class can be found in the base implementation:
`jupyter_client.kernelmanager.KernelManager`
"""
@abc.abstractproperty
def kernel(self):
pass
#--------------------------------------------------------------------------
# Kernel management
#--------------------------------------------------------------------------
@abc.abstractmethod
def start_kernel(self, **kw):
pass
@abc.abstractmethod
def shutdown_kernel(self, now=False, restart=False):
pass
@abc.abstractmethod
def restart_kernel(self, now=False, **kw):
pass
@abc.abstractproperty
def has_kernel(self):
pass
@abc.abstractmethod
def interrupt_kernel(self):
pass
@abc.abstractmethod
def signal_kernel(self, signum):
pass
@abc.abstractmethod
def is_alive(self):
pass

View file

@ -0,0 +1,554 @@
"""A kernel manager for multiple kernels"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
import uuid
import socket
import zmq
from traitlets.config.configurable import LoggingConfigurable
from ipython_genutils.importstring import import_item
from traitlets import (
Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe
)
from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager
from .manager import KernelManager, AsyncKernelManager
class DuplicateKernelError(Exception):
pass
def kernel_method(f):
"""decorator for proxying MKM.method(kernel_id) to individual KMs by ID"""
def wrapped(self, kernel_id, *args, **kwargs):
# get the kernel
km = self.get_kernel(kernel_id)
method = getattr(km, f.__name__)
# call the kernel's method
r = method(*args, **kwargs)
# last thing, call anything defined in the actual class method
# such as logging messages
f(self, kernel_id, *args, **kwargs)
# return the method result
return r
return wrapped
class MultiKernelManager(LoggingConfigurable):
"""A class for managing multiple kernels."""
default_kernel_name = Unicode(NATIVE_KERNEL_NAME, config=True,
help="The name of the default kernel to start"
)
kernel_spec_manager = Instance(KernelSpecManager, allow_none=True)
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.IOLoopKernelManager", config=True,
help="""The kernel manager class. This is configurable to allow
subclassing of the KernelManager for customized behavior.
"""
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Cache all the currently used ports
self.currently_used_ports = set()
@observe('kernel_manager_class')
def _kernel_manager_class_changed(self, change):
self.kernel_manager_factory = self._create_kernel_manager_factory()
kernel_manager_factory = Any(help="this is kernel_manager_class after import")
@default('kernel_manager_factory')
def _kernel_manager_factory_default(self):
return self._create_kernel_manager_factory()
def _create_kernel_manager_factory(self):
kernel_manager_ctor = import_item(self.kernel_manager_class)
def create_kernel_manager(*args, **kwargs):
if self.shared_context:
if self.context.closed:
# recreate context if closed
self.context = self._context_default()
kwargs.setdefault("context", self.context)
km = kernel_manager_ctor(*args, **kwargs)
if km.cache_ports:
km.shell_port = self._find_available_port(km.ip)
km.iopub_port = self._find_available_port(km.ip)
km.stdin_port = self._find_available_port(km.ip)
km.hb_port = self._find_available_port(km.ip)
km.control_port = self._find_available_port(km.ip)
return km
return create_kernel_manager
def _find_available_port(self, ip):
while True:
tmp_sock = socket.socket()
tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8)
tmp_sock.bind((ip, 0))
port = tmp_sock.getsockname()[1]
tmp_sock.close()
# This is a workaround for https://github.com/jupyter/jupyter_client/issues/487
# We prevent two kernels to have the same ports.
if port not in self.currently_used_ports:
self.currently_used_ports.add(port)
return port
shared_context = Bool(
True,
config=True,
help="Share a single zmq.Context to talk to all my kernels",
)
_created_context = Bool(False)
context = Instance('zmq.Context')
@default("context")
def _context_default(self):
self._created_context = True
return zmq.Context()
def __del__(self):
if self._created_context and self.context and not self.context.closed:
if self.log:
self.log.debug("Destroying zmq context for %s", self)
self.context.destroy()
try:
super_del = super().__del__
except AttributeError:
pass
else:
super_del()
connection_dir = Unicode('')
_kernels = Dict()
def list_kernel_ids(self):
"""Return a list of the kernel ids of the active kernels."""
# Create a copy so we can iterate over kernels in operations
# that delete keys.
return list(self._kernels.keys())
def __len__(self):
"""Return the number of running kernels."""
return len(self.list_kernel_ids())
def __contains__(self, kernel_id):
return kernel_id in self._kernels
def pre_start_kernel(self, kernel_name, kwargs):
# kwargs should be mutable, passing it as a dict argument.
kernel_id = kwargs.pop('kernel_id', self.new_kernel_id(**kwargs))
if kernel_id in self:
raise DuplicateKernelError('Kernel already exists: %s' % kernel_id)
if kernel_name is None:
kernel_name = self.default_kernel_name
# kernel_manager_factory is the constructor for the KernelManager
# subclass we are using. It can be configured as any Configurable,
# including things like its transport and ip.
constructor_kwargs = {}
if self.kernel_spec_manager:
constructor_kwargs['kernel_spec_manager'] = self.kernel_spec_manager
km = self.kernel_manager_factory(connection_file=os.path.join(
self.connection_dir, "kernel-%s.json" % kernel_id),
parent=self, log=self.log, kernel_name=kernel_name,
**constructor_kwargs
)
return km, kernel_name, kernel_id
def start_kernel(self, kernel_name=None, **kwargs):
"""Start a new kernel.
The caller can pick a kernel_id by passing one in as a keyword arg,
otherwise one will be generated using new_kernel_id().
The kernel ID for the newly started kernel is returned.
"""
km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs)
km.start_kernel(**kwargs)
self._kernels[kernel_id] = km
return kernel_id
def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by its kernel uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to shutdown.
now : bool
Should the kernel be shutdown forcibly using a signal.
restart : bool
Will the kernel be restarted?
"""
self.log.info("Kernel shutdown: %s" % kernel_id)
km = self.get_kernel(kernel_id)
ports = (
km.shell_port, km.iopub_port, km.stdin_port,
km.hb_port, km.control_port
)
km.shutdown_kernel(now=now, restart=restart)
self.remove_kernel(kernel_id)
if km.cache_ports and not restart:
for port in ports:
self.currently_used_ports.remove(port)
@kernel_method
def request_shutdown(self, kernel_id, restart=False):
"""Ask a kernel to shut down by its kernel uuid"""
@kernel_method
def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1):
"""Wait for a kernel to finish shutting down, and kill it if it doesn't
"""
self.log.info("Kernel shutdown: %s" % kernel_id)
@kernel_method
def cleanup(self, kernel_id, connection_file=True):
"""Clean up a kernel's resources"""
@kernel_method
def cleanup_resources(self, kernel_id, restart=False):
"""Clean up a kernel's resources"""
def remove_kernel(self, kernel_id):
"""remove a kernel from our mapping.
Mainly so that a kernel can be removed if it is already dead,
without having to call shutdown_kernel.
The kernel object is returned.
"""
return self._kernels.pop(kernel_id)
def shutdown_all(self, now=False):
"""Shutdown all kernels."""
kids = self.list_kernel_ids()
for kid in kids:
self.request_shutdown(kid)
for kid in kids:
self.finish_shutdown(kid)
# Determine which cleanup method to call
# See comment in KernelManager.shutdown_kernel().
km = self.get_kernel(kid)
overrides_cleanup = type(km).cleanup is not KernelManager.cleanup
overrides_cleanup_resources = type(km).cleanup_resources is not KernelManager.cleanup_resources
if overrides_cleanup and not overrides_cleanup_resources:
km.cleanup(connection_file=True)
else:
km.cleanup_resources(restart=False)
self.remove_kernel(kid)
@kernel_method
def interrupt_kernel(self, kernel_id):
"""Interrupt (SIGINT) the kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
self.log.info("Kernel interrupted: %s" % kernel_id)
@kernel_method
def signal_kernel(self, kernel_id, signum):
"""Sends a signal to the kernel by its uuid.
Note that since only SIGTERM is supported on Windows, this function
is only useful on Unix systems.
Parameters
==========
kernel_id : uuid
The id of the kernel to signal.
"""
self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum))
@kernel_method
def restart_kernel(self, kernel_id, now=False):
"""Restart a kernel by its uuid, keeping the same ports.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
self.log.info("Kernel restarted: %s" % kernel_id)
@kernel_method
def is_alive(self, kernel_id):
"""Is the kernel alive.
This calls KernelManager.is_alive() which calls Popen.poll on the
actual kernel subprocess.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
def _check_kernel_id(self, kernel_id):
"""check that a kernel id is valid"""
if kernel_id not in self:
raise KeyError("Kernel with id not found: %s" % kernel_id)
def get_kernel(self, kernel_id):
"""Get the single KernelManager object for a kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel.
"""
self._check_kernel_id(kernel_id)
return self._kernels[kernel_id]
@kernel_method
def add_restart_callback(self, kernel_id, callback, event='restart'):
"""add a callback for the KernelRestarter"""
@kernel_method
def remove_restart_callback(self, kernel_id, callback, event='restart'):
"""remove a callback for the KernelRestarter"""
@kernel_method
def get_connection_info(self, kernel_id):
"""Return a dictionary of connection data for a kernel.
Parameters
==========
kernel_id : uuid
The id of the kernel.
Returns
=======
connection_dict : dict
A dict of the information needed to connect to a kernel.
This includes the ip address and the integer port
numbers of the different channels (stdin_port, iopub_port,
shell_port, hb_port).
"""
@kernel_method
def connect_iopub(self, kernel_id, identity=None):
"""Return a zmq Socket connected to the iopub channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_shell(self, kernel_id, identity=None):
"""Return a zmq Socket connected to the shell channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_control(self, kernel_id, identity=None):
"""Return a zmq Socket connected to the control channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_stdin(self, kernel_id, identity=None):
"""Return a zmq Socket connected to the stdin channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
@kernel_method
def connect_hb(self, kernel_id, identity=None):
"""Return a zmq Socket connected to the hb channel.
Parameters
==========
kernel_id : uuid
The id of the kernel
identity : bytes (optional)
The zmq identity of the socket
Returns
=======
stream : zmq Socket or ZMQStream
"""
def new_kernel_id(self, **kwargs):
"""
Returns the id to associate with the kernel for this request. Subclasses may override
this method to substitute other sources of kernel ids.
:param kwargs:
:return: string-ized version 4 uuid
"""
return str(uuid.uuid4())
class AsyncMultiKernelManager(MultiKernelManager):
kernel_manager_class = DottedObjectName(
"jupyter_client.ioloop.AsyncIOLoopKernelManager", config=True,
help="""The kernel manager class. This is configurable to allow
subclassing of the AsyncKernelManager for customized behavior.
"""
)
async def start_kernel(self, kernel_name=None, **kwargs):
"""Start a new kernel.
The caller can pick a kernel_id by passing one in as a keyword arg,
otherwise one will be generated using new_kernel_id().
The kernel ID for the newly started kernel is returned.
"""
km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs)
if not isinstance(km, AsyncKernelManager):
self.log.warning("Kernel manager class ({km_class}) is not an instance of 'AsyncKernelManager'!".
format(km_class=self.kernel_manager_class.__class__))
await km.start_kernel(**kwargs)
self._kernels[kernel_id] = km
return kernel_id
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by its kernel uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to shutdown.
now : bool
Should the kernel be shutdown forcibly using a signal.
restart : bool
Will the kernel be restarted?
"""
self.log.info("Kernel shutdown: %s" % kernel_id)
km = self.get_kernel(kernel_id)
ports = (
km.shell_port, km.iopub_port, km.stdin_port,
km.hb_port, km.control_port
)
await km.shutdown_kernel(now, restart)
self.remove_kernel(kernel_id)
if km.cache_ports and not restart:
for port in ports:
self.currently_used_ports.remove(port)
async def finish_shutdown(self, kernel_id, waittime=None, pollinterval=0.1):
"""Wait for a kernel to finish shutting down, and kill it if it doesn't
"""
km = self.get_kernel(kernel_id)
await km.finish_shutdown(waittime, pollinterval)
self.log.info("Kernel shutdown: %s" % kernel_id)
async def interrupt_kernel(self, kernel_id):
"""Interrupt (SIGINT) the kernel by its uuid.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
km = self.get_kernel(kernel_id)
await km.interrupt_kernel()
self.log.info("Kernel interrupted: %s" % kernel_id)
async def signal_kernel(self, kernel_id, signum):
"""Sends a signal to the kernel by its uuid.
Note that since only SIGTERM is supported on Windows, this function
is only useful on Unix systems.
Parameters
==========
kernel_id : uuid
The id of the kernel to signal.
"""
km = self.get_kernel(kernel_id)
await km.signal_kernel(signum)
self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum))
async def restart_kernel(self, kernel_id, now=False):
"""Restart a kernel by its uuid, keeping the same ports.
Parameters
==========
kernel_id : uuid
The id of the kernel to interrupt.
"""
km = self.get_kernel(kernel_id)
await km.restart_kernel(now)
self.log.info("Kernel restarted: %s" % kernel_id)
async def shutdown_all(self, now=False):
"""Shutdown all kernels."""
kids = self.list_kernel_ids()
for kid in kids:
self.request_shutdown(kid)
for kid in kids:
await self.finish_shutdown(kid)
self.cleanup_resources(kid)
self.remove_kernel(kid)

View file

@ -0,0 +1,120 @@
"""A basic kernel monitor with autorestarting.
This watches a kernel's state using KernelManager.is_alive and auto
restarts the kernel if it dies.
It is an incomplete base class, and must be subclassed.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from traitlets.config.configurable import LoggingConfigurable
from traitlets import (
Instance, Float, Dict, Bool, Integer,
)
class KernelRestarter(LoggingConfigurable):
"""Monitor and autorestart a kernel."""
kernel_manager = Instance('jupyter_client.KernelManager')
debug = Bool(False, config=True,
help="""Whether to include every poll event in debugging output.
Has to be set explicitly, because there will be *a lot* of output.
"""
)
time_to_dead = Float(3.0, config=True,
help="""Kernel heartbeat interval in seconds."""
)
restart_limit = Integer(5, config=True,
help="""The number of consecutive autorestarts before the kernel is presumed dead."""
)
random_ports_until_alive = Bool(True, config=True,
help="""Whether to choose new random ports when restarting before the kernel is alive."""
)
_restarting = Bool(False)
_restart_count = Integer(0)
_initial_startup = Bool(True)
callbacks = Dict()
def _callbacks_default(self):
return dict(restart=[], dead=[])
def start(self):
"""Start the polling of the kernel."""
raise NotImplementedError("Must be implemented in a subclass")
def stop(self):
"""Stop the kernel polling."""
raise NotImplementedError("Must be implemented in a subclass")
def add_callback(self, f, event='restart'):
"""register a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
self.callbacks[event].append(f)
def remove_callback(self, f, event='restart'):
"""unregister a callback to fire on a particular event
Possible values for event:
'restart' (default): kernel has died, and will be restarted.
'dead': restart has failed, kernel will be left dead.
"""
try:
self.callbacks[event].remove(f)
except ValueError:
pass
def _fire_callbacks(self, event):
"""fire our callbacks for a particular event"""
for callback in self.callbacks[event]:
try:
callback()
except Exception as e:
self.log.error("KernelRestarter: %s callback %r failed", event, callback, exc_info=True)
def poll(self):
if self.debug:
self.log.debug('Polling kernel...')
if not self.kernel_manager.is_alive():
if self._restarting:
self._restart_count += 1
else:
self._restart_count = 1
if self._restart_count >= self.restart_limit:
self.log.warning("KernelRestarter: restart failed")
self._fire_callbacks('dead')
self._restarting = False
self._restart_count = 0
self.stop()
else:
newports = self.random_ports_until_alive and self._initial_startup
self.log.info('KernelRestarter: restarting kernel (%i/%i), %s random ports',
self._restart_count,
self.restart_limit,
'new' if newports else 'keep'
)
self._fire_callbacks('restart')
self.kernel_manager.restart_kernel(now=True, newports=newports)
self._restarting = True
else:
if self._initial_startup:
self._initial_startup = False
if self._restarting:
self.log.debug("KernelRestarter: restart apparently succeeded")
self._restarting = False

View file

@ -0,0 +1,120 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import logging
import signal
import queue
import time
import sys
from traitlets.config import catch_config_error
from traitlets import (
Instance, Dict, Unicode, Bool, List, CUnicode, Any, Float
)
from jupyter_core.application import (
JupyterApp, base_flags, base_aliases
)
from . import __version__
from .consoleapp import JupyterConsoleApp, app_aliases, app_flags
OUTPUT_TIMEOUT = 10
# copy flags from mixin:
flags = dict(base_flags)
# start with mixin frontend flags:
frontend_flags = dict(app_flags)
# update full dict with frontend flags:
flags.update(frontend_flags)
# copy flags from mixin
aliases = dict(base_aliases)
# start with mixin frontend flags
frontend_aliases = dict(app_aliases)
# load updated frontend flags into full dict
aliases.update(frontend_aliases)
# get flags&aliases into sets, and remove a couple that
# shouldn't be scrubbed from backend flags:
frontend_aliases = set(frontend_aliases.keys())
frontend_flags = set(frontend_flags.keys())
class RunApp(JupyterApp, JupyterConsoleApp):
version = __version__
name = "jupyter run"
description = """Run Jupyter kernel code."""
flags = Dict(flags)
aliases = Dict(aliases)
frontend_aliases = Any(frontend_aliases)
frontend_flags = Any(frontend_flags)
kernel_timeout = Float(60, config=True,
help="""Timeout for giving up on a kernel (in seconds).
On first connect and restart, the console tests whether the
kernel is running and responsive by sending kernel_info_requests.
This sets the timeout in seconds for how long the kernel can take
before being presumed dead.
"""
)
def parse_command_line(self, argv=None):
super().parse_command_line(argv)
self.build_kernel_argv(self.extra_args)
self.filenames_to_run = self.extra_args[:]
@catch_config_error
def initialize(self, argv=None):
self.log.debug("jupyter run: initialize...")
super().initialize(argv)
JupyterConsoleApp.initialize(self)
signal.signal(signal.SIGINT, self.handle_sigint)
self.init_kernel_info()
def handle_sigint(self, *args):
if self.kernel_manager:
self.kernel_manager.interrupt_kernel()
else:
print("", file=sys.stderr)
error("Cannot interrupt kernels we didn't start.\n")
def init_kernel_info(self):
"""Wait for a kernel to be ready, and store kernel info"""
timeout = self.kernel_timeout
tic = time.time()
self.kernel_client.hb_channel.unpause()
msg_id = self.kernel_client.kernel_info()
while True:
try:
reply = self.kernel_client.get_shell_msg(timeout=1)
except queue.Empty as e:
if (time.time() - tic) > timeout:
raise RuntimeError("Kernel didn't respond to kernel_info_request") from e
else:
if reply['parent_header'].get('msg_id') == msg_id:
self.kernel_info = reply['content']
return
def start(self):
self.log.debug("jupyter run: starting...")
super().start()
if self.filenames_to_run:
for filename in self.filenames_to_run:
self.log.debug("jupyter run: executing `%s`" % filename)
with open(filename) as fp:
code = fp.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply['content']['status'] == 'ok' else 1
if return_code:
raise Exception("jupyter-run error running '%s'" % filename)
else:
code = sys.stdin.read()
reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT)
return_code = 0 if reply['content']['status'] == 'ok' else 1
if return_code:
raise Exception("jupyter-run error running 'stdin'")
main = launch_new_instance = RunApp.launch_instance
if __name__ == '__main__':
main()

View file

@ -0,0 +1,972 @@
"""Session object for building, serializing, sending, and receiving messages.
The Session object supports serialization, HMAC signatures,
and metadata on messages.
Also defined here are utilities for working with Sessions:
* A SessionFactory to be used as a base class for configurables that work with
Sessions.
* A Message object for convenience that allows attribute-access to the msg dict.
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from binascii import b2a_hex
import hashlib
import hmac
import logging
import os
import pickle
import pprint
import random
import warnings
from datetime import datetime
PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
# We are using compare_digest to limit the surface of timing attacks
from hmac import compare_digest
from datetime import timezone
utc = timezone.utc
import zmq
from zmq.utils import jsonapi
from zmq.eventloop.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream
from traitlets.config.configurable import Configurable, LoggingConfigurable
from ipython_genutils.importstring import import_item
from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
from ipython_genutils.py3compat import str_to_bytes, str_to_unicode
from traitlets import (
CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode,
Dict, Integer, TraitError, observe
)
from jupyter_client import protocol_version
from jupyter_client.adapter import adapt
from traitlets.log import get_logger
#-----------------------------------------------------------------------------
# utility functions
#-----------------------------------------------------------------------------
def squash_unicode(obj):
"""coerce unicode back to bytestrings."""
if isinstance(obj,dict):
for key in obj.keys():
obj[key] = squash_unicode(obj[key])
if isinstance(key, str):
obj[squash_unicode(key)] = obj.pop(key)
elif isinstance(obj, list):
for i,v in enumerate(obj):
obj[i] = squash_unicode(v)
elif isinstance(obj, str):
obj = obj.encode('utf8')
return obj
#-----------------------------------------------------------------------------
# globals and defaults
#-----------------------------------------------------------------------------
# default values for the thresholds:
MAX_ITEMS = 64
MAX_BYTES = 1024
# ISO8601-ify datetime objects
# allow unicode
# disallow nan, because it's not actually valid JSON
json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
ensure_ascii=False, allow_nan=False,
)
json_unpacker = lambda s: jsonapi.loads(s)
pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
pickle_unpacker = pickle.loads
default_packer = json_packer
default_unpacker = json_unpacker
DELIM = b"<IDS|MSG>"
# singleton dummy tracker, which will always report as done
DONE = zmq.MessageTracker()
#-----------------------------------------------------------------------------
# Mixin tools for apps that use Sessions
#-----------------------------------------------------------------------------
def new_id():
"""Generate a new random id.
Avoids problematic runtime import in stdlib uuid on Python 2.
Returns
-------
id string (16 random bytes as hex-encoded text, chunks separated by '-')
"""
buf = os.urandom(16)
return '-'.join(b2a_hex(x).decode('ascii') for x in (
buf[:4], buf[4:]
))
def new_id_bytes():
"""Return new_id as ascii bytes"""
return new_id().encode('ascii')
session_aliases = dict(
ident = 'Session.session',
user = 'Session.username',
keyfile = 'Session.keyfile',
)
session_flags = {
'secure' : ({'Session' : { 'key' : new_id_bytes(),
'keyfile' : '' }},
"""Use HMAC digests for authentication of messages.
Setting this flag will generate a new UUID to use as the HMAC key.
"""),
'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
"""Don't authenticate messages."""),
}
def default_secure(cfg):
"""Set the default behavior for a config environment to be secure.
If Session.key/keyfile have not been set, set Session.key to
a new random UUID.
"""
warnings.warn("default_secure is deprecated", DeprecationWarning)
if 'Session' in cfg:
if 'key' in cfg.Session or 'keyfile' in cfg.Session:
return
# key/keyfile not specified, generate new UUID:
cfg.Session.key = new_id_bytes()
def utcnow():
"""Return timezone-aware UTC timestamp"""
return datetime.utcnow().replace(tzinfo=utc)
#-----------------------------------------------------------------------------
# Classes
#-----------------------------------------------------------------------------
class SessionFactory(LoggingConfigurable):
"""The Base class for configurables that have a Session, Context, logger,
and IOLoop.
"""
logname = Unicode('')
@observe('logname')
def _logname_changed(self, change):
self.log = logging.getLogger(change['new'])
# not configurable:
context = Instance('zmq.Context')
def _context_default(self):
return zmq.Context()
session = Instance('jupyter_client.session.Session',
allow_none=True)
loop = Instance('tornado.ioloop.IOLoop')
def _loop_default(self):
return IOLoop.current()
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.session is None:
# construct the session
self.session = Session(**kwargs)
class Message(object):
"""A simple message object that maps dict keys to attributes.
A Message can be created from a dict and a dict from a Message instance
simply by calling dict(msg_obj)."""
def __init__(self, msg_dict):
dct = self.__dict__
for k, v in dict(msg_dict).items():
if isinstance(v, dict):
v = Message(v)
dct[k] = v
# Having this iterator lets dict(msg_obj) work out of the box.
def __iter__(self):
return self.__dict__.items()
def __repr__(self):
return repr(self.__dict__)
def __str__(self):
return pprint.pformat(self.__dict__)
def __contains__(self, k):
return k in self.__dict__
def __getitem__(self, k):
return self.__dict__[k]
def msg_header(msg_id, msg_type, username, session):
"""Create a new message header"""
date = utcnow()
version = protocol_version
return locals()
def extract_header(msg_or_header):
"""Given a message or header, return the header."""
if not msg_or_header:
return {}
try:
# See if msg_or_header is the entire message.
h = msg_or_header['header']
except KeyError:
try:
# See if msg_or_header is just the header
h = msg_or_header['msg_id']
except KeyError:
raise
else:
h = msg_or_header
if not isinstance(h, dict):
h = dict(h)
return h
class Session(Configurable):
"""Object for handling serialization and sending of messages.
The Session object handles building messages and sending them
with ZMQ sockets or ZMQStream objects. Objects can communicate with each
other over the network via Session objects, and only need to work with the
dict-based IPython message spec. The Session will handle
serialization/deserialization, security, and metadata.
Sessions support configurable serialization via packer/unpacker traits,
and signing with HMAC digests via the key/keyfile traits.
Parameters
----------
debug : bool
whether to trigger extra debugging statements
packer/unpacker : str : 'json', 'pickle' or import_string
importstrings for methods to serialize message parts. If just
'json' or 'pickle', predefined JSON and pickle packers will be used.
Otherwise, the entire importstring must be used.
The functions must accept at least valid JSON input, and output *bytes*.
For example, to use msgpack:
packer = 'msgpack.packb', unpacker='msgpack.unpackb'
pack/unpack : callables
You can also set the pack/unpack callables for serialization directly.
session : bytes
the ID of this Session object. The default is to generate a new UUID.
username : unicode
username added to message headers. The default is to ask the OS.
key : bytes
The key used to initialize an HMAC signature. If unset, messages
will not be signed or checked.
keyfile : filepath
The file containing a key. If this is set, `key` will be initialized
to the contents of the file.
"""
debug = Bool(False, config=True, help="""Debug output in the Session""")
check_pid = Bool(True, config=True,
help="""Whether to check PID to protect against calls after fork.
This check can be disabled if fork-safety is handled elsewhere.
""")
packer = DottedObjectName('json',config=True,
help="""The name of the packer for serializing messages.
Should be one of 'json', 'pickle', or an import name
for a custom callable serializer.""")
@observe('packer')
def _packer_changed(self, change):
new = change['new']
if new.lower() == 'json':
self.pack = json_packer
self.unpack = json_unpacker
self.unpacker = new
elif new.lower() == 'pickle':
self.pack = pickle_packer
self.unpack = pickle_unpacker
self.unpacker = new
else:
self.pack = import_item(str(new))
unpacker = DottedObjectName('json', config=True,
help="""The name of the unpacker for unserializing messages.
Only used with custom functions for `packer`.""")
@observe('unpacker')
def _unpacker_changed(self, change):
new = change['new']
if new.lower() == 'json':
self.pack = json_packer
self.unpack = json_unpacker
self.packer = new
elif new.lower() == 'pickle':
self.pack = pickle_packer
self.unpack = pickle_unpacker
self.packer = new
else:
self.unpack = import_item(str(new))
session = CUnicode('', config=True,
help="""The UUID identifying this session.""")
def _session_default(self):
u = new_id()
self.bsession = u.encode('ascii')
return u
@observe('session')
def _session_changed(self, change):
self.bsession = self.session.encode('ascii')
# bsession is the session as bytes
bsession = CBytes(b'')
username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
help="""Username for the Session. Default is your system username.""",
config=True)
metadata = Dict({}, config=True,
help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
# if 0, no adapting to do.
adapt_version = Integer(0)
# message signature related traits:
key = CBytes(config=True,
help="""execution key, for signing messages.""")
def _key_default(self):
return new_id_bytes()
@observe('key')
def _key_changed(self, change):
self._new_auth()
signature_scheme = Unicode('hmac-sha256', config=True,
help="""The digest scheme used to construct the message signatures.
Must have the form 'hmac-HASH'.""")
@observe('signature_scheme')
def _signature_scheme_changed(self, change):
new = change['new']
if not new.startswith('hmac-'):
raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
hash_name = new.split('-', 1)[1]
try:
self.digest_mod = getattr(hashlib, hash_name)
except AttributeError as e:
raise TraitError("hashlib has no such attribute: %s" %
hash_name) from e
self._new_auth()
digest_mod = Any()
def _digest_mod_default(self):
return hashlib.sha256
auth = Instance(hmac.HMAC, allow_none=True)
def _new_auth(self):
if self.key:
self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
else:
self.auth = None
digest_history = Set()
digest_history_size = Integer(2**16, config=True,
help="""The maximum number of digests to remember.
The digest history will be culled when it exceeds this value.
"""
)
keyfile = Unicode('', config=True,
help="""path to file containing execution key.""")
@observe('keyfile')
def _keyfile_changed(self, change):
with open(change['new'], 'rb') as f:
self.key = f.read().strip()
# for protecting against sends from forks
pid = Integer()
# serialization traits:
pack = Any(default_packer) # the actual packer function
@observe('pack')
def _pack_changed(self, change):
new = change['new']
if not callable(new):
raise TypeError("packer must be callable, not %s"%type(new))
unpack = Any(default_unpacker) # the actual packer function
@observe('unpack')
def _unpack_changed(self, change):
# unpacker is not checked - it is assumed to be
new = change['new']
if not callable(new):
raise TypeError("unpacker must be callable, not %s"%type(new))
# thresholds:
copy_threshold = Integer(2**16, config=True,
help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
buffer_threshold = Integer(MAX_BYTES, config=True,
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
item_threshold = Integer(MAX_ITEMS, config=True,
help="""The maximum number of items for a container to be introspected for custom serialization.
Containers larger than this are pickled outright.
"""
)
def __init__(self, **kwargs):
"""create a Session object
Parameters
----------
debug : bool
whether to trigger extra debugging statements
packer/unpacker : str : 'json', 'pickle' or import_string
importstrings for methods to serialize message parts. If just
'json' or 'pickle', predefined JSON and pickle packers will be used.
Otherwise, the entire importstring must be used.
The functions must accept at least valid JSON input, and output
*bytes*.
For example, to use msgpack:
packer = 'msgpack.packb', unpacker='msgpack.unpackb'
pack/unpack : callables
You can also set the pack/unpack callables for serialization
directly.
session : unicode (must be ascii)
the ID of this Session object. The default is to generate a new
UUID.
bsession : bytes
The session as bytes
username : unicode
username added to message headers. The default is to ask the OS.
key : bytes
The key used to initialize an HMAC signature. If unset, messages
will not be signed or checked.
signature_scheme : str
The message digest scheme. Currently must be of the form 'hmac-HASH',
where 'HASH' is a hashing function available in Python's hashlib.
The default is 'hmac-sha256'.
This is ignored if 'key' is empty.
keyfile : filepath
The file containing a key. If this is set, `key` will be
initialized to the contents of the file.
"""
super().__init__(**kwargs)
self._check_packers()
self.none = self.pack({})
# ensure self._session_default() if necessary, so bsession is defined:
self.session
self.pid = os.getpid()
self._new_auth()
if not self.key:
get_logger().warning("Message signing is disabled. This is insecure and not recommended!")
def clone(self):
"""Create a copy of this Session
Useful when connecting multiple times to a given kernel.
This prevents a shared digest_history warning about duplicate digests
due to multiple connections to IOPub in the same process.
.. versionadded:: 5.1
"""
# make a copy
new_session = type(self)()
for name in self.traits():
setattr(new_session, name, getattr(self, name))
# fork digest_history
new_session.digest_history = set()
new_session.digest_history.update(self.digest_history)
return new_session
message_count = 0
@property
def msg_id(self):
message_number = self.message_count
self.message_count += 1
return '{}_{}'.format(self.session, message_number)
def _check_packers(self):
"""check packers for datetime support."""
pack = self.pack
unpack = self.unpack
# check simple serialization
msg = dict(a=[1,'hi'])
try:
packed = pack(msg)
except Exception as e:
msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
if self.packer == 'json':
jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
else:
jsonmsg = ""
raise ValueError(
msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
) from e
# ensure packed message is bytes
if not isinstance(packed, bytes):
raise ValueError("message packed to %r, but bytes are required"%type(packed))
# check that unpack is pack's inverse
try:
unpacked = unpack(packed)
assert unpacked == msg
except Exception as e:
msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
if self.packer == 'json':
jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
else:
jsonmsg = ""
raise ValueError(
msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
) from e
# check datetime support
msg = dict(t=utcnow())
try:
unpacked = unpack(pack(msg))
if isinstance(unpacked['t'], datetime):
raise ValueError("Shouldn't deserialize to datetime")
except Exception:
self.pack = lambda o: pack(squash_dates(o))
self.unpack = lambda s: unpack(s)
def msg_header(self, msg_type):
return msg_header(self.msg_id, msg_type, self.username, self.session)
def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
"""Return the nested message dict.
This format is different from what is sent over the wire. The
serialize/deserialize methods converts this nested message dict to the wire
format, which is a list of message parts.
"""
msg = {}
header = self.msg_header(msg_type) if header is None else header
msg['header'] = header
msg['msg_id'] = header['msg_id']
msg['msg_type'] = header['msg_type']
msg['parent_header'] = {} if parent is None else extract_header(parent)
msg['content'] = {} if content is None else content
msg['metadata'] = self.metadata.copy()
if metadata is not None:
msg['metadata'].update(metadata)
return msg
def sign(self, msg_list):
"""Sign a message with HMAC digest. If no auth, return b''.
Parameters
----------
msg_list : list
The [p_header,p_parent,p_content] part of the message list.
"""
if self.auth is None:
return b''
h = self.auth.copy()
for m in msg_list:
h.update(m)
return str_to_bytes(h.hexdigest())
def serialize(self, msg, ident=None):
"""Serialize the message components to bytes.
This is roughly the inverse of deserialize. The serialize/deserialize
methods work with full message lists, whereas pack/unpack work with
the individual message parts in the message list.
Parameters
----------
msg : dict or Message
The next message dict as returned by the self.msg method.
Returns
-------
msg_list : list
The list of bytes objects to be sent with the format::
[ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
p_metadata, p_content, buffer1, buffer2, ...]
In this list, the ``p_*`` entities are the packed or serialized
versions, so if JSON is used, these are utf8 encoded JSON strings.
"""
content = msg.get('content', {})
if content is None:
content = self.none
elif isinstance(content, dict):
content = self.pack(content)
elif isinstance(content, bytes):
# content is already packed, as in a relayed message
pass
elif isinstance(content, str):
# should be bytes, but JSON often spits out unicode
content = content.encode('utf8')
else:
raise TypeError("Content incorrect type: %s"%type(content))
real_message = [self.pack(msg['header']),
self.pack(msg['parent_header']),
self.pack(msg['metadata']),
content,
]
to_send = []
if isinstance(ident, list):
# accept list of idents
to_send.extend(ident)
elif ident is not None:
to_send.append(ident)
to_send.append(DELIM)
signature = self.sign(real_message)
to_send.append(signature)
to_send.extend(real_message)
return to_send
def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
buffers=None, track=False, header=None, metadata=None):
"""Build and send a message via stream or socket.
The message format used by this function internally is as follows:
[ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
buffer1,buffer2,...]
The serialize/deserialize methods convert the nested message dict into this
format.
Parameters
----------
stream : zmq.Socket or ZMQStream
The socket-like object used to send the data.
msg_or_type : str or Message/dict
Normally, msg_or_type will be a msg_type unless a message is being
sent more than once. If a header is supplied, this can be set to
None and the msg_type will be pulled from the header.
content : dict or None
The content of the message (ignored if msg_or_type is a message).
header : dict or None
The header dict for the message (ignored if msg_to_type is a message).
parent : Message or dict or None
The parent or parent header describing the parent of this message
(ignored if msg_or_type is a message).
ident : bytes or list of bytes
The zmq.IDENTITY routing path.
metadata : dict or None
The metadata describing the message
buffers : list or None
The already-serialized buffers to be appended to the message.
track : bool
Whether to track. Only for use with Sockets, because ZMQStream
objects cannot track messages.
Returns
-------
msg : dict
The constructed message.
"""
if not isinstance(stream, zmq.Socket):
# ZMQStreams and dummy sockets do not support tracking.
track = False
if isinstance(msg_or_type, (Message, dict)):
# We got a Message or message dict, not a msg_type so don't
# build a new Message.
msg = msg_or_type
buffers = buffers or msg.get('buffers', [])
else:
msg = self.msg(msg_or_type, content=content, parent=parent,
header=header, metadata=metadata)
if self.check_pid and not os.getpid() == self.pid:
get_logger().warning("WARNING: attempted to send message from fork\n%s",
msg
)
return
buffers = [] if buffers is None else buffers
for idx, buf in enumerate(buffers):
if isinstance(buf, memoryview):
view = buf
else:
try:
# check to see if buf supports the buffer protocol.
view = memoryview(buf)
except TypeError as e:
raise TypeError("Buffer objects must support the buffer protocol.") from e
# memoryview.contiguous is new in 3.3,
# just skip the check on Python 2
if hasattr(view, 'contiguous') and not view.contiguous:
# zmq requires memoryviews to be contiguous
raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf))
if self.adapt_version:
msg = adapt(msg, self.adapt_version)
to_send = self.serialize(msg, ident)
to_send.extend(buffers)
longest = max([ len(s) for s in to_send ])
copy = (longest < self.copy_threshold)
if buffers and track and not copy:
# only really track when we are doing zero-copy buffers
tracker = stream.send_multipart(to_send, copy=False, track=True)
else:
# use dummy tracker, which will be done immediately
tracker = DONE
stream.send_multipart(to_send, copy=copy)
if self.debug:
pprint.pprint(msg)
pprint.pprint(to_send)
pprint.pprint(buffers)
msg['tracker'] = tracker
return msg
def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
"""Send a raw message via ident path.
This method is used to send a already serialized message.
Parameters
----------
stream : ZMQStream or Socket
The ZMQ stream or socket to use for sending the message.
msg_list : list
The serialized list of messages to send. This only includes the
[p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
the message.
ident : ident or list
A single ident or a list of idents to use in sending.
"""
to_send = []
if isinstance(ident, bytes):
ident = [ident]
if ident is not None:
to_send.extend(ident)
to_send.append(DELIM)
# Don't include buffers in signature (per spec).
to_send.append(self.sign(msg_list[0:4]))
to_send.extend(msg_list)
stream.send_multipart(to_send, flags, copy=copy)
def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
"""Receive and unpack a message.
Parameters
----------
socket : ZMQStream or Socket
The socket or stream to use in receiving.
Returns
-------
[idents], msg
[idents] is a list of idents and msg is a nested message dict of
same format as self.msg returns.
"""
if isinstance(socket, ZMQStream):
socket = socket.socket
try:
msg_list = socket.recv_multipart(mode, copy=copy)
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN:
# We can convert EAGAIN to None as we know in this case
# recv_multipart won't return None.
return None,None
else:
raise
# split multipart message into identity list and message dict
# invalid large messages can cause very expensive string comparisons
idents, msg_list = self.feed_identities(msg_list, copy)
try:
return idents, self.deserialize(msg_list, content=content, copy=copy)
except Exception as e:
# TODO: handle it
raise e
def feed_identities(self, msg_list, copy=True):
"""Split the identities from the rest of the message.
Feed until DELIM is reached, then return the prefix as idents and
remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
but that would be silly.
Parameters
----------
msg_list : a list of Message or bytes objects
The message to be split.
copy : bool
flag determining whether the arguments are bytes or Messages
Returns
-------
(idents, msg_list) : two lists
idents will always be a list of bytes, each of which is a ZMQ
identity. msg_list will be a list of bytes or zmq.Messages of the
form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
should be unpackable/unserializable via self.deserialize at this
point.
"""
if copy:
idx = msg_list.index(DELIM)
return msg_list[:idx], msg_list[idx+1:]
else:
failed = True
for idx,m in enumerate(msg_list):
if m.bytes == DELIM:
failed = False
break
if failed:
raise ValueError("DELIM not in msg_list")
idents, msg_list = msg_list[:idx], msg_list[idx+1:]
return [m.bytes for m in idents], msg_list
def _add_digest(self, signature):
"""add a digest to history to protect against replay attacks"""
if self.digest_history_size == 0:
# no history, never add digests
return
self.digest_history.add(signature)
if len(self.digest_history) > self.digest_history_size:
# threshold reached, cull 10%
self._cull_digest_history()
def _cull_digest_history(self):
"""cull the digest history
Removes a randomly selected 10% of the digest history
"""
current = len(self.digest_history)
n_to_cull = max(int(current // 10), current - self.digest_history_size)
if n_to_cull >= current:
self.digest_history = set()
return
to_cull = random.sample(self.digest_history, n_to_cull)
self.digest_history.difference_update(to_cull)
def deserialize(self, msg_list, content=True, copy=True):
"""Unserialize a msg_list to a nested message dict.
This is roughly the inverse of serialize. The serialize/deserialize
methods work with full message lists, whereas pack/unpack work with
the individual message parts in the message list.
Parameters
----------
msg_list : list of bytes or Message objects
The list of message parts of the form [HMAC,p_header,p_parent,
p_metadata,p_content,buffer1,buffer2,...].
content : bool (True)
Whether to unpack the content dict (True), or leave it packed
(False).
copy : bool (True)
Whether msg_list contains bytes (True) or the non-copying Message
objects in each place (False).
Returns
-------
msg : dict
The nested message dict with top-level keys [header, parent_header,
content, buffers]. The buffers are returned as memoryviews.
"""
minlen = 5
message = {}
if not copy:
# pyzmq didn't copy the first parts of the message, so we'll do it
for i in range(minlen):
msg_list[i] = msg_list[i].bytes
if self.auth is not None:
signature = msg_list[0]
if not signature:
raise ValueError("Unsigned Message")
if signature in self.digest_history:
raise ValueError("Duplicate Signature: %r" % signature)
if content:
# Only store signature if we are unpacking content, don't store if just peeking.
self._add_digest(signature)
check = self.sign(msg_list[1:5])
if not compare_digest(signature, check):
raise ValueError("Invalid Signature: %r" % signature)
if not len(msg_list) >= minlen:
raise TypeError("malformed message, must have at least %i elements"%minlen)
header = self.unpack(msg_list[1])
message['header'] = extract_dates(header)
message['msg_id'] = header['msg_id']
message['msg_type'] = header['msg_type']
message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
message['metadata'] = self.unpack(msg_list[3])
if content:
message['content'] = self.unpack(msg_list[4])
else:
message['content'] = msg_list[4]
buffers = [memoryview(b) for b in msg_list[5:]]
if buffers and buffers[0].shape is None:
# force copy to workaround pyzmq #646
buffers = [memoryview(b.bytes) for b in msg_list[5:]]
message['buffers'] = buffers
if self.debug:
pprint.pprint(message)
# adapt to the current version
return adapt(message)
def unserialize(self, *args, **kwargs):
warnings.warn(
"Session.unserialize is deprecated. Use Session.deserialize.",
DeprecationWarning,
)
return self.deserialize(*args, **kwargs)
def test_msg2obj():
am = dict(x=1)
ao = Message(am)
assert ao.x == am['x']
am['y'] = dict(z=1)
ao = Message(am)
assert ao.y.z == am['y']['z']
k1, k2 = 'y', 'z'
assert ao[k1][k2] == am[k1][k2]
am2 = dict(ao)
assert am['x'] == am2['x']
assert am['y']['z'] == am2['y']['z']

View file

@ -0,0 +1 @@
from jupyter_client.ssh.tunnel import *

View file

@ -0,0 +1,87 @@
#
# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
# Edits Copyright (C) 2010 The IPython Team
#
# Paramiko is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Paramiko; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
"""
Sample script showing how to do local port forwarding over paramiko.
This script connects to the requested SSH server and sets up local port
forwarding (the openssh -L option) from a local port through a tunneled
connection to a destination reachable from the SSH server machine.
"""
import logging
import select
import socketserver
logger = logging.getLogger('ssh')
class ForwardServer (socketserver.ThreadingTCPServer):
daemon_threads = True
allow_reuse_address = True
class Handler (socketserver.BaseRequestHandler):
def handle(self):
try:
chan = self.ssh_transport.open_channel('direct-tcpip',
(self.chain_host, self.chain_port),
self.request.getpeername())
except Exception as e:
logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host,
self.chain_port,
repr(e)))
return
if chan is None:
logger.debug('Incoming request to %s:%d was rejected by the SSH server.' %
(self.chain_host, self.chain_port))
return
logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
chan.getpeername(), (self.chain_host, self.chain_port)))
while True:
r, w, x = select.select([self.request, chan], [], [])
if self.request in r:
data = self.request.recv(1024)
if len(data) == 0:
break
chan.send(data)
if chan in r:
data = chan.recv(1024)
if len(data) == 0:
break
self.request.send(data)
chan.close()
self.request.close()
logger.debug('Tunnel closed ')
def forward_tunnel(local_port, remote_host, remote_port, transport):
# this is a little convoluted, but lets me configure things for the Handler
# object. (SocketServer doesn't give Handlers any way to access the outer
# server normally.)
class SubHander (Handler):
chain_host = remote_host
chain_port = remote_port
ssh_transport = transport
ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
__all__ = ['forward_tunnel']

View file

@ -0,0 +1,372 @@
"""Basic ssh tunnel utilities, and convenience functions for tunneling
zeromq connections.
"""
# Copyright (C) 2010-2011 IPython Development Team
# Copyright (C) 2011- PyZMQ Developers
#
# Redistributed from IPython under the terms of the BSD License.
import atexit
import os
import re
import signal
import socket
import sys
import warnings
from getpass import getpass, getuser
from multiprocessing import Process
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
import paramiko
SSHException = paramiko.ssh_exception.SSHException
except ImportError:
paramiko = None
class SSHException(Exception):
pass
else:
from .forward import forward_tunnel
try:
import pexpect
except ImportError:
pexpect = None
from zmq.utils.strtypes import b
def select_random_ports(n):
"""Select and return n random ports that are available."""
ports = []
sockets = []
for i in range(n):
sock = socket.socket()
sock.bind(('', 0))
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports
#-----------------------------------------------------------------------------
# Check for passwordless login
#-----------------------------------------------------------------------------
_password_pat = re.compile(b(r'pass(word|phrase):'), re.IGNORECASE)
def try_passwordless_ssh(server, keyfile, paramiko=None):
"""Attempt to make an ssh connection without a password.
This is mainly used for requiring password input only once
when many tunnels may be connected to the same server.
If paramiko is None, the default for the platform is chosen.
"""
if paramiko is None:
paramiko = sys.platform == 'win32'
if not paramiko:
f = _try_passwordless_openssh
else:
f = _try_passwordless_paramiko
return f(server, keyfile)
def _try_passwordless_openssh(server, keyfile):
"""Try passwordless login with shell ssh command."""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko")
cmd = 'ssh -f ' + server
if keyfile:
cmd += ' -i ' + keyfile
cmd += ' exit'
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop('SSH_ASKPASS', None)
ssh_newkey = 'Are you sure you want to continue connecting'
p = pexpect.spawn(cmd, env=env)
while True:
try:
i = p.expect([ssh_newkey, _password_pat], timeout=.1)
if i == 0:
raise SSHException('The authenticity of the host can\'t be established.')
except pexpect.TIMEOUT:
continue
except pexpect.EOF:
return True
else:
return False
def _try_passwordless_paramiko(server, keyfile):
"""Try passwordless login with paramiko."""
if paramiko is None:
msg = "Paramiko unavailable, "
if sys.platform == 'win32':
msg += "Paramiko is required for ssh tunneled connections on Windows."
else:
msg += "use OpenSSH."
raise ImportError(msg)
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(server, port, username=username, key_filename=keyfile,
look_for_keys=True)
except paramiko.AuthenticationException:
return False
else:
client.close()
return True
def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60):
"""Connect a socket to an address via an ssh tunnel.
This is a wrapper for socket.connect(addr), when addr is not accessible
from the local machine. It simply creates an ssh tunnel using the remaining args,
and calls socket.connect('tcp://localhost:lport') where lport is the randomly
selected local port of the tunnel.
"""
new_url, tunnel = open_tunnel(addr, server, keyfile=keyfile, password=password, paramiko=paramiko, timeout=timeout)
socket.connect(new_url)
return tunnel
def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60):
"""Open a tunneled connection from a 0MQ url.
For use inside tunnel_connection.
Returns
-------
(url, tunnel) : (str, object)
The 0MQ url that has been forwarded, and the tunnel object
"""
lport = select_random_ports(1)[0]
transport, addr = addr.split('://')
ip, rport = addr.split(':')
rport = int(rport)
if paramiko is None:
paramiko = sys.platform == 'win32'
if paramiko:
tunnelf = paramiko_tunnel
else:
tunnelf = openssh_tunnel
tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout)
return 'tcp://127.0.0.1:%i' % lport, tunnel
def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60):
"""Create an ssh tunnel using command-line ssh that connects port lport
on this machine to localhost:rport on server. The tunnel
will automatically close when not in use, remaining open
for a minimum of timeout seconds for an initial connection.
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if pexpect is None:
raise ImportError("pexpect unavailable, use paramiko_tunnel")
ssh = "ssh "
if keyfile:
ssh += "-i " + keyfile
if ':' in server:
server, port = server.split(':')
ssh += " -p %s" % port
cmd = "%s -O check %s" % (ssh, server)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
pid = int(output[output.find(b"(pid=")+5:output.find(b")")])
cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % (
ssh, lport, remoteip, rport, server)
(output, exitstatus) = pexpect.run(cmd, withexitstatus=True)
if not exitstatus:
atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1))
return pid
cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % (
ssh, lport, remoteip, rport, server, timeout)
# pop SSH_ASKPASS from env
env = os.environ.copy()
env.pop('SSH_ASKPASS', None)
ssh_newkey = 'Are you sure you want to continue connecting'
tunnel = pexpect.spawn(cmd, env=env)
failed = False
while True:
try:
i = tunnel.expect([ssh_newkey, _password_pat], timeout=.1)
if i == 0:
raise SSHException('The authenticity of the host can\'t be established.')
except pexpect.TIMEOUT:
continue
except pexpect.EOF as e:
if tunnel.exitstatus:
print(tunnel.exitstatus)
print(tunnel.before)
print(tunnel.after)
raise RuntimeError("tunnel '%s' failed to start" % (cmd)) from e
else:
return tunnel.pid
else:
if failed:
print("Password rejected, try again")
password = None
if password is None:
password = getpass("%s's password: " % (server))
tunnel.sendline(password)
failed = True
def _stop_tunnel(cmd):
pexpect.run(cmd)
def _split_server(server):
if '@' in server:
username, server = server.split('@', 1)
else:
username = getuser()
if ':' in server:
server, port = server.split(':')
port = int(port)
else:
port = 22
return username, server, port
def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60):
"""launch a tunner with paramiko in a subprocess. This should only be used
when shell ssh is unavailable (e.g. Windows).
This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
as seen from `server`.
If you are familiar with ssh tunnels, this creates the tunnel:
ssh server -L localhost:lport:remoteip:rport
keyfile and password may be specified, but ssh config is checked for defaults.
Parameters
----------
lport : int
local port for connecting to the tunnel from this machine.
rport : int
port on the remote machine to connect to.
server : str
The ssh server to connect to. The full ssh server string will be parsed.
user@server:port
remoteip : str [Default: 127.0.0.1]
The remote ip, specifying the destination of the tunnel.
Default is localhost, which means that the tunnel would redirect
localhost:lport on this machine to localhost:rport on the *server*.
keyfile : str; path to public key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to the ssh server. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
timeout : int [default: 60]
The time (in seconds) after which no activity will result in the tunnel
closing. This prevents orphaned tunnels from running forever.
"""
if paramiko is None:
raise ImportError("Paramiko not available")
if password is None:
if not _try_passwordless_paramiko(server, keyfile):
password = getpass("%s's password: " % (server))
p = Process(target=_paramiko_tunnel,
args=(lport, rport, server, remoteip),
kwargs=dict(keyfile=keyfile, password=password))
p.daemon = True
p.start()
return p
def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
"""Function for actually starting a paramiko tunnel, to be passed
to multiprocessing.Process(target=this), and not called directly.
"""
username, server, port = _split_server(server)
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
try:
client.connect(server, port, username=username, key_filename=keyfile,
look_for_keys=True, password=password)
# except paramiko.AuthenticationException:
# if password is None:
# password = getpass("%s@%s's password: "%(username, server))
# client.connect(server, port, username=username, password=password)
# else:
# raise
except Exception as e:
print('*** Failed to connect to %s:%d: %r' % (server, port, e))
sys.exit(1)
# Don't let SIGINT kill the tunnel subprocess
signal.signal(signal.SIGINT, signal.SIG_IGN)
try:
forward_tunnel(lport, remoteip, rport, client.get_transport())
except KeyboardInterrupt:
print('SIGINT: Port forwarding stopped cleanly')
sys.exit(0)
except Exception as e:
print("Port forwarding stopped uncleanly: %s" % e)
sys.exit(255)
if sys.platform == 'win32':
ssh_tunnel = paramiko_tunnel
else:
ssh_tunnel = openssh_tunnel
__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh']

View file

@ -0,0 +1,77 @@
"""Test kernel for signalling subprocesses"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
from subprocess import Popen, PIPE
import sys
import time
from ipykernel.displayhook import ZMQDisplayHook
from ipykernel.kernelbase import Kernel
from ipykernel.kernelapp import IPKernelApp
class SignalTestKernel(Kernel):
"""Kernel for testing subprocess signaling"""
implementation = 'signaltest'
implementation_version = '0.0'
banner = ''
def __init__(self, **kwargs):
kwargs.pop('user_ns', None)
super().__init__(**kwargs)
self.children = []
def do_execute(self, code, silent, store_history=True, user_expressions=None,
allow_stdin=False):
code = code.strip()
reply = {
'status': 'ok',
'user_expressions': {},
}
if code == 'start':
child = Popen(['bash', '-i', '-c', 'sleep 30'], stderr=PIPE)
self.children.append(child)
reply['user_expressions']['pid'] = self.children[-1].pid
elif code == 'check':
reply['user_expressions']['poll'] = [ child.poll() for child in self.children ]
elif code == 'env':
reply['user_expressions']['env'] = os.getenv("TEST_VARS", "")
elif code == 'sleep':
try:
time.sleep(10)
except KeyboardInterrupt:
reply['user_expressions']['interrupted'] = True
else:
reply['user_expressions']['interrupted'] = False
else:
reply['status'] = 'error'
reply['ename'] = 'Error'
reply['evalue'] = code
reply['traceback'] = ['no such command: %s' % code]
return reply
def kernel_info_request(self, *args, **kwargs):
"""Add delay to kernel_info_request
triggers slow-response code in KernelClient.wait_for_ready
"""
return super().kernel_info_request(*args, **kwargs)
class SignalTestApp(IPKernelApp):
kernel_class = SignalTestKernel
def init_io(self):
# Overridden to disable stdout/stderr capture
self.displayhook = ZMQDisplayHook(self.session, self.iopub_socket)
if __name__ == '__main__':
# make startup artificially slow,
# so that we exercise client logic for slow-starting kernels
time.sleep(2)
SignalTestApp.launch_instance()

View file

@ -0,0 +1,404 @@
"""Tests for adapting Jupyter msg spec versions"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import copy
import json
from unittest import TestCase
from jupyter_client.adapter import adapt, V4toV5, V5toV4, code_to_line
from jupyter_client.session import Session
def test_default_version():
s = Session()
msg = s.msg("msg_type")
msg['header'].pop('version')
original = copy.deepcopy(msg)
adapted = adapt(original)
assert adapted['header']['version'] == V4toV5.version
def test_code_to_line_no_code():
line, pos = code_to_line("", 0)
assert line == ""
assert pos == 0
class AdapterTest(TestCase):
def setUp(self):
self.session = Session()
def adapt(self, msg, version=None):
original = copy.deepcopy(msg)
adapted = adapt(msg, version or self.to_version)
return original, adapted
def check_header(self, msg):
pass
class V4toV5TestCase(AdapterTest):
from_version = 4
to_version = 5
def msg(self, msg_type, content):
"""Create a v4 msg (same as v5, minus version header)"""
msg = self.session.msg(msg_type, content)
msg['header'].pop('version')
return msg
def test_same_version(self):
msg = self.msg("execute_result",
content={'status' : 'ok'}
)
original, adapted = self.adapt(msg, self.from_version)
self.assertEqual(original, adapted)
def test_no_adapt(self):
msg = self.msg("input_reply", {'value' : 'some text'})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['header']['version'], V4toV5.version)
v5['header'].pop('version')
self.assertEqual(v4, v5)
def test_rename_type(self):
for v5_type, v4_type in [
('execute_result', 'pyout'),
('execute_input', 'pyin'),
('error', 'pyerr'),
]:
msg = self.msg(v4_type, {'key' : 'value'})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['header']['version'], V4toV5.version)
self.assertEqual(v5['header']['msg_type'], v5_type)
self.assertEqual(v4['content'], v5['content'])
def test_execute_request(self):
msg = self.msg("execute_request", {
'code' : 'a=5',
'silent' : False,
'user_expressions' : {'a' : 'apple'},
'user_variables' : ['b'],
})
v4, v5 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], v5['header']['msg_type'])
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c['user_expressions'], {'a' : 'apple', 'b': 'b'})
self.assertNotIn('user_variables', v5c)
self.assertEqual(v5c['code'], v4c['code'])
def test_execute_reply(self):
msg = self.msg("execute_reply", {
'status': 'ok',
'execution_count': 7,
'user_variables': {'a': 1},
'user_expressions': {'a+a': 2},
'payload': [{'source':'page', 'text':'blah'}]
})
v4, v5 = self.adapt(msg)
v5c = v5['content']
self.assertNotIn('user_variables', v5c)
self.assertEqual(v5c['user_expressions'], {'a': 1, 'a+a': 2})
self.assertEqual(v5c['payload'], [{'source': 'page',
'data': {'text/plain': 'blah'}}
])
def test_complete_request(self):
msg = self.msg("complete_request", {
'text' : 'a.is',
'line' : 'foo = a.is',
'block' : None,
'cursor_pos' : 10,
})
v4, v5 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
for key in ('text', 'line', 'block'):
self.assertNotIn(key, v5c)
self.assertEqual(v5c['cursor_pos'], v4c['cursor_pos'])
self.assertEqual(v5c['code'], v4c['line'])
def test_complete_reply(self):
msg = self.msg("complete_reply", {
'matched_text' : 'a.is',
'matches' : ['a.isalnum',
'a.isalpha',
'a.isdigit',
'a.islower',
],
})
v4, v5 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c['matches'], v4c['matches'])
self.assertEqual(v5c['metadata'], {})
self.assertEqual(v5c['cursor_start'], -4)
self.assertEqual(v5c['cursor_end'], None)
def test_object_info_request(self):
msg = self.msg("object_info_request", {
'oname' : 'foo',
'detail_level' : 1,
})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['header']['msg_type'], 'inspect_request')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c['code'], v4c['oname'])
self.assertEqual(v5c['cursor_pos'], len(v4c['oname']))
self.assertEqual(v5c['detail_level'], v4c['detail_level'])
def test_object_info_reply(self):
msg = self.msg("object_info_reply", {
'name' : 'foo',
'found' : True,
'status' : 'ok',
'definition' : 'foo(a=5)',
'docstring' : "the docstring",
})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['header']['msg_type'], 'inspect_reply')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(sorted(v5c), [ 'data', 'found', 'metadata', 'status'])
text = v5c['data']['text/plain']
self.assertEqual(text, '\n'.join([v4c['definition'], v4c['docstring']]))
def test_object_info_reply_not_found(self):
msg = self.msg("object_info_reply", {
'name' : 'foo',
'found' : False,
})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['header']['msg_type'], 'inspect_reply')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c, {
'status': 'ok',
'found': False,
'data': {},
'metadata': {},
})
def test_kernel_info_reply(self):
msg = self.msg("kernel_info_reply", {
'language': 'python',
'language_version': [2,8,0],
'ipython_version': [1,2,3],
})
v4, v5 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c, {
'protocol_version': '4.1',
'implementation': 'ipython',
'implementation_version': '1.2.3',
'language_info': {
'name': 'python',
'version': '2.8.0',
},
'banner' : '',
})
# iopub channel
def test_display_data(self):
jsondata = dict(a=5)
msg = self.msg("display_data", {
'data' : {
'text/plain' : 'some text',
'application/json' : json.dumps(jsondata)
},
'metadata' : {'text/plain' : { 'key' : 'value' }},
})
v4, v5 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c['metadata'], v4c['metadata'])
self.assertEqual(v5c['data']['text/plain'], v4c['data']['text/plain'])
self.assertEqual(v5c['data']['application/json'], jsondata)
# stdin channel
def test_input_request(self):
msg = self.msg('input_request', {'prompt': "$>"})
v4, v5 = self.adapt(msg)
self.assertEqual(v5['content']['prompt'], v4['content']['prompt'])
self.assertFalse(v5['content']['password'])
class V5toV4TestCase(AdapterTest):
from_version = 5
to_version = 4
def msg(self, msg_type, content):
return self.session.msg(msg_type, content)
def test_same_version(self):
msg = self.msg("execute_result",
content={'status' : 'ok'}
)
original, adapted = self.adapt(msg, self.from_version)
self.assertEqual(original, adapted)
def test_no_adapt(self):
msg = self.msg("input_reply", {'value' : 'some text'})
v5, v4 = self.adapt(msg)
self.assertNotIn('version', v4['header'])
v5['header'].pop('version')
self.assertEqual(v4, v5)
def test_rename_type(self):
for v5_type, v4_type in [
('execute_result', 'pyout'),
('execute_input', 'pyin'),
('error', 'pyerr'),
]:
msg = self.msg(v5_type, {'key' : 'value'})
v5, v4 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], v4_type)
assert 'version' not in v4['header']
self.assertEqual(v4['content'], v5['content'])
def test_execute_request(self):
msg = self.msg("execute_request", {
'code' : 'a=5',
'silent' : False,
'user_expressions' : {'a' : 'apple'},
})
v5, v4 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], v5['header']['msg_type'])
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v4c['user_variables'], [])
self.assertEqual(v5c['code'], v4c['code'])
def test_complete_request(self):
msg = self.msg("complete_request", {
'code' : 'def foo():\n'
' a.is\n'
'foo()',
'cursor_pos': 19,
})
v5, v4 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertNotIn('code', v4c)
self.assertEqual(v4c['line'], v5c['code'].splitlines(True)[1])
self.assertEqual(v4c['cursor_pos'], 8)
self.assertEqual(v4c['text'], '')
self.assertEqual(v4c['block'], None)
def test_complete_reply(self):
msg = self.msg("complete_reply", {
'cursor_start' : 10,
'cursor_end' : 14,
'matches' : ['a.isalnum',
'a.isalpha',
'a.isdigit',
'a.islower',
],
'metadata' : {},
})
v5, v4 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v4c['matched_text'], 'a.is')
self.assertEqual(v4c['matches'], v5c['matches'])
def test_inspect_request(self):
msg = self.msg("inspect_request", {
'code' : 'def foo():\n'
' apple\n'
'bar()',
'cursor_pos': 18,
'detail_level' : 1,
})
v5, v4 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], 'object_info_request')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v4c['oname'], 'apple')
self.assertEqual(v5c['detail_level'], v4c['detail_level'])
def test_inspect_request_token(self):
line = 'something(range(10), kwarg=smth) ; xxx.xxx.xxx( firstarg, rand(234,23), kwarg1=2,'
msg = self.msg("inspect_request", {
'code' : line,
'cursor_pos': len(line)-1,
'detail_level' : 1,
})
v5, v4 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], 'object_info_request')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v4c['oname'], 'xxx.xxx.xxx')
self.assertEqual(v5c['detail_level'], v4c['detail_level'])
def test_inspect_reply(self):
msg = self.msg("inspect_reply", {
'name' : 'foo',
'found' : True,
'data' : {'text/plain' : 'some text'},
'metadata' : {},
})
v5, v4 = self.adapt(msg)
self.assertEqual(v4['header']['msg_type'], 'object_info_reply')
v4c = v4['content']
v5c = v5['content']
self.assertEqual(sorted(v4c), ['found', 'oname'])
self.assertEqual(v4c['found'], False)
def test_kernel_info_reply(self):
msg = self.msg("kernel_info_reply", {
'protocol_version': '5.0',
'implementation': 'ipython',
'implementation_version': '1.2.3',
'language_info': {
'name': 'python',
'version': '2.8.0',
'mimetype': 'text/x-python',
},
'banner' : 'the banner',
})
v5, v4 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
info = v5c['language_info']
self.assertEqual(v4c, {
'protocol_version': [5,0],
'language': 'python',
'language_version': [2,8,0],
'ipython_version': [1,2,3],
})
# iopub channel
def test_display_data(self):
jsondata = dict(a=5)
msg = self.msg("display_data", {
'data' : {
'text/plain' : 'some text',
'application/json' : jsondata,
},
'metadata' : {'text/plain' : { 'key' : 'value' }},
})
v5, v4 = self.adapt(msg)
v4c = v4['content']
v5c = v5['content']
self.assertEqual(v5c['metadata'], v4c['metadata'])
self.assertEqual(v5c['data']['text/plain'], v4c['data']['text/plain'])
self.assertEqual(v4c['data']['application/json'], json.dumps(jsondata))
# stdin channel
def test_input_request(self):
msg = self.msg('input_request', {'prompt': "$>", 'password' : True})
v5, v4 = self.adapt(msg)
self.assertEqual(v5['content']['prompt'], v4['content']['prompt'])
self.assertNotIn('password', v4['content'])

View file

@ -0,0 +1,87 @@
"""Tests for the KernelClient"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import os
pjoin = os.path.join
from unittest import TestCase
from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME
from ..manager import start_new_kernel
from .utils import test_env
import pytest
from IPython.utils.capture import capture_output
TIMEOUT = 30
class TestKernelClient(TestCase):
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
self.addCleanup(self.env_patch.stop)
try:
KernelSpecManager().get_kernel_spec(NATIVE_KERNEL_NAME)
except NoSuchKernel:
pytest.skip()
self.km, self.kc = start_new_kernel(kernel_name=NATIVE_KERNEL_NAME)
self.addCleanup(self.kc.stop_channels)
self.addCleanup(self.km.shutdown_kernel)
def test_execute_interactive(self):
kc = self.kc
with capture_output() as io:
reply = kc.execute_interactive("print('hello')", timeout=TIMEOUT)
assert 'hello' in io.stdout
assert reply['content']['status'] == 'ok'
def _check_reply(self, reply_type, reply):
self.assertIsInstance(reply, dict)
self.assertEqual(reply['header']['msg_type'], reply_type + '_reply')
self.assertEqual(reply['parent_header']['msg_type'], reply_type + '_request')
def test_history(self):
kc = self.kc
msg_id = kc.history(session=0)
self.assertIsInstance(msg_id, str)
reply = kc.history(session=0, reply=True, timeout=TIMEOUT)
self._check_reply('history', reply)
def test_inspect(self):
kc = self.kc
msg_id = kc.inspect('who cares')
self.assertIsInstance(msg_id, str)
reply = kc.inspect('code', reply=True, timeout=TIMEOUT)
self._check_reply('inspect', reply)
def test_complete(self):
kc = self.kc
msg_id = kc.complete('who cares')
self.assertIsInstance(msg_id, str)
reply = kc.complete('code', reply=True, timeout=TIMEOUT)
self._check_reply('complete', reply)
def test_kernel_info(self):
kc = self.kc
msg_id = kc.kernel_info()
self.assertIsInstance(msg_id, str)
reply = kc.kernel_info(reply=True, timeout=TIMEOUT)
self._check_reply('kernel_info', reply)
def test_comm_info(self):
kc = self.kc
msg_id = kc.comm_info()
self.assertIsInstance(msg_id, str)
reply = kc.comm_info(reply=True, timeout=TIMEOUT)
self._check_reply('comm_info', reply)
def test_shutdown(self):
kc = self.kc
msg_id = kc.shutdown()
self.assertIsInstance(msg_id, str)
reply = kc.shutdown(reply=True, timeout=TIMEOUT)
self._check_reply('shutdown', reply)

View file

@ -0,0 +1,259 @@
"""Tests for kernel connection utilities"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import os
import re
import stat
import tempfile
import shutil
from traitlets.config import Config
from jupyter_core.application import JupyterApp
from jupyter_core.paths import jupyter_runtime_dir
from ipython_genutils.tempdir import TemporaryDirectory, TemporaryWorkingDirectory
from ipython_genutils.py3compat import str_to_bytes
from jupyter_client import connect, KernelClient
from jupyter_client.consoleapp import JupyterConsoleApp
from jupyter_client.session import Session
from jupyter_client.connect import secure_write
class DummyConsoleApp(JupyterApp, JupyterConsoleApp):
def initialize(self, argv=[]):
JupyterApp.initialize(self, argv=argv)
self.init_connection_file()
class DummyConfigurable(connect.ConnectionFileMixin):
def initialize(self):
pass
sample_info = dict(ip='1.2.3.4', transport='ipc',
shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
key=b'abc123', signature_scheme='hmac-md5', kernel_name='python'
)
sample_info_kn = dict(ip='1.2.3.4', transport='ipc',
shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5,
key=b'abc123', signature_scheme='hmac-md5', kernel_name='test'
)
def test_write_connection_file():
with TemporaryDirectory() as d:
cf = os.path.join(d, 'kernel.json')
connect.write_connection_file(cf, **sample_info)
assert os.path.exists(cf)
with open(cf, 'r') as f:
info = json.load(f)
info['key'] = str_to_bytes(info['key'])
assert info == sample_info
def test_load_connection_file_session():
"""test load_connection_file() after """
session = Session()
app = DummyConsoleApp(session=Session())
app.initialize(argv=[])
session = app.session
with TemporaryDirectory() as d:
cf = os.path.join(d, 'kernel.json')
connect.write_connection_file(cf, **sample_info)
app.connection_file = cf
app.load_connection_file()
assert session.key == sample_info['key']
assert session.signature_scheme == sample_info['signature_scheme']
def test_load_connection_file_session_with_kn():
"""test load_connection_file() after """
session = Session()
app = DummyConsoleApp(session=Session())
app.initialize(argv=[])
session = app.session
with TemporaryDirectory() as d:
cf = os.path.join(d, 'kernel.json')
connect.write_connection_file(cf, **sample_info_kn)
app.connection_file = cf
app.load_connection_file()
assert session.key == sample_info_kn['key']
assert session.signature_scheme == sample_info_kn['signature_scheme']
def test_app_load_connection_file():
"""test `ipython console --existing` loads a connection file"""
with TemporaryDirectory() as d:
cf = os.path.join(d, 'kernel.json')
connect.write_connection_file(cf, **sample_info)
app = DummyConsoleApp(connection_file=cf)
app.initialize(argv=[])
for attr, expected in sample_info.items():
if attr in ('key', 'signature_scheme'):
continue
value = getattr(app, attr)
assert value == expected, "app.%s = %s != %s" % (attr, value, expected)
def test_load_connection_info():
client = KernelClient()
info = {
'control_port': 53702,
'hb_port': 53705,
'iopub_port': 53703,
'ip': '0.0.0.0',
'key': 'secret',
'shell_port': 53700,
'signature_scheme': 'hmac-sha256',
'stdin_port': 53701,
'transport': 'tcp',
}
client.load_connection_info(info)
assert client.control_port == info['control_port']
assert client.session.key.decode('ascii') == info['key']
assert client.ip == info['ip']
def test_find_connection_file():
with TemporaryDirectory() as d:
cf = 'kernel.json'
app = DummyConsoleApp(runtime_dir=d, connection_file=cf)
app.initialize()
security_dir = app.runtime_dir
profile_cf = os.path.join(security_dir, cf)
with open(profile_cf, 'w') as f:
f.write("{}")
for query in (
'kernel.json',
'kern*',
'*ernel*',
'k*',
):
assert connect.find_connection_file(query, path=security_dir) == profile_cf
def test_find_connection_file_local():
with TemporaryWorkingDirectory() as d:
cf = 'test.json'
abs_cf = os.path.abspath(cf)
with open(cf, 'w') as f:
f.write('{}')
for query in (
'test.json',
'test',
abs_cf,
os.path.join('.', 'test.json'),
):
assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf
def test_find_connection_file_relative():
with TemporaryWorkingDirectory() as d:
cf = 'test.json'
os.mkdir('subdir')
cf = os.path.join('subdir', 'test.json')
abs_cf = os.path.abspath(cf)
with open(cf, 'w') as f:
f.write('{}')
for query in (
os.path.join('.', 'subdir', 'test.json'),
os.path.join('subdir', 'test.json'),
abs_cf,
):
assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf
def test_find_connection_file_abspath():
with TemporaryDirectory() as d:
cf = 'absolute.json'
abs_cf = os.path.abspath(cf)
with open(cf, 'w') as f:
f.write('{}')
assert connect.find_connection_file(abs_cf, path=jupyter_runtime_dir()) == abs_cf
os.remove(abs_cf)
def test_mixin_record_random_ports():
with TemporaryDirectory() as d:
dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
dc.write_connection_file()
assert dc._connection_file_written
assert os.path.exists(dc.connection_file)
assert dc._random_port_names == connect.port_names
def test_mixin_cleanup_random_ports():
with TemporaryDirectory() as d:
dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp')
dc.write_connection_file()
filename = dc.connection_file
dc.cleanup_random_ports()
assert not os.path.exists(filename)
for name in dc._random_port_names:
assert getattr(dc, name) == 0
def test_secure_write():
def fetch_win32_permissions(filename):
'''Extracts file permissions on windows using icacls'''
role_permissions = {}
for index, line in enumerate(os.popen("icacls %s" % filename).read().splitlines()):
if index == 0:
line = line.split(filename)[-1].strip().lower()
match = re.match(r'\s*([^:]+):\(([^\)]*)\)', line)
if match:
usergroup, permissions = match.groups()
usergroup = usergroup.lower().split('\\')[-1]
permissions = set(p.lower() for p in permissions.split(','))
role_permissions[usergroup] = permissions
elif not line.strip():
break
return role_permissions
def check_user_only_permissions(fname):
# Windows has it's own permissions ACL patterns
if os.name == 'nt':
import win32api
username = win32api.GetUserName().lower()
permissions = fetch_win32_permissions(fname)
assert username in permissions
assert permissions[username] == set(['r', 'w'])
assert 'administrators' in permissions
assert permissions['administrators'] == set(['f'])
assert 'everyone' not in permissions
assert len(permissions) == 2
else:
mode = os.stat(fname).st_mode
assert '0600' == oct(stat.S_IMODE(mode)).replace('0o', '0')
directory = tempfile.mkdtemp()
fname = os.path.join(directory, 'check_perms')
try:
with secure_write(fname) as f:
f.write('test 1')
check_user_only_permissions(fname)
with open(fname, 'r') as f:
assert f.read() == 'test 1'
# Try changing file permissions ahead of time
os.chmod(fname, 0o755)
with secure_write(fname) as f:
f.write('test 2')
check_user_only_permissions(fname)
with open(fname, 'r') as f:
assert f.read() == 'test 2'
finally:
shutil.rmtree(directory)

View file

@ -0,0 +1,85 @@
"""Test suite for our JSON utilities."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import json
import pytest
import datetime
from datetime import timedelta
from unittest import mock
from dateutil.tz import tzlocal, tzoffset
from jupyter_client import jsonutil
from jupyter_client.session import utcnow
REFERENCE_DATETIME = datetime.datetime(
2013, 7, 3, 16, 34, 52, 249482, tzlocal()
)
def test_extract_date_from_naive():
ref = REFERENCE_DATETIME
timestamp = '2013-07-03T16:34:52.249482'
with pytest.deprecated_call(match='Interpreting naive datetime as local'):
extracted = jsonutil.extract_dates(timestamp)
assert isinstance(extracted, datetime.datetime)
assert extracted.tzinfo is not None
assert extracted.tzinfo.utcoffset(ref) == tzlocal().utcoffset(ref)
assert extracted == ref
def test_extract_dates():
ref = REFERENCE_DATETIME
timestamps = [
'2013-07-03T16:34:52.249482Z',
'2013-07-03T16:34:52.249482-0800',
'2013-07-03T16:34:52.249482+0800',
'2013-07-03T16:34:52.249482-08:00',
'2013-07-03T16:34:52.249482+08:00',
]
extracted = jsonutil.extract_dates(timestamps)
for dt in extracted:
assert isinstance(dt, datetime.datetime)
assert dt.tzinfo is not None
assert extracted[0].tzinfo.utcoffset(ref) == timedelta(0)
assert extracted[1].tzinfo.utcoffset(ref) == timedelta(hours=-8)
assert extracted[2].tzinfo.utcoffset(ref) == timedelta(hours=8)
assert extracted[3].tzinfo.utcoffset(ref) == timedelta(hours=-8)
assert extracted[4].tzinfo.utcoffset(ref) == timedelta(hours=8)
def test_parse_ms_precision():
base = '2013-07-03T16:34:52'
digits = '1234567890'
parsed = jsonutil.parse_date(base+'Z')
assert isinstance(parsed, datetime.datetime)
for i in range(len(digits)):
ts = base + '.' + digits[:i]
parsed = jsonutil.parse_date(ts+'Z')
if i >= 1 and i <= 6:
assert isinstance(parsed, datetime.datetime)
else:
assert isinstance(parsed, str)
def test_date_default():
naive = datetime.datetime.now()
local = tzoffset('Local', -8 * 3600)
other = tzoffset('Other', 2 * 3600)
data = dict(naive=naive, utc=utcnow(), withtz=naive.replace(tzinfo=other))
with mock.patch.object(jsonutil, 'tzlocal', lambda : local):
with pytest.deprecated_call(match='Please add timezone info'):
jsondata = json.dumps(data, default=jsonutil.date_default)
assert "Z" in jsondata
assert jsondata.count("Z") == 1
extracted = jsonutil.extract_dates(json.loads(jsondata))
for dt in extracted.values():
assert isinstance(dt, datetime.datetime)
assert dt.tzinfo != None

View file

@ -0,0 +1,50 @@
import os
import sys
import shutil
import time
from subprocess import Popen, PIPE
from tempfile import mkdtemp
def _launch(extra_env):
env = os.environ.copy()
env.update(extra_env)
return Popen([sys.executable, '-c',
'from jupyter_client.kernelapp import main; main()'],
env=env, stderr=PIPE)
WAIT_TIME = 10
POLL_FREQ = 10
def test_kernelapp_lifecycle():
# Check that 'jupyter kernel' starts and terminates OK.
runtime_dir = mkdtemp()
startup_dir = mkdtemp()
started = os.path.join(startup_dir, 'started')
try:
p = _launch({'JUPYTER_RUNTIME_DIR': runtime_dir,
'JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE': started,
})
# Wait for start
for _ in range(WAIT_TIME * POLL_FREQ):
if os.path.isfile(started):
break
time.sleep(1 / POLL_FREQ)
else:
raise AssertionError("No started file created in {} seconds"
.format(WAIT_TIME))
# Connection file should be there by now
files = os.listdir(runtime_dir)
assert len(files) == 1
cf = files[0]
assert cf.startswith('kernel')
assert cf.endswith('.json')
# Send SIGTERM to shut down
p.terminate()
_, stderr = p.communicate(timeout=WAIT_TIME)
assert cf in stderr.decode('utf-8', 'replace')
finally:
shutil.rmtree(runtime_dir)
shutil.rmtree(startup_dir)

View file

@ -0,0 +1,414 @@
"""Tests for the KernelManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import json
import os
import signal
import sys
import time
import threading
import multiprocessing as mp
import pytest
from async_generator import async_generator, yield_
from traitlets.config.loader import Config
from jupyter_core import paths
from jupyter_client import KernelManager, AsyncKernelManager
from subprocess import PIPE
from ..manager import start_new_kernel, start_new_async_kernel
from .utils import test_env, skip_win32, AsyncKernelManagerSubclass, AsyncKernelManagerWithCleanup
pjoin = os.path.join
TIMEOUT = 30
@pytest.fixture(autouse=True)
def env():
env_patch = test_env()
env_patch.start()
yield
env_patch.stop()
@pytest.fixture(params=['tcp', 'ipc'])
def transport(request):
if sys.platform == 'win32' and request.param == 'ipc': #
pytest.skip("Transport 'ipc' not supported on Windows.")
return request.param
@pytest.fixture
def config(transport):
c = Config()
c.KernelManager.transport = transport
if transport == 'ipc':
c.KernelManager.ip = 'test'
return c
@pytest.fixture
def install_kernel():
kernel_dir = pjoin(paths.jupyter_data_dir(), 'kernels', 'signaltest')
os.makedirs(kernel_dir)
with open(pjoin(kernel_dir, 'kernel.json'), 'w') as f:
f.write(json.dumps({
'argv': [sys.executable,
'-m', 'jupyter_client.tests.signalkernel',
'-f', '{connection_file}'],
'display_name': "Signal Test Kernel",
'env': {'TEST_VARS': '${TEST_VARS}:test_var_2'},
}))
@pytest.fixture
def start_kernel():
km, kc = start_new_kernel(kernel_name='signaltest')
yield km, kc
kc.stop_channels()
km.shutdown_kernel()
assert km.context.closed
@pytest.fixture
def start_kernel_w_env():
kernel_cmd = [sys.executable,
'-m', 'jupyter_client.tests.signalkernel',
'-f', '{connection_file}']
extra_env = {'TEST_VARS': '${TEST_VARS}:test_var_2'}
km = KernelManager(kernel_name='signaltest')
km.kernel_cmd = kernel_cmd
km.extra_env = extra_env
km.start_kernel()
kc = km.client()
kc.start_channels()
kc.wait_for_ready(timeout=60)
yield km, kc
kc.stop_channels()
km.shutdown_kernel()
@pytest.fixture
def km(config):
km = KernelManager(config=config)
return km
@pytest.fixture
def zmq_context():
import zmq
ctx = zmq.Context()
yield ctx
ctx.term()
@pytest.fixture(params=[AsyncKernelManager, AsyncKernelManagerSubclass, AsyncKernelManagerWithCleanup])
def async_km(request, config):
km = request.param(config=config)
return km
@pytest.fixture
@async_generator # This is only necessary while Python 3.5 is support afterwhich both it and yield_() can be removed
async def start_async_kernel():
km, kc = await start_new_async_kernel(kernel_name='signaltest')
await yield_((km, kc))
kc.stop_channels()
await km.shutdown_kernel()
assert km.context.closed
class TestKernelManager:
def test_lifecycle(self, km):
km.start_kernel(stdout=PIPE, stderr=PIPE)
assert km.is_alive()
km.restart_kernel(now=True)
assert km.is_alive()
km.interrupt_kernel()
assert isinstance(km, KernelManager)
km.shutdown_kernel(now=True)
assert km.context.closed
def test_get_connect_info(self, km):
cinfo = km.get_connection_info()
keys = sorted(cinfo.keys())
expected = sorted([
'ip', 'transport',
'hb_port', 'shell_port', 'stdin_port', 'iopub_port', 'control_port',
'key', 'signature_scheme',
])
assert keys == expected
@pytest.mark.skipif(sys.platform == 'win32', reason="Windows doesn't support signals")
def test_signal_kernel_subprocesses(self, install_kernel, start_kernel):
km, kc = start_kernel
def execute(cmd):
kc.execute(cmd)
reply = kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
return content
N = 5
for i in range(N):
execute("start")
time.sleep(1) # make sure subprocs stay up
reply = execute('check')
assert reply['user_expressions']['poll'] == [None] * N
# start a job on the kernel to be interrupted
kc.execute('sleep')
time.sleep(1) # ensure sleep message has been handled before we interrupt
km.interrupt_kernel()
reply = kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
assert content['user_expressions']['interrupted']
# wait up to 5s for subprocesses to handle signal
for i in range(50):
reply = execute('check')
if reply['user_expressions']['poll'] != [-signal.SIGINT] * N:
time.sleep(0.1)
else:
break
# verify that subprocesses were interrupted
assert reply['user_expressions']['poll'] == [-signal.SIGINT] * N
def test_start_new_kernel(self, install_kernel, start_kernel):
km, kc = start_kernel
assert km.is_alive()
assert kc.is_alive()
assert km.context.closed is False
def _env_test_body(self, kc):
def execute(cmd):
kc.execute(cmd)
reply = kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
return content
reply = execute('env')
assert reply is not None
assert reply['user_expressions']['env'] == 'test_var_1:test_var_2'
def test_templated_kspec_env(self, install_kernel, start_kernel):
km, kc = start_kernel
assert km.is_alive()
assert kc.is_alive()
assert km.context.closed is False
self._env_test_body(kc)
def test_templated_extra_env(self, install_kernel, start_kernel_w_env):
km, kc = start_kernel_w_env
assert km.is_alive()
assert kc.is_alive()
assert km.context.closed is False
self._env_test_body(kc)
def test_cleanup_context(self, km):
assert km.context is not None
km.cleanup_resources(restart=False)
assert km.context.closed
def test_no_cleanup_shared_context(self, zmq_context):
"""kernel manager does not terminate shared context"""
km = KernelManager(context=zmq_context)
assert km.context == zmq_context
assert km.context is not None
km.cleanup_resources(restart=False)
assert km.context.closed is False
assert zmq_context.closed is False
class TestParallel:
@pytest.mark.timeout(TIMEOUT)
def test_start_sequence_kernels(self, config, install_kernel):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_signaltest_lifecycle(config)
self._run_signaltest_lifecycle(config)
self._run_signaltest_lifecycle(config)
@pytest.mark.timeout(TIMEOUT)
def test_start_parallel_thread_kernels(self, config, install_kernel):
if config.KernelManager.transport == 'ipc': # FIXME
pytest.skip("IPC transport is currently not working for this test!")
self._run_signaltest_lifecycle(config)
thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,))
thread2 = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,))
try:
thread.start()
thread2.start()
finally:
thread.join()
thread2.join()
@pytest.mark.timeout(TIMEOUT)
def test_start_parallel_process_kernels(self, config, install_kernel):
if config.KernelManager.transport == 'ipc': # FIXME
pytest.skip("IPC transport is currently not working for this test!")
self._run_signaltest_lifecycle(config)
thread = threading.Thread(target=self._run_signaltest_lifecycle, args=(config,))
proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,))
try:
thread.start()
proc.start()
finally:
thread.join()
proc.join()
assert proc.exitcode == 0
@pytest.mark.timeout(TIMEOUT)
def test_start_sequence_process_kernels(self, config, install_kernel):
self._run_signaltest_lifecycle(config)
proc = mp.Process(target=self._run_signaltest_lifecycle, args=(config,))
try:
proc.start()
finally:
proc.join()
assert proc.exitcode == 0
def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs):
km.start_kernel(**kwargs)
kc = km.client()
kc.start_channels()
try:
kc.wait_for_ready(timeout=startup_timeout)
except RuntimeError:
kc.stop_channels()
km.shutdown_kernel()
raise
return kc
def _run_signaltest_lifecycle(self, config=None):
km = KernelManager(config=config, kernel_name='signaltest')
kc = self._prepare_kernel(km, stdout=PIPE, stderr=PIPE)
def execute(cmd):
kc.execute(cmd)
reply = kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
return content
execute("start")
assert km.is_alive()
execute('check')
assert km.is_alive()
km.restart_kernel(now=True)
assert km.is_alive()
execute('check')
km.shutdown_kernel()
assert km.context.closed
@pytest.mark.asyncio
class TestAsyncKernelManager:
async def test_lifecycle(self, async_km):
await async_km.start_kernel(stdout=PIPE, stderr=PIPE)
is_alive = await async_km.is_alive()
assert is_alive
await async_km.restart_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive
await async_km.interrupt_kernel()
assert isinstance(async_km, AsyncKernelManager)
await async_km.shutdown_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive is False
assert async_km.context.closed
async def test_get_connect_info(self, async_km):
cinfo = async_km.get_connection_info()
keys = sorted(cinfo.keys())
expected = sorted([
'ip', 'transport',
'hb_port', 'shell_port', 'stdin_port', 'iopub_port', 'control_port',
'key', 'signature_scheme',
])
assert keys == expected
async def test_subclasses(self, async_km):
await async_km.start_kernel(stdout=PIPE, stderr=PIPE)
is_alive = await async_km.is_alive()
assert is_alive
assert isinstance(async_km, AsyncKernelManager)
await async_km.shutdown_kernel(now=True)
is_alive = await async_km.is_alive()
assert is_alive is False
assert async_km.context.closed
if isinstance(async_km, AsyncKernelManagerWithCleanup):
assert async_km.which_cleanup == "cleanup"
elif isinstance(async_km, AsyncKernelManagerSubclass):
assert async_km.which_cleanup == "cleanup_resources"
else:
assert hasattr(async_km, "which_cleanup") is False
@pytest.mark.timeout(10)
@pytest.mark.skipif(sys.platform == 'win32', reason="Windows doesn't support signals")
async def test_signal_kernel_subprocesses(self, install_kernel, start_async_kernel):
km, kc = start_async_kernel
async def execute(cmd):
kc.execute(cmd)
reply = await kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
return content
# Ensure that shutdown_kernel and stop_channels are called at the end of the test.
# Note: we cannot use addCleanup(<func>) for these since it doesn't prpperly handle
# coroutines - which km.shutdown_kernel now is.
N = 5
for i in range(N):
await execute("start")
await asyncio.sleep(1) # make sure subprocs stay up
reply = await execute('check')
assert reply['user_expressions']['poll'] == [None] * N
# start a job on the kernel to be interrupted
kc.execute('sleep')
await asyncio.sleep(1) # ensure sleep message has been handled before we interrupt
await km.interrupt_kernel()
reply = await kc.get_shell_msg(TIMEOUT)
content = reply['content']
assert content['status'] == 'ok'
assert content['user_expressions']['interrupted'] is True
# wait up to 5s for subprocesses to handle signal
for i in range(50):
reply = await execute('check')
if reply['user_expressions']['poll'] != [-signal.SIGINT] * N:
await asyncio.sleep(0.1)
else:
break
# verify that subprocesses were interrupted
assert reply['user_expressions']['poll'] == [-signal.SIGINT] * N
@pytest.mark.timeout(10)
async def test_start_new_async_kernel(self, install_kernel, start_async_kernel):
km, kc = start_async_kernel
is_alive = await km.is_alive()
assert is_alive
is_alive = await kc.is_alive()
assert is_alive

View file

@ -0,0 +1,193 @@
"""Tests for the KernelSpecManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import pytest
import copy
import io
import json
import os
import sys
import tempfile
import unittest
from io import StringIO
from os.path import join as pjoin
from subprocess import Popen, PIPE, STDOUT
from logging import StreamHandler
from ipython_genutils.tempdir import TemporaryDirectory
from jupyter_client import kernelspec
from jupyter_core import paths
from .utils import test_env
sample_kernel_json = {'argv':['cat', '{connection_file}'],
'display_name':'Test kernel',
}
class KernelSpecTests(unittest.TestCase):
def _install_sample_kernel(self, kernels_dir):
"""install a sample kernel in a kernels directory"""
sample_kernel_dir = pjoin(kernels_dir, 'sample')
os.makedirs(sample_kernel_dir)
json_file = pjoin(sample_kernel_dir, 'kernel.json')
with open(json_file, 'w') as f:
json.dump(sample_kernel_json, f)
return sample_kernel_dir
def setUp(self):
self.env_patch = test_env()
self.env_patch.start()
self.sample_kernel_dir = self._install_sample_kernel(
pjoin(paths.jupyter_data_dir(), 'kernels'))
self.ksm = kernelspec.KernelSpecManager()
td2 = TemporaryDirectory()
self.addCleanup(td2.cleanup)
self.installable_kernel = td2.name
with open(pjoin(self.installable_kernel, 'kernel.json'), 'w') as f:
json.dump(sample_kernel_json, f)
def tearDown(self):
self.env_patch.stop()
def test_find_kernel_specs(self):
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels['sample'], self.sample_kernel_dir)
def test_get_kernel_spec(self):
ks = self.ksm.get_kernel_spec('SAMPLE') # Case insensitive
self.assertEqual(ks.resource_dir, self.sample_kernel_dir)
self.assertEqual(ks.argv, sample_kernel_json['argv'])
self.assertEqual(ks.display_name, sample_kernel_json['display_name'])
self.assertEqual(ks.env, {})
self.assertEqual(ks.metadata, {})
def test_find_all_specs(self):
kernels = self.ksm.get_all_specs()
self.assertEqual(kernels['sample']['resource_dir'], self.sample_kernel_dir)
self.assertIsNotNone(kernels['sample']['spec'])
def test_kernel_spec_priority(self):
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
sample_kernel = self._install_sample_kernel(td.name)
self.ksm.kernel_dirs.append(td.name)
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels['sample'], self.sample_kernel_dir)
self.ksm.kernel_dirs.insert(0, td.name)
kernels = self.ksm.find_kernel_specs()
self.assertEqual(kernels['sample'], sample_kernel)
def test_install_kernel_spec(self):
self.ksm.install_kernel_spec(self.installable_kernel,
kernel_name='tstinstalled',
user=True)
self.assertIn('tstinstalled', self.ksm.find_kernel_specs())
# install again works
self.ksm.install_kernel_spec(self.installable_kernel,
kernel_name='tstinstalled',
user=True)
def test_install_kernel_spec_prefix(self):
td = TemporaryDirectory()
self.addCleanup(td.cleanup)
capture = StringIO()
handler = StreamHandler(capture)
self.ksm.log.addHandler(handler)
self.ksm.install_kernel_spec(self.installable_kernel,
kernel_name='tstinstalled',
prefix=td.name)
captured = capture.getvalue()
self.ksm.log.removeHandler(handler)
self.assertIn("may not be found", captured)
self.assertNotIn('tstinstalled', self.ksm.find_kernel_specs())
# add prefix to path, so we find the spec
self.ksm.kernel_dirs.append(pjoin(td.name, 'share', 'jupyter', 'kernels'))
self.assertIn('tstinstalled', self.ksm.find_kernel_specs())
# Run it again, no warning this time because we've added it to the path
capture = StringIO()
handler = StreamHandler(capture)
self.ksm.log.addHandler(handler)
self.ksm.install_kernel_spec(self.installable_kernel,
kernel_name='tstinstalled',
prefix=td.name)
captured = capture.getvalue()
self.ksm.log.removeHandler(handler)
self.assertNotIn("may not be found", captured)
@pytest.mark.skipif(
not (os.name != 'nt' and not os.access('/usr/local/share', os.W_OK)),
reason="needs Unix system without root privileges")
def test_cant_install_kernel_spec(self):
with self.assertRaises(OSError):
self.ksm.install_kernel_spec(self.installable_kernel,
kernel_name='tstinstalled',
user=False)
def test_remove_kernel_spec(self):
path = self.ksm.remove_kernel_spec('sample')
self.assertEqual(path, self.sample_kernel_dir)
def test_remove_kernel_spec_app(self):
p = Popen(
[sys.executable, '-m', 'jupyter_client.kernelspecapp', 'remove', 'sample', '-f'],
stdout=PIPE, stderr=STDOUT,
env=os.environ,
)
out, _ = p.communicate()
self.assertEqual(p.returncode, 0, out.decode('utf8', 'replace'))
def test_validate_kernel_name(self):
for good in [
'julia-0.4',
'ipython',
'R',
'python_3',
'Haskell-1-2-3',
]:
assert kernelspec._is_valid_kernel_name(good)
for bad in [
'has space',
'ünicode',
'%percent',
'question?',
]:
assert not kernelspec._is_valid_kernel_name(bad)
def test_subclass(self):
"""Test get_all_specs in subclasses that override find_kernel_specs"""
ksm = self.ksm
resource_dir = tempfile.gettempdir()
native_name = kernelspec.NATIVE_KERNEL_NAME
native_kernel = ksm.get_kernel_spec(native_name)
class MyKSM(kernelspec.KernelSpecManager):
def get_kernel_spec(self, name):
spec = copy.copy(native_kernel)
if name == 'fake':
spec.name = name
spec.resource_dir = resource_dir
elif name == native_name:
pass
else:
raise KeyError(name)
return spec
def find_kernel_specs(self):
return {
'fake': resource_dir,
native_name: native_kernel.resource_dir,
}
# ensure that get_all_specs doesn't raise if only
# find_kernel_specs and get_kernel_spec are defined
myksm = MyKSM()
specs = myksm.get_all_specs()
assert sorted(specs) == ['fake', native_name]

View file

@ -0,0 +1,15 @@
#-----------------------------------------------------------------------------
# Copyright (c) The Jupyter Development Team
#
# Distributed under the terms of the BSD License. The full license is in
# the file COPYING, distributed as part of this software.
#-----------------------------------------------------------------------------
from .. import localinterfaces
def test_load_ips():
# Override the machinery that skips it if it was called before
localinterfaces._load_ips.called = False
# Just check this doesn't error
localinterfaces._load_ips(suppress_exceptions=False)

View file

@ -0,0 +1,38 @@
"""Tests for KernelManager"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyter_client.kernelspec import KernelSpec
from unittest import mock
from jupyter_client.manager import KernelManager
import os
import tempfile
def test_connection_file_real_path():
""" Verify realpath is used when formatting connection file """
with mock.patch('os.path.realpath') as patched_realpath:
patched_realpath.return_value = 'foobar'
km = KernelManager(connection_file=os.path.join(
tempfile.gettempdir(), "kernel-test.json"),
kernel_name='test_kernel')
# KernelSpec and launch args have to be mocked as we don't have an actual kernel on disk
km._kernel_spec = KernelSpec(resource_dir='test',
**{
"argv": [
"python.exe",
"-m",
"test_kernel",
"-f",
"{connection_file}"
],
"env": {},
"display_name": "test_kernel",
"language": "python",
"metadata": {}
})
km._launch_args = {}
cmds = km.format_kernel_cmd()
assert cmds[4] is 'foobar'

View file

@ -0,0 +1,286 @@
"""Tests for the notebook kernel and session manager."""
import asyncio
import threading
import uuid
import multiprocessing as mp
from subprocess import PIPE
from unittest import TestCase
from tornado.testing import AsyncTestCase, gen_test
from traitlets.config.loader import Config
from jupyter_client import KernelManager, AsyncKernelManager
from jupyter_client.multikernelmanager import MultiKernelManager, AsyncMultiKernelManager
from .utils import skip_win32
from ..localinterfaces import localhost
TIMEOUT = 30
class TestKernelManager(TestCase):
def _get_tcp_km(self):
c = Config()
km = MultiKernelManager(config=c)
return km
def _get_ipc_km(self):
c = Config()
c.KernelManager.transport = 'ipc'
c.KernelManager.ip = 'test'
km = MultiKernelManager(config=c)
return km
def _run_lifecycle(self, km, test_kid=None):
if test_kid:
kid = km.start_kernel(stdout=PIPE, stderr=PIPE, kernel_id=test_kid)
self.assertTrue(kid == test_kid)
else:
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertTrue(km.is_alive(kid))
self.assertTrue(kid in km)
self.assertTrue(kid in km.list_kernel_ids())
self.assertEqual(len(km), 1)
km.restart_kernel(kid, now=True)
self.assertTrue(km.is_alive(kid))
self.assertTrue(kid in km.list_kernel_ids())
km.interrupt_kernel(kid)
k = km.get_kernel(kid)
self.assertTrue(isinstance(k, KernelManager))
km.shutdown_kernel(kid, now=True)
self.assertNotIn(kid, km)
def _run_cinfo(self, km, transport, ip):
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
k = km.get_kernel(kid)
cinfo = km.get_connection_info(kid)
self.assertEqual(transport, cinfo['transport'])
self.assertEqual(ip, cinfo['ip'])
self.assertTrue('stdin_port' in cinfo)
self.assertTrue('iopub_port' in cinfo)
stream = km.connect_iopub(kid)
stream.close()
self.assertTrue('shell_port' in cinfo)
stream = km.connect_shell(kid)
stream.close()
self.assertTrue('hb_port' in cinfo)
stream = km.connect_hb(kid)
stream.close()
km.shutdown_kernel(kid, now=True)
def test_tcp_lifecycle(self):
km = self._get_tcp_km()
self._run_lifecycle(km)
def test_tcp_lifecycle_with_kernel_id(self):
km = self._get_tcp_km()
self._run_lifecycle(km, test_kid=str(uuid.uuid4()))
def test_shutdown_all(self):
km = self._get_tcp_km()
kid = km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
km.shutdown_all()
def test_tcp_cinfo(self):
km = self._get_tcp_km()
self._run_cinfo(km, 'tcp', localhost())
@skip_win32
def test_ipc_lifecycle(self):
km = self._get_ipc_km()
self._run_lifecycle(km)
@skip_win32
def test_ipc_cinfo(self):
km = self._get_ipc_km()
self._run_cinfo(km, 'ipc', 'test')
def test_start_sequence_tcp_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_lifecycle(self._get_tcp_km())
self._run_lifecycle(self._get_tcp_km())
self._run_lifecycle(self._get_tcp_km())
def test_start_sequence_ipc_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
self._run_lifecycle(self._get_ipc_km())
self._run_lifecycle(self._get_ipc_km())
self._run_lifecycle(self._get_ipc_km())
def tcp_lifecycle_with_loop(self):
# Ensure each thread has an event loop
asyncio.set_event_loop(asyncio.new_event_loop())
self.test_tcp_lifecycle()
def test_start_parallel_thread_kernels(self):
self.test_tcp_lifecycle()
thread = threading.Thread(target=self.tcp_lifecycle_with_loop)
thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop)
try:
thread.start()
thread2.start()
finally:
thread.join()
thread2.join()
def test_start_parallel_process_kernels(self):
self.test_tcp_lifecycle()
thread = threading.Thread(target=self.tcp_lifecycle_with_loop)
proc = mp.Process(target=self.test_tcp_lifecycle)
try:
thread.start()
proc.start()
finally:
thread.join()
proc.join()
assert proc.exitcode == 0
class TestAsyncKernelManager(AsyncTestCase):
def _get_tcp_km(self):
c = Config()
km = AsyncMultiKernelManager(config=c)
return km
def _get_ipc_km(self):
c = Config()
c.KernelManager.transport = 'ipc'
c.KernelManager.ip = 'test'
km = AsyncMultiKernelManager(config=c)
return km
async def _run_lifecycle(self, km, test_kid=None):
if test_kid:
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE, kernel_id=test_kid)
self.assertTrue(kid == test_kid)
else:
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertTrue(await km.is_alive(kid))
self.assertTrue(kid in km)
self.assertTrue(kid in km.list_kernel_ids())
self.assertEqual(len(km), 1)
await km.restart_kernel(kid, now=True)
self.assertTrue(await km.is_alive(kid))
self.assertTrue(kid in km.list_kernel_ids())
await km.interrupt_kernel(kid)
k = km.get_kernel(kid)
self.assertTrue(isinstance(k, AsyncKernelManager))
await km.shutdown_kernel(kid, now=True)
self.assertNotIn(kid, km)
async def _run_cinfo(self, km, transport, ip):
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
k = km.get_kernel(kid)
cinfo = km.get_connection_info(kid)
self.assertEqual(transport, cinfo['transport'])
self.assertEqual(ip, cinfo['ip'])
self.assertTrue('stdin_port' in cinfo)
self.assertTrue('iopub_port' in cinfo)
stream = km.connect_iopub(kid)
stream.close()
self.assertTrue('shell_port' in cinfo)
stream = km.connect_shell(kid)
stream.close()
self.assertTrue('hb_port' in cinfo)
stream = km.connect_hb(kid)
stream.close()
await km.shutdown_kernel(kid, now=True)
self.assertNotIn(kid, km)
@gen_test
async def test_tcp_lifecycle(self):
await self.raw_tcp_lifecycle()
@gen_test
async def test_tcp_lifecycle_with_kernel_id(self):
await self.raw_tcp_lifecycle(test_kid=str(uuid.uuid4()))
@gen_test
async def test_shutdown_all(self):
km = self._get_tcp_km()
kid = await km.start_kernel(stdout=PIPE, stderr=PIPE)
self.assertIn(kid, km)
await km.shutdown_all()
self.assertNotIn(kid, km)
# shutdown again is okay, because we have no kernels
await km.shutdown_all()
@gen_test
async def test_tcp_cinfo(self):
km = self._get_tcp_km()
await self._run_cinfo(km, 'tcp', localhost())
@skip_win32
@gen_test
async def test_ipc_lifecycle(self):
km = self._get_ipc_km()
await self._run_lifecycle(km)
@skip_win32
@gen_test
async def test_ipc_cinfo(self):
km = self._get_ipc_km()
await self._run_cinfo(km, 'ipc', 'test')
@gen_test
async def test_start_sequence_tcp_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
await self._run_lifecycle(self._get_tcp_km())
await self._run_lifecycle(self._get_tcp_km())
await self._run_lifecycle(self._get_tcp_km())
@gen_test
async def test_start_sequence_ipc_kernels(self):
"""Ensure that a sequence of kernel startups doesn't break anything."""
await self._run_lifecycle(self._get_ipc_km())
await self._run_lifecycle(self._get_ipc_km())
await self._run_lifecycle(self._get_ipc_km())
def tcp_lifecycle_with_loop(self):
# Ensure each thread has an event loop
asyncio.set_event_loop(asyncio.new_event_loop())
asyncio.get_event_loop().run_until_complete(self.raw_tcp_lifecycle())
async def raw_tcp_lifecycle(self, test_kid=None):
# Since @gen_test creates an event loop, we need a raw form of
# test_tcp_lifecycle that assumes the loop already exists.
km = self._get_tcp_km()
await self._run_lifecycle(km, test_kid=test_kid)
@gen_test
async def test_start_parallel_thread_kernels(self):
await self.raw_tcp_lifecycle()
thread = threading.Thread(target=self.tcp_lifecycle_with_loop)
thread2 = threading.Thread(target=self.tcp_lifecycle_with_loop)
try:
thread.start()
thread2.start()
finally:
thread.join()
thread2.join()
@gen_test
async def test_start_parallel_process_kernels(self):
await self.raw_tcp_lifecycle()
thread = threading.Thread(target=self.tcp_lifecycle_with_loop)
proc = mp.Process(target=self.raw_tcp_lifecycle)
try:
thread.start()
proc.start()
finally:
proc.join()
thread.join()
assert proc.exitcode == 0

View file

@ -0,0 +1,27 @@
"""Test the jupyter_client public API
"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from jupyter_client import launcher, connect
import jupyter_client
def test_kms():
for base in ("", "Multi"):
KM = base + "KernelManager"
assert KM in dir(jupyter_client)
def test_kcs():
for base in ("", "Blocking"):
KM = base + "KernelClient"
assert KM in dir(jupyter_client)
def test_launcher():
for name in launcher.__all__:
assert name in dir(jupyter_client)
def test_connect():
for name in connect.__all__:
assert name in dir(jupyter_client)

View file

@ -0,0 +1,346 @@
"""test building messages with Session"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import hmac
import os
import uuid
from datetime import datetime
from unittest import mock
import pytest
import zmq
from zmq.tests import BaseZMQTestCase
from zmq.eventloop.zmqstream import ZMQStream
from jupyter_client import session as ss
from jupyter_client import jsonutil
def _bad_packer(obj):
raise TypeError("I don't work")
def _bad_unpacker(bytes):
raise TypeError("I don't work either")
class SessionTestCase(BaseZMQTestCase):
def setUp(self):
BaseZMQTestCase.setUp(self)
self.session = ss.Session()
@pytest.fixture
def no_copy_threshold():
"""Disable zero-copy optimizations in pyzmq >= 17"""
with mock.patch.object(zmq, 'COPY_THRESHOLD', 1, create=True):
yield
@pytest.mark.usefixtures('no_copy_threshold')
class TestSession(SessionTestCase):
def test_msg(self):
"""message format"""
msg = self.session.msg('execute')
thekeys = set('header parent_header metadata content msg_type msg_id'.split())
s = set(msg.keys())
self.assertEqual(s, thekeys)
self.assertTrue(isinstance(msg['content'],dict))
self.assertTrue(isinstance(msg['metadata'],dict))
self.assertTrue(isinstance(msg['header'],dict))
self.assertTrue(isinstance(msg['parent_header'],dict))
self.assertTrue(isinstance(msg['msg_id'], str))
self.assertTrue(isinstance(msg['msg_type'], str))
self.assertEqual(msg['header']['msg_type'], 'execute')
self.assertEqual(msg['msg_type'], 'execute')
def test_serialize(self):
msg = self.session.msg('execute', content=dict(a=10, b=1.1))
msg_list = self.session.serialize(msg, ident=b'foo')
ident, msg_list = self.session.feed_identities(msg_list)
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b'foo')
self.assertEqual(new_msg['msg_id'],msg['msg_id'])
self.assertEqual(new_msg['msg_type'],msg['msg_type'])
self.assertEqual(new_msg['header'],msg['header'])
self.assertEqual(new_msg['content'],msg['content'])
self.assertEqual(new_msg['parent_header'],msg['parent_header'])
self.assertEqual(new_msg['metadata'],msg['metadata'])
# ensure floats don't come out as Decimal:
self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
def test_default_secure(self):
self.assertIsInstance(self.session.key, bytes)
self.assertIsInstance(self.session.auth, hmac.HMAC)
def test_send(self):
ctx = zmq.Context()
A = ctx.socket(zmq.PAIR)
B = ctx.socket(zmq.PAIR)
A.bind("inproc://test")
B.connect("inproc://test")
msg = self.session.msg('execute', content=dict(a=10))
self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
ident, msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b'foo')
self.assertEqual(new_msg['msg_id'],msg['msg_id'])
self.assertEqual(new_msg['msg_type'],msg['msg_type'])
self.assertEqual(new_msg['header'],msg['header'])
self.assertEqual(new_msg['content'],msg['content'])
self.assertEqual(new_msg['parent_header'],msg['parent_header'])
self.assertEqual(new_msg['metadata'],msg['metadata'])
self.assertEqual(new_msg['buffers'],[b'bar'])
content = msg['content']
header = msg['header']
header['msg_id'] = self.session.msg_id
parent = msg['parent_header']
metadata = msg['metadata']
msg_type = header['msg_type']
self.session.send(A, None, content=content, parent=parent,
header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
ident, msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(msg_list)
self.assertEqual(ident[0], b'foo')
self.assertEqual(new_msg['msg_id'],header['msg_id'])
self.assertEqual(new_msg['msg_type'],msg['msg_type'])
self.assertEqual(new_msg['header'],msg['header'])
self.assertEqual(new_msg['content'],msg['content'])
self.assertEqual(new_msg['metadata'],msg['metadata'])
self.assertEqual(new_msg['parent_header'],msg['parent_header'])
self.assertEqual(new_msg['buffers'],[b'bar'])
header['msg_id'] = self.session.msg_id
self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
ident, new_msg = self.session.recv(B)
self.assertEqual(ident[0], b'foo')
self.assertEqual(new_msg['msg_id'],header['msg_id'])
self.assertEqual(new_msg['msg_type'],msg['msg_type'])
self.assertEqual(new_msg['header'],msg['header'])
self.assertEqual(new_msg['content'],msg['content'])
self.assertEqual(new_msg['metadata'],msg['metadata'])
self.assertEqual(new_msg['parent_header'],msg['parent_header'])
self.assertEqual(new_msg['buffers'],[b'bar'])
# buffers must support the buffer protocol
with self.assertRaises(TypeError):
self.session.send(A, msg, ident=b'foo', buffers=[1])
# buffers must be contiguous
buf = memoryview(os.urandom(16))
with self.assertRaises(ValueError):
self.session.send(A, msg, ident=b'foo', buffers=[buf[::2]])
A.close()
B.close()
ctx.term()
def test_args(self):
"""initialization arguments for Session"""
s = self.session
self.assertTrue(s.pack is ss.default_packer)
self.assertTrue(s.unpack is ss.default_unpacker)
self.assertEqual(s.username, os.environ.get('USER', 'username'))
s = ss.Session()
self.assertEqual(s.username, os.environ.get('USER', 'username'))
self.assertRaises(TypeError, ss.Session, pack='hi')
self.assertRaises(TypeError, ss.Session, unpack='hi')
u = str(uuid.uuid4())
s = ss.Session(username='carrot', session=u)
self.assertEqual(s.session, u)
self.assertEqual(s.username, 'carrot')
def test_tracking(self):
"""test tracking messages"""
a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
s = self.session
s.copy_threshold = 1
stream = ZMQStream(a)
msg = s.send(a, 'hello', track=False)
self.assertTrue(msg['tracker'] is ss.DONE)
msg = s.send(a, 'hello', track=True)
self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
M = zmq.Message(b'hi there', track=True)
msg = s.send(a, 'hello', buffers=[M], track=True)
t = msg['tracker']
self.assertTrue(isinstance(t, zmq.MessageTracker))
self.assertRaises(zmq.NotDone, t.wait, .1)
del M
t.wait(1) # this will raise
def test_unique_msg_ids(self):
"""test that messages receive unique ids"""
ids = set()
for i in range(2**12):
h = self.session.msg_header('test')
msg_id = h['msg_id']
self.assertTrue(msg_id not in ids)
ids.add(msg_id)
def test_feed_identities(self):
"""scrub the front for zmq IDENTITIES"""
theids = "engine client other".split()
content = dict(code='whoda',stuff=object())
themsg = self.session.msg('execute',content=content)
pmsg = theids
def test_session_id(self):
session = ss.Session()
# get bs before us
bs = session.bsession
us = session.session
self.assertEqual(us.encode('ascii'), bs)
session = ss.Session()
# get us before bs
us = session.session
bs = session.bsession
self.assertEqual(us.encode('ascii'), bs)
# change propagates:
session.session = 'something else'
bs = session.bsession
us = session.session
self.assertEqual(us.encode('ascii'), bs)
session = ss.Session(session='stuff')
# get us before bs
self.assertEqual(session.bsession, session.session.encode('ascii'))
self.assertEqual(b'stuff', session.bsession)
def test_zero_digest_history(self):
session = ss.Session(digest_history_size=0)
for i in range(11):
session._add_digest(uuid.uuid4().bytes)
self.assertEqual(len(session.digest_history), 0)
def test_cull_digest_history(self):
session = ss.Session(digest_history_size=100)
for i in range(100):
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 100)
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 91)
for i in range(9):
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 100)
session._add_digest(uuid.uuid4().bytes)
self.assertTrue(len(session.digest_history) == 91)
def test_bad_pack(self):
try:
session = ss.Session(pack=_bad_packer)
except ValueError as e:
self.assertIn("could not serialize", str(e))
self.assertIn("don't work", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_unpack(self):
try:
session = ss.Session(unpack=_bad_unpacker)
except ValueError as e:
self.assertIn("could not handle output", str(e))
self.assertIn("don't work either", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_packer(self):
try:
session = ss.Session(packer=__name__ + '._bad_packer')
except ValueError as e:
self.assertIn("could not serialize", str(e))
self.assertIn("don't work", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_unpacker(self):
try:
session = ss.Session(unpacker=__name__ + '._bad_unpacker')
except ValueError as e:
self.assertIn("could not handle output", str(e))
self.assertIn("don't work either", str(e))
else:
self.fail("Should have raised ValueError")
def test_bad_roundtrip(self):
with self.assertRaises(ValueError):
session = ss.Session(unpack=lambda b: 5)
def _datetime_test(self, session):
content = dict(t=ss.utcnow())
metadata = dict(t=ss.utcnow())
p = session.msg('msg')
msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
smsg = session.serialize(msg)
msg2 = session.deserialize(session.feed_identities(smsg)[1])
assert isinstance(msg2['header']['date'], datetime)
self.assertEqual(msg['header'], msg2['header'])
self.assertEqual(msg['parent_header'], msg2['parent_header'])
self.assertEqual(msg['parent_header'], msg2['parent_header'])
assert isinstance(msg['content']['t'], datetime)
assert isinstance(msg['metadata']['t'], datetime)
assert isinstance(msg2['content']['t'], str)
assert isinstance(msg2['metadata']['t'], str)
self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
def test_datetimes(self):
self._datetime_test(self.session)
def test_datetimes_pickle(self):
session = ss.Session(packer='pickle')
self._datetime_test(session)
def test_datetimes_msgpack(self):
msgpack = pytest.importorskip('msgpack')
session = ss.Session(
pack=msgpack.packb,
unpack=lambda buf: msgpack.unpackb(buf, encoding='utf8'),
)
self._datetime_test(session)
def test_send_raw(self):
ctx = zmq.Context()
A = ctx.socket(zmq.PAIR)
B = ctx.socket(zmq.PAIR)
A.bind("inproc://test")
B.connect("inproc://test")
msg = self.session.msg('execute', content=dict(a=10))
msg_list = [self.session.pack(msg[part]) for part in
['header', 'parent_header', 'metadata', 'content']]
self.session.send_raw(A, msg_list, ident=b'foo')
ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
new_msg = self.session.deserialize(new_msg_list)
self.assertEqual(ident[0], b'foo')
self.assertEqual(new_msg['msg_type'],msg['msg_type'])
self.assertEqual(new_msg['header'],msg['header'])
self.assertEqual(new_msg['parent_header'],msg['parent_header'])
self.assertEqual(new_msg['content'],msg['content'])
self.assertEqual(new_msg['metadata'],msg['metadata'])
A.close()
B.close()
ctx.term()
def test_clone(self):
s = self.session
s._add_digest('initial')
s2 = s.clone()
assert s2.session == s.session
assert s2.digest_history == s.digest_history
assert s2.digest_history is not s.digest_history
digest = 'abcdef'
s._add_digest(digest)
assert digest in s.digest_history
assert digest not in s2.digest_history

View file

@ -0,0 +1,8 @@
from jupyter_client.ssh.tunnel import select_random_ports
def test_random_ports():
for i in range(4096):
ports = select_random_ports(10)
assert len(ports) == 10
for p in ports:
assert ports.count(p) == 1

View file

@ -0,0 +1,89 @@
"""Testing utils for jupyter_client tests
"""
import os
pjoin = os.path.join
import sys
from unittest.mock import patch
import pytest
from jupyter_client import AsyncKernelManager
from ipython_genutils.tempdir import TemporaryDirectory
skip_win32 = pytest.mark.skipif(sys.platform.startswith('win'), reason="Windows")
class test_env(object):
"""Set Jupyter path variables to a temporary directory
Useful as a context manager or with explicit start/stop
"""
def start(self):
self.test_dir = td = TemporaryDirectory()
self.env_patch = patch.dict(os.environ, {
'JUPYTER_CONFIG_DIR': pjoin(td.name, 'jupyter'),
'JUPYTER_DATA_DIR': pjoin(td.name, 'jupyter_data'),
'JUPYTER_RUNTIME_DIR': pjoin(td.name, 'jupyter_runtime'),
'IPYTHONDIR': pjoin(td.name, 'ipython'),
'TEST_VARS': 'test_var_1',
})
self.env_patch.start()
def stop(self):
self.env_patch.stop()
self.test_dir.cleanup()
def __enter__(self):
self.start()
return self.test_dir.name
def __exit__(self, *exc_info):
self.stop()
def execute(code='', kc=None, **kwargs):
"""wrapper for doing common steps for validating an execution request"""
from .test_message_spec import validate_message
if kc is None:
kc = KC
msg_id = kc.execute(code=code, **kwargs)
reply = kc.get_shell_msg(timeout=TIMEOUT)
validate_message(reply, 'execute_reply', msg_id)
busy = kc.get_iopub_msg(timeout=TIMEOUT)
validate_message(busy, 'status', msg_id)
assert busy['content']['execution_state'] == 'busy'
if not kwargs.get('silent'):
execute_input = kc.get_iopub_msg(timeout=TIMEOUT)
validate_message(execute_input, 'execute_input', msg_id)
assert execute_input['content']['code'] == code
return msg_id, reply['content']
class AsyncKernelManagerSubclass(AsyncKernelManager):
"""Used to test deprecation "routes" that are determined by superclass' detection of methods.
This class represents a current subclass that overrides both cleanup() and cleanup_resources()
in order to be compatible with older jupyter_clients. We should find that cleanup_resources()
is called on these instances vix TestAsyncKernelManagerSubclass.
"""
def cleanup(self, connection_file=True):
super().cleanup(connection_file=connection_file)
self.which_cleanup = 'cleanup'
def cleanup_resources(self, restart=False):
super().cleanup_resources(restart=restart)
self.which_cleanup = 'cleanup_resources'
class AsyncKernelManagerWithCleanup(AsyncKernelManager):
"""Used to test deprecation "routes" that are determined by superclass' detection of methods.
This class represents the older subclass that overrides cleanup(). We should find that
cleanup() is called on these instances via TestAsyncKernelManagerWithCleanup.
"""
def cleanup(self, connection_file=True):
super().cleanup(connection_file=connection_file)
self.which_cleanup = 'cleanup'

Some files were not shown because too many files have changed in this diff Show more