Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,101 @@
""" Test for assert_deallocated context manager and gc utilities
"""
import gc
from scipy._lib._gcutils import (set_gc_state, gc_state, assert_deallocated,
ReferenceError, IS_PYPY)
from numpy.testing import assert_equal
import pytest
def test_set_gc_state():
gc_status = gc.isenabled()
try:
for state in (True, False):
gc.enable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
gc.disable()
set_gc_state(state)
assert_equal(gc.isenabled(), state)
finally:
if gc_status:
gc.enable()
def test_gc_state():
# Test gc_state context manager
gc_status = gc.isenabled()
try:
for pre_state in (True, False):
set_gc_state(pre_state)
for with_state in (True, False):
# Check the gc state is with_state in with block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
# And returns to previous state outside block
assert_equal(gc.isenabled(), pre_state)
# Even if the gc state is set explicitly within the block
with gc_state(with_state):
assert_equal(gc.isenabled(), with_state)
set_gc_state(not with_state)
assert_equal(gc.isenabled(), pre_state)
finally:
if gc_status:
gc.enable()
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated():
# Ordinary use
class C(object):
def __init__(self, arg0, arg1, name='myname'):
self.name = name
for gc_current in (True, False):
with gc_state(gc_current):
# We are deleting from with-block context, so that's OK
with assert_deallocated(C, 0, 2, 'another name') as c:
assert_equal(c.name, 'another name')
del c
# Or not using the thing in with-block context, also OK
with assert_deallocated(C, 0, 2, name='third name'):
pass
assert_equal(gc.isenabled(), gc_current)
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_nodel():
class C(object):
pass
with pytest.raises(ReferenceError):
# Need to delete after using if in with-block context
# Note: assert_deallocated(C) needs to be assigned for the test
# to function correctly. It is assigned to c, but c itself is
# not referenced in the body of the with, it is only there for
# the refcount.
with assert_deallocated(C) as c:
pass
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular():
class C(object):
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Circular reference, no automatic garbage collection
with assert_deallocated(C) as c:
del c
@pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
def test_assert_deallocated_circular2():
class C(object):
def __init__(self):
self._circular = self
with pytest.raises(ReferenceError):
# Still circular reference, no automatic garbage collection
with assert_deallocated(C):
pass

View file

@ -0,0 +1,67 @@
from pytest import raises as assert_raises
from scipy._lib._pep440 import Version, parse
def test_main_versions():
assert Version('1.8.0') == Version('1.8.0')
for ver in ['1.9.0', '2.0.0', '1.8.1']:
assert Version('1.8.0') < Version(ver)
for ver in ['1.7.0', '1.7.1', '0.9.9']:
assert Version('1.8.0') > Version(ver)
def test_version_1_point_10():
# regression test for gh-2998.
assert Version('1.9.0') < Version('1.10.0')
assert Version('1.11.0') < Version('1.11.1')
assert Version('1.11.0') == Version('1.11.0')
assert Version('1.99.11') < Version('1.99.12')
def test_alpha_beta_rc():
assert Version('1.8.0rc1') == Version('1.8.0rc1')
for ver in ['1.8.0', '1.8.0rc2']:
assert Version('1.8.0rc1') < Version(ver)
for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
assert Version('1.8.0rc1') > Version(ver)
assert Version('1.8.0b1') > Version('1.8.0a2')
def test_dev_version():
assert Version('1.9.0.dev+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev+ffffffff', '1.9.0.dev1']:
assert Version('1.9.0.dev+f16acvda') < Version(ver)
assert Version('1.9.0.dev+f16acvda') == Version('1.9.0.dev+f16acvda')
def test_dev_a_b_rc_mixed():
assert Version('1.9.0a2.dev+f16acvda') == Version('1.9.0a2.dev+f16acvda')
assert Version('1.9.0a2.dev+6acvda54') < Version('1.9.0a2')
def test_dev0_version():
assert Version('1.9.0.dev0+Unknown') < Version('1.9.0')
for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
assert Version('1.9.0.dev0+f16acvda') < Version(ver)
assert Version('1.9.0.dev0+f16acvda') == Version('1.9.0.dev0+f16acvda')
def test_dev0_a_b_rc_mixed():
assert Version('1.9.0a2.dev0+f16acvda') == Version('1.9.0a2.dev0+f16acvda')
assert Version('1.9.0a2.dev0+6acvda54') < Version('1.9.0a2')
def test_raises():
for ver in ['1,9.0', '1.7.x']:
assert_raises(ValueError, Version, ver)
def test_legacy_version():
# Non-PEP-440 version identifiers always compare less. For NumPy this only
# occurs on dev builds prior to 1.10.0 which are unsupported anyway.
assert parse('invalid') < Version('0.0.0')
assert parse('1.9.0-f16acvda') < Version('1.0.0')

