266 lines
11 KiB
Python
266 lines
11 KiB
Python
#!/usr/bin/env python
|
|
from __future__ import division, print_function, absolute_import
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_allclose, assert_
|
|
|
|
import pywt
|
|
|
|
|
|
def test_wavelet_properties():
|
|
w = pywt.Wavelet('db3')
|
|
|
|
# Name
|
|
assert_(w.name == 'db3')
|
|
assert_(w.short_family_name == 'db')
|
|
assert_(w.family_name, 'Daubechies')
|
|
|
|
# String representation
|
|
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
|
|
'Biorthogonal', 'Symmetry')
|
|
for field in fields:
|
|
assert_(field in str(w))
|
|
|
|
# Filter coefficients
|
|
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
|
|
0.45987750211933, 0.80689150931334, 0.33267055295096]
|
|
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
|
|
-0.13501102001039, 0.08544127388224, 0.03522629188210]
|
|
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
|
|
-0.13501102001039, -0.08544127388224, 0.03522629188210]
|
|
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
|
|
-0.45987750211933, 0.80689150931334, -0.33267055295096]
|
|
assert_allclose(w.dec_lo, dec_lo)
|
|
assert_allclose(w.dec_hi, dec_hi)
|
|
assert_allclose(w.rec_lo, rec_lo)
|
|
assert_allclose(w.rec_hi, rec_hi)
|
|
|
|
assert_(len(w.filter_bank) == 4)
|
|
|
|
# Orthogonality
|
|
assert_(w.orthogonal)
|
|
assert_(w.biorthogonal)
|
|
|
|
# Symmetry
|
|
assert_(w.symmetry)
|
|
|
|
# Vanishing moments
|
|
assert_(w.vanishing_moments_phi == 0)
|
|
assert_(w.vanishing_moments_psi == 3)
|
|
|
|
|
|
def test_wavelet_coefficients():
|
|
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
|
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
|
for wavelet in wavelets:
|
|
if (pywt.Wavelet(wavelet).orthogonal):
|
|
check_coefficients_orthogonal(wavelet)
|
|
elif(pywt.Wavelet(wavelet).biorthogonal):
|
|
check_coefficients_biorthogonal(wavelet)
|
|
else:
|
|
check_coefficients(wavelet)
|
|
|
|
|
|
def check_coefficients_orthogonal(wavelet):
|
|
|
|
epsilon = 5e-11
|
|
level = 5
|
|
w = pywt.Wavelet(wavelet)
|
|
phi, psi, x = w.wavefun(level=level)
|
|
|
|
# Lowpass filter coefficients sum to sqrt2
|
|
res = np.sum(w.dec_lo)-np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# sum even coef = sum odd coef = 1 / sqrt(2)
|
|
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
|
|
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Highpass filter coefficients sum to zero
|
|
res = np.sum(w.dec_hi)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Scaling function integrates to unity
|
|
|
|
res = np.sum(phi) - 2**level
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Wavelet function is orthogonal to the scaling function at the same scale
|
|
res = np.sum(phi*psi)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# The lowpass and highpass filter coefficients are orthogonal
|
|
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
|
|
|
|
def check_coefficients_biorthogonal(wavelet):
|
|
|
|
epsilon = 5e-11
|
|
level = 5
|
|
w = pywt.Wavelet(wavelet)
|
|
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
|
|
|
|
# Lowpass filter coefficients sum to sqrt2
|
|
res = np.sum(w.dec_lo)-np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# sum even coef = sum odd coef = 1 / sqrt(2)
|
|
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Highpass filter coefficients sum to zero
|
|
res = np.sum(w.dec_hi)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Scaling function integrates to unity
|
|
res = np.sum(phi_d) - 2**level
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
res = np.sum(phi_r) - 2**level
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
|
|
|
|
def check_coefficients(wavelet):
|
|
epsilon = 5e-11
|
|
level = 10
|
|
w = pywt.Wavelet(wavelet)
|
|
# Lowpass filter coefficients sum to sqrt2
|
|
res = np.sum(w.dec_lo)-np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# sum even coef = sum odd coef = 1 / sqrt(2)
|
|
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
|
|
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
# Highpass filter coefficients sum to zero
|
|
res = np.sum(w.dec_hi)
|
|
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
|
assert_(res < epsilon, msg=msg)
|
|
|
|
|
|
class _CustomHaarFilterBank(object):
|
|
@property
|
|
def filter_bank(self):
|
|
val = np.sqrt(2) / 2
|
|
return ([val]*2, [-val, val], [val]*2, [val, -val])
|
|
|
|
|
|
def test_custom_wavelet():
|
|
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
|
|
filter_bank=_CustomHaarFilterBank())
|
|
haar_custom1.orthogonal = True
|
|
haar_custom1.biorthogonal = True
|
|
|
|
val = np.sqrt(2) / 2
|
|
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
|
|
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
|
|
filter_bank=filter_bank)
|
|
|
|
# check expected default wavelet properties
|
|
assert_(~haar_custom2.orthogonal)
|
|
assert_(~haar_custom2.biorthogonal)
|
|
assert_(haar_custom2.symmetry == 'unknown')
|
|
assert_(haar_custom2.family_name == '')
|
|
assert_(haar_custom2.short_family_name == '')
|
|
assert_(haar_custom2.vanishing_moments_phi == 0)
|
|
assert_(haar_custom2.vanishing_moments_psi == 0)
|
|
|
|
# Some properties can be set by the user
|
|
haar_custom2.orthogonal = True
|
|
haar_custom2.biorthogonal = True
|
|
|
|
|
|
def test_wavefun_sym3():
|
|
w = pywt.Wavelet('sym3')
|
|
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
|
|
phi, psi, x = w.wavefun(level=3)
|
|
assert_(phi.size == 41)
|
|
assert_(psi.size == 41)
|
|
assert_(x.size == 41)
|
|
|
|
assert_allclose(x, np.linspace(0, 5, num=x.size))
|
|
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
|
|
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
|
|
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
|
|
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
|
|
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
|
|
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
|
|
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
|
|
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
|
|
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
|
|
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
|
|
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
|
|
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
|
|
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
|
|
0.00000000e+00, 0.00000000e+00])
|
|
|
|
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
|
|
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
|
|
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
|
|
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
|
|
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
|
|
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
|
|
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
|
|
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
|
|
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
|
|
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
|
|
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
|
|
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
|
|
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
|
|
0.00000000e+00, 0.00000000e+00])
|
|
|
|
assert_allclose(phi, phi_expect)
|
|
assert_allclose(psi, psi_expect)
|
|
|
|
|
|
def test_wavefun_bior13():
|
|
w = pywt.Wavelet('bior1.3')
|
|
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
|
|
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
|
|
for arr in [phi_d, psi_d, phi_r, psi_r]:
|
|
assert_(arr.size == 40)
|
|
|
|
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
|
|
0.01367188, 0.00390625, -0.03515625, -0.12890625,
|
|
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
|
|
0.15234375, 0.37890625, 0.78515625, 0.99609375,
|
|
1.08203125, 1.13671875, 1.13671875, 1.08203125,
|
|
0.99609375, 0.78515625, 0.37890625, 0.15234375,
|
|
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
|
|
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
|
|
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
|
|
phi_r_expect = np.zeros(x.size, dtype=np.float)
|
|
phi_r_expect[15:23] = 1
|
|
|
|
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
0.015625, -0.015625, -0.140625, -0.109375,
|
|
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
|
|
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
|
|
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
|
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
|
|
|
psi_r_expect = np.zeros(x.size, dtype=np.float)
|
|
psi_r_expect[7:15] = -0.125
|
|
psi_r_expect[15:19] = 1
|
|
psi_r_expect[19:23] = -1
|
|
psi_r_expect[23:31] = 0.125
|
|
|
|
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
|
|
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
|
|
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
|
|
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
|
|
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
|