Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,96 @@
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
the outputs from Matlab R2012a. """
from __future__ import division, print_function, absolute_import
import numpy as np
import pywt
try:
from pymatbridge import Matlab
mlab = Matlab()
_matlab_missing = False
except ImportError:
print("To run Matlab compatibility tests you need to have MathWorks "
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
"package installed.")
_matlab_missing = True
if _matlab_missing:
raise EnvironmentError("Can't generate matlab data files without MATLAB")
size_set = 'reduced'
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('reflect', 'symw'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per'),
('antisymmetric', 'asym'),
('antireflect', 'asymw')]
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
rstate = np.random.RandomState(1234)
mlab.start()
try:
all_matlab_results = {}
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
mlab.set_variable('wavelet', wavelet)
if size_set == 'full':
data_sizes = list(range(w.dec_len, 40)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (w.dec_len, w.dec_len + 1)
for N in data_sizes:
data = rstate.randn(N)
mlab.set_variable('data', data)
for pmode, mmode in modes:
# Matlab result
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
"'mode', '%s');" % mmode)
else:
mlab_code = ("[ma, md] = dwt(data, wavelet, "
"'mode', '%s');" % mmode)
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
all_matlab_results[ma_key] = ma
all_matlab_results[md_key] = md
# Matlab result
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
"'mode', '%s');" % mmode)
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
all_matlab_results[ma_key] = ma
all_matlab_results[md_key] = md
finally:
mlab.stop()
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)

View file

@ -0,0 +1,86 @@
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
the outputs from Matlab R2012a. """
from __future__ import division, print_function, absolute_import
import numpy as np
import pywt
try:
from pymatbridge import Matlab
mlab = Matlab()
_matlab_missing = False
except ImportError:
print("To run Matlab compatibility tests you need to have MathWorks "
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
"package installed.")
_matlab_missing = True
if _matlab_missing:
raise EnvironmentError("Can't generate matlab data files without MATLAB")
size_set = 'reduced'
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per')]
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
wavelets = sum([pywt.wavelist(name) for name in families], [])
rstate = np.random.RandomState(1234)
mlab.start()
try:
all_matlab_results = {}
for wavelet in wavelets:
w = pywt.ContinuousWavelet(wavelet)
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
elif wavelet == 'fbsp':
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
else:
mlab.set_variable('wavelet', wavelet)
if size_set == 'full':
data_sizes = list(range(100, 101)) + \
[100, 200, 500, 1000, 50000]
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
else:
data_sizes = (1000, 1000 + 1)
Scales = (1,np.arange(1,3))
mlab_code = ("psi = wavefun(wavelet,10)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
psi = np.asarray(mlab.get_variable('psi'))
psi_key = '_'.join([wavelet, 'psi'])
all_matlab_results[psi_key] = psi
for N in data_sizes:
data = rstate.randn(N)
mlab.set_variable('data', data)
# Matlab result
scale_count = 0
for scales in Scales:
scale_count += 1
mlab.set_variable('scales', scales)
mlab_code = ("coefs = cwt(data, scales, wavelet)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
coefs = np.asarray(mlab.get_variable('coefs'))
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
all_matlab_results[coefs_key] = coefs
finally:
mlab.stop()
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)

View file

@ -0,0 +1,170 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_, assert_raises
import pywt
def test_upcoef_reconstruct():
data = np.arange(3)
a = pywt.downcoef('a', data, 'haar')
d = pywt.downcoef('d', data, 'haar')
rec = (pywt.upcoef('a', a, 'haar', take=3) +
pywt.upcoef('d', d, 'haar', take=3))
assert_allclose(rec, data)
def test_downcoef_multilevel():
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.downcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)
def test_downcoef_complex():
rstate = np.random.RandomState(1234)
r = rstate.randn(16) + 1j * rstate.randn(16)
nlevels = 3
a = pywt.downcoef('a', r, 'haar', level=nlevels)
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
assert_allclose(a, a_ref)
def test_downcoef_errs():
# invalid part string (not 'a' or 'd')
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
def test_compare_downcoef_coeffs():
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
# compare downcoef against wavedec outputs
for nlevels in [1, 2, 3]:
for wavelet in pywt.wavelist():
if wavelet in ['cmor', 'shan', 'fbsp']:
# skip these CWT families to avoid warnings
continue
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
if isinstance(wavelet, pywt.Wavelet):
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
if nlevels <= max_level:
a = pywt.downcoef('a', r, wavelet, level=nlevels)
d = pywt.downcoef('d', r, wavelet, level=nlevels)
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
assert_allclose(a, coeffs[0])
assert_allclose(d, coeffs[1])
def test_upcoef_multilevel():
rstate = np.random.RandomState(1234)
r = rstate.randn(4)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.upcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)
def test_upcoef_complex():
rstate = np.random.RandomState(1234)
r = rstate.randn(4) + 1j*rstate.randn(4)
nlevels = 3
a = pywt.upcoef('a', r, 'haar', level=nlevels)
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
assert_allclose(a, a_ref)
def test_upcoef_errs():
# invalid part string (not 'a' or 'd')
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
def test_upcoef_and_downcoef_1d_only():
# upcoef and downcoef raise a ValueError if data.ndim > 1d
for ndim in [2, 3]:
data = np.ones((8, )*ndim)
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
def test_wavelet_repr():
from pywt._extensions import _pywt
wavelet = _pywt.Wavelet('sym8')
repr_wavelet = eval(wavelet.__repr__())
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
def test_dwt_max_level():
assert_(pywt.dwt_max_level(16, 2) == 4)
assert_(pywt.dwt_max_level(16, 8) == 1)
assert_(pywt.dwt_max_level(16, 9) == 1)
assert_(pywt.dwt_max_level(16, 10) == 0)
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
assert_(pywt.dwt_max_level(16, 10.) == 0)
assert_(pywt.dwt_max_level(16, 18) == 0)
# accepts discrete Wavelet object or string as well
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
# string input that is not a discrete wavelet
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
# filter_len must be an integer >= 2
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
def test_ContinuousWavelet_errs():
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
def test_ContinuousWavelet_repr():
from pywt._extensions import _pywt
wavelet = _pywt.ContinuousWavelet('gaus2')
repr_wavelet = eval(wavelet.__repr__())
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
def test_wavelist():
for name in pywt.wavelist(family='coif'):
assert_(name.startswith('coif'))
assert_('cgau7' in pywt.wavelist(kind='continuous'))
assert_('sym20' in pywt.wavelist(kind='discrete'))
assert_(len(pywt.wavelist(kind='continuous')) +
len(pywt.wavelist(kind='discrete')) ==
len(pywt.wavelist(kind='all')))
assert_raises(ValueError, pywt.wavelist, kind='foobar')
def test_wavelet_errormsgs():
try:
pywt.Wavelet('gaus1')
except ValueError as e:
assert_(e.args[0].startswith('The `Wavelet` class'))
try:
pywt.Wavelet('cmord')
except ValueError as e:
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")

View file

