161 lines
5.7 KiB
Python
161 lines
5.7 KiB
Python
|
"""
|
||
|
Test used to verify PyWavelets Discrete Wavelet Transform computation
|
||
|
accuracy against MathWorks Wavelet Toolbox.
|
||
|
"""
|
||
|
|
||
|
from __future__ import division, print_function, absolute_import
|
||
|
|
||
|
import numpy as np
|
||
|
import pytest
|
||
|
from numpy.testing import assert_
|
||
|
|
||
|
import pywt
|
||
|
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
|
||
|
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
|
||
|
|
||
|
# list of mode names in pywt and matlab
|
||
|
modes = [('zero', 'zpd'),
|
||
|
('constant', 'sp0'),
|
||
|
('symmetric', 'sym'),
|
||
|
('reflect', 'symw'),
|
||
|
('periodic', 'ppd'),
|
||
|
('smooth', 'sp1'),
|
||
|
('periodization', 'per'),
|
||
|
# TODO: Now have implemented asymmetric modes too.
|
||
|
# Would be nice to update the Matlab data to test these as well.
|
||
|
('antisymmetric', 'asym'),
|
||
|
('antireflect', 'asymw'),
|
||
|
]
|
||
|
|
||
|
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||
|
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||
|
|
||
|
|
||
|
def _get_data_sizes(w):
|
||
|
""" Return the sizes to test for wavelet w. """
|
||
|
if size_set == 'full':
|
||
|
data_sizes = list(range(w.dec_len, 40)) + \
|
||
|
[100, 200, 500, 1000, 50000]
|
||
|
else:
|
||
|
data_sizes = (w.dec_len, w.dec_len + 1)
|
||
|
return data_sizes
|
||
|
|
||
|
|
||
|
@uses_pymatbridge
|
||
|
@pytest.mark.slow
|
||
|
def test_accuracy_pymatbridge():
|
||
|
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||
|
mlab = Matlab()
|
||
|
|
||
|
rstate = np.random.RandomState(1234)
|
||
|
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||
|
epsilon = 5.0e-5
|
||
|
epsilon_pywt_coeffs = 1.0e-10
|
||
|
mlab.start()
|
||
|
try:
|
||
|
for wavelet in wavelets:
|
||
|
w = pywt.Wavelet(wavelet)
|
||
|
mlab.set_variable('wavelet', wavelet)
|
||
|
for N in _get_data_sizes(w):
|
||
|
data = rstate.randn(N)
|
||
|
mlab.set_variable('data', data)
|
||
|
for pmode, mmode in modes:
|
||
|
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
|
||
|
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||
|
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||
|
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||
|
|
||
|
finally:
|
||
|
mlab.stop()
|
||
|
|
||
|
|
||
|
@uses_precomputed
|
||
|
@pytest.mark.slow
|
||
|
def test_accuracy_precomputed():
|
||
|
# Keep this specific random seed to match the precomputed Matlab result.
|
||
|
rstate = np.random.RandomState(1234)
|
||
|
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||
|
epsilon = 5.0e-5
|
||
|
epsilon_pywt_coeffs = 1.0e-10
|
||
|
for wavelet in wavelets:
|
||
|
w = pywt.Wavelet(wavelet)
|
||
|
for N in _get_data_sizes(w):
|
||
|
data = rstate.randn(N)
|
||
|
for pmode, mmode in modes:
|
||
|
ma, md = _load_matlab_result(data, wavelet, mmode)
|
||
|
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||
|
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||
|
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||
|
|
||
|
|
||
|
def _compute_matlab_result(data, wavelet, mmode, mlab):
|
||
|
""" Compute the result using MATLAB.
|
||
|
|
||
|
This function assumes that the Matlab variables `wavelet` and `data` have
|
||
|
already been set externally.
|
||
|
"""
|
||
|
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||
|
w = pywt.Wavelet(wavelet)
|
||
|
mlab.set_variable('Lo_D', w.dec_lo)
|
||
|
mlab.set_variable('Hi_D', w.dec_hi)
|
||
|
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
|
||
|
else:
|
||
|
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
|
||
|
res = mlab.run_code(mlab_code)
|
||
|
if not res['success']:
|
||
|
raise RuntimeError("Matlab failed to execute the provided code. "
|
||
|
"Check that the wavelet toolbox is installed.")
|
||
|
# need np.asarray because sometimes the output is a single float64
|
||
|
ma = np.asarray(mlab.get_variable('ma'))
|
||
|
md = np.asarray(mlab.get_variable('md'))
|
||
|
return ma, md
|
||
|
|
||
|
|
||
|
def _load_matlab_result(data, wavelet, mmode):
|
||
|
""" Load the precomputed result.
|
||
|
"""
|
||
|
N = len(data)
|
||
|
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||
|
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||
|
if (ma_key not in matlab_result_dict) or \
|
||
|
(md_key not in matlab_result_dict):
|
||
|
raise KeyError(
|
||
|
"Precompted Matlab result not found for wavelet: "
|
||
|
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||
|
ma = matlab_result_dict[ma_key]
|
||
|
md = matlab_result_dict[md_key]
|
||
|
return ma, md
|
||
|
|
||
|
|
||
|
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
|
||
|
""" Load the precomputed result.
|
||
|
"""
|
||
|
N = len(data)
|
||
|
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||
|
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||
|
if (ma_key not in matlab_result_dict) or \
|
||
|
(md_key not in matlab_result_dict):
|
||
|
raise KeyError(
|
||
|
"Precompted Matlab result not found for wavelet: "
|
||
|
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||
|
ma = matlab_result_dict[ma_key]
|
||
|
md = matlab_result_dict[md_key]
|
||
|
return ma, md
|
||
|
|
||
|
|
||
|
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
|
||
|
# PyWavelets result
|
||
|
pa, pd = pywt.dwt(data, w, pmode)
|
||
|
|
||
|
# calculate error measures
|
||
|
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
|
||
|
rms_d = np.sqrt(np.mean((pd - md) ** 2))
|
||
|
|
||
|
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
|
||
|
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
|
||
|
assert_(rms_a < epsilon, msg=msg)
|
||
|
|
||
|
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
|
||
|
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
|
||
|
assert_(rms_d < epsilon, msg=msg)
|