Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,96 @@
|
|||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw')]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
# Matlab result
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = ("[ma, md] = dwt(data, wavelet, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
|
||||
# Matlab result
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)
|
|
@ -0,0 +1,86 @@
|
|||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
Scales = (1,np.arange(1,3))
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
all_matlab_results[psi_key] = psi
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
|
||||
# Matlab result
|
||||
scale_count = 0
|
||||
for scales in Scales:
|
||||
scale_count += 1
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
|
||||
all_matlab_results[coefs_key] = coefs
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)
|
BIN
venv/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
Normal file
BIN
venv/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
Normal file
Binary file not shown.
170
venv/Lib/site-packages/pywt/tests/test__pywt.py
Normal file
170
venv/Lib/site-packages/pywt/tests/test__pywt.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_upcoef_reconstruct():
|
||||
data = np.arange(3)
|
||||
a = pywt.downcoef('a', data, 'haar')
|
||||
d = pywt.downcoef('d', data, 'haar')
|
||||
|
||||
rec = (pywt.upcoef('a', a, 'haar', take=3) +
|
||||
pywt.upcoef('d', d, 'haar', take=3))
|
||||
assert_allclose(rec, data)
|
||||
|
||||
|
||||
def test_downcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.downcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_downcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16) + 1j * rstate.randn(16)
|
||||
nlevels = 3
|
||||
a = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_downcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
|
||||
|
||||
|
||||
def test_compare_downcoef_coeffs():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
# compare downcoef against wavedec outputs
|
||||
for nlevels in [1, 2, 3]:
|
||||
for wavelet in pywt.wavelist():
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, pywt.Wavelet):
|
||||
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
|
||||
if nlevels <= max_level:
|
||||
a = pywt.downcoef('a', r, wavelet, level=nlevels)
|
||||
d = pywt.downcoef('d', r, wavelet, level=nlevels)
|
||||
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
|
||||
assert_allclose(a, coeffs[0])
|
||||
assert_allclose(d, coeffs[1])
|
||||
|
||||
|
||||
def test_upcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.upcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_upcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4) + 1j*rstate.randn(4)
|
||||
nlevels = 3
|
||||
a = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_upcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
|
||||
|
||||
|
||||
def test_upcoef_and_downcoef_1d_only():
|
||||
# upcoef and downcoef raise a ValueError if data.ndim > 1d
|
||||
for ndim in [2, 3]:
|
||||
data = np.ones((8, )*ndim)
|
||||
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
|
||||
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
|
||||
|
||||
|
||||
def test_wavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.Wavelet('sym8')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_dwt_max_level():
|
||||
assert_(pywt.dwt_max_level(16, 2) == 4)
|
||||
assert_(pywt.dwt_max_level(16, 8) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 9) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 10) == 0)
|
||||
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 10.) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 18) == 0)
|
||||
|
||||
# accepts discrete Wavelet object or string as well
|
||||
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
|
||||
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
|
||||
|
||||
# string input that is not a discrete wavelet
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
|
||||
|
||||
# filter_len must be an integer >= 2
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
|
||||
|
||||
|
||||
def test_ContinuousWavelet_errs():
|
||||
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
|
||||
|
||||
|
||||
def test_ContinuousWavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.ContinuousWavelet('gaus2')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_wavelist():
|
||||
for name in pywt.wavelist(family='coif'):
|
||||
assert_(name.startswith('coif'))
|
||||
|
||||
assert_('cgau7' in pywt.wavelist(kind='continuous'))
|
||||
assert_('sym20' in pywt.wavelist(kind='discrete'))
|
||||
assert_(len(pywt.wavelist(kind='continuous')) +
|
||||
len(pywt.wavelist(kind='discrete')) ==
|
||||
len(pywt.wavelist(kind='all')))
|
||||
|
||||
assert_raises(ValueError, pywt.wavelist, kind='foobar')
|
||||
|
||||
|
||||
def test_wavelet_errormsgs():
|
||||
try:
|
||||
pywt.Wavelet('gaus1')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0].startswith('The `Wavelet` class'))
|
||||
try:
|
||||
pywt.Wavelet('cmord')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")
|
105
venv/Lib/site-packages/pywt/tests/test_concurrent.py
Normal file
105
venv/Lib/site-packages/pywt/tests/test_concurrent.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Tests used to verify running PyWavelets transforms in parallel via
|
||||
concurrent.futures.ThreadPoolExecutor does not raise errors.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from numpy.testing import assert_array_equal, assert_allclose
|
||||
from pywt._pytest import uses_futures, futures, max_workers
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def _assert_all_coeffs_equal(coefs1, coefs2):
|
||||
# return True only if all coefficients of SWT or DWT match over all levels
|
||||
if len(coefs1) != len(coefs2):
|
||||
return False
|
||||
for (c1, c2) in zip(coefs1, coefs2):
|
||||
if isinstance(c1, tuple):
|
||||
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
|
||||
for a1, a2 in zip(c1, c2):
|
||||
assert_array_equal(a1, a2)
|
||||
elif isinstance(c1, dict):
|
||||
# for swtn, dwtn, wavedecn
|
||||
for k, v in c1.items():
|
||||
assert_array_equal(v, c2[k])
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_swt():
|
||||
# tests error-free concurrent operation (see gh-288)
|
||||
# swt on 1D data calls the Cython swt
|
||||
# other cases call swt_axes
|
||||
with warnings.catch_warnings():
|
||||
# can remove catch_warnings once the swt2 FutureWarning is removed
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(swt_func, wavelet='haar', level=3)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_wavedec():
|
||||
# wavedec on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(wavedec_func, wavelet='haar', level=1)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_dwt():
|
||||
# dwt on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(dwt_func, wavelet='haar')
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_cwt():
|
||||
atol = rtol = 1e-14
|
||||
time, sst = pywt.data.nino()
|
||||
dt = time[1]-time[0]
|
||||
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
|
||||
sampling_period=dt)
|
||||
for _ in range(10):
|
||||
arrs = [sst.copy() for _ in range(50)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(sst)
|
||||
for a1, a2 in zip(expected_result, results[-1]):
|
||||
assert_allclose(a1, a2, atol=atol, rtol=rtol)
|
434
venv/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
Normal file
434
venv/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
Normal file
|
@ -0,0 +1,434 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
from itertools import product
|
||||
|
||||
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
|
||||
assert_raises, assert_equal)
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
|
||||
def ref_gaus(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
|
||||
if (num == 1):
|
||||
psi = -2.*X*F0
|
||||
elif (num == 2):
|
||||
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
|
||||
elif (num == 3):
|
||||
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
|
||||
elif (num == 4):
|
||||
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
|
||||
elif (num == 5):
|
||||
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
|
||||
elif (num == 6):
|
||||
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
|
||||
elif (num == 7):
|
||||
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
|
||||
elif (num == 8):
|
||||
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
|
||||
224*X**6 + 16*X**8)*F0
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def ref_cgau(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = np.exp(-X**2)
|
||||
F1 = np.exp(-1j*X)
|
||||
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
|
||||
if (num == 1):
|
||||
psi = F2*(-1j - 2*X)*2**(1/2)
|
||||
elif (num == 2):
|
||||
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
|
||||
elif (num == 3):
|
||||
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
|
||||
elif (num == 4):
|
||||
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
|
||||
elif (num == 5):
|
||||
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
|
||||
80j*X**4 - 32*X**5)*210**(1/2)
|
||||
elif (num == 6):
|
||||
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
|
||||
192j*X**5 + 64*X**6)*2310**(1/2)
|
||||
elif (num == 7):
|
||||
psi = 1/45045*F2*(
|
||||
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
|
||||
448j*X**6 - 128*X**7)*30030**(1/2)
|
||||
elif (num == 8):
|
||||
psi = 1/45045*F2*(
|
||||
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
|
||||
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
|
||||
|
||||
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def sinc2(x):
|
||||
y = np.ones_like(x)
|
||||
k = np.where(x)[0]
|
||||
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
|
||||
return y
|
||||
|
||||
|
||||
def ref_shan(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_fbsp(LB, UB, N, m, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_cmor(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_morl(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.exp(-(x**2)/2)*np.cos(5*x)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_mexh(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def test_gaus():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_gaus(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("gaus" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cgau():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_cgau(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("cgau" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_shan():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_cmor():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_fbsp():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 3
|
||||
Fb = 1.5
|
||||
Fc = 1.2
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
# TODO: investigate why atol = 1e-5 is necessary
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_morl():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_morl(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("morl")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_mexh():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1001
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cwt_parameters_in_names():
|
||||
|
||||
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
|
||||
for name in ['fbsp', 'cmor', 'shan']:
|
||||
# additional parameters should be specified within the name
|
||||
assert_warns(FutureWarning, func, name)
|
||||
|
||||
for name in ['cmor', 'shan']:
|
||||
# valid names
|
||||
func(name + '1.5-1.0')
|
||||
func(name + '1-4')
|
||||
|
||||
# invalid names
|
||||
assert_raises(ValueError, func, name + '1.0')
|
||||
assert_raises(ValueError, func, name + 'B-C')
|
||||
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
|
||||
|
||||
# valid names
|
||||
func('fbsp1-1.5-1.0')
|
||||
func('fbsp1.0-1.5-1')
|
||||
func('fbsp2-5-1')
|
||||
|
||||
# invalid name (non-integer order)
|
||||
assert_raises(ValueError, func, 'fbsp1.5-1-1')
|
||||
assert_raises(ValueError, func, 'fbspM-B-C')
|
||||
|
||||
# invalid name (too few or too many params)
|
||||
assert_raises(ValueError, func, 'fbsp1.0')
|
||||
assert_raises(ValueError, func, 'fbsp1.0-0.4')
|
||||
assert_raises(ValueError, func, 'fbsp1-1-1-1')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype, tol, method',
|
||||
[(np.float32, 1e-5, 'conv'),
|
||||
(np.float32, 1e-5, 'fft'),
|
||||
(np.float64, 1e-13, 'conv'),
|
||||
(np.float64, 1e-13, 'fft')])
|
||||
def test_cwt_complex(dtype, tol, method):
|
||||
time, sst = pywt.data.nino()
|
||||
sst = np.asarray(sst, dtype=dtype)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# real-valued tranfsorm as a reference
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# complex-valued transform equals sum of the transforms of the real
|
||||
# and imaginary components
|
||||
sst_complex = sst + 1j*sst
|
||||
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
|
||||
method=method)
|
||||
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
|
||||
# verify dtype is preserved
|
||||
assert_equal(cfs_complex.dtype, sst_complex.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
|
||||
def test_cwt_batch(axis, method):
|
||||
dtype = np.float64
|
||||
time, sst = pywt.data.nino()
|
||||
n_batch = 8
|
||||
batch_axis = 1 - axis
|
||||
sst1 = np.asarray(sst, dtype=dtype)
|
||||
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# non-batch transform as reference
|
||||
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
shape_in = sst.shape
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
# shape of input is not modified
|
||||
assert_equal(shape_in, sst.shape)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# verify expected shape
|
||||
assert_equal(cfs.shape[0], len(scales))
|
||||
assert_equal(cfs.shape[1 + batch_axis], n_batch)
|
||||
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
|
||||
|
||||
# batch result on stacked input is the same as stacked 1d result
|
||||
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
|
||||
|
||||
|
||||
def test_cwt_small_scales():
|
||||
data = np.zeros(32)
|
||||
|
||||
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
|
||||
# mexh. This corner case should not raise an error.
|
||||
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
|
||||
assert_allclose(cfs, np.zeros_like(cfs))
|
||||
|
||||
# extremely short scale factors raise a ValueError
|
||||
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
|
||||
|
||||
|
||||
def test_cwt_method_fft():
|
||||
rstate = np.random.RandomState(1)
|
||||
data = rstate.randn(50)
|
||||
data[15] = 1.
|
||||
scales = np.arange(1, 64)
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
|
||||
# build a reference cwt with the legacy np.conv() method
|
||||
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
|
||||
|
||||
# compare with the fft based convolution
|
||||
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
|
||||
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)
|
77
venv/Lib/site-packages/pywt/tests/test_data.py
Normal file
77
venv/Lib/site-packages/pywt/tests/test_data.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_
|
||||
|
||||
import pywt.data
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
|
||||
wavelab_result_dict = np.load(wavelab_data_file)
|
||||
|
||||
|
||||
def test_data_aero():
|
||||
aero = pywt.data.aero()
|
||||
|
||||
ref = np.array([[178, 178, 179],
|
||||
[170, 173, 171],
|
||||
[185, 174, 171]])
|
||||
|
||||
assert_allclose(aero[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ascent():
|
||||
ascent = pywt.data.ascent()
|
||||
|
||||
ref = np.array([[83, 83, 83],
|
||||
[82, 82, 83],
|
||||
[80, 81, 83]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_camera():
|
||||
ascent = pywt.data.camera()
|
||||
|
||||
ref = np.array([[156, 157, 160],
|
||||
[156, 157, 159],
|
||||
[158, 157, 156]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ecg():
|
||||
ecg = pywt.data.ecg()
|
||||
|
||||
ref = np.array([-86, -87, -87])
|
||||
|
||||
assert_allclose(ecg[:3], ref)
|
||||
|
||||
|
||||
def test_wavelab_signals():
|
||||
"""Comparison with results generated using WaveLab"""
|
||||
rtol = atol = 1e-12
|
||||
|
||||
# get a list of the available signals
|
||||
available_signals = pywt.data.demo_signal('list')
|
||||
assert_('Doppler' in available_signals)
|
||||
|
||||
for signal in available_signals:
|
||||
# reference dictionary has lowercase names for the keys
|
||||
key = signal.replace('-', '_').lower()
|
||||
val = wavelab_result_dict[key]
|
||||
if key in ['gabor', 'sineoneoverx']:
|
||||
# these functions do not allow a size to be provided
|
||||
assert_allclose(val, pywt.data.demo_signal(signal),
|
||||
rtol=rtol, atol=atol)
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
|
||||
else:
|
||||
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
|
||||
rtol=rtol, atol=atol)
|
||||
# these functions require a size to be provided
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key)
|
||||
|
||||
# ValueError on unrecognized signal type
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
|
||||
|
||||
# ValueError on invalid length
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)
|
89
venv/Lib/site-packages/pywt/tests/test_deprecations.py
Normal file
89
venv/Lib/site-packages/pywt/tests/test_deprecations.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_warns, assert_array_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_intwave_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
|
||||
|
||||
|
||||
def test_centrfrq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
|
||||
|
||||
|
||||
def test_scal2frq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
|
||||
|
||||
|
||||
def test_orthfilt_deprecation():
|
||||
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
|
||||
|
||||
|
||||
def test_integrate_wave_tuple():
|
||||
sig = [0, 1, 2, 3]
|
||||
xgrid = [0, 1, 2, 3]
|
||||
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
|
||||
|
||||
|
||||
old_modes = ['zpd',
|
||||
'cpd',
|
||||
'sym',
|
||||
'ppd',
|
||||
'sp1',
|
||||
'per',
|
||||
]
|
||||
|
||||
|
||||
def test_MODES_from_object_deprecation():
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
|
||||
|
||||
|
||||
def test_MODES_attributes_deprecation():
|
||||
def get_mode(Modes, name):
|
||||
return getattr(Modes, name)
|
||||
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
|
||||
|
||||
|
||||
def test_MODES_deprecation_new():
|
||||
def use_MODES_new():
|
||||
return pywt.MODES.symmetric
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_MODES_deprecation_old():
|
||||
def use_MODES_old():
|
||||
return pywt.MODES.sym
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_old)
|
||||
|
||||
|
||||
def test_MODES_deprecation_getattr():
|
||||
def use_MODES_new():
|
||||
return getattr(pywt.MODES, 'symmetric')
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_mode_equivalence():
|
||||
old_new = [('zpd', 'zero'),
|
||||
('cpd', 'constant'),
|
||||
('sym', 'symmetric'),
|
||||
('ppd', 'periodic'),
|
||||
('sp1', 'smooth'),
|
||||
('per', 'periodization')]
|
||||
x = np.arange(8.)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
for old, new in old_new:
|
||||
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
|
||||
pywt.dwt(x, 'db2', mode=new))
|
25
venv/Lib/site-packages/pywt/tests/test_doc.py
Normal file
25
venv/Lib/site-packages/pywt/tests/test_doc.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import doctest
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
np.set_printoptions(legacy='1.13')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
pdir = os.path.pardir
|
||||
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
pdir, pdir, "doc", "source"))
|
||||
|
||||
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
|
||||
glob.glob(os.path.join(docs_base, "*", "*.rst"))
|
||||
|
||||
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.TextTestRunner().run(suite)
|
299
venv/Lib/site-packages/pywt/tests/test_dwt_idwt.py
Normal file
299
venv/Lib/site-packages/pywt/tests/test_dwt_idwt.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_array_equal)
|
||||
import pywt
|
||||
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwt_idwt_basic():
|
||||
x = [3, 7, 1, 1, -2, 5, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
|
||||
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.float64)
|
||||
|
||||
|
||||
def test_idwt_mixed_complex_dtype():
|
||||
x = np.arange(8).astype(float)
|
||||
x = x + 1j*x[::-1]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_dwt_idwt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones(4, dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet)
|
||||
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
|
||||
|
||||
|
||||
def test_dwt_idwt_basic_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
|
||||
7.77817459])
|
||||
cA_expect = cA_expect + 0.5j*cA_expect
|
||||
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487])
|
||||
cD_expect = cD_expect + 0.5j*cD_expect
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_idwt_partial_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
|
||||
cA, cD = pywt.dwt(x, 'haar')
|
||||
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
|
||||
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
|
||||
cA_rec = pywt.idwt(cA, None, 'haar')
|
||||
assert_allclose(cA_rec, cA_rec_expect)
|
||||
|
||||
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
|
||||
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
|
||||
cD_rec = pywt.idwt(None, cD, 'haar')
|
||||
assert_allclose(cD_rec, cD_rec_expect)
|
||||
|
||||
assert_allclose(cA_rec + cD_rec, x)
|
||||
|
||||
|
||||
def test_dwt_wavelet_kwd():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
|
||||
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
|
||||
7.81994027]
|
||||
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
|
||||
-0.09722957, -0.07045258]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
|
||||
def test_dwt_coeff_len():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
|
||||
expected_result = [6, ] * len(pywt.Modes.modes)
|
||||
expected_result[pywt.Modes.modes.index('periodization')] = 4
|
||||
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
|
||||
|
||||
def test_idwt_none_input():
|
||||
# None input equals arrays of zeros of the right length
|
||||
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
# Only one argument at a time can be None
|
||||
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
|
||||
|
||||
|
||||
def test_idwt_invalid_input():
|
||||
# Too short, min length is 4 for 'db4':
|
||||
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
|
||||
|
||||
|
||||
def test_dwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
cA0, cD0 = pywt.dwt(x[0], 'db2')
|
||||
cA1, cD1 = pywt.dwt(x[1], 'db2')
|
||||
|
||||
assert_allclose(cA[0], cA0)
|
||||
assert_allclose(cA[1], cA1)
|
||||
|
||||
assert_allclose(cD[0], cD0)
|
||||
assert_allclose(cD[1], cD1)
|
||||
|
||||
|
||||
def test_idwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
x = np.asarray(x)
|
||||
x = x + 1j*x # test with complex data
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
|
||||
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
|
||||
|
||||
assert_allclose(x[0], x0)
|
||||
assert_allclose(x[1], x1)
|
||||
|
||||
|
||||
def test_dwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
assert_allclose(cA_, cA)
|
||||
assert_allclose(cD_, cD)
|
||||
|
||||
|
||||
def test_idwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
|
||||
x = pywt.idwt(cA, cD, 'db2', axis=1)
|
||||
|
||||
assert_allclose(x_, x)
|
||||
|
||||
|
||||
def test_dwt_idwt_axis_excess():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
# can't transform over axes that aren't there
|
||||
assert_raises(ValueError,
|
||||
pywt.dwt, x, 'db2', 'symmetric', axis=2)
|
||||
|
||||
assert_raises(ValueError,
|
||||
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((32, ))
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, pywt.dwt, data, cwave)
|
||||
|
||||
cA, cD = pywt.dwt(data, 'db1')
|
||||
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
|
||||
|
||||
|
||||
def test_dwt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.dwt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
|
||||
|
||||
|
||||
def test_pad_1d():
|
||||
x = [1, 2, 3]
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
|
||||
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
|
||||
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
|
||||
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
|
||||
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
|
||||
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
|
||||
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
|
||||
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
# equivalence of various pad_width formats
|
||||
assert_array_equal(pywt.pad(x, 4, 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
|
||||
def test_pad_errors():
|
||||
# negative pad width
|
||||
x = [1, 2, 3]
|
||||
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
|
||||
|
||||
# wrong length pad width
|
||||
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
|
||||
|
||||
# invalid mode name
|
||||
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
|
||||
|
||||
|
||||
def test_pad_nd():
|
||||
for ndim in [2, 3]:
|
||||
x = np.arange(4**ndim).reshape((4, ) * ndim)
|
||||
if ndim == 2:
|
||||
pad_widths = [(2, 1), (2, 3)]
|
||||
else:
|
||||
pad_widths = [(2, 1), ] * ndim
|
||||
for mode in pywt.Modes.modes:
|
||||
xp = pywt.pad(x, pad_widths, mode)
|
||||
|
||||
# expected result is the same as applying along axes separably
|
||||
xp_expected = x.copy()
|
||||
for ax in range(ndim):
|
||||
xp_expected = np.apply_along_axis(pywt.pad,
|
||||
ax,
|
||||
xp_expected,
|
||||
pad_widths=[pad_widths[ax]],
|
||||
mode=mode)
|
||||
assert_array_equal(xp, xp_expected)
|
38
venv/Lib/site-packages/pywt/tests/test_functions.py
Normal file
38
venv/Lib/site-packages/pywt/tests/test_functions.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from numpy.testing import assert_almost_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_centrfreq():
|
||||
# db1 is Haar function, frequency=1
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
# db2, frequency=2/3
|
||||
w = pywt.Wavelet('db2')
|
||||
expected = 2/3.
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected)
|
||||
|
||||
|
||||
def test_scal2frq_scale():
|
||||
scale = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / scale
|
||||
result = pywt.scale2frequency(w, scale, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
|
||||
def test_intwave_orthogonal():
|
||||
w = pywt.Wavelet('db1')
|
||||
int_psi, x = pywt.integrate_wavelet(w, precision=12)
|
||||
ix = x < 0.5
|
||||
# For x < 0.5, the integral is equal to x
|
||||
assert_allclose(int_psi[ix], x[ix])
|
||||
# For x > 0.5, the integral is equal to (1 - x)
|
||||
# Ignore last point here, there x > 1 and something goes wrong
|
||||
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)
|
160
venv/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
Normal file
160
venv/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
Test used to verify PyWavelets Discrete Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
|
||||
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
# TODO: Now have implemented asymmetric modes too.
|
||||
# Would be nice to update the Matlab data to test these as well.
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw'),
|
||||
]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
@uses_pymatbridge
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _load_matlab_result(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, mmode, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
pa, pd = pywt.dwt(data, w, pmode)
|
||||
|
||||
# calculate error measures
|
||||
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
|
||||
rms_d = np.sqrt(np.mean((pd - md) ** 2))
|
||||
|
||||
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
|
||||
assert_(rms_a < epsilon, msg=msg)
|
||||
|
||||
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
|
||||
assert_(rms_d < epsilon, msg=msg)
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Test used to verify PyWavelets Continuous Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
|
||||
matlab_result_dict_cwt)
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
def _get_scales(w):
|
||||
""" Return the scales to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
|
||||
else:
|
||||
scales = (1, np.arange(1, 3))
|
||||
return scales
|
||||
|
||||
|
||||
@uses_pymatbridge # skip this case if precomputed results are used instead
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge_cwt():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 1e-15
|
||||
epsilon_psi = 1e-15
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for scales in _get_scales(w):
|
||||
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed_cwt():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# has to be improved
|
||||
epsilon = 2e-15
|
||||
epsilon32 = 1e-5
|
||||
epsilon_psi = 1e-15
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
|
||||
psi = _load_matlab_result_psi(wavelet)
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
data32 = data.astype(np.float32)
|
||||
scales_count = 0
|
||||
for scales in _get_scales(w):
|
||||
scales_count += 1
|
||||
coefs = _load_matlab_result(data, wavelet, scales_count)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, scales, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, scales):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
|
||||
if (coefs_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
|
||||
coefs = matlab_result_dict_cwt[coefs_key]
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result_psi(wavelet):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
if (psi_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab psi result not found for wavelet: "
|
||||
"{0}}".format(wavelet))
|
||||
psi = matlab_result_dict_cwt[psi_key]
|
||||
return psi
|
||||
|
||||
|
||||
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
coefs_pywt, freq = pywt.cwt(data, scales, w)
|
||||
|
||||
# coefs from Matlab are from R2012a which is missing the complex conjugate
|
||||
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
|
||||
# the precomputed Matlab result to account for this.
|
||||
coefs = np.conj(coefs)
|
||||
|
||||
# calculate error measures
|
||||
err = coefs_pywt - coefs
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
|
||||
|
||||
def _check_accuracy_psi(w, psi, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
psi_pywt, x = w.wavefun(length=1024)
|
||||
|
||||
# calculate error measures
|
||||
err = psi_pywt.flatten() - psi.flatten()
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Wavelet: %s, '
|
||||
'rms=%.3g' % (wavelet, rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
109
venv/Lib/site-packages/pywt/tests/test_modes.py
Normal file
109
venv/Lib/site-packages/pywt/tests/test_modes.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_raises, assert_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_available_modes():
|
||||
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
|
||||
'periodization', 'reflect', 'antisymmetric', 'antireflect']
|
||||
assert_equal(pywt.Modes.modes, modes)
|
||||
assert_equal(pywt.Modes.from_object('constant'), 2)
|
||||
|
||||
|
||||
def test_invalid_modes():
|
||||
x = np.arange(4)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
|
||||
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
|
||||
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
|
||||
assert_raises(ValueError, pywt.Modes.from_object, -1)
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 9)
|
||||
assert_raises(TypeError, pywt.Modes.from_object, None)
|
||||
|
||||
|
||||
def test_dwt_idwt_allmodes():
|
||||
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
dwt_results = {
|
||||
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
|
||||
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.8625013]),
|
||||
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.51935555],
|
||||
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
|
||||
0.25881905]),
|
||||
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.77817459],
|
||||
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
|
||||
1.22474487]),
|
||||
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.81224877],
|
||||
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
|
||||
-2.38013939]),
|
||||
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.9162743],
|
||||
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.99191082]),
|
||||
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.45000519],
|
||||
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
|
||||
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
|
||||
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
|
||||
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.12372436],
|
||||
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
|
||||
-4.94974747]),
|
||||
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
|
||||
8.22646233],
|
||||
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
|
||||
2.89777748])
|
||||
}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
cA, cD = pywt.dwt(x, 'db2', mode)
|
||||
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_short_input_allmodes():
|
||||
# some test cases where the input is shorter than the DWT filter
|
||||
x = [1, 3, 2]
|
||||
wavelet = 'db2'
|
||||
# manually pad each end by the filter size (4 for 'db2' used here)
|
||||
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
|
||||
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
|
||||
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
|
||||
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
|
||||
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
|
||||
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
|
||||
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
|
||||
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
|
||||
}
|
||||
for mode, xpad in padded_x.items():
|
||||
# DWT of the manually padded array. will discard edges later so
|
||||
# symmetric mode used here doesn't matter.
|
||||
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
|
||||
|
||||
# central region of the padded output (unaffected by mode )
|
||||
expected_result = (cApad[2:-2], cDpad[2:-2])
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet, mode)
|
||||
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
|
||||
|
||||
|
||||
def test_default_mode():
|
||||
# The default mode should be 'symmetric'
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
|
||||
assert_allclose(cA, cA2)
|
||||
assert_allclose(cD, cD2)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
|
443
venv/Lib/site-packages/pywt/tests/test_multidim.py
Normal file
443
venv/Lib/site-packages/pywt/tests/test_multidim.py
Normal file
|
@ -0,0 +1,443 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from itertools import combinations
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
|
||||
|
||||
import pywt
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwtn_input():
|
||||
# Array-like must be accepted
|
||||
pywt.dwtn([1, 2, 3, 4], 'haar')
|
||||
# Others must not
|
||||
data = dict()
|
||||
assert_raises(TypeError, pywt.dwtn, data, 'haar')
|
||||
# Must be at least 1D
|
||||
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
|
||||
|
||||
|
||||
def test_3D_reconstruct():
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for mode in pywt.Modes.modes:
|
||||
d = pywt.dwtn(data, wavelet, mode=mode)
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_dwdtn_idwtn_allwavelets():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16, 16)
|
||||
# test 2D case only for all wavelet types
|
||||
wavelist = pywt.wavelist()
|
||||
if 'dmey' in wavelist:
|
||||
wavelist.remove('dmey')
|
||||
for wavelet in wavelist:
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
|
||||
for mode in pywt.Modes.modes:
|
||||
coeffs = pywt.dwtn(r, wavelet, mode=mode)
|
||||
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
|
||||
r, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_stride():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
strided = np.ones((3, 12), dtype=data.dtype)
|
||||
strided[::-1, ::2] = data
|
||||
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(strided_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_byte_offset():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
|
||||
'pad': ('byte', data.dtype.itemsize)},
|
||||
align=True))
|
||||
padded[:] = data
|
||||
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(padded_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_3D_reconstruct_complex():
|
||||
# All dimensions even length so `take` does not need to be specified
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
data = data + 1j
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
d = pywt.dwtn(data, wavelet)
|
||||
# idwtn creates even-length shapes (2x dwtn size)
|
||||
original_shape = tuple([slice(None, s) for s in data.shape])
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_idwtn_idwt2():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_idwt2_complex():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
data = data + 1j
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_missing():
|
||||
# Test to confirm missing data behave as zeroes
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
coefs = pywt.dwtn(data, wavelet)
|
||||
|
||||
# No point removing zero, or all
|
||||
for num_missing in range(1, len(coefs)):
|
||||
for missing in combinations(coefs.keys(), num_missing):
|
||||
missing_coefs = coefs.copy()
|
||||
for key in missing:
|
||||
del missing_coefs[key]
|
||||
LL = missing_coefs.get('aa', None)
|
||||
HL = missing_coefs.get('da', None)
|
||||
LH = missing_coefs.get('ad', None)
|
||||
HH = missing_coefs.get('dd', None)
|
||||
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
|
||||
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
|
||||
|
||||
|
||||
def test_idwtn_all_coeffs_None():
|
||||
coefs = dict(aa=None, da=None, ad=None, dd=None)
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
|
||||
|
||||
|
||||
def test_error_on_invalid_keys():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# unexpected key
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
# mismatched key lengths
|
||||
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_error_mismatched_size():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# Pass/fail depends on first element being shorter than remaining ones so
|
||||
# set 3/4 to an incorrect size to maximize chances. Order of dict items
|
||||
# is random so may not trigger on every test run. Dict is constructed
|
||||
# inside idwtn function so no use using an OrderedDict here.
|
||||
LL = LL[:, :-1]
|
||||
LH = LH[:, :-1]
|
||||
HH = HH[:, :-1]
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_dwt2_idwt2_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
|
||||
"dwt2: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
|
||||
|
||||
|
||||
def test_dwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1,))
|
||||
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
|
||||
assert_equal(coefs['a'], expected_a)
|
||||
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
|
||||
assert_equal(coefs['d'], expected_d)
|
||||
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
|
||||
assert_equal(coefs['aa'], expected_aa)
|
||||
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
|
||||
assert_equal(coefs['ad'], expected_ad)
|
||||
|
||||
|
||||
def test_idwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwt2_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
cD = np.zeros_like(cD)
|
||||
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
cD = None
|
||||
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwtn_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
coefs['dd'] = np.zeros_like(coefs['dd'])
|
||||
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
coefs['dd'] = None
|
||||
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwt2_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
# too many axes
|
||||
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_idwt2_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4)))
|
||||
# test all combinations of 2 out of 3 axes transformed
|
||||
for axes in combinations((0, 1, 2), 2):
|
||||
coefs = pywt.dwt2(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
|
||||
# test all combinations of 3 out of 4 axes transformed
|
||||
for axes in combinations((0, 1, 2, 3), 3):
|
||||
coefs = pywt.dwtn(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_negative_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
|
||||
assert_equal(coefs1, coefs2)
|
||||
|
||||
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
|
||||
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
|
||||
assert_equal(rec1, rec2)
|
||||
|
||||
|
||||
def test_dwtn_idwtn_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
coeffs = pywt.dwtn(x, wavelet)
|
||||
for k, v in coeffs.items():
|
||||
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
|
||||
|
||||
|
||||
def test_idwtn_mixed_complex_dtype():
|
||||
rstate = np.random.RandomState(0)
|
||||
x = rstate.randn(8, 8, 8)
|
||||
x = x + 1j*x
|
||||
coeffs = pywt.dwtn(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
|
||||
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_idwt2_size_mismatch_error():
|
||||
LL = np.zeros((6, 6))
|
||||
LH = HL = HH = np.zeros((5, 5))
|
||||
|
||||
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
|
||||
|
||||
|
||||
def test_dwt2_dimension_error():
|
||||
data = np.ones(16)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
# wrong number of input dimensions
|
||||
assert_raises(ValueError, pywt.dwt2, data, wavelet)
|
||||
|
||||
# too many axes
|
||||
data2 = np.ones((8, 8))
|
||||
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_per_axis_wavelets_and_modes():
|
||||
# tests seperate wavelet and edge mode for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
# mode can be a string or a Modes enum
|
||||
modes = ('symmetric', 'periodization',
|
||||
pywt._extensions._pywt.Modes.reflect)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets[:1], modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes[:1])
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets or modes doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
|
||||
|
||||
# dwt2/idwt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
|
||||
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
|
||||
atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
|
||||
[pywt.idwt2, pywt.idwtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
|
||||
|
||||
c = dec_fun(data, 'db1')
|
||||
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
|
1033
venv/Lib/site-packages/pywt/tests/test_multilevel.py
Normal file
1033
venv/Lib/site-packages/pywt/tests/test_multilevel.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Verify DWT perfect reconstruction.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_perfect_reconstruction():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
dtypes = (np.float32, np.float64)
|
||||
|
||||
for wavelet in wavelets:
|
||||
for pmode, mmode in modes:
|
||||
for dt in dtypes:
|
||||
check_reconstruction(pmode, mmode, wavelet, dt)
|
||||
|
||||
|
||||
def check_reconstruction(pmode, mmode, wavelet, dtype):
|
||||
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
|
||||
50000, 100000]
|
||||
np.random.seed(12345)
|
||||
# TODO: smoke testing - more failures for different seeds
|
||||
|
||||
if dtype == np.float32:
|
||||
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
|
||||
epsilon = 6e-7
|
||||
else:
|
||||
epsilon = 5e-11
|
||||
|
||||
for N in data_size:
|
||||
data = np.asarray(np.random.random(N), dtype)
|
||||
|
||||
# compute dwt coefficients
|
||||
pa, pd = pywt.dwt(data, wavelet, pmode)
|
||||
|
||||
# compute reconstruction
|
||||
rec = pywt.idwt(pa, pd, wavelet, pmode)
|
||||
|
||||
if len(data) % 2:
|
||||
rec = rec[:len(data)]
|
||||
|
||||
rms_rec = np.sqrt(np.mean((data-rec)**2))
|
||||
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
|
||||
assert_(rms_rec < epsilon, msg=msg)
|
633
venv/Lib/site-packages/pywt/tests/test_swt.py
Normal file
633
venv/Lib/site-packages/pywt/tests/test_swt.py
Normal file
|
@ -0,0 +1,633 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import combinations, permutations
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import (assert_allclose, assert_, assert_equal,
|
||||
assert_raises, assert_array_equal, assert_warns)
|
||||
|
||||
import pywt
|
||||
from pywt._extensions._swt import swt_axis
|
||||
|
||||
# Check that float32 and complex64 are preserved. Other real types get
|
||||
# converted to float64.
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
|
||||
####
|
||||
# 1d multilevel swt tests
|
||||
####
|
||||
|
||||
|
||||
def test_swt_decomposition():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
db1 = pywt.Wavelet('db1')
|
||||
atol = tol_double
|
||||
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
|
||||
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
|
||||
2.82842712, 7.07106781, 7.07106781, 6.36396103]
|
||||
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
|
||||
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
|
||||
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
|
||||
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
|
||||
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
expected_cA3 = [9.89949494, ] * 8
|
||||
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
|
||||
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
|
||||
0.00000000, 3.53553391, 4.24264069, 2.12132034]
|
||||
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
|
||||
|
||||
# level=1, start_level=1 decomposition should match level=2
|
||||
res = pywt.swt(cA1, db1, level=1, start_level=1)
|
||||
cA2, cD2 = res[0]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
|
||||
coeffs = pywt.swt(x, db1)
|
||||
assert_(len(coeffs) == 3)
|
||||
assert_(pywt.swt_max_level(len(x)), 3)
|
||||
|
||||
|
||||
def test_swt_max_level():
|
||||
# odd sized signal will warn about no levels of decomposition possible
|
||||
assert_warns(UserWarning, pywt.swt_max_level, 11)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', UserWarning)
|
||||
assert_equal(pywt.swt_max_level(11), 0)
|
||||
|
||||
# no warnings when >= 1 level of decomposition possible
|
||||
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
|
||||
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
|
||||
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
|
||||
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
|
||||
|
||||
|
||||
def test_swt_axis():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
|
||||
db1 = pywt.Wavelet('db1')
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
|
||||
|
||||
# test cases use 2D arrays based on tiling x along an axis and then
|
||||
# calling swt along the other axis.
|
||||
for order in ['C', 'F']:
|
||||
# test SWT of 2D data along default axis (-1)
|
||||
x_2d = np.asarray(x).reshape((1, -1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=0)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each row should match the 1D result
|
||||
for row in cA1_2d:
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d:
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d:
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d:
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# test SWT of 2D data along other axis (0)
|
||||
x_2d = np.asarray(x).reshape((-1, 1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=1)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
|
||||
axis=0)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each column should match the 1D result
|
||||
for row in cA1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# axis too large
|
||||
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
|
||||
|
||||
|
||||
def test_swt_iswt_integration():
|
||||
# This function performs a round-trip swt/iswt transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt or iswt as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet seems to be a bit special - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length)
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
# swt
|
||||
x = np.ones(8, dtype=dt_in)
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
|
||||
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
|
||||
"swt: " + errmsg)
|
||||
|
||||
# swt2
|
||||
x = np.ones((8, 8), dtype=dt_in)
|
||||
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
|
||||
"swt2: " + errmsg)
|
||||
|
||||
|
||||
def test_swt_roundtrip_dtypes():
|
||||
# verify perfect reconstruction for all dtypes
|
||||
rstate = np.random.RandomState(5)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
# swt, iswt
|
||||
x = rstate.standard_normal((8, )).astype(dt_in)
|
||||
c = pywt.swt(x, wavelet, level=2)
|
||||
xr = pywt.iswt(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
# swt2, iswt2
|
||||
x = rstate.standard_normal((8, 8)).astype(dt_in)
|
||||
c = pywt.swt2(x, wavelet, level=2)
|
||||
xr = pywt.iswt2(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_default_level_by_axis():
|
||||
# make sure default number of levels matches the max level along the axis
|
||||
wav = 'db2'
|
||||
x = np.ones((2**3, 2**4, 2**5))
|
||||
for axis in (0, 1, 2):
|
||||
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
|
||||
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
|
||||
|
||||
|
||||
def test_swt2_ndim_error():
|
||||
x = np.ones(8)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swt2_iswt2_integration(wavelets=None):
|
||||
# This function performs a round-trip swt2/iswt2 transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt2 or iswt2 as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt2(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def test_swt2_iswt2_quick():
|
||||
test_swt2_iswt2_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_swt2_iswt2_non_square(wavelets=None):
|
||||
for nrows in [8, 16, 48]:
|
||||
X = np.arange(nrows*32).reshape(nrows, 32)
|
||||
current_wavelet = 'db1'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
coeffs = pywt.swt2(X, current_wavelet, level=2)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet)
|
||||
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
|
||||
|
||||
|
||||
def test_swt2_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
|
||||
# opposite order
|
||||
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
|
||||
axes=(1, 0))[0]
|
||||
assert_allclose(cA1, cA2, atol=atol)
|
||||
assert_allclose(cH1, cV2, atol=atol)
|
||||
assert_allclose(cV1, cH2, atol=atol)
|
||||
assert_allclose(cD1, cD2, atol=atol)
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
|
||||
axes=(0, 0))
|
||||
# too few axes
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
|
||||
|
||||
|
||||
def test_iswt2_2d_only():
|
||||
# iswt2 is not currently compatible with data that is not 2D
|
||||
x_3d = np.ones((4, 4, 4))
|
||||
c = pywt.swt2(x_3d, 'haar', level=1)
|
||||
assert_raises(ValueError, pywt.iswt2, c, 'haar')
|
||||
|
||||
|
||||
def test_swtn_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
|
||||
# opposite order
|
||||
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
|
||||
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
|
||||
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
|
||||
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
|
||||
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
|
||||
|
||||
# 0-level transform
|
||||
empty = pywt.swtn(X, current_wavelet, level=0)
|
||||
assert_equal(empty, [])
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
|
||||
|
||||
# data.ndim = 0
|
||||
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
|
||||
|
||||
# start_level too large
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
|
||||
level=1, start_level=2)
|
||||
|
||||
# level < 1 in swt_axis call
|
||||
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
|
||||
start_level=0)
|
||||
# odd-sized data not allowed
|
||||
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
|
||||
start_level=0, axis=0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swtn_iswtn_integration(wavelets=None):
|
||||
# This function performs a round-trip swtn/iswtn transform for various
|
||||
# possible combinations of:
|
||||
# 1.) 1 out of 2 axes of a 2D array
|
||||
# 2.) 2 out of 3 axes of a 3D array
|
||||
#
|
||||
# To keep test time down, only wavelets of length <= 8 are run.
|
||||
#
|
||||
# This test does not validate swtn or iswtn individually, but only
|
||||
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for ndim_transform in range(1, 3):
|
||||
ndim = ndim_transform + 1
|
||||
for axes in combinations(range(ndim), ndim_transform):
|
||||
for current_wavelet_str in wavelets:
|
||||
wav = pywt.Wavelet(current_wavelet_str)
|
||||
if wav.dec_len > 8:
|
||||
continue # avoid excessive test duration
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
wav.dec_len,
|
||||
wav.rec_len))))
|
||||
N = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(N**ndim).reshape((N, )*ndim)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not wav.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
coeffs_copy = deepcopy(coeffs)
|
||||
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# verify the inverse transform didn't modify any coeffs
|
||||
for c, c2 in zip(coeffs, coeffs_copy):
|
||||
for k, v in c.items():
|
||||
assert_array_equal(c2[k], v)
|
||||
|
||||
|
||||
def test_swtn_iswtn_quick():
|
||||
test_swtn_iswtn_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_iswtn_errors():
|
||||
x = np.arange(8**3).reshape(8, 8, 8)
|
||||
max_level = 2
|
||||
axes = (0, 1)
|
||||
w = pywt.Wavelet('db1')
|
||||
coeffs = pywt.swtn(x, w, max_level, axes=axes)
|
||||
|
||||
# more axes than dimensions transformed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
|
||||
# mismatched coefficient size
|
||||
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
|
||||
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
|
||||
|
||||
|
||||
def test_swtn_iswtn_unique_shape_per_axis():
|
||||
# test case for gh-460
|
||||
_shape = (1, 48, 32) # unique shape per axis
|
||||
wav = 'sym2'
|
||||
max_level = 3
|
||||
rstate = np.random.RandomState(0)
|
||||
for shape in permutations(_shape):
|
||||
# transform only along the non-singleton axes
|
||||
axes = [ax for ax, s in enumerate(shape) if s != 1]
|
||||
x = rstate.standard_normal(shape)
|
||||
c = pywt.swtn(x, wav, max_level, axes=axes)
|
||||
r = pywt.iswtn(c, wav, axes=axes)
|
||||
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
|
||||
|
||||
|
||||
def test_per_axis_wavelets():
|
||||
# tests seperate wavelet for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
level = 3
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
coefs = pywt.swtn(data, wavelets, level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
|
||||
|
||||
# 1-tuple also okay
|
||||
coefs = pywt.swtn(data, wavelets[:1], level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
|
||||
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
# swt2/iswt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.swt2(data2, wavelets[:2], level)
|
||||
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_func, data, wavelet=cwave,
|
||||
level=3)
|
||||
|
||||
c = dec_func(data, 'db1', level=3)
|
||||
assert_raises(ValueError, rec_func, c, wavelet=cwave)
|
||||
|
||||
|
||||
def test_iswt_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
x_real = np.arange(16).astype(np.float64)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
coeffs[0][1].astype(dtype2)]
|
||||
y = pywt.iswt(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswt2_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt2(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
|
||||
y = pywt.iswt2(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswtn_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swtn(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
a = coeffs[0].pop('a' * x.ndim)
|
||||
a = a.astype(dtype1)
|
||||
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
|
||||
coeffs[0]['a' * x.ndim] = a
|
||||
y = pywt.iswtn(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_swt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.swt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
|
||||
|
||||
|
||||
def test_swt_variance_and_energy_preservation():
|
||||
"""Verify that the 1D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(256)
|
||||
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
|
||||
variances = [np.var(c) for c in coeffs]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeffs)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
|
||||
|
||||
|
||||
def test_swt2_variance_and_energy_preservation():
|
||||
"""Verify that the 2D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for v in d:
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swtn_variance_and_energy_preservation():
|
||||
"""Verify that the nD SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for k, v in d.items():
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swt_ravel_and_unravel():
|
||||
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
|
||||
for ndim, _swt, _iswt, ravel_type in [
|
||||
(1, pywt.swt, pywt.iswt, 'swt'),
|
||||
(2, pywt.swt2, pywt.iswt2, 'swt2'),
|
||||
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
|
||||
x = np.ones((16, ) * ndim)
|
||||
c = _swt(x, 'sym2', level=3, trim_approx=True)
|
||||
arr, slices, shapes = pywt.ravel_coeffs(c)
|
||||
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
|
||||
r = _iswt(c, 'sym2')
|
||||
assert_allclose(x, r)
|
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
real_dtypes = [np.float32, np.float64]
|
||||
|
||||
|
||||
def _sign(x):
|
||||
# Matlab-like sign function (numpy uses a different convention).
|
||||
return x / np.abs(x)
|
||||
|
||||
|
||||
def _soft(x, thresh):
|
||||
"""soft thresholding supporting complex values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This version is not robust to zeros in x.
|
||||
"""
|
||||
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
||||
|
||||
|
||||
def test_threshold():
|
||||
data = np.linspace(1, 4, 7)
|
||||
|
||||
# soft
|
||||
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
||||
np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
||||
-np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
||||
[[0, 1]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
|
||||
# soft thresholding complex values
|
||||
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
||||
[[0j, 1j]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
complex_data = [[1+2j, 2+2j]]*2
|
||||
for thresh in [1, 2]:
|
||||
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
||||
_soft(complex_data, thresh), rtol=1e-12)
|
||||
|
||||
# test soft thresholding with non-default substitute argument
|
||||
s = 5
|
||||
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
||||
[[s, 0.5]] * 2, rtol=1e-12)
|
||||
|
||||
# soft: no divide by zero warnings when input contains zeros
|
||||
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
||||
np.zeros(16), rtol=1e-12)
|
||||
|
||||
# hard
|
||||
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
||||
np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
||||
-np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
||||
[[0, 2+2j]] * 2, rtol=1e-12)
|
||||
|
||||
# greater
|
||||
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
||||
np.array(greater_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
# greater doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
||||
|
||||
# less
|
||||
assert_allclose(pywt.threshold(data, 2, 'less'),
|
||||
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
||||
[[1, 0]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
||||
[[1, s]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
|
||||
# less doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
||||
|
||||
# invalid
|
||||
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
||||
|
||||
|
||||
def test_nonnegative_garotte():
|
||||
thresh = 0.3
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_garotte.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_garotte[lt] == 0))
|
||||
|
||||
# values > than the threshold are intermediate between soft and hard
|
||||
gt = np.where(np.abs(data) > thresh)
|
||||
gt_abs_garotte = np.abs(d_garotte[gt])
|
||||
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
||||
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
||||
|
||||
|
||||
def test_threshold_firm():
|
||||
thresh = 0.2
|
||||
thresh2 = 3 * thresh
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
if data.real.dtype == np.float32:
|
||||
rtol = atol = 1e-6
|
||||
else:
|
||||
rtol = atol = 1e-14
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_firm.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_firm[lt] == 0))
|
||||
|
||||
# values > than the threshold are equal to hard-thresholding
|
||||
gt = np.where(np.abs(data) >= thresh2)
|
||||
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
# other values are intermediate between soft and hard thresholding
|
||||
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
||||
np.abs(data) < thresh2))
|
||||
mt_abs_firm = np.abs(d_firm[mt])
|
||||
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
||||
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|
266
venv/Lib/site-packages/pywt/tests/test_wavelet.py
Normal file
266
venv/Lib/site-packages/pywt/tests/test_wavelet.py
Normal file
|
@ -0,0 +1,266 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_properties():
|
||||
w = pywt.Wavelet('db3')
|
||||
|
||||
# Name
|
||||
assert_(w.name == 'db3')
|
||||
assert_(w.short_family_name == 'db')
|
||||
assert_(w.family_name, 'Daubechies')
|
||||
|
||||
# String representation
|
||||
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
|
||||
'Biorthogonal', 'Symmetry')
|
||||
for field in fields:
|
||||
assert_(field in str(w))
|
||||
|
||||
# Filter coefficients
|
||||
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
|
||||
0.45987750211933, 0.80689150931334, 0.33267055295096]
|
||||
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
|
||||
-0.13501102001039, 0.08544127388224, 0.03522629188210]
|
||||
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
|
||||
-0.13501102001039, -0.08544127388224, 0.03522629188210]
|
||||
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
|
||||
-0.45987750211933, 0.80689150931334, -0.33267055295096]
|
||||
assert_allclose(w.dec_lo, dec_lo)
|
||||
assert_allclose(w.dec_hi, dec_hi)
|
||||
assert_allclose(w.rec_lo, rec_lo)
|
||||
assert_allclose(w.rec_hi, rec_hi)
|
||||
|
||||
assert_(len(w.filter_bank) == 4)
|
||||
|
||||
# Orthogonality
|
||||
assert_(w.orthogonal)
|
||||
assert_(w.biorthogonal)
|
||||
|
||||
# Symmetry
|
||||
assert_(w.symmetry)
|
||||
|
||||
# Vanishing moments
|
||||
assert_(w.vanishing_moments_phi == 0)
|
||||
assert_(w.vanishing_moments_psi == 3)
|
||||
|
||||
|
||||
def test_wavelet_coefficients():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
for wavelet in wavelets:
|
||||
if (pywt.Wavelet(wavelet).orthogonal):
|
||||
check_coefficients_orthogonal(wavelet)
|
||||
elif(pywt.Wavelet(wavelet).biorthogonal):
|
||||
check_coefficients_biorthogonal(wavelet)
|
||||
else:
|
||||
check_coefficients(wavelet)
|
||||
|
||||
|
||||
def check_coefficients_orthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi, psi, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
|
||||
res = np.sum(phi) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Wavelet function is orthogonal to the scaling function at the same scale
|
||||
res = np.sum(phi*psi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# The lowpass and highpass filter coefficients are orthogonal
|
||||
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients_biorthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
res = np.sum(phi_d) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(phi_r) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients(wavelet):
|
||||
epsilon = 5e-11
|
||||
level = 10
|
||||
w = pywt.Wavelet(wavelet)
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
class _CustomHaarFilterBank(object):
|
||||
@property
|
||||
def filter_bank(self):
|
||||
val = np.sqrt(2) / 2
|
||||
return ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
|
||||
|
||||
def test_custom_wavelet():
|
||||
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=_CustomHaarFilterBank())
|
||||
haar_custom1.orthogonal = True
|
||||
haar_custom1.biorthogonal = True
|
||||
|
||||
val = np.sqrt(2) / 2
|
||||
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=filter_bank)
|
||||
|
||||
# check expected default wavelet properties
|
||||
assert_(~haar_custom2.orthogonal)
|
||||
assert_(~haar_custom2.biorthogonal)
|
||||
assert_(haar_custom2.symmetry == 'unknown')
|
||||
assert_(haar_custom2.family_name == '')
|
||||
assert_(haar_custom2.short_family_name == '')
|
||||
assert_(haar_custom2.vanishing_moments_phi == 0)
|
||||
assert_(haar_custom2.vanishing_moments_psi == 0)
|
||||
|
||||
# Some properties can be set by the user
|
||||
haar_custom2.orthogonal = True
|
||||
haar_custom2.biorthogonal = True
|
||||
|
||||
|
||||
def test_wavefun_sym3():
|
||||
w = pywt.Wavelet('sym3')
|
||||
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
|
||||
phi, psi, x = w.wavefun(level=3)
|
||||
assert_(phi.size == 41)
|
||||
assert_(psi.size == 41)
|
||||
assert_(x.size == 41)
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, num=x.size))
|
||||
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
|
||||
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
|
||||
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
|
||||
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
|
||||
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
|
||||
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
|
||||
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
|
||||
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
|
||||
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
|
||||
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
|
||||
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
|
||||
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
|
||||
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
|
||||
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
|
||||
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
|
||||
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
|
||||
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
|
||||
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
|
||||
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
|
||||
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
|
||||
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
|
||||
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
|
||||
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
|
||||
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
|
||||
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
assert_allclose(phi, phi_expect)
|
||||
assert_allclose(psi, psi_expect)
|
||||
|
||||
|
||||
def test_wavefun_bior13():
|
||||
w = pywt.Wavelet('bior1.3')
|
||||
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
|
||||
for arr in [phi_d, psi_d, phi_r, psi_r]:
|
||||
assert_(arr.size == 40)
|
||||
|
||||
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
|
||||
0.01367188, 0.00390625, -0.03515625, -0.12890625,
|
||||
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
|
||||
0.15234375, 0.37890625, 0.78515625, 0.99609375,
|
||||
1.08203125, 1.13671875, 1.13671875, 1.08203125,
|
||||
0.99609375, 0.78515625, 0.37890625, 0.15234375,
|
||||
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
|
||||
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
|
||||
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
|
||||
phi_r_expect = np.zeros(x.size, dtype=np.float)
|
||||
phi_r_expect[15:23] = 1
|
||||
|
||||
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.015625, -0.015625, -0.140625, -0.109375,
|
||||
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
|
||||
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
|
||||
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
psi_r_expect = np.zeros(x.size, dtype=np.float)
|
||||
psi_r_expect[7:15] = -0.125
|
||||
psi_r_expect[15:19] = 1
|
||||
psi_r_expect[19:23] = -1
|
||||
psi_r_expect[23:31] = 0.125
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
|
||||
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
|
||||
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
|
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
|
@ -0,0 +1,197 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_packet_structure():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
|
||||
|
||||
def test_traversing_wp_tree():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
# First level
|
||||
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
|
||||
7.778174593052, 10.606601717798]),
|
||||
rtol=1e-12)
|
||||
|
||||
# Second level
|
||||
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
|
||||
|
||||
# Third level
|
||||
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
|
||||
|
||||
|
||||
def test_acess_path():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp['a'].path == 'a')
|
||||
assert_(wp['aa'].path == 'aa')
|
||||
assert_(wp['aaa'].path == 'aaa')
|
||||
|
||||
# Maximum level reached:
|
||||
assert_raises(IndexError, lambda: wp['aaaa'].path)
|
||||
|
||||
# Wrong path
|
||||
assert_raises(ValueError, lambda: wp['ac'].path)
|
||||
|
||||
|
||||
def test_access_node_atributes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
|
||||
assert_(wp['ad'].path == 'ad')
|
||||
assert_(wp['ad'].node_name == 'd')
|
||||
assert_(wp['ad'].parent.path == 'a')
|
||||
assert_(wp['ad'].level == 2)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
assert_(wp['ad'].mode == 'symmetric')
|
||||
|
||||
|
||||
def test_collecting_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# All nodes in natural order
|
||||
assert_([node.path for node in wp.get_level(3, 'natural')] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
# and in frequency order.
|
||||
assert_([node.path for node in wp.get_level(3, 'freq')] ==
|
||||
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
|
||||
|
||||
|
||||
def test_reconstructing_data():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# Create another Wavelet Packet and feed it with some data.
|
||||
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['aa'] = wp['aa'].data
|
||||
new_wp['ad'] = [-2., -2.]
|
||||
|
||||
# For convenience, :attr:`Node.data` gets automatically extracted
|
||||
# from the :class:`Node` object:
|
||||
new_wp['d'] = wp['d']
|
||||
|
||||
# Reconstruct data from aa, ad, and d packets.
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
# The node's :attr:`~Node.data` will not be updated
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
# When `update` is True:
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
|
||||
['aa', 'ad', 'd'])
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
|
||||
def test_removing_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
wp.get_level(2)
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
node = wp['ad']
|
||||
del(wp['ad'])
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(3):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
wp.reconstruct()
|
||||
# The reconstruction is:
|
||||
assert_allclose(wp.reconstruct(),
|
||||
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
|
||||
|
||||
# Restore the data
|
||||
wp['ad'].data = node.data
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
N = 32
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(N).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assiging to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# first element of the tuple is the input dtype
|
||||
# second element of the tuple is the transform dtype
|
||||
dtype_pairs = [(np.uint8, np.float64),
|
||||
(np.intp, np.float64), ]
|
||||
if hasattr(np, "complex256"):
|
||||
dtype_pairs += [(np.complex256, np.complex128), ]
|
||||
if hasattr(np, "half"):
|
||||
dtype_pairs += [(np.half, np.float32), ]
|
||||
for (dtype, transform_dtype) in dtype_pairs:
|
||||
x = np.arange(N, dtype=dtype)
|
||||
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# no unnecessary copy made of top-level data
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstructed data will have modified dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, transform_dtype)
|
||||
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_db3_roundtrip():
|
||||
original = np.arange(512)
|
||||
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
177
venv/Lib/site-packages/pywt/tests/test_wp2d.py
Normal file
177
venv/Lib/site-packages/pywt/tests/test_wp2d.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_traversing_tree_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(np.all(wp.data == x))
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
|
||||
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
||||
|
||||
assert_(wp['a']['a'].data is wp['aa'].data)
|
||||
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
|
||||
|
||||
assert_raises(IndexError, lambda: wp['aaaa'])
|
||||
assert_raises(ValueError, lambda: wp['f'])
|
||||
|
||||
|
||||
def test_accessing_node_atributes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
||||
assert_(wp['av'].path == 'av')
|
||||
assert_(wp['av'].node_name == 'v')
|
||||
assert_(wp['av'].parent.path == 'a')
|
||||
|
||||
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_(wp['av'].level == 2)
|
||||
assert_(wp['av'].maxlevel == 3)
|
||||
assert_(wp['av'].mode == 'symmetric')
|
||||
|
||||
|
||||
def test_collecting_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(len(wp.get_level(0)) == 1)
|
||||
assert_(wp.get_level(0)[0].path == '')
|
||||
|
||||
# First level
|
||||
assert_(len(wp.get_level(1)) == 4)
|
||||
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
|
||||
|
||||
# Second level
|
||||
assert_(len(wp.get_level(2)) == 16)
|
||||
paths = [node.path for node in wp.get_level(2)]
|
||||
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
|
||||
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# Third level.
|
||||
assert_(len(wp.get_level(3)) == 64)
|
||||
paths = [node.path for node in wp.get_level(3)]
|
||||
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
|
||||
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
|
||||
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
|
||||
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
|
||||
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
|
||||
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
|
||||
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
|
||||
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
|
||||
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
|
||||
def test_data_reconstruction_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
|
||||
def test_data_reconstruction_delete_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
del(new_wp['va'])
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, x, rtol=1e-12)
|
||||
|
||||
# TODO: decompose=True
|
||||
|
||||
|
||||
def test_lazy_evaluation_2D():
|
||||
# Note: internal implementation detail not to be relied on. Testing for
|
||||
# now for backwards compatibility, but this test may be broken in needed.
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.a is None)
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
||||
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
shape = (16, 16)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(*shape).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assiging to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_2d_roundtrip():
|
||||
# test case corresponding to PyWavelets issue 447
|
||||
original = pywt.data.camera()
|
||||
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
Loading…
Add table
Add a link
Reference in a new issue