@ -0,0 +1,105 @@
"""
Tests used to verify running PyWavelets transforms in parallel via
concurrent.futures.ThreadPoolExecutor does not raise errors.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
from functools import partial
from numpy.testing import assert_array_equal, assert_allclose
from pywt._pytest import uses_futures, futures, max_workers
import pywt
def _assert_all_coeffs_equal(coefs1, coefs2):
# return True only if all coefficients of SWT or DWT match over all levels
if len(coefs1) != len(coefs2):
return False
for (c1, c2) in zip(coefs1, coefs2):
if isinstance(c1, tuple):
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
for a1, a2 in zip(c1, c2):
assert_array_equal(a1, a2)
elif isinstance(c1, dict):
# for swtn, dwtn, wavedecn
for k, v in c1.items():
assert_array_equal(v, c2[k])
else:
return False
return True
@uses_futures
def test_concurrent_swt():
# tests error-free concurrent operation (see gh-288)
# swt on 1D data calls the Cython swt
# other cases call swt_axes
with warnings.catch_warnings():
# can remove catch_warnings once the swt2 FutureWarning is removed
warnings.simplefilter('ignore', FutureWarning)
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(swt_func, wavelet='haar', level=3)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_wavedec():
# wavedec on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(wavedec_func, wavelet='haar', level=1)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_dwt():
# dwt on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(dwt_func, wavelet='haar')
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
@uses_futures
def test_concurrent_cwt():
atol = rtol = 1e-14
time, sst = pywt.data.nino()
dt = time[1]-time[0]
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
sampling_period=dt)
for _ in range(10):
arrs = [sst.copy() for _ in range(50)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(sst)
for a1, a2 in zip(expected_result, results[-1]):
assert_allclose(a1, a2, atol=atol, rtol=rtol)

View file

@ -0,0 +1,434 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from itertools import product
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises, assert_equal)
import pytest
import numpy as np
import pywt
def ref_gaus(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
if (num == 1):
psi = -2.*X*F0
elif (num == 2):
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
elif (num == 3):
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
elif (num == 4):
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
elif (num == 5):
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
elif (num == 6):
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
elif (num == 7):
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
elif (num == 8):
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
224*X**6 + 16*X**8)*F0
return (psi, X)
def ref_cgau(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = np.exp(-X**2)
F1 = np.exp(-1j*X)
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
if (num == 1):
psi = F2*(-1j - 2*X)*2**(1/2)
elif (num == 2):
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
elif (num == 3):
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
elif (num == 4):
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
elif (num == 5):
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
80j*X**4 - 32*X**5)*210**(1/2)
elif (num == 6):
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
192j*X**5 + 64*X**6)*2310**(1/2)
elif (num == 7):
psi = 1/45045*F2*(
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
448j*X**6 - 128*X**7)*30030**(1/2)
elif (num == 8):
psi = 1/45045*F2*(
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
return (psi, X)
def sinc2(x):
y = np.ones_like(x)
k = np.where(x)[0]
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
return y
def ref_shan(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_fbsp(LB, UB, N, m, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_cmor(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
return (psi, x)
def ref_morl(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = np.exp(-(x**2)/2)*np.cos(5*x)
return (psi, x)
def ref_mexh(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
return (psi, x)
def test_gaus():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_gaus(LB, UB, N, num)
w = pywt.ContinuousWavelet("gaus" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cgau():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_cgau(LB, UB, N, num)
w = pywt.ContinuousWavelet("cgau" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_shan():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_cmor():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_fbsp():
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1
Fc = 1.5
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1.5
Fc = 1
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 3
Fb = 1.5
Fc = 1.2
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
# TODO: investigate why atol = 1e-5 is necessary
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
assert_allclose(X, x, atol=1e-15)
def test_morl():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_morl(LB, UB, N)
w = pywt.ContinuousWavelet("morl")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_mexh():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
LB = -5
UB = 5
N = 1001
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cwt_parameters_in_names():
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
for name in ['fbsp', 'cmor', 'shan']:
# additional parameters should be specified within the name
assert_warns(FutureWarning, func, name)
for name in ['cmor', 'shan']:
# valid names
func(name + '1.5-1.0')
func(name + '1-4')
# invalid names
assert_raises(ValueError, func, name + '1.0')
assert_raises(ValueError, func, name + 'B-C')
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
# valid names
func('fbsp1-1.5-1.0')
func('fbsp1.0-1.5-1')
func('fbsp2-5-1')
# invalid name (non-integer order)
assert_raises(ValueError, func, 'fbsp1.5-1-1')
assert_raises(ValueError, func, 'fbspM-B-C')
# invalid name (too few or too many params)
assert_raises(ValueError, func, 'fbsp1.0')
assert_raises(ValueError, func, 'fbsp1.0-0.4')
assert_raises(ValueError, func, 'fbsp1-1-1-1')
@pytest.mark.parametrize('dtype, tol, method',
[(np.float32, 1e-5, 'conv'),
(np.float32, 1e-5, 'fft'),
(np.float64, 1e-13, 'conv'),
(np.float64, 1e-13, 'fft')])
def test_cwt_complex(dtype, tol, method):
time, sst = pywt.data.nino()
sst = np.asarray(sst, dtype=dtype)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# real-valued tranfsorm as a reference
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# complex-valued transform equals sum of the transforms of the real
# and imaginary components
sst_complex = sst + 1j*sst
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
method=method)
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
# verify dtype is preserved
assert_equal(cfs_complex.dtype, sst_complex.dtype)
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
def test_cwt_batch(axis, method):
dtype = np.float64
time, sst = pywt.data.nino()
n_batch = 8
batch_axis = 1 - axis
sst1 = np.asarray(sst, dtype=dtype)
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# non-batch transform as reference
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
shape_in = sst.shape
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
# shape of input is not modified
assert_equal(shape_in, sst.shape)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# verify expected shape
assert_equal(cfs.shape[0], len(scales))
assert_equal(cfs.shape[1 + batch_axis], n_batch)
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
# batch result on stacked input is the same as stacked 1d result
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
def test_cwt_small_scales():
data = np.zeros(32)
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
# mexh. This corner case should not raise an error.
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
assert_allclose(cfs, np.zeros_like(cfs))
# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
def test_cwt_method_fft():
rstate = np.random.RandomState(1)
data = rstate.randn(50)
data[15] = 1.
scales = np.arange(1, 64)
wavelet = 'cmor1.5-1.0'
# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)

View file

@ -0,0 +1,77 @@
import os
import numpy as np
from numpy.testing import assert_allclose, assert_raises, assert_
import pywt.data
data_dir = os.path.join(os.path.dirname(__file__), 'data')
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
wavelab_result_dict = np.load(wavelab_data_file)
def test_data_aero():
aero = pywt.data.aero()
ref = np.array([[178, 178, 179],
[170, 173, 171],
[185, 174, 171]])
assert_allclose(aero[:3, :3], ref)
def test_data_ascent():
ascent = pywt.data.ascent()
ref = np.array([[83, 83, 83],
[82, 82, 83],
[80, 81, 83]])
assert_allclose(ascent[:3, :3], ref)
def test_data_camera():
ascent = pywt.data.camera()
ref = np.array([[156, 157, 160],
[156, 157, 159],
[158, 157, 156]])
assert_allclose(ascent[:3, :3], ref)
def test_data_ecg():
ecg = pywt.data.ecg()
ref = np.array([-86, -87, -87])
assert_allclose(ecg[:3], ref)
def test_wavelab_signals():
"""Comparison with results generated using WaveLab"""
rtol = atol = 1e-12
# get a list of the available signals
available_signals = pywt.data.demo_signal('list')
assert_('Doppler' in available_signals)
for signal in available_signals:
# reference dictionary has lowercase names for the keys
key = signal.replace('-', '_').lower()
val = wavelab_result_dict[key]
if key in ['gabor', 'sineoneoverx']:
# these functions do not allow a size to be provided
assert_allclose(val, pywt.data.demo_signal(signal),
rtol=rtol, atol=atol)
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
else:
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
rtol=rtol, atol=atol)
# these functions require a size to be provided
assert_raises(ValueError, pywt.data.demo_signal, key)
# ValueError on unrecognized signal type
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
# ValueError on invalid length
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)

View file

@ -0,0 +1,89 @@
import warnings
import numpy as np
from numpy.testing import assert_warns, assert_array_equal
import pywt
def test_intwave_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
def test_centrfrq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
def test_scal2frq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
def test_orthfilt_deprecation():
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
def test_integrate_wave_tuple():
sig = [0, 1, 2, 3]
xgrid = [0, 1, 2, 3]
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
old_modes = ['zpd',
'cpd',
'sym',
'ppd',
'sp1',
'per',
]
def test_MODES_from_object_deprecation():
for mode in old_modes:
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
def test_MODES_attributes_deprecation():
def get_mode(Modes, name):
return getattr(Modes, name)
for mode in old_modes:
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
def test_MODES_deprecation_new():
def use_MODES_new():
return pywt.MODES.symmetric
assert_warns(DeprecationWarning, use_MODES_new)
def test_MODES_deprecation_old():
def use_MODES_old():
return pywt.MODES.sym
assert_warns(DeprecationWarning, use_MODES_old)
def test_MODES_deprecation_getattr():
def use_MODES_new():
return getattr(pywt.MODES, 'symmetric')
assert_warns(DeprecationWarning, use_MODES_new)
def test_mode_equivalence():
old_new = [('zpd', 'zero'),
('cpd', 'constant'),
('sym', 'symmetric'),
('ppd', 'periodic'),
('sp1', 'smooth'),
('per', 'periodization')]
x = np.arange(8.)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
for old, new in old_new:
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
pywt.dwt(x, 'db2', mode=new))

View file

@ -0,0 +1,25 @@
from __future__ import division, print_function, absolute_import
import doctest
import glob
import os
import unittest
try:
import numpy as np
np.set_printoptions(legacy='1.13')
except TypeError:
pass
pdir = os.path.pardir
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
pdir, pdir, "doc", "source"))
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
glob.glob(os.path.join(docs_base, "*", "*.rst"))
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
if __name__ == "__main__":
unittest.TextTestRunner().run(suite)

View file

@ -0,0 +1,299 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_array_equal)
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
# Other real types get converted to float64.
# complex256 gets converted to complex128
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# test complex256 as well if it is available
try:
dtypes_in += [np.complex256, ]
dtypes_out += [np.complex128, ]
except AttributeError:
pass
def test_dwt_idwt_basic():
x = [3, 7, 1, 1, -2, 5, 4, 6]
cA, cD = pywt.dwt(x, 'db2')
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
1.22474487]
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.float64)
def test_idwt_mixed_complex_dtype():
x = np.arange(8).astype(float)
x = x + 1j*x[::-1]
cA, cD = pywt.dwt(x, 'db2')
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)
def test_dwt_idwt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones(4, dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, cD = pywt.dwt(x, wavelet)
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
x_roundtrip = pywt.idwt(cA, cD, wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
def test_dwt_idwt_basic_complex():
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 0.5j*x
cA, cD = pywt.dwt(x, 'db2')
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
7.77817459])
cA_expect = cA_expect + 0.5j*cA_expect
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
1.22474487])
cD_expect = cD_expect + 0.5j*cD_expect
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
def test_dwt_idwt_partial_complex():
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 0.5j*x
cA, cD = pywt.dwt(x, 'haar')
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
cA_rec = pywt.idwt(cA, None, 'haar')
assert_allclose(cA_rec, cA_rec_expect)
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
cD_rec = pywt.idwt(None, cD, 'haar')
assert_allclose(cD_rec, cD_rec_expect)
assert_allclose(cA_rec + cD_rec, x)
def test_dwt_wavelet_kwd():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
w = pywt.Wavelet('sym3')
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
7.81994027]
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
-0.09722957, -0.07045258]
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
def test_dwt_coeff_len():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
w = pywt.Wavelet('sym3')
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
pywt.Modes.modes]
expected_result = [6, ] * len(pywt.Modes.modes)
expected_result[pywt.Modes.modes.index('periodization')] = 4
assert_allclose(ln_modes, expected_result)
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
pywt.Modes.modes]
assert_allclose(ln_modes, expected_result)
def test_idwt_none_input():
# None input equals arrays of zeros of the right length
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
# Only one argument at a time can be None
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
def test_idwt_invalid_input():
# Too short, min length is 4 for 'db4':
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
def test_dwt_single_axis():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA, cD = pywt.dwt(x, 'db2', axis=-1)
cA0, cD0 = pywt.dwt(x[0], 'db2')
cA1, cD1 = pywt.dwt(x[1], 'db2')
assert_allclose(cA[0], cA0)
assert_allclose(cA[1], cA1)
assert_allclose(cD[0], cD0)
assert_allclose(cD[1], cD1)
def test_idwt_single_axis():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
x = np.asarray(x)
x = x + 1j*x # test with complex data
cA, cD = pywt.dwt(x, 'db2', axis=-1)
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
assert_allclose(x[0], x0)
assert_allclose(x[1], x1)
def test_dwt_axis_arg():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
cA, cD = pywt.dwt(x, 'db2', axis=1)
assert_allclose(cA_, cA)
assert_allclose(cD_, cD)
def test_idwt_axis_arg():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA, cD = pywt.dwt(x, 'db2', axis=1)
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
x = pywt.idwt(cA, cD, 'db2', axis=1)
assert_allclose(x_, x)
def test_dwt_idwt_axis_excess():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
# can't transform over axes that aren't there
assert_raises(ValueError,
pywt.dwt, x, 'db2', 'symmetric', axis=2)
assert_raises(ValueError,
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((32, ))
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, pywt.dwt, data, cwave)
cA, cD = pywt.dwt(data, 'db1')
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
def test_dwt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.dwt, [], 'db2')
# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
def test_pad_1d():
x = [1, 2, 3]
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# equivalence of various pad_width formats
assert_array_equal(pywt.pad(x, 4, 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
def test_pad_errors():
# negative pad width
x = [1, 2, 3]
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
# wrong length pad width
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
# invalid mode name
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
def test_pad_nd():
for ndim in [2, 3]:
x = np.arange(4**ndim).reshape((4, ) * ndim)
if ndim == 2:
pad_widths = [(2, 1), (2, 3)]
else:
pad_widths = [(2, 1), ] * ndim
for mode in pywt.Modes.modes:
xp = pywt.pad(x, pad_widths, mode)
# expected result is the same as applying along axes separably
xp_expected = x.copy()
for ax in range(ndim):
xp_expected = np.apply_along_axis(pywt.pad,
ax,
xp_expected,
pad_widths=[pad_widths[ax]],
mode=mode)
assert_array_equal(xp, xp_expected)

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from numpy.testing import assert_almost_equal, assert_allclose
import pywt
def test_centrfreq():
# db1 is Haar function, frequency=1
w = pywt.Wavelet('db1')
expected = 1
result = pywt.central_frequency(w, precision=12)
assert_almost_equal(result, expected, decimal=3)
# db2, frequency=2/3
w = pywt.Wavelet('db2')
expected = 2/3.
result = pywt.central_frequency(w, precision=12)
assert_almost_equal(result, expected)
def test_scal2frq_scale():
scale = 2
w = pywt.Wavelet('db1')
expected = 1. / scale
result = pywt.scale2frequency(w, scale, precision=12)
assert_almost_equal(result, expected, decimal=3)
def test_intwave_orthogonal():
w = pywt.Wavelet('db1')
int_psi, x = pywt.integrate_wavelet(w, precision=12)
ix = x < 0.5
# For x < 0.5, the integral is equal to x
assert_allclose(int_psi[ix], x[ix])
# For x > 0.5, the integral is equal to (1 - x)
# Ignore last point here, there x > 1 and something goes wrong
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)

View file

@ -0,0 +1,160 @@
"""
Test used to verify PyWavelets Discrete Wavelet Transform computation
accuracy against MathWorks Wavelet Toolbox.
"""
from __future__ import division, print_function, absolute_import
import numpy as np
import pytest
from numpy.testing import assert_
import pywt
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('reflect', 'symw'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per'),
# TODO: Now have implemented asymmetric modes too.
# Would be nice to update the Matlab data to test these as well.
('antisymmetric', 'asym'),
('antireflect', 'asymw'),
]
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
def _get_data_sizes(w):
""" Return the sizes to test for wavelet w. """
if size_set == 'full':
data_sizes = list(range(w.dec_len, 40)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (w.dec_len, w.dec_len + 1)
return data_sizes
@uses_pymatbridge
@pytest.mark.slow
def test_accuracy_pymatbridge():
Matlab = pytest.importorskip("pymatbridge.Matlab")
mlab = Matlab()
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 5.0e-5
epsilon_pywt_coeffs = 1.0e-10
mlab.start()
try:
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
mlab.set_variable('wavelet', wavelet)
for N in _get_data_sizes(w):
data = rstate.randn(N)
mlab.set_variable('data', data)
for pmode, mmode in modes:
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
finally:
mlab.stop()
@uses_precomputed
@pytest.mark.slow
def test_accuracy_precomputed():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 5.0e-5
epsilon_pywt_coeffs = 1.0e-10
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
for N in _get_data_sizes(w):
data = rstate.randn(N)
for pmode, mmode in modes:
ma, md = _load_matlab_result(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
def _compute_matlab_result(data, wavelet, mmode, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
already been set externally.
"""
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
w = pywt.Wavelet(wavelet)
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
else:
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError("Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is a single float64
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
return ma, md
def _load_matlab_result(data, wavelet, mmode):
""" Load the precomputed result.
"""
N = len(data)
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
if (ma_key not in matlab_result_dict) or \
(md_key not in matlab_result_dict):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
ma = matlab_result_dict[ma_key]
md = matlab_result_dict[md_key]
return ma, md
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
""" Load the precomputed result.
"""
N = len(data)
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
if (ma_key not in matlab_result_dict) or \
(md_key not in matlab_result_dict):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
ma = matlab_result_dict[ma_key]
md = matlab_result_dict[md_key]
return ma, md
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
# PyWavelets result
pa, pd = pywt.dwt(data, w, pmode)
# calculate error measures
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
rms_d = np.sqrt(np.mean((pd - md) ** 2))
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
assert_(rms_a < epsilon, msg=msg)
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
assert_(rms_d < epsilon, msg=msg)

View file

@ -0,0 +1,174 @@
"""
Test used to verify PyWavelets Continuous Wavelet Transform computation
accuracy against MathWorks Wavelet Toolbox.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
import pytest
from numpy.testing import assert_
import pywt
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
matlab_result_dict_cwt)
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
wavelets = sum([pywt.wavelist(name) for name in families], [])
def _get_data_sizes(w):
""" Return the sizes to test for wavelet w. """
if size_set == 'full':
data_sizes = list(range(100, 101)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (1000, 1000 + 1)
return data_sizes
def _get_scales(w):
""" Return the scales to test for wavelet w. """
if size_set == 'full':
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
else:
scales = (1, np.arange(1, 3))
return scales
@uses_pymatbridge # skip this case if precomputed results are used instead
@pytest.mark.slow
def test_accuracy_pymatbridge_cwt():
Matlab = pytest.importorskip("pymatbridge.Matlab")
mlab = Matlab()
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 1e-15
epsilon_psi = 1e-15
mlab.start()
try:
for wavelet in wavelets:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
w = pywt.ContinuousWavelet(wavelet)
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
elif wavelet == 'fbsp':
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
else:
mlab.set_variable('wavelet', wavelet)
mlab_code = ("psi = wavefun(wavelet,10)")
res = mlab.run_code(mlab_code)
psi = np.asarray(mlab.get_variable('psi'))
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
mlab.set_variable('data', data)
for scales in _get_scales(w):
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
finally:
mlab.stop()
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
@pytest.mark.slow
def test_accuracy_precomputed_cwt():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
# has to be improved
epsilon = 2e-15
epsilon32 = 1e-5
epsilon_psi = 1e-15
for wavelet in wavelets:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
w = pywt.ContinuousWavelet(wavelet)
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
psi = _load_matlab_result_psi(wavelet)
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
data32 = data.astype(np.float32)
scales_count = 0
for scales in _get_scales(w):
scales_count += 1
coefs = _load_matlab_result(data, wavelet, scales_count)
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
def _compute_matlab_result(data, wavelet, scales, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
already been set externally.
"""
mlab.set_variable('scales', scales)
mlab_code = ("coefs = cwt(data, scales, wavelet)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError("Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is a single float64
coefs = np.asarray(mlab.get_variable('coefs'))
return coefs
def _load_matlab_result(data, wavelet, scales):
""" Load the precomputed result.
"""
N = len(data)
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
if (coefs_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
coefs = matlab_result_dict_cwt[coefs_key]
return coefs
def _load_matlab_result_psi(wavelet):
""" Load the precomputed result.
"""
psi_key = '_'.join([wavelet, 'psi'])
if (psi_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab psi result not found for wavelet: "
"{0}}".format(wavelet))
psi = matlab_result_dict_cwt[psi_key]
return psi
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
# PyWavelets result
coefs_pywt, freq = pywt.cwt(data, scales, w)
# coefs from Matlab are from R2012a which is missing the complex conjugate
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
# the precomputed Matlab result to account for this.
coefs = np.conj(coefs)
# calculate error measures
err = coefs_pywt - coefs
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
assert_(rms < epsilon, msg=msg)
def _check_accuracy_psi(w, psi, wavelet, epsilon):
# PyWavelets result
psi_pywt, x = w.wavefun(length=1024)
# calculate error measures
err = psi_pywt.flatten() - psi.flatten()
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
msg = ('[RMS > EPSILON] for Wavelet: %s, '
'rms=%.3g' % (wavelet, rms))
assert_(rms < epsilon, msg=msg)

View file

@ -0,0 +1,109 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_raises, assert_equal, assert_allclose
import pywt
def test_available_modes():
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
'periodization', 'reflect', 'antisymmetric', 'antireflect']
assert_equal(pywt.Modes.modes, modes)
assert_equal(pywt.Modes.from_object('constant'), 2)
def test_invalid_modes():
x = np.arange(4)
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
assert_raises(ValueError, pywt.Modes.from_object, -1)
assert_raises(ValueError, pywt.Modes.from_object, 9)
assert_raises(TypeError, pywt.Modes.from_object, None)
def test_dwt_idwt_allmodes():
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
x = [1, 2, 1, 5, -1, 8, 4, 6]
dwt_results = {
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
-1.8625013]),
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
7.51935555],
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
0.25881905]),
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
7.77817459],
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
1.22474487]),
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
6.81224877],
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
-2.38013939]),
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
6.9162743],
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
-1.99191082]),
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
7.45000519],
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
6.12372436],
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
-4.94974747]),
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
8.22646233],
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
2.89777748])
}
for mode in pywt.Modes.modes:
cA, cD = pywt.dwt(x, 'db2', mode)
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
def test_dwt_short_input_allmodes():
# some test cases where the input is shorter than the DWT filter
x = [1, 3, 2]
wavelet = 'db2'
# manually pad each end by the filter size (4 for 'db2' used here)
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
}
for mode, xpad in padded_x.items():
# DWT of the manually padded array. will discard edges later so
# symmetric mode used here doesn't matter.
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
# central region of the padded output (unaffected by mode )
expected_result = (cApad[2:-2], cDpad[2:-2])
cA, cD = pywt.dwt(x, wavelet, mode)
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
def test_default_mode():
# The default mode should be 'symmetric'
x = [1, 2, 1, 5, -1, 8, 4, 6]
cA, cD = pywt.dwt(x, 'db2')
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
assert_allclose(cA, cA2)
assert_allclose(cD, cD2)
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)

