Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
real_dtypes = [np.float32, np.float64]
|
||||
|
||||
|
||||
def _sign(x):
|
||||
# Matlab-like sign function (numpy uses a different convention).
|
||||
return x / np.abs(x)
|
||||
|
||||
|
||||
def _soft(x, thresh):
|
||||
"""soft thresholding supporting complex values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This version is not robust to zeros in x.
|
||||
"""
|
||||
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
||||
|
||||
|
||||
def test_threshold():
|
||||
data = np.linspace(1, 4, 7)
|
||||
|
||||
# soft
|
||||
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
||||
np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
||||
-np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
||||
[[0, 1]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
|
||||
# soft thresholding complex values
|
||||
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
||||
[[0j, 1j]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
complex_data = [[1+2j, 2+2j]]*2
|
||||
for thresh in [1, 2]:
|
||||
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
||||
_soft(complex_data, thresh), rtol=1e-12)
|
||||
|
||||
# test soft thresholding with non-default substitute argument
|
||||
s = 5
|
||||
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
||||
[[s, 0.5]] * 2, rtol=1e-12)
|
||||
|
||||
# soft: no divide by zero warnings when input contains zeros
|
||||
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
||||
np.zeros(16), rtol=1e-12)
|
||||
|
||||
# hard
|
||||
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
||||
np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
||||
-np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
||||
[[0, 2+2j]] * 2, rtol=1e-12)
|
||||
|
||||
# greater
|
||||
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
||||
np.array(greater_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
# greater doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
||||
|
||||
# less
|
||||
assert_allclose(pywt.threshold(data, 2, 'less'),
|
||||
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
||||
[[1, 0]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
||||
[[1, s]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
|
||||
# less doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
||||
|
||||
# invalid
|
||||
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
||||
|
||||
|
||||
def test_nonnegative_garotte():
|
||||
thresh = 0.3
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_garotte.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_garotte[lt] == 0))
|
||||
|
||||
# values > than the threshold are intermediate between soft and hard
|
||||
gt = np.where(np.abs(data) > thresh)
|
||||
gt_abs_garotte = np.abs(d_garotte[gt])
|
||||
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
||||
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
||||
|
||||
|
||||
def test_threshold_firm():
|
||||
thresh = 0.2
|
||||
thresh2 = 3 * thresh
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
if data.real.dtype == np.float32:
|
||||
rtol = atol = 1e-6
|
||||
else:
|
||||
rtol = atol = 1e-14
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_firm.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_firm[lt] == 0))
|
||||
|
||||
# values > than the threshold are equal to hard-thresholding
|
||||
gt = np.where(np.abs(data) >= thresh2)
|
||||
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
# other values are intermediate between soft and hard thresholding
|
||||
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
||||
np.abs(data) < thresh2))
|
||||
mt_abs_firm = np.abs(d_firm[mt])
|
||||
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
||||
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|
Loading…
Add table
Add a link
Reference in a new issue