253 lines
9.9 KiB
Python
253 lines
9.9 KiB
Python
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import os
|
|
import logging
|
|
import mimetypes
|
|
|
|
from ..base.handlers import APIHandler, IPythonHandler
|
|
from ..utils import url_path_join
|
|
|
|
from tornado import gen, web
|
|
from tornado.concurrent import Future
|
|
from tornado.ioloop import IOLoop, PeriodicCallback
|
|
from tornado.websocket import WebSocketHandler, websocket_connect
|
|
from tornado.httpclient import HTTPRequest
|
|
from tornado.escape import url_escape, json_decode, utf8
|
|
|
|
from ipython_genutils.py3compat import cast_unicode
|
|
from jupyter_client.session import Session
|
|
from traitlets.config.configurable import LoggingConfigurable
|
|
|
|
from .managers import GatewayClient
|
|
|
|
# Keepalive ping interval (default: 30 seconds)
|
|
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv('GATEWAY_WS_PING_INTERVAL_SECS', 30))
|
|
|
|
|
|
class WebSocketChannelsHandler(WebSocketHandler, IPythonHandler):
|
|
|
|
session = None
|
|
gateway = None
|
|
kernel_id = None
|
|
ping_callback = None
|
|
|
|
def check_origin(self, origin=None):
|
|
return IPythonHandler.check_origin(self, origin)
|
|
|
|
def set_default_headers(self):
|
|
"""Undo the set_default_headers in IPythonHandler which doesn't make sense for websockets"""
|
|
pass
|
|
|
|
def get_compression_options(self):
|
|
# use deflate compress websocket
|
|
return {}
|
|
|
|
def authenticate(self):
|
|
"""Run before finishing the GET request
|
|
|
|
Extend this method to add logic that should fire before
|
|
the websocket finishes completing.
|
|
"""
|
|
# authenticate the request before opening the websocket
|
|
if self.get_current_user() is None:
|
|
self.log.warning("Couldn't authenticate WebSocket connection")
|
|
raise web.HTTPError(403)
|
|
|
|
if self.get_argument('session_id', False):
|
|
self.session.session = cast_unicode(self.get_argument('session_id'))
|
|
else:
|
|
self.log.warning("No session ID specified")
|
|
|
|
def initialize(self):
|
|
self.log.debug("Initializing websocket connection %s", self.request.path)
|
|
self.session = Session(config=self.config)
|
|
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
|
|
|
|
@gen.coroutine
|
|
def get(self, kernel_id, *args, **kwargs):
|
|
self.authenticate()
|
|
self.kernel_id = cast_unicode(kernel_id, 'ascii')
|
|
yield super(WebSocketChannelsHandler, self).get(kernel_id=kernel_id, *args, **kwargs)
|
|
|
|
def send_ping(self):
|
|
if self.ws_connection is None and self.ping_callback is not None:
|
|
self.ping_callback.stop()
|
|
return
|
|
|
|
self.ping(b'')
|
|
|
|
def open(self, kernel_id, *args, **kwargs):
|
|
"""Handle web socket connection open to notebook server and delegate to gateway web socket handler """
|
|
self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
|
|
self.ping_callback.start()
|
|
|
|
self.gateway.on_open(
|
|
kernel_id=kernel_id,
|
|
message_callback=self.write_message,
|
|
compression_options=self.get_compression_options()
|
|
)
|
|
|
|
def on_message(self, message):
|
|
"""Forward message to gateway web socket handler."""
|
|
self.gateway.on_message(message)
|
|
|
|
def write_message(self, message, binary=False):
|
|
"""Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
|
|
if self.ws_connection: # prevent WebSocketClosedError
|
|
if isinstance(message, bytes):
|
|
binary = True
|
|
super(WebSocketChannelsHandler, self).write_message(message, binary=binary)
|
|
elif self.log.isEnabledFor(logging.DEBUG):
|
|
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
|
|
self.log.debug("Notebook client closed websocket connection - message dropped: {}".format(msg_summary))
|
|
|
|
def on_close(self):
|
|
self.log.debug("Closing websocket connection %s", self.request.path)
|
|
self.gateway.on_close()
|
|
super(WebSocketChannelsHandler, self).on_close()
|
|
|
|
@staticmethod
|
|
def _get_message_summary(message):
|
|
summary = []
|
|
message_type = message['msg_type']
|
|
summary.append('type: {}'.format(message_type))
|
|
|
|
if message_type == 'status':
|
|
summary.append(', state: {}'.format(message['content']['execution_state']))
|
|
elif message_type == 'error':
|
|
summary.append(', {}:{}:{}'.format(message['content']['ename'],
|
|
message['content']['evalue'],
|
|
message['content']['traceback']))
|
|
else:
|
|
summary.append(', ...') # don't display potentially sensitive data
|
|
|
|
return ''.join(summary)
|
|
|
|
|
|
class GatewayWebSocketClient(LoggingConfigurable):
|
|
"""Proxy web socket connection to a kernel/enterprise gateway."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super(GatewayWebSocketClient, self).__init__(**kwargs)
|
|
self.kernel_id = None
|
|
self.ws = None
|
|
self.ws_future = Future()
|
|
self.disconnected = False
|
|
|
|
@gen.coroutine
|
|
def _connect(self, kernel_id):
|
|
# websocket is initialized before connection
|
|
self.ws = None
|
|
self.kernel_id = kernel_id
|
|
ws_url = url_path_join(
|
|
GatewayClient.instance().ws_url,
|
|
GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels'
|
|
)
|
|
self.log.info('Connecting to {}'.format(ws_url))
|
|
kwargs = {}
|
|
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
|
|
|
|
request = HTTPRequest(ws_url, **kwargs)
|
|
self.ws_future = websocket_connect(request)
|
|
self.ws_future.add_done_callback(self._connection_done)
|
|
|
|
def _connection_done(self, fut):
|
|
if not self.disconnected and fut.exception() is None: # prevent concurrent.futures._base.CancelledError
|
|
self.ws = fut.result()
|
|
self.log.debug("Connection is ready: ws: {}".format(self.ws))
|
|
else:
|
|
self.log.warning("Websocket connection has been closed via client disconnect or due to error. "
|
|
"Kernel with ID '{}' may not be terminated on GatewayClient: {}".
|
|
format(self.kernel_id, GatewayClient.instance().url))
|
|
|
|
def _disconnect(self):
|
|
self.disconnected = True
|
|
if self.ws is not None:
|
|
# Close connection
|
|
self.ws.close()
|
|
elif not self.ws_future.done():
|
|
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
|
|
self.ws_future.cancel()
|
|
self.log.debug("_disconnect: future cancelled, disconnected: {}".format(self.disconnected))
|
|
|
|
@gen.coroutine
|
|
def _read_messages(self, callback):
|
|
"""Read messages from gateway server."""
|
|
while self.ws is not None:
|
|
message = None
|
|
if not self.disconnected:
|
|
try:
|
|
message = yield self.ws.read_message()
|
|
except Exception as e:
|
|
self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True)
|
|
if message is None:
|
|
if not self.disconnected:
|
|
self.log.warning("Lost connection to Gateway: {}".format(self.kernel_id))
|
|
break
|
|
callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
|
|
else: # ws cancelled - stop reading
|
|
break
|
|
|
|
if not self.disconnected: # if websocket is not disconnected by client, attept to reconnect to Gateway
|
|
self.log.info("Attempting to re-establish the connection to Gateway: {}".format(self.kernel_id))
|
|
self._connect(self.kernel_id)
|
|
loop = IOLoop.current()
|
|
loop.add_future(self.ws_future, lambda future: self._read_messages(callback))
|
|
|
|
def on_open(self, kernel_id, message_callback, **kwargs):
|
|
"""Web socket connection open against gateway server."""
|
|
self._connect(kernel_id)
|
|
loop = IOLoop.current()
|
|
loop.add_future(
|
|
self.ws_future,
|
|
lambda future: self._read_messages(message_callback)
|
|
)
|
|
|
|
def on_message(self, message):
|
|
"""Send message to gateway server."""
|
|
if self.ws is None:
|
|
loop = IOLoop.current()
|
|
loop.add_future(
|
|
self.ws_future,
|
|
lambda future: self._write_message(message)
|
|
)
|
|
else:
|
|
self._write_message(message)
|
|
|
|
def _write_message(self, message):
|
|
"""Send message to gateway server."""
|
|
try:
|
|
if not self.disconnected and self.ws is not None:
|
|
self.ws.write_message(message)
|
|
except Exception as e:
|
|
self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True)
|
|
|
|
def on_close(self):
|
|
"""Web socket closed event."""
|
|
self._disconnect()
|
|
|
|
|
|
class GatewayResourceHandler(APIHandler):
|
|
"""Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway."""
|
|
|
|
@web.authenticated
|
|
@gen.coroutine
|
|
def get(self, kernel_name, path, include_body=True):
|
|
ksm = self.kernel_spec_manager
|
|
kernel_spec_res = yield ksm.get_kernel_spec_resource(kernel_name, path)
|
|
if kernel_spec_res is None:
|
|
self.log.warning("Kernelspec resource '{}' for '{}' not found. Gateway may not support"
|
|
" resource serving.".format(path, kernel_name))
|
|
else:
|
|
self.set_header("Content-Type", mimetypes.guess_type(path)[0])
|
|
self.finish(kernel_spec_res)
|
|
|
|
|
|
from ..services.kernels.handlers import _kernel_id_regex
|
|
from ..services.kernelspecs.handlers import kernel_name_regex
|
|
|
|
default_handlers = [
|
|
(r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
|
|
(r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler),
|
|
]
|