View file

@ -0,0 +1,443 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from itertools import combinations
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
# Other real types get converted to float64.
# complex256 gets converted to complex128
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# test complex256 as well if it is available
try:
dtypes_in += [np.complex256, ]
dtypes_out += [np.complex128, ]
except AttributeError:
pass
def test_dwtn_input():
# Array-like must be accepted
pywt.dwtn([1, 2, 3, 4], 'haar')
# Others must not
data = dict()
assert_raises(TypeError, pywt.dwtn, data, 'haar')
# Must be at least 1D
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
def test_3D_reconstruct():
data = np.array([
[[0, 4, 1, 5, 1, 4],
[0, 5, 26, 3, 2, 1],
[5, 8, 2, 33, 4, 9],
[2, 5, 19, 4, 19, 1]],
[[1, 5, 1, 2, 3, 4],
[7, 12, 6, 52, 7, 8],
[2, 12, 3, 52, 6, 8],
[5, 2, 6, 78, 12, 2]]])
wavelet = pywt.Wavelet('haar')
for mode in pywt.Modes.modes:
d = pywt.dwtn(data, wavelet, mode=mode)
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-13, atol=1e-13)
def test_dwdtn_idwtn_allwavelets():
rstate = np.random.RandomState(1234)
r = rstate.randn(16, 16)
# test 2D case only for all wavelet types
wavelist = pywt.wavelist()
if 'dmey' in wavelist:
wavelist.remove('dmey')
for wavelet in wavelist:
if wavelet in ['cmor', 'shan', 'fbsp']:
# skip these CWT families to avoid warnings
continue
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
for mode in pywt.Modes.modes:
coeffs = pywt.dwtn(r, wavelet, mode=mode)
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
r, rtol=1e-7, atol=1e-7)
def test_stride():
wavelet = pywt.Wavelet('haar')
for dtype in ('float32', 'float64'):
data = np.array([[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]],
dtype=dtype)
for mode in pywt.Modes.modes:
expected = pywt.dwtn(data, wavelet)
strided = np.ones((3, 12), dtype=data.dtype)
strided[::-1, ::2] = data
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
for key in expected.keys():
assert_allclose(strided_dwtn[key], expected[key])
def test_byte_offset():
wavelet = pywt.Wavelet('haar')
for dtype in ('float32', 'float64'):
data = np.array([[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]],
dtype=dtype)
for mode in pywt.Modes.modes:
expected = pywt.dwtn(data, wavelet)
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
'pad': ('byte', data.dtype.itemsize)},
align=True))
padded[:] = data
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
for key in expected.keys():
assert_allclose(padded_dwtn[key], expected[key])
def test_3D_reconstruct_complex():
# All dimensions even length so `take` does not need to be specified
data = np.array([
[[0, 4, 1, 5, 1, 4],
[0, 5, 26, 3, 2, 1],
[5, 8, 2, 33, 4, 9],
[2, 5, 19, 4, 19, 1]],
[[1, 5, 1, 2, 3, 4],
[7, 12, 6, 52, 7, 8],
[2, 12, 3, 52, 6, 8],
[5, 2, 6, 78, 12, 2]]])
data = data + 1j
wavelet = pywt.Wavelet('haar')
d = pywt.dwtn(data, wavelet)
# idwtn creates even-length shapes (2x dwtn size)
original_shape = tuple([slice(None, s) for s in data.shape])
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
rtol=1e-13, atol=1e-13)
def test_idwtn_idwt2():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
for mode in pywt.Modes.modes:
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-14, atol=1e-14)
def test_idwtn_idwt2_complex():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
data = data + 1j
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
for mode in pywt.Modes.modes:
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-14, atol=1e-14)
def test_idwtn_missing():
# Test to confirm missing data behave as zeroes
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
coefs = pywt.dwtn(data, wavelet)
# No point removing zero, or all
for num_missing in range(1, len(coefs)):
for missing in combinations(coefs.keys(), num_missing):
missing_coefs = coefs.copy()
for key in missing:
del missing_coefs[key]
LL = missing_coefs.get('aa', None)
HL = missing_coefs.get('da', None)
LH = missing_coefs.get('ad', None)
HH = missing_coefs.get('dd', None)
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
def test_idwtn_all_coeffs_None():
coefs = dict(aa=None, da=None, ad=None, dd=None)
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
def test_error_on_invalid_keys():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
# unexpected key
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
# mismatched key lengths
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
def test_error_mismatched_size():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
# Pass/fail depends on first element being shorter than remaining ones so
# set 3/4 to an incorrect size to maximize chances. Order of dict items
# is random so may not trigger on every test run. Dict is constructed
# inside idwtn function so no use using an OrderedDict here.
LL = LL[:, :-1]
LH = LH[:, :-1]
HH = HH[:, :-1]
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
def test_dwt2_idwt2_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
"dwt2: " + errmsg)
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
def test_dwtn_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1,))
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
assert_equal(coefs['a'], expected_a)
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
assert_equal(coefs['d'], expected_d)
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
assert_equal(coefs['aa'], expected_aa)
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
assert_equal(coefs['ad'], expected_ad)
def test_idwtn_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
def test_idwt2_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
# verify setting coefficients to None is the same as zeroing them
cD = np.zeros_like(cD)
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
cD = None
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
assert_equal(result_zeros, result_none)
def test_idwtn_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
# verify setting coefficients to None is the same as zeroing them
coefs['dd'] = np.zeros_like(coefs['dd'])
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
coefs['dd'] = None
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
assert_equal(result_zeros, result_none)
def test_idwt2_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
# too many axes
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
def test_idwt2_axes_subsets():
data = np.array(np.random.standard_normal((4, 4, 4)))
# test all combinations of 2 out of 3 axes transformed
for axes in combinations((0, 1, 2), 2):
coefs = pywt.dwt2(data, 'haar', axes=axes)
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
def test_idwtn_axes_subsets():
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
# test all combinations of 3 out of 4 axes transformed
for axes in combinations((0, 1, 2, 3), 3):
coefs = pywt.dwtn(data, 'haar', axes=axes)
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
def test_negative_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
assert_equal(coefs1, coefs2)
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
assert_equal(rec1, rec2)
def test_dwtn_idwtn_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
coeffs = pywt.dwtn(x, wavelet)
for k, v in coeffs.items():
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
x_roundtrip = pywt.idwtn(coeffs, wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
def test_idwtn_mixed_complex_dtype():
rstate = np.random.RandomState(0)
x = rstate.randn(8, 8, 8)
x = x + 1j*x
coeffs = pywt.dwtn(x, 'db2')
x_roundtrip = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)
def test_idwt2_size_mismatch_error():
LL = np.zeros((6, 6))
LH = HL = HH = np.zeros((5, 5))
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
def test_dwt2_dimension_error():
data = np.ones(16)
wavelet = pywt.Wavelet('haar')
# wrong number of input dimensions
assert_raises(ValueError, pywt.dwt2, data, wavelet)
# too many axes
data2 = np.ones((8, 8))
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
def test_per_axis_wavelets_and_modes():
# tests seperate wavelet and edge mode for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(16, 16, 16)
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
# mode can be a string or a Modes enum
modes = ('symmetric', 'periodization',
pywt._extensions._pywt.Modes.reflect)
coefs = pywt.dwtn(data, wavelets, modes)
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
coefs = pywt.dwtn(data, wavelets[:1], modes)
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
coefs = pywt.dwtn(data, wavelets, modes[:1])
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
# length of wavelets or modes doesn't match the length of axes
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
# dwt2/idwt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
atol=1e-14)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
[pywt.idwt2, pywt.idwtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
c = dec_fun(data, 'db1')
assert_raises(ValueError, rec_fun, c, wavelet=cwave)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
"""
Verify DWT perfect reconstruction.
"""
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_
import pywt
def test_perfect_reconstruction():
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per')]
dtypes = (np.float32, np.float64)
for wavelet in wavelets:
for pmode, mmode in modes:
for dt in dtypes:
check_reconstruction(pmode, mmode, wavelet, dt)
def check_reconstruction(pmode, mmode, wavelet, dtype):
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
50000, 100000]
np.random.seed(12345)
# TODO: smoke testing - more failures for different seeds
if dtype == np.float32:
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
epsilon = 6e-7
else:
epsilon = 5e-11
for N in data_size:
data = np.asarray(np.random.random(N), dtype)
# compute dwt coefficients
pa, pd = pywt.dwt(data, wavelet, pmode)
# compute reconstruction
rec = pywt.idwt(pa, pd, wavelet, pmode)
if len(data) % 2:
rec = rec[:len(data)]
rms_rec = np.sqrt(np.mean((data-rec)**2))
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
assert_(rms_rec < epsilon, msg=msg)