View file

@ -0,0 +1,32 @@
import sys
from scipy._lib._testutils import _parse_size, _get_mem_available
import pytest
def test__parse_size():
expected = {
'12': 12e6,
'12 b': 12,
'12k': 12e3,
' 12 M ': 12e6,
' 12 G ': 12e9,
' 12Tb ': 12e12,
'12 Mib ': 12 * 1024.0**2,
'12Tib': 12 * 1024.0**4,
}
for inp, outp in sorted(expected.items()):
if outp is None:
with pytest.raises(ValueError):
_parse_size(inp)
else:
assert _parse_size(inp) == outp
def test__mem_available():
# May return None on non-Linux platforms
available = _get_mem_available()
if sys.platform.startswith('linux'):
assert available >= 0
else:
assert available is None or available >= 0

View file

@ -0,0 +1,51 @@
import threading
import time
import traceback
from numpy.testing import assert_
from pytest import raises as assert_raises
from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
def test_parallel_threads():
# Check that ReentrancyLock serializes work in parallel threads.
#
# The test is not fully deterministic, and may succeed falsely if
# the timings go wrong.
lock = ReentrancyLock("failure")
failflag = [False]
exceptions_raised = []
def worker(k):
try:
with lock:
assert_(not failflag[0])
failflag[0] = True
time.sleep(0.1 * k)
assert_(failflag[0])
failflag[0] = False
except Exception:
exceptions_raised.append(traceback.format_exc(2))
threads = [threading.Thread(target=lambda k=k: worker(k))
for k in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
exceptions_raised = "\n".join(exceptions_raised)
assert_(not exceptions_raised, exceptions_raised)
def test_reentering():
# Check that ReentrancyLock prevents re-entering from the same thread.
@non_reentrant()
def func(x):
return func(x)
assert_raises(ReentrancyError, func, 0)

View file

@ -0,0 +1,249 @@
from multiprocessing import Pool
from multiprocessing.pool import Pool as PWL
import os
import math
import numpy as np
from numpy.testing import assert_equal, assert_
import pytest
from pytest import raises as assert_raises, deprecated_call
import scipy
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
getfullargspec_no_self, FullArgSpec,
rng_integers)
def test__aligned_zeros():
niter = 10
def check(shape, dtype, order, align):
err_msg = repr((shape, dtype, order, align))
x = _aligned_zeros(shape, dtype, order, align=align)
if align is None:
align = np.dtype(dtype).alignment
assert_equal(x.__array_interface__['data'][0] % align, 0)
if hasattr(shape, '__len__'):
assert_equal(x.shape, shape, err_msg)
else:
assert_equal(x.shape, (shape,), err_msg)
assert_equal(x.dtype, dtype)
if order == "C":
assert_(x.flags.c_contiguous, err_msg)
elif order == "F":
if x.size > 0:
# Size-0 arrays get invalid flags on NumPy 1.5
assert_(x.flags.f_contiguous, err_msg)
elif order is None:
assert_(x.flags.c_contiguous, err_msg)
else:
raise ValueError()
# try various alignments
for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
for n in [0, 1, 3, 11]:
for order in ["C", "F", None]:
for dtype in [np.uint8, np.float64]:
for shape in [n, (1, 2, 3, n)]:
for j in range(niter):
check(shape, dtype, order, align)
def test_check_random_state():
# If seed is None, return the RandomState singleton used by np.random.
# If seed is an int, return a new RandomState instance seeded with seed.
# If seed is already a RandomState instance, return it.
# Otherwise raise ValueError.
rsi = check_random_state(1)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(rsi)
assert_equal(type(rsi), np.random.RandomState)
rsi = check_random_state(None)
assert_equal(type(rsi), np.random.RandomState)
assert_raises(ValueError, check_random_state, 'a')
if hasattr(np.random, 'Generator'):
# np.random.Generator is only available in NumPy >= 1.17
rg = np.random.Generator(np.random.PCG64())
rsi = check_random_state(rg)
assert_equal(type(rsi), np.random.Generator)
def test_getfullargspec_no_self():
p = MapWrapper(1)
argspec = getfullargspec_no_self(p.__init__)
assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [], None, {}))
argspec = getfullargspec_no_self(p.__call__)
assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None, [], None, {}))
class _rv_generic(object):
def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
return None
rv_obj = _rv_generic()
argspec = getfullargspec_no_self(rv_obj._rvs)
assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs', (2, 3), ['size'], {'size': None}, {}))
def test_mapwrapper_serial():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
p = MapWrapper(1)
assert_(p._mapfunc is map)
assert_(p.pool is None)
assert_(p._own_pool is False)
out = list(p(np.sin, in_arg))
assert_equal(out, out_arg)
with assert_raises(RuntimeError):
p = MapWrapper(0)
def test_pool():
with Pool(2) as p:
p.map(math.sin, [1,2,3, 4])
def test_mapwrapper_parallel():
in_arg = np.arange(10.)
out_arg = np.sin(in_arg)
with MapWrapper(2) as p:
out = p(np.sin, in_arg)
assert_equal(list(out), out_arg)
assert_(p._own_pool is True)
assert_(isinstance(p.pool, PWL))
assert_(p._mapfunc is not None)
# the context manager should've closed the internal pool
# check that it has by asking it to calculate again.
with assert_raises(Exception) as excinfo:
p(np.sin, in_arg)
assert_(excinfo.type is ValueError)
# can also set a PoolWrapper up with a map-like callable instance
try:
p = Pool(2)
q = MapWrapper(p.map)
assert_(q._own_pool is False)
q.close()
# closing the PoolWrapper shouldn't close the internal pool
# because it didn't create it
out = p.map(np.sin, in_arg)
assert_equal(list(out), out_arg)
finally:
p.close()
# get our custom ones and a few from the "import *" cases
@pytest.mark.parametrize(
'key', ('fft', 'ifft', 'diag', 'arccos',
'randn', 'rand', 'array'))
def test_numpy_deprecation(key):
"""Test that 'from numpy import *' functions are deprecated."""
if key in ('fft', 'ifft', 'diag', 'arccos'):
arg = [1.0, 0.]
elif key == 'finfo':
arg = float
else:
arg = 2
func = getattr(scipy, key)
if key == 'fft':
match = r'scipy\.fft.*deprecated.*1.5.0.*'
else:
match = r'scipy\.%s is deprecated.*2\.0\.0' % key
with deprecated_call(match=match) as dep:
func(arg) # deprecated
# in case we catch more than one dep warning
fnames = [os.path.splitext(d.filename)[0] for d in dep.list]
basenames = [os.path.basename(fname) for fname in fnames]
assert 'test__util' in basenames
if key in ('rand', 'randn'):
root = np.random
elif key in ('fft', 'ifft'):
root = np.fft
else:
root = np
func_np = getattr(root, key)
func_np(arg) # not deprecated
assert func_np is not func
# classes should remain classes
if isinstance(func_np, type):
assert isinstance(func, type)
def test_numpy_deprecation_functionality():
# Check that the deprecation wrappers don't break basic NumPy
# functionality
with deprecated_call():
x = scipy.array([1, 2, 3], dtype=scipy.float64)
assert x.dtype == scipy.float64
assert x.dtype == np.float64
x = scipy.finfo(scipy.float32)
assert x.eps == np.finfo(np.float32).eps
assert scipy.float64 == np.float64
assert issubclass(np.float64, scipy.float64)
def test_rng_integers():
rng = np.random.RandomState()
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )
# now try with np.random.Generator
try:
rng = np.random.default_rng()
except AttributeError:
return
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are inclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=True)
assert np.max(arr) == 5
assert np.min(arr) == 0
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 2
assert arr.shape == (100, )
# test that numbers are exclusive of high point
arr = rng_integers(rng, low=5, size=100, endpoint=False)
assert np.max(arr) == 4
assert np.min(arr) == 0
assert arr.shape == (100, )

