#!/usr/bin/env python from __future__ import division, print_function, absolute_import import warnings from itertools import combinations import numpy as np import pytest from numpy.testing import (assert_almost_equal, assert_allclose, assert_, assert_equal, assert_raises, assert_raises_regex, assert_array_equal, assert_warns) import pywt # Check that float32, float64, complex64, complex128 are preserved. # Other real types get converted to float64. # complex256 gets converted to complex128 dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64, np.complex128] dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64, np.complex128] # tolerances used in accuracy comparisons tol_single = 1e-6 tol_double = 1e-13 dtypes_and_tolerances = [(np.float16, tol_single), (np.float32, tol_single), (np.float64, tol_double), (np.int8, tol_double), (np.complex64, tol_single), (np.complex128, tol_double)] # test complex256 as well if it is available try: dtypes_in += [np.complex256, ] dtypes_out += [np.complex128, ] dtypes_and_tolerances += [(np.complex256, tol_double), ] except AttributeError: pass # determine which wavelets to test wavelist = pywt.wavelist() if 'dmey' in wavelist: # accuracy is very low for dmey, so omit it wavelist.remove('dmey') # removing wavelets with dwt_possible == False del_list = [] for wavelet in wavelist: with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) if not isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet): del_list.append(wavelet) for del_ind in del_list: wavelist.remove(del_ind) #### # 1d multilevel dwt tests #### def test_wavedec(): x = [3, 7, 1, 1, -2, 5, 4, 6] db1 = pywt.Wavelet('db1') cA3, cD3, cD2, cD1 = pywt.wavedec(x, db1) assert_almost_equal(cA3, [8.83883476]) assert_almost_equal(cD3, [-0.35355339]) assert_allclose(cD2, [4., -3.5]) assert_allclose(cD1, [-2.82842712, 0, -4.94974747, -1.41421356]) assert_(pywt.dwt_max_level(len(x), db1) == 3) def test_waverec_invalid_inputs(): # input must be list or tuple assert_raises(ValueError, pywt.waverec, np.ones(8), 'haar') # input list cannot be empty assert_raises(ValueError, pywt.waverec, [], 'haar') # 'array_to_coeffs must specify 'output_format' to perform waverec x = [3, 7, 1, 1, -2, 5, 4, 6] coeffs = pywt.wavedec(x, 'db1') arr, coeff_slices = pywt.coeffs_to_array(coeffs) coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices) message = "Unexpected detail coefficient type" assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr, 'haar') def test_waverec_accuracies(): rstate = np.random.RandomState(1234) x0 = rstate.randn(8) for dt, tol in dtypes_and_tolerances: x = x0.astype(dt) if np.iscomplexobj(x): x += 1j*rstate.randn(8).astype(x.real.dtype) coeffs = pywt.wavedec(x, 'db1') assert_allclose(pywt.waverec(coeffs, 'db1'), x, atol=tol, rtol=tol) def test_waverec_none(): x = [3, 7, 1, 1, -2, 5, 4, 6] coeffs = pywt.wavedec(x, 'db1') # set some coefficients to None coeffs[2] = None coeffs[0] = None assert_(pywt.waverec(coeffs, 'db1').size, len(x)) def test_waverec_odd_length(): x = [3, 7, 1, 1, -2, 5] coeffs = pywt.wavedec(x, 'db1') assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12) def test_waverec_complex(): x = np.array([3, 7, 1, 1, -2, 5, 4, 6]) x = x + 1j coeffs = pywt.wavedec(x, 'db1') assert_allclose(pywt.waverec(coeffs, 'db1'), x, rtol=1e-12) def test_multilevel_dtypes_1d(): # only checks that the result is of the expected type wavelet = pywt.Wavelet('haar') for dt_in, dt_out in zip(dtypes_in, dtypes_out): # wavedec, waverec x = np.ones(8, dtype=dt_in) errmsg = "wrong dtype returned for {0} input".format(dt_in) coeffs = pywt.wavedec(x, wavelet, level=2) for c in coeffs: assert_(c.dtype == dt_out, "wavedec: " + errmsg) x_roundtrip = pywt.waverec(coeffs, wavelet) assert_(x_roundtrip.dtype == dt_out, "waverec: " + errmsg) def test_waverec_all_wavelets_modes(): # test 2D case using all wavelets and modes rstate = np.random.RandomState(1234) r = rstate.randn(80) for wavelet in wavelist: for mode in pywt.Modes.modes: coeffs = pywt.wavedec(r, wavelet, mode=mode) assert_allclose(pywt.waverec(coeffs, wavelet, mode=mode), r, rtol=tol_single, atol=tol_single) #### # 2d multilevel dwt function tests #### def test_waverec2_accuracies(): rstate = np.random.RandomState(1234) x0 = rstate.randn(4, 4) for dt, tol in dtypes_and_tolerances: x = x0.astype(dt) if np.iscomplexobj(x): x += 1j*rstate.randn(4, 4).astype(x.real.dtype) coeffs = pywt.wavedec2(x, 'db1') assert_(len(coeffs) == 3) assert_allclose(pywt.waverec2(coeffs, 'db1'), x, atol=tol, rtol=tol) def test_multilevel_dtypes_2d(): wavelet = pywt.Wavelet('haar') for dt_in, dt_out in zip(dtypes_in, dtypes_out): # wavedec2, waverec2 x = np.ones((8, 8), dtype=dt_in) errmsg = "wrong dtype returned for {0} input".format(dt_in) cA, coeffsD2, coeffsD1 = pywt.wavedec2(x, wavelet, level=2) assert_(cA.dtype == dt_out, "wavedec2: " + errmsg) for c in coeffsD1: assert_(c.dtype == dt_out, "wavedec2: " + errmsg) for c in coeffsD2: assert_(c.dtype == dt_out, "wavedec2: " + errmsg) x_roundtrip = pywt.waverec2([cA, coeffsD2, coeffsD1], wavelet) assert_(x_roundtrip.dtype == dt_out, "waverec2: " + errmsg) @pytest.mark.slow def test_waverec2_all_wavelets_modes(): # test 2D case using all wavelets and modes rstate = np.random.RandomState(1234) r = rstate.randn(80, 96) for wavelet in wavelist: for mode in pywt.Modes.modes: coeffs = pywt.wavedec2(r, wavelet, mode=mode) assert_allclose(pywt.waverec2(coeffs, wavelet, mode=mode), r, rtol=tol_single, atol=tol_single) def test_wavedec2_complex(): data = np.ones((4, 4)) + 1j coeffs = pywt.wavedec2(data, 'db1') assert_(len(coeffs) == 3) assert_allclose(pywt.waverec2(coeffs, 'db1'), data, rtol=1e-12) def test_wavedec2_invalid_inputs(): # input array has too few dimensions data = np.ones(4) assert_raises(ValueError, pywt.wavedec2, data, 'haar') def test_waverec2_invalid_inputs(): # input must be list or tuple assert_raises(ValueError, pywt.waverec2, np.ones((8, 8)), 'haar') # input list cannot be empty assert_raises(ValueError, pywt.waverec2, [], 'haar') # coefficients from a difference decomposition used as input for dec_func in [pywt.wavedec, pywt.wavedecn]: coeffs = dec_func(np.ones((8, 8)), 'haar') message = "Unexpected detail coefficient type" assert_raises_regex(ValueError, message, pywt.waverec2, coeffs, 'haar') def test_waverec2_coeff_shape_mismatch(): x = np.ones((8, 8)) coeffs = pywt.wavedec2(x, 'db1') # introduce a shape mismatch in the coefficients coeffs = list(coeffs) coeffs[1] = list(coeffs[1]) coeffs[1][1] = np.zeros((16, 1)) assert_raises(ValueError, pywt.waverec2, coeffs, 'db1') def test_waverec2_odd_length(): x = np.ones((10, 6)) coeffs = pywt.wavedec2(x, 'db1') assert_allclose(pywt.waverec2(coeffs, 'db1'), x, rtol=1e-12) def test_waverec2_none_coeffs(): x = np.arange(24).reshape(6, 4) coeffs = pywt.wavedec2(x, 'db1') coeffs[1] = (None, None, None) assert_(x.shape == pywt.waverec2(coeffs, 'db1').shape) #### # nd multilevel dwt function tests #### def test_waverecn(): rstate = np.random.RandomState(1234) # test 1D through 4D cases for nd in range(1, 5): x = rstate.randn(*(4, )*nd) coeffs = pywt.wavedecn(x, 'db1') assert_(len(coeffs) == 3) assert_allclose(pywt.waverecn(coeffs, 'db1'), x, rtol=tol_double) def test_waverecn_empty_coeff(): coeffs = [np.ones((2, 2, 2)), {}, {}] assert_equal(pywt.waverecn(coeffs, 'db1').shape, (8, 8, 8)) assert_equal(pywt.waverecn(coeffs, 'db1').shape, (8, 8, 8)) coeffs = [np.ones((2, 2, 2)), {}, {'daa': np.ones((4, 4, 4))}] coeffs = [np.ones((2, 2, 2)), {}, {}, {'daa': np.ones((8, 8, 8))}] assert_equal(pywt.waverecn(coeffs, 'db1').shape, (16, 16, 16)) def test_waverecn_invalid_coeffs(): # approximation coeffs as None and no valid detail oeffs coeffs = [None, {}] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') # use of None for a coefficient value coeffs = [np.ones((2, 2, 2)), {}, {'daa': None}, ] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') # invalid key names in coefficient list coeffs = [np.ones((4, 4, 4)), {'daa': np.ones((4, 4, 4)), 'foo': np.ones((4, 4, 4))}] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') # mismatched key name lengths coeffs = [np.ones((4, 4, 4)), {'daa': np.ones((4, 4, 4)), 'da': np.ones((4, 4, 4))}] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') # key name lengths don't match the array dimensions coeffs = [[[[1.0]]], {'ad': [[[0.0]]], 'da': [[[0.0]]], 'dd': [[[0.0]]]}] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') # input list cannot be empty assert_raises(ValueError, pywt.waverecn, [], 'haar') def test_waverecn_invalid_inputs(): # coefficients from a difference decomposition used as input for dec_func in [pywt.wavedec, pywt.wavedec2]: coeffs = dec_func(np.ones((8, 8)), 'haar') message = "Unexpected detail coefficient type" assert_raises_regex(ValueError, message, pywt.waverecn, coeffs, 'haar') def test_waverecn_lists(): # support coefficient arrays specified as lists instead of arrays coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}] assert_equal(pywt.waverecn(coeffs, 'db1').shape, (2, 2)) def test_waverecn_invalid_coeffs2(): # shape mismatch should raise an error coeffs = [np.ones((4, 4, 4)), {'ada': np.ones((4, 4))}] assert_raises(ValueError, pywt.waverecn, coeffs, 'db1') def test_wavedecn_invalid_inputs(): # input array has too few dimensions data = np.array(0) assert_raises(ValueError, pywt.wavedecn, data, 'haar') # invalid number of levels data = np.ones(16) assert_raises(ValueError, pywt.wavedecn, data, 'haar', level=-1) def test_wavedecn_many_levels(): # perfect reconstruction even when level > pywt.dwt_max_level data = np.arange(64).reshape(8, 8) tol = 1e-12 dec_funcs = [pywt.wavedec, pywt.wavedec2, pywt.wavedecn] rec_funcs = [pywt.waverec, pywt.waverec2, pywt.waverecn] with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) for dec_func, rec_func in zip(dec_funcs, rec_funcs): for mode in ['periodization', 'symmetric']: coeffs = dec_func(data, 'haar', mode=mode, level=20) r = rec_func(coeffs, 'haar', mode=mode) assert_allclose(data, r, atol=tol, rtol=tol) def test_waverecn_accuracies(): # testing 3D only here rstate = np.random.RandomState(1234) x0 = rstate.randn(4, 4, 4) for dt, tol in dtypes_and_tolerances: x = x0.astype(dt) if np.iscomplexobj(x): x += 1j*rstate.randn(4, 4, 4).astype(x.real.dtype) coeffs = pywt.wavedecn(x.astype(dt), 'db1') assert_allclose(pywt.waverecn(coeffs, 'db1'), x, atol=tol, rtol=tol) def test_multilevel_dtypes_nd(): wavelet = pywt.Wavelet('haar') for dt_in, dt_out in zip(dtypes_in, dtypes_out): # wavedecn, waverecn x = np.ones((8, 8), dtype=dt_in) errmsg = "wrong dtype returned for {0} input".format(dt_in) cA, coeffsD2, coeffsD1 = pywt.wavedecn(x, wavelet, level=2) assert_(cA.dtype == dt_out, "wavedecn: " + errmsg) for key, c in coeffsD1.items(): assert_(c.dtype == dt_out, "wavedecn: " + errmsg) for key, c in coeffsD2.items(): assert_(c.dtype == dt_out, "wavedecn: " + errmsg) x_roundtrip = pywt.waverecn([cA, coeffsD2, coeffsD1], wavelet) assert_(x_roundtrip.dtype == dt_out, "waverecn: " + errmsg) def test_wavedecn_complex(): data = np.ones((4, 4, 4)) + 1j coeffs = pywt.wavedecn(data, 'db1') assert_allclose(pywt.waverecn(coeffs, 'db1'), data, rtol=1e-12) def test_waverecn_dtypes(): x = np.ones((4, 4, 4)) for dt, tol in dtypes_and_tolerances: coeffs = pywt.wavedecn(x.astype(dt), 'db1') assert_allclose(pywt.waverecn(coeffs, 'db1'), x, atol=tol, rtol=tol) @pytest.mark.slow def test_waverecn_all_wavelets_modes(): # test 2D case using all wavelets and modes rstate = np.random.RandomState(1234) r = rstate.randn(80, 96) for wavelet in wavelist: for mode in pywt.Modes.modes: coeffs = pywt.wavedecn(r, wavelet, mode=mode) assert_allclose(pywt.waverecn(coeffs, wavelet, mode=mode), r, rtol=tol_single, atol=tol_single) def test_coeffs_to_array(): # single element list returns the first element a_coeffs = [np.arange(8).reshape(2, 4), ] arr, arr_slices = pywt.coeffs_to_array(a_coeffs) assert_allclose(arr, a_coeffs[0]) assert_allclose(arr, arr[arr_slices[0]]) assert_raises(ValueError, pywt.coeffs_to_array, []) # invalid second element: array as in wavedec, but not 1D assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs[0], ] * 2) # invalid second element: tuple as in wavedec2, but not a 3-tuple assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs[0], (a_coeffs[0], )]) # coefficients as None is not supported assert_raises(ValueError, pywt.coeffs_to_array, [None, ]) assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs, (None, None, None)]) # invalid type for second coefficient list element assert_raises(ValueError, pywt.coeffs_to_array, [a_coeffs, None]) # use an invalid key name in the coef dictionary coeffs = [np.array([0]), dict(d=np.array([0]), c=np.array([0]))] assert_raises(ValueError, pywt.coeffs_to_array, coeffs) def test_wavedecn_coeff_reshape_even(): # verify round trip is correct: # wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn # This is done for wavedec{1, 2, n} rng = np.random.RandomState(1234) params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec}, 'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2}, 'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}} N = 28 for f in params: x1 = rng.randn(*([N] * params[f]['d'])) for mode in pywt.Modes.modes: for wave in wavelist: w = pywt.Wavelet(wave) maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len) if maxlevel == 0: continue coeffs = params[f]['dec'](x1, w, mode=mode) coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs) coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices, output_format=f) x1r = params[f]['rec'](coeffs2, w, mode=mode) assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_wavedecn_coeff_reshape_axes_subset(): # verify round trip is correct when only a subset of axes are transformed: # wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn # This is done for wavedec{1, 2, n} rng = np.random.RandomState(1234) mode = 'symmetric' w = pywt.Wavelet('db2') N = 16 ndim = 3 for axes in [(-1, ), (0, ), (1, ), (0, 1), (1, 2), (0, 2), None]: x1 = rng.randn(*([N] * ndim)) coeffs = pywt.wavedecn(x1, w, mode=mode, axes=axes) coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, axes=axes) if axes is not None: # if axes is not None, it must be provided to coeffs_to_array assert_raises(ValueError, pywt.coeffs_to_array, coeffs) # mismatched axes size assert_raises(ValueError, pywt.coeffs_to_array, coeffs, axes=(0, 1, 2, 3)) assert_raises(ValueError, pywt.coeffs_to_array, coeffs, axes=()) coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices) x1r = pywt.waverecn(coeffs2, w, mode=mode, axes=axes) assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_coeffs_to_array_padding(): rng = np.random.RandomState(1234) x1 = rng.randn(32, 32) mode = 'symmetric' coeffs = pywt.wavedecn(x1, 'db2', mode=mode) # padding=None raises a ValueError when tight packing is not possible assert_raises(ValueError, pywt.coeffs_to_array, coeffs, padding=None) # set padded values to nan coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, padding=np.nan) npad = np.sum(np.isnan(coeff_arr)) assert_(npad > 0) # pad with zeros coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs, padding=0) assert_(np.sum(np.isnan(coeff_arr)) == 0) assert_(np.sum(coeff_arr == 0) == npad) # Haar case with N as a power of 2 can be tightly packed coeffs_haar = pywt.wavedecn(x1, 'haar', mode=mode) coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs_haar, padding=None) # shape of coeff_arr will match in this case, but not in general assert_equal(coeff_arr.shape, x1.shape) def test_waverecn_coeff_reshape_odd(): # verify round trip is correct: # wavedecn - >coeffs_to_array-> array_to_coeffs -> waverecn rng = np.random.RandomState(1234) x1 = rng.randn(35, 33) for mode in pywt.Modes.modes: for wave in ['haar', ]: w = pywt.Wavelet(wave) maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len) if maxlevel == 0: continue coeffs = pywt.wavedecn(x1, w, mode=mode) coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs) coeffs2 = pywt.array_to_coeffs(coeff_arr, coeff_slices) x1r = pywt.waverecn(coeffs2, w, mode=mode) # truncate reconstructed values to original shape x1r = x1r[tuple([slice(s) for s in x1.shape])] assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_array_to_coeffs_invalid_inputs(): coeffs = pywt.wavedecn(np.ones(2), 'haar') arr, arr_slices = pywt.coeffs_to_array(coeffs) # empty list of array slices assert_raises(ValueError, pywt.array_to_coeffs, arr, []) # invalid format name assert_raises(ValueError, pywt.array_to_coeffs, arr, arr_slices, 'foo') def test_wavedecn_coeff_ravel(): # verify round trip is correct: # wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn # This is done for wavedec{1, 2, n} rng = np.random.RandomState(1234) params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec}, 'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2}, 'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}} N = 12 for f in params: x1 = rng.randn(*([N] * params[f]['d'])) for mode in pywt.Modes.modes: for wave in wavelist: w = pywt.Wavelet(wave) maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len) if maxlevel == 0: continue coeffs = params[f]['dec'](x1, w, mode=mode) coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs) coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes, output_format=f) x1r = params[f]['rec'](coeffs2, w, mode=mode) assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_wavedecn_coeff_ravel_zero_level(): # verify round trip is correct: # wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn # This is done for wavedec{1, 2, n} rng = np.random.RandomState(1234) params = {'wavedec': {'d': 1, 'dec': pywt.wavedec, 'rec': pywt.waverec}, 'wavedec2': {'d': 2, 'dec': pywt.wavedec2, 'rec': pywt.waverec2}, 'wavedecn': {'d': 3, 'dec': pywt.wavedecn, 'rec': pywt.waverecn}} N = 16 for f in params: x1 = rng.randn(*([N] * params[f]['d'])) for mode in pywt.Modes.modes: w = pywt.Wavelet('db2') coeffs = params[f]['dec'](x1, w, mode=mode, level=0) coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs) coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes, output_format=f) x1r = params[f]['rec'](coeffs2, w, mode=mode) assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_waverecn_coeff_ravel_odd(): # verify round trip is correct: # wavedecn - >ravel_coeffs-> unravel_coeffs -> waverecn rng = np.random.RandomState(1234) x1 = rng.randn(35, 33) for mode in pywt.Modes.modes: for wave in ['haar', ]: w = pywt.Wavelet(wave) maxlevel = pywt.dwt_max_level(np.min(x1.shape), w.dec_len) if maxlevel == 0: continue coeffs = pywt.wavedecn(x1, w, mode=mode) coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs) coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes) x1r = pywt.waverecn(coeffs2, w, mode=mode) # truncate reconstructed values to original shape x1r = x1r[tuple([slice(s) for s in x1.shape])] assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) def test_ravel_wavedec2_with_lists(): x1 = np.ones((8, 8)) wav = pywt.Wavelet('haar') coeffs = pywt.wavedec2(x1, wav) # list [cHn, cVn, cDn] instead of tuple is okay coeffs[1:] = [list(c) for c in coeffs[1:]] coeff_arr, slices, shapes = pywt.ravel_coeffs(coeffs) coeffs2 = pywt.unravel_coeffs(coeff_arr, slices, shapes, output_format='wavedec2') x1r = pywt.waverec2(coeffs2, wav) assert_allclose(x1, x1r, rtol=1e-4, atol=1e-4) # wrong length list will cause a ValueError coeffs[1:] = [list(c[:-1]) for c in coeffs[1:]] # truncate diag coeffs assert_raises(ValueError, pywt.ravel_coeffs, coeffs) def test_ravel_invalid_input(): # wavedec ravel does not support any coefficient arrays being set to None coeffs = pywt.wavedec(np.ones(8), 'haar') coeffs[1] = None assert_raises(ValueError, pywt.ravel_coeffs, coeffs) # wavedec2 ravel cannot have None or a tuple/list of None coeffs = pywt.wavedec2(np.ones((8, 8)), 'haar') coeffs[1] = (None, None, None) assert_raises(ValueError, pywt.ravel_coeffs, coeffs) coeffs[1] = [None, None, None] assert_raises(ValueError, pywt.ravel_coeffs, coeffs) coeffs[1] = None assert_raises(ValueError, pywt.ravel_coeffs, coeffs) # wavedecn ravel cannot have any dictionary elements as None coeffs = pywt.wavedecn(np.ones((8, 8, 8)), 'haar') coeffs[1]['ddd'] = None assert_raises(ValueError, pywt.ravel_coeffs, coeffs) def test_unravel_invalid_inputs(): coeffs = pywt.wavedecn(np.ones(2), 'haar') arr, slices, shapes = pywt.ravel_coeffs(coeffs) # empty list for slices or shapes assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, []) assert_raises(ValueError, pywt.unravel_coeffs, arr, [], shapes) # unequal length for slices/shapes assert_raises(ValueError, pywt.unravel_coeffs, arr, slices[:-1], shapes) # invalid format name assert_raises(ValueError, pywt.unravel_coeffs, arr, slices, shapes, 'foo') def test_wavedecn_shapes_and_size(): wav = pywt.Wavelet('db2') for data_shape in [(33, ), (64, 32), (1, 15, 30)]: for axes in [None, 0, -1]: for mode in pywt.Modes.modes: coeffs = pywt.wavedecn(np.ones(data_shape), wav, mode=mode, axes=axes) # verify that the shapes match the coefficient shapes shapes = pywt.wavedecn_shapes(data_shape, wav, mode=mode, axes=axes) assert_equal(coeffs[0].shape, shapes[0]) expected_size = coeffs[0].size for level in range(1, len(coeffs)): for k, v in coeffs[level].items(): expected_size += v.size assert_equal(shapes[level][k], v.shape) # size can be determined from either the shapes or coeffs size = pywt.wavedecn_size(shapes) assert_equal(size, expected_size) size = pywt.wavedecn_size(coeffs) assert_equal(size, expected_size) def test_dwtn_max_level(): # predicted and empirical dwtn_max_level match for wav in [pywt.Wavelet('db2'), 'sym8']: for data_shape in [(33, ), (64, 32), (1, 15, 30)]: for axes in [None, 0, -1]: for mode in pywt.Modes.modes: coeffs = pywt.wavedecn(np.ones(data_shape), wav, mode=mode, axes=axes) max_lev = pywt.dwtn_max_level(data_shape, wav, axes) assert_equal(len(coeffs[1:]), max_lev) def test_waverec_axes_subsets(): rstate = np.random.RandomState(0) data = rstate.standard_normal((8, 8, 8)) # test all combinations of 1 out of 3 axes transformed for axis in [0, 1, 2]: coefs = pywt.wavedec(data, 'haar', axis=axis) rec = pywt.waverec(coefs, 'haar', axis=axis) assert_allclose(rec, data, atol=1e-14) def test_waverec_axis_db2(): # test for fix to issue gh-293 rstate = np.random.RandomState(0) data = rstate.standard_normal((16, 16)) for axis in [0, 1]: coefs = pywt.wavedec(data, 'db2', axis=axis) rec = pywt.waverec(coefs, 'db2', axis=axis) assert_allclose(rec, data, atol=1e-14) def test_waverec2_axes_subsets(): rstate = np.random.RandomState(0) data = rstate.standard_normal((8, 8, 8)) # test all combinations of 2 out of 3 axes transformed for axes in combinations((0, 1, 2), 2): coefs = pywt.wavedec2(data, 'haar', axes=axes) rec = pywt.waverec2(coefs, 'haar', axes=axes) assert_allclose(rec, data, atol=1e-14) def test_waverecn_axes_subsets(): rstate = np.random.RandomState(0) data = rstate.standard_normal((8, 8, 8, 8)) # test all combinations of 3 out of 4 axes transformed for axes in combinations((0, 1, 2, 3), 3): coefs = pywt.wavedecn(data, 'haar', axes=axes) rec = pywt.waverecn(coefs, 'haar', axes=axes) assert_allclose(rec, data, atol=1e-14) def test_waverecn_int_axis(): # waverecn should also work for axes as an integer rstate = np.random.RandomState(0) data = rstate.standard_normal((8, 8)) for axis in [0, 1]: coefs = pywt.wavedecn(data, 'haar', axes=axis) rec = pywt.waverecn(coefs, 'haar', axes=axis) assert_allclose(rec, data, atol=1e-14) def test_wavedec_axis_error(): data = np.ones(4) # out of range axis not allowed assert_raises(ValueError, pywt.wavedec, data, 'haar', axis=1) def test_waverec_axis_error(): c = pywt.wavedec(np.ones(4), 'haar') # out of range axis not allowed assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1) def test_waverec_shape_mismatch_error(): c = pywt.wavedec(np.ones(16), 'haar') # truncate a detail coefficient to an incorrect shape c[3] = c[3][:-1] assert_raises(ValueError, pywt.waverec, c, 'haar', axis=1) def test_wavedec2_axes_errors(): data = np.ones((4, 4)) # integer axes not allowed assert_raises(TypeError, pywt.wavedec2, data, 'haar', axes=1) # non-unique axes not allowed assert_raises(ValueError, pywt.wavedec2, data, 'haar', axes=(0, 0)) # out of range axis not allowed assert_raises(ValueError, pywt.wavedec2, data, 'haar', axes=(0, 2)) def test_waverec2_axes_errors(): data = np.ones((4, 4)) c = pywt.wavedec2(data, 'haar') # integer axes not allowed assert_raises(TypeError, pywt.waverec2, c, 'haar', axes=1) # non-unique axes not allowed assert_raises(ValueError, pywt.waverec2, c, 'haar', axes=(0, 0)) # out of range axis not allowed assert_raises(ValueError, pywt.waverec2, c, 'haar', axes=(0, 2)) def test_wavedecn_axes_errors(): data = np.ones((8, 8, 8)) # repeated axes not allowed assert_raises(ValueError, pywt.wavedecn, data, 'haar', axes=(1, 1)) # out of range axis not allowed assert_raises(ValueError, pywt.wavedecn, data, 'haar', axes=(0, 1, 3)) def test_waverecn_axes_errors(): data = np.ones((8, 8, 8)) c = pywt.wavedecn(data, 'haar') # repeated axes not allowed assert_raises(ValueError, pywt.waverecn, c, 'haar', axes=(1, 1)) # out of range axis not allowed assert_raises(ValueError, pywt.waverecn, c, 'haar', axes=(0, 1, 3)) def test_per_axis_wavelets_and_modes(): # tests seperate wavelet and edge mode for each axis. rstate = np.random.RandomState(1234) data = rstate.randn(24, 24, 16) # wavelet can be a string or wavelet object wavelets = (pywt.Wavelet('haar'), 'sym2', 'db2') # The default number of levels should be the minimum over this list max_levels = [pywt._dwt.dwt_max_level(nd, nf) for nd, nf in zip(data.shape, wavelets)] # mode can be a string or a Modes enum modes = ('symmetric', 'periodization', pywt._extensions._pywt.Modes.reflect) coefs = pywt.wavedecn(data, wavelets, modes) assert_allclose(pywt.waverecn(coefs, wavelets, modes), data, atol=1e-14) assert_equal(min(max_levels), len(coefs[1:])) coefs = pywt.wavedecn(data, wavelets[:1], modes) assert_allclose(pywt.waverecn(coefs, wavelets[:1], modes), data, atol=1e-14) coefs = pywt.wavedecn(data, wavelets, modes[:1]) assert_allclose(pywt.waverecn(coefs, wavelets, modes[:1]), data, atol=1e-14) # length of wavelets or modes doesn't match the length of axes assert_raises(ValueError, pywt.wavedecn, data, wavelets[:2]) assert_raises(ValueError, pywt.wavedecn, data, wavelets, mode=modes[:2]) assert_raises(ValueError, pywt.waverecn, coefs, wavelets[:2]) assert_raises(ValueError, pywt.waverecn, coefs, wavelets, mode=modes[:2]) # dwt2/idwt2 also support per-axis wavelets/modes data2 = data[..., 0] coefs2 = pywt.wavedec2(data2, wavelets[:2], modes[:2]) assert_allclose(pywt.waverec2(coefs2, wavelets[:2], modes[:2]), data2, atol=1e-14) assert_equal(min(max_levels[:2]), len(coefs2[1:])) # Tests for fully separable multi-level transforms def test_fswavedecn_fswaverecn_roundtrip(): # verify proper round trip result for 1D through 4D data # same DWT as wavedecn/waverecn so don't need to test all modes/wavelets rstate = np.random.RandomState(0) for ndim in range(1, 5): for dt_in, dt_out in zip(dtypes_in, dtypes_out): for levels in (1, None): data = rstate.standard_normal((8, )*ndim) data = data.astype(dt_in) T = pywt.fswavedecn(data, 'haar', levels=levels) rec = pywt.fswaverecn(T) if data.real.dtype in [np.float32, np.float16]: assert_allclose(rec, data, rtol=1e-6, atol=1e-6) else: assert_allclose(rec, data, rtol=1e-14, atol=1e-14) assert_(T.coeffs.dtype == dt_out) assert_(rec.dtype == dt_out) def test_fswavedecn_fswaverecn_zero_levels(): # zero level transform gives coefs matching the original data rstate = np.random.RandomState(0) ndim = 2 data = rstate.standard_normal((8, )*ndim) T = pywt.fswavedecn(data, 'haar', levels=0) assert_array_equal(T.coeffs, data) rec = pywt.fswaverecn(T) assert_array_equal(T.coeffs, rec) def test_fswavedecn_fswaverecn_variable_levels(): # test with differing number of transform levels per axis rstate = np.random.RandomState(0) ndim = 3 data = rstate.standard_normal((16, )*ndim) T = pywt.fswavedecn(data, 'haar', levels=(1, 2, 3)) rec = pywt.fswaverecn(T) assert_allclose(rec, data, atol=1e-14) # levels doesn't match number of axes assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1)) assert_raises(ValueError, pywt.fswavedecn, data, 'haar', levels=(1, 1, 1, 1)) # levels too large for array size assert_warns(UserWarning, pywt.fswavedecn, data, 'haar', levels=int(np.log2(np.min(data.shape)))+1) def test_fswavedecn_fswaverecn_variable_wavelets_and_modes(): # test with differing number of transform levels per axis rstate = np.random.RandomState(0) ndim = 3 data = rstate.standard_normal((16, )*ndim) wavelets = ('haar', 'db2', 'sym3') modes = ('periodic', 'symmetric', 'periodization') T = pywt.fswavedecn(data, wavelet=wavelets, mode=modes) for ax in range(ndim): # expect approx + dwt_max_level detail coeffs along each axis assert_equal(len(T.coeff_slices[ax]), pywt.dwt_max_level(data.shape[ax], wavelets[ax])+1) rec = pywt.fswaverecn(T) assert_allclose(rec, data, atol=1e-14) # number of wavelets doesn't match number of axes assert_raises(ValueError, pywt.fswavedecn, data, wavelets[:2]) # number of modes doesn't match number of axes assert_raises(ValueError, pywt.fswavedecn, data, wavelets[0], mode=modes[:2]) def test_fswavedecn_fswaverecn_axes_subsets(): """Fully separable DWT over only a subset of axes""" rstate = np.random.RandomState(0) # use anisotropic data to result in unique number of levels per axis data = rstate.standard_normal((4, 8, 16, 32)) # test all combinations of 3 out of 4 axes transformed for axes in combinations((0, 1, 2, 3), 3): T = pywt.fswavedecn(data, 'haar', axes=axes) rec = pywt.fswaverecn(T) assert_allclose(rec, data, atol=1e-14) # some axes exceed data dimensions assert_raises(ValueError, pywt.fswavedecn, data, 'haar', axes=(1, 5)) def test_fswavedecnresult(): data = np.ones((32, 32)) levels = (1, 2) result = pywt.fswavedecn(data, 'sym2', levels=levels) # can access the lowpass band via .approx or via __getitem__ approx_key = (0, ) * data.ndim assert_array_equal(result[approx_key], result.approx) dkeys = result.detail_keys() # the approximation key shouldn't be present in the detail_keys assert_(approx_key not in dkeys) # can access all detail coefficients and they have matching ndim for k in dkeys: d = result[k] assert_equal(d.ndim, data.ndim) # can assign modified coefficients result[k] = np.zeros_like(d) # assigning a differently sized array raises a ValueError assert_raises(ValueError, result.__setitem__, k, np.zeros(tuple([s + 1 for s in d.shape]))) # warns on assigning with a non-matching dtype assert_warns(UserWarning, result.__setitem__, k, np.zeros_like(d).astype(np.float32)) # all coefficients are stacked into result.coeffs (same ndim) assert_equal(result.coeffs.ndim, data.ndim) def test_error_on_continuous_wavelet(): # A ValueError is raised if a Continuous wavelet is selected data = np.ones((16, 16)) for dec_fun, rec_fun in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn], [pywt.waverec, pywt.waverec2, pywt.waverecn]): for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]: assert_raises(ValueError, dec_fun, data, wavelet=cwave) c = dec_fun(data, 'db1') assert_raises(ValueError, rec_fun, c, wavelet=cwave) def test_default_level(): # default level is the maximum permissible for the transformed axes data = np.ones((128, 32, 4)) wavelet = ('db8', 'db1') for dec_func in [pywt.wavedec2, pywt.wavedecn]: for axes in [(0, 1), (2, 1), (0, 2)]: c = dec_func(data, wavelet, axes=axes) max_lev = np.min([pywt.dwt_max_level(data.shape[ax], wav) for ax, wav in zip(axes, wavelet)]) assert_equal(len(c[1:]), max_lev) for ax in [0, 1]: c = pywt.wavedecn(data, wavelet[ax], axes=(ax, )) assert_equal(len(c[1:]), pywt.dwt_max_level(data.shape[ax], wavelet[ax])) def test_waverec_mixed_precision(): rstate = np.random.RandomState(0) for func, ifunc, shape in [(pywt.wavedec, pywt.waverec, (8, )), (pywt.wavedec2, pywt.waverec2, (8, 8)), (pywt.wavedecn, pywt.waverecn, (8, 8, 8))]: x = rstate.randn(*shape) coeffs_real = func(x, 'db1') # real: single precision approx, double precision details coeffs_real[0] = coeffs_real[0].astype(np.float32) r = ifunc(coeffs_real, 'db1') assert_allclose(r, x, rtol=1e-7, atol=1e-7) assert_equal(r.dtype, np.float64) x = x + 1j*x coeffs = func(x, 'db1') # complex: single precision approx, double precision details coeffs[0] = coeffs[0].astype(np.complex64) r = ifunc(coeffs, 'db1') assert_allclose(r, x, rtol=1e-7, atol=1e-7) assert_equal(r.dtype, np.complex128) # complex: double precision approx, single precision details if x.ndim == 1: coeffs[0] = coeffs[0].astype(np.complex128) coeffs[1] = coeffs[1].astype(np.complex64) if x.ndim == 2: coeffs[0] = coeffs[0].astype(np.complex128) coeffs[1] = tuple([v.astype(np.complex64) for v in coeffs[1]]) if x.ndim == 3: coeffs[0] = coeffs[0].astype(np.complex128) coeffs[1] = {k: v.astype(np.complex64) for k, v in coeffs[1].items()} r = ifunc(coeffs, 'db1') assert_allclose(r, x, rtol=1e-7, atol=1e-7) assert_equal(r.dtype, np.complex128)