2510 lines
83 KiB
Python
2510 lines
83 KiB
Python
|
"""
|
||
|
Utility function to facilitate testing.
|
||
|
|
||
|
"""
|
||
|
import os
|
||
|
import sys
|
||
|
import platform
|
||
|
import re
|
||
|
import gc
|
||
|
import operator
|
||
|
import warnings
|
||
|
from functools import partial, wraps
|
||
|
import shutil
|
||
|
import contextlib
|
||
|
from tempfile import mkdtemp, mkstemp
|
||
|
from unittest.case import SkipTest
|
||
|
from warnings import WarningMessage
|
||
|
import pprint
|
||
|
|
||
|
from numpy.core import(
|
||
|
intp, float32, empty, arange, array_repr, ndarray, isnat, array)
|
||
|
import numpy.linalg.lapack_lite
|
||
|
|
||
|
from io import StringIO
|
||
|
|
||
|
__all__ = [
|
||
|
'assert_equal', 'assert_almost_equal', 'assert_approx_equal',
|
||
|
'assert_array_equal', 'assert_array_less', 'assert_string_equal',
|
||
|
'assert_array_almost_equal', 'assert_raises', 'build_err_msg',
|
||
|
'decorate_methods', 'jiffies', 'memusage', 'print_assert_equal',
|
||
|
'raises', 'rundocs', 'runstring', 'verbose', 'measure',
|
||
|
'assert_', 'assert_array_almost_equal_nulp', 'assert_raises_regex',
|
||
|
'assert_array_max_ulp', 'assert_warns', 'assert_no_warnings',
|
||
|
'assert_allclose', 'IgnoreException', 'clear_and_catch_warnings',
|
||
|
'SkipTest', 'KnownFailureException', 'temppath', 'tempdir', 'IS_PYPY',
|
||
|
'HAS_REFCOUNT', 'suppress_warnings', 'assert_array_compare',
|
||
|
'_assert_valid_refcount', '_gen_alignment_data', 'assert_no_gc_cycles',
|
||
|
'break_cycles', 'HAS_LAPACK64'
|
||
|
]
|
||
|
|
||
|
|
||
|
class KnownFailureException(Exception):
|
||
|
'''Raise this exception to mark a test as a known failing test.'''
|
||
|
pass
|
||
|
|
||
|
|
||
|
KnownFailureTest = KnownFailureException # backwards compat
|
||
|
verbose = 0
|
||
|
|
||
|
IS_PYPY = platform.python_implementation() == 'PyPy'
|
||
|
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None
|
||
|
HAS_LAPACK64 = numpy.linalg.lapack_lite._ilp64
|
||
|
|
||
|
|
||
|
def import_nose():
|
||
|
""" Import nose only when needed.
|
||
|
"""
|
||
|
nose_is_good = True
|
||
|
minimum_nose_version = (1, 0, 0)
|
||
|
try:
|
||
|
import nose
|
||
|
except ImportError:
|
||
|
nose_is_good = False
|
||
|
else:
|
||
|
if nose.__versioninfo__ < minimum_nose_version:
|
||
|
nose_is_good = False
|
||
|
|
||
|
if not nose_is_good:
|
||
|
msg = ('Need nose >= %d.%d.%d for tests - see '
|
||
|
'https://nose.readthedocs.io' %
|
||
|
minimum_nose_version)
|
||
|
raise ImportError(msg)
|
||
|
|
||
|
return nose
|
||
|
|
||
|
|
||
|
def assert_(val, msg=''):
|
||
|
"""
|
||
|
Assert that works in release mode.
|
||
|
Accepts callable msg to allow deferring evaluation until failure.
|
||
|
|
||
|
The Python built-in ``assert`` does not work when executing code in
|
||
|
optimized mode (the ``-O`` flag) - no byte-code is generated for it.
|
||
|
|
||
|
For documentation on usage, refer to the Python documentation.
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if not val:
|
||
|
try:
|
||
|
smsg = msg()
|
||
|
except TypeError:
|
||
|
smsg = msg
|
||
|
raise AssertionError(smsg)
|
||
|
|
||
|
|
||
|
def gisnan(x):
|
||
|
"""like isnan, but always raise an error if type not supported instead of
|
||
|
returning a TypeError object.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
isnan and other ufunc sometimes return a NotImplementedType object instead
|
||
|
of raising any exception. This function is a wrapper to make sure an
|
||
|
exception is always raised.
|
||
|
|
||
|
This should be removed once this problem is solved at the Ufunc level."""
|
||
|
from numpy.core import isnan
|
||
|
st = isnan(x)
|
||
|
if isinstance(st, type(NotImplemented)):
|
||
|
raise TypeError("isnan not supported for this type")
|
||
|
return st
|
||
|
|
||
|
|
||
|
def gisfinite(x):
|
||
|
"""like isfinite, but always raise an error if type not supported instead of
|
||
|
returning a TypeError object.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
isfinite and other ufunc sometimes return a NotImplementedType object instead
|
||
|
of raising any exception. This function is a wrapper to make sure an
|
||
|
exception is always raised.
|
||
|
|
||
|
This should be removed once this problem is solved at the Ufunc level."""
|
||
|
from numpy.core import isfinite, errstate
|
||
|
with errstate(invalid='ignore'):
|
||
|
st = isfinite(x)
|
||
|
if isinstance(st, type(NotImplemented)):
|
||
|
raise TypeError("isfinite not supported for this type")
|
||
|
return st
|
||
|
|
||
|
|
||
|
def gisinf(x):
|
||
|
"""like isinf, but always raise an error if type not supported instead of
|
||
|
returning a TypeError object.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
isinf and other ufunc sometimes return a NotImplementedType object instead
|
||
|
of raising any exception. This function is a wrapper to make sure an
|
||
|
exception is always raised.
|
||
|
|
||
|
This should be removed once this problem is solved at the Ufunc level."""
|
||
|
from numpy.core import isinf, errstate
|
||
|
with errstate(invalid='ignore'):
|
||
|
st = isinf(x)
|
||
|
if isinstance(st, type(NotImplemented)):
|
||
|
raise TypeError("isinf not supported for this type")
|
||
|
return st
|
||
|
|
||
|
|
||
|
if os.name == 'nt':
|
||
|
# Code "stolen" from enthought/debug/memusage.py
|
||
|
def GetPerformanceAttributes(object, counter, instance=None,
|
||
|
inum=-1, format=None, machine=None):
|
||
|
# NOTE: Many counters require 2 samples to give accurate results,
|
||
|
# including "% Processor Time" (as by definition, at any instant, a
|
||
|
# thread's CPU usage is either 0 or 100). To read counters like this,
|
||
|
# you should copy this function, but keep the counter open, and call
|
||
|
# CollectQueryData() each time you need to know.
|
||
|
# See http://msdn.microsoft.com/library/en-us/dnperfmo/html/perfmonpt2.asp (dead link)
|
||
|
# My older explanation for this was that the "AddCounter" process forced
|
||
|
# the CPU to 100%, but the above makes more sense :)
|
||
|
import win32pdh
|
||
|
if format is None:
|
||
|
format = win32pdh.PDH_FMT_LONG
|
||
|
path = win32pdh.MakeCounterPath( (machine, object, instance, None, inum, counter))
|
||
|
hq = win32pdh.OpenQuery()
|
||
|
try:
|
||
|
hc = win32pdh.AddCounter(hq, path)
|
||
|
try:
|
||
|
win32pdh.CollectQueryData(hq)
|
||
|
type, val = win32pdh.GetFormattedCounterValue(hc, format)
|
||
|
return val
|
||
|
finally:
|
||
|
win32pdh.RemoveCounter(hc)
|
||
|
finally:
|
||
|
win32pdh.CloseQuery(hq)
|
||
|
|
||
|
def memusage(processName="python", instance=0):
|
||
|
# from win32pdhutil, part of the win32all package
|
||
|
import win32pdh
|
||
|
return GetPerformanceAttributes("Process", "Virtual Bytes",
|
||
|
processName, instance,
|
||
|
win32pdh.PDH_FMT_LONG, None)
|
||
|
elif sys.platform[:5] == 'linux':
|
||
|
|
||
|
def memusage(_proc_pid_stat='/proc/%s/stat' % (os.getpid())):
|
||
|
"""
|
||
|
Return virtual memory size in bytes of the running python.
|
||
|
|
||
|
"""
|
||
|
try:
|
||
|
with open(_proc_pid_stat, 'r') as f:
|
||
|
l = f.readline().split(' ')
|
||
|
return int(l[22])
|
||
|
except Exception:
|
||
|
return
|
||
|
else:
|
||
|
def memusage():
|
||
|
"""
|
||
|
Return memory usage of running python. [Not implemented]
|
||
|
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
if sys.platform[:5] == 'linux':
|
||
|
def jiffies(_proc_pid_stat='/proc/%s/stat' % (os.getpid()),
|
||
|
_load_time=[]):
|
||
|
"""
|
||
|
Return number of jiffies elapsed.
|
||
|
|
||
|
Return number of jiffies (1/100ths of a second) that this
|
||
|
process has been scheduled in user mode. See man 5 proc.
|
||
|
|
||
|
"""
|
||
|
import time
|
||
|
if not _load_time:
|
||
|
_load_time.append(time.time())
|
||
|
try:
|
||
|
with open(_proc_pid_stat, 'r') as f:
|
||
|
l = f.readline().split(' ')
|
||
|
return int(l[13])
|
||
|
except Exception:
|
||
|
return int(100*(time.time()-_load_time[0]))
|
||
|
else:
|
||
|
# os.getpid is not in all platforms available.
|
||
|
# Using time is safe but inaccurate, especially when process
|
||
|
# was suspended or sleeping.
|
||
|
def jiffies(_load_time=[]):
|
||
|
"""
|
||
|
Return number of jiffies elapsed.
|
||
|
|
||
|
Return number of jiffies (1/100ths of a second) that this
|
||
|
process has been scheduled in user mode. See man 5 proc.
|
||
|
|
||
|
"""
|
||
|
import time
|
||
|
if not _load_time:
|
||
|
_load_time.append(time.time())
|
||
|
return int(100*(time.time()-_load_time[0]))
|
||
|
|
||
|
|
||
|
def build_err_msg(arrays, err_msg, header='Items are not equal:',
|
||
|
verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
|
||
|
msg = ['\n' + header]
|
||
|
if err_msg:
|
||
|
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
|
||
|
msg = [msg[0] + ' ' + err_msg]
|
||
|
else:
|
||
|
msg.append(err_msg)
|
||
|
if verbose:
|
||
|
for i, a in enumerate(arrays):
|
||
|
|
||
|
if isinstance(a, ndarray):
|
||
|
# precision argument is only needed if the objects are ndarrays
|
||
|
r_func = partial(array_repr, precision=precision)
|
||
|
else:
|
||
|
r_func = repr
|
||
|
|
||
|
try:
|
||
|
r = r_func(a)
|
||
|
except Exception as exc:
|
||
|
r = '[repr failed for <{}>: {}]'.format(type(a).__name__, exc)
|
||
|
if r.count('\n') > 3:
|
||
|
r = '\n'.join(r.splitlines()[:3])
|
||
|
r += '...'
|
||
|
msg.append(' %s: %s' % (names[i], r))
|
||
|
return '\n'.join(msg)
|
||
|
|
||
|
|
||
|
def assert_equal(actual, desired, err_msg='', verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two objects are not equal.
|
||
|
|
||
|
Given two objects (scalars, lists, tuples, dictionaries or numpy arrays),
|
||
|
check that all elements of these objects are equal. An exception is raised
|
||
|
at the first conflicting values.
|
||
|
|
||
|
When one of `actual` and `desired` is a scalar and the other is array_like,
|
||
|
the function checks that each element of the array_like object is equal to
|
||
|
the scalar.
|
||
|
|
||
|
This function handles NaN comparisons as if NaN was a "normal" number.
|
||
|
That is, AssertionError is not raised if both objects have NaNs in the same
|
||
|
positions. This is in contrast to the IEEE standard on NaNs, which says
|
||
|
that NaN compared to anything must return False.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
actual : array_like
|
||
|
The object to check.
|
||
|
desired : array_like
|
||
|
The expected object.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired are not equal.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.testing.assert_equal([4,5], [4,6])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Items are not equal:
|
||
|
item=1
|
||
|
ACTUAL: 5
|
||
|
DESIRED: 6
|
||
|
|
||
|
The following comparison does not raise an exception. There are NaNs
|
||
|
in the inputs, but they are in the same positions.
|
||
|
|
||
|
>>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan])
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
if isinstance(desired, dict):
|
||
|
if not isinstance(actual, dict):
|
||
|
raise AssertionError(repr(type(actual)))
|
||
|
assert_equal(len(actual), len(desired), err_msg, verbose)
|
||
|
for k, i in desired.items():
|
||
|
if k not in actual:
|
||
|
raise AssertionError(repr(k))
|
||
|
assert_equal(actual[k], desired[k], 'key=%r\n%s' % (k, err_msg), verbose)
|
||
|
return
|
||
|
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
|
||
|
assert_equal(len(actual), len(desired), err_msg, verbose)
|
||
|
for k in range(len(desired)):
|
||
|
assert_equal(actual[k], desired[k], 'item=%r\n%s' % (k, err_msg), verbose)
|
||
|
return
|
||
|
from numpy.core import ndarray, isscalar, signbit
|
||
|
from numpy.lib import iscomplexobj, real, imag
|
||
|
if isinstance(actual, ndarray) or isinstance(desired, ndarray):
|
||
|
return assert_array_equal(actual, desired, err_msg, verbose)
|
||
|
msg = build_err_msg([actual, desired], err_msg, verbose=verbose)
|
||
|
|
||
|
# Handle complex numbers: separate into real/imag to handle
|
||
|
# nan/inf/negative zero correctly
|
||
|
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail
|
||
|
try:
|
||
|
usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
|
||
|
except (ValueError, TypeError):
|
||
|
usecomplex = False
|
||
|
|
||
|
if usecomplex:
|
||
|
if iscomplexobj(actual):
|
||
|
actualr = real(actual)
|
||
|
actuali = imag(actual)
|
||
|
else:
|
||
|
actualr = actual
|
||
|
actuali = 0
|
||
|
if iscomplexobj(desired):
|
||
|
desiredr = real(desired)
|
||
|
desiredi = imag(desired)
|
||
|
else:
|
||
|
desiredr = desired
|
||
|
desiredi = 0
|
||
|
try:
|
||
|
assert_equal(actualr, desiredr)
|
||
|
assert_equal(actuali, desiredi)
|
||
|
except AssertionError:
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
# isscalar test to check cases such as [np.nan] != np.nan
|
||
|
if isscalar(desired) != isscalar(actual):
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
try:
|
||
|
isdesnat = isnat(desired)
|
||
|
isactnat = isnat(actual)
|
||
|
dtypes_match = array(desired).dtype.type == array(actual).dtype.type
|
||
|
if isdesnat and isactnat:
|
||
|
# If both are NaT (and have the same dtype -- datetime or
|
||
|
# timedelta) they are considered equal.
|
||
|
if dtypes_match:
|
||
|
return
|
||
|
else:
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
except (TypeError, ValueError, NotImplementedError):
|
||
|
pass
|
||
|
|
||
|
# Inf/nan/negative zero handling
|
||
|
try:
|
||
|
isdesnan = gisnan(desired)
|
||
|
isactnan = gisnan(actual)
|
||
|
if isdesnan and isactnan:
|
||
|
return # both nan, so equal
|
||
|
|
||
|
# handle signed zero specially for floats
|
||
|
array_actual = array(actual)
|
||
|
array_desired = array(desired)
|
||
|
if (array_actual.dtype.char in 'Mm' or
|
||
|
array_desired.dtype.char in 'Mm'):
|
||
|
# version 1.18
|
||
|
# until this version, gisnan failed for datetime64 and timedelta64.
|
||
|
# Now it succeeds but comparison to scalar with a different type
|
||
|
# emits a DeprecationWarning.
|
||
|
# Avoid that by skipping the next check
|
||
|
raise NotImplementedError('cannot compare to a scalar '
|
||
|
'with a different type')
|
||
|
|
||
|
if desired == 0 and actual == 0:
|
||
|
if not signbit(desired) == signbit(actual):
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
except (TypeError, ValueError, NotImplementedError):
|
||
|
pass
|
||
|
|
||
|
try:
|
||
|
# Explicitly use __eq__ for comparison, gh-2552
|
||
|
if not (desired == actual):
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
except (DeprecationWarning, FutureWarning) as e:
|
||
|
# this handles the case when the two types are not even comparable
|
||
|
if 'elementwise == comparison' in e.args[0]:
|
||
|
raise AssertionError(msg)
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def print_assert_equal(test_string, actual, desired):
|
||
|
"""
|
||
|
Test if two objects are equal, and print an error message if test fails.
|
||
|
|
||
|
The test is performed with ``actual == desired``.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
test_string : str
|
||
|
The message supplied to AssertionError.
|
||
|
actual : object
|
||
|
The object to test for equality against `desired`.
|
||
|
desired : object
|
||
|
The expected result.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1])
|
||
|
>>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError: Test XYZ of func xyz failed
|
||
|
ACTUAL:
|
||
|
[0, 1]
|
||
|
DESIRED:
|
||
|
[0, 2]
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import pprint
|
||
|
|
||
|
if not (actual == desired):
|
||
|
msg = StringIO()
|
||
|
msg.write(test_string)
|
||
|
msg.write(' failed\nACTUAL: \n')
|
||
|
pprint.pprint(actual, msg)
|
||
|
msg.write('DESIRED: \n')
|
||
|
pprint.pprint(desired, msg)
|
||
|
raise AssertionError(msg.getvalue())
|
||
|
|
||
|
|
||
|
def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two items are not equal up to desired
|
||
|
precision.
|
||
|
|
||
|
.. note:: It is recommended to use one of `assert_allclose`,
|
||
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
||
|
instead of this function for more consistent floating point
|
||
|
comparisons.
|
||
|
|
||
|
The test verifies that the elements of ``actual`` and ``desired`` satisfy.
|
||
|
|
||
|
``abs(desired-actual) < 1.5 * 10**(-decimal)``
|
||
|
|
||
|
That is a looser test than originally documented, but agrees with what the
|
||
|
actual implementation in `assert_array_almost_equal` did up to rounding
|
||
|
vagaries. An exception is raised at conflicting values. For ndarrays this
|
||
|
delegates to assert_array_almost_equal
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
actual : array_like
|
||
|
The object to check.
|
||
|
desired : array_like
|
||
|
The expected object.
|
||
|
decimal : int, optional
|
||
|
Desired precision, default is 7.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired are not equal up to specified precision.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_allclose: Compare two array_like objects for equality with desired
|
||
|
relative and/or absolute precision.
|
||
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> import numpy.testing as npt
|
||
|
>>> npt.assert_almost_equal(2.3333333333333, 2.33333334)
|
||
|
>>> npt.assert_almost_equal(2.3333333333333, 2.33333334, decimal=10)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not almost equal to 10 decimals
|
||
|
ACTUAL: 2.3333333333333
|
||
|
DESIRED: 2.33333334
|
||
|
|
||
|
>>> npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
|
||
|
... np.array([1.0,2.33333334]), decimal=9)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not almost equal to 9 decimals
|
||
|
<BLANKLINE>
|
||
|
Mismatched elements: 1 / 2 (50%)
|
||
|
Max absolute difference: 6.66669964e-09
|
||
|
Max relative difference: 2.85715698e-09
|
||
|
x: array([1. , 2.333333333])
|
||
|
y: array([1. , 2.33333334])
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
from numpy.core import ndarray
|
||
|
from numpy.lib import iscomplexobj, real, imag
|
||
|
|
||
|
# Handle complex numbers: separate into real/imag to handle
|
||
|
# nan/inf/negative zero correctly
|
||
|
# XXX: catch ValueError for subclasses of ndarray where iscomplex fail
|
||
|
try:
|
||
|
usecomplex = iscomplexobj(actual) or iscomplexobj(desired)
|
||
|
except ValueError:
|
||
|
usecomplex = False
|
||
|
|
||
|
def _build_err_msg():
|
||
|
header = ('Arrays are not almost equal to %d decimals' % decimal)
|
||
|
return build_err_msg([actual, desired], err_msg, verbose=verbose,
|
||
|
header=header)
|
||
|
|
||
|
if usecomplex:
|
||
|
if iscomplexobj(actual):
|
||
|
actualr = real(actual)
|
||
|
actuali = imag(actual)
|
||
|
else:
|
||
|
actualr = actual
|
||
|
actuali = 0
|
||
|
if iscomplexobj(desired):
|
||
|
desiredr = real(desired)
|
||
|
desiredi = imag(desired)
|
||
|
else:
|
||
|
desiredr = desired
|
||
|
desiredi = 0
|
||
|
try:
|
||
|
assert_almost_equal(actualr, desiredr, decimal=decimal)
|
||
|
assert_almost_equal(actuali, desiredi, decimal=decimal)
|
||
|
except AssertionError:
|
||
|
raise AssertionError(_build_err_msg())
|
||
|
|
||
|
if isinstance(actual, (ndarray, tuple, list)) \
|
||
|
or isinstance(desired, (ndarray, tuple, list)):
|
||
|
return assert_array_almost_equal(actual, desired, decimal, err_msg)
|
||
|
try:
|
||
|
# If one of desired/actual is not finite, handle it specially here:
|
||
|
# check that both are nan if any is a nan, and test for equality
|
||
|
# otherwise
|
||
|
if not (gisfinite(desired) and gisfinite(actual)):
|
||
|
if gisnan(desired) or gisnan(actual):
|
||
|
if not (gisnan(desired) and gisnan(actual)):
|
||
|
raise AssertionError(_build_err_msg())
|
||
|
else:
|
||
|
if not desired == actual:
|
||
|
raise AssertionError(_build_err_msg())
|
||
|
return
|
||
|
except (NotImplementedError, TypeError):
|
||
|
pass
|
||
|
if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
|
||
|
raise AssertionError(_build_err_msg())
|
||
|
|
||
|
|
||
|
def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two items are not equal up to significant
|
||
|
digits.
|
||
|
|
||
|
.. note:: It is recommended to use one of `assert_allclose`,
|
||
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
||
|
instead of this function for more consistent floating point
|
||
|
comparisons.
|
||
|
|
||
|
Given two numbers, check that they are approximately equal.
|
||
|
Approximately equal is defined as the number of significant digits
|
||
|
that agree.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
actual : scalar
|
||
|
The object to check.
|
||
|
desired : scalar
|
||
|
The expected object.
|
||
|
significant : int, optional
|
||
|
Desired precision, default is 7.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired are not equal up to specified precision.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_allclose: Compare two array_like objects for equality with desired
|
||
|
relative and/or absolute precision.
|
||
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20)
|
||
|
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20,
|
||
|
... significant=8)
|
||
|
>>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20,
|
||
|
... significant=8)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Items are not equal to 8 significant digits:
|
||
|
ACTUAL: 1.234567e-21
|
||
|
DESIRED: 1.2345672e-21
|
||
|
|
||
|
the evaluated condition that raises the exception is
|
||
|
|
||
|
>>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1)
|
||
|
True
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import numpy as np
|
||
|
|
||
|
(actual, desired) = map(float, (actual, desired))
|
||
|
if desired == actual:
|
||
|
return
|
||
|
# Normalized the numbers to be in range (-10.0,10.0)
|
||
|
# scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual))))))
|
||
|
with np.errstate(invalid='ignore'):
|
||
|
scale = 0.5*(np.abs(desired) + np.abs(actual))
|
||
|
scale = np.power(10, np.floor(np.log10(scale)))
|
||
|
try:
|
||
|
sc_desired = desired/scale
|
||
|
except ZeroDivisionError:
|
||
|
sc_desired = 0.0
|
||
|
try:
|
||
|
sc_actual = actual/scale
|
||
|
except ZeroDivisionError:
|
||
|
sc_actual = 0.0
|
||
|
msg = build_err_msg(
|
||
|
[actual, desired], err_msg,
|
||
|
header='Items are not equal to %d significant digits:' % significant,
|
||
|
verbose=verbose)
|
||
|
try:
|
||
|
# If one of desired/actual is not finite, handle it specially here:
|
||
|
# check that both are nan if any is a nan, and test for equality
|
||
|
# otherwise
|
||
|
if not (gisfinite(desired) and gisfinite(actual)):
|
||
|
if gisnan(desired) or gisnan(actual):
|
||
|
if not (gisnan(desired) and gisnan(actual)):
|
||
|
raise AssertionError(msg)
|
||
|
else:
|
||
|
if not desired == actual:
|
||
|
raise AssertionError(msg)
|
||
|
return
|
||
|
except (TypeError, NotImplementedError):
|
||
|
pass
|
||
|
if np.abs(sc_desired - sc_actual) >= np.power(10., -(significant-1)):
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
|
||
|
def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
|
||
|
header='', precision=6, equal_nan=True,
|
||
|
equal_inf=True):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
from numpy.core import array, array2string, isnan, inf, bool_, errstate, all, max, object_
|
||
|
|
||
|
x = array(x, copy=False, subok=True)
|
||
|
y = array(y, copy=False, subok=True)
|
||
|
|
||
|
# original array for output formatting
|
||
|
ox, oy = x, y
|
||
|
|
||
|
def isnumber(x):
|
||
|
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
|
||
|
|
||
|
def istime(x):
|
||
|
return x.dtype.char in "Mm"
|
||
|
|
||
|
def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
|
||
|
"""Handling nan/inf.
|
||
|
|
||
|
Combine results of running func on x and y, checking that they are True
|
||
|
at the same locations.
|
||
|
|
||
|
"""
|
||
|
x_id = func(x)
|
||
|
y_id = func(y)
|
||
|
# We include work-arounds here to handle three types of slightly
|
||
|
# pathological ndarray subclasses:
|
||
|
# (1) all() on `masked` array scalars can return masked arrays, so we
|
||
|
# use != True
|
||
|
# (2) __eq__ on some ndarray subclasses returns Python booleans
|
||
|
# instead of element-wise comparisons, so we cast to bool_() and
|
||
|
# use isinstance(..., bool) checks
|
||
|
# (3) subclasses with bare-bones __array_function__ implementations may
|
||
|
# not implement np.all(), so favor using the .all() method
|
||
|
# We are not committed to supporting such subclasses, but it's nice to
|
||
|
# support them if possible.
|
||
|
if bool_(x_id == y_id).all() != True:
|
||
|
msg = build_err_msg([x, y],
|
||
|
err_msg + '\nx and y %s location mismatch:'
|
||
|
% (hasval), verbose=verbose, header=header,
|
||
|
names=('x', 'y'), precision=precision)
|
||
|
raise AssertionError(msg)
|
||
|
# If there is a scalar, then here we know the array has the same
|
||
|
# flag as it everywhere, so we should return the scalar flag.
|
||
|
if isinstance(x_id, bool) or x_id.ndim == 0:
|
||
|
return bool_(x_id)
|
||
|
elif isinstance(x_id, bool) or y_id.ndim == 0:
|
||
|
return bool_(y_id)
|
||
|
else:
|
||
|
return y_id
|
||
|
|
||
|
try:
|
||
|
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
|
||
|
if not cond:
|
||
|
msg = build_err_msg([x, y],
|
||
|
err_msg
|
||
|
+ '\n(shapes %s, %s mismatch)' % (x.shape,
|
||
|
y.shape),
|
||
|
verbose=verbose, header=header,
|
||
|
names=('x', 'y'), precision=precision)
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
flagged = bool_(False)
|
||
|
if isnumber(x) and isnumber(y):
|
||
|
if equal_nan:
|
||
|
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
|
||
|
|
||
|
if equal_inf:
|
||
|
flagged |= func_assert_same_pos(x, y,
|
||
|
func=lambda xy: xy == +inf,
|
||
|
hasval='+inf')
|
||
|
flagged |= func_assert_same_pos(x, y,
|
||
|
func=lambda xy: xy == -inf,
|
||
|
hasval='-inf')
|
||
|
|
||
|
elif istime(x) and istime(y):
|
||
|
# If one is datetime64 and the other timedelta64 there is no point
|
||
|
if equal_nan and x.dtype.type == y.dtype.type:
|
||
|
flagged = func_assert_same_pos(x, y, func=isnat, hasval="NaT")
|
||
|
|
||
|
if flagged.ndim > 0:
|
||
|
x, y = x[~flagged], y[~flagged]
|
||
|
# Only do the comparison if actual values are left
|
||
|
if x.size == 0:
|
||
|
return
|
||
|
elif flagged:
|
||
|
# no sense doing comparison if everything is flagged.
|
||
|
return
|
||
|
|
||
|
val = comparison(x, y)
|
||
|
|
||
|
if isinstance(val, bool):
|
||
|
cond = val
|
||
|
reduced = array([val])
|
||
|
else:
|
||
|
reduced = val.ravel()
|
||
|
cond = reduced.all()
|
||
|
|
||
|
# The below comparison is a hack to ensure that fully masked
|
||
|
# results, for which val.ravel().all() returns np.ma.masked,
|
||
|
# do not trigger a failure (np.ma.masked != True evaluates as
|
||
|
# np.ma.masked, which is falsy).
|
||
|
if cond != True:
|
||
|
n_mismatch = reduced.size - reduced.sum(dtype=intp)
|
||
|
n_elements = flagged.size if flagged.ndim != 0 else reduced.size
|
||
|
percent_mismatch = 100 * n_mismatch / n_elements
|
||
|
remarks = [
|
||
|
'Mismatched elements: {} / {} ({:.3g}%)'.format(
|
||
|
n_mismatch, n_elements, percent_mismatch)]
|
||
|
|
||
|
with errstate(invalid='ignore', divide='ignore'):
|
||
|
# ignore errors for non-numeric types
|
||
|
with contextlib.suppress(TypeError):
|
||
|
error = abs(x - y)
|
||
|
max_abs_error = max(error)
|
||
|
if getattr(error, 'dtype', object_) == object_:
|
||
|
remarks.append('Max absolute difference: '
|
||
|
+ str(max_abs_error))
|
||
|
else:
|
||
|
remarks.append('Max absolute difference: '
|
||
|
+ array2string(max_abs_error))
|
||
|
|
||
|
# note: this definition of relative error matches that one
|
||
|
# used by assert_allclose (found in np.isclose)
|
||
|
# Filter values where the divisor would be zero
|
||
|
nonzero = bool_(y != 0)
|
||
|
if all(~nonzero):
|
||
|
max_rel_error = array(inf)
|
||
|
else:
|
||
|
max_rel_error = max(error[nonzero] / abs(y[nonzero]))
|
||
|
if getattr(error, 'dtype', object_) == object_:
|
||
|
remarks.append('Max relative difference: '
|
||
|
+ str(max_rel_error))
|
||
|
else:
|
||
|
remarks.append('Max relative difference: '
|
||
|
+ array2string(max_rel_error))
|
||
|
|
||
|
err_msg += '\n' + '\n'.join(remarks)
|
||
|
msg = build_err_msg([ox, oy], err_msg,
|
||
|
verbose=verbose, header=header,
|
||
|
names=('x', 'y'), precision=precision)
|
||
|
raise AssertionError(msg)
|
||
|
except ValueError:
|
||
|
import traceback
|
||
|
efmt = traceback.format_exc()
|
||
|
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)
|
||
|
|
||
|
msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
|
||
|
names=('x', 'y'), precision=precision)
|
||
|
raise ValueError(msg)
|
||
|
|
||
|
|
||
|
def assert_array_equal(x, y, err_msg='', verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two array_like objects are not equal.
|
||
|
|
||
|
Given two array_like objects, check that the shape is equal and all
|
||
|
elements of these objects are equal (but see the Notes for the special
|
||
|
handling of a scalar). An exception is raised at shape mismatch or
|
||
|
conflicting values. In contrast to the standard usage in numpy, NaNs
|
||
|
are compared like numbers, no assertion is raised if both objects have
|
||
|
NaNs in the same positions.
|
||
|
|
||
|
The usual caution for verifying equality with floating point numbers is
|
||
|
advised.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x : array_like
|
||
|
The actual object to check.
|
||
|
y : array_like
|
||
|
The desired, expected object.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired objects are not equal.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_allclose: Compare two array_like objects for equality with desired
|
||
|
relative and/or absolute precision.
|
||
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
When one of `x` and `y` is a scalar and the other is array_like, the
|
||
|
function checks that each element of the array_like object is equal to
|
||
|
the scalar.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
The first assert does not raise an exception:
|
||
|
|
||
|
>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
|
||
|
... [np.exp(0),2.33333, np.nan])
|
||
|
|
||
|
Assert fails with numerical imprecision with floats:
|
||
|
|
||
|
>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
|
||
|
... [1, np.sqrt(np.pi)**2, np.nan])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not equal
|
||
|
<BLANKLINE>
|
||
|
Mismatched elements: 1 / 3 (33.3%)
|
||
|
Max absolute difference: 4.4408921e-16
|
||
|
Max relative difference: 1.41357986e-16
|
||
|
x: array([1. , 3.141593, nan])
|
||
|
y: array([1. , 3.141593, nan])
|
||
|
|
||
|
Use `assert_allclose` or one of the nulp (number of floating point values)
|
||
|
functions for these cases instead:
|
||
|
|
||
|
>>> np.testing.assert_allclose([1.0,np.pi,np.nan],
|
||
|
... [1, np.sqrt(np.pi)**2, np.nan],
|
||
|
... rtol=1e-10, atol=0)
|
||
|
|
||
|
As mentioned in the Notes section, `assert_array_equal` has special
|
||
|
handling for scalars. Here the test checks that each value in `x` is 3:
|
||
|
|
||
|
>>> x = np.full((2, 5), fill_value=3)
|
||
|
>>> np.testing.assert_array_equal(x, 3)
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg,
|
||
|
verbose=verbose, header='Arrays are not equal')
|
||
|
|
||
|
|
||
|
def assert_array_almost_equal(x, y, decimal=6, err_msg='', verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two objects are not equal up to desired
|
||
|
precision.
|
||
|
|
||
|
.. note:: It is recommended to use one of `assert_allclose`,
|
||
|
`assert_array_almost_equal_nulp` or `assert_array_max_ulp`
|
||
|
instead of this function for more consistent floating point
|
||
|
comparisons.
|
||
|
|
||
|
The test verifies identical shapes and that the elements of ``actual`` and
|
||
|
``desired`` satisfy.
|
||
|
|
||
|
``abs(desired-actual) < 1.5 * 10**(-decimal)``
|
||
|
|
||
|
That is a looser test than originally documented, but agrees with what the
|
||
|
actual implementation did up to rounding vagaries. An exception is raised
|
||
|
at shape mismatch or conflicting values. In contrast to the standard usage
|
||
|
in numpy, NaNs are compared like numbers, no assertion is raised if both
|
||
|
objects have NaNs in the same positions.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x : array_like
|
||
|
The actual object to check.
|
||
|
y : array_like
|
||
|
The desired, expected object.
|
||
|
decimal : int, optional
|
||
|
Desired precision, default is 6.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired are not equal up to specified precision.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_allclose: Compare two array_like objects for equality with desired
|
||
|
relative and/or absolute precision.
|
||
|
assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
the first assert does not raise an exception
|
||
|
|
||
|
>>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan],
|
||
|
... [1.0,2.333,np.nan])
|
||
|
|
||
|
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
||
|
... [1.0,2.33339,np.nan], decimal=5)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not almost equal to 5 decimals
|
||
|
<BLANKLINE>
|
||
|
Mismatched elements: 1 / 3 (33.3%)
|
||
|
Max absolute difference: 6.e-05
|
||
|
Max relative difference: 2.57136612e-05
|
||
|
x: array([1. , 2.33333, nan])
|
||
|
y: array([1. , 2.33339, nan])
|
||
|
|
||
|
>>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan],
|
||
|
... [1.0,2.33333, 5], decimal=5)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not almost equal to 5 decimals
|
||
|
<BLANKLINE>
|
||
|
x and y nan location mismatch:
|
||
|
x: array([1. , 2.33333, nan])
|
||
|
y: array([1. , 2.33333, 5. ])
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
from numpy.core import number, float_, result_type, array
|
||
|
from numpy.core.numerictypes import issubdtype
|
||
|
from numpy.core.fromnumeric import any as npany
|
||
|
|
||
|
def compare(x, y):
|
||
|
try:
|
||
|
if npany(gisinf(x)) or npany( gisinf(y)):
|
||
|
xinfid = gisinf(x)
|
||
|
yinfid = gisinf(y)
|
||
|
if not (xinfid == yinfid).all():
|
||
|
return False
|
||
|
# if one item, x and y is +- inf
|
||
|
if x.size == y.size == 1:
|
||
|
return x == y
|
||
|
x = x[~xinfid]
|
||
|
y = y[~yinfid]
|
||
|
except (TypeError, NotImplementedError):
|
||
|
pass
|
||
|
|
||
|
# make sure y is an inexact type to avoid abs(MIN_INT); will cause
|
||
|
# casting of x later.
|
||
|
dtype = result_type(y, 1.)
|
||
|
y = array(y, dtype=dtype, copy=False, subok=True)
|
||
|
z = abs(x - y)
|
||
|
|
||
|
if not issubdtype(z.dtype, number):
|
||
|
z = z.astype(float_) # handle object arrays
|
||
|
|
||
|
return z < 1.5 * 10.0**(-decimal)
|
||
|
|
||
|
assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
|
||
|
header=('Arrays are not almost equal to %d decimals' % decimal),
|
||
|
precision=decimal)
|
||
|
|
||
|
|
||
|
def assert_array_less(x, y, err_msg='', verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two array_like objects are not ordered by less
|
||
|
than.
|
||
|
|
||
|
Given two array_like objects, check that the shape is equal and all
|
||
|
elements of the first object are strictly smaller than those of the
|
||
|
second object. An exception is raised at shape mismatch or incorrectly
|
||
|
ordered values. Shape mismatch does not raise if an object has zero
|
||
|
dimension. In contrast to the standard usage in numpy, NaNs are
|
||
|
compared, no assertion is raised if both objects have NaNs in the same
|
||
|
positions.
|
||
|
|
||
|
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x : array_like
|
||
|
The smaller object to check.
|
||
|
y : array_like
|
||
|
The larger object to compare.
|
||
|
err_msg : string
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired objects are not equal.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_array_equal: tests objects for equality
|
||
|
assert_array_almost_equal: test objects for equality up to precision
|
||
|
|
||
|
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
|
||
|
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not less-ordered
|
||
|
<BLANKLINE>
|
||
|
Mismatched elements: 1 / 3 (33.3%)
|
||
|
Max absolute difference: 1.
|
||
|
Max relative difference: 0.5
|
||
|
x: array([ 1., 1., nan])
|
||
|
y: array([ 1., 2., nan])
|
||
|
|
||
|
>>> np.testing.assert_array_less([1.0, 4.0], 3)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not less-ordered
|
||
|
<BLANKLINE>
|
||
|
Mismatched elements: 1 / 2 (50%)
|
||
|
Max absolute difference: 2.
|
||
|
Max relative difference: 0.66666667
|
||
|
x: array([1., 4.])
|
||
|
y: array(3)
|
||
|
|
||
|
>>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError:
|
||
|
Arrays are not less-ordered
|
||
|
<BLANKLINE>
|
||
|
(shapes (3,), (1,) mismatch)
|
||
|
x: array([1., 2., 3.])
|
||
|
y: array([4])
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
|
||
|
verbose=verbose,
|
||
|
header='Arrays are not less-ordered',
|
||
|
equal_inf=False)
|
||
|
|
||
|
|
||
|
def runstring(astr, dict):
|
||
|
exec(astr, dict)
|
||
|
|
||
|
|
||
|
def assert_string_equal(actual, desired):
|
||
|
"""
|
||
|
Test if two strings are equal.
|
||
|
|
||
|
If the given strings are equal, `assert_string_equal` does nothing.
|
||
|
If they are not equal, an AssertionError is raised, and the diff
|
||
|
between the strings is shown.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
actual : str
|
||
|
The string to test for equality against the expected string.
|
||
|
desired : str
|
||
|
The expected string.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> np.testing.assert_string_equal('abc', 'abc')
|
||
|
>>> np.testing.assert_string_equal('abc', 'abcd')
|
||
|
Traceback (most recent call last):
|
||
|
File "<stdin>", line 1, in <module>
|
||
|
...
|
||
|
AssertionError: Differences in strings:
|
||
|
- abc+ abcd? +
|
||
|
|
||
|
"""
|
||
|
# delay import of difflib to reduce startup time
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import difflib
|
||
|
|
||
|
if not isinstance(actual, str):
|
||
|
raise AssertionError(repr(type(actual)))
|
||
|
if not isinstance(desired, str):
|
||
|
raise AssertionError(repr(type(desired)))
|
||
|
if desired == actual:
|
||
|
return
|
||
|
|
||
|
diff = list(difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)))
|
||
|
diff_list = []
|
||
|
while diff:
|
||
|
d1 = diff.pop(0)
|
||
|
if d1.startswith(' '):
|
||
|
continue
|
||
|
if d1.startswith('- '):
|
||
|
l = [d1]
|
||
|
d2 = diff.pop(0)
|
||
|
if d2.startswith('? '):
|
||
|
l.append(d2)
|
||
|
d2 = diff.pop(0)
|
||
|
if not d2.startswith('+ '):
|
||
|
raise AssertionError(repr(d2))
|
||
|
l.append(d2)
|
||
|
if diff:
|
||
|
d3 = diff.pop(0)
|
||
|
if d3.startswith('? '):
|
||
|
l.append(d3)
|
||
|
else:
|
||
|
diff.insert(0, d3)
|
||
|
if d2[2:] == d1[2:]:
|
||
|
continue
|
||
|
diff_list.extend(l)
|
||
|
continue
|
||
|
raise AssertionError(repr(d1))
|
||
|
if not diff_list:
|
||
|
return
|
||
|
msg = 'Differences in strings:\n%s' % (''.join(diff_list)).rstrip()
|
||
|
if actual != desired:
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
|
||
|
def rundocs(filename=None, raise_on_error=True):
|
||
|
"""
|
||
|
Run doctests found in the given file.
|
||
|
|
||
|
By default `rundocs` raises an AssertionError on failure.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
filename : str
|
||
|
The path to the file for which the doctests are run.
|
||
|
raise_on_error : bool
|
||
|
Whether to raise an AssertionError when a doctest fails. Default is
|
||
|
True.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
The doctests can be run by the user/developer by adding the ``doctests``
|
||
|
argument to the ``test()`` call. For example, to run all tests (including
|
||
|
doctests) for `numpy.lib`:
|
||
|
|
||
|
>>> np.lib.test(doctests=True) # doctest: +SKIP
|
||
|
"""
|
||
|
from numpy.compat import npy_load_module
|
||
|
import doctest
|
||
|
if filename is None:
|
||
|
f = sys._getframe(1)
|
||
|
filename = f.f_globals['__file__']
|
||
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||
|
m = npy_load_module(name, filename)
|
||
|
|
||
|
tests = doctest.DocTestFinder().find(m)
|
||
|
runner = doctest.DocTestRunner(verbose=False)
|
||
|
|
||
|
msg = []
|
||
|
if raise_on_error:
|
||
|
out = lambda s: msg.append(s)
|
||
|
else:
|
||
|
out = None
|
||
|
|
||
|
for test in tests:
|
||
|
runner.run(test, out=out)
|
||
|
|
||
|
if runner.failures > 0 and raise_on_error:
|
||
|
raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg))
|
||
|
|
||
|
|
||
|
def raises(*args):
|
||
|
"""Decorator to check for raised exceptions.
|
||
|
|
||
|
The decorated test function must raise one of the passed exceptions to
|
||
|
pass. If you want to test many assertions about exceptions in a single
|
||
|
test, you may want to use `assert_raises` instead.
|
||
|
|
||
|
.. warning::
|
||
|
This decorator is nose specific, do not use it if you are using a
|
||
|
different test framework.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
args : exceptions
|
||
|
The test passes if any of the passed exceptions is raised.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
@raises(TypeError, ValueError)
|
||
|
def test_raises_type_error():
|
||
|
raise TypeError("This test passes")
|
||
|
|
||
|
@raises(Exception)
|
||
|
def test_that_fails_by_passing():
|
||
|
pass
|
||
|
|
||
|
"""
|
||
|
nose = import_nose()
|
||
|
return nose.tools.raises(*args)
|
||
|
|
||
|
#
|
||
|
# assert_raises and assert_raises_regex are taken from unittest.
|
||
|
#
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
class _Dummy(unittest.TestCase):
|
||
|
def nop(self):
|
||
|
pass
|
||
|
|
||
|
_d = _Dummy('nop')
|
||
|
|
||
|
def assert_raises(*args, **kwargs):
|
||
|
"""
|
||
|
assert_raises(exception_class, callable, *args, **kwargs)
|
||
|
assert_raises(exception_class)
|
||
|
|
||
|
Fail unless an exception of class exception_class is thrown
|
||
|
by callable when invoked with arguments args and keyword
|
||
|
arguments kwargs. If a different type of exception is
|
||
|
thrown, it will not be caught, and the test case will be
|
||
|
deemed to have suffered an error, exactly as for an
|
||
|
unexpected exception.
|
||
|
|
||
|
Alternatively, `assert_raises` can be used as a context manager:
|
||
|
|
||
|
>>> from numpy.testing import assert_raises
|
||
|
>>> with assert_raises(ZeroDivisionError):
|
||
|
... 1 / 0
|
||
|
|
||
|
is equivalent to
|
||
|
|
||
|
>>> def div(x, y):
|
||
|
... return x / y
|
||
|
>>> assert_raises(ZeroDivisionError, div, 1, 0)
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
return _d.assertRaises(*args,**kwargs)
|
||
|
|
||
|
|
||
|
def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs):
|
||
|
"""
|
||
|
assert_raises_regex(exception_class, expected_regexp, callable, *args,
|
||
|
**kwargs)
|
||
|
assert_raises_regex(exception_class, expected_regexp)
|
||
|
|
||
|
Fail unless an exception of class exception_class and with message that
|
||
|
matches expected_regexp is thrown by callable when invoked with arguments
|
||
|
args and keyword arguments kwargs.
|
||
|
|
||
|
Alternatively, can be used as a context manager like `assert_raises`.
|
||
|
|
||
|
Name of this function adheres to Python 3.2+ reference, but should work in
|
||
|
all versions down to 2.6.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
.. versionadded:: 1.9.0
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs)
|
||
|
|
||
|
|
||
|
def decorate_methods(cls, decorator, testmatch=None):
|
||
|
"""
|
||
|
Apply a decorator to all methods in a class matching a regular expression.
|
||
|
|
||
|
The given decorator is applied to all public methods of `cls` that are
|
||
|
matched by the regular expression `testmatch`
|
||
|
(``testmatch.search(methodname)``). Methods that are private, i.e. start
|
||
|
with an underscore, are ignored.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
cls : class
|
||
|
Class whose methods to decorate.
|
||
|
decorator : function
|
||
|
Decorator to apply to methods
|
||
|
testmatch : compiled regexp or str, optional
|
||
|
The regular expression. Default value is None, in which case the
|
||
|
nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``)
|
||
|
is used.
|
||
|
If `testmatch` is a string, it is compiled to a regular expression
|
||
|
first.
|
||
|
|
||
|
"""
|
||
|
if testmatch is None:
|
||
|
testmatch = re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)
|
||
|
else:
|
||
|
testmatch = re.compile(testmatch)
|
||
|
cls_attr = cls.__dict__
|
||
|
|
||
|
# delayed import to reduce startup time
|
||
|
from inspect import isfunction
|
||
|
|
||
|
methods = [_m for _m in cls_attr.values() if isfunction(_m)]
|
||
|
for function in methods:
|
||
|
try:
|
||
|
if hasattr(function, 'compat_func_name'):
|
||
|
funcname = function.compat_func_name
|
||
|
else:
|
||
|
funcname = function.__name__
|
||
|
except AttributeError:
|
||
|
# not a function
|
||
|
continue
|
||
|
if testmatch.search(funcname) and not funcname.startswith('_'):
|
||
|
setattr(cls, funcname, decorator(function))
|
||
|
return
|
||
|
|
||
|
|
||
|
def measure(code_str, times=1, label=None):
|
||
|
"""
|
||
|
Return elapsed time for executing code in the namespace of the caller.
|
||
|
|
||
|
The supplied code string is compiled with the Python builtin ``compile``.
|
||
|
The precision of the timing is 10 milli-seconds. If the code will execute
|
||
|
fast on this timescale, it can be executed many times to get reasonable
|
||
|
timing accuracy.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
code_str : str
|
||
|
The code to be timed.
|
||
|
times : int, optional
|
||
|
The number of times the code is executed. Default is 1. The code is
|
||
|
only compiled once.
|
||
|
label : str, optional
|
||
|
A label to identify `code_str` with. This is passed into ``compile``
|
||
|
as the second argument (for run-time error messages).
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
elapsed : float
|
||
|
Total elapsed time in seconds for executing `code_str` `times` times.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> times = 10
|
||
|
>>> etime = np.testing.measure('for i in range(1000): np.sqrt(i**2)', times=times)
|
||
|
>>> print("Time for a single execution : ", etime / times, "s") # doctest: +SKIP
|
||
|
Time for a single execution : 0.005 s
|
||
|
|
||
|
"""
|
||
|
frame = sys._getframe(1)
|
||
|
locs, globs = frame.f_locals, frame.f_globals
|
||
|
|
||
|
code = compile(code_str,
|
||
|
'Test name: %s ' % label,
|
||
|
'exec')
|
||
|
i = 0
|
||
|
elapsed = jiffies()
|
||
|
while i < times:
|
||
|
i += 1
|
||
|
exec(code, globs, locs)
|
||
|
elapsed = jiffies() - elapsed
|
||
|
return 0.01*elapsed
|
||
|
|
||
|
|
||
|
def _assert_valid_refcount(op):
|
||
|
"""
|
||
|
Check that ufuncs don't mishandle refcount of object `1`.
|
||
|
Used in a few regression tests.
|
||
|
"""
|
||
|
if not HAS_REFCOUNT:
|
||
|
return True
|
||
|
|
||
|
import gc
|
||
|
import numpy as np
|
||
|
|
||
|
b = np.arange(100*100).reshape(100, 100)
|
||
|
c = b
|
||
|
i = 1
|
||
|
|
||
|
gc.disable()
|
||
|
try:
|
||
|
rc = sys.getrefcount(i)
|
||
|
for j in range(15):
|
||
|
d = op(b, c)
|
||
|
assert_(sys.getrefcount(i) >= rc)
|
||
|
finally:
|
||
|
gc.enable()
|
||
|
del d # for pyflakes
|
||
|
|
||
|
|
||
|
def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
|
||
|
err_msg='', verbose=True):
|
||
|
"""
|
||
|
Raises an AssertionError if two objects are not equal up to desired
|
||
|
tolerance.
|
||
|
|
||
|
The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note
|
||
|
that ``allclose`` has different default values). It compares the difference
|
||
|
between `actual` and `desired` to ``atol + rtol * abs(desired)``.
|
||
|
|
||
|
.. versionadded:: 1.5.0
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
actual : array_like
|
||
|
Array obtained.
|
||
|
desired : array_like
|
||
|
Array desired.
|
||
|
rtol : float, optional
|
||
|
Relative tolerance.
|
||
|
atol : float, optional
|
||
|
Absolute tolerance.
|
||
|
equal_nan : bool, optional.
|
||
|
If True, NaNs will compare equal.
|
||
|
err_msg : str, optional
|
||
|
The error message to be printed in case of failure.
|
||
|
verbose : bool, optional
|
||
|
If True, the conflicting values are appended to the error message.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If actual and desired are not equal up to specified precision.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_array_almost_equal_nulp, assert_array_max_ulp
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> x = [1e-5, 1e-3, 1e-1]
|
||
|
>>> y = np.arccos(np.cos(x))
|
||
|
>>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import numpy as np
|
||
|
|
||
|
def compare(x, y):
|
||
|
return np.core.numeric.isclose(x, y, rtol=rtol, atol=atol,
|
||
|
equal_nan=equal_nan)
|
||
|
|
||
|
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
|
||
|
header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
|
||
|
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
|
||
|
verbose=verbose, header=header, equal_nan=equal_nan)
|
||
|
|
||
|
|
||
|
def assert_array_almost_equal_nulp(x, y, nulp=1):
|
||
|
"""
|
||
|
Compare two arrays relatively to their spacing.
|
||
|
|
||
|
This is a relatively robust method to compare two arrays whose amplitude
|
||
|
is variable.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x, y : array_like
|
||
|
Input arrays.
|
||
|
nulp : int, optional
|
||
|
The maximum number of unit in the last place for tolerance (see Notes).
|
||
|
Default is 1.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
None
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If the spacing between `x` and `y` for one or more elements is larger
|
||
|
than `nulp`.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_array_max_ulp : Check that all items of arrays differ in at most
|
||
|
N Units in the Last Place.
|
||
|
spacing : Return the distance between x and the nearest adjacent number.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
An assertion is raised if the following condition is not met::
|
||
|
|
||
|
abs(x - y) <= nulps * spacing(maximum(abs(x), abs(y)))
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> x = np.array([1., 1e-10, 1e-20])
|
||
|
>>> eps = np.finfo(x.dtype).eps
|
||
|
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x)
|
||
|
|
||
|
>>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x)
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
AssertionError: X and Y are not equal to 1 ULP (max is 2)
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import numpy as np
|
||
|
ax = np.abs(x)
|
||
|
ay = np.abs(y)
|
||
|
ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
|
||
|
if not np.all(np.abs(x-y) <= ref):
|
||
|
if np.iscomplexobj(x) or np.iscomplexobj(y):
|
||
|
msg = "X and Y are not equal to %d ULP" % nulp
|
||
|
else:
|
||
|
max_nulp = np.max(nulp_diff(x, y))
|
||
|
msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
|
||
|
raise AssertionError(msg)
|
||
|
|
||
|
|
||
|
def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
|
||
|
"""
|
||
|
Check that all items of arrays differ in at most N Units in the Last Place.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
a, b : array_like
|
||
|
Input arrays to be compared.
|
||
|
maxulp : int, optional
|
||
|
The maximum number of units in the last place that elements of `a` and
|
||
|
`b` can differ. Default is 1.
|
||
|
dtype : dtype, optional
|
||
|
Data-type to convert `a` and `b` to if given. Default is None.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
ret : ndarray
|
||
|
Array containing number of representable floating point numbers between
|
||
|
items in `a` and `b`.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
AssertionError
|
||
|
If one or more elements differ by more than `maxulp`.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
For computing the ULP difference, this API does not differentiate between
|
||
|
various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
|
||
|
is zero).
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
assert_array_almost_equal_nulp : Compare two arrays relatively to their
|
||
|
spacing.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> a = np.linspace(0., 1., 100)
|
||
|
>>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a)))
|
||
|
|
||
|
"""
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
import numpy as np
|
||
|
ret = nulp_diff(a, b, dtype)
|
||
|
if not np.all(ret <= maxulp):
|
||
|
raise AssertionError("Arrays are not almost equal up to %g "
|
||
|
"ULP (max difference is %g ULP)" %
|
||
|
(maxulp, np.max(ret)))
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def nulp_diff(x, y, dtype=None):
|
||
|
"""For each item in x and y, return the number of representable floating
|
||
|
points between them.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
x : array_like
|
||
|
first input array
|
||
|
y : array_like
|
||
|
second input array
|
||
|
dtype : dtype, optional
|
||
|
Data-type to convert `x` and `y` to if given. Default is None.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
nulp : array_like
|
||
|
number of representable floating point numbers between each item in x
|
||
|
and y.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
For computing the ULP difference, this API does not differentiate between
|
||
|
various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000
|
||
|
is zero).
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
# By definition, epsilon is the smallest number such as 1 + eps != 1, so
|
||
|
# there should be exactly one ULP between 1 and 1 + eps
|
||
|
>>> nulp_diff(1, 1 + np.finfo(x.dtype).eps)
|
||
|
1.0
|
||
|
"""
|
||
|
import numpy as np
|
||
|
if dtype:
|
||
|
x = np.array(x, dtype=dtype)
|
||
|
y = np.array(y, dtype=dtype)
|
||
|
else:
|
||
|
x = np.array(x)
|
||
|
y = np.array(y)
|
||
|
|
||
|
t = np.common_type(x, y)
|
||
|
if np.iscomplexobj(x) or np.iscomplexobj(y):
|
||
|
raise NotImplementedError("_nulp not implemented for complex array")
|
||
|
|
||
|
x = np.array([x], dtype=t)
|
||
|
y = np.array([y], dtype=t)
|
||
|
|
||
|
x[np.isnan(x)] = np.nan
|
||
|
y[np.isnan(y)] = np.nan
|
||
|
|
||
|
if not x.shape == y.shape:
|
||
|
raise ValueError("x and y do not have the same shape: %s - %s" %
|
||
|
(x.shape, y.shape))
|
||
|
|
||
|
def _diff(rx, ry, vdt):
|
||
|
diff = np.array(rx-ry, dtype=vdt)
|
||
|
return np.abs(diff)
|
||
|
|
||
|
rx = integer_repr(x)
|
||
|
ry = integer_repr(y)
|
||
|
return _diff(rx, ry, t)
|
||
|
|
||
|
|
||
|
def _integer_repr(x, vdt, comp):
|
||
|
# Reinterpret binary representation of the float as sign-magnitude:
|
||
|
# take into account two-complement representation
|
||
|
# See also
|
||
|
# https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/
|
||
|
rx = x.view(vdt)
|
||
|
if not (rx.size == 1):
|
||
|
rx[rx < 0] = comp - rx[rx < 0]
|
||
|
else:
|
||
|
if rx < 0:
|
||
|
rx = comp - rx
|
||
|
|
||
|
return rx
|
||
|
|
||
|
|
||
|
def integer_repr(x):
|
||
|
"""Return the signed-magnitude interpretation of the binary representation of
|
||
|
x."""
|
||
|
import numpy as np
|
||
|
if x.dtype == np.float16:
|
||
|
return _integer_repr(x, np.int16, np.int16(-2**15))
|
||
|
elif x.dtype == np.float32:
|
||
|
return _integer_repr(x, np.int32, np.int32(-2**31))
|
||
|
elif x.dtype == np.float64:
|
||
|
return _integer_repr(x, np.int64, np.int64(-2**63))
|
||
|
else:
|
||
|
raise ValueError("Unsupported dtype %s" % x.dtype)
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _assert_warns_context(warning_class, name=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
with suppress_warnings() as sup:
|
||
|
l = sup.record(warning_class)
|
||
|
yield
|
||
|
if not len(l) > 0:
|
||
|
name_str = " when calling %s" % name if name is not None else ""
|
||
|
raise AssertionError("No warning raised" + name_str)
|
||
|
|
||
|
|
||
|
def assert_warns(warning_class, *args, **kwargs):
|
||
|
"""
|
||
|
Fail unless the given callable throws the specified warning.
|
||
|
|
||
|
A warning of class warning_class should be thrown by the callable when
|
||
|
invoked with arguments args and keyword arguments kwargs.
|
||
|
If a different type of warning is thrown, it will not be caught.
|
||
|
|
||
|
If called with all arguments other than the warning class omitted, may be
|
||
|
used as a context manager:
|
||
|
|
||
|
with assert_warns(SomeWarning):
|
||
|
do_something()
|
||
|
|
||
|
The ability to be used as a context manager is new in NumPy v1.11.0.
|
||
|
|
||
|
.. versionadded:: 1.4.0
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
warning_class : class
|
||
|
The class defining the warning that `func` is expected to throw.
|
||
|
func : callable
|
||
|
The callable to test.
|
||
|
\\*args : Arguments
|
||
|
Arguments passed to `func`.
|
||
|
\\*\\*kwargs : Kwargs
|
||
|
Keyword arguments passed to `func`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
The value returned by `func`.
|
||
|
|
||
|
"""
|
||
|
if not args:
|
||
|
return _assert_warns_context(warning_class)
|
||
|
|
||
|
func = args[0]
|
||
|
args = args[1:]
|
||
|
with _assert_warns_context(warning_class, name=func.__name__):
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _assert_no_warnings_context(name=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
with warnings.catch_warnings(record=True) as l:
|
||
|
warnings.simplefilter('always')
|
||
|
yield
|
||
|
if len(l) > 0:
|
||
|
name_str = " when calling %s" % name if name is not None else ""
|
||
|
raise AssertionError("Got warnings%s: %s" % (name_str, l))
|
||
|
|
||
|
|
||
|
def assert_no_warnings(*args, **kwargs):
|
||
|
"""
|
||
|
Fail if the given callable produces any warnings.
|
||
|
|
||
|
If called with all arguments omitted, may be used as a context manager:
|
||
|
|
||
|
with assert_no_warnings():
|
||
|
do_something()
|
||
|
|
||
|
The ability to be used as a context manager is new in NumPy v1.11.0.
|
||
|
|
||
|
.. versionadded:: 1.7.0
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
func : callable
|
||
|
The callable to test.
|
||
|
\\*args : Arguments
|
||
|
Arguments passed to `func`.
|
||
|
\\*\\*kwargs : Kwargs
|
||
|
Keyword arguments passed to `func`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
The value returned by `func`.
|
||
|
|
||
|
"""
|
||
|
if not args:
|
||
|
return _assert_no_warnings_context()
|
||
|
|
||
|
func = args[0]
|
||
|
args = args[1:]
|
||
|
with _assert_no_warnings_context(name=func.__name__):
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def _gen_alignment_data(dtype=float32, type='binary', max_size=24):
|
||
|
"""
|
||
|
generator producing data with different alignment and offsets
|
||
|
to test simd vectorization
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
dtype : dtype
|
||
|
data type to produce
|
||
|
type : string
|
||
|
'unary': create data for unary operations, creates one input
|
||
|
and output array
|
||
|
'binary': create data for unary operations, creates two input
|
||
|
and output array
|
||
|
max_size : integer
|
||
|
maximum size of data to produce
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
if type is 'unary' yields one output, one input array and a message
|
||
|
containing information on the data
|
||
|
if type is 'binary' yields one output array, two input array and a message
|
||
|
containing information on the data
|
||
|
|
||
|
"""
|
||
|
ufmt = 'unary offset=(%d, %d), size=%d, dtype=%r, %s'
|
||
|
bfmt = 'binary offset=(%d, %d, %d), size=%d, dtype=%r, %s'
|
||
|
for o in range(3):
|
||
|
for s in range(o + 2, max(o + 3, max_size)):
|
||
|
if type == 'unary':
|
||
|
inp = lambda: arange(s, dtype=dtype)[o:]
|
||
|
out = empty((s,), dtype=dtype)[o:]
|
||
|
yield out, inp(), ufmt % (o, o, s, dtype, 'out of place')
|
||
|
d = inp()
|
||
|
yield d, d, ufmt % (o, o, s, dtype, 'in place')
|
||
|
yield out[1:], inp()[:-1], ufmt % \
|
||
|
(o + 1, o, s - 1, dtype, 'out of place')
|
||
|
yield out[:-1], inp()[1:], ufmt % \
|
||
|
(o, o + 1, s - 1, dtype, 'out of place')
|
||
|
yield inp()[:-1], inp()[1:], ufmt % \
|
||
|
(o, o + 1, s - 1, dtype, 'aliased')
|
||
|
yield inp()[1:], inp()[:-1], ufmt % \
|
||
|
(o + 1, o, s - 1, dtype, 'aliased')
|
||
|
if type == 'binary':
|
||
|
inp1 = lambda: arange(s, dtype=dtype)[o:]
|
||
|
inp2 = lambda: arange(s, dtype=dtype)[o:]
|
||
|
out = empty((s,), dtype=dtype)[o:]
|
||
|
yield out, inp1(), inp2(), bfmt % \
|
||
|
(o, o, o, s, dtype, 'out of place')
|
||
|
d = inp1()
|
||
|
yield d, d, inp2(), bfmt % \
|
||
|
(o, o, o, s, dtype, 'in place1')
|
||
|
d = inp2()
|
||
|
yield d, inp1(), d, bfmt % \
|
||
|
(o, o, o, s, dtype, 'in place2')
|
||
|
yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % \
|
||
|
(o + 1, o, o, s - 1, dtype, 'out of place')
|
||
|
yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % \
|
||
|
(o, o + 1, o, s - 1, dtype, 'out of place')
|
||
|
yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % \
|
||
|
(o, o, o + 1, s - 1, dtype, 'out of place')
|
||
|
yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % \
|
||
|
(o + 1, o, o, s - 1, dtype, 'aliased')
|
||
|
yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % \
|
||
|
(o, o + 1, o, s - 1, dtype, 'aliased')
|
||
|
yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % \
|
||
|
(o, o, o + 1, s - 1, dtype, 'aliased')
|
||
|
|
||
|
|
||
|
class IgnoreException(Exception):
|
||
|
"Ignoring this exception due to disabled feature"
|
||
|
pass
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def tempdir(*args, **kwargs):
|
||
|
"""Context manager to provide a temporary test folder.
|
||
|
|
||
|
All arguments are passed as this to the underlying tempfile.mkdtemp
|
||
|
function.
|
||
|
|
||
|
"""
|
||
|
tmpdir = mkdtemp(*args, **kwargs)
|
||
|
try:
|
||
|
yield tmpdir
|
||
|
finally:
|
||
|
shutil.rmtree(tmpdir)
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def temppath(*args, **kwargs):
|
||
|
"""Context manager for temporary files.
|
||
|
|
||
|
Context manager that returns the path to a closed temporary file. Its
|
||
|
parameters are the same as for tempfile.mkstemp and are passed directly
|
||
|
to that function. The underlying file is removed when the context is
|
||
|
exited, so it should be closed at that time.
|
||
|
|
||
|
Windows does not allow a temporary file to be opened if it is already
|
||
|
open, so the underlying file must be closed after opening before it
|
||
|
can be opened again.
|
||
|
|
||
|
"""
|
||
|
fd, path = mkstemp(*args, **kwargs)
|
||
|
os.close(fd)
|
||
|
try:
|
||
|
yield path
|
||
|
finally:
|
||
|
os.remove(path)
|
||
|
|
||
|
|
||
|
class clear_and_catch_warnings(warnings.catch_warnings):
|
||
|
""" Context manager that resets warning registry for catching warnings
|
||
|
|
||
|
Warnings can be slippery, because, whenever a warning is triggered, Python
|
||
|
adds a ``__warningregistry__`` member to the *calling* module. This makes
|
||
|
it impossible to retrigger the warning in this module, whatever you put in
|
||
|
the warnings filters. This context manager accepts a sequence of `modules`
|
||
|
as a keyword argument to its constructor and:
|
||
|
|
||
|
* stores and removes any ``__warningregistry__`` entries in given `modules`
|
||
|
on entry;
|
||
|
* resets ``__warningregistry__`` to its previous state on exit.
|
||
|
|
||
|
This makes it possible to trigger any warning afresh inside the context
|
||
|
manager without disturbing the state of warnings outside.
|
||
|
|
||
|
For compatibility with Python 3.0, please consider all arguments to be
|
||
|
keyword-only.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
record : bool, optional
|
||
|
Specifies whether warnings should be captured by a custom
|
||
|
implementation of ``warnings.showwarning()`` and be appended to a list
|
||
|
returned by the context manager. Otherwise None is returned by the
|
||
|
context manager. The objects appended to the list are arguments whose
|
||
|
attributes mirror the arguments to ``showwarning()``.
|
||
|
modules : sequence, optional
|
||
|
Sequence of modules for which to reset warnings registry on entry and
|
||
|
restore on exit. To work correctly, all 'ignore' filters should
|
||
|
filter by one of these modules.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> import warnings
|
||
|
>>> with np.testing.clear_and_catch_warnings(
|
||
|
... modules=[np.core.fromnumeric]):
|
||
|
... warnings.simplefilter('always')
|
||
|
... warnings.filterwarnings('ignore', module='np.core.fromnumeric')
|
||
|
... # do something that raises a warning but ignore those in
|
||
|
... # np.core.fromnumeric
|
||
|
"""
|
||
|
class_modules = ()
|
||
|
|
||
|
def __init__(self, record=False, modules=()):
|
||
|
self.modules = set(modules).union(self.class_modules)
|
||
|
self._warnreg_copies = {}
|
||
|
super(clear_and_catch_warnings, self).__init__(record=record)
|
||
|
|
||
|
def __enter__(self):
|
||
|
for mod in self.modules:
|
||
|
if hasattr(mod, '__warningregistry__'):
|
||
|
mod_reg = mod.__warningregistry__
|
||
|
self._warnreg_copies[mod] = mod_reg.copy()
|
||
|
mod_reg.clear()
|
||
|
return super(clear_and_catch_warnings, self).__enter__()
|
||
|
|
||
|
def __exit__(self, *exc_info):
|
||
|
super(clear_and_catch_warnings, self).__exit__(*exc_info)
|
||
|
for mod in self.modules:
|
||
|
if hasattr(mod, '__warningregistry__'):
|
||
|
mod.__warningregistry__.clear()
|
||
|
if mod in self._warnreg_copies:
|
||
|
mod.__warningregistry__.update(self._warnreg_copies[mod])
|
||
|
|
||
|
|
||
|
class suppress_warnings:
|
||
|
"""
|
||
|
Context manager and decorator doing much the same as
|
||
|
``warnings.catch_warnings``.
|
||
|
|
||
|
However, it also provides a filter mechanism to work around
|
||
|
https://bugs.python.org/issue4180.
|
||
|
|
||
|
This bug causes Python before 3.4 to not reliably show warnings again
|
||
|
after they have been ignored once (even within catch_warnings). It
|
||
|
means that no "ignore" filter can be used easily, since following
|
||
|
tests might need to see the warning. Additionally it allows easier
|
||
|
specificity for testing warnings and can be nested.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
forwarding_rule : str, optional
|
||
|
One of "always", "once", "module", or "location". Analogous to
|
||
|
the usual warnings module filter mode, it is useful to reduce
|
||
|
noise mostly on the outmost level. Unsuppressed and unrecorded
|
||
|
warnings will be forwarded based on this rule. Defaults to "always".
|
||
|
"location" is equivalent to the warnings "default", match by exact
|
||
|
location the warning warning originated from.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
Filters added inside the context manager will be discarded again
|
||
|
when leaving it. Upon entering all filters defined outside a
|
||
|
context will be applied automatically.
|
||
|
|
||
|
When a recording filter is added, matching warnings are stored in the
|
||
|
``log`` attribute as well as in the list returned by ``record``.
|
||
|
|
||
|
If filters are added and the ``module`` keyword is given, the
|
||
|
warning registry of this module will additionally be cleared when
|
||
|
applying it, entering the context, or exiting it. This could cause
|
||
|
warnings to appear a second time after leaving the context if they
|
||
|
were configured to be printed once (default) and were already
|
||
|
printed before the context was entered.
|
||
|
|
||
|
Nesting this context manager will work as expected when the
|
||
|
forwarding rule is "always" (default). Unfiltered and unrecorded
|
||
|
warnings will be passed out and be matched by the outer level.
|
||
|
On the outmost level they will be printed (or caught by another
|
||
|
warnings context). The forwarding rule argument can modify this
|
||
|
behaviour.
|
||
|
|
||
|
Like ``catch_warnings`` this context manager is not threadsafe.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
|
||
|
With a context manager::
|
||
|
|
||
|
with np.testing.suppress_warnings() as sup:
|
||
|
sup.filter(DeprecationWarning, "Some text")
|
||
|
sup.filter(module=np.ma.core)
|
||
|
log = sup.record(FutureWarning, "Does this occur?")
|
||
|
command_giving_warnings()
|
||
|
# The FutureWarning was given once, the filtered warnings were
|
||
|
# ignored. All other warnings abide outside settings (may be
|
||
|
# printed/error)
|
||
|
assert_(len(log) == 1)
|
||
|
assert_(len(sup.log) == 1) # also stored in log attribute
|
||
|
|
||
|
Or as a decorator::
|
||
|
|
||
|
sup = np.testing.suppress_warnings()
|
||
|
sup.filter(module=np.ma.core) # module must match exactly
|
||
|
@sup
|
||
|
def some_function():
|
||
|
# do something which causes a warning in np.ma.core
|
||
|
pass
|
||
|
"""
|
||
|
def __init__(self, forwarding_rule="always"):
|
||
|
self._entered = False
|
||
|
|
||
|
# Suppressions are either instance or defined inside one with block:
|
||
|
self._suppressions = []
|
||
|
|
||
|
if forwarding_rule not in {"always", "module", "once", "location"}:
|
||
|
raise ValueError("unsupported forwarding rule.")
|
||
|
self._forwarding_rule = forwarding_rule
|
||
|
|
||
|
def _clear_registries(self):
|
||
|
if hasattr(warnings, "_filters_mutated"):
|
||
|
# clearing the registry should not be necessary on new pythons,
|
||
|
# instead the filters should be mutated.
|
||
|
warnings._filters_mutated()
|
||
|
return
|
||
|
# Simply clear the registry, this should normally be harmless,
|
||
|
# note that on new pythons it would be invalidated anyway.
|
||
|
for module in self._tmp_modules:
|
||
|
if hasattr(module, "__warningregistry__"):
|
||
|
module.__warningregistry__.clear()
|
||
|
|
||
|
def _filter(self, category=Warning, message="", module=None, record=False):
|
||
|
if record:
|
||
|
record = [] # The log where to store warnings
|
||
|
else:
|
||
|
record = None
|
||
|
if self._entered:
|
||
|
if module is None:
|
||
|
warnings.filterwarnings(
|
||
|
"always", category=category, message=message)
|
||
|
else:
|
||
|
module_regex = module.__name__.replace('.', r'\.') + '$'
|
||
|
warnings.filterwarnings(
|
||
|
"always", category=category, message=message,
|
||
|
module=module_regex)
|
||
|
self._tmp_modules.add(module)
|
||
|
self._clear_registries()
|
||
|
|
||
|
self._tmp_suppressions.append(
|
||
|
(category, message, re.compile(message, re.I), module, record))
|
||
|
else:
|
||
|
self._suppressions.append(
|
||
|
(category, message, re.compile(message, re.I), module, record))
|
||
|
|
||
|
return record
|
||
|
|
||
|
def filter(self, category=Warning, message="", module=None):
|
||
|
"""
|
||
|
Add a new suppressing filter or apply it if the state is entered.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
category : class, optional
|
||
|
Warning class to filter
|
||
|
message : string, optional
|
||
|
Regular expression matching the warning message.
|
||
|
module : module, optional
|
||
|
Module to filter for. Note that the module (and its file)
|
||
|
must match exactly and cannot be a submodule. This may make
|
||
|
it unreliable for external modules.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
When added within a context, filters are only added inside
|
||
|
the context and will be forgotten when the context is exited.
|
||
|
"""
|
||
|
self._filter(category=category, message=message, module=module,
|
||
|
record=False)
|
||
|
|
||
|
def record(self, category=Warning, message="", module=None):
|
||
|
"""
|
||
|
Append a new recording filter or apply it if the state is entered.
|
||
|
|
||
|
All warnings matching will be appended to the ``log`` attribute.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
category : class, optional
|
||
|
Warning class to filter
|
||
|
message : string, optional
|
||
|
Regular expression matching the warning message.
|
||
|
module : module, optional
|
||
|
Module to filter for. Note that the module (and its file)
|
||
|
must match exactly and cannot be a submodule. This may make
|
||
|
it unreliable for external modules.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
log : list
|
||
|
A list which will be filled with all matched warnings.
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
When added within a context, filters are only added inside
|
||
|
the context and will be forgotten when the context is exited.
|
||
|
"""
|
||
|
return self._filter(category=category, message=message, module=module,
|
||
|
record=True)
|
||
|
|
||
|
def __enter__(self):
|
||
|
if self._entered:
|
||
|
raise RuntimeError("cannot enter suppress_warnings twice.")
|
||
|
|
||
|
self._orig_show = warnings.showwarning
|
||
|
self._filters = warnings.filters
|
||
|
warnings.filters = self._filters[:]
|
||
|
|
||
|
self._entered = True
|
||
|
self._tmp_suppressions = []
|
||
|
self._tmp_modules = set()
|
||
|
self._forwarded = set()
|
||
|
|
||
|
self.log = [] # reset global log (no need to keep same list)
|
||
|
|
||
|
for cat, mess, _, mod, log in self._suppressions:
|
||
|
if log is not None:
|
||
|
del log[:] # clear the log
|
||
|
if mod is None:
|
||
|
warnings.filterwarnings(
|
||
|
"always", category=cat, message=mess)
|
||
|
else:
|
||
|
module_regex = mod.__name__.replace('.', r'\.') + '$'
|
||
|
warnings.filterwarnings(
|
||
|
"always", category=cat, message=mess,
|
||
|
module=module_regex)
|
||
|
self._tmp_modules.add(mod)
|
||
|
warnings.showwarning = self._showwarning
|
||
|
self._clear_registries()
|
||
|
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, *exc_info):
|
||
|
warnings.showwarning = self._orig_show
|
||
|
warnings.filters = self._filters
|
||
|
self._clear_registries()
|
||
|
self._entered = False
|
||
|
del self._orig_show
|
||
|
del self._filters
|
||
|
|
||
|
def _showwarning(self, message, category, filename, lineno,
|
||
|
*args, use_warnmsg=None, **kwargs):
|
||
|
for cat, _, pattern, mod, rec in (
|
||
|
self._suppressions + self._tmp_suppressions)[::-1]:
|
||
|
if (issubclass(category, cat) and
|
||
|
pattern.match(message.args[0]) is not None):
|
||
|
if mod is None:
|
||
|
# Message and category match, either recorded or ignored
|
||
|
if rec is not None:
|
||
|
msg = WarningMessage(message, category, filename,
|
||
|
lineno, **kwargs)
|
||
|
self.log.append(msg)
|
||
|
rec.append(msg)
|
||
|
return
|
||
|
# Use startswith, because warnings strips the c or o from
|
||
|
# .pyc/.pyo files.
|
||
|
elif mod.__file__.startswith(filename):
|
||
|
# The message and module (filename) match
|
||
|
if rec is not None:
|
||
|
msg = WarningMessage(message, category, filename,
|
||
|
lineno, **kwargs)
|
||
|
self.log.append(msg)
|
||
|
rec.append(msg)
|
||
|
return
|
||
|
|
||
|
# There is no filter in place, so pass to the outside handler
|
||
|
# unless we should only pass it once
|
||
|
if self._forwarding_rule == "always":
|
||
|
if use_warnmsg is None:
|
||
|
self._orig_show(message, category, filename, lineno,
|
||
|
*args, **kwargs)
|
||
|
else:
|
||
|
self._orig_showmsg(use_warnmsg)
|
||
|
return
|
||
|
|
||
|
if self._forwarding_rule == "once":
|
||
|
signature = (message.args, category)
|
||
|
elif self._forwarding_rule == "module":
|
||
|
signature = (message.args, category, filename)
|
||
|
elif self._forwarding_rule == "location":
|
||
|
signature = (message.args, category, filename, lineno)
|
||
|
|
||
|
if signature in self._forwarded:
|
||
|
return
|
||
|
self._forwarded.add(signature)
|
||
|
if use_warnmsg is None:
|
||
|
self._orig_show(message, category, filename, lineno, *args,
|
||
|
**kwargs)
|
||
|
else:
|
||
|
self._orig_showmsg(use_warnmsg)
|
||
|
|
||
|
def __call__(self, func):
|
||
|
"""
|
||
|
Function decorator to apply certain suppressions to a whole
|
||
|
function.
|
||
|
"""
|
||
|
@wraps(func)
|
||
|
def new_func(*args, **kwargs):
|
||
|
with self:
|
||
|
return func(*args, **kwargs)
|
||
|
|
||
|
return new_func
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _assert_no_gc_cycles_context(name=None):
|
||
|
__tracebackhide__ = True # Hide traceback for py.test
|
||
|
|
||
|
# not meaningful to test if there is no refcounting
|
||
|
if not HAS_REFCOUNT:
|
||
|
yield
|
||
|
return
|
||
|
|
||
|
assert_(gc.isenabled())
|
||
|
gc.disable()
|
||
|
gc_debug = gc.get_debug()
|
||
|
try:
|
||
|
for i in range(100):
|
||
|
if gc.collect() == 0:
|
||
|
break
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
"Unable to fully collect garbage - perhaps a __del__ method is "
|
||
|
"creating more reference cycles?")
|
||
|
|
||
|
gc.set_debug(gc.DEBUG_SAVEALL)
|
||
|
yield
|
||
|
# gc.collect returns the number of unreachable objects in cycles that
|
||
|
# were found -- we are checking that no cycles were created in the context
|
||
|
n_objects_in_cycles = gc.collect()
|
||
|
objects_in_cycles = gc.garbage[:]
|
||
|
finally:
|
||
|
del gc.garbage[:]
|
||
|
gc.set_debug(gc_debug)
|
||
|
gc.enable()
|
||
|
|
||
|
if n_objects_in_cycles:
|
||
|
name_str = " when calling %s" % name if name is not None else ""
|
||
|
raise AssertionError(
|
||
|
"Reference cycles were found{}: {} objects were collected, "
|
||
|
"of which {} are shown below:{}"
|
||
|
.format(
|
||
|
name_str,
|
||
|
n_objects_in_cycles,
|
||
|
len(objects_in_cycles),
|
||
|
''.join(
|
||
|
"\n {} object with id={}:\n {}".format(
|
||
|
type(o).__name__,
|
||
|
id(o),
|
||
|
pprint.pformat(o).replace('\n', '\n ')
|
||
|
) for o in objects_in_cycles
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
def assert_no_gc_cycles(*args, **kwargs):
|
||
|
"""
|
||
|
Fail if the given callable produces any reference cycles.
|
||
|
|
||
|
If called with all arguments omitted, may be used as a context manager:
|
||
|
|
||
|
with assert_no_gc_cycles():
|
||
|
do_something()
|
||
|
|
||
|
.. versionadded:: 1.15.0
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
func : callable
|
||
|
The callable to test.
|
||
|
\\*args : Arguments
|
||
|
Arguments passed to `func`.
|
||
|
\\*\\*kwargs : Kwargs
|
||
|
Keyword arguments passed to `func`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
Nothing. The result is deliberately discarded to ensure that all cycles
|
||
|
are found.
|
||
|
|
||
|
"""
|
||
|
if not args:
|
||
|
return _assert_no_gc_cycles_context()
|
||
|
|
||
|
func = args[0]
|
||
|
args = args[1:]
|
||
|
with _assert_no_gc_cycles_context(name=func.__name__):
|
||
|
func(*args, **kwargs)
|
||
|
|
||
|
def break_cycles():
|
||
|
"""
|
||
|
Break reference cycles by calling gc.collect
|
||
|
Objects can call other objects' methods (for instance, another object's
|
||
|
__del__) inside their own __del__. On PyPy, the interpreter only runs
|
||
|
between calls to gc.collect, so multiple calls are needed to completely
|
||
|
release all cycles.
|
||
|
"""
|
||
|
|
||
|
gc.collect()
|
||
|
if IS_PYPY:
|
||
|
# interpreter runs now, to call deleted objects' __del__ methods
|
||
|
gc.collect()
|
||
|
# one more, just to make sure
|
||
|
gc.collect()
|
||
|
|
||
|
|
||
|
def requires_memory(free_bytes):
|
||
|
"""Decorator to skip a test if not enough memory is available"""
|
||
|
import pytest
|
||
|
|
||
|
def decorator(func):
|
||
|
@wraps(func)
|
||
|
def wrapper(*a, **kw):
|
||
|
msg = check_free_memory(free_bytes)
|
||
|
if msg is not None:
|
||
|
pytest.skip(msg)
|
||
|
|
||
|
try:
|
||
|
return func(*a, **kw)
|
||
|
except MemoryError:
|
||
|
# Probably ran out of memory regardless: don't regard as failure
|
||
|
pytest.xfail("MemoryError raised")
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
def check_free_memory(free_bytes):
|
||
|
"""
|
||
|
Check whether `free_bytes` amount of memory is currently free.
|
||
|
Returns: None if enough memory available, otherwise error message
|
||
|
"""
|
||
|
env_var = 'NPY_AVAILABLE_MEM'
|
||
|
env_value = os.environ.get(env_var)
|
||
|
if env_value is not None:
|
||
|
try:
|
||
|
mem_free = _parse_size(env_value)
|
||
|
except ValueError as exc:
|
||
|
raise ValueError('Invalid environment variable {}: {!s}'.format(
|
||
|
env_var, exc))
|
||
|
|
||
|
msg = ('{0} GB memory required, but environment variable '
|
||
|
'NPY_AVAILABLE_MEM={1} set'.format(
|
||
|
free_bytes/1e9, env_value))
|
||
|
else:
|
||
|
mem_free = _get_mem_available()
|
||
|
|
||
|
if mem_free is None:
|
||
|
msg = ("Could not determine available memory; set NPY_AVAILABLE_MEM "
|
||
|
"environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run "
|
||
|
"the test.")
|
||
|
mem_free = -1
|
||
|
else:
|
||
|
msg = '{0} GB memory required, but {1} GB available'.format(
|
||
|
free_bytes/1e9, mem_free/1e9)
|
||
|
|
||
|
return msg if mem_free < free_bytes else None
|
||
|
|
||
|
|
||
|
def _parse_size(size_str):
|
||
|
"""Convert memory size strings ('12 GB' etc.) to float"""
|
||
|
suffixes = {'': 1, 'b': 1,
|
||
|
'k': 1000, 'm': 1000**2, 'g': 1000**3, 't': 1000**4,
|
||
|
'kb': 1000, 'mb': 1000**2, 'gb': 1000**3, 'tb': 1000**4,
|
||
|
'kib': 1024, 'mib': 1024**2, 'gib': 1024**3, 'tib': 1024**4}
|
||
|
|
||
|
size_re = re.compile(r'^\s*(\d+|\d+\.\d+)\s*({0})\s*$'.format(
|
||
|
'|'.join(suffixes.keys())), re.I)
|
||
|
|
||
|
m = size_re.match(size_str.lower())
|
||
|
if not m or m.group(2) not in suffixes:
|
||
|
raise ValueError("value {!r} not a valid size".format(size_str))
|
||
|
return int(float(m.group(1)) * suffixes[m.group(2)])
|
||
|
|
||
|
|
||
|
def _get_mem_available():
|
||
|
"""Return available memory in bytes, or None if unknown."""
|
||
|
try:
|
||
|
import psutil
|
||
|
return psutil.virtual_memory().available
|
||
|
except (ImportError, AttributeError):
|
||
|
pass
|
||
|
|
||
|
if sys.platform.startswith('linux'):
|
||
|
info = {}
|
||
|
with open('/proc/meminfo', 'r') as f:
|
||
|
for line in f:
|
||
|
p = line.split()
|
||
|
info[p[0].strip(':').lower()] = int(p[1]) * 1024
|
||
|
|
||
|
if 'memavailable' in info:
|
||
|
# Linux >= 3.14
|
||
|
return info['memavailable']
|
||
|
else:
|
||
|
return info['memfree'] + info['cached']
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _no_tracing(func):
|
||
|
"""
|
||
|
Decorator to temporarily turn off tracing for the duration of a test.
|
||
|
Needed in tests that check refcounting, otherwise the tracing itself
|
||
|
influences the refcounts
|
||
|
"""
|
||
|
if not hasattr(sys, 'gettrace'):
|
||
|
return func
|
||
|
else:
|
||
|
@wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
original_trace = sys.gettrace()
|
||
|
try:
|
||
|
sys.settrace(None)
|
||
|
return func(*args, **kwargs)
|
||
|
finally:
|
||
|
sys.settrace(original_trace)
|
||
|
return wrapper
|
||
|
|