View file

@ -0,0 +1,197 @@
from numpy.testing import assert_equal, assert_
from pytest import raises as assert_raises
import time
import pytest
import ctypes
import threading
from scipy._lib import _ccallback_c as _test_ccallback_cython
from scipy._lib import _test_ccallback
from scipy._lib._ccallback import LowLevelCallable
try:
import cffi
HAVE_CFFI = True
except ImportError:
HAVE_CFFI = False
ERROR_VALUE = 2.0
def callback_python(a, user_data=None):
if a == ERROR_VALUE:
raise ValueError("bad value")
if user_data is None:
return a + 1
else:
return a + user_data
def _get_cffi_func(base, signature):
if not HAVE_CFFI:
pytest.skip("cffi not installed")
# Get function address
voidp = ctypes.cast(base, ctypes.c_void_p)
address = voidp.value
# Create corresponding cffi handle
ffi = cffi.FFI()
func = ffi.cast(signature, address)
return func
def _get_ctypes_data():
value = ctypes.c_double(2.0)
return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
def _get_cffi_data():
if not HAVE_CFFI:
pytest.skip("cffi not installed")
ffi = cffi.FFI()
return ffi.new('double *', 2.0)
CALLERS = {
'simple': _test_ccallback.test_call_simple,
'nodata': _test_ccallback.test_call_nodata,
'nonlocal': _test_ccallback.test_call_nonlocal,
'cython': _test_ccallback_cython.test_call_cython,
}
# These functions have signatures known to the callers
FUNCS = {
'python': lambda: callback_python,
'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1_cython"),
'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
'double (*)(double, int *, void *)'),
'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1b_cython"),
'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
'double (*)(double, double, int *, void *)'),
}
# These functions have signatures the callers don't know
BAD_FUNCS = {
'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1bc_cython"),
'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
'cffi_bc': lambda: _get_cffi_func(_test_ccallback_cython.plus1bc_ctypes,
'double (*)(double, double, double, int *, void *)'),
}
USER_DATAS = {
'ctypes': _get_ctypes_data,
'cffi': _get_cffi_data,
'capsule': _test_ccallback.test_get_data_capsule,
}
def test_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
func = FUNCS[func]()
user_data = USER_DATAS[user_data]()
if func is callback_python:
func2 = lambda x: func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test basic call
assert_equal(caller(func, 1.0), 2.0)
# Test 'bad' value resulting to an error
assert_raises(ValueError, caller, func, ERROR_VALUE)
# Test passing in user_data
assert_equal(caller(func2, 1.0), 3.0)
for caller in sorted(CALLERS.keys()):
for func in sorted(FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_bad_callbacks():
def check(caller, func, user_data):
caller = CALLERS[caller]
user_data = USER_DATAS[user_data]()
func = BAD_FUNCS[func]()
if func is callback_python:
func2 = lambda x: func(x, 2.0)
else:
func2 = LowLevelCallable(func, user_data)
func = LowLevelCallable(func)
# Test that basic call fails
assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
# Test that passing in user_data also fails
assert_raises(ValueError, caller, func2, 1.0)
# Test error message
llfunc = LowLevelCallable(func)
try:
caller(llfunc, 1.0)
except ValueError as err:
msg = str(err)
assert_(llfunc.signature in msg, msg)
assert_('double (double, double, int *, void *)' in msg, msg)
for caller in sorted(CALLERS.keys()):
for func in sorted(BAD_FUNCS.keys()):
for user_data in sorted(USER_DATAS.keys()):
check(caller, func, user_data)
def test_signature_override():
caller = _test_ccallback.test_call_simple
func = _test_ccallback.test_get_plus1_capsule()
llcallable = LowLevelCallable(func, signature="bad signature")
assert_equal(llcallable.signature, "bad signature")
assert_raises(ValueError, caller, llcallable, 3)
llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
assert_equal(llcallable.signature, "double (double, int *, void *)")
assert_equal(caller(llcallable, 3), 4)
def test_threadsafety():
def callback(a, caller):
if a <= 0:
return 1
else:
res = caller(lambda x: callback(x, caller), a - 1)
return 2*res
def check(caller):
caller = CALLERS[caller]
results = []
count = 10
def run():
time.sleep(0.01)
r = caller(lambda x: callback(x, caller), count)
results.append(r)
threads = [threading.Thread(target=run) for j in range(20)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert_equal(results, [2.0**count]*len(threads))
for caller in CALLERS.keys():
check(caller)

View file

@ -0,0 +1,10 @@
import pytest
def test_cython_api_deprecation():
match = ("`scipy._lib._test_deprecation_def.foo_deprecated` "
"is deprecated, use `foo` instead!\n"
"Deprecated in Scipy 42.0.0")
with pytest.warns(DeprecationWarning, match=match):
from .. import _test_deprecation_call
assert _test_deprecation_call.call() == (1, 1)

View file

@ -0,0 +1,50 @@
import sys
import subprocess
MODULES = [
"scipy.cluster",
"scipy.cluster.vq",
"scipy.cluster.hierarchy",
"scipy.constants",
"scipy.fft",
"scipy.fftpack",
"scipy.integrate",
"scipy.interpolate",
"scipy.io",
"scipy.io.arff",
"scipy.io.harwell_boeing",
"scipy.io.idl",
"scipy.io.matlab",
"scipy.io.netcdf",
"scipy.io.wavfile",
"scipy.linalg",
"scipy.linalg.blas",
"scipy.linalg.cython_blas",
"scipy.linalg.lapack",
"scipy.linalg.cython_lapack",
"scipy.linalg.interpolative",
"scipy.misc",
"scipy.ndimage",
"scipy.odr",
"scipy.optimize",
"scipy.signal",
"scipy.signal.windows",
"scipy.sparse",
"scipy.sparse.linalg",
"scipy.sparse.csgraph",
"scipy.spatial",
"scipy.spatial.distance",
"scipy.special",
"scipy.stats",
"scipy.stats.distributions",
"scipy.stats.mstats",
]
def test_modules_importable():
# Check that all modules are importable in a new Python process.
#This is not necessarily true if there are import cycles present.
for module in MODULES:
cmd = 'import {}'.format(module)
subprocess.check_call([sys.executable, '-c', cmd])

View file

@ -0,0 +1,42 @@
""" Test tmpdirs module """
from os import getcwd
from os.path import realpath, abspath, dirname, isfile, join as pjoin, exists
from scipy._lib._tmpdirs import tempdir, in_tempdir, in_dir
from numpy.testing import assert_, assert_equal
MY_PATH = abspath(__file__)
MY_DIR = dirname(MY_PATH)
def test_tempdir():
with tempdir() as tmpdir:
fname = pjoin(tmpdir, 'example_file.txt')
with open(fname, 'wt') as fobj:
fobj.write('a string\\n')
assert_(not exists(tmpdir))
def test_in_tempdir():
my_cwd = getcwd()
with in_tempdir() as tmpdir:
with open('test.txt', 'wt') as f:
f.write('some text')
assert_(isfile('test.txt'))
assert_(isfile(pjoin(tmpdir, 'test.txt')))
assert_(not exists(tmpdir))
assert_equal(getcwd(), my_cwd)
def test_given_directory():
# Test InGivenDirectory
cwd = getcwd()
with in_dir() as tmpdir:
assert_equal(tmpdir, abspath(cwd))
assert_equal(tmpdir, abspath(getcwd()))
with in_dir(MY_DIR) as tmpdir:
assert_equal(tmpdir, MY_DIR)
assert_equal(realpath(MY_DIR), realpath(abspath(getcwd())))
# We were deleting the given directory! Check not so now.
assert_(isfile(MY_PATH))

View file

@ -0,0 +1,121 @@
"""
Tests which scan for certain occurrences in the code, they may not find
all of these occurrences but should catch almost all. This file was adapted
from NumPy.
"""
import os
from pathlib import Path
import ast
import tokenize
import scipy
import pytest
class ParseCall(ast.NodeVisitor):
def __init__(self):
self.ls = []
def visit_Attribute(self, node):
ast.NodeVisitor.generic_visit(self, node)
self.ls.append(node.attr)
def visit_Name(self, node):
self.ls.append(node.id)
class FindFuncs(ast.NodeVisitor):
def __init__(self, filename):
super().__init__()
self.__filename = filename
self.bad_filters = []
self.bad_stacklevels = []
def visit_Call(self, node):
p = ParseCall()
p.visit(node.func)
ast.NodeVisitor.generic_visit(self, node)
if p.ls[-1] == 'simplefilter' or p.ls[-1] == 'filterwarnings':
if node.args[0].s == "ignore":
self.bad_filters.append(
"{}:{}".format(self.__filename, node.lineno))
if p.ls[-1] == 'warn' and (
len(p.ls) == 1 or p.ls[-2] == 'warnings'):
if self.__filename == "_lib/tests/test_warnings.py":
# This file
return
# See if stacklevel exists:
if len(node.args) == 3:
return
args = {kw.arg for kw in node.keywords}
if "stacklevel" not in args:
self.bad_stacklevels.append(
"{}:{}".format(self.__filename, node.lineno))
@pytest.fixture(scope="session")
def warning_calls():
# combined "ignore" and stacklevel error
base = Path(scipy.__file__).parent
bad_filters = []
bad_stacklevels = []
for path in base.rglob("*.py"):
# use tokenize to auto-detect encoding on systems where no
# default encoding is defined (e.g., LANG='C')
with tokenize.open(str(path)) as file:
tree = ast.parse(file.read(), filename=str(path))
finder = FindFuncs(path.relative_to(base))
finder.visit(tree)
bad_filters.extend(finder.bad_filters)
bad_stacklevels.extend(finder.bad_stacklevels)
return bad_filters, bad_stacklevels
@pytest.mark.slow
def test_warning_calls_filters(warning_calls):
bad_filters, bad_stacklevels = warning_calls
# There is still one simplefilter occurrence in optimize.py that could be removed.
bad_filters = [item for item in bad_filters
if 'optimize.py' not in item]
# The filterwarnings calls in sparse are needed.
bad_filters = [item for item in bad_filters
if os.path.join('sparse', '__init__.py') not in item
and os.path.join('sparse', 'sputils.py') not in item]
if bad_filters:
raise AssertionError(
"warning ignore filter should not be used, instead, use\n"
"numpy.testing.suppress_warnings (in tests only);\n"
"found in:\n {}".format(
"\n ".join(bad_filters)))
@pytest.mark.slow
@pytest.mark.xfail(reason="stacklevels currently missing")
def test_warning_calls_stacklevels(warning_calls):
bad_filters, bad_stacklevels = warning_calls
msg = ""
if bad_filters:
msg += ("warning ignore filter should not be used, instead, use\n"
"numpy.testing.suppress_warnings (in tests only);\n"
"found in:\n {}".format("\n ".join(bad_filters)))
msg += "\n\n"
if bad_stacklevels:
msg += "warnings should have an appropriate stacklevel:\n {}".format(
"\n ".join(bad_stacklevels))
if msg:
raise AssertionError(msg)