Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/pywt/tests/test_cwt_wavelets.py

435 lines
12 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from itertools import product
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises, assert_equal)
import pytest
import numpy as np
import pywt
def ref_gaus(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
if (num == 1):
psi = -2.*X*F0
elif (num == 2):
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
elif (num == 3):
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
elif (num == 4):
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
elif (num == 5):
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
elif (num == 6):
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
elif (num == 7):
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
elif (num == 8):
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
224*X**6 + 16*X**8)*F0
return (psi, X)
def ref_cgau(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = np.exp(-X**2)
F1 = np.exp(-1j*X)
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
if (num == 1):
psi = F2*(-1j - 2*X)*2**(1/2)
elif (num == 2):
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
elif (num == 3):
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
elif (num == 4):
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
elif (num == 5):
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
80j*X**4 - 32*X**5)*210**(1/2)
elif (num == 6):
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
192j*X**5 + 64*X**6)*2310**(1/2)
elif (num == 7):
psi = 1/45045*F2*(
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
448j*X**6 - 128*X**7)*30030**(1/2)
elif (num == 8):
psi = 1/45045*F2*(
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
return (psi, X)
def sinc2(x):
y = np.ones_like(x)
k = np.where(x)[0]
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
return y
def ref_shan(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_fbsp(LB, UB, N, m, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_cmor(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
return (psi, x)
def ref_morl(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = np.exp(-(x**2)/2)*np.cos(5*x)
return (psi, x)
def ref_mexh(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
return (psi, x)
def test_gaus():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_gaus(LB, UB, N, num)
w = pywt.ContinuousWavelet("gaus" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cgau():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_cgau(LB, UB, N, num)
w = pywt.ContinuousWavelet("cgau" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_shan():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_cmor():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_fbsp():
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1
Fc = 1.5
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1.5
Fc = 1
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 3
Fb = 1.5
Fc = 1.2
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
# TODO: investigate why atol = 1e-5 is necessary
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
assert_allclose(X, x, atol=1e-15)
def test_morl():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_morl(LB, UB, N)
w = pywt.ContinuousWavelet("morl")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_mexh():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
LB = -5
UB = 5
N = 1001
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cwt_parameters_in_names():
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
for name in ['fbsp', 'cmor', 'shan']:
# additional parameters should be specified within the name
assert_warns(FutureWarning, func, name)
for name in ['cmor', 'shan']:
# valid names
func(name + '1.5-1.0')
func(name + '1-4')
# invalid names
assert_raises(ValueError, func, name + '1.0')
assert_raises(ValueError, func, name + 'B-C')
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
# valid names
func('fbsp1-1.5-1.0')
func('fbsp1.0-1.5-1')
func('fbsp2-5-1')
# invalid name (non-integer order)
assert_raises(ValueError, func, 'fbsp1.5-1-1')
assert_raises(ValueError, func, 'fbspM-B-C')
# invalid name (too few or too many params)
assert_raises(ValueError, func, 'fbsp1.0')
assert_raises(ValueError, func, 'fbsp1.0-0.4')
assert_raises(ValueError, func, 'fbsp1-1-1-1')
@pytest.mark.parametrize('dtype, tol, method',
[(np.float32, 1e-5, 'conv'),
(np.float32, 1e-5, 'fft'),
(np.float64, 1e-13, 'conv'),
(np.float64, 1e-13, 'fft')])
def test_cwt_complex(dtype, tol, method):
time, sst = pywt.data.nino()
sst = np.asarray(sst, dtype=dtype)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# real-valued tranfsorm as a reference
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# complex-valued transform equals sum of the transforms of the real
# and imaginary components
sst_complex = sst + 1j*sst
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
method=method)
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
# verify dtype is preserved
assert_equal(cfs_complex.dtype, sst_complex.dtype)
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
def test_cwt_batch(axis, method):
dtype = np.float64
time, sst = pywt.data.nino()
n_batch = 8
batch_axis = 1 - axis
sst1 = np.asarray(sst, dtype=dtype)
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# non-batch transform as reference
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
shape_in = sst.shape
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
# shape of input is not modified
assert_equal(shape_in, sst.shape)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# verify expected shape
assert_equal(cfs.shape[0], len(scales))
assert_equal(cfs.shape[1 + batch_axis], n_batch)
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
# batch result on stacked input is the same as stacked 1d result
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
def test_cwt_small_scales():
data = np.zeros(32)
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
# mexh. This corner case should not raise an error.
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
assert_allclose(cfs, np.zeros_like(cfs))
# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
def test_cwt_method_fft():
rstate = np.random.RandomState(1)
data = rstate.randn(50)
data[15] = 1.
scales = np.arange(1, 64)
wavelet = 'cmor1.5-1.0'
# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)