View file

@ -0,0 +1,633 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import warnings
from copy import deepcopy
from itertools import combinations, permutations
import numpy as np
import pytest
from numpy.testing import (assert_allclose, assert_, assert_equal,
assert_raises, assert_array_equal, assert_warns)
import pywt
from pywt._extensions._swt import swt_axis
# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
####
# 1d multilevel swt tests
####
def test_swt_decomposition():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
atol = tol_double
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
2.82842712, 7.07106781, 7.07106781, 6.36396103]
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
expected_cA3 = [9.89949494, ] * 8
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
0.00000000, 3.53553391, 4.24264069, 2.12132034]
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
# level=1, start_level=1 decomposition should match level=2
res = pywt.swt(cA1, db1, level=1, start_level=1)
cA2, cD2 = res[0]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
coeffs = pywt.swt(x, db1)
assert_(len(coeffs) == 3)
assert_(pywt.swt_max_level(len(x)), 3)
def test_swt_max_level():
# odd sized signal will warn about no levels of decomposition possible
assert_warns(UserWarning, pywt.swt_max_level, 11)
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
assert_equal(pywt.swt_max_level(11), 0)
# no warnings when >= 1 level of decomposition possible
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
def test_swt_axis():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
# test cases use 2D arrays based on tiling x along an axis and then
# calling swt along the other axis.
for order in ['C', 'F']:
# test SWT of 2D data along default axis (-1)
x_2d = np.asarray(x).reshape((1, -1))
x_2d = np.concatenate((x_2d, )*5, axis=0)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each row should match the 1D result
for row in cA1_2d:
assert_array_equal(row, cA1)
for row in cA2_2d:
assert_array_equal(row, cA2)
for row in cD1_2d:
assert_array_equal(row, cD1)
for row in cD2_2d:
assert_array_equal(row, cD2)
# test SWT of 2D data along other axis (0)
x_2d = np.asarray(x).reshape((-1, 1))
x_2d = np.concatenate((x_2d, )*5, axis=1)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
axis=0)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each column should match the 1D result
for row in cA1_2d.transpose((1, 0)):
assert_array_equal(row, cA1)
for row in cA2_2d.transpose((1, 0)):
assert_array_equal(row, cA2)
for row in cD1_2d.transpose((1, 0)):
assert_array_equal(row, cD1)
for row in cD2_2d.transpose((1, 0)):
assert_array_equal(row, cD2)
# axis too large
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
def test_swt_iswt_integration():
# This function performs a round-trip swt/iswt transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt or iswt as such, but it does ensure
# that they are each other's inverse.
max_level = 3
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet seems to be a bit special - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
def test_swt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
errmsg = "wrong dtype returned for {0} input".format(dt_in)
# swt
x = np.ones(8, dtype=dt_in)
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
"swt: " + errmsg)
# swt2
x = np.ones((8, 8), dtype=dt_in)
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
"swt2: " + errmsg)
def test_swt_roundtrip_dtypes():
# verify perfect reconstruction for all dtypes
rstate = np.random.RandomState(5)
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# swt, iswt
x = rstate.standard_normal((8, )).astype(dt_in)
c = pywt.swt(x, wavelet, level=2)
xr = pywt.iswt(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
# swt2, iswt2
x = rstate.standard_normal((8, 8)).astype(dt_in)
c = pywt.swt2(x, wavelet, level=2)
xr = pywt.iswt2(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
def test_swt_default_level_by_axis():
# make sure default number of levels matches the max level along the axis
wav = 'db2'
x = np.ones((2**3, 2**4, 2**5))
for axis in (0, 1, 2):
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
def test_swt2_ndim_error():
x = np.ones(8)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
@pytest.mark.slow
def test_swt2_iswt2_integration(wavelets=None):
# This function performs a round-trip swt2/iswt2 transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt2 or iswt2 as such, but it does ensure
# that they are each other's inverse.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length**2).reshape(input_length, input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt2(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
def test_swt2_iswt2_quick():
test_swt2_iswt2_integration(wavelets=['db1', ])
def test_swt2_iswt2_non_square(wavelets=None):
for nrows in [8, 16, 48]:
X = np.arange(nrows*32).reshape(nrows, 32)
current_wavelet = 'db1'
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
coeffs = pywt.swt2(X, current_wavelet, level=2)
Y = pywt.iswt2(coeffs, current_wavelet)
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
def test_swt2_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
# opposite order
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
axes=(1, 0))[0]
assert_allclose(cA1, cA2, atol=atol)
assert_allclose(cH1, cV2, atol=atol)
assert_allclose(cV1, cH2, atol=atol)
assert_allclose(cD1, cD2, atol=atol)
# duplicate axes not allowed
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
axes=(0, 0))
# too few axes
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
def test_iswt2_2d_only():
# iswt2 is not currently compatible with data that is not 2D
x_3d = np.ones((4, 4, 4))
c = pywt.swt2(x_3d, 'haar', level=1)
assert_raises(ValueError, pywt.iswt2, c, 'haar')
def test_swtn_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
# opposite order
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
# 0-level transform
empty = pywt.swtn(X, current_wavelet, level=0)
assert_equal(empty, [])
# duplicate axes not allowed
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
# data.ndim = 0
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
# start_level too large
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
level=1, start_level=2)
# level < 1 in swt_axis call
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
start_level=0)
# odd-sized data not allowed
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
start_level=0, axis=0)
@pytest.mark.slow
def test_swtn_iswtn_integration(wavelets=None):
# This function performs a round-trip swtn/iswtn transform for various
# possible combinations of:
# 1.) 1 out of 2 axes of a 2D array
# 2.) 2 out of 3 axes of a 3D array
#
# To keep test time down, only wavelets of length <= 8 are run.
#
# This test does not validate swtn or iswtn individually, but only
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for ndim_transform in range(1, 3):
ndim = ndim_transform + 1
for axes in combinations(range(ndim), ndim_transform):
for current_wavelet_str in wavelets:
wav = pywt.Wavelet(current_wavelet_str)
if wav.dec_len > 8:
continue # avoid excessive test duration
input_length_power = int(np.ceil(np.log2(max(
wav.dec_len,
wav.rec_len))))
N = 2**(input_length_power + max_level - 1)
X = np.arange(N**ndim).reshape((N, )*ndim)
for norm in [True, False]:
if norm and not wav.orthogonal:
# non-orthogonal wavelets to avoid warnings
continue
for trim_approx in [True, False]:
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
trim_approx=trim_approx, norm=norm)
coeffs_copy = deepcopy(coeffs)
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
# verify the inverse transform didn't modify any coeffs
for c, c2 in zip(coeffs, coeffs_copy):
for k, v in c.items():
assert_array_equal(c2[k], v)
def test_swtn_iswtn_quick():
test_swtn_iswtn_integration(wavelets=['db1', ])
def test_iswtn_errors():
x = np.arange(8**3).reshape(8, 8, 8)
max_level = 2
axes = (0, 1)
w = pywt.Wavelet('db1')
coeffs = pywt.swtn(x, w, max_level, axes=axes)
# more axes than dimensions transformed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
# duplicate axes not allowed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
# mismatched coefficient size
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
def test_swtn_iswtn_unique_shape_per_axis():
# test case for gh-460
_shape = (1, 48, 32) # unique shape per axis
wav = 'sym2'
max_level = 3
rstate = np.random.RandomState(0)
for shape in permutations(_shape):
# transform only along the non-singleton axes
axes = [ax for ax, s in enumerate(shape) if s != 1]
x = rstate.standard_normal(shape)
c = pywt.swtn(x, wav, max_level, axes=axes)
r = pywt.iswtn(c, wav, axes=axes)
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
def test_per_axis_wavelets():
# tests seperate wavelet for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(16, 16, 16)
level = 3
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
coefs = pywt.swtn(data, wavelets, level=level)
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
# 1-tuple also okay
coefs = pywt.swtn(data, wavelets[:1], level=level)
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
# length of wavelets doesn't match the length of axes
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
# swt2/iswt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.swt2(data2, wavelets[:2], level)
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_func, data, wavelet=cwave,
level=3)
c = dec_func(data, 'db1', level=3)
assert_raises(ValueError, rec_func, c, wavelet=cwave)
def test_iswt_mixed_dtypes():
# Mixed precision inputs give double precision output
x_real = np.arange(16).astype(np.float64)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
coeffs[0][1].astype(dtype2)]
y = pywt.iswt(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswt2_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt2(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
y = pywt.iswt2(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswtn_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swtn(x, wav, 2)
# different precision for the approximation coefficients
a = coeffs[0].pop('a' * x.ndim)
a = a.astype(dtype1)
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
coeffs[0]['a' * x.ndim] = a
y = pywt.iswtn(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_swt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.swt, [], 'db2')
# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
def test_swt_variance_and_energy_preservation():
"""Verify that the 1D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(256)
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
variances = [np.var(c) for c in coeffs]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeffs)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
def test_swt2_variance_and_energy_preservation():
"""Verify that the 2D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for v in d:
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
def test_swtn_variance_and_energy_preservation():
"""Verify that the nD SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for k, v in d.items():
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
def test_swt_ravel_and_unravel():
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
for ndim, _swt, _iswt, ravel_type in [
(1, pywt.swt, pywt.iswt, 'swt'),
(2, pywt.swt2, pywt.iswt2, 'swt2'),
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
x = np.ones((16, ) * ndim)
c = _swt(x, 'sym2', level=3, trim_approx=True)
arr, slices, shapes = pywt.ravel_coeffs(c)
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
r = _iswt(c, 'sym2')
assert_allclose(x, r)

View file

@ -0,0 +1,169 @@
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
import pywt
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
real_dtypes = [np.float32, np.float64]
def _sign(x):
# Matlab-like sign function (numpy uses a different convention).
return x / np.abs(x)
def _soft(x, thresh):
"""soft thresholding supporting complex values.
Notes
-----
This version is not robust to zeros in x.
"""
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
def test_threshold():
data = np.linspace(1, 4, 7)
# soft
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
assert_allclose(pywt.threshold(data, 2, 'soft'),
np.array(soft_result), rtol=1e-12)
assert_allclose(pywt.threshold(-data, 2, 'soft'),
-np.array(soft_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
[[0, 1]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
[[0, 0]] * 2, rtol=1e-12)
# soft thresholding complex values
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
[[0j, 1j]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
[[0, 0]] * 2, rtol=1e-12)
complex_data = [[1+2j, 2+2j]]*2
for thresh in [1, 2]:
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
_soft(complex_data, thresh), rtol=1e-12)
# test soft thresholding with non-default substitute argument
s = 5
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
[[s, 0.5]] * 2, rtol=1e-12)
# soft: no divide by zero warnings when input contains zeros
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
np.zeros(16), rtol=1e-12)
# hard
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
assert_allclose(pywt.threshold(data, 2, 'hard'),
np.array(hard_result), rtol=1e-12)
assert_allclose(pywt.threshold(-data, 2, 'hard'),
-np.array(hard_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
[[1, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
[[0, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
[[s, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
[[0, 2+2j]] * 2, rtol=1e-12)
# greater
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
assert_allclose(pywt.threshold(data, 2, 'greater'),
np.array(greater_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
[[1, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
[[0, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
[[s, 2]] * 2, rtol=1e-12)
# greater doesn't allow complex-valued inputs
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
# less
assert_allclose(pywt.threshold(data, 2, 'less'),
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
[[1, 0]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
[[1, s]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
[[1, 2]] * 2, rtol=1e-12)
# less doesn't allow complex-valued inputs
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
# invalid
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
def test_nonnegative_garotte():
thresh = 0.3
data_real = np.linspace(-1, 1, 100)
for dtype in float_dtypes:
if dtype in real_dtypes:
data = np.asarray(data_real, dtype=dtype)
else:
data = np.asarray(data_real + 0.1j, dtype=dtype)
d_hard = pywt.threshold(data, thresh, 'hard')
d_soft = pywt.threshold(data, thresh, 'soft')
d_garotte = pywt.threshold(data, thresh, 'garotte')
# check dtypes
assert_equal(d_hard.dtype, data.dtype)
assert_equal(d_soft.dtype, data.dtype)
assert_equal(d_garotte.dtype, data.dtype)
# values < threshold are zero
lt = np.where(np.abs(data) < thresh)
assert_(np.all(d_garotte[lt] == 0))
# values > than the threshold are intermediate between soft and hard
gt = np.where(np.abs(data) > thresh)
gt_abs_garotte = np.abs(d_garotte[gt])
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
def test_threshold_firm():
thresh = 0.2
thresh2 = 3 * thresh
data_real = np.linspace(-1, 1, 100)
for dtype in float_dtypes:
if dtype in real_dtypes:
data = np.asarray(data_real, dtype=dtype)
else:
data = np.asarray(data_real + 0.1j, dtype=dtype)
if data.real.dtype == np.float32:
rtol = atol = 1e-6
else:
rtol = atol = 1e-14
d_hard = pywt.threshold(data, thresh, 'hard')
d_soft = pywt.threshold(data, thresh, 'soft')
d_firm = pywt.threshold_firm(data, thresh, thresh2)
# check dtypes
assert_equal(d_hard.dtype, data.dtype)
assert_equal(d_soft.dtype, data.dtype)
assert_equal(d_firm.dtype, data.dtype)
# values < threshold are zero
lt = np.where(np.abs(data) < thresh)
assert_(np.all(d_firm[lt] == 0))
# values > than the threshold are equal to hard-thresholding
gt = np.where(np.abs(data) >= thresh2)
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
rtol=rtol, atol=atol)
# other values are intermediate between soft and hard thresholding
mt = np.where(np.logical_and(np.abs(data) > thresh,
np.abs(data) < thresh2))
mt_abs_firm = np.abs(d_firm[mt])
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))

View file

@ -0,0 +1,266 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_
import pywt
def test_wavelet_properties():
w = pywt.Wavelet('db3')
# Name
assert_(w.name == 'db3')
assert_(w.short_family_name == 'db')
assert_(w.family_name, 'Daubechies')
# String representation
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
'Biorthogonal', 'Symmetry')
for field in fields:
assert_(field in str(w))
# Filter coefficients
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
0.45987750211933, 0.80689150931334, 0.33267055295096]
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
-0.13501102001039, 0.08544127388224, 0.03522629188210]
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
-0.13501102001039, -0.08544127388224, 0.03522629188210]
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
-0.45987750211933, 0.80689150931334, -0.33267055295096]
assert_allclose(w.dec_lo, dec_lo)
assert_allclose(w.dec_hi, dec_hi)
assert_allclose(w.rec_lo, rec_lo)
assert_allclose(w.rec_hi, rec_hi)
assert_(len(w.filter_bank) == 4)
# Orthogonality
assert_(w.orthogonal)
assert_(w.biorthogonal)
# Symmetry
assert_(w.symmetry)
# Vanishing moments
assert_(w.vanishing_moments_phi == 0)
assert_(w.vanishing_moments_psi == 3)
def test_wavelet_coefficients():
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
for wavelet in wavelets:
if (pywt.Wavelet(wavelet).orthogonal):
check_coefficients_orthogonal(wavelet)
elif(pywt.Wavelet(wavelet).biorthogonal):
check_coefficients_biorthogonal(wavelet)
else:
check_coefficients(wavelet)
def check_coefficients_orthogonal(wavelet):
epsilon = 5e-11
level = 5
w = pywt.Wavelet(wavelet)
phi, psi, x = w.wavefun(level=level)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Scaling function integrates to unity
res = np.sum(phi) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Wavelet function is orthogonal to the scaling function at the same scale
res = np.sum(phi*psi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# The lowpass and highpass filter coefficients are orthogonal
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
def check_coefficients_biorthogonal(wavelet):
epsilon = 5e-11
level = 5
w = pywt.Wavelet(wavelet)
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Scaling function integrates to unity
res = np.sum(phi_d) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(phi_r) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
def check_coefficients(wavelet):
epsilon = 5e-11
level = 10
w = pywt.Wavelet(wavelet)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
class _CustomHaarFilterBank(object):
@property
def filter_bank(self):
val = np.sqrt(2) / 2
return ([val]*2, [-val, val], [val]*2, [val, -val])
def test_custom_wavelet():
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
filter_bank=_CustomHaarFilterBank())
haar_custom1.orthogonal = True
haar_custom1.biorthogonal = True
val = np.sqrt(2) / 2
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
filter_bank=filter_bank)
# check expected default wavelet properties
assert_(~haar_custom2.orthogonal)
assert_(~haar_custom2.biorthogonal)
assert_(haar_custom2.symmetry == 'unknown')
assert_(haar_custom2.family_name == '')
assert_(haar_custom2.short_family_name == '')
assert_(haar_custom2.vanishing_moments_phi == 0)
assert_(haar_custom2.vanishing_moments_psi == 0)
# Some properties can be set by the user
haar_custom2.orthogonal = True
haar_custom2.biorthogonal = True
def test_wavefun_sym3():
w = pywt.Wavelet('sym3')
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
phi, psi, x = w.wavefun(level=3)
assert_(phi.size == 41)
assert_(psi.size == 41)
assert_(x.size == 41)
assert_allclose(x, np.linspace(0, 5, num=x.size))
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00])
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00])
assert_allclose(phi, phi_expect)
assert_allclose(psi, psi_expect)
def test_wavefun_bior13():
w = pywt.Wavelet('bior1.3')
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
for arr in [phi_d, psi_d, phi_r, psi_r]:
assert_(arr.size == 40)
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
0.01367188, 0.00390625, -0.03515625, -0.12890625,
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
0.15234375, 0.37890625, 0.78515625, 0.99609375,
1.08203125, 1.13671875, 1.13671875, 1.08203125,
0.99609375, 0.78515625, 0.37890625, 0.15234375,
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
phi_r_expect = np.zeros(x.size, dtype=np.float)
phi_r_expect[15:23] = 1
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
0.015625, -0.015625, -0.140625, -0.109375,
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
psi_r_expect = np.zeros(x.size, dtype=np.float)
psi_r_expect[7:15] = -0.125
psi_r_expect[15:19] = 1
psi_r_expect[19:23] = -1
psi_r_expect[23:31] = 0.125
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)

View file

@ -0,0 +1,197 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_equal)
import pywt
def test_wavelet_packet_structure():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
assert_(wp.path == '')
assert_(wp.level == 0)
assert_(wp['ad'].maxlevel == 3)
def test_traversing_wp_tree():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp.maxlevel == 3)
# First level
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
7.778174593052, 10.606601717798]),
rtol=1e-12)
# Second level
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
# Third level
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
def test_acess_path():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp['a'].path == 'a')
assert_(wp['aa'].path == 'aa')
assert_(wp['aaa'].path == 'aaa')
# Maximum level reached:
assert_raises(IndexError, lambda: wp['aaaa'].path)
# Wrong path
assert_raises(ValueError, lambda: wp['ac'].path)
def test_access_node_atributes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
assert_(wp['ad'].path == 'ad')
assert_(wp['ad'].node_name == 'd')
assert_(wp['ad'].parent.path == 'a')
assert_(wp['ad'].level == 2)
assert_(wp['ad'].maxlevel == 3)
assert_(wp['ad'].mode == 'symmetric')
def test_collecting_nodes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# All nodes in natural order
assert_([node.path for node in wp.get_level(3, 'natural')] ==
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
# and in frequency order.
assert_([node.path for node in wp.get_level(3, 'freq')] ==
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
def test_reconstructing_data():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# Create another Wavelet Packet and feed it with some data.
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
new_wp['aa'] = wp['aa'].data
new_wp['ad'] = [-2., -2.]
# For convenience, :attr:`Node.data` gets automatically extracted
# from the :class:`Node` object:
new_wp['d'] = wp['d']
# Reconstruct data from aa, ad, and d packets.
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
# The node's :attr:`~Node.data` will not be updated
assert_(new_wp.data is None)
# When `update` is True:
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
['aa', 'ad', 'd'])
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
def test_removing_nodes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
wp.get_level(2)
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
for i in range(4):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
node = wp['ad']
del(wp['ad'])
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
for i in range(3):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
wp.reconstruct()
# The reconstruction is:
assert_allclose(wp.reconstruct(),
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
# Restore the data
wp['ad'].data = node.data
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
for i in range(4):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
def test_wavelet_packet_dtypes():
N = 32
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
x = np.random.randn(N).astype(dtype)
if np.iscomplexobj(x):
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# no unnecessary copy made
assert_(wp.data is x)
# assiging to a node should not change supported dtypes
wp['d'] = wp['d'].data
assert_equal(wp['d'].data.dtype, x.dtype)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstruction from coefficients should preserve dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, x.dtype)
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
# first element of the tuple is the input dtype
# second element of the tuple is the transform dtype
dtype_pairs = [(np.uint8, np.float64),
(np.intp, np.float64), ]
if hasattr(np, "complex256"):
dtype_pairs += [(np.complex256, np.complex128), ]
if hasattr(np, "half"):
dtype_pairs += [(np.half, np.float32), ]
for (dtype, transform_dtype) in dtype_pairs:
x = np.arange(N, dtype=dtype)
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
# no unnecessary copy made of top-level data
assert_(wp.data is x)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstructed data will have modified dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, transform_dtype)
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
def test_db3_roundtrip():
original = np.arange(512)
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)

View file

@ -0,0 +1,177 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_equal)
import pywt
def test_traversing_tree_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(np.all(wp.data == x))
assert_(wp.path == '')
assert_(wp.level == 0)
assert_(wp.maxlevel == 3)
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
assert_(wp['a']['a'].data is wp['aa'].data)
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
assert_raises(IndexError, lambda: wp['aaaa'])
assert_raises(ValueError, lambda: wp['f'])
def test_accessing_node_atributes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
assert_(wp['av'].path == 'av')
assert_(wp['av'].node_name == 'v')
assert_(wp['av'].parent.path == 'a')
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_(wp['av'].level == 2)
assert_(wp['av'].maxlevel == 3)
assert_(wp['av'].mode == 'symmetric')
def test_collecting_nodes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(len(wp.get_level(0)) == 1)
assert_(wp.get_level(0)[0].path == '')
# First level
assert_(len(wp.get_level(1)) == 4)
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
# Second level
assert_(len(wp.get_level(2)) == 16)
paths = [node.path for node in wp.get_level(2)]
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
assert_(paths == expected_paths)
# Third level.
assert_(len(wp.get_level(3)) == 64)
paths = [node.path for node in wp.get_level(3)]
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
assert_(paths == expected_paths)
def test_data_reconstruction_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
new_wp['vh'] = wp['vh'].data
new_wp['vv'] = wp['vh'].data
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
new_wp['h'] = wp['h'] # all zeros
assert_allclose(new_wp.reconstruct(update=False),
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
rtol=1e-12)
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
new_wp['va'] = wp['va'].data
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
def test_data_reconstruction_delete_nodes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
new_wp['vh'] = wp['vh'].data
new_wp['vv'] = wp['vh'].data
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
new_wp['h'] = wp['h'] # all zeros
assert_allclose(new_wp.reconstruct(update=False),
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
rtol=1e-12)
new_wp['va'] = wp['va'].data
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
del(new_wp['va'])
new_wp['va'] = wp['va'].data
assert_(new_wp.data is None)
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
assert_allclose(new_wp.data, x, rtol=1e-12)
# TODO: decompose=True
def test_lazy_evaluation_2D():
# Note: internal implementation detail not to be relied on. Testing for
# now for backwards compatibility, but this test may be broken in needed.
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(wp.a is None)
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
def test_wavelet_packet_dtypes():
shape = (16, 16)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
x = np.random.randn(*shape).astype(dtype)
if np.iscomplexobj(x):
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
# no unnecessary copy made
assert_(wp.data is x)
# assiging to a node should not change supported dtypes
wp['d'] = wp['d'].data
assert_equal(wp['d'].data.dtype, x.dtype)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstruction from coefficients should preserve dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, x.dtype)
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
def test_2d_roundtrip():
# test case corresponding to PyWavelets issue 447
original = pywt.data.camera()
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)