#!/usr/bin/env python from __future__ import division, print_function, absolute_import import numpy as np from numpy.testing import assert_allclose, assert_, assert_raises import pywt def test_upcoef_reconstruct(): data = np.arange(3) a = pywt.downcoef('a', data, 'haar') d = pywt.downcoef('d', data, 'haar') rec = (pywt.upcoef('a', a, 'haar', take=3) + pywt.upcoef('d', d, 'haar', take=3)) assert_allclose(rec, data) def test_downcoef_multilevel(): rstate = np.random.RandomState(1234) r = rstate.randn(16) nlevels = 3 # calling with level=1 nlevels times a1 = r.copy() for i in range(nlevels): a1 = pywt.downcoef('a', a1, 'haar', level=1) # call with level=nlevels once a3 = pywt.downcoef('a', r, 'haar', level=nlevels) assert_allclose(a1, a3) def test_downcoef_complex(): rstate = np.random.RandomState(1234) r = rstate.randn(16) + 1j * rstate.randn(16) nlevels = 3 a = pywt.downcoef('a', r, 'haar', level=nlevels) a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels) a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels) assert_allclose(a, a_ref) def test_downcoef_errs(): # invalid part string (not 'a' or 'd') assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar') def test_compare_downcoef_coeffs(): rstate = np.random.RandomState(1234) r = rstate.randn(16) # compare downcoef against wavedec outputs for nlevels in [1, 2, 3]: for wavelet in pywt.wavelist(): if wavelet in ['cmor', 'shan', 'fbsp']: # skip these CWT families to avoid warnings continue wavelet = pywt.DiscreteContinuousWavelet(wavelet) if isinstance(wavelet, pywt.Wavelet): max_level = pywt.dwt_max_level(r.size, wavelet.dec_len) if nlevels <= max_level: a = pywt.downcoef('a', r, wavelet, level=nlevels) d = pywt.downcoef('d', r, wavelet, level=nlevels) coeffs = pywt.wavedec(r, wavelet, level=nlevels) assert_allclose(a, coeffs[0]) assert_allclose(d, coeffs[1]) def test_upcoef_multilevel(): rstate = np.random.RandomState(1234) r = rstate.randn(4) nlevels = 3 # calling with level=1 nlevels times a1 = r.copy() for i in range(nlevels): a1 = pywt.upcoef('a', a1, 'haar', level=1) # call with level=nlevels once a3 = pywt.upcoef('a', r, 'haar', level=nlevels) assert_allclose(a1, a3) def test_upcoef_complex(): rstate = np.random.RandomState(1234) r = rstate.randn(4) + 1j*rstate.randn(4) nlevels = 3 a = pywt.upcoef('a', r, 'haar', level=nlevels) a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels) a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels) assert_allclose(a, a_ref) def test_upcoef_errs(): # invalid part string (not 'a' or 'd') assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar') def test_upcoef_and_downcoef_1d_only(): # upcoef and downcoef raise a ValueError if data.ndim > 1d for ndim in [2, 3]: data = np.ones((8, )*ndim) assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar') assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar') def test_wavelet_repr(): from pywt._extensions import _pywt wavelet = _pywt.Wavelet('sym8') repr_wavelet = eval(wavelet.__repr__()) assert_(wavelet.__repr__() == repr_wavelet.__repr__()) def test_dwt_max_level(): assert_(pywt.dwt_max_level(16, 2) == 4) assert_(pywt.dwt_max_level(16, 8) == 1) assert_(pywt.dwt_max_level(16, 9) == 1) assert_(pywt.dwt_max_level(16, 10) == 0) assert_(pywt.dwt_max_level(16, np.int8(10)) == 0) assert_(pywt.dwt_max_level(16, 10.) == 0) assert_(pywt.dwt_max_level(16, 18) == 0) # accepts discrete Wavelet object or string as well assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1) assert_(pywt.dwt_max_level(32, 'sym5') == 1) # string input that is not a discrete wavelet assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh') # filter_len must be an integer >= 2 assert_raises(ValueError, pywt.dwt_max_level, 16, 1) assert_raises(ValueError, pywt.dwt_max_level, 16, -1) assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3) def test_ContinuousWavelet_errs(): assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz') def test_ContinuousWavelet_repr(): from pywt._extensions import _pywt wavelet = _pywt.ContinuousWavelet('gaus2') repr_wavelet = eval(wavelet.__repr__()) assert_(wavelet.__repr__() == repr_wavelet.__repr__()) def test_wavelist(): for name in pywt.wavelist(family='coif'): assert_(name.startswith('coif')) assert_('cgau7' in pywt.wavelist(kind='continuous')) assert_('sym20' in pywt.wavelist(kind='discrete')) assert_(len(pywt.wavelist(kind='continuous')) + len(pywt.wavelist(kind='discrete')) == len(pywt.wavelist(kind='all'))) assert_raises(ValueError, pywt.wavelist, kind='foobar') def test_wavelet_errormsgs(): try: pywt.Wavelet('gaus1') except ValueError as e: assert_(e.args[0].startswith('The `Wavelet` class')) try: pywt.Wavelet('cmord') except ValueError as e: assert_(e.args[0] == "Invalid wavelet name 'cmord'.")