342 lines
10 KiB
Python
342 lines
10 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
Provides a reload() function that acts recursively.
|
||
|
|
||
|
Python's normal :func:`python:reload` function only reloads the module that it's
|
||
|
passed. The :func:`reload` function in this module also reloads everything
|
||
|
imported from that module, which is useful when you're changing files deep
|
||
|
inside a package.
|
||
|
|
||
|
To use this as your default reload function, type this::
|
||
|
|
||
|
import builtins
|
||
|
from IPython.lib import deepreload
|
||
|
builtins.reload = deepreload.reload
|
||
|
|
||
|
A reference to the original :func:`python:reload` is stored in this module as
|
||
|
:data:`original_reload`, so you can restore it later.
|
||
|
|
||
|
This code is almost entirely based on knee.py, which is a Python
|
||
|
re-implementation of hierarchical module import.
|
||
|
"""
|
||
|
#*****************************************************************************
|
||
|
# Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
|
||
|
#
|
||
|
# Distributed under the terms of the BSD License. The full license is in
|
||
|
# the file COPYING, distributed as part of this software.
|
||
|
#*****************************************************************************
|
||
|
|
||
|
import builtins as builtin_mod
|
||
|
from contextlib import contextmanager
|
||
|
import imp
|
||
|
import sys
|
||
|
|
||
|
from types import ModuleType
|
||
|
from warnings import warn
|
||
|
import types
|
||
|
|
||
|
original_import = builtin_mod.__import__
|
||
|
|
||
|
@contextmanager
|
||
|
def replace_import_hook(new_import):
|
||
|
saved_import = builtin_mod.__import__
|
||
|
builtin_mod.__import__ = new_import
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
builtin_mod.__import__ = saved_import
|
||
|
|
||
|
def get_parent(globals, level):
|
||
|
"""
|
||
|
parent, name = get_parent(globals, level)
|
||
|
|
||
|
Return the package that an import is being performed in. If globals comes
|
||
|
from the module foo.bar.bat (not itself a package), this returns the
|
||
|
sys.modules entry for foo.bar. If globals is from a package's __init__.py,
|
||
|
the package's entry in sys.modules is returned.
|
||
|
|
||
|
If globals doesn't come from a package or a module in a package, or a
|
||
|
corresponding entry is not found in sys.modules, None is returned.
|
||
|
"""
|
||
|
orig_level = level
|
||
|
|
||
|
if not level or not isinstance(globals, dict):
|
||
|
return None, ''
|
||
|
|
||
|
pkgname = globals.get('__package__', None)
|
||
|
|
||
|
if pkgname is not None:
|
||
|
# __package__ is set, so use it
|
||
|
if not hasattr(pkgname, 'rindex'):
|
||
|
raise ValueError('__package__ set to non-string')
|
||
|
if len(pkgname) == 0:
|
||
|
if level > 0:
|
||
|
raise ValueError('Attempted relative import in non-package')
|
||
|
return None, ''
|
||
|
name = pkgname
|
||
|
else:
|
||
|
# __package__ not set, so figure it out and set it
|
||
|
if '__name__' not in globals:
|
||
|
return None, ''
|
||
|
modname = globals['__name__']
|
||
|
|
||
|
if '__path__' in globals:
|
||
|
# __path__ is set, so modname is already the package name
|
||
|
globals['__package__'] = name = modname
|
||
|
else:
|
||
|
# Normal module, so work out the package name if any
|
||
|
lastdot = modname.rfind('.')
|
||
|
if lastdot < 0 < level:
|
||
|
raise ValueError("Attempted relative import in non-package")
|
||
|
if lastdot < 0:
|
||
|
globals['__package__'] = None
|
||
|
return None, ''
|
||
|
globals['__package__'] = name = modname[:lastdot]
|
||
|
|
||
|
dot = len(name)
|
||
|
for x in range(level, 1, -1):
|
||
|
try:
|
||
|
dot = name.rindex('.', 0, dot)
|
||
|
except ValueError:
|
||
|
raise ValueError("attempted relative import beyond top-level "
|
||
|
"package")
|
||
|
name = name[:dot]
|
||
|
|
||
|
try:
|
||
|
parent = sys.modules[name]
|
||
|
except:
|
||
|
if orig_level < 1:
|
||
|
warn("Parent module '%.200s' not found while handling absolute "
|
||
|
"import" % name)
|
||
|
parent = None
|
||
|
else:
|
||
|
raise SystemError("Parent module '%.200s' not loaded, cannot "
|
||
|
"perform relative import" % name)
|
||
|
|
||
|
# We expect, but can't guarantee, if parent != None, that:
|
||
|
# - parent.__name__ == name
|
||
|
# - parent.__dict__ is globals
|
||
|
# If this is violated... Who cares?
|
||
|
return parent, name
|
||
|
|
||
|
def load_next(mod, altmod, name, buf):
|
||
|
"""
|
||
|
mod, name, buf = load_next(mod, altmod, name, buf)
|
||
|
|
||
|
altmod is either None or same as mod
|
||
|
"""
|
||
|
|
||
|
if len(name) == 0:
|
||
|
# completely empty module name should only happen in
|
||
|
# 'from . import' (or '__import__("")')
|
||
|
return mod, None, buf
|
||
|
|
||
|
dot = name.find('.')
|
||
|
if dot == 0:
|
||
|
raise ValueError('Empty module name')
|
||
|
|
||
|
if dot < 0:
|
||
|
subname = name
|
||
|
next = None
|
||
|
else:
|
||
|
subname = name[:dot]
|
||
|
next = name[dot+1:]
|
||
|
|
||
|
if buf != '':
|
||
|
buf += '.'
|
||
|
buf += subname
|
||
|
|
||
|
result = import_submodule(mod, subname, buf)
|
||
|
if result is None and mod != altmod:
|
||
|
result = import_submodule(altmod, subname, subname)
|
||
|
if result is not None:
|
||
|
buf = subname
|
||
|
|
||
|
if result is None:
|
||
|
raise ImportError("No module named %.200s" % name)
|
||
|
|
||
|
return result, next, buf
|
||
|
|
||
|
|
||
|
# Need to keep track of what we've already reloaded to prevent cyclic evil
|
||
|
found_now = {}
|
||
|
|
||
|
def import_submodule(mod, subname, fullname):
|
||
|
"""m = import_submodule(mod, subname, fullname)"""
|
||
|
# Require:
|
||
|
# if mod == None: subname == fullname
|
||
|
# else: mod.__name__ + "." + subname == fullname
|
||
|
|
||
|
global found_now
|
||
|
if fullname in found_now and fullname in sys.modules:
|
||
|
m = sys.modules[fullname]
|
||
|
else:
|
||
|
print('Reloading', fullname)
|
||
|
found_now[fullname] = 1
|
||
|
oldm = sys.modules.get(fullname, None)
|
||
|
|
||
|
if mod is None:
|
||
|
path = None
|
||
|
elif hasattr(mod, '__path__'):
|
||
|
path = mod.__path__
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
try:
|
||
|
# This appears to be necessary on Python 3, because imp.find_module()
|
||
|
# tries to import standard libraries (like io) itself, and we don't
|
||
|
# want them to be processed by our deep_import_hook.
|
||
|
with replace_import_hook(original_import):
|
||
|
fp, filename, stuff = imp.find_module(subname, path)
|
||
|
except ImportError:
|
||
|
return None
|
||
|
|
||
|
try:
|
||
|
m = imp.load_module(fullname, fp, filename, stuff)
|
||
|
except:
|
||
|
# load_module probably removed name from modules because of
|
||
|
# the error. Put back the original module object.
|
||
|
if oldm:
|
||
|
sys.modules[fullname] = oldm
|
||
|
raise
|
||
|
finally:
|
||
|
if fp: fp.close()
|
||
|
|
||
|
add_submodule(mod, m, fullname, subname)
|
||
|
|
||
|
return m
|
||
|
|
||
|
def add_submodule(mod, submod, fullname, subname):
|
||
|
"""mod.{subname} = submod"""
|
||
|
if mod is None:
|
||
|
return #Nothing to do here.
|
||
|
|
||
|
if submod is None:
|
||
|
submod = sys.modules[fullname]
|
||
|
|
||
|
setattr(mod, subname, submod)
|
||
|
|
||
|
return
|
||
|
|
||
|
def ensure_fromlist(mod, fromlist, buf, recursive):
|
||
|
"""Handle 'from module import a, b, c' imports."""
|
||
|
if not hasattr(mod, '__path__'):
|
||
|
return
|
||
|
for item in fromlist:
|
||
|
if not hasattr(item, 'rindex'):
|
||
|
raise TypeError("Item in ``from list'' not a string")
|
||
|
if item == '*':
|
||
|
if recursive:
|
||
|
continue # avoid endless recursion
|
||
|
try:
|
||
|
all = mod.__all__
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
else:
|
||
|
ret = ensure_fromlist(mod, all, buf, 1)
|
||
|
if not ret:
|
||
|
return 0
|
||
|
elif not hasattr(mod, item):
|
||
|
import_submodule(mod, item, buf + '.' + item)
|
||
|
|
||
|
def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
|
||
|
"""Replacement for __import__()"""
|
||
|
parent, buf = get_parent(globals, level)
|
||
|
|
||
|
head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
|
||
|
|
||
|
tail = head
|
||
|
while name:
|
||
|
tail, name, buf = load_next(tail, tail, name, buf)
|
||
|
|
||
|
# If tail is None, both get_parent and load_next found
|
||
|
# an empty module name: someone called __import__("") or
|
||
|
# doctored faulty bytecode
|
||
|
if tail is None:
|
||
|
raise ValueError('Empty module name')
|
||
|
|
||
|
if not fromlist:
|
||
|
return head
|
||
|
|
||
|
ensure_fromlist(tail, fromlist, buf, 0)
|
||
|
return tail
|
||
|
|
||
|
modules_reloading = {}
|
||
|
|
||
|
def deep_reload_hook(m):
|
||
|
"""Replacement for reload()."""
|
||
|
# Hardcode this one as it would raise a NotImplementedError from the
|
||
|
# bowels of Python and screw up the import machinery after.
|
||
|
# unlike other imports the `exclude` list already in place is not enough.
|
||
|
|
||
|
if m is types:
|
||
|
return m
|
||
|
if not isinstance(m, ModuleType):
|
||
|
raise TypeError("reload() argument must be module")
|
||
|
|
||
|
name = m.__name__
|
||
|
|
||
|
if name not in sys.modules:
|
||
|
raise ImportError("reload(): module %.200s not in sys.modules" % name)
|
||
|
|
||
|
global modules_reloading
|
||
|
try:
|
||
|
return modules_reloading[name]
|
||
|
except:
|
||
|
modules_reloading[name] = m
|
||
|
|
||
|
dot = name.rfind('.')
|
||
|
if dot < 0:
|
||
|
subname = name
|
||
|
path = None
|
||
|
else:
|
||
|
try:
|
||
|
parent = sys.modules[name[:dot]]
|
||
|
except KeyError:
|
||
|
modules_reloading.clear()
|
||
|
raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
|
||
|
subname = name[dot+1:]
|
||
|
path = getattr(parent, "__path__", None)
|
||
|
|
||
|
try:
|
||
|
# This appears to be necessary on Python 3, because imp.find_module()
|
||
|
# tries to import standard libraries (like io) itself, and we don't
|
||
|
# want them to be processed by our deep_import_hook.
|
||
|
with replace_import_hook(original_import):
|
||
|
fp, filename, stuff = imp.find_module(subname, path)
|
||
|
finally:
|
||
|
modules_reloading.clear()
|
||
|
|
||
|
try:
|
||
|
newm = imp.load_module(name, fp, filename, stuff)
|
||
|
except:
|
||
|
# load_module probably removed name from modules because of
|
||
|
# the error. Put back the original module object.
|
||
|
sys.modules[name] = m
|
||
|
raise
|
||
|
finally:
|
||
|
if fp: fp.close()
|
||
|
|
||
|
modules_reloading.clear()
|
||
|
return newm
|
||
|
|
||
|
# Save the original hooks
|
||
|
original_reload = imp.reload
|
||
|
|
||
|
# Replacement for reload()
|
||
|
def reload(module, exclude=('sys', 'os.path', 'builtins', '__main__',
|
||
|
'numpy', 'numpy._globals')):
|
||
|
"""Recursively reload all modules used in the given module. Optionally
|
||
|
takes a list of modules to exclude from reloading. The default exclude
|
||
|
list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
|
||
|
display, exception, and io hooks.
|
||
|
"""
|
||
|
global found_now
|
||
|
for i in exclude:
|
||
|
found_now[i] = 1
|
||
|
try:
|
||
|
with replace_import_hook(deep_import_hook):
|
||
|
return deep_reload_hook(module)
|
||
|
finally:
|
||
|
found_now = {}
|