#!/usr/bin/env python from __future__ import division, print_function, absolute_import import warnings from copy import deepcopy from itertools import combinations, permutations import numpy as np import pytest from numpy.testing import (assert_allclose, assert_, assert_equal, assert_raises, assert_array_equal, assert_warns) import pywt from pywt._extensions._swt import swt_axis # Check that float32 and complex64 are preserved. Other real types get # converted to float64. dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64, np.complex128] dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64, np.complex128] # tolerances used in accuracy comparisons tol_single = 1e-6 tol_double = 1e-13 #### # 1d multilevel swt tests #### def test_swt_decomposition(): x = [3, 7, 1, 3, -2, 6, 4, 6] db1 = pywt.Wavelet('db1') atol = tol_double (cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3) expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678, 2.82842712, 7.07106781, 7.07106781, 6.36396103] assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol) expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391, -5.65685425, 1.41421356, -1.41421356, 2.12132034] assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol) expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5] assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol) expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5] assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol) expected_cA3 = [9.89949494, ] * 8 assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol) expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034, 0.00000000, 3.53553391, 4.24264069, 2.12132034] assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol) # level=1, start_level=1 decomposition should match level=2 res = pywt.swt(cA1, db1, level=1, start_level=1) cA2, cD2 = res[0] assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol) assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol) coeffs = pywt.swt(x, db1) assert_(len(coeffs) == 3) assert_(pywt.swt_max_level(len(x)), 3) def test_swt_max_level(): # odd sized signal will warn about no levels of decomposition possible assert_warns(UserWarning, pywt.swt_max_level, 11) with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) assert_equal(pywt.swt_max_level(11), 0) # no warnings when >= 1 level of decomposition possible assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1 assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2 assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4 assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4 def test_swt_axis(): x = [3, 7, 1, 3, -2, 6, 4, 6] db1 = pywt.Wavelet('db1') (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2) # test cases use 2D arrays based on tiling x along an axis and then # calling swt along the other axis. for order in ['C', 'F']: # test SWT of 2D data along default axis (-1) x_2d = np.asarray(x).reshape((1, -1)) x_2d = np.concatenate((x_2d, )*5, axis=0) if order == 'C': x_2d = np.ascontiguousarray(x_2d) elif order == 'F': x_2d = np.asfortranarray(x_2d) (cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2) for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]: assert_(c.shape == x_2d.shape) # each row should match the 1D result for row in cA1_2d: assert_array_equal(row, cA1) for row in cA2_2d: assert_array_equal(row, cA2) for row in cD1_2d: assert_array_equal(row, cD1) for row in cD2_2d: assert_array_equal(row, cD2) # test SWT of 2D data along other axis (0) x_2d = np.asarray(x).reshape((-1, 1)) x_2d = np.concatenate((x_2d, )*5, axis=1) if order == 'C': x_2d = np.ascontiguousarray(x_2d) elif order == 'F': x_2d = np.asfortranarray(x_2d) (cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2, axis=0) for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]: assert_(c.shape == x_2d.shape) # each column should match the 1D result for row in cA1_2d.transpose((1, 0)): assert_array_equal(row, cA1) for row in cA2_2d.transpose((1, 0)): assert_array_equal(row, cA2) for row in cD1_2d.transpose((1, 0)): assert_array_equal(row, cD1) for row in cD2_2d.transpose((1, 0)): assert_array_equal(row, cD2) # axis too large assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5) def test_swt_iswt_integration(): # This function performs a round-trip swt/iswt transform test on # all available types of wavelets in PyWavelets - except the # 'dmey' wavelet. The latter has been excluded because it does not # produce very precise results. This is likely due to the fact # that the 'dmey' wavelet is a discrete approximation of a # continuous wavelet. All wavelets are tested up to 3 levels. The # test validates neither swt or iswt as such, but it does ensure # that they are each other's inverse. max_level = 3 wavelets = pywt.wavelist(kind='discrete') if 'dmey' in wavelets: # The 'dmey' wavelet seems to be a bit special - disregard it for now wavelets.remove('dmey') for current_wavelet_str in wavelets: current_wavelet = pywt.Wavelet(current_wavelet_str) input_length_power = int(np.ceil(np.log2(max( current_wavelet.dec_len, current_wavelet.rec_len)))) input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length) for norm in [True, False]: if norm and not current_wavelet.orthogonal: # non-orthogonal wavelets to avoid warnings when norm=True continue for trim_approx in [True, False]: coeffs = pywt.swt(X, current_wavelet, max_level, trim_approx=trim_approx, norm=norm) Y = pywt.iswt(coeffs, current_wavelet, norm=norm) assert_allclose(Y, X, rtol=1e-5, atol=1e-7) def test_swt_dtypes(): wavelet = pywt.Wavelet('haar') for dt_in, dt_out in zip(dtypes_in, dtypes_out): errmsg = "wrong dtype returned for {0} input".format(dt_in) # swt x = np.ones(8, dtype=dt_in) (cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2) assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out, "swt: " + errmsg) # swt2 x = np.ones((8, 8), dtype=dt_in) cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0] assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out, "swt2: " + errmsg) def test_swt_roundtrip_dtypes(): # verify perfect reconstruction for all dtypes rstate = np.random.RandomState(5) wavelet = pywt.Wavelet('haar') for dt_in, dt_out in zip(dtypes_in, dtypes_out): # swt, iswt x = rstate.standard_normal((8, )).astype(dt_in) c = pywt.swt(x, wavelet, level=2) xr = pywt.iswt(c, wavelet) assert_allclose(x, xr, rtol=1e-6, atol=1e-7) # swt2, iswt2 x = rstate.standard_normal((8, 8)).astype(dt_in) c = pywt.swt2(x, wavelet, level=2) xr = pywt.iswt2(c, wavelet) assert_allclose(x, xr, rtol=1e-6, atol=1e-7) def test_swt_default_level_by_axis(): # make sure default number of levels matches the max level along the axis wav = 'db2' x = np.ones((2**3, 2**4, 2**5)) for axis in (0, 1, 2): sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis) assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis])) def test_swt2_ndim_error(): x = np.ones(8) with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) assert_raises(ValueError, pywt.swt2, x, 'haar', level=1) @pytest.mark.slow def test_swt2_iswt2_integration(wavelets=None): # This function performs a round-trip swt2/iswt2 transform test on # all available types of wavelets in PyWavelets - except the # 'dmey' wavelet. The latter has been excluded because it does not # produce very precise results. This is likely due to the fact # that the 'dmey' wavelet is a discrete approximation of a # continuous wavelet. All wavelets are tested up to 3 levels. The # test validates neither swt2 or iswt2 as such, but it does ensure # that they are each other's inverse. max_level = 3 if wavelets is None: wavelets = pywt.wavelist(kind='discrete') if 'dmey' in wavelets: # The 'dmey' wavelet is a special case - disregard it for now wavelets.remove('dmey') for current_wavelet_str in wavelets: current_wavelet = pywt.Wavelet(current_wavelet_str) input_length_power = int(np.ceil(np.log2(max( current_wavelet.dec_len, current_wavelet.rec_len)))) input_length = 2**(input_length_power + max_level - 1) X = np.arange(input_length**2).reshape(input_length, input_length) for norm in [True, False]: if norm and not current_wavelet.orthogonal: # non-orthogonal wavelets to avoid warnings when norm=True continue for trim_approx in [True, False]: coeffs = pywt.swt2(X, current_wavelet, max_level, trim_approx=trim_approx, norm=norm) Y = pywt.iswt2(coeffs, current_wavelet, norm=norm) assert_allclose(Y, X, rtol=1e-5, atol=1e-5) def test_swt2_iswt2_quick(): test_swt2_iswt2_integration(wavelets=['db1', ]) def test_swt2_iswt2_non_square(wavelets=None): for nrows in [8, 16, 48]: X = np.arange(nrows*32).reshape(nrows, 32) current_wavelet = 'db1' with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) coeffs = pywt.swt2(X, current_wavelet, level=2) Y = pywt.iswt2(coeffs, current_wavelet) assert_allclose(Y, X, rtol=tol_single, atol=tol_single) def test_swt2_axes(): atol = 1e-14 current_wavelet = pywt.Wavelet('db2') input_length_power = int(np.ceil(np.log2(max( current_wavelet.dec_len, current_wavelet.rec_len)))) input_length = 2**(input_length_power) X = np.arange(input_length**2).reshape(input_length, input_length) (cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0] # opposite order (cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1, axes=(1, 0))[0] assert_allclose(cA1, cA2, atol=atol) assert_allclose(cH1, cV2, atol=atol) assert_allclose(cV1, cH2, atol=atol) assert_allclose(cD1, cD2, atol=atol) # duplicate axes not allowed assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, 0)) # too few axes assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, )) def test_iswt2_2d_only(): # iswt2 is not currently compatible with data that is not 2D x_3d = np.ones((4, 4, 4)) c = pywt.swt2(x_3d, 'haar', level=1) assert_raises(ValueError, pywt.iswt2, c, 'haar') def test_swtn_axes(): atol = 1e-14 current_wavelet = pywt.Wavelet('db2') input_length_power = int(np.ceil(np.log2(max( current_wavelet.dec_len, current_wavelet.rec_len)))) input_length = 2**(input_length_power) X = np.arange(input_length**2).reshape(input_length, input_length) coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0] # opposite order coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0] assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol) assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol) assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol) assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol) # 0-level transform empty = pywt.swtn(X, current_wavelet, level=0) assert_equal(empty, []) # duplicate axes not allowed assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0)) # data.ndim = 0 assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1) # start_level too large assert_raises(ValueError, pywt.swtn, X, current_wavelet, level=1, start_level=2) # level < 1 in swt_axis call assert_raises(ValueError, swt_axis, X, current_wavelet, level=0, start_level=0) # odd-sized data not allowed assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0, start_level=0, axis=0) @pytest.mark.slow def test_swtn_iswtn_integration(wavelets=None): # This function performs a round-trip swtn/iswtn transform for various # possible combinations of: # 1.) 1 out of 2 axes of a 2D array # 2.) 2 out of 3 axes of a 3D array # # To keep test time down, only wavelets of length <= 8 are run. # # This test does not validate swtn or iswtn individually, but only # confirms that iswtn yields an (almost) perfect reconstruction of swtn. max_level = 3 if wavelets is None: wavelets = pywt.wavelist(kind='discrete') if 'dmey' in wavelets: # The 'dmey' wavelet is a special case - disregard it for now wavelets.remove('dmey') for ndim_transform in range(1, 3): ndim = ndim_transform + 1 for axes in combinations(range(ndim), ndim_transform): for current_wavelet_str in wavelets: wav = pywt.Wavelet(current_wavelet_str) if wav.dec_len > 8: continue # avoid excessive test duration input_length_power = int(np.ceil(np.log2(max( wav.dec_len, wav.rec_len)))) N = 2**(input_length_power + max_level - 1) X = np.arange(N**ndim).reshape((N, )*ndim) for norm in [True, False]: if norm and not wav.orthogonal: # non-orthogonal wavelets to avoid warnings continue for trim_approx in [True, False]: coeffs = pywt.swtn(X, wav, max_level, axes=axes, trim_approx=trim_approx, norm=norm) coeffs_copy = deepcopy(coeffs) Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm) assert_allclose(Y, X, rtol=1e-5, atol=1e-5) # verify the inverse transform didn't modify any coeffs for c, c2 in zip(coeffs, coeffs_copy): for k, v in c.items(): assert_array_equal(c2[k], v) def test_swtn_iswtn_quick(): test_swtn_iswtn_integration(wavelets=['db1', ]) def test_iswtn_errors(): x = np.arange(8**3).reshape(8, 8, 8) max_level = 2 axes = (0, 1) w = pywt.Wavelet('db1') coeffs = pywt.swtn(x, w, max_level, axes=axes) # more axes than dimensions transformed assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2)) # duplicate axes not allowed assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0)) # mismatched coefficient size coeffs[0]['da'] = coeffs[0]['da'][:-1, :] assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes) def test_swtn_iswtn_unique_shape_per_axis(): # test case for gh-460 _shape = (1, 48, 32) # unique shape per axis wav = 'sym2' max_level = 3 rstate = np.random.RandomState(0) for shape in permutations(_shape): # transform only along the non-singleton axes axes = [ax for ax, s in enumerate(shape) if s != 1] x = rstate.standard_normal(shape) c = pywt.swtn(x, wav, max_level, axes=axes) r = pywt.iswtn(c, wav, axes=axes) assert_allclose(x, r, rtol=1e-10, atol=1e-10) def test_per_axis_wavelets(): # tests seperate wavelet for each axis. rstate = np.random.RandomState(1234) data = rstate.randn(16, 16, 16) level = 3 # wavelet can be a string or wavelet object wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4') coefs = pywt.swtn(data, wavelets, level=level) assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14) # 1-tuple also okay coefs = pywt.swtn(data, wavelets[:1], level=level) assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14) # length of wavelets doesn't match the length of axes assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level) assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2]) with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) # swt2/iswt2 also support per-axis wavelets/modes data2 = data[..., 0] coefs2 = pywt.swt2(data2, wavelets[:2], level) assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14) def test_error_on_continuous_wavelet(): # A ValueError is raised if a Continuous wavelet is selected data = np.ones((16, 16)) for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn], [pywt.iswt, pywt.iswt2, pywt.iswtn]): for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]: assert_raises(ValueError, dec_func, data, wavelet=cwave, level=3) c = dec_func(data, 'db1', level=3) assert_raises(ValueError, rec_func, c, wavelet=cwave) def test_iswt_mixed_dtypes(): # Mixed precision inputs give double precision output x_real = np.arange(16).astype(np.float64) x_complex = x_real + 1j*x_real wav = 'sym2' for dtype1, dtype2 in [(np.float64, np.float32), (np.float32, np.float64), (np.float16, np.float64), (np.complex128, np.complex64), (np.complex64, np.complex128)]: if dtype1 in [np.complex64, np.complex128]: x = x_complex output_dtype = np.complex128 else: x = x_real output_dtype = np.float64 coeffs = pywt.swt(x, wav, 2) # different precision for the approximation coefficients coeffs[0] = [coeffs[0][0].astype(dtype1), coeffs[0][1].astype(dtype2)] y = pywt.iswt(coeffs, wav) assert_equal(output_dtype, y.dtype) assert_allclose(y, x, rtol=1e-3, atol=1e-3) def test_iswt2_mixed_dtypes(): # Mixed precision inputs give double precision output rstate = np.random.RandomState(0) x_real = rstate.randn(8, 8) x_complex = x_real + 1j*x_real wav = 'sym2' for dtype1, dtype2 in [(np.float64, np.float32), (np.float32, np.float64), (np.float16, np.float64), (np.complex128, np.complex64), (np.complex64, np.complex128)]: if dtype1 in [np.complex64, np.complex128]: x = x_complex output_dtype = np.complex128 else: x = x_real output_dtype = np.float64 coeffs = pywt.swt2(x, wav, 2) # different precision for the approximation coefficients coeffs[0] = [coeffs[0][0].astype(dtype1), tuple([c.astype(dtype2) for c in coeffs[0][1]])] y = pywt.iswt2(coeffs, wav) assert_equal(output_dtype, y.dtype) assert_allclose(y, x, rtol=1e-3, atol=1e-3) def test_iswtn_mixed_dtypes(): # Mixed precision inputs give double precision output rstate = np.random.RandomState(0) x_real = rstate.randn(8, 8, 8) x_complex = x_real + 1j*x_real wav = 'sym2' for dtype1, dtype2 in [(np.float64, np.float32), (np.float32, np.float64), (np.float16, np.float64), (np.complex128, np.complex64), (np.complex64, np.complex128)]: if dtype1 in [np.complex64, np.complex128]: x = x_complex output_dtype = np.complex128 else: x = x_real output_dtype = np.float64 coeffs = pywt.swtn(x, wav, 2) # different precision for the approximation coefficients a = coeffs[0].pop('a' * x.ndim) a = a.astype(dtype1) coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()} coeffs[0]['a' * x.ndim] = a y = pywt.iswtn(coeffs, wav) assert_equal(output_dtype, y.dtype) assert_allclose(y, x, rtol=1e-3, atol=1e-3) def test_swt_zero_size_axes(): # raise on empty input array assert_raises(ValueError, pywt.swt, [], 'db2') # >1D case uses a different code path so check there as well x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,)) def test_swt_variance_and_energy_preservation(): """Verify that the 1D SWT partitions variance among the coefficients.""" # When norm is True and the wavelet is orthogonal, the sum of the # variances of the coefficients should equal the variance of the signal. wav = 'db2' rstate = np.random.RandomState(5) x = rstate.randn(256) coeffs = pywt.swt(x, wav, trim_approx=True, norm=True) variances = [np.var(c) for c in coeffs] assert_allclose(np.sum(variances), np.var(x)) # also verify L2-norm energy preservation property assert_allclose(np.linalg.norm(x), np.linalg.norm(np.concatenate(coeffs))) # non-orthogonal wavelet with norm=True raises a warning assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True) def test_swt2_variance_and_energy_preservation(): """Verify that the 2D SWT partitions variance among the coefficients.""" # When norm is True and the wavelet is orthogonal, the sum of the # variances of the coefficients should equal the variance of the signal. wav = 'db2' rstate = np.random.RandomState(5) x = rstate.randn(64, 64) coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True) coeff_list = [coeffs[0].ravel()] for d in coeffs[1:]: for v in d: coeff_list.append(v.ravel()) variances = [np.var(v) for v in coeff_list] assert_allclose(np.sum(variances), np.var(x)) # also verify L2-norm energy preservation property assert_allclose(np.linalg.norm(x), np.linalg.norm(np.concatenate(coeff_list))) # non-orthogonal wavelet with norm=True raises a warning assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True) def test_swtn_variance_and_energy_preservation(): """Verify that the nD SWT partitions variance among the coefficients.""" # When norm is True and the wavelet is orthogonal, the sum of the # variances of the coefficients should equal the variance of the signal. wav = 'db2' rstate = np.random.RandomState(5) x = rstate.randn(64, 64) coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True) coeff_list = [coeffs[0].ravel()] for d in coeffs[1:]: for k, v in d.items(): coeff_list.append(v.ravel()) variances = [np.var(v) for v in coeff_list] assert_allclose(np.sum(variances), np.var(x)) # also verify L2-norm energy preservation property assert_allclose(np.linalg.norm(x), np.linalg.norm(np.concatenate(coeff_list))) # non-orthogonal wavelet with norm=True raises a warning assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True) def test_swt_ravel_and_unravel(): # When trim_approx=True, all swt functions can user pywt.ravel_coeffs for ndim, _swt, _iswt, ravel_type in [ (1, pywt.swt, pywt.iswt, 'swt'), (2, pywt.swt2, pywt.iswt2, 'swt2'), (3, pywt.swtn, pywt.iswtn, 'swtn')]: x = np.ones((16, ) * ndim) c = _swt(x, 'sym2', level=3, trim_approx=True) arr, slices, shapes = pywt.ravel_coeffs(c) c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type) r = _iswt(c, 'sym2') assert_allclose(x, r)