Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
40
venv/Lib/site-packages/pywt/__init__.py
Normal file
40
venv/Lib/site-packages/pywt/__init__.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# flake8: noqa
|
||||
|
||||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
Discrete forward and inverse wavelet transform, stationary wavelet transform,
|
||||
wavelet packets signal decomposition and reconstruction module.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from ._extensions._pywt import *
|
||||
from ._functions import *
|
||||
from ._multilevel import *
|
||||
from ._multidim import *
|
||||
from ._thresholding import *
|
||||
from ._wavelet_packets import *
|
||||
from ._dwt import *
|
||||
from ._swt import *
|
||||
from ._cwt import *
|
||||
|
||||
from . import data
|
||||
|
||||
__all__ = [s for s in dir() if not s.startswith('_')]
|
||||
try:
|
||||
# In Python 2.x the name of the tempvar leaks out of the list
|
||||
# comprehension. Delete it to not make it show up in the main namespace.
|
||||
del s
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
from pywt.version import version as __version__
|
||||
|
||||
from ._pytesttester import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
BIN
venv/Lib/site-packages/pywt/__pycache__/__init__.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/__init__.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_cwt.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_cwt.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_dwt.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_dwt.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_multidim.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_multidim.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_pytest.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_pytest.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_swt.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_swt.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/_utils.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/_utils.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/conftest.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/conftest.cpython-36.pyc
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/__pycache__/version.cpython-36.pyc
Normal file
BIN
venv/Lib/site-packages/pywt/__pycache__/version.cpython-36.pyc
Normal file
Binary file not shown.
3
venv/Lib/site-packages/pywt/_c99_config.py
Normal file
3
venv/Lib/site-packages/pywt/_c99_config.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
# Autogenerated file containing compile-time definitions
|
||||
|
||||
_have_c99_complex = 0
|
203
venv/Lib/site-packages/pywt/_cwt.py
Normal file
203
venv/Lib/site-packages/pywt/_cwt.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
from math import floor, ceil
|
||||
|
||||
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
|
||||
Wavelet, _check_dtype)
|
||||
from ._functions import integrate_wavelet, scale2frequency
|
||||
|
||||
|
||||
__all__ = ["cwt"]
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
# Prefer scipy.fft (new in SciPy 1.4)
|
||||
import scipy.fft
|
||||
fftmodule = scipy.fft
|
||||
next_fast_len = fftmodule.next_fast_len
|
||||
except ImportError:
|
||||
try:
|
||||
import scipy.fftpack
|
||||
fftmodule = scipy.fftpack
|
||||
next_fast_len = fftmodule.next_fast_len
|
||||
except ImportError:
|
||||
fftmodule = np.fft
|
||||
|
||||
# provide a fallback so scipy is an optional requirement
|
||||
def next_fast_len(n):
|
||||
"""Round up size to the nearest power of two.
|
||||
|
||||
Given a number of samples `n`, returns the next power of two
|
||||
following this number to take advantage of FFT speedup.
|
||||
This fallback is less efficient than `scipy.fftpack.next_fast_len`
|
||||
"""
|
||||
return 2**ceil(np.log2(n))
|
||||
|
||||
|
||||
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
|
||||
"""
|
||||
cwt(data, scales, wavelet)
|
||||
|
||||
One dimensional Continuous Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Input signal
|
||||
scales : array_like
|
||||
The wavelet scales to use. One can use
|
||||
``f = scale2frequency(wavelet, scale)/sampling_period`` to determine
|
||||
what physical frequency, ``f``. Here, ``f`` is in hertz when the
|
||||
``sampling_period`` is given in seconds.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
sampling_period : float
|
||||
Sampling period for the frequencies output (optional).
|
||||
The values computed for ``coefs`` are independent of the choice of
|
||||
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
|
||||
period).
|
||||
method : {'conv', 'fft'}, optional
|
||||
The method used to compute the CWT. Can be any of:
|
||||
- ``conv`` uses ``numpy.convolve``.
|
||||
- ``fft`` uses frequency domain convolution.
|
||||
- ``auto`` uses automatic selection based on an estimate of the
|
||||
computational complexity at each scale.
|
||||
|
||||
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
|
||||
The ``fft`` method is ``O(N * log2(N))`` with
|
||||
``N = len(scale) + len(data) - 1``. It is well suited for large size
|
||||
signals but slightly slower than ``conv`` on small ones.
|
||||
axis: int, optional
|
||||
Axis over which to compute the CWT. If not given, the last axis is
|
||||
used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coefs : array_like
|
||||
Continuous wavelet transform of the input signal for the given scales
|
||||
and wavelet. The first axis of ``coefs`` corresponds to the scales.
|
||||
The remaining axes match the shape of ``data``.
|
||||
frequencies : array_like
|
||||
If the unit of sampling period are seconds and given, than frequencies
|
||||
are in hertz. Otherwise, a sampling period of 1 is assumed.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Size of coefficients arrays depends on the length of the input array and
|
||||
the length of given scales.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> import numpy as np
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> x = np.arange(512)
|
||||
>>> y = np.sin(2*np.pi*x/32)
|
||||
>>> coef, freqs=pywt.cwt(y,np.arange(1,129),'gaus1')
|
||||
>>> plt.matshow(coef) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
----------
|
||||
>>> import pywt
|
||||
>>> import numpy as np
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> t = np.linspace(-1, 1, 200, endpoint=False)
|
||||
>>> sig = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
|
||||
>>> widths = np.arange(1, 31)
|
||||
>>> cwtmatr, freqs = pywt.cwt(sig, widths, 'mexh')
|
||||
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
|
||||
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) # doctest: +SKIP
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt)
|
||||
dt_cplx = np.result_type(dt, np.complex64)
|
||||
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
if np.isscalar(scales):
|
||||
scales = np.array([scales])
|
||||
if not np.isscalar(axis):
|
||||
raise ValueError("axis must be a scalar.")
|
||||
|
||||
dt_out = dt_cplx if wavelet.complex_cwt else dt
|
||||
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
|
||||
precision = 10
|
||||
int_psi, x = integrate_wavelet(wavelet, precision=precision)
|
||||
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
||||
|
||||
# convert int_psi, x to the same precision as the data
|
||||
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
|
||||
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
||||
x = np.asarray(x, dtype=data.real.dtype)
|
||||
|
||||
if method == 'fft':
|
||||
size_scale0 = -1
|
||||
fft_data = None
|
||||
elif not method == 'conv':
|
||||
raise ValueError("method must be 'conv' or 'fft'")
|
||||
|
||||
if data.ndim > 1:
|
||||
# move axis to be transformed last (so it is contiguous)
|
||||
data = data.swapaxes(-1, axis)
|
||||
|
||||
# reshape to (n_batch, data.shape[-1])
|
||||
data_shape_pre = data.shape
|
||||
data = data.reshape((-1, data.shape[-1]))
|
||||
|
||||
for i, scale in enumerate(scales):
|
||||
step = x[1] - x[0]
|
||||
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
|
||||
j = j.astype(int) # floor
|
||||
if j[-1] >= int_psi.size:
|
||||
j = np.extract(j < int_psi.size, j)
|
||||
int_psi_scale = int_psi[j][::-1]
|
||||
|
||||
if method == 'conv':
|
||||
if data.ndim == 1:
|
||||
conv = np.convolve(data, int_psi_scale)
|
||||
else:
|
||||
# batch convolution via loop
|
||||
conv_shape = list(data.shape)
|
||||
conv_shape[-1] += int_psi_scale.size - 1
|
||||
conv_shape = tuple(conv_shape)
|
||||
conv = np.empty(conv_shape, dtype=dt_out)
|
||||
for n in range(data.shape[0]):
|
||||
conv[n, :] = np.convolve(data[n], int_psi_scale)
|
||||
else:
|
||||
# The padding is selected for:
|
||||
# - optimal FFT complexity
|
||||
# - to be larger than the two signals length to avoid circular
|
||||
# convolution
|
||||
size_scale = next_fast_len(
|
||||
data.shape[-1] + int_psi_scale.size - 1
|
||||
)
|
||||
if size_scale != size_scale0:
|
||||
# Must recompute fft_data when the padding size changes.
|
||||
fft_data = fftmodule.fft(data, size_scale, axis=-1)
|
||||
size_scale0 = size_scale
|
||||
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
|
||||
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
|
||||
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
|
||||
|
||||
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
|
||||
if out.dtype.kind != 'c':
|
||||
coef = coef.real
|
||||
# transform axis is always -1 due to the data reshape above
|
||||
d = (coef.shape[-1] - data.shape[-1]) / 2.
|
||||
if d > 0:
|
||||
coef = coef[..., floor(d):-ceil(d)]
|
||||
elif d < 0:
|
||||
raise ValueError(
|
||||
"Selected scale of {} too small.".format(scale))
|
||||
if data.ndim > 1:
|
||||
# restore original data shape and axis position
|
||||
coef = coef.reshape(data_shape_pre)
|
||||
coef = coef.swapaxes(axis, -1)
|
||||
out[i, ...] = coef
|
||||
|
||||
frequencies = scale2frequency(wavelet, scales, precision)
|
||||
if np.isscalar(frequencies):
|
||||
frequencies = np.array([frequencies])
|
||||
frequencies /= sampling_period
|
||||
return out, frequencies
|
187
venv/Lib/site-packages/pywt/_doc_utils.py
Normal file
187
venv/Lib/site-packages/pywt/_doc_utils.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
"""Utilities used to generate various figures in the documentation."""
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ._dwt import pad
|
||||
|
||||
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
|
||||
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
|
||||
|
||||
|
||||
def wavedec_keys(level):
|
||||
"""Subband keys corresponding to a wavedec decomposition."""
|
||||
approx = ''
|
||||
coeffs = {}
|
||||
for lev in range(level):
|
||||
for k in ['a', 'd']:
|
||||
coeffs[approx + k] = None
|
||||
approx = 'a' * (lev + 1)
|
||||
if lev < level - 1:
|
||||
coeffs.pop(approx)
|
||||
return list(coeffs.keys())
|
||||
|
||||
|
||||
def wavedec2_keys(level):
|
||||
"""Subband keys corresponding to a wavedec2 decomposition."""
|
||||
approx = ''
|
||||
coeffs = {}
|
||||
for lev in range(level):
|
||||
for k in ['a', 'h', 'v', 'd']:
|
||||
coeffs[approx + k] = None
|
||||
approx = 'a' * (lev + 1)
|
||||
if lev < level - 1:
|
||||
coeffs.pop(approx)
|
||||
return list(coeffs.keys())
|
||||
|
||||
|
||||
def _box(bl, ur):
|
||||
"""(x, y) coordinates for the 4 lines making up a rectangular box.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
bl : float
|
||||
The bottom left corner of the box
|
||||
ur : float
|
||||
The upper right corner of the box
|
||||
|
||||
Returns
|
||||
=======
|
||||
coords : 2-tuple
|
||||
The first and second elements of the tuple are the x and y coordinates
|
||||
of the box.
|
||||
"""
|
||||
xl, xr = bl[0], ur[0]
|
||||
yb, yt = bl[1], ur[1]
|
||||
box_x = [xl, xr,
|
||||
xr, xr,
|
||||
xr, xl,
|
||||
xl, xl]
|
||||
box_y = [yb, yb,
|
||||
yb, yt,
|
||||
yt, yt,
|
||||
yt, yb]
|
||||
return (box_x, box_y)
|
||||
|
||||
|
||||
def _2d_wp_basis_coords(shape, keys):
|
||||
# Coordinates of the lines to be drawn by draw_2d_wp_basis
|
||||
coords = []
|
||||
centers = {} # retain center of boxes for use in labeling
|
||||
for key in keys:
|
||||
offset_x = offset_y = 0
|
||||
for n, char in enumerate(key):
|
||||
if char in ['h', 'd']:
|
||||
offset_x += shape[0] // 2**(n + 1)
|
||||
if char in ['v', 'd']:
|
||||
offset_y += shape[1] // 2**(n + 1)
|
||||
sx = shape[0] // 2**(n + 1)
|
||||
sy = shape[1] // 2**(n + 1)
|
||||
xc, yc = _box((offset_x, -offset_y),
|
||||
(offset_x + sx, -offset_y - sy))
|
||||
coords.append((xc, yc))
|
||||
centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
|
||||
return coords, centers
|
||||
|
||||
|
||||
def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
|
||||
label_levels=0):
|
||||
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
||||
coords, centers = _2d_wp_basis_coords(shape, keys)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
for coord in coords:
|
||||
ax.plot(coord[0], coord[1], fmt)
|
||||
ax.set_axis_off()
|
||||
ax.axis('square')
|
||||
if label_levels > 0:
|
||||
for key, c in centers.items():
|
||||
if len(key) <= label_levels:
|
||||
ax.text(c[0], c[1], key,
|
||||
horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
return fig, ax
|
||||
|
||||
|
||||
def _2d_fswavedecn_coords(shape, levels):
|
||||
coords = []
|
||||
centers = {} # retain center of boxes for use in labeling
|
||||
for key in product(wavedec_keys(levels), repeat=2):
|
||||
(key0, key1) = key
|
||||
offsets = [0, 0]
|
||||
widths = list(shape)
|
||||
for n0, char in enumerate(key0):
|
||||
if char in ['d']:
|
||||
offsets[0] += shape[0] // 2**(n0 + 1)
|
||||
for n1, char in enumerate(key1):
|
||||
if char in ['d']:
|
||||
offsets[1] += shape[1] // 2**(n1 + 1)
|
||||
widths[0] = shape[0] // 2**(n0 + 1)
|
||||
widths[1] = shape[1] // 2**(n1 + 1)
|
||||
xc, yc = _box((offsets[0], -offsets[1]),
|
||||
(offsets[0] + widths[0], -offsets[1] - widths[1]))
|
||||
coords.append((xc, yc))
|
||||
centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
|
||||
-offsets[1] - widths[1] / 2)
|
||||
return coords, centers
|
||||
|
||||
|
||||
def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
|
||||
label_levels=0):
|
||||
"""Plot a 2D representation of a WaveletPacket2D basis."""
|
||||
coords, centers = _2d_fswavedecn_coords(shape, levels)
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
else:
|
||||
fig = ax.get_figure()
|
||||
for coord in coords:
|
||||
ax.plot(coord[0], coord[1], fmt)
|
||||
ax.set_axis_off()
|
||||
ax.axis('square')
|
||||
if label_levels > 0:
|
||||
for key, c in centers.items():
|
||||
lev = np.max([len(k) for k in key])
|
||||
if lev <= label_levels:
|
||||
ax.text(c[0], c[1], key,
|
||||
horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
return fig, ax
|
||||
|
||||
|
||||
def boundary_mode_subplot(x, mode, ax, symw=True):
|
||||
"""Plot an illustration of the boundary mode in a subplot axis."""
|
||||
|
||||
# if odd-length, periodization replicates the last sample to make it even
|
||||
if mode == 'periodization' and len(x) % 2 == 1:
|
||||
x = np.concatenate((x, (x[-1], )))
|
||||
|
||||
npad = 2 * len(x)
|
||||
t = np.arange(len(x) + 2 * npad)
|
||||
xp = pad(x, (npad, npad), mode=mode)
|
||||
|
||||
ax.plot(t, xp, 'k.')
|
||||
ax.set_title(mode)
|
||||
|
||||
# plot the original signal in red
|
||||
if mode == 'periodization':
|
||||
ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
|
||||
else:
|
||||
ax.plot(t[npad:npad + len(x)], x, 'r.')
|
||||
|
||||
# add vertical bars indicating points of symmetry or boundary extension
|
||||
o2 = np.ones(2)
|
||||
left = npad
|
||||
if symw:
|
||||
step = len(x) - 1
|
||||
rng = range(-2, 4)
|
||||
else:
|
||||
left -= 0.5
|
||||
step = len(x)
|
||||
rng = range(-2, 4)
|
||||
if mode in ['smooth', 'constant', 'zero']:
|
||||
rng = range(0, 2)
|
||||
for rep in rng:
|
||||
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')
|
517
venv/Lib/site-packages/pywt/_dwt.py
Normal file
517
venv/Lib/site-packages/pywt/_dwt.py
Normal file
|
@ -0,0 +1,517 @@
|
|||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._pywt import Wavelet, Modes, _check_dtype, wavelist
|
||||
from ._extensions._dwt import (dwt_single, dwt_axis, idwt_single, idwt_axis,
|
||||
upcoef as _upcoef, downcoef as _downcoef,
|
||||
dwt_max_level as _dwt_max_level,
|
||||
dwt_coeff_len as _dwt_coeff_len)
|
||||
from ._utils import string_types, _as_wavelet
|
||||
|
||||
|
||||
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
|
||||
"dwt_coeff_len", "pad"]
|
||||
|
||||
|
||||
def dwt_max_level(data_len, filter_len):
|
||||
r"""
|
||||
dwt_max_level(data_len, filter_len)
|
||||
|
||||
Compute the maximum useful level of decomposition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_len : int
|
||||
Input data length.
|
||||
filter_len : int, str or Wavelet
|
||||
The wavelet filter length. Alternatively, the name of a discrete
|
||||
wavelet or a Wavelet object can be specified.
|
||||
|
||||
Returns
|
||||
-------
|
||||
max_level : int
|
||||
Maximum level.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The rational for the choice of levels is the maximum level where at least
|
||||
one coefficient in the output is uncorrupted by edge effects caused by
|
||||
signal extension. Put another way, decomposition stops when the signal
|
||||
becomes shorter than the FIR filter length for a given wavelet. This
|
||||
corresponds to:
|
||||
|
||||
.. max_level = floor(log2(data_len/(filter_len - 1)))
|
||||
|
||||
.. math::
|
||||
\mathtt{max\_level} = \left\lfloor\log_2\left(\mathtt{
|
||||
\frac{data\_len}{filter\_len - 1}}\right)\right\rfloor
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> w = pywt.Wavelet('sym5')
|
||||
>>> pywt.dwt_max_level(data_len=1000, filter_len=w.dec_len)
|
||||
6
|
||||
>>> pywt.dwt_max_level(1000, w)
|
||||
6
|
||||
>>> pywt.dwt_max_level(1000, 'sym5')
|
||||
6
|
||||
"""
|
||||
if isinstance(filter_len, Wavelet):
|
||||
filter_len = filter_len.dec_len
|
||||
elif isinstance(filter_len, string_types):
|
||||
if filter_len in wavelist(kind='discrete'):
|
||||
filter_len = Wavelet(filter_len).dec_len
|
||||
else:
|
||||
raise ValueError(
|
||||
("'{}', is not a recognized discrete wavelet. A list of "
|
||||
"supported wavelet names can be obtained via "
|
||||
"pywt.wavelist(kind='discrete')").format(filter_len))
|
||||
elif not (isinstance(filter_len, Number) and filter_len % 1 == 0):
|
||||
raise ValueError(
|
||||
"filter_len must be an integer, discrete Wavelet object, or the "
|
||||
"name of a discrete wavelet.")
|
||||
|
||||
if filter_len < 2:
|
||||
raise ValueError("invalid wavelet filter length")
|
||||
|
||||
return _dwt_max_level(data_len, filter_len)
|
||||
|
||||
|
||||
def dwt_coeff_len(data_len, filter_len, mode):
|
||||
"""
|
||||
dwt_coeff_len(data_len, filter_len, mode='symmetric')
|
||||
|
||||
Returns length of dwt output for given data length, filter length and mode
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_len : int
|
||||
Data length.
|
||||
filter_len : int
|
||||
Filter length.
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
len : int
|
||||
Length of dwt output.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For all modes except periodization::
|
||||
|
||||
len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)
|
||||
|
||||
for periodization mode ("per")::
|
||||
|
||||
len(cA) == len(cD) == ceil(len(data) / 2)
|
||||
|
||||
"""
|
||||
if isinstance(filter_len, Wavelet):
|
||||
filter_len = filter_len.dec_len
|
||||
|
||||
return _dwt_coeff_len(data_len, filter_len, Modes.from_object(mode))
|
||||
|
||||
|
||||
def dwt(data, wavelet, mode='symmetric', axis=-1):
|
||||
"""
|
||||
dwt(data, wavelet, mode='symmetric', axis=-1)
|
||||
|
||||
Single level Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Input signal
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
axis: int, optional
|
||||
Axis over which to compute the DWT. If not given, the
|
||||
last axis is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(cA, cD) : tuple
|
||||
Approximation and detail coefficients.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Length of coefficients arrays depends on the selected mode.
|
||||
For all modes except periodization:
|
||||
|
||||
``len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)``
|
||||
|
||||
For periodization mode ("per"):
|
||||
|
||||
``len(cA) == len(cD) == ceil(len(data) / 2)``
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> (cA, cD) = pywt.dwt([1, 2, 3, 4, 5, 6], 'db1')
|
||||
>>> cA
|
||||
array([ 2.12132034, 4.94974747, 7.77817459])
|
||||
>>> cD
|
||||
array([-0.70710678, -0.70710678, -0.70710678])
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
data = np.asarray(data)
|
||||
cA_r, cD_r = dwt(data.real, wavelet, mode, axis)
|
||||
cA_i, cD_i = dwt(data.imag, wavelet, mode, axis)
|
||||
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt, order='C')
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + data.ndim
|
||||
if not 0 <= axis < data.ndim:
|
||||
raise ValueError("Axis greater than data dimensions")
|
||||
|
||||
if data.ndim == 1:
|
||||
cA, cD = dwt_single(data, wavelet, mode)
|
||||
# TODO: Check whether this makes a copy
|
||||
cA, cD = np.asarray(cA, dt), np.asarray(cD, dt)
|
||||
else:
|
||||
cA, cD = dwt_axis(data, wavelet, mode, axis=axis)
|
||||
|
||||
return (cA, cD)
|
||||
|
||||
|
||||
def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
|
||||
"""
|
||||
idwt(cA, cD, wavelet, mode='symmetric', axis=-1)
|
||||
|
||||
Single level Inverse Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cA : array_like or None
|
||||
Approximation coefficients. If None, will be set to array of zeros
|
||||
with same shape as ``cD``.
|
||||
cD : array_like or None
|
||||
Detail coefficients. If None, will be set to array of zeros
|
||||
with same shape as ``cA``.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional (default: 'symmetric')
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
axis: int, optional
|
||||
Axis over which to compute the inverse DWT. If not given, the
|
||||
last axis is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec: array_like
|
||||
Single level reconstruction of signal from given coefficients.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
|
||||
>>> pywt.idwt(cA, cD, 'db2', 'smooth')
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD``
|
||||
arguments can be set to None. In that situation the reconstruction will be
|
||||
performed using only the other one. Mathematically speaking, this is
|
||||
equivalent to passing a zero-filled array as one of the arguments.
|
||||
|
||||
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
|
||||
>>> A = pywt.idwt(cA, None, 'db2', 'smooth')
|
||||
>>> D = pywt.idwt(None, cD, 'db2', 'smooth')
|
||||
>>> A + D
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
"""
|
||||
# TODO: Lots of possible allocations to eliminate (zeros_like, asarray(rec))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
|
||||
if cA is None and cD is None:
|
||||
raise ValueError("At least one coefficient parameter must be "
|
||||
"specified.")
|
||||
|
||||
# for complex inputs: compute real and imaginary separately then combine
|
||||
if not _have_c99_complex and (np.iscomplexobj(cA) or np.iscomplexobj(cD)):
|
||||
if cA is None:
|
||||
cD = np.asarray(cD)
|
||||
cA = np.zeros_like(cD)
|
||||
elif cD is None:
|
||||
cA = np.asarray(cA)
|
||||
cD = np.zeros_like(cA)
|
||||
return (idwt(cA.real, cD.real, wavelet, mode, axis) +
|
||||
1j*idwt(cA.imag, cD.imag, wavelet, mode, axis))
|
||||
|
||||
if cA is not None:
|
||||
dt = _check_dtype(cA)
|
||||
cA = np.asarray(cA, dtype=dt, order='C')
|
||||
if cD is not None:
|
||||
dt = _check_dtype(cD)
|
||||
cD = np.asarray(cD, dtype=dt, order='C')
|
||||
|
||||
if cA is not None and cD is not None:
|
||||
if cA.dtype != cD.dtype:
|
||||
# need to upcast to common type
|
||||
if cA.dtype.kind == 'c' or cD.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
cA = cA.astype(dtype)
|
||||
cD = cD.astype(dtype)
|
||||
elif cA is None:
|
||||
cA = np.zeros_like(cD)
|
||||
elif cD is None:
|
||||
cD = np.zeros_like(cA)
|
||||
|
||||
# cA and cD should be same dimension by here
|
||||
ndim = cA.ndim
|
||||
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + ndim
|
||||
if not 0 <= axis < ndim:
|
||||
raise ValueError("Axis greater than coefficient dimensions")
|
||||
|
||||
if ndim == 1:
|
||||
rec = idwt_single(cA, cD, wavelet, mode)
|
||||
else:
|
||||
rec = idwt_axis(cA, cD, wavelet, mode, axis=axis)
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def downcoef(part, data, wavelet, mode='symmetric', level=1):
|
||||
"""
|
||||
downcoef(part, data, wavelet, mode='symmetric', level=1)
|
||||
|
||||
Partial Discrete Wavelet Transform data decomposition.
|
||||
|
||||
Similar to ``pywt.dwt``, but computes only one set of coefficients.
|
||||
Useful when you need only approximation or only details at the given level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part : str
|
||||
Coefficients type:
|
||||
|
||||
* 'a' - approximations reconstruction is performed
|
||||
* 'd' - details reconstruction is performed
|
||||
|
||||
data : array_like
|
||||
Input signal.
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
level : int, optional
|
||||
Decomposition level. Default is 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : ndarray
|
||||
1-D array of coefficients.
|
||||
|
||||
See Also
|
||||
--------
|
||||
upcoef
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
return (downcoef(part, data.real, wavelet, mode, level) +
|
||||
1j*downcoef(part, data.imag, wavelet, mode, level))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.asarray(data, dtype=dt, order='C')
|
||||
if data.ndim > 1:
|
||||
raise ValueError("downcoef only supports 1d data.")
|
||||
if part not in 'ad':
|
||||
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
|
||||
mode = Modes.from_object(mode)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
return np.asarray(_downcoef(part == 'a', data, wavelet, mode, level))
|
||||
|
||||
|
||||
def upcoef(part, coeffs, wavelet, level=1, take=0):
|
||||
"""
|
||||
upcoef(part, coeffs, wavelet, level=1, take=0)
|
||||
|
||||
Direct reconstruction from coefficients.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part : str
|
||||
Coefficients type:
|
||||
* 'a' - approximations reconstruction is performed
|
||||
* 'd' - details reconstruction is performed
|
||||
coeffs : array_like
|
||||
Coefficients array to recontruct
|
||||
wavelet : Wavelet object or name
|
||||
Wavelet to use
|
||||
level : int, optional
|
||||
Multilevel reconstruction level. Default is 1.
|
||||
take : int, optional
|
||||
Take central part of length equal to 'take' from the result.
|
||||
Default is 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rec : ndarray
|
||||
1-D array with reconstructed data from coefficients.
|
||||
|
||||
See Also
|
||||
--------
|
||||
downcoef
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> data = [1,2,3,4,5,6]
|
||||
>>> (cA, cD) = pywt.dwt(data, 'db2', 'smooth')
|
||||
>>> pywt.upcoef('a', cA, 'db2') + pywt.upcoef('d', cD, 'db2')
|
||||
array([-0.25 , -0.4330127 , 1. , 2. , 3. ,
|
||||
4. , 5. , 6. , 1.78589838, -1.03108891])
|
||||
>>> n = len(data)
|
||||
>>> pywt.upcoef('a', cA, 'db2', take=n) + pywt.upcoef('d', cD, 'db2', take=n)
|
||||
array([ 1., 2., 3., 4., 5., 6.])
|
||||
|
||||
"""
|
||||
if not _have_c99_complex and np.iscomplexobj(coeffs):
|
||||
return (upcoef(part, coeffs.real, wavelet, level, take) +
|
||||
1j*upcoef(part, coeffs.imag, wavelet, level, take))
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(coeffs)
|
||||
coeffs = np.asarray(coeffs, dtype=dt, order='C')
|
||||
if coeffs.ndim > 1:
|
||||
raise ValueError("upcoef only supports 1d coeffs.")
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if part not in 'ad':
|
||||
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
|
||||
return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take))
|
||||
|
||||
|
||||
def pad(x, pad_widths, mode):
|
||||
"""Extend a 1D signal using a given boundary mode.
|
||||
|
||||
This function operates like :func:`numpy.pad` but supports all signal
|
||||
extension modes that can be used by PyWavelets discrete wavelet transforms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : ndarray
|
||||
The array to pad
|
||||
pad_widths : {sequence, array_like, int}
|
||||
Number of values padded to the edges of each axis.
|
||||
``((before_1, after_1), … (before_N, after_N))`` unique pad widths for
|
||||
each axis. ``((before, after),)`` yields same before and after pad for
|
||||
each axis. ``(pad,)`` or int is a shortcut for
|
||||
``before = after = pad width`` for all axes.
|
||||
mode : str, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pad : ndarray
|
||||
Padded array of rank equal to array with shape increased according to
|
||||
``pad_widths``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The performance of padding in dimensions > 1 may be substantially slower
|
||||
for modes ``'smooth'`` and ``'antisymmetric'`` as these modes are not
|
||||
supported efficiently by the underlying :func:`numpy.pad` function.
|
||||
|
||||
Note that the behavior of the ``'constant'`` mode here follows the
|
||||
PyWavelets convention which is different from NumPy (it is equivalent to
|
||||
``mode='edge'`` in :func:`numpy.pad`).
|
||||
"""
|
||||
x = np.asanyarray(x)
|
||||
|
||||
# process pad_widths exactly as in numpy.pad
|
||||
pad_widths = np.array(pad_widths)
|
||||
pad_widths = np.round(pad_widths).astype(np.intp, copy=False)
|
||||
if pad_widths.min() < 0:
|
||||
raise ValueError("pad_widths must be > 0")
|
||||
pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist()
|
||||
|
||||
if mode in ['symmetric', 'reflect']:
|
||||
xp = np.pad(x, pad_widths, mode=mode)
|
||||
elif mode in ['periodic', 'periodization']:
|
||||
if mode == 'periodization':
|
||||
# Promote odd-sized dimensions to even length by duplicating the
|
||||
# last value.
|
||||
edge_pad_widths = [(0, x.shape[ax] % 2)
|
||||
for ax in range(x.ndim)]
|
||||
x = np.pad(x, edge_pad_widths, mode='edge')
|
||||
xp = np.pad(x, pad_widths, mode='wrap')
|
||||
elif mode == 'zero':
|
||||
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
|
||||
elif mode == 'constant':
|
||||
xp = np.pad(x, pad_widths, mode='edge')
|
||||
elif mode == 'smooth':
|
||||
def pad_smooth(vector, pad_width, iaxis, kwargs):
|
||||
# smooth extension to left
|
||||
left = vector[pad_width[0]]
|
||||
slope_left = (left - vector[pad_width[0] + 1])
|
||||
vector[:pad_width[0]] = \
|
||||
left + np.arange(pad_width[0], 0, -1) * slope_left
|
||||
|
||||
# smooth extension to right
|
||||
right = vector[-pad_width[1] - 1]
|
||||
slope_right = (right - vector[-pad_width[1] - 2])
|
||||
vector[-pad_width[1]:] = \
|
||||
right + np.arange(1, pad_width[1] + 1) * slope_right
|
||||
return vector
|
||||
xp = np.pad(x, pad_widths, pad_smooth)
|
||||
elif mode == 'antisymmetric':
|
||||
def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
|
||||
# smooth extension to left
|
||||
# implement by flipping portions symmetric padding
|
||||
npad_l, npad_r = pad_width
|
||||
vsize_nonpad = vector.size - npad_l - npad_r
|
||||
# Note: must modify vector in-place
|
||||
vector[:] = np.pad(vector[pad_width[0]:-pad_width[-1]],
|
||||
pad_width, mode='symmetric')
|
||||
vp = vector
|
||||
r_edge = npad_l + vsize_nonpad - 1
|
||||
l_edge = npad_l
|
||||
# width of each reflected segment
|
||||
seg_width = vsize_nonpad
|
||||
# flip reflected segments on the right of the original signal
|
||||
n = 1
|
||||
while r_edge <= vp.size:
|
||||
segment_slice = slice(r_edge + 1,
|
||||
min(r_edge + 1 + seg_width, vp.size))
|
||||
if n % 2:
|
||||
vp[segment_slice] *= -1
|
||||
r_edge += seg_width
|
||||
n += 1
|
||||
|
||||
# flip reflected segments on the left of the original signal
|
||||
n = 1
|
||||
while l_edge >= 0:
|
||||
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
|
||||
if n % 2:
|
||||
vp[segment_slice] *= -1
|
||||
l_edge -= seg_width
|
||||
n += 1
|
||||
return vector
|
||||
xp = np.pad(x, pad_widths, pad_antisymmetric)
|
||||
elif mode == 'antireflect':
|
||||
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
|
||||
else:
|
||||
raise ValueError(
|
||||
("unsupported mode: {}. The supported modes are {}").format(
|
||||
mode, Modes.modes))
|
||||
return xp
|
0
venv/Lib/site-packages/pywt/_extensions/__init__.py
Normal file
0
venv/Lib/site-packages/pywt/_extensions/__init__.py
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/_extensions/_cwt.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/pywt/_extensions/_cwt.cp36-win32.pyd
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/_extensions/_dwt.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/pywt/_extensions/_dwt.cp36-win32.pyd
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/_extensions/_pywt.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/pywt/_extensions/_pywt.cp36-win32.pyd
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/_extensions/_swt.cp36-win32.pyd
Normal file
BIN
venv/Lib/site-packages/pywt/_extensions/_swt.cp36-win32.pyd
Normal file
Binary file not shown.
240
venv/Lib/site-packages/pywt/_functions.py
Normal file
240
venv/Lib/site-packages/pywt/_functions.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
Other wavelet related functions.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
|
||||
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
|
||||
|
||||
|
||||
__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
|
||||
"orthogonal_filter_bank",
|
||||
"intwave", "centrfrq", "scal2frq", "orthfilt"]
|
||||
|
||||
|
||||
_DEPRECATION_MSG = ("`{old}` has been renamed to `{new}` and will "
|
||||
"be removed in a future version of pywt.")
|
||||
|
||||
|
||||
def _integrate(arr, step):
|
||||
integral = np.cumsum(arr)
|
||||
integral *= step
|
||||
return integral
|
||||
|
||||
|
||||
def intwave(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='intwave', new='integrate_wavelet')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return integrate_wavelet(*args, **kwargs)
|
||||
|
||||
|
||||
def centrfrq(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='centrfrq', new='central_frequency')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return central_frequency(*args, **kwargs)
|
||||
|
||||
|
||||
def scal2frq(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='scal2frq', new='scale2frequency')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return scale2frequency(*args, **kwargs)
|
||||
|
||||
|
||||
def orthfilt(*args, **kwargs):
|
||||
msg = _DEPRECATION_MSG.format(old='orthfilt', new='orthogonal_filter_bank')
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return orthogonal_filter_bank(*args, **kwargs)
|
||||
|
||||
|
||||
def integrate_wavelet(wavelet, precision=8):
|
||||
"""
|
||||
Integrate `psi` wavelet function from -Inf to x using the rectangle
|
||||
integration method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance or str
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function
|
||||
approximation computed with the wavefun(level=precision)
|
||||
Wavelet's method (default: 8).
|
||||
|
||||
Returns
|
||||
-------
|
||||
[int_psi, x] :
|
||||
for orthogonal wavelets
|
||||
[int_psi_d, int_psi_r, x] :
|
||||
for other wavelets
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from pywt import Wavelet, integrate_wavelet
|
||||
>>> wavelet1 = Wavelet('db2')
|
||||
>>> [int_psi, x] = integrate_wavelet(wavelet1, precision=5)
|
||||
>>> wavelet2 = Wavelet('bior1.3')
|
||||
>>> [int_psi_d, int_psi_r, x] = integrate_wavelet(wavelet2, precision=5)
|
||||
|
||||
"""
|
||||
# FIXME: this function should really use scipy.integrate.quad
|
||||
|
||||
if type(wavelet) in (tuple, list):
|
||||
msg = ("Integration of a general signal is deprecated "
|
||||
"and will be removed in a future version of pywt.")
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
|
||||
if type(wavelet) in (tuple, list):
|
||||
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
functions_approximations = wavelet.wavefun(precision)
|
||||
|
||||
if len(functions_approximations) == 2: # continuous wavelet
|
||||
psi, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
elif len(functions_approximations) == 3: # orthogonal wavelet
|
||||
phi, psi, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi, step), x
|
||||
|
||||
else: # biorthogonal wavelet
|
||||
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
|
||||
step = x[1] - x[0]
|
||||
return _integrate(psi_d, step), _integrate(psi_r, step), x
|
||||
|
||||
|
||||
def central_frequency(wavelet, precision=8):
|
||||
"""
|
||||
Computes the central frequency of the `psi` wavelet function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance, str or tuple
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function
|
||||
approximation computed with the wavefun(level=precision)
|
||||
Wavelet's method (default: 8).
|
||||
|
||||
Returns
|
||||
-------
|
||||
scalar
|
||||
|
||||
"""
|
||||
|
||||
if not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
|
||||
functions_approximations = wavelet.wavefun(precision)
|
||||
|
||||
if len(functions_approximations) == 2:
|
||||
psi, x = functions_approximations
|
||||
else:
|
||||
# (psi, x) for (phi, psi, x)
|
||||
# (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x)
|
||||
psi, x = functions_approximations[1], functions_approximations[-1]
|
||||
|
||||
domain = float(x[-1] - x[0])
|
||||
assert domain > 0
|
||||
|
||||
index = np.argmax(abs(fft(psi)[1:])) + 2
|
||||
if index > len(psi) / 2:
|
||||
index = len(psi) - index + 2
|
||||
|
||||
return 1.0 / (domain / (index - 1))
|
||||
|
||||
|
||||
def scale2frequency(wavelet, scale, precision=8):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet instance or str
|
||||
Wavelet to integrate. If a string, should be the name of a wavelet.
|
||||
scale : scalar
|
||||
precision : int, optional
|
||||
Precision that will be used for wavelet function approximation computed
|
||||
with ``wavelet.wavefun(level=precision)``. Default is 8.
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq : scalar
|
||||
|
||||
"""
|
||||
return central_frequency(wavelet, precision=precision) / scale
|
||||
|
||||
|
||||
def qmf(filt):
|
||||
"""
|
||||
Returns the Quadrature Mirror Filter(QMF).
|
||||
|
||||
The magnitude response of QMF is mirror image about `pi/2` of that of the
|
||||
input filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filt : array_like
|
||||
Input filter for which QMF needs to be computed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
qm_filter : ndarray
|
||||
Quadrature mirror of the input filter.
|
||||
|
||||
"""
|
||||
qm_filter = np.array(filt)[::-1]
|
||||
qm_filter[1::2] = -qm_filter[1::2]
|
||||
return qm_filter
|
||||
|
||||
|
||||
def orthogonal_filter_bank(scaling_filter):
|
||||
"""
|
||||
Returns the orthogonal filter bank.
|
||||
|
||||
The orthogonal filter bank consists of the HPFs and LPFs at
|
||||
decomposition and reconstruction stage for the input scaling filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scaling_filter : array_like
|
||||
Input scaling filter (father wavelet).
|
||||
|
||||
Returns
|
||||
-------
|
||||
orth_filt_bank : tuple of 4 ndarrays
|
||||
The orthogonal filter bank of the input scaling filter in the order :
|
||||
1] Decomposition LPF
|
||||
2] Decomposition HPF
|
||||
3] Reconstruction LPF
|
||||
4] Reconstruction HPF
|
||||
|
||||
"""
|
||||
if not (len(scaling_filter) % 2 == 0):
|
||||
raise ValueError("`scaling_filter` length has to be even.")
|
||||
|
||||
scaling_filter = np.asarray(scaling_filter, dtype=np.float64)
|
||||
|
||||
rec_lo = np.sqrt(2) * scaling_filter / np.sum(scaling_filter)
|
||||
dec_lo = rec_lo[::-1]
|
||||
|
||||
rec_hi = qmf(rec_lo)
|
||||
dec_hi = rec_hi[::-1]
|
||||
|
||||
orth_filt_bank = (dec_lo, dec_hi, rec_lo, rec_hi)
|
||||
return orth_filt_bank
|
311
venv/Lib/site-packages/pywt/_multidim.py
Normal file
311
venv/Lib/site-packages/pywt/_multidim.py
Normal file
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
2D and nD Discrete Wavelet Transforms and Inverse Discrete Wavelet Transforms.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._dwt import dwt_axis, idwt_axis
|
||||
from ._utils import _wavelets_per_axis, _modes_per_axis
|
||||
|
||||
|
||||
__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']
|
||||
|
||||
|
||||
def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
|
||||
"""
|
||||
2D Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
2D array with input data
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or 2-tuple of strings, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
|
||||
also be a tuple of modes specifying the mode to use on each axis in
|
||||
``axes``.
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements mean the DWT will
|
||||
be performed multiple times along these axes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(cA, (cH, cV, cD)) : tuple
|
||||
Approximation, horizontal detail, vertical detail and diagonal
|
||||
detail coefficients respectively. Horizontal refers to array axis 0
|
||||
(or ``axes[0]`` for user-specified ``axes``).
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.ones((4,4), dtype=np.float64)
|
||||
>>> coeffs = pywt.dwt2(data, 'haar')
|
||||
>>> cA, (cH, cV, cD) = coeffs
|
||||
>>> cA
|
||||
array([[ 2., 2.],
|
||||
[ 2., 2.]])
|
||||
>>> cV
|
||||
array([[ 0., 0.],
|
||||
[ 0., 0.]])
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
data = np.asarray(data)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
if data.ndim < len(np.unique(axes)):
|
||||
raise ValueError("Input array has fewer dimensions than the specified "
|
||||
"axes")
|
||||
|
||||
coefs = dwtn(data, wavelet, mode, axes)
|
||||
return coefs['aa'], (coefs['da'], coefs['ad'], coefs['dd'])
|
||||
|
||||
|
||||
def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
|
||||
"""
|
||||
2-D Inverse Discrete Wavelet Transform.
|
||||
|
||||
Reconstructs data from coefficient arrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : tuple
|
||||
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
|
||||
details coefficients 2D arrays like from ``dwt2``. If any of these
|
||||
components are set to ``None``, it will be treated as zeros.
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or 2-tuple of strings, optional
|
||||
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
|
||||
also be a tuple of modes specifying the mode to use on each axis in
|
||||
``axes``.
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the IDWT. Repeated elements mean the IDWT
|
||||
will be performed multiple times along these axes.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.array([[1,2], [3,4]], dtype=np.float64)
|
||||
>>> coeffs = pywt.dwt2(data, 'haar')
|
||||
>>> pywt.idwt2(coeffs, 'haar')
|
||||
array([[ 1., 2.],
|
||||
[ 3., 4.]])
|
||||
|
||||
"""
|
||||
# L -low-pass data, H - high-pass data
|
||||
LL, (HL, LH, HH) = coeffs
|
||||
axes = tuple(axes)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
|
||||
coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
return idwtn(coeffs, wavelet, mode, axes)
|
||||
|
||||
|
||||
def dwtn(data, wavelet, mode='symmetric', axes=None):
|
||||
"""
|
||||
Single-level n-dimensional Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
n-dimensional array with input data.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or tuple of string, optional
|
||||
Signal extension mode used in the decomposition,
|
||||
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
|
||||
specifying the mode to use on each axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the DWT. Repeated elements mean the DWT will
|
||||
be performed multiple times along these axes. A value of ``None`` (the
|
||||
default) selects all axes.
|
||||
|
||||
Axes may be repeated, but information about the original size may be
|
||||
lost if it is not divisible by ``2 ** nrepeats``. The reconstruction
|
||||
will be larger, with additional values derived according to the
|
||||
``mode`` parameter. ``pywt.wavedecn`` should be used for multilevel
|
||||
decomposition.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : dict
|
||||
Results are arranged in a dictionary, where key specifies
|
||||
the transform type on each dimension and value is a n-dimensional
|
||||
coefficients array.
|
||||
|
||||
For example, for a 2D case the result will look something like this::
|
||||
|
||||
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
|
||||
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
|
||||
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
|
||||
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
|
||||
}
|
||||
|
||||
For user-specified ``axes``, the order of the characters in the
|
||||
dictionary keys map to the specified ``axes``.
|
||||
|
||||
"""
|
||||
data = np.asarray(data)
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
real = dwtn(data.real, wavelet, mode, axes)
|
||||
imag = dwtn(data.imag, wavelet, mode, axes)
|
||||
return dict((k, real[k] + 1j * imag[k]) for k in real.keys())
|
||||
|
||||
if data.dtype == np.dtype('object'):
|
||||
raise TypeError("Input must be a numeric array-like")
|
||||
if data.ndim < 1:
|
||||
raise ValueError("Input data must be at least 1D")
|
||||
|
||||
if axes is None:
|
||||
axes = range(data.ndim)
|
||||
axes = [a + data.ndim if a < 0 else a for a in axes]
|
||||
|
||||
modes = _modes_per_axis(mode, axes)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
|
||||
coeffs = [('', data)]
|
||||
for axis, wav, mode in zip(axes, wavelets, modes):
|
||||
new_coeffs = []
|
||||
for subband, x in coeffs:
|
||||
cA, cD = dwt_axis(x, wav, mode, axis)
|
||||
new_coeffs.extend([(subband + 'a', cA),
|
||||
(subband + 'd', cD)])
|
||||
coeffs = new_coeffs
|
||||
return dict(coeffs)
|
||||
|
||||
|
||||
def _fix_coeffs(coeffs):
|
||||
missing_keys = [k for k, v in coeffs.items() if v is None]
|
||||
if missing_keys:
|
||||
raise ValueError(
|
||||
"The following detail coefficients were set to None:\n"
|
||||
"{0}\n"
|
||||
"For multilevel transforms, rather than setting\n"
|
||||
"\tcoeffs[key] = None\n"
|
||||
"use\n"
|
||||
"\tcoeffs[key] = np.zeros_like(coeffs[key])\n".format(
|
||||
missing_keys))
|
||||
|
||||
invalid_keys = [k for k, v in coeffs.items() if
|
||||
not set(k) <= set('ad')]
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
"The following invalid keys were found in the detail "
|
||||
"coefficient dictionary: {}.".format(invalid_keys))
|
||||
|
||||
key_lengths = [len(k) for k in coeffs.keys()]
|
||||
if len(np.unique(key_lengths)) > 1:
|
||||
raise ValueError(
|
||||
"All detail coefficient names must have equal length.")
|
||||
|
||||
return dict((k, np.asarray(v)) for k, v in coeffs.items())
|
||||
|
||||
|
||||
def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
|
||||
"""
|
||||
Single-level n-dimensional Inverse Discrete Wavelet Transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs: dict
|
||||
Dictionary as in output of ``dwtn``. Missing or ``None`` items
|
||||
will be treated as zeros.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple containing a wavelet to
|
||||
apply along each axis in ``axes``.
|
||||
mode : str or list of string, optional
|
||||
Signal extension mode used in the decomposition,
|
||||
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
|
||||
specifying the mode to use on each axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the IDWT. Repeated elements mean the IDWT
|
||||
will be performed multiple times along these axes. A value of ``None``
|
||||
(the default) selects all axes.
|
||||
|
||||
For the most accurate reconstruction, the axes should be provided in
|
||||
the same order as they were provided to ``dwtn``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data: ndarray
|
||||
Original signal reconstructed from input data.
|
||||
|
||||
"""
|
||||
|
||||
# drop the keys corresponding to value = None
|
||||
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
|
||||
|
||||
# Raise error for invalid key combinations
|
||||
coeffs = _fix_coeffs(coeffs)
|
||||
|
||||
if (not _have_c99_complex and
|
||||
any(np.iscomplexobj(v) for v in coeffs.values())):
|
||||
real_coeffs = dict((k, v.real) for k, v in coeffs.items())
|
||||
imag_coeffs = dict((k, v.imag) for k, v in coeffs.items())
|
||||
return (idwtn(real_coeffs, wavelet, mode, axes) +
|
||||
1j * idwtn(imag_coeffs, wavelet, mode, axes))
|
||||
|
||||
# key length matches the number of axes transformed
|
||||
ndim_transform = max(len(key) for key in coeffs.keys())
|
||||
|
||||
try:
|
||||
coeff_shapes = (v.shape for k, v in coeffs.items()
|
||||
if v is not None and len(k) == ndim_transform)
|
||||
coeff_shape = next(coeff_shapes)
|
||||
except StopIteration:
|
||||
raise ValueError("`coeffs` must contain at least one non-null wavelet "
|
||||
"band")
|
||||
if any(s != coeff_shape for s in coeff_shapes):
|
||||
raise ValueError("`coeffs` must all be of equal size (or None)")
|
||||
|
||||
if axes is None:
|
||||
axes = range(ndim_transform)
|
||||
ndim = ndim_transform
|
||||
else:
|
||||
ndim = len(coeff_shape)
|
||||
axes = [a + ndim if a < 0 else a for a in axes]
|
||||
|
||||
modes = _modes_per_axis(mode, axes)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
for key_length, (axis, wav, mode) in reversed(
|
||||
list(enumerate(zip(axes, wavelets, modes)))):
|
||||
if axis < 0 or axis >= ndim:
|
||||
raise ValueError("Axis greater than data dimensions")
|
||||
|
||||
new_coeffs = {}
|
||||
new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]
|
||||
|
||||
for key in new_keys:
|
||||
L = coeffs.get(key + 'a', None)
|
||||
H = coeffs.get(key + 'd', None)
|
||||
if L is not None and H is not None:
|
||||
if L.dtype != H.dtype:
|
||||
# upcast to a common dtype (float64 or complex128)
|
||||
if L.dtype.kind == 'c' or H.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
L = np.asarray(L, dtype=dtype)
|
||||
H = np.asarray(H, dtype=dtype)
|
||||
new_coeffs[key] = idwt_axis(L, H, wav, mode, axis)
|
||||
coeffs = new_coeffs
|
||||
|
||||
return coeffs['']
|
1551
venv/Lib/site-packages/pywt/_multilevel.py
Normal file
1551
venv/Lib/site-packages/pywt/_multilevel.py
Normal file
File diff suppressed because it is too large
Load diff
68
venv/Lib/site-packages/pywt/_pytest.py
Normal file
68
venv/Lib/site-packages/pywt/_pytest.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
"""common test-related code."""
|
||||
import os
|
||||
import sys
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
__all__ = ['uses_matlab', # skip if pymatbridge and Matlab unavailable
|
||||
'uses_futures', # skip if futures unavailable
|
||||
'uses_pymatbridge', # skip if no PYWT_XSLOW environment variable
|
||||
'uses_precomputed', # skip if PYWT_XSLOW environment variable found
|
||||
'matlab_result_dict_cwt', # dict with precomputed Matlab dwt data
|
||||
'matlab_result_dict_dwt', # dict with precomputed Matlab cwt data
|
||||
'futures', # the futures module or None
|
||||
'max_workers', # the number of workers available to futures
|
||||
'size_set', # the set of Matlab tests to run
|
||||
]
|
||||
|
||||
try:
|
||||
if sys.version_info[0] == 2:
|
||||
import futures
|
||||
else:
|
||||
from concurrent import futures
|
||||
max_workers = multiprocessing.cpu_count()
|
||||
futures_available = True
|
||||
except ImportError:
|
||||
futures_available = False
|
||||
futures = None
|
||||
|
||||
# check if pymatbridge + MATLAB tests should be run
|
||||
matlab_result_dict_dwt = None
|
||||
matlab_result_dict_cwt = None
|
||||
matlab_missing = True
|
||||
use_precomputed = True
|
||||
size_set = 'reduced'
|
||||
if 'PYWT_XSLOW' in os.environ:
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
matlab_missing = False
|
||||
use_precomputed = False
|
||||
size_set = 'full'
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
if use_precomputed:
|
||||
# load dictionaries of precomputed results
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'tests', 'data')
|
||||
matlab_data_file_cwt = os.path.join(
|
||||
data_dir, 'cwt_matlabR2015b_result.npz')
|
||||
matlab_result_dict_cwt = np.load(matlab_data_file_cwt)
|
||||
|
||||
matlab_data_file_dwt = os.path.join(
|
||||
data_dir, 'dwt_matlabR2012a_result.npz')
|
||||
matlab_result_dict_dwt = np.load(matlab_data_file_dwt)
|
||||
|
||||
uses_futures = pytest.mark.skipif(
|
||||
not futures_available, reason='futures not available')
|
||||
uses_matlab = pytest.mark.skipif(
|
||||
matlab_missing, reason='pymatbridge and/or Matlab not available')
|
||||
uses_pymatbridge = pytest.mark.skipif(
|
||||
use_precomputed,
|
||||
reason='PYWT_XSLOW set: skipping tests against precomputed Matlab results')
|
||||
uses_precomputed = pytest.mark.skipif(
|
||||
not use_precomputed,
|
||||
reason='PYWT_XSLOW not set: test against precomputed matlab tests')
|
164
venv/Lib/site-packages/pywt/_pytesttester.py
Normal file
164
venv/Lib/site-packages/pywt/_pytesttester.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
"""
|
||||
Pytest test running.
|
||||
|
||||
This module implements the ``test()`` function for NumPy modules. The usual
|
||||
boiler plate for doing that is to put the following in the module
|
||||
``__init__.py`` file::
|
||||
|
||||
from pywt._pytesttester import PytestTester
|
||||
test = PytestTester(__name__).test
|
||||
del PytestTester
|
||||
|
||||
|
||||
Warnings filtering and other runtime settings should be dealt with in the
|
||||
``pytest.ini`` file in the pywt repo root. The behavior of the test depends on
|
||||
whether or not that file is found as follows:
|
||||
|
||||
* ``pytest.ini`` is present (develop mode)
|
||||
All warnings except those explicily filtered out are raised as error.
|
||||
* ``pytest.ini`` is absent (release mode)
|
||||
DeprecationWarnings and PendingDeprecationWarnings are ignored, other
|
||||
warnings are passed through.
|
||||
|
||||
In practice, tests run from the PyWavelets repo are run in develop mode. That
|
||||
includes the standard ``python runtests.py`` invocation.
|
||||
|
||||
"""
|
||||
from __future__ import division, absolute_import, print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['PytestTester']
|
||||
|
||||
|
||||
def _show_pywt_info():
|
||||
import pywt
|
||||
from pywt._c99_config import _have_c99_complex
|
||||
print("PyWavelets version %s" % pywt.__version__)
|
||||
if _have_c99_complex:
|
||||
print("Compiled with C99 complex support.")
|
||||
else:
|
||||
print("Compiled without C99 complex support.")
|
||||
|
||||
|
||||
class PytestTester(object):
|
||||
"""
|
||||
Pytest test runner.
|
||||
|
||||
This class is made available in ``pywt.testing``, and a test function
|
||||
is typically added to a package's __init__.py like so::
|
||||
|
||||
from pywt.testing import PytestTester
|
||||
test = PytestTester(__name__).test
|
||||
del PytestTester
|
||||
|
||||
Calling this test function finds and runs all tests associated with the
|
||||
module and all its sub-modules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
module_name : str
|
||||
Full path to the package to test.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module_name : module name
|
||||
The name of the module to test.
|
||||
|
||||
"""
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
|
||||
def __call__(self, label='fast', verbose=1, extra_argv=None,
|
||||
doctests=False, coverage=False, durations=-1, tests=None):
|
||||
"""
|
||||
Run tests for module using pytest.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label : {'fast', 'full'}, optional
|
||||
Identifies the tests to run. When set to 'fast', tests decorated
|
||||
with `pytest.mark.slow` are skipped, when 'full', the slow marker
|
||||
is ignored.
|
||||
verbose : int, optional
|
||||
Verbosity value for test outputs, in the range 1-3. Default is 1.
|
||||
extra_argv : list, optional
|
||||
List with any extra arguments to pass to pytests.
|
||||
doctests : bool, optional
|
||||
.. note:: Not supported
|
||||
coverage : bool, optional
|
||||
If True, report coverage of NumPy code. Default is False.
|
||||
Requires installation of (pip) pytest-cov.
|
||||
durations : int, optional
|
||||
If < 0, do nothing, If 0, report time of all tests, if > 0,
|
||||
report the time of the slowest `timer` tests. Default is -1.
|
||||
tests : test or list of tests
|
||||
Tests to be executed with pytest '--pyargs'
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : bool
|
||||
Return True on success, false otherwise.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> result = np.lib.test() #doctest: +SKIP
|
||||
...
|
||||
1023 passed, 2 skipped, 6 deselected, 1 xfailed in 10.39 seconds
|
||||
>>> result
|
||||
True
|
||||
|
||||
"""
|
||||
import pytest
|
||||
|
||||
module = sys.modules[self.module_name]
|
||||
module_path = os.path.abspath(module.__path__[0])
|
||||
|
||||
# setup the pytest arguments
|
||||
pytest_args = ["-l"]
|
||||
|
||||
# offset verbosity. The "-q" cancels a "-v".
|
||||
pytest_args += ["-q"]
|
||||
|
||||
# Filter out annoying import messages. Want these in both develop and
|
||||
# release mode.
|
||||
pytest_args += [
|
||||
"-W ignore:Not importing directory",
|
||||
"-W ignore:numpy.dtype size changed",
|
||||
"-W ignore:numpy.ufunc size changed", ]
|
||||
|
||||
if doctests:
|
||||
raise ValueError("Doctests not supported")
|
||||
|
||||
if extra_argv:
|
||||
pytest_args += list(extra_argv)
|
||||
|
||||
if verbose > 1:
|
||||
pytest_args += ["-" + "v"*(verbose - 1)]
|
||||
|
||||
if coverage:
|
||||
pytest_args += ["--cov=" + module_path]
|
||||
|
||||
if label == "fast":
|
||||
pytest_args += ["-m", "not slow"]
|
||||
elif label != "full":
|
||||
pytest_args += ["-m", label]
|
||||
|
||||
if durations >= 0:
|
||||
pytest_args += ["--durations=%s" % durations]
|
||||
|
||||
if tests is None:
|
||||
tests = [self.module_name]
|
||||
|
||||
pytest_args += ["--pyargs"] + list(tests)
|
||||
|
||||
# run tests.
|
||||
_show_pywt_info()
|
||||
|
||||
try:
|
||||
code = pytest.main(pytest_args)
|
||||
except SystemExit as exc:
|
||||
code = exc.code
|
||||
|
||||
return code == 0
|
774
venv/Lib/site-packages/pywt/_swt.py
Normal file
774
venv/Lib/site-packages/pywt/_swt.py
Normal file
|
@ -0,0 +1,774 @@
|
|||
import warnings
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._c99_config import _have_c99_complex
|
||||
from ._extensions._dwt import idwt_single
|
||||
from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis
|
||||
from ._extensions._pywt import Wavelet, Modes, _check_dtype
|
||||
from ._multidim import idwt2, idwtn
|
||||
from ._utils import _as_wavelet, _wavelets_per_axis
|
||||
|
||||
|
||||
__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']
|
||||
|
||||
|
||||
def _rescale_wavelet_filterbank(wavelet, sf):
|
||||
wav = Wavelet(wavelet.name + 'r',
|
||||
[np.asarray(f) * sf for f in wavelet.filter_bank])
|
||||
|
||||
# copy attributes from the original wavelet
|
||||
wav.orthogonal = wavelet.orthogonal
|
||||
wav.biorthogonal = wavelet.biorthogonal
|
||||
return wav
|
||||
|
||||
|
||||
def swt(data, wavelet, level=None, start_level=0, axis=-1,
|
||||
trim_approx=False, norm=False):
|
||||
"""
|
||||
Multilevel 1D stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
Input signal
|
||||
wavelet :
|
||||
Wavelet to use (Wavelet object or name)
|
||||
level : int, optional
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will begin (it allows one to
|
||||
skip a given number of transform steps and compute
|
||||
coefficients starting from start_level) (default: 0)
|
||||
axis: int, optional
|
||||
Axis over which to compute the SWT. If not given, the
|
||||
last axis is used.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
List of approximation and details coefficients pairs in order
|
||||
similar to wavedec function::
|
||||
|
||||
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
|
||||
|
||||
where n equals input parameter ``level``.
|
||||
|
||||
If ``start_level = m`` is given, then the beginning m steps are
|
||||
skipped::
|
||||
|
||||
[(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)]
|
||||
|
||||
If ``trim_approx`` is ``True``, then the output list is exactly as in
|
||||
``pywt.wavedec``, where the first coefficient in the list is the
|
||||
approximation coefficient at the final level and the rest are the
|
||||
detail coefficients::
|
||||
|
||||
[cAn, cDn, ..., cD2, cD1]
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axis be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swt`` is called with ``norm=True``
|
||||
3. ``swt`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
When used with ``norm=True``, this transform is closely related to the
|
||||
multiple-overlap DWT (MODWT) as popularized for time-series analysis,
|
||||
although the underlying implementation is slightly different from the one
|
||||
published in [1]_. Specifically, the implementation used here requires a
|
||||
signal that is a multiple of ``2**level`` in length.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] DB Percival and AT Walden. Wavelet Methods for Time Series Analysis.
|
||||
Cambridge University Press, 2000.
|
||||
"""
|
||||
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
data = np.asarray(data)
|
||||
coeffs_real = swt(data.real, wavelet, level, start_level, trim_approx)
|
||||
coeffs_imag = swt(data.imag, wavelet, level, start_level, trim_approx)
|
||||
if not trim_approx:
|
||||
coeffs_cplx = []
|
||||
for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag):
|
||||
coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i))
|
||||
else:
|
||||
coeffs_cplx = [cr + 1j*ci
|
||||
for (cr, ci) in zip(coeffs_real, coeffs_imag)]
|
||||
return coeffs_cplx
|
||||
|
||||
# accept array_like input; make a copy to ensure a contiguous array
|
||||
dt = _check_dtype(data)
|
||||
data = np.array(data, dtype=dt)
|
||||
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if norm:
|
||||
if not wavelet.orthogonal:
|
||||
warnings.warn(
|
||||
"norm=True, but the wavelet is not orthogonal: \n"
|
||||
"\tThe conditions for energy preservation are not satisfied.")
|
||||
wavelet = _rescale_wavelet_filterbank(wavelet, 1/np.sqrt(2))
|
||||
|
||||
if axis < 0:
|
||||
axis = axis + data.ndim
|
||||
if not 0 <= axis < data.ndim:
|
||||
raise ValueError("Axis greater than data dimensions")
|
||||
|
||||
if level is None:
|
||||
level = swt_max_level(data.shape[axis])
|
||||
|
||||
if data.ndim == 1:
|
||||
ret = _swt(data, wavelet, level, start_level, trim_approx)
|
||||
else:
|
||||
ret = _swt_axis(data, wavelet, level, start_level, axis, trim_approx)
|
||||
return ret
|
||||
|
||||
|
||||
def iswt(coeffs, wavelet, norm=False):
|
||||
"""
|
||||
Multilevel 1D inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : array_like
|
||||
Coefficients list of tuples::
|
||||
|
||||
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
|
||||
|
||||
where cA is approximation, cD is details. Index 1 corresponds to
|
||||
``start_level`` from ``pywt.swt``.
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet to use
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swt`` to preserve the
|
||||
energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
1D array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swt([1,2,3,4,5,6,7,8], 'db2', level=2)
|
||||
>>> pywt.iswt(coeffs, 'db2')
|
||||
array([ 1., 2., 3., 4., 5., 6., 7., 8.])
|
||||
"""
|
||||
# copy to avoid modification of input data
|
||||
|
||||
# If swt was called with trim_approx=False, first element is a tuple
|
||||
trim_approx = not isinstance(coeffs[0], (tuple, list))
|
||||
|
||||
if trim_approx:
|
||||
cA = coeffs[0]
|
||||
coeffs = coeffs[1:]
|
||||
else:
|
||||
cA = coeffs[0][0]
|
||||
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
if not _have_c99_complex and np.iscomplexobj(output):
|
||||
# compute real and imaginary separately then combine
|
||||
if trim_approx:
|
||||
coeffs_real = [c.real for c in coeffs]
|
||||
coeffs_imag = [c.imag for c in coeffs]
|
||||
else:
|
||||
coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs]
|
||||
coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs]
|
||||
return iswt(coeffs_real, wavelet) + 1j*iswt(coeffs_imag, wavelet)
|
||||
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelet = _as_wavelet(wavelet)
|
||||
if norm:
|
||||
wavelet = _rescale_wavelet_filterbank(wavelet, np.sqrt(2))
|
||||
mode = Modes.from_object('periodization')
|
||||
for j in range(num_levels, 0, -1):
|
||||
step_size = int(pow(2, j-1))
|
||||
last_index = step_size
|
||||
if trim_approx:
|
||||
cD = coeffs[-j]
|
||||
else:
|
||||
_, cD = coeffs[-j]
|
||||
cD = np.asarray(cD, dtype=_check_dtype(cD))
|
||||
if cD.dtype != output.dtype:
|
||||
# upcast to a common dtype (float64 or complex128)
|
||||
if output.dtype.kind == 'c' or cD.dtype.kind == 'c':
|
||||
dtype = np.complex128
|
||||
else:
|
||||
dtype = np.float64
|
||||
output = np.asarray(output, dtype=dtype)
|
||||
cD = np.asarray(cD, dtype=dtype)
|
||||
for first in range(last_index): # 0 to last_index - 1
|
||||
|
||||
# Getting the indices that we will transform
|
||||
indices = np.arange(first, len(cD), step_size)
|
||||
|
||||
# select the even indices
|
||||
even_indices = indices[0::2]
|
||||
# select the odd indices
|
||||
odd_indices = indices[1::2]
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
# Note: indexing with an array of ints returns a contiguous
|
||||
# copy as required by idwt_single.
|
||||
x1 = idwt_single(output[even_indices],
|
||||
cD[even_indices],
|
||||
wavelet, mode)
|
||||
x2 = idwt_single(output[odd_indices],
|
||||
cD[odd_indices],
|
||||
wavelet, mode)
|
||||
|
||||
# perform a circular shift right
|
||||
x2 = np.roll(x2, 1)
|
||||
|
||||
# average and insert into the correct indices
|
||||
output[indices] = (x1 + x2)/2.
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def swt2(data, wavelet, level, start_level=0, axes=(-2, -1),
|
||||
trim_approx=False, norm=False):
|
||||
"""
|
||||
Multilevel 2D stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
2D array with input data
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
level : int
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will start (default: 0)
|
||||
axes : 2-tuple of ints, optional
|
||||
Axes over which to compute the SWT. Repeated elements are not allowed.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
coeffs : list
|
||||
Approximation and details coefficients (for ``start_level = m``).
|
||||
If ``trim_approx`` is ``True``, approximation coefficients are
|
||||
retained for all levels::
|
||||
|
||||
[
|
||||
(cA_m+level,
|
||||
(cH_m+level, cV_m+level, cD_m+level)
|
||||
),
|
||||
...,
|
||||
(cA_m+1,
|
||||
(cH_m+1, cV_m+1, cD_m+1)
|
||||
),
|
||||
(cA_m,
|
||||
(cH_m, cV_m, cD_m)
|
||||
)
|
||||
]
|
||||
|
||||
where cA is approximation, cH is horizontal details, cV is
|
||||
vertical details, cD is diagonal details and m is ``start_level``.
|
||||
|
||||
If ``trim_approx`` is ``False``, approximation coefficients are only
|
||||
retained at the final level of decomposition. This matches the format
|
||||
used by ``pywt.wavedec2``::
|
||||
|
||||
[
|
||||
cA_m+level,
|
||||
(cH_m+level, cV_m+level, cD_m+level),
|
||||
...,
|
||||
(cH_m+1, cV_m+1, cD_m+1),
|
||||
(cH_m, cV_m, cD_m),
|
||||
]
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axes be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swt2`` is called with ``norm=True``
|
||||
3. ``swt2`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
data = np.asarray(data)
|
||||
if len(axes) != 2:
|
||||
raise ValueError("Expected 2 axes")
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swt2 must be unique.")
|
||||
if data.ndim < len(np.unique(axes)):
|
||||
raise ValueError("Input array has fewer dimensions than the specified "
|
||||
"axes")
|
||||
|
||||
coefs = swtn(data, wavelet, level, start_level, axes, trim_approx, norm)
|
||||
ret = []
|
||||
if trim_approx:
|
||||
ret.append(coefs[0])
|
||||
coefs = coefs[1:]
|
||||
for c in coefs:
|
||||
if trim_approx:
|
||||
ret.append((c['da'], c['ad'], c['dd']))
|
||||
else:
|
||||
ret.append((c['aa'], (c['da'], c['ad'], c['dd'])))
|
||||
return ret
|
||||
|
||||
|
||||
def iswt2(coeffs, wavelet, norm=False):
|
||||
"""
|
||||
Multilevel 2D inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : list
|
||||
Approximation and details coefficients::
|
||||
|
||||
[
|
||||
(cA_n,
|
||||
(cH_n, cV_n, cD_n)
|
||||
),
|
||||
...,
|
||||
(cA_2,
|
||||
(cH_2, cV_2, cD_2)
|
||||
),
|
||||
(cA_1,
|
||||
(cH_1, cV_1, cD_1)
|
||||
)
|
||||
]
|
||||
|
||||
where cA is approximation, cH is horizontal details, cV is
|
||||
vertical details, cD is diagonal details and n is the number of
|
||||
levels. Index 1 corresponds to ``start_level`` from ``pywt.swt2``.
|
||||
wavelet : Wavelet object or name string, or 2-tuple of wavelets
|
||||
Wavelet to use. This can also be a 2-tuple of wavelets to apply per
|
||||
axis.
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swt2`` to preserve
|
||||
the energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
2D array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swt2([[1,2,3,4],[5,6,7,8],
|
||||
... [9,10,11,12],[13,14,15,16]],
|
||||
... 'db1', level=2)
|
||||
>>> pywt.iswt2(coeffs, 'db1')
|
||||
array([[ 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8.],
|
||||
[ 9., 10., 11., 12.],
|
||||
[ 13., 14., 15., 16.]])
|
||||
|
||||
"""
|
||||
|
||||
# If swt was called with trim_approx=False, first element is a tuple
|
||||
trim_approx = not isinstance(coeffs[0], (tuple, list))
|
||||
if trim_approx:
|
||||
cA = coeffs[0]
|
||||
coeffs = coeffs[1:]
|
||||
else:
|
||||
cA = coeffs[0][0]
|
||||
|
||||
# copy to avoid modification of input data
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
|
||||
if output.ndim != 2:
|
||||
raise ValueError(
|
||||
"iswt2 only supports 2D arrays. see iswtn for a general "
|
||||
"n-dimensionsal ISWT")
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes=(0, 1))
|
||||
if norm:
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
|
||||
for j in range(num_levels):
|
||||
step_size = int(pow(2, num_levels-j-1))
|
||||
last_index = step_size
|
||||
if trim_approx:
|
||||
(cH, cV, cD) = coeffs[j]
|
||||
else:
|
||||
_, (cH, cV, cD) = coeffs[j]
|
||||
# We are going to assume cH, cV, and cD are of equal size
|
||||
if (cH.shape != cV.shape) or (cH.shape != cD.shape):
|
||||
raise RuntimeError(
|
||||
"Mismatch in shape of intermediate coefficient arrays")
|
||||
|
||||
# make sure output shares the common dtype
|
||||
# (conversion of dtype for individual coeffs is handled within idwt2 )
|
||||
common_dtype = np.result_type(*(
|
||||
[dt, ] + [_check_dtype(c) for c in [cH, cV, cD]]))
|
||||
if output.dtype != common_dtype:
|
||||
output = output.astype(common_dtype)
|
||||
|
||||
for first_h in range(last_index): # 0 to last_index - 1
|
||||
for first_w in range(last_index): # 0 to last_index - 1
|
||||
# Getting the indices that we will transform
|
||||
indices_h = slice(first_h, cH.shape[0], step_size)
|
||||
indices_w = slice(first_w, cH.shape[1], step_size)
|
||||
|
||||
even_idx_h = slice(first_h, cH.shape[0], 2*step_size)
|
||||
even_idx_w = slice(first_w, cH.shape[1], 2*step_size)
|
||||
odd_idx_h = slice(first_h + step_size, cH.shape[0], 2*step_size)
|
||||
odd_idx_w = slice(first_w + step_size, cH.shape[1], 2*step_size)
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
x1 = idwt2((output[even_idx_h, even_idx_w],
|
||||
(cH[even_idx_h, even_idx_w],
|
||||
cV[even_idx_h, even_idx_w],
|
||||
cD[even_idx_h, even_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x2 = idwt2((output[even_idx_h, odd_idx_w],
|
||||
(cH[even_idx_h, odd_idx_w],
|
||||
cV[even_idx_h, odd_idx_w],
|
||||
cD[even_idx_h, odd_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x3 = idwt2((output[odd_idx_h, even_idx_w],
|
||||
(cH[odd_idx_h, even_idx_w],
|
||||
cV[odd_idx_h, even_idx_w],
|
||||
cD[odd_idx_h, even_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
x4 = idwt2((output[odd_idx_h, odd_idx_w],
|
||||
(cH[odd_idx_h, odd_idx_w],
|
||||
cV[odd_idx_h, odd_idx_w],
|
||||
cD[odd_idx_h, odd_idx_w])),
|
||||
wavelets, 'periodization')
|
||||
|
||||
# perform a circular shifts
|
||||
x2 = np.roll(x2, 1, axis=1)
|
||||
x3 = np.roll(x3, 1, axis=0)
|
||||
x4 = np.roll(x4, 1, axis=0)
|
||||
x4 = np.roll(x4, 1, axis=1)
|
||||
output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False,
|
||||
norm=False):
|
||||
"""
|
||||
n-dimensional stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
n-dimensional array with input data.
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
level : int
|
||||
The number of decomposition steps to perform.
|
||||
start_level : int, optional
|
||||
The level at which the decomposition will start (default: 0)
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the SWT. A value of ``None`` (the
|
||||
default) selects all axes. Axes may not be repeated.
|
||||
trim_approx : bool, optional
|
||||
If True, approximation coefficients at the final level are retained.
|
||||
norm : bool, optional
|
||||
If True, transform is normalized so that the energy of the coefficients
|
||||
will be equal to the energy of ``data``. In other words,
|
||||
``np.linalg.norm(data.ravel())`` will equal the norm of the
|
||||
concatenated transform coefficients when ``trim_approx`` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
|
||||
Results for each level are arranged in a dictionary, where the key
|
||||
specifies the transform type on each dimension and value is a
|
||||
n-dimensional coefficients array.
|
||||
|
||||
For example, for a 2D case the result at a given level will look
|
||||
something like this::
|
||||
|
||||
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
|
||||
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
|
||||
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
|
||||
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
|
||||
}
|
||||
|
||||
For user-specified ``axes``, the order of the characters in the
|
||||
dictionary keys map to the specified ``axes``.
|
||||
|
||||
If ``trim_approx`` is ``True``, the first element of the list contains
|
||||
the array of approximation coefficients from the final level of
|
||||
decomposition, while the remaining coefficient dictionaries contain
|
||||
only detail coefficients. This matches the behavior of `pywt.wavedecn`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The implementation here follows the "algorithm a-trous" and requires that
|
||||
the signal length along the transformed axes be a multiple of ``2**level``.
|
||||
If this is not the case, the user should pad up to an appropriate size
|
||||
using a function such as ``numpy.pad``.
|
||||
|
||||
A primary benefit of this transform in comparison to its decimated
|
||||
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
|
||||
at cost of redundancy in the transform (the size of the output coefficients
|
||||
is larger than the input).
|
||||
|
||||
When the following three conditions are true:
|
||||
|
||||
1. The wavelet is orthogonal
|
||||
2. ``swtn`` is called with ``norm=True``
|
||||
3. ``swtn`` is called with ``trim_approx=True``
|
||||
|
||||
the transform has the following additional properties that may be
|
||||
desirable in applications:
|
||||
|
||||
1. energy is conserved
|
||||
2. variance is partitioned across scales
|
||||
|
||||
"""
|
||||
data = np.asarray(data)
|
||||
if not _have_c99_complex and np.iscomplexobj(data):
|
||||
real = swtn(data.real, wavelet, level, start_level, axes, trim_approx)
|
||||
imag = swtn(data.imag, wavelet, level, start_level, axes, trim_approx)
|
||||
if trim_approx:
|
||||
cplx = [real[0] + 1j * imag[0]]
|
||||
offset = 1
|
||||
else:
|
||||
cplx = []
|
||||
offset = 0
|
||||
for rdict, idict in zip(real[offset:], imag[offset:]):
|
||||
cplx.append(
|
||||
dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys()))
|
||||
return cplx
|
||||
|
||||
if data.dtype == np.dtype('object'):
|
||||
raise TypeError("Input must be a numeric array-like")
|
||||
if data.ndim < 1:
|
||||
raise ValueError("Input data must be at least 1D")
|
||||
|
||||
if axes is None:
|
||||
axes = range(data.ndim)
|
||||
axes = [a + data.ndim if a < 0 else a for a in axes]
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swtn must be unique.")
|
||||
num_axes = len(axes)
|
||||
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
if norm:
|
||||
if not np.all([wav.orthogonal for wav in wavelets]):
|
||||
warnings.warn(
|
||||
"norm=True, but the wavelets used are not orthogonal: \n"
|
||||
"\tThe conditions for energy preservation are not satisfied.")
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, 1/np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
ret = []
|
||||
for i in range(start_level, start_level + level):
|
||||
coeffs = [('', data)]
|
||||
for axis, wavelet in zip(axes, wavelets):
|
||||
new_coeffs = []
|
||||
for subband, x in coeffs:
|
||||
cA, cD = _swt_axis(x, wavelet, level=1, start_level=i,
|
||||
axis=axis)[0]
|
||||
new_coeffs.extend([(subband + 'a', cA),
|
||||
(subband + 'd', cD)])
|
||||
coeffs = new_coeffs
|
||||
|
||||
coeffs = dict(coeffs)
|
||||
ret.append(coeffs)
|
||||
|
||||
# data for the next level is the approximation coeffs from this level
|
||||
data = coeffs['a' * num_axes]
|
||||
if trim_approx:
|
||||
coeffs.pop('a' * num_axes)
|
||||
if trim_approx:
|
||||
ret.append(data)
|
||||
ret.reverse()
|
||||
return ret
|
||||
|
||||
|
||||
def iswtn(coeffs, wavelet, axes=None, norm=False):
|
||||
"""
|
||||
Multilevel nD inverse discrete stationary wavelet transform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeffs : list
|
||||
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
|
||||
wavelet : Wavelet object or name string, or tuple of wavelets
|
||||
Wavelet to use. This can also be a tuple of wavelets to apply per
|
||||
axis in ``axes``.
|
||||
axes : sequence of ints, optional
|
||||
Axes over which to compute the inverse SWT. Axes may not be repeated.
|
||||
The default is ``None``, which means transform all axes
|
||||
(``axes = range(data.ndim)``).
|
||||
norm : bool, optional
|
||||
Controls the normalization used by the inverse transform. This must
|
||||
be set equal to the value that was used by ``pywt.swtn`` to preserve
|
||||
the energy of a round-trip transform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nD array of reconstructed data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt
|
||||
>>> coeffs = pywt.swtn([[1,2,3,4],[5,6,7,8],
|
||||
... [9,10,11,12],[13,14,15,16]],
|
||||
... 'db1', level=2)
|
||||
>>> pywt.iswtn(coeffs, 'db1')
|
||||
array([[ 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8.],
|
||||
[ 9., 10., 11., 12.],
|
||||
[ 13., 14., 15., 16.]])
|
||||
|
||||
"""
|
||||
|
||||
# key length matches the number of axes transformed
|
||||
ndim_transform = max(len(key) for key in coeffs[-1].keys())
|
||||
|
||||
trim_approx = not isinstance(coeffs[0], dict)
|
||||
if trim_approx:
|
||||
cA = coeffs[0]
|
||||
coeffs = coeffs[1:]
|
||||
else:
|
||||
cA = coeffs[0]['a'*ndim_transform]
|
||||
|
||||
# copy to avoid modification of input data
|
||||
dt = _check_dtype(cA)
|
||||
output = np.array(cA, dtype=dt, copy=True)
|
||||
ndim = output.ndim
|
||||
|
||||
if axes is None:
|
||||
axes = range(output.ndim)
|
||||
axes = [a + ndim if a < 0 else a for a in axes]
|
||||
if len(axes) != len(set(axes)):
|
||||
raise ValueError("The axes passed to swtn must be unique.")
|
||||
if ndim_transform != len(axes):
|
||||
raise ValueError("The number of axes used in iswtn must match the "
|
||||
"number of dimensions transformed in swtn.")
|
||||
|
||||
# num_levels, equivalent to the decomposition level, n
|
||||
num_levels = len(coeffs)
|
||||
wavelets = _wavelets_per_axis(wavelet, axes)
|
||||
if norm:
|
||||
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
|
||||
for wav in wavelets]
|
||||
|
||||
# initialize various slice objects used in the loops below
|
||||
# these will remain slice(None) only on axes that aren't transformed
|
||||
indices = [slice(None), ]*ndim
|
||||
even_indices = [slice(None), ]*ndim
|
||||
odd_indices = [slice(None), ]*ndim
|
||||
odd_even_slices = [slice(None), ]*ndim
|
||||
|
||||
for j in range(num_levels):
|
||||
step_size = int(pow(2, num_levels-j-1))
|
||||
last_index = step_size
|
||||
if not trim_approx:
|
||||
a = coeffs[j].pop('a'*ndim_transform) # will restore later
|
||||
details = coeffs[j]
|
||||
# make sure dtype matches the coarsest level approximation coefficients
|
||||
common_dtype = np.result_type(*(
|
||||
[dt, ] + [v.dtype for v in details.values()]))
|
||||
if output.dtype != common_dtype:
|
||||
output = output.astype(common_dtype)
|
||||
|
||||
# We assume all coefficient arrays are of equal size
|
||||
shapes = [v.shape for k, v in details.items()]
|
||||
if len(set(shapes)) != 1:
|
||||
raise RuntimeError(
|
||||
"Mismatch in shape of intermediate coefficient arrays")
|
||||
|
||||
# shape of a single coefficient array, excluding non-transformed axes
|
||||
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])
|
||||
|
||||
# nested loop over all combinations of axis offsets at this level
|
||||
for firsts in product(*([range(last_index), ]*ndim_transform)):
|
||||
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
|
||||
indices[ax] = slice(first, sh, step_size)
|
||||
even_indices[ax] = slice(first, sh, 2*step_size)
|
||||
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)
|
||||
|
||||
# nested loop over all combinations of odd/even inidices
|
||||
approx = output.copy()
|
||||
output[tuple(indices)] = 0
|
||||
ntransforms = 0
|
||||
for odds in product(*([(0, 1), ]*ndim_transform)):
|
||||
for o, ax in zip(odds, axes):
|
||||
if o:
|
||||
odd_even_slices[ax] = odd_indices[ax]
|
||||
else:
|
||||
odd_even_slices[ax] = even_indices[ax]
|
||||
# extract the odd/even indices for all detail coefficients
|
||||
details_slice = {}
|
||||
for key, value in details.items():
|
||||
details_slice[key] = value[tuple(odd_even_slices)]
|
||||
details_slice['a'*ndim_transform] = approx[
|
||||
tuple(odd_even_slices)]
|
||||
|
||||
# perform the inverse dwt on the selected indices,
|
||||
# making sure to use periodic boundary conditions
|
||||
x = idwtn(details_slice, wavelets, 'periodization', axes=axes)
|
||||
for o, ax in zip(odds, axes):
|
||||
# circular shift along any odd indexed axis
|
||||
if o:
|
||||
x = np.roll(x, 1, axis=ax)
|
||||
output[tuple(indices)] += x
|
||||
ntransforms += 1
|
||||
output[tuple(indices)] /= ntransforms # normalize
|
||||
if not trim_approx:
|
||||
coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict
|
||||
return output
|
250
venv/Lib/site-packages/pywt/_thresholding.py
Normal file
250
venv/Lib/site-packages/pywt/_thresholding.py
Normal file
|
@ -0,0 +1,250 @@
|
|||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""
|
||||
The thresholding helper module implements the most popular signal thresholding
|
||||
functions.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['threshold', 'threshold_firm']
|
||||
|
||||
|
||||
def soft(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
thresholded = (1 - value/magnitude)
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
if substitute == 0:
|
||||
return thresholded
|
||||
else:
|
||||
cond = np.less(magnitude, value)
|
||||
return np.where(cond, substitute, thresholded)
|
||||
|
||||
|
||||
def nn_garrote(data, value, substitute=0):
|
||||
"""Non-negative Garrote."""
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
thresholded = (1 - value**2/magnitude**2)
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
if substitute == 0:
|
||||
return thresholded
|
||||
else:
|
||||
cond = np.less(magnitude, value)
|
||||
return np.where(cond, substitute, thresholded)
|
||||
|
||||
|
||||
def hard(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
cond = np.less(np.absolute(data), value)
|
||||
return np.where(cond, substitute, data)
|
||||
|
||||
|
||||
def greater(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
if np.iscomplexobj(data):
|
||||
raise ValueError("greater thresholding only supports real data")
|
||||
return np.where(np.less(data, value), substitute, data)
|
||||
|
||||
|
||||
def less(data, value, substitute=0):
|
||||
data = np.asarray(data)
|
||||
if np.iscomplexobj(data):
|
||||
raise ValueError("less thresholding only supports real data")
|
||||
return np.where(np.greater(data, value), substitute, data)
|
||||
|
||||
|
||||
thresholding_options = {'soft': soft,
|
||||
'hard': hard,
|
||||
'greater': greater,
|
||||
'less': less,
|
||||
'garrote': nn_garrote,
|
||||
# misspelled garrote for backwards compatibility
|
||||
'garotte': nn_garrote,
|
||||
}
|
||||
|
||||
|
||||
def threshold(data, value, mode='soft', substitute=0):
|
||||
"""
|
||||
Thresholds the input data depending on the mode argument.
|
||||
|
||||
In ``soft`` thresholding [1]_, data values with absolute value less than
|
||||
`param` are replaced with `substitute`. Data values with absolute value
|
||||
greater or equal to the thresholding value are shrunk toward zero
|
||||
by `value`. In other words, the new value is
|
||||
``data/np.abs(data) * np.maximum(np.abs(data) - value, 0)``.
|
||||
|
||||
In ``hard`` thresholding, the data values where their absolute value is
|
||||
less than the value param are replaced with `substitute`. Data values with
|
||||
absolute value greater or equal to the thresholding value stay untouched.
|
||||
|
||||
``garrote`` corresponds to the Non-negative garrote threshold [2]_, [3]_.
|
||||
It is intermediate between ``hard`` and ``soft`` thresholding. It behaves
|
||||
like soft thresholding for small data values and approaches hard
|
||||
thresholding for large data values.
|
||||
|
||||
In ``greater`` thresholding, the data is replaced with `substitute` where
|
||||
data is below the thresholding value. Greater data values pass untouched.
|
||||
|
||||
In ``less`` thresholding, the data is replaced with `substitute` where data
|
||||
is above the thresholding value. Lesser data values pass untouched.
|
||||
|
||||
Both ``hard`` and ``soft`` thresholding also support complex-valued data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array_like
|
||||
Numeric data.
|
||||
value : scalar
|
||||
Thresholding value.
|
||||
mode : {'soft', 'hard', 'garrote', 'greater', 'less'}
|
||||
Decides the type of thresholding to be applied on input data. Default
|
||||
is 'soft'.
|
||||
substitute : float, optional
|
||||
Substitute value (default: 0).
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : array
|
||||
Thresholded array.
|
||||
|
||||
See Also
|
||||
--------
|
||||
threshold_firm
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D.L. Donoho and I.M. Johnstone. Ideal Spatial Adaptation via
|
||||
Wavelet Shrinkage. Biometrika. Vol. 81, No. 3, pp.425-455, 1994.
|
||||
DOI:10.1093/biomet/81.3.425
|
||||
.. [2] L. Breiman. Better Subset Regression Using the Nonnegative Garrote.
|
||||
Technometrics, Vol. 37, pp. 373-384, 1995.
|
||||
DOI:10.2307/1269730
|
||||
.. [3] H-Y. Gao. Wavelet Shrinkage Denoising Using the Non-Negative
|
||||
Garrote. Journal of Computational and Graphical Statistics Vol. 7,
|
||||
No. 4, pp.469-488. 1998.
|
||||
DOI:10.1080/10618600.1998.10474789
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pywt
|
||||
>>> data = np.linspace(1, 4, 7)
|
||||
>>> data
|
||||
array([ 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'soft')
|
||||
array([ 0. , 0. , 0. , 0.5, 1. , 1.5, 2. ])
|
||||
>>> pywt.threshold(data, 2, 'hard')
|
||||
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'garrote')
|
||||
array([ 0. , 0. , 0. , 0.9 , 1.66666667,
|
||||
2.35714286, 3. ])
|
||||
>>> pywt.threshold(data, 2, 'greater')
|
||||
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
|
||||
>>> pywt.threshold(data, 2, 'less')
|
||||
array([ 1. , 1.5, 2. , 0. , 0. , 0. , 0. ])
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
return thresholding_options[mode](data, value, substitute)
|
||||
except KeyError:
|
||||
# Make sure error is always identical by sorting keys
|
||||
keys = ("'{0}'".format(key) for key in
|
||||
sorted(thresholding_options.keys()))
|
||||
raise ValueError("The mode parameter only takes values from: {0}."
|
||||
.format(', '.join(keys)))
|
||||
|
||||
|
||||
def threshold_firm(data, value_low, value_high):
|
||||
"""Firm threshold.
|
||||
|
||||
The approach is intermediate between soft and hard thresholding [1]_. It
|
||||
behaves the same as soft-thresholding for values below `value_low` and
|
||||
the same as hard-thresholding for values above `thresh_high`. For
|
||||
intermediate values, the thresholded value is in between that corresponding
|
||||
to soft or hard thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : array-like
|
||||
The data to threshold. This can be either real or complex-valued.
|
||||
value_low : float
|
||||
Any values smaller then `value_low` will be set to zero.
|
||||
value_high : float
|
||||
Any values larger than `value_high` will not be modified.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This thresholding technique is also known as semi-soft thresholding [2]_.
|
||||
|
||||
For each value, `x`, in `data`. This function computes::
|
||||
|
||||
if np.abs(x) <= value_low:
|
||||
return 0
|
||||
elif np.abs(x) > value_high:
|
||||
return x
|
||||
elif value_low < np.abs(x) and np.abs(x) <= value_high:
|
||||
return x * value_high * (1 - value_low/x)/(value_high - value_low)
|
||||
|
||||
``firm`` is a continuous function (like soft thresholding), but is
|
||||
unbiased for large values (like hard thresholding).
|
||||
|
||||
If ``value_high == value_low`` this function becomes hard-thresholding.
|
||||
If ``value_high`` is infinity, this function becomes soft-thresholding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val_new : array-like
|
||||
The values after firm thresholding at the specified thresholds.
|
||||
|
||||
See Also
|
||||
--------
|
||||
threshold
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] H.-Y. Gao and A.G. Bruce. Waveshrink with firm shrinkage.
|
||||
Statistica Sinica, Vol. 7, pp. 855-874, 1997.
|
||||
.. [2] A. Bruce and H-Y. Gao. WaveShrink: Shrinkage Functions and
|
||||
Thresholds. Proc. SPIE 2569, Wavelet Applications in Signal and
|
||||
Image Processing III, 1995.
|
||||
DOI:10.1117/12.217582
|
||||
"""
|
||||
|
||||
if value_low < 0:
|
||||
raise ValueError("value_low must be non-negative.")
|
||||
|
||||
if value_high < value_low:
|
||||
raise ValueError(
|
||||
"value_high must be greater than or equal to value_low.")
|
||||
|
||||
data = np.asarray(data)
|
||||
magnitude = np.absolute(data)
|
||||
with np.errstate(divide='ignore'):
|
||||
# divide by zero okay as np.inf values get clipped, so ignore warning.
|
||||
vdiff = value_high - value_low
|
||||
thresholded = value_high * (1 - value_low/magnitude) / vdiff
|
||||
thresholded.clip(min=0, max=None, out=thresholded)
|
||||
thresholded = data * thresholded
|
||||
|
||||
# restore hard-thresholding behavior for values > value_high
|
||||
large_vals = np.where(magnitude > value_high)
|
||||
if np.any(large_vals[0]):
|
||||
thresholded[large_vals] = data[large_vals]
|
||||
return thresholded
|
101
venv/Lib/site-packages/pywt/_utils.py
Normal file
101
venv/Lib/site-packages/pywt/_utils.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) 2017 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
import inspect
|
||||
import numpy as np
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
|
||||
from ._extensions._pywt import (Wavelet, ContinuousWavelet,
|
||||
DiscreteContinuousWavelet, Modes)
|
||||
|
||||
|
||||
# define string_types as in six for Python 2/3 compatibility
|
||||
if sys.version_info[0] == 3:
|
||||
string_types = str,
|
||||
else:
|
||||
string_types = basestring,
|
||||
|
||||
|
||||
def _as_wavelet(wavelet):
|
||||
"""Convert wavelet name to a Wavelet object."""
|
||||
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
|
||||
wavelet = DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, ContinuousWavelet):
|
||||
raise ValueError(
|
||||
"A ContinuousWavelet object was provided, but only discrete "
|
||||
"Wavelet objects are supported by this function. A list of all "
|
||||
"supported discrete wavelets can be obtained by running:\n"
|
||||
"print(pywt.wavelist(kind='discrete'))")
|
||||
return wavelet
|
||||
|
||||
|
||||
def _wavelets_per_axis(wavelet, axes):
|
||||
"""Initialize Wavelets for each axis to be transformed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelet : Wavelet or tuple of Wavelets
|
||||
If a single Wavelet is provided, it will used for all axes. Otherwise
|
||||
one Wavelet per axis must be provided.
|
||||
axes : list
|
||||
The tuple of axes to be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
wavelets : list of Wavelet objects
|
||||
A tuple of Wavelets equal in length to ``axes``.
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
if isinstance(wavelet, string_types + (Wavelet, )):
|
||||
# same wavelet on all axes
|
||||
wavelets = [_as_wavelet(wavelet), ] * len(axes)
|
||||
elif isinstance(wavelet, Iterable):
|
||||
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
|
||||
if len(wavelet) == 1:
|
||||
wavelets = [_as_wavelet(wavelet[0]), ] * len(axes)
|
||||
else:
|
||||
if len(wavelet) != len(axes):
|
||||
raise ValueError((
|
||||
"The number of wavelets must match the number of axes "
|
||||
"to be transformed."))
|
||||
wavelets = [_as_wavelet(w) for w in wavelet]
|
||||
else:
|
||||
raise ValueError("wavelet must be a str, Wavelet or iterable")
|
||||
return wavelets
|
||||
|
||||
|
||||
def _modes_per_axis(modes, axes):
|
||||
"""Initialize mode for each axis to be transformed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
modes : str or tuple of strings
|
||||
If a single mode is provided, it will used for all axes. Otherwise
|
||||
one mode per axis must be provided.
|
||||
axes : tuple
|
||||
The tuple of axes to be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
modes : tuple of int
|
||||
A tuple of Modes equal in length to ``axes``.
|
||||
|
||||
"""
|
||||
axes = tuple(axes)
|
||||
if isinstance(modes, string_types + (int, )):
|
||||
# same wavelet on all axes
|
||||
modes = [Modes.from_object(modes), ] * len(axes)
|
||||
elif isinstance(modes, Iterable):
|
||||
if len(modes) == 1:
|
||||
modes = [Modes.from_object(modes[0]), ] * len(axes)
|
||||
else:
|
||||
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
|
||||
if len(modes) != len(axes):
|
||||
raise ValueError(("The number of modes must match the number "
|
||||
"of axes to be transformed."))
|
||||
modes = [Modes.from_object(mode) for mode in modes]
|
||||
else:
|
||||
raise ValueError("modes must be a str, Mode enum or iterable")
|
||||
return modes
|
733
venv/Lib/site-packages/pywt/_wavelet_packets.py
Normal file
733
venv/Lib/site-packages/pywt/_wavelet_packets.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""1D and 2D Wavelet packet transform module."""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
__all__ = ["BaseNode", "Node", "WaveletPacket", "Node2D", "WaveletPacket2D"]
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._extensions._pywt import Wavelet, _check_dtype
|
||||
from ._dwt import dwt, idwt, dwt_max_level
|
||||
from ._multidim import dwt2, idwt2
|
||||
|
||||
|
||||
def get_graycode_order(level, x='a', y='d'):
|
||||
graycode_order = [x, y]
|
||||
for i in range(level - 1):
|
||||
graycode_order = [x + path for path in graycode_order] + \
|
||||
[y + path for path in graycode_order[::-1]]
|
||||
return graycode_order
|
||||
|
||||
|
||||
class BaseNode(object):
|
||||
"""
|
||||
BaseNode for wavelet packet 1D and 2D tree nodes.
|
||||
|
||||
The BaseNode is a base class for `Node` and `Node2D`.
|
||||
It should not be used directly unless creating a new transformation
|
||||
type. It is included here to document the common interface of 1D
|
||||
and 2D node and wavelet packet transform classes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent :
|
||||
Parent node. If parent is None then the node is considered detached
|
||||
(ie root).
|
||||
data : 1D or 2D array
|
||||
Data associated with the node. 1D or 2D numeric array, depending on the
|
||||
transform type.
|
||||
node_name :
|
||||
A name identifying the coefficients type.
|
||||
See `Node.node_name` and `Node2D.node_name`
|
||||
for information on the accepted subnodes names.
|
||||
"""
|
||||
|
||||
# PART_LEN and PARTS attributes that define path tokens for node[] lookup
|
||||
# must be defined in subclasses.
|
||||
PART_LEN = None
|
||||
PARTS = None
|
||||
|
||||
def __init__(self, parent, data, node_name):
|
||||
self.parent = parent
|
||||
if parent is not None:
|
||||
self.wavelet = parent.wavelet
|
||||
self.mode = parent.mode
|
||||
self.level = parent.level + 1
|
||||
self._maxlevel = parent.maxlevel
|
||||
self.path = parent.path + node_name
|
||||
else:
|
||||
self.wavelet = None
|
||||
self.mode = None
|
||||
self.path = ""
|
||||
self.level = 0
|
||||
|
||||
# data - signal on level 0, coeffs on higher levels
|
||||
self.data = data
|
||||
# Need to retain original data size/shape so we can trim any excess
|
||||
# boundary coefficients from the inverse transform.
|
||||
if self.data is None:
|
||||
self._data_shape = None
|
||||
else:
|
||||
self._data_shape = np.asarray(data).shape
|
||||
|
||||
self._init_subnodes()
|
||||
|
||||
def _init_subnodes(self):
|
||||
for part in self.PARTS:
|
||||
self._set_node(part, None)
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _create_subnode_base(self, node_cls, part, data=None, overwrite=True):
|
||||
self._validate_node_name(part)
|
||||
if not overwrite and self._get_node(part) is not None:
|
||||
return self._get_node(part)
|
||||
node = node_cls(self, data, part)
|
||||
self._set_node(part, node)
|
||||
return node
|
||||
|
||||
def _get_node(self, part):
|
||||
return getattr(self, part)
|
||||
|
||||
def _set_node(self, part, node):
|
||||
setattr(self, part, node)
|
||||
|
||||
def _delete_node(self, part):
|
||||
self._set_node(part, None)
|
||||
|
||||
def _validate_node_name(self, part):
|
||||
if part not in self.PARTS:
|
||||
raise ValueError("Subnode name must be in [%s], not '%s'." %
|
||||
(', '.join("'%s'" % p for p in self.PARTS), part))
|
||||
|
||||
def _evaluate_maxlevel(self, evaluate_from='parent'):
|
||||
"""
|
||||
Try to find the value of maximum decomposition level if it is not
|
||||
specified explicitly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evaluate_from : {'parent', 'subnodes'}
|
||||
"""
|
||||
assert evaluate_from in ('parent', 'subnodes')
|
||||
|
||||
if self._maxlevel is not None:
|
||||
return self._maxlevel
|
||||
elif self.data is not None:
|
||||
return self.level + dwt_max_level(
|
||||
min(self.data.shape), self.wavelet)
|
||||
|
||||
if evaluate_from == 'parent':
|
||||
if self.parent is not None:
|
||||
return self.parent._evaluate_maxlevel(evaluate_from)
|
||||
elif evaluate_from == 'subnodes':
|
||||
for node_name in self.PARTS:
|
||||
node = getattr(self, node_name, None)
|
||||
if node is not None:
|
||||
level = node._evaluate_maxlevel(evaluate_from)
|
||||
if level is not None:
|
||||
return level
|
||||
return None
|
||||
|
||||
@property
|
||||
def maxlevel(self):
|
||||
if self._maxlevel is not None:
|
||||
return self._maxlevel
|
||||
|
||||
# Try getting the maxlevel from parents first
|
||||
self._maxlevel = self._evaluate_maxlevel(evaluate_from='parent')
|
||||
|
||||
# If not found, check whether it can be evaluated from subnodes
|
||||
if self._maxlevel is None:
|
||||
self._maxlevel = self._evaluate_maxlevel(evaluate_from='subnodes')
|
||||
return self._maxlevel
|
||||
|
||||
@property
|
||||
def node_name(self):
|
||||
return self.path[-self.PART_LEN:]
|
||||
|
||||
def decompose(self):
|
||||
"""
|
||||
Decompose node data creating DWT coefficients subnodes.
|
||||
|
||||
Performs Discrete Wavelet Transform on the `~BaseNode.data` and
|
||||
returns transform coefficients.
|
||||
|
||||
Note
|
||||
----
|
||||
Descends to subnodes and recursively
|
||||
calls `~BaseNode.reconstruct` on them.
|
||||
|
||||
"""
|
||||
if self.level < self.maxlevel:
|
||||
return self._decompose()
|
||||
else:
|
||||
raise ValueError("Maximum decomposition level reached.")
|
||||
|
||||
def _decompose(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reconstruct(self, update=False):
|
||||
"""
|
||||
Reconstruct node from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True, then reconstructed data replaces the current
|
||||
node data (default: False).
|
||||
|
||||
Returns:
|
||||
- original node data if subnodes do not exist
|
||||
- IDWT of subnodes otherwise.
|
||||
"""
|
||||
if not self.has_any_subnode:
|
||||
return self.data
|
||||
return self._reconstruct(update)
|
||||
|
||||
def _reconstruct(self):
|
||||
raise NotImplementedError() # override this in subclasses
|
||||
|
||||
def get_subnode(self, part, decompose=True):
|
||||
"""
|
||||
Returns subnode or None (see `decomposition` flag description).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part :
|
||||
Subnode name
|
||||
decompose : bool, optional
|
||||
If the param is True and corresponding subnode does not
|
||||
exist, the subnode will be created using coefficients
|
||||
from the DWT decomposition of the current node.
|
||||
(default: True)
|
||||
"""
|
||||
self._validate_node_name(part)
|
||||
subnode = self._get_node(part)
|
||||
if subnode is None and decompose and not self.is_empty:
|
||||
self.decompose()
|
||||
subnode = self._get_node(part)
|
||||
return subnode
|
||||
|
||||
def __getitem__(self, path):
|
||||
"""
|
||||
Find node represented by the given path.
|
||||
|
||||
Similar to `~BaseNode.get_subnode` method with `decompose=True`, but
|
||||
can access nodes on any level in the decomposition tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names. See `Node.node_name` and
|
||||
`Node2D.node_name` for node naming convention.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If node does not exist yet, it will be created by decomposition of its
|
||||
parent node.
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
if (self.maxlevel is not None
|
||||
and len(path) > self.maxlevel * self.PART_LEN):
|
||||
raise IndexError("Path length is out of range.")
|
||||
if path:
|
||||
return self.get_subnode(path[0:self.PART_LEN], True)[
|
||||
path[self.PART_LEN:]]
|
||||
else:
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Invalid path parameter type - expected string but"
|
||||
" got %s." % type(path))
|
||||
|
||||
def __setitem__(self, path, data):
|
||||
"""
|
||||
Set node or node's data in the decomposition tree. Nodes are
|
||||
identified by string `path`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names.
|
||||
data : array or BaseNode subclass.
|
||||
"""
|
||||
|
||||
if isinstance(path, str):
|
||||
if (
|
||||
self.maxlevel is not None
|
||||
and len(self.path) + len(path) > self.maxlevel * self.PART_LEN
|
||||
):
|
||||
raise IndexError("Path length out of range.")
|
||||
if path:
|
||||
subnode = self.get_subnode(path[0:self.PART_LEN], False)
|
||||
if subnode is None:
|
||||
self._create_subnode(path[0:self.PART_LEN], None)
|
||||
subnode = self.get_subnode(path[0:self.PART_LEN], False)
|
||||
subnode[path[self.PART_LEN:]] = data
|
||||
else:
|
||||
if isinstance(data, BaseNode):
|
||||
self.data = np.asarray(data.data)
|
||||
else:
|
||||
self.data = np.asarray(data)
|
||||
# convert data to nearest supported dtype
|
||||
dtype = _check_dtype(data)
|
||||
if self.data.dtype != dtype:
|
||||
self.data = self.data.astype(dtype)
|
||||
else:
|
||||
raise TypeError("Invalid path parameter type - expected string but"
|
||||
" got %s." % type(path))
|
||||
|
||||
def __delitem__(self, path):
|
||||
"""
|
||||
Remove node from the tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names.
|
||||
"""
|
||||
node = self[path]
|
||||
# don't clear node value and subnodes (node may still exist outside
|
||||
# the tree)
|
||||
# # node._init_subnodes()
|
||||
# # node.data = None
|
||||
parent = node.parent
|
||||
node.parent = None # TODO
|
||||
if parent and node.node_name:
|
||||
parent._delete_node(node.node_name)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return self.data is None
|
||||
|
||||
@property
|
||||
def has_any_subnode(self):
|
||||
for part in self.PARTS:
|
||||
if self._get_node(part) is not None: # and not .is_empty
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_leaf_nodes(self, decompose=False):
|
||||
"""
|
||||
Returns leaf nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decompose : bool, optional
|
||||
(default: True)
|
||||
"""
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == node.maxlevel and not node.is_empty:
|
||||
result.append(node)
|
||||
return False
|
||||
if not decompose and not node.has_any_subnode:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
self.walk(collect, decompose=decompose)
|
||||
return result
|
||||
|
||||
def walk(self, func, args=(), kwargs=None, decompose=True):
|
||||
"""
|
||||
Traverses the decomposition tree and calls
|
||||
``func(node, *args, **kwargs)`` on every node. If `func` returns True,
|
||||
descending to subnodes will continue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable accepting `BaseNode` as the first param and
|
||||
optional positional and keyword arguments
|
||||
args :
|
||||
func params
|
||||
kwargs :
|
||||
func keyword params
|
||||
decompose : bool, optional
|
||||
If True (default), the method will also try to decompose the tree
|
||||
up to the `maximum level <BaseNode.maxlevel>`.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func(self, *args, **kwargs) and self.level < self.maxlevel:
|
||||
for part in self.PARTS:
|
||||
subnode = self.get_subnode(part, decompose)
|
||||
if subnode is not None:
|
||||
subnode.walk(func, args, kwargs, decompose)
|
||||
|
||||
def walk_depth(self, func, args=(), kwargs=None, decompose=True):
|
||||
"""
|
||||
Walk tree and call func on every node starting from the bottom-most
|
||||
nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable accepting :class:`BaseNode` as the first param and
|
||||
optional positional and keyword arguments
|
||||
args :
|
||||
func params
|
||||
kwargs :
|
||||
func keyword params
|
||||
decompose : bool, optional
|
||||
(default: False)
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.level < self.maxlevel:
|
||||
for part in self.PARTS:
|
||||
subnode = self.get_subnode(part, decompose)
|
||||
if subnode is not None:
|
||||
subnode.walk_depth(func, args, kwargs, decompose)
|
||||
func(self, *args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return self.path + ": " + str(self.data)
|
||||
|
||||
|
||||
class Node(BaseNode):
|
||||
"""
|
||||
WaveletPacket tree node.
|
||||
|
||||
Subnodes are called `a` and `d`, just like approximation
|
||||
and detail coefficients in the Discrete Wavelet Transform.
|
||||
"""
|
||||
|
||||
A = 'a'
|
||||
D = 'd'
|
||||
PARTS = A, D
|
||||
PART_LEN = 1
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
return self._create_subnode_base(node_cls=Node, part=part, data=data,
|
||||
overwrite=overwrite)
|
||||
|
||||
def _decompose(self):
|
||||
"""
|
||||
|
||||
See also
|
||||
--------
|
||||
dwt : for 1D Discrete Wavelet Transform output coefficients.
|
||||
"""
|
||||
if self.is_empty:
|
||||
data_a, data_d = None, None
|
||||
if self._get_node(self.A) is None:
|
||||
self._create_subnode(self.A, data_a)
|
||||
if self._get_node(self.D) is None:
|
||||
self._create_subnode(self.D, data_d)
|
||||
else:
|
||||
data_a, data_d = dwt(self.data, self.wavelet, self.mode)
|
||||
self._create_subnode(self.A, data_a)
|
||||
self._create_subnode(self.D, data_d)
|
||||
return self._get_node(self.A), self._get_node(self.D)
|
||||
|
||||
def _reconstruct(self, update):
|
||||
data_a, data_d = None, None
|
||||
node_a, node_d = self._get_node(self.A), self._get_node(self.D)
|
||||
|
||||
if node_a is not None:
|
||||
data_a = node_a.reconstruct() # TODO: (update) ???
|
||||
if node_d is not None:
|
||||
data_d = node_d.reconstruct() # TODO: (update) ???
|
||||
|
||||
if data_a is None and data_d is None:
|
||||
raise ValueError("Node is a leaf node and cannot be reconstructed"
|
||||
" from subnodes.")
|
||||
else:
|
||||
rec = idwt(data_a, data_d, self.wavelet, self.mode)
|
||||
if self._data_shape is not None and (
|
||||
rec.shape != self._data_shape):
|
||||
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
|
||||
if update:
|
||||
self.data = rec
|
||||
return rec
|
||||
|
||||
|
||||
class Node2D(BaseNode):
|
||||
"""
|
||||
WaveletPacket tree node.
|
||||
|
||||
Subnodes are called 'a' (LL), 'h' (HL), 'v' (LH) and 'd' (HH), like
|
||||
approximation and detail coefficients in the 2D Discrete Wavelet Transform
|
||||
"""
|
||||
|
||||
LL = 'a'
|
||||
HL = 'h'
|
||||
LH = 'v'
|
||||
HH = 'd'
|
||||
|
||||
PARTS = LL, HL, LH, HH
|
||||
PART_LEN = 1
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
return self._create_subnode_base(node_cls=Node2D, part=part, data=data,
|
||||
overwrite=overwrite)
|
||||
|
||||
def _decompose(self):
|
||||
"""
|
||||
See also
|
||||
--------
|
||||
dwt2 : for 2D Discrete Wavelet Transform output coefficients.
|
||||
"""
|
||||
if self.is_empty:
|
||||
data_ll, data_lh, data_hl, data_hh = None, None, None, None
|
||||
else:
|
||||
data_ll, (data_hl, data_lh, data_hh) =\
|
||||
dwt2(self.data, self.wavelet, self.mode)
|
||||
self._create_subnode(self.LL, data_ll)
|
||||
self._create_subnode(self.LH, data_lh)
|
||||
self._create_subnode(self.HL, data_hl)
|
||||
self._create_subnode(self.HH, data_hh)
|
||||
return (self._get_node(self.LL), self._get_node(self.HL),
|
||||
self._get_node(self.LH), self._get_node(self.HH))
|
||||
|
||||
def _reconstruct(self, update):
|
||||
data_ll, data_lh, data_hl, data_hh = None, None, None, None
|
||||
|
||||
node_ll, node_lh, node_hl, node_hh =\
|
||||
self._get_node(self.LL), self._get_node(self.LH),\
|
||||
self._get_node(self.HL), self._get_node(self.HH)
|
||||
|
||||
if node_ll is not None:
|
||||
data_ll = node_ll.reconstruct()
|
||||
if node_lh is not None:
|
||||
data_lh = node_lh.reconstruct()
|
||||
if node_hl is not None:
|
||||
data_hl = node_hl.reconstruct()
|
||||
if node_hh is not None:
|
||||
data_hh = node_hh.reconstruct()
|
||||
|
||||
if (data_ll is None and data_lh is None
|
||||
and data_hl is None and data_hh is None):
|
||||
raise ValueError(
|
||||
"Tree is missing data - all subnodes of `%s` node "
|
||||
"are None. Cannot reconstruct node." % self.path
|
||||
)
|
||||
else:
|
||||
coeffs = data_ll, (data_hl, data_lh, data_hh)
|
||||
rec = idwt2(coeffs, self.wavelet, self.mode)
|
||||
if self._data_shape is not None and (
|
||||
rec.shape != self._data_shape):
|
||||
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
|
||||
if update:
|
||||
self.data = rec
|
||||
return rec
|
||||
|
||||
def expand_2d_path(self, path):
|
||||
expanded_paths = {
|
||||
self.HH: 'hh',
|
||||
self.HL: 'hl',
|
||||
self.LH: 'lh',
|
||||
self.LL: 'll'
|
||||
}
|
||||
return (''.join([expanded_paths[p][0] for p in path]),
|
||||
''.join([expanded_paths[p][1] for p in path]))
|
||||
|
||||
|
||||
class WaveletPacket(Node):
|
||||
"""
|
||||
Data structure representing Wavelet Packet decomposition of signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : 1D ndarray
|
||||
Original data (signal)
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet used in DWT decomposition and reconstruction
|
||||
mode : str, optional
|
||||
Signal extension mode for the `dwt` and `idwt` decomposition and
|
||||
reconstruction functions.
|
||||
maxlevel : int, optional
|
||||
Maximum level of decomposition.
|
||||
If None, it will be calculated based on the `wavelet` and `data`
|
||||
length using `pywt.dwt_max_level`.
|
||||
"""
|
||||
def __init__(self, data, wavelet, mode='symmetric', maxlevel=None):
|
||||
super(WaveletPacket, self).__init__(None, data, "")
|
||||
|
||||
if not isinstance(wavelet, Wavelet):
|
||||
wavelet = Wavelet(wavelet)
|
||||
self.wavelet = wavelet
|
||||
self.mode = mode
|
||||
|
||||
if data is not None:
|
||||
data = np.asarray(data)
|
||||
assert data.ndim == 1
|
||||
self.data_size = data.shape[0]
|
||||
if maxlevel is None:
|
||||
maxlevel = dwt_max_level(self.data_size, self.wavelet)
|
||||
else:
|
||||
self.data_size = None
|
||||
|
||||
self._maxlevel = maxlevel
|
||||
|
||||
def reconstruct(self, update=True):
|
||||
"""
|
||||
Reconstruct data value using coefficients from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True (default), then data values will be replaced by
|
||||
reconstruction values, also in subnodes.
|
||||
"""
|
||||
if self.has_any_subnode:
|
||||
data = super(WaveletPacket, self).reconstruct(update)
|
||||
if update:
|
||||
self.data = data
|
||||
return data
|
||||
return self.data # return original data
|
||||
|
||||
def get_level(self, level, order="natural", decompose=True):
|
||||
"""
|
||||
Returns all nodes on the specified level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : int
|
||||
Specifies decomposition `level` from which the nodes will be
|
||||
collected.
|
||||
order : {'natural', 'freq'}, optional
|
||||
- "natural" - left to right in tree (default)
|
||||
- "freq" - band ordered
|
||||
decompose : bool, optional
|
||||
If set then the method will try to decompose the data up
|
||||
to the specified `level` (default: True).
|
||||
|
||||
Notes
|
||||
-----
|
||||
If nodes at the given level are missing (i.e. the tree is partially
|
||||
decomposed) and the `decompose` is set to False, only existing nodes
|
||||
will be returned.
|
||||
"""
|
||||
assert order in ["natural", "freq"]
|
||||
if level > self.maxlevel:
|
||||
raise ValueError("The level cannot be greater than the maximum"
|
||||
" decomposition level value (%d)" % self.maxlevel)
|
||||
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == level:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
|
||||
self.walk(collect, decompose=decompose)
|
||||
if order == "natural":
|
||||
return result
|
||||
elif order == "freq":
|
||||
result = dict((node.path, node) for node in result)
|
||||
graycode_order = get_graycode_order(level)
|
||||
return [result[path] for path in graycode_order if path in result]
|
||||
else:
|
||||
raise ValueError("Invalid order name - %s." % order)
|
||||
|
||||
|
||||
class WaveletPacket2D(Node2D):
|
||||
"""
|
||||
Data structure representing 2D Wavelet Packet decomposition of signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : 2D ndarray
|
||||
Data associated with the node.
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet used in DWT decomposition and reconstruction
|
||||
mode : str, optional
|
||||
Signal extension mode for the `dwt` and `idwt` decomposition and
|
||||
reconstruction functions.
|
||||
maxlevel : int
|
||||
Maximum level of decomposition.
|
||||
If None, it will be calculated based on the `wavelet` and `data`
|
||||
length using `pywt.dwt_max_level`.
|
||||
"""
|
||||
def __init__(self, data, wavelet, mode='smooth', maxlevel=None):
|
||||
super(WaveletPacket2D, self).__init__(None, data, "")
|
||||
|
||||
if not isinstance(wavelet, Wavelet):
|
||||
wavelet = Wavelet(wavelet)
|
||||
self.wavelet = wavelet
|
||||
self.mode = mode
|
||||
|
||||
if data is not None:
|
||||
data = np.asarray(data)
|
||||
assert data.ndim == 2
|
||||
self.data_size = data.shape
|
||||
if maxlevel is None:
|
||||
maxlevel = dwt_max_level(min(self.data_size), self.wavelet)
|
||||
else:
|
||||
self.data_size = None
|
||||
self._maxlevel = maxlevel
|
||||
|
||||
def reconstruct(self, update=True):
|
||||
"""
|
||||
Reconstruct data using coefficients from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True (default) then the coefficients of the current node
|
||||
and its subnodes will be replaced with values from reconstruction.
|
||||
"""
|
||||
if self.has_any_subnode:
|
||||
data = super(WaveletPacket2D, self).reconstruct(update)
|
||||
if update:
|
||||
self.data = data
|
||||
return data
|
||||
return self.data # return original data
|
||||
|
||||
def get_level(self, level, order="natural", decompose=True):
|
||||
"""
|
||||
Returns all nodes from specified level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : int
|
||||
Decomposition `level` from which the nodes will be
|
||||
collected.
|
||||
order : {'natural', 'freq'}, optional
|
||||
If `natural` (default) a flat list is returned.
|
||||
If `freq`, a 2d structure with rows and cols
|
||||
sorted by corresponding dimension frequency of 2d
|
||||
coefficient array (adapted from 1d case).
|
||||
decompose : bool, optional
|
||||
If set then the method will try to decompose the data up
|
||||
to the specified `level` (default: True).
|
||||
"""
|
||||
assert order in ["natural", "freq"]
|
||||
if level > self.maxlevel:
|
||||
raise ValueError("The level cannot be greater than the maximum"
|
||||
" decomposition level value (%d)" % self.maxlevel)
|
||||
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == level:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
|
||||
self.walk(collect, decompose=decompose)
|
||||
|
||||
if order == "freq":
|
||||
nodes = {}
|
||||
for (row_path, col_path), node in [
|
||||
(self.expand_2d_path(node.path), node) for node in result
|
||||
]:
|
||||
nodes.setdefault(row_path, {})[col_path] = node
|
||||
graycode_order = get_graycode_order(level, x='l', y='h')
|
||||
nodes = [nodes[path] for path in graycode_order if path in nodes]
|
||||
result = []
|
||||
for row in nodes:
|
||||
result.append(
|
||||
[row[path] for path in graycode_order if path in row]
|
||||
)
|
||||
return result
|
6
venv/Lib/site-packages/pywt/conftest.py
Normal file
6
venv/Lib/site-packages/pywt/conftest.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers",
|
||||
"slow: Tests that are slow.")
|
2
venv/Lib/site-packages/pywt/data/__init__.py
Normal file
2
venv/Lib/site-packages/pywt/data/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from ._readers import ascent, aero, ecg, camera, nino
|
||||
from ._wavelab_signals import demo_signal
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
185
venv/Lib/site-packages/pywt/data/_readers.py
Normal file
185
venv/Lib/site-packages/pywt/data/_readers.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def ascent():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
The image is derived from accent-to-the-top.jpg at
|
||||
http://www.public-domain-image.com/people-public-domain-images-pictures/
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
ascent : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> ascent = pywt.data.ascent()
|
||||
>>> ascent.shape == (512, 512)
|
||||
True
|
||||
>>> ascent.max()
|
||||
255
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(ascent) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'ascent.npz')
|
||||
ascent = np.load(fname)['data']
|
||||
return ascent
|
||||
|
||||
|
||||
def aero():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
aero : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> aero = pywt.data.ascent()
|
||||
>>> aero.shape == (512, 512)
|
||||
True
|
||||
>>> aero.max()
|
||||
255
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(aero) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'aero.npz')
|
||||
aero = np.load(fname)['data']
|
||||
return aero
|
||||
|
||||
|
||||
def camera():
|
||||
"""
|
||||
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
|
||||
easy use in demos
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
camera : ndarray
|
||||
convenient image to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> camera = pywt.data.ascent()
|
||||
>>> camera.shape == (512, 512)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.gray()
|
||||
>>> plt.imshow(camera) # doctest: +ELLIPSIS
|
||||
<matplotlib.image.AxesImage object at ...>
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'camera.npz')
|
||||
camera = np.load(fname)['data']
|
||||
return camera
|
||||
|
||||
|
||||
def ecg():
|
||||
"""
|
||||
Get 1024 points of an ECG timeseries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
ecg : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> ecg = pywt.data.ecg()
|
||||
>>> ecg.shape == (1024,)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.plot(ecg) # doctest: +ELLIPSIS
|
||||
[<matplotlib.lines.Line2D object at ...>]
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'ecg.npy')
|
||||
ecg = np.load(fname)
|
||||
return ecg
|
||||
|
||||
|
||||
def nino():
|
||||
"""
|
||||
This data contains the averaged monthly sea surface temperature in degrees
|
||||
Celcius of the Pacific Ocean, between 0-10 degrees South and 90-80 degrees West, from 1950 to 2016.
|
||||
This dataset is in the public domain and was obtained from NOAA.
|
||||
National Oceanic and Atmospheric Administration's National Weather Service
|
||||
ERSSTv4 dataset, nino 3, http://www.cpc.ncep.noaa.gov/data/indices/
|
||||
|
||||
Parameters
|
||||
----------
|
||||
None
|
||||
|
||||
Returns
|
||||
-------
|
||||
time : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
sst : ndarray
|
||||
convenient timeseries to use for testing and demonstration
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import pywt.data
|
||||
>>> time, sst = pywt.data.nino()
|
||||
>>> sst.shape == (264,)
|
||||
True
|
||||
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> plt.plot(time,sst) # doctest: +ELLIPSIS
|
||||
[<matplotlib.lines.Line2D object at ...>]
|
||||
>>> plt.show() # doctest: +SKIP
|
||||
"""
|
||||
fname = os.path.join(os.path.dirname(__file__), 'sst_nino3.npz')
|
||||
sst_csv = np.load(fname)['sst_csv']
|
||||
# sst_csv = pd.read_csv("http://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii", sep=' ', skipinitialspace=True)
|
||||
# take only full years
|
||||
n = int(np.floor(sst_csv.shape[0]/12.)*12.)
|
||||
# Building the mean of three mounth
|
||||
# the 4. column is nino 3
|
||||
sst = np.mean(np.reshape(np.array(sst_csv)[:n, 4], (n//3, -1)), axis=1)
|
||||
sst = (sst - np.mean(sst)) / np.std(sst, ddof=1)
|
||||
|
||||
dt = 0.25
|
||||
time = np.arange(len(sst)) * dt + 1950.0 # construct time array
|
||||
return time, sst
|
259
venv/Lib/site-packages/pywt/data/_wavelab_signals.py
Normal file
259
venv/Lib/site-packages/pywt/data/_wavelab_signals.py
Normal file
|
@ -0,0 +1,259 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from __future__ import division
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['demo_signal']
|
||||
|
||||
_implemented_signals = [
|
||||
'Blocks',
|
||||
'Bumps',
|
||||
'HeaviSine',
|
||||
'Doppler',
|
||||
'Ramp',
|
||||
'HiSine',
|
||||
'LoSine',
|
||||
'LinChirp',
|
||||
'TwoChirp',
|
||||
'QuadChirp',
|
||||
'MishMash',
|
||||
'WernerSorrows',
|
||||
'HypChirps',
|
||||
'LinChirps',
|
||||
'Chirps',
|
||||
'Gabor',
|
||||
'sineoneoverx',
|
||||
'Piece-Regular',
|
||||
'Piece-Polynomial',
|
||||
'Riemann']
|
||||
|
||||
|
||||
def demo_signal(name='Bumps', n=None):
|
||||
"""Simple 1D wavelet test functions.
|
||||
|
||||
This function can generate a number of common 1D test signals used in
|
||||
papers by David Donoho and colleagues (e.g. [1]_) as well as the wavelet
|
||||
book by Stéphane Mallat [2]_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : {'Blocks', 'Bumps', 'HeaviSine', 'Doppler', ...}
|
||||
The type of test signal to generate (`name` is case-insensitive). If
|
||||
`name` is set to `'list'`, a list of the avialable test functions is
|
||||
returned.
|
||||
n : int or None
|
||||
The length of the test signal. This should be provided for all test
|
||||
signals except `'Gabor'` and `'sineoneoverx'` which have a fixed
|
||||
length.
|
||||
|
||||
Returns
|
||||
-------
|
||||
f : np.ndarray
|
||||
Array of length ``n`` corresponding to the specified test signal type.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] D.L. Donoho and I.M. Johnstone. Ideal spatial adaptation by
|
||||
wavelet shrinkage. Biometrika, vol. 81, pp. 425–455, 1994.
|
||||
.. [2] S. Mallat. A Wavelet Tour of Signal Processing: The Sparse Way.
|
||||
Academic Press. 2009.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is a partial reimplementation of the `MakeSignal` function
|
||||
from the [Wavelab](https://statweb.stanford.edu/~wavelab/) toolbox. These
|
||||
test signals are provided with permission of Dr. Donoho to encourage
|
||||
reproducible research.
|
||||
|
||||
"""
|
||||
if name.lower() == 'list':
|
||||
return _implemented_signals
|
||||
|
||||
if n is not None:
|
||||
if n < 1 or (n % 1) != 0:
|
||||
raise ValueError("n must be an integer >= 1")
|
||||
t = np.arange(1/n, 1 + 1/n, 1/n)
|
||||
|
||||
# The following function types don't allow user-specified `n`.
|
||||
n_hard_coded = ['gabor', 'sineoneoverx']
|
||||
|
||||
name = name.lower()
|
||||
if name in n_hard_coded and n is not None:
|
||||
raise ValueError(
|
||||
"Parameter n must be set to None when name is {}".format(name))
|
||||
elif n is None and name not in n_hard_coded:
|
||||
raise ValueError(
|
||||
"Parameter n must be provided when name is {}".format(name))
|
||||
|
||||
if name == 'blocks':
|
||||
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
|
||||
hs = [4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2]
|
||||
f = 0
|
||||
for (t0, h) in zip(t0s, hs):
|
||||
f += h * (1 + np.sign(t - t0)) / 2
|
||||
elif name == 'bumps':
|
||||
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
|
||||
hs = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
|
||||
ws = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
|
||||
f = 0
|
||||
for (t0, h, w) in zip(t0s, hs, ws):
|
||||
f += h / (1 + np.abs((t - t0) / w))**4
|
||||
elif name == 'heavisine':
|
||||
f = 4 * np.sin(4 * np.pi * t) - np.sign(t - 0.3) - np.sign(0.72 - t)
|
||||
elif name == 'doppler':
|
||||
f = np.sqrt(t * (1 - t)) * np.sin(2 * np.pi * 1.05 / (t + 0.05))
|
||||
elif name == 'ramp':
|
||||
f = t - (t >= .37)
|
||||
elif name == 'hisine':
|
||||
f = np.sin(np.pi * (n * .6902) * t)
|
||||
elif name == 'losine':
|
||||
f = np.sin(np.pi * (n * .3333) * t)
|
||||
elif name == 'linchirp':
|
||||
f = np.sin(np.pi * t * ((n * .500) * t))
|
||||
elif name == 'twochirp':
|
||||
f = np.sin(np.pi * t * (n * t)) + np.sin((np.pi / 3) * t * (n * t))
|
||||
elif name == 'quadchirp':
|
||||
f = np.sin((np.pi / 3) * t * (n * t**2))
|
||||
elif name == 'mishmash': # QuadChirp + LinChirp + HiSine
|
||||
f = np.sin((np.pi / 3) * t * (n * t**2))
|
||||
f += np.sin(np.pi * (n * .6902) * t)
|
||||
f += np.sin(np.pi * t * (n * .125 * t))
|
||||
elif name == 'wernersorrows':
|
||||
f = np.sin(np.pi * t * (n / 2 * t**2))
|
||||
f = f + np.sin(np.pi * (n * .6902) * t)
|
||||
f = f + np.sin(np.pi * t * (n * t))
|
||||
pos = [.1, .13, .15, .23, .25, .40, .44, .65, .76, .78, .81]
|
||||
hgt = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
|
||||
wth = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
|
||||
for p, h, w in zip(pos, hgt, wth):
|
||||
f += h / (1 + np.abs((t - p) / w))**4
|
||||
elif name == 'hypchirps': # Hyperbolic Chirps of Mallat's book
|
||||
alpha = 15 * n * np.pi / 1024
|
||||
beta = 5 * n * np.pi / 1024
|
||||
t = np.arange(1.001, n + .001 + 1) / n
|
||||
f1 = np.zeros(n)
|
||||
f2 = np.zeros(n)
|
||||
f1 = np.sin(alpha / (.8 - t)) * (0.1 < t) * (t < 0.68)
|
||||
f2 = np.sin(beta / (.8 - t)) * (0.1 < t) * (t < 0.75)
|
||||
m = int(np.round(0.65 * n))
|
||||
p = m // 4
|
||||
envelope = np.ones(m) # the rinp.sing cutoff function
|
||||
tmp = np.arange(1, p + 1)-np.ones(p)
|
||||
envelope[:p] = (1 + np.sin(-np.pi / 2 + tmp / (p - 1) * np.pi)) / 2
|
||||
envelope[m-p:m] = envelope[:p][::-1]
|
||||
env = np.zeros(n)
|
||||
env[int(np.ceil(n / 10)) - 1:m + int(np.ceil(n / 10)) - 1] = \
|
||||
envelope[:m]
|
||||
f = (f1 + f2) * env
|
||||
elif name == 'linchirps': # Linear Chirps of Mallat's book
|
||||
b = 100 * n * np.pi / 1024
|
||||
a = 250 * n * np.pi / 1024
|
||||
t = np.arange(1, n + 1) / n
|
||||
A1 = np.sqrt((t - 1 / n) * (1 - t))
|
||||
f = A1 * (np.cos(a * t**2) + np.cos(b * t + a * t**2))
|
||||
elif name == 'chirps': # Mixture of Chirps of Mallat's book
|
||||
t = np.arange(1, n + 1)/n * 10 * np.pi
|
||||
f1 = np.cos(t**2 * n / 1024)
|
||||
a = 30 * n / 1024
|
||||
t = np.arange(1, n + 1)/n * np.pi
|
||||
f2 = np.cos(a * (t**3))
|
||||
f2 = f2[::-1]
|
||||
ix = np.arange(-n, n + 1) / n * 20
|
||||
g = np.exp(-ix**2 * 4 * n / 1024)
|
||||
i1 = slice(n // 2, n // 2 + n)
|
||||
i2 = slice(n // 8, n // 8 + n)
|
||||
j = np.arange(1, n + 1) / n
|
||||
f3 = g[i1] * np.cos(50 * np.pi * j * n / 1024)
|
||||
f4 = g[i2] * np.cos(350 * np.pi * j * n / 1024)
|
||||
f = f1 + f2 + f3 + f4
|
||||
envelope = np.ones(n) # the rinp.sing cutoff function
|
||||
tmp = np.arange(1, n // 8 + 1) - np.ones(n // 8)
|
||||
envelope[:n // 8] = (
|
||||
1 + np.sin(-np.pi / 2 + tmp / (n / 8 - 1) * np.pi)) / 2
|
||||
envelope[7 * n // 8:n] = envelope[:n // 8][::-1]
|
||||
f = f*envelope
|
||||
elif name == 'gabor': # two modulated Gabor functions in Mallat's book
|
||||
n = 512
|
||||
t = np.arange(-n, n + 1)*5 / n
|
||||
j = np.arange(1, n + 1) / n
|
||||
g = np.exp(-t**2 * 20)
|
||||
i1 = slice(2*n // 4, 2 * n // 4 + n)
|
||||
i2 = slice(n // 4, n // 4 + n)
|
||||
f1 = 3 * g[i1] * np.exp(1j * (n // 16) * np.pi * j)
|
||||
f2 = 3 * g[i2] * np.exp(1j * (n // 4) * np.pi * j)
|
||||
f = f1 + f2
|
||||
elif name == 'sineoneoverx': # np.sin(1/x) in Mallat's book
|
||||
n = 1024
|
||||
i1 = np.arange(-n + 1, n + 1, dtype=float)
|
||||
i1[i1 == 0] = 1 / 100
|
||||
i1 = i1 / (n - 1)
|
||||
f = np.sin(1.5 / i1)
|
||||
f = f[512:1536]
|
||||
elif name == 'piece-regular':
|
||||
f = np.zeros(n)
|
||||
n_12 = int(np.fix(n / 12))
|
||||
n_7 = int(np.fix(n / 7))
|
||||
n_5 = int(np.fix(n / 5))
|
||||
n_3 = int(np.fix(n / 3))
|
||||
n_2 = int(np.fix(n / 2))
|
||||
n_20 = int(np.fix(n / 20))
|
||||
f1 = -15 * demo_signal('bumps', n)
|
||||
t = np.arange(1, n_12 + 1) / n_12
|
||||
f2 = -np.exp(4 * t)
|
||||
t = np.arange(1, n_7 + 1) / n_7
|
||||
f5 = np.exp(4 * t)-np.exp(4)
|
||||
t = np.arange(1, n_3 + 1) / n_3
|
||||
fma = 6 / 40
|
||||
f6 = -70 * np.exp(-((t - 0.5) * (t - 0.5)) / (2 * fma**2))
|
||||
f[:n_7] = f6[:n_7]
|
||||
f[n_7:n_5] = 0.5 * f6[n_7:n_5]
|
||||
f[n_5:n_3] = f6[n_5:n_3]
|
||||
f[n_3:n_2] = f1[n_3:n_2]
|
||||
f[n_2:n_2 + n_12] = f2
|
||||
f[n_2 + 2 * n_12 - 1:n_2 + n_12 - 1:-1] = f2
|
||||
f[n_2 + 2 * n_12 + n_20:n_2 + 2 * n_12 + 3 * n_20] = -np.ones(
|
||||
n_2 + 2*n_12 + 3*n_20 - n_2 - 2*n_12 - n_20) * 25
|
||||
k = n_2 + 2 * n_12 + 3 * n_20
|
||||
f[k:k + n_7] = f5
|
||||
diff = n - 5 * n_5
|
||||
f[5 * n_5:n] = f[diff - 1::-1]
|
||||
# zero-mean
|
||||
bias = np.sum(f) / n
|
||||
f = bias - f
|
||||
elif name == 'piece-polynomial':
|
||||
f = np.zeros(n)
|
||||
n_5 = int(np.fix(n / 5))
|
||||
n_10 = int(np.fix(n / 10))
|
||||
n_20 = int(np.fix(n / 20))
|
||||
t = np.arange(1, n_5 + 1) / n_5
|
||||
f1 = 20 * (t**3 + t**2 + 4)
|
||||
f3 = 40 * (2 * t**3 + t) + 100
|
||||
f2 = 10 * t**3 + 45
|
||||
f4 = 16 * t**2 + 8 * t + 16
|
||||
f5 = 20 * (t + 4)
|
||||
f6 = np.ones(n_10) * 20
|
||||
f[:n_5] = f1
|
||||
f[2 * n_5 - 1:n_5 - 1:-1] = f2
|
||||
f[2 * n_5:3 * n_5] = f3
|
||||
f[3 * n_5:4 * n_5] = f4
|
||||
f[4 * n_5:5 * n_5] = f5[n_5::-1]
|
||||
diff = n - 5*n_5
|
||||
f[5 * n_5:n] = f[diff - 1::-1]
|
||||
f[n_20:n_20 + n_10] = np.ones(n_10) * 10
|
||||
f[n - n_10:n + n_20 - n_10] = np.ones(n_20) * 150
|
||||
# zero-mean
|
||||
bias = np.sum(f) / n
|
||||
f = f - bias
|
||||
elif name == 'riemann':
|
||||
# Riemann's Non-differentiable Function
|
||||
sqn = int(np.round(np.sqrt(n)))
|
||||
idx = np.arange(1, sqn + 1)
|
||||
idx *= idx
|
||||
f = np.zeros_like(t)
|
||||
f[idx - 1] = 1. / np.arange(1, sqn + 1)
|
||||
f = np.real(np.fft.ifft(f))
|
||||
else:
|
||||
raise ValueError(
|
||||
"unknown name: {}. name must be one of: {}".format(
|
||||
name, _implemented_signals))
|
||||
return f
|
BIN
venv/Lib/site-packages/pywt/data/aero.npz
Normal file
BIN
venv/Lib/site-packages/pywt/data/aero.npz
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/data/ascent.npz
Normal file
BIN
venv/Lib/site-packages/pywt/data/ascent.npz
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/data/camera.npz
Normal file
BIN
venv/Lib/site-packages/pywt/data/camera.npz
Normal file
Binary file not shown.
39
venv/Lib/site-packages/pywt/data/create_dat.py
Normal file
39
venv/Lib/site-packages/pywt/data/create_dat.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""Helper script for creating image .dat files by numpy.save
|
||||
|
||||
Usage:
|
||||
|
||||
python create_dat.py <name of image file> <name of dat file>
|
||||
|
||||
Example (to create aero.dat):
|
||||
|
||||
python create_dat.py aero.png aero.dat
|
||||
|
||||
Requires Scipy and PIL.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
from scipy.misc import imread
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
print(__doc__)
|
||||
exit()
|
||||
|
||||
image_fname = sys.argv[1]
|
||||
dat_fname = sys.argv[2]
|
||||
|
||||
data = imread(image_fname)
|
||||
|
||||
np.savez_compressed(dat_fname, data=data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
BIN
venv/Lib/site-packages/pywt/data/ecg.npy
Normal file
BIN
venv/Lib/site-packages/pywt/data/ecg.npy
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/pywt/data/sst_nino3.npz
Normal file
BIN
venv/Lib/site-packages/pywt/data/sst_nino3.npz
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,96 @@
|
|||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw')]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
# Matlab result
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = ("[ma, md] = dwt(data, wavelet, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
|
||||
# Matlab result
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
|
||||
"'mode', '%s');" % mmode)
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
all_matlab_results[ma_key] = ma
|
||||
all_matlab_results[md_key] = md
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)
|
|
@ -0,0 +1,86 @@
|
|||
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
|
||||
the outputs from Matlab R2012a. """
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
try:
|
||||
from pymatbridge import Matlab
|
||||
mlab = Matlab()
|
||||
_matlab_missing = False
|
||||
except ImportError:
|
||||
print("To run Matlab compatibility tests you need to have MathWorks "
|
||||
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
|
||||
"package installed.")
|
||||
_matlab_missing = True
|
||||
|
||||
if _matlab_missing:
|
||||
raise EnvironmentError("Can't generate matlab data files without MATLAB")
|
||||
|
||||
size_set = 'reduced'
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
mlab.start()
|
||||
try:
|
||||
all_matlab_results = {}
|
||||
for wavelet in wavelets:
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
Scales = (1,np.arange(1,3))
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
all_matlab_results[psi_key] = psi
|
||||
for N in data_sizes:
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
|
||||
# Matlab result
|
||||
scale_count = 0
|
||||
for scales in Scales:
|
||||
scale_count += 1
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError(
|
||||
"Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is type float
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
|
||||
all_matlab_results[coefs_key] = coefs
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)
|
BIN
venv/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
Normal file
BIN
venv/Lib/site-packages/pywt/tests/data/wavelab_test_signals.npz
Normal file
Binary file not shown.
170
venv/Lib/site-packages/pywt/tests/test__pywt.py
Normal file
170
venv/Lib/site-packages/pywt/tests/test__pywt.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_upcoef_reconstruct():
|
||||
data = np.arange(3)
|
||||
a = pywt.downcoef('a', data, 'haar')
|
||||
d = pywt.downcoef('d', data, 'haar')
|
||||
|
||||
rec = (pywt.upcoef('a', a, 'haar', take=3) +
|
||||
pywt.upcoef('d', d, 'haar', take=3))
|
||||
assert_allclose(rec, data)
|
||||
|
||||
|
||||
def test_downcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.downcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_downcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16) + 1j * rstate.randn(16)
|
||||
nlevels = 3
|
||||
a = pywt.downcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_downcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
|
||||
|
||||
|
||||
def test_compare_downcoef_coeffs():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16)
|
||||
# compare downcoef against wavedec outputs
|
||||
for nlevels in [1, 2, 3]:
|
||||
for wavelet in pywt.wavelist():
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
|
||||
if isinstance(wavelet, pywt.Wavelet):
|
||||
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
|
||||
if nlevels <= max_level:
|
||||
a = pywt.downcoef('a', r, wavelet, level=nlevels)
|
||||
d = pywt.downcoef('d', r, wavelet, level=nlevels)
|
||||
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
|
||||
assert_allclose(a, coeffs[0])
|
||||
assert_allclose(d, coeffs[1])
|
||||
|
||||
|
||||
def test_upcoef_multilevel():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4)
|
||||
nlevels = 3
|
||||
# calling with level=1 nlevels times
|
||||
a1 = r.copy()
|
||||
for i in range(nlevels):
|
||||
a1 = pywt.upcoef('a', a1, 'haar', level=1)
|
||||
# call with level=nlevels once
|
||||
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
assert_allclose(a1, a3)
|
||||
|
||||
|
||||
def test_upcoef_complex():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(4) + 1j*rstate.randn(4)
|
||||
nlevels = 3
|
||||
a = pywt.upcoef('a', r, 'haar', level=nlevels)
|
||||
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
|
||||
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
|
||||
assert_allclose(a, a_ref)
|
||||
|
||||
|
||||
def test_upcoef_errs():
|
||||
# invalid part string (not 'a' or 'd')
|
||||
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
|
||||
|
||||
|
||||
def test_upcoef_and_downcoef_1d_only():
|
||||
# upcoef and downcoef raise a ValueError if data.ndim > 1d
|
||||
for ndim in [2, 3]:
|
||||
data = np.ones((8, )*ndim)
|
||||
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
|
||||
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
|
||||
|
||||
|
||||
def test_wavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.Wavelet('sym8')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_dwt_max_level():
|
||||
assert_(pywt.dwt_max_level(16, 2) == 4)
|
||||
assert_(pywt.dwt_max_level(16, 8) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 9) == 1)
|
||||
assert_(pywt.dwt_max_level(16, 10) == 0)
|
||||
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 10.) == 0)
|
||||
assert_(pywt.dwt_max_level(16, 18) == 0)
|
||||
|
||||
# accepts discrete Wavelet object or string as well
|
||||
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
|
||||
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
|
||||
|
||||
# string input that is not a discrete wavelet
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
|
||||
|
||||
# filter_len must be an integer >= 2
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
|
||||
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
|
||||
|
||||
|
||||
def test_ContinuousWavelet_errs():
|
||||
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
|
||||
|
||||
|
||||
def test_ContinuousWavelet_repr():
|
||||
from pywt._extensions import _pywt
|
||||
wavelet = _pywt.ContinuousWavelet('gaus2')
|
||||
|
||||
repr_wavelet = eval(wavelet.__repr__())
|
||||
|
||||
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
|
||||
|
||||
|
||||
def test_wavelist():
|
||||
for name in pywt.wavelist(family='coif'):
|
||||
assert_(name.startswith('coif'))
|
||||
|
||||
assert_('cgau7' in pywt.wavelist(kind='continuous'))
|
||||
assert_('sym20' in pywt.wavelist(kind='discrete'))
|
||||
assert_(len(pywt.wavelist(kind='continuous')) +
|
||||
len(pywt.wavelist(kind='discrete')) ==
|
||||
len(pywt.wavelist(kind='all')))
|
||||
|
||||
assert_raises(ValueError, pywt.wavelist, kind='foobar')
|
||||
|
||||
|
||||
def test_wavelet_errormsgs():
|
||||
try:
|
||||
pywt.Wavelet('gaus1')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0].startswith('The `Wavelet` class'))
|
||||
try:
|
||||
pywt.Wavelet('cmord')
|
||||
except ValueError as e:
|
||||
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")
|
105
venv/Lib/site-packages/pywt/tests/test_concurrent.py
Normal file
105
venv/Lib/site-packages/pywt/tests/test_concurrent.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
"""
|
||||
Tests used to verify running PyWavelets transforms in parallel via
|
||||
concurrent.futures.ThreadPoolExecutor does not raise errors.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from numpy.testing import assert_array_equal, assert_allclose
|
||||
from pywt._pytest import uses_futures, futures, max_workers
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def _assert_all_coeffs_equal(coefs1, coefs2):
|
||||
# return True only if all coefficients of SWT or DWT match over all levels
|
||||
if len(coefs1) != len(coefs2):
|
||||
return False
|
||||
for (c1, c2) in zip(coefs1, coefs2):
|
||||
if isinstance(c1, tuple):
|
||||
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
|
||||
for a1, a2 in zip(c1, c2):
|
||||
assert_array_equal(a1, a2)
|
||||
elif isinstance(c1, dict):
|
||||
# for swtn, dwtn, wavedecn
|
||||
for k, v in c1.items():
|
||||
assert_array_equal(v, c2[k])
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_swt():
|
||||
# tests error-free concurrent operation (see gh-288)
|
||||
# swt on 1D data calls the Cython swt
|
||||
# other cases call swt_axes
|
||||
with warnings.catch_warnings():
|
||||
# can remove catch_warnings once the swt2 FutureWarning is removed
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(swt_func, wavelet='haar', level=3)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_wavedec():
|
||||
# wavedec on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(wavedec_func, wavelet='haar', level=1)
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal(expected_result, results[-1])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_dwt():
|
||||
# dwt on 1D data calls the Cython dwt_single
|
||||
# other cases call dwt_axis
|
||||
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
|
||||
[np.ones(8), np.eye(16), np.eye(16)]):
|
||||
transform = partial(dwt_func, wavelet='haar')
|
||||
for _ in range(10):
|
||||
arrs = [x.copy() for _ in range(100)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(x)
|
||||
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
|
||||
|
||||
|
||||
@uses_futures
|
||||
def test_concurrent_cwt():
|
||||
atol = rtol = 1e-14
|
||||
time, sst = pywt.data.nino()
|
||||
dt = time[1]-time[0]
|
||||
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
|
||||
sampling_period=dt)
|
||||
for _ in range(10):
|
||||
arrs = [sst.copy() for _ in range(50)]
|
||||
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
|
||||
results = list(ex.map(transform, arrs))
|
||||
|
||||
# validate result from one of the concurrent runs
|
||||
expected_result = transform(sst)
|
||||
for a1, a2 in zip(expected_result, results[-1]):
|
||||
assert_allclose(a1, a2, atol=atol, rtol=rtol)
|
434
venv/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
Normal file
434
venv/Lib/site-packages/pywt/tests/test_cwt_wavelets.py
Normal file
|
@ -0,0 +1,434 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
from itertools import product
|
||||
|
||||
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
|
||||
assert_raises, assert_equal)
|
||||
import pytest
|
||||
import numpy as np
|
||||
import pywt
|
||||
|
||||
|
||||
def ref_gaus(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
|
||||
if (num == 1):
|
||||
psi = -2.*X*F0
|
||||
elif (num == 2):
|
||||
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
|
||||
elif (num == 3):
|
||||
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
|
||||
elif (num == 4):
|
||||
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
|
||||
elif (num == 5):
|
||||
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
|
||||
elif (num == 6):
|
||||
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
|
||||
elif (num == 7):
|
||||
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
|
||||
elif (num == 8):
|
||||
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
|
||||
224*X**6 + 16*X**8)*F0
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def ref_cgau(LB, UB, N, num):
|
||||
X = np.linspace(LB, UB, N)
|
||||
F0 = np.exp(-X**2)
|
||||
F1 = np.exp(-1j*X)
|
||||
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
|
||||
if (num == 1):
|
||||
psi = F2*(-1j - 2*X)*2**(1/2)
|
||||
elif (num == 2):
|
||||
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
|
||||
elif (num == 3):
|
||||
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
|
||||
elif (num == 4):
|
||||
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
|
||||
elif (num == 5):
|
||||
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
|
||||
80j*X**4 - 32*X**5)*210**(1/2)
|
||||
elif (num == 6):
|
||||
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
|
||||
192j*X**5 + 64*X**6)*2310**(1/2)
|
||||
elif (num == 7):
|
||||
psi = 1/45045*F2*(
|
||||
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
|
||||
448j*X**6 - 128*X**7)*30030**(1/2)
|
||||
elif (num == 8):
|
||||
psi = 1/45045*F2*(
|
||||
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
|
||||
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
|
||||
|
||||
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
|
||||
return (psi, X)
|
||||
|
||||
|
||||
def sinc2(x):
|
||||
y = np.ones_like(x)
|
||||
k = np.where(x)[0]
|
||||
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
|
||||
return y
|
||||
|
||||
|
||||
def ref_shan(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_fbsp(LB, UB, N, m, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_cmor(LB, UB, N, Fb, Fc):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_morl(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = np.exp(-(x**2)/2)*np.cos(5*x)
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def ref_mexh(LB, UB, N):
|
||||
x = np.linspace(LB, UB, N)
|
||||
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
|
||||
return (psi, x)
|
||||
|
||||
|
||||
def test_gaus():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_gaus(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("gaus" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cgau():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
for num in np.arange(1, 9):
|
||||
[psi, x] = ref_cgau(LB, UB, N, num)
|
||||
w = pywt.ContinuousWavelet("cgau" + str(num))
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_shan():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_cmor():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_fbsp():
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1
|
||||
Fc = 1.5
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 2
|
||||
Fb = 1.5
|
||||
Fc = 1
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
LB = -20
|
||||
UB = 20
|
||||
N = 1000
|
||||
M = 3
|
||||
Fb = 1.5
|
||||
Fc = 1.2
|
||||
|
||||
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
|
||||
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
|
||||
assert_almost_equal(w.center_frequency, Fc)
|
||||
assert_almost_equal(w.bandwidth_frequency, Fb)
|
||||
w.fbsp_order = M
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
# TODO: investigate why atol = 1e-5 is necessary
|
||||
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
|
||||
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
|
||||
assert_allclose(X, x, atol=1e-15)
|
||||
|
||||
|
||||
def test_morl():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_morl(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("morl")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_mexh():
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1000
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
LB = -5
|
||||
UB = 5
|
||||
N = 1001
|
||||
|
||||
[psi, x] = ref_mexh(LB, UB, N)
|
||||
w = pywt.ContinuousWavelet("mexh")
|
||||
w.upper_bound = UB
|
||||
w.lower_bound = LB
|
||||
PSI, X = w.wavefun(length=N)
|
||||
|
||||
assert_allclose(np.real(PSI), np.real(psi))
|
||||
assert_allclose(np.imag(PSI), np.imag(psi))
|
||||
assert_allclose(X, x)
|
||||
|
||||
|
||||
def test_cwt_parameters_in_names():
|
||||
|
||||
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
|
||||
for name in ['fbsp', 'cmor', 'shan']:
|
||||
# additional parameters should be specified within the name
|
||||
assert_warns(FutureWarning, func, name)
|
||||
|
||||
for name in ['cmor', 'shan']:
|
||||
# valid names
|
||||
func(name + '1.5-1.0')
|
||||
func(name + '1-4')
|
||||
|
||||
# invalid names
|
||||
assert_raises(ValueError, func, name + '1.0')
|
||||
assert_raises(ValueError, func, name + 'B-C')
|
||||
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
|
||||
|
||||
# valid names
|
||||
func('fbsp1-1.5-1.0')
|
||||
func('fbsp1.0-1.5-1')
|
||||
func('fbsp2-5-1')
|
||||
|
||||
# invalid name (non-integer order)
|
||||
assert_raises(ValueError, func, 'fbsp1.5-1-1')
|
||||
assert_raises(ValueError, func, 'fbspM-B-C')
|
||||
|
||||
# invalid name (too few or too many params)
|
||||
assert_raises(ValueError, func, 'fbsp1.0')
|
||||
assert_raises(ValueError, func, 'fbsp1.0-0.4')
|
||||
assert_raises(ValueError, func, 'fbsp1-1-1-1')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype, tol, method',
|
||||
[(np.float32, 1e-5, 'conv'),
|
||||
(np.float32, 1e-5, 'fft'),
|
||||
(np.float64, 1e-13, 'conv'),
|
||||
(np.float64, 1e-13, 'fft')])
|
||||
def test_cwt_complex(dtype, tol, method):
|
||||
time, sst = pywt.data.nino()
|
||||
sst = np.asarray(sst, dtype=dtype)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# real-valued tranfsorm as a reference
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# complex-valued transform equals sum of the transforms of the real
|
||||
# and imaginary components
|
||||
sst_complex = sst + 1j*sst
|
||||
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
|
||||
method=method)
|
||||
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
|
||||
# verify dtype is preserved
|
||||
assert_equal(cfs_complex.dtype, sst_complex.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
|
||||
def test_cwt_batch(axis, method):
|
||||
dtype = np.float64
|
||||
time, sst = pywt.data.nino()
|
||||
n_batch = 8
|
||||
batch_axis = 1 - axis
|
||||
sst1 = np.asarray(sst, dtype=dtype)
|
||||
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
|
||||
dt = time[1] - time[0]
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
scales = np.arange(1, 32)
|
||||
|
||||
# non-batch transform as reference
|
||||
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
shape_in = sst.shape
|
||||
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
|
||||
|
||||
# shape of input is not modified
|
||||
assert_equal(shape_in, sst.shape)
|
||||
|
||||
# verify same precision
|
||||
assert_equal(cfs.real.dtype, sst.dtype)
|
||||
|
||||
# verify expected shape
|
||||
assert_equal(cfs.shape[0], len(scales))
|
||||
assert_equal(cfs.shape[1 + batch_axis], n_batch)
|
||||
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
|
||||
|
||||
# batch result on stacked input is the same as stacked 1d result
|
||||
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
|
||||
|
||||
|
||||
def test_cwt_small_scales():
|
||||
data = np.zeros(32)
|
||||
|
||||
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
|
||||
# mexh. This corner case should not raise an error.
|
||||
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
|
||||
assert_allclose(cfs, np.zeros_like(cfs))
|
||||
|
||||
# extremely short scale factors raise a ValueError
|
||||
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
|
||||
|
||||
|
||||
def test_cwt_method_fft():
|
||||
rstate = np.random.RandomState(1)
|
||||
data = rstate.randn(50)
|
||||
data[15] = 1.
|
||||
scales = np.arange(1, 64)
|
||||
wavelet = 'cmor1.5-1.0'
|
||||
|
||||
# build a reference cwt with the legacy np.conv() method
|
||||
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
|
||||
|
||||
# compare with the fft based convolution
|
||||
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
|
||||
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)
|
77
venv/Lib/site-packages/pywt/tests/test_data.py
Normal file
77
venv/Lib/site-packages/pywt/tests/test_data.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_
|
||||
|
||||
import pywt.data
|
||||
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
|
||||
wavelab_result_dict = np.load(wavelab_data_file)
|
||||
|
||||
|
||||
def test_data_aero():
|
||||
aero = pywt.data.aero()
|
||||
|
||||
ref = np.array([[178, 178, 179],
|
||||
[170, 173, 171],
|
||||
[185, 174, 171]])
|
||||
|
||||
assert_allclose(aero[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ascent():
|
||||
ascent = pywt.data.ascent()
|
||||
|
||||
ref = np.array([[83, 83, 83],
|
||||
[82, 82, 83],
|
||||
[80, 81, 83]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_camera():
|
||||
ascent = pywt.data.camera()
|
||||
|
||||
ref = np.array([[156, 157, 160],
|
||||
[156, 157, 159],
|
||||
[158, 157, 156]])
|
||||
|
||||
assert_allclose(ascent[:3, :3], ref)
|
||||
|
||||
|
||||
def test_data_ecg():
|
||||
ecg = pywt.data.ecg()
|
||||
|
||||
ref = np.array([-86, -87, -87])
|
||||
|
||||
assert_allclose(ecg[:3], ref)
|
||||
|
||||
|
||||
def test_wavelab_signals():
|
||||
"""Comparison with results generated using WaveLab"""
|
||||
rtol = atol = 1e-12
|
||||
|
||||
# get a list of the available signals
|
||||
available_signals = pywt.data.demo_signal('list')
|
||||
assert_('Doppler' in available_signals)
|
||||
|
||||
for signal in available_signals:
|
||||
# reference dictionary has lowercase names for the keys
|
||||
key = signal.replace('-', '_').lower()
|
||||
val = wavelab_result_dict[key]
|
||||
if key in ['gabor', 'sineoneoverx']:
|
||||
# these functions do not allow a size to be provided
|
||||
assert_allclose(val, pywt.data.demo_signal(signal),
|
||||
rtol=rtol, atol=atol)
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
|
||||
else:
|
||||
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
|
||||
rtol=rtol, atol=atol)
|
||||
# these functions require a size to be provided
|
||||
assert_raises(ValueError, pywt.data.demo_signal, key)
|
||||
|
||||
# ValueError on unrecognized signal type
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
|
||||
|
||||
# ValueError on invalid length
|
||||
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)
|
89
venv/Lib/site-packages/pywt/tests/test_deprecations.py
Normal file
89
venv/Lib/site-packages/pywt/tests/test_deprecations.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_warns, assert_array_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_intwave_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
|
||||
|
||||
|
||||
def test_centrfrq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
|
||||
|
||||
|
||||
def test_scal2frq_deprecation():
|
||||
wavelet = pywt.Wavelet('db3')
|
||||
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
|
||||
|
||||
|
||||
def test_orthfilt_deprecation():
|
||||
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
|
||||
|
||||
|
||||
def test_integrate_wave_tuple():
|
||||
sig = [0, 1, 2, 3]
|
||||
xgrid = [0, 1, 2, 3]
|
||||
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
|
||||
|
||||
|
||||
old_modes = ['zpd',
|
||||
'cpd',
|
||||
'sym',
|
||||
'ppd',
|
||||
'sp1',
|
||||
'per',
|
||||
]
|
||||
|
||||
|
||||
def test_MODES_from_object_deprecation():
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
|
||||
|
||||
|
||||
def test_MODES_attributes_deprecation():
|
||||
def get_mode(Modes, name):
|
||||
return getattr(Modes, name)
|
||||
|
||||
for mode in old_modes:
|
||||
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
|
||||
|
||||
|
||||
def test_MODES_deprecation_new():
|
||||
def use_MODES_new():
|
||||
return pywt.MODES.symmetric
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_MODES_deprecation_old():
|
||||
def use_MODES_old():
|
||||
return pywt.MODES.sym
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_old)
|
||||
|
||||
|
||||
def test_MODES_deprecation_getattr():
|
||||
def use_MODES_new():
|
||||
return getattr(pywt.MODES, 'symmetric')
|
||||
|
||||
assert_warns(DeprecationWarning, use_MODES_new)
|
||||
|
||||
|
||||
def test_mode_equivalence():
|
||||
old_new = [('zpd', 'zero'),
|
||||
('cpd', 'constant'),
|
||||
('sym', 'symmetric'),
|
||||
('ppd', 'periodic'),
|
||||
('sp1', 'smooth'),
|
||||
('per', 'periodization')]
|
||||
x = np.arange(8.)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
for old, new in old_new:
|
||||
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
|
||||
pywt.dwt(x, 'db2', mode=new))
|
25
venv/Lib/site-packages/pywt/tests/test_doc.py
Normal file
25
venv/Lib/site-packages/pywt/tests/test_doc.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import doctest
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
np.set_printoptions(legacy='1.13')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
pdir = os.path.pardir
|
||||
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
|
||||
pdir, pdir, "doc", "source"))
|
||||
|
||||
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
|
||||
glob.glob(os.path.join(docs_base, "*", "*.rst"))
|
||||
|
||||
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.TextTestRunner().run(suite)
|
299
venv/Lib/site-packages/pywt/tests/test_dwt_idwt.py
Normal file
299
venv/Lib/site-packages/pywt/tests/test_dwt_idwt.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_array_equal)
|
||||
import pywt
|
||||
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwt_idwt_basic():
|
||||
x = [3, 7, 1, 1, -2, 5, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
|
||||
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.float64)
|
||||
|
||||
|
||||
def test_idwt_mixed_complex_dtype():
|
||||
x = np.arange(8).astype(float)
|
||||
x = x + 1j*x[::-1]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
|
||||
'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_dwt_idwt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones(4, dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet)
|
||||
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
|
||||
|
||||
|
||||
def test_dwt_idwt_basic_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
|
||||
7.77817459])
|
||||
cA_expect = cA_expect + 0.5j*cA_expect
|
||||
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
|
||||
1.22474487])
|
||||
cD_expect = cD_expect + 0.5j*cD_expect
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
x_roundtrip = pywt.idwt(cA, cD, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_idwt_partial_complex():
|
||||
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
x = x + 0.5j*x
|
||||
|
||||
cA, cD = pywt.dwt(x, 'haar')
|
||||
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
|
||||
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
|
||||
cA_rec = pywt.idwt(cA, None, 'haar')
|
||||
assert_allclose(cA_rec, cA_rec_expect)
|
||||
|
||||
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
|
||||
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
|
||||
cD_rec = pywt.idwt(None, cD, 'haar')
|
||||
assert_allclose(cD_rec, cD_rec_expect)
|
||||
|
||||
assert_allclose(cA_rec + cD_rec, x)
|
||||
|
||||
|
||||
def test_dwt_wavelet_kwd():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
|
||||
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
|
||||
7.81994027]
|
||||
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
|
||||
-0.09722957, -0.07045258]
|
||||
assert_allclose(cA, cA_expect)
|
||||
assert_allclose(cD, cD_expect)
|
||||
|
||||
|
||||
def test_dwt_coeff_len():
|
||||
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
|
||||
w = pywt.Wavelet('sym3')
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
|
||||
expected_result = [6, ] * len(pywt.Modes.modes)
|
||||
expected_result[pywt.Modes.modes.index('periodization')] = 4
|
||||
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
|
||||
pywt.Modes.modes]
|
||||
assert_allclose(ln_modes, expected_result)
|
||||
|
||||
|
||||
def test_idwt_none_input():
|
||||
# None input equals arrays of zeros of the right length
|
||||
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
|
||||
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
|
||||
|
||||
# Only one argument at a time can be None
|
||||
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
|
||||
|
||||
|
||||
def test_idwt_invalid_input():
|
||||
# Too short, min length is 4 for 'db4':
|
||||
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
|
||||
|
||||
|
||||
def test_dwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
cA0, cD0 = pywt.dwt(x[0], 'db2')
|
||||
cA1, cD1 = pywt.dwt(x[1], 'db2')
|
||||
|
||||
assert_allclose(cA[0], cA0)
|
||||
assert_allclose(cA[1], cA1)
|
||||
|
||||
assert_allclose(cD[0], cD0)
|
||||
assert_allclose(cD[1], cD1)
|
||||
|
||||
|
||||
def test_idwt_single_axis():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
x = np.asarray(x)
|
||||
x = x + 1j*x # test with complex data
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=-1)
|
||||
|
||||
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
|
||||
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
|
||||
|
||||
assert_allclose(x[0], x0)
|
||||
assert_allclose(x[1], x1)
|
||||
|
||||
|
||||
def test_dwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
assert_allclose(cA_, cA)
|
||||
assert_allclose(cD_, cD)
|
||||
|
||||
|
||||
def test_idwt_axis_arg():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
|
||||
cA, cD = pywt.dwt(x, 'db2', axis=1)
|
||||
|
||||
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
|
||||
x = pywt.idwt(cA, cD, 'db2', axis=1)
|
||||
|
||||
assert_allclose(x_, x)
|
||||
|
||||
|
||||
def test_dwt_idwt_axis_excess():
|
||||
x = [[3, 7, 1, 1],
|
||||
[-2, 5, 4, 6]]
|
||||
# can't transform over axes that aren't there
|
||||
assert_raises(ValueError,
|
||||
pywt.dwt, x, 'db2', 'symmetric', axis=2)
|
||||
|
||||
assert_raises(ValueError,
|
||||
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((32, ))
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, pywt.dwt, data, cwave)
|
||||
|
||||
cA, cD = pywt.dwt(data, 'db1')
|
||||
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
|
||||
|
||||
|
||||
def test_dwt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.dwt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
|
||||
|
||||
|
||||
def test_pad_1d():
|
||||
x = [1, 2, 3]
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
|
||||
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
|
||||
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
|
||||
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
|
||||
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
|
||||
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
|
||||
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
|
||||
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
|
||||
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
|
||||
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
# equivalence of various pad_width formats
|
||||
assert_array_equal(pywt.pad(x, 4, 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
|
||||
pywt.pad(x, (4, 4), 'periodic'))
|
||||
|
||||
|
||||
def test_pad_errors():
|
||||
# negative pad width
|
||||
x = [1, 2, 3]
|
||||
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
|
||||
|
||||
# wrong length pad width
|
||||
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
|
||||
|
||||
# invalid mode name
|
||||
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
|
||||
|
||||
|
||||
def test_pad_nd():
|
||||
for ndim in [2, 3]:
|
||||
x = np.arange(4**ndim).reshape((4, ) * ndim)
|
||||
if ndim == 2:
|
||||
pad_widths = [(2, 1), (2, 3)]
|
||||
else:
|
||||
pad_widths = [(2, 1), ] * ndim
|
||||
for mode in pywt.Modes.modes:
|
||||
xp = pywt.pad(x, pad_widths, mode)
|
||||
|
||||
# expected result is the same as applying along axes separably
|
||||
xp_expected = x.copy()
|
||||
for ax in range(ndim):
|
||||
xp_expected = np.apply_along_axis(pywt.pad,
|
||||
ax,
|
||||
xp_expected,
|
||||
pad_widths=[pad_widths[ax]],
|
||||
mode=mode)
|
||||
assert_array_equal(xp, xp_expected)
|
38
venv/Lib/site-packages/pywt/tests/test_functions.py
Normal file
38
venv/Lib/site-packages/pywt/tests/test_functions.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
from numpy.testing import assert_almost_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_centrfreq():
|
||||
# db1 is Haar function, frequency=1
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
# db2, frequency=2/3
|
||||
w = pywt.Wavelet('db2')
|
||||
expected = 2/3.
|
||||
result = pywt.central_frequency(w, precision=12)
|
||||
assert_almost_equal(result, expected)
|
||||
|
||||
|
||||
def test_scal2frq_scale():
|
||||
scale = 2
|
||||
w = pywt.Wavelet('db1')
|
||||
expected = 1. / scale
|
||||
result = pywt.scale2frequency(w, scale, precision=12)
|
||||
assert_almost_equal(result, expected, decimal=3)
|
||||
|
||||
|
||||
def test_intwave_orthogonal():
|
||||
w = pywt.Wavelet('db1')
|
||||
int_psi, x = pywt.integrate_wavelet(w, precision=12)
|
||||
ix = x < 0.5
|
||||
# For x < 0.5, the integral is equal to x
|
||||
assert_allclose(int_psi[ix], x[ix])
|
||||
# For x > 0.5, the integral is equal to (1 - x)
|
||||
# Ignore last point here, there x > 1 and something goes wrong
|
||||
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)
|
160
venv/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
Normal file
160
venv/Lib/site-packages/pywt/tests/test_matlab_compatibility.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
Test used to verify PyWavelets Discrete Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
|
||||
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
|
||||
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('reflect', 'symw'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per'),
|
||||
# TODO: Now have implemented asymmetric modes too.
|
||||
# Would be nice to update the Matlab data to test these as well.
|
||||
('antisymmetric', 'asym'),
|
||||
('antireflect', 'asymw'),
|
||||
]
|
||||
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(w.dec_len, 40)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (w.dec_len, w.dec_len + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
@uses_pymatbridge
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 5.0e-5
|
||||
epsilon_pywt_coeffs = 1.0e-10
|
||||
for wavelet in wavelets:
|
||||
w = pywt.Wavelet(wavelet)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
for pmode, mmode in modes:
|
||||
ma, md = _load_matlab_result(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
|
||||
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
|
||||
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, mmode, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
|
||||
w = pywt.Wavelet(wavelet)
|
||||
mlab.set_variable('Lo_D', w.dec_lo)
|
||||
mlab.set_variable('Hi_D', w.dec_hi)
|
||||
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
|
||||
else:
|
||||
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
ma = np.asarray(mlab.get_variable('ma'))
|
||||
md = np.asarray(mlab.get_variable('md'))
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
|
||||
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
|
||||
if (ma_key not in matlab_result_dict) or \
|
||||
(md_key not in matlab_result_dict):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
|
||||
ma = matlab_result_dict[ma_key]
|
||||
md = matlab_result_dict[md_key]
|
||||
return ma, md
|
||||
|
||||
|
||||
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
pa, pd = pywt.dwt(data, w, pmode)
|
||||
|
||||
# calculate error measures
|
||||
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
|
||||
rms_d = np.sqrt(np.mean((pd - md) ** 2))
|
||||
|
||||
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
|
||||
assert_(rms_a < epsilon, msg=msg)
|
||||
|
||||
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
|
||||
assert_(rms_d < epsilon, msg=msg)
|
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
Test used to verify PyWavelets Continuous Wavelet Transform computation
|
||||
accuracy against MathWorks Wavelet Toolbox.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
|
||||
matlab_result_dict_cwt)
|
||||
|
||||
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
|
||||
|
||||
def _get_data_sizes(w):
|
||||
""" Return the sizes to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
data_sizes = list(range(100, 101)) + \
|
||||
[100, 200, 500, 1000, 50000]
|
||||
else:
|
||||
data_sizes = (1000, 1000 + 1)
|
||||
return data_sizes
|
||||
|
||||
|
||||
def _get_scales(w):
|
||||
""" Return the scales to test for wavelet w. """
|
||||
if size_set == 'full':
|
||||
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
|
||||
else:
|
||||
scales = (1, np.arange(1, 3))
|
||||
return scales
|
||||
|
||||
|
||||
@uses_pymatbridge # skip this case if precomputed results are used instead
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_pymatbridge_cwt():
|
||||
Matlab = pytest.importorskip("pymatbridge.Matlab")
|
||||
mlab = Matlab()
|
||||
rstate = np.random.RandomState(1234)
|
||||
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
|
||||
epsilon = 1e-15
|
||||
epsilon_psi = 1e-15
|
||||
mlab.start()
|
||||
try:
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
|
||||
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
elif wavelet == 'fbsp':
|
||||
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
|
||||
else:
|
||||
mlab.set_variable('wavelet', wavelet)
|
||||
mlab_code = ("psi = wavefun(wavelet,10)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
psi = np.asarray(mlab.get_variable('psi'))
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
mlab.set_variable('data', data)
|
||||
for scales in _get_scales(w):
|
||||
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
|
||||
finally:
|
||||
mlab.stop()
|
||||
|
||||
|
||||
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
|
||||
@pytest.mark.slow
|
||||
def test_accuracy_precomputed_cwt():
|
||||
# Keep this specific random seed to match the precomputed Matlab result.
|
||||
rstate = np.random.RandomState(1234)
|
||||
# has to be improved
|
||||
epsilon = 2e-15
|
||||
epsilon32 = 1e-5
|
||||
epsilon_psi = 1e-15
|
||||
for wavelet in wavelets:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
w = pywt.ContinuousWavelet(wavelet)
|
||||
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
|
||||
psi = _load_matlab_result_psi(wavelet)
|
||||
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
|
||||
|
||||
for N in _get_data_sizes(w):
|
||||
data = rstate.randn(N)
|
||||
data32 = data.astype(np.float32)
|
||||
scales_count = 0
|
||||
for scales in _get_scales(w):
|
||||
scales_count += 1
|
||||
coefs = _load_matlab_result(data, wavelet, scales_count)
|
||||
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
|
||||
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
|
||||
|
||||
|
||||
def _compute_matlab_result(data, wavelet, scales, mlab):
|
||||
""" Compute the result using MATLAB.
|
||||
|
||||
This function assumes that the Matlab variables `wavelet` and `data` have
|
||||
already been set externally.
|
||||
"""
|
||||
mlab.set_variable('scales', scales)
|
||||
mlab_code = ("coefs = cwt(data, scales, wavelet)")
|
||||
res = mlab.run_code(mlab_code)
|
||||
if not res['success']:
|
||||
raise RuntimeError("Matlab failed to execute the provided code. "
|
||||
"Check that the wavelet toolbox is installed.")
|
||||
# need np.asarray because sometimes the output is a single float64
|
||||
coefs = np.asarray(mlab.get_variable('coefs'))
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result(data, wavelet, scales):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
N = len(data)
|
||||
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
|
||||
if (coefs_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab result not found for wavelet: "
|
||||
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
|
||||
coefs = matlab_result_dict_cwt[coefs_key]
|
||||
return coefs
|
||||
|
||||
|
||||
def _load_matlab_result_psi(wavelet):
|
||||
""" Load the precomputed result.
|
||||
"""
|
||||
psi_key = '_'.join([wavelet, 'psi'])
|
||||
if (psi_key not in matlab_result_dict_cwt):
|
||||
raise KeyError(
|
||||
"Precompted Matlab psi result not found for wavelet: "
|
||||
"{0}}".format(wavelet))
|
||||
psi = matlab_result_dict_cwt[psi_key]
|
||||
return psi
|
||||
|
||||
|
||||
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
coefs_pywt, freq = pywt.cwt(data, scales, w)
|
||||
|
||||
# coefs from Matlab are from R2012a which is missing the complex conjugate
|
||||
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
|
||||
# the precomputed Matlab result to account for this.
|
||||
coefs = np.conj(coefs)
|
||||
|
||||
# calculate error measures
|
||||
err = coefs_pywt - coefs
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
||||
|
||||
|
||||
def _check_accuracy_psi(w, psi, wavelet, epsilon):
|
||||
# PyWavelets result
|
||||
psi_pywt, x = w.wavefun(length=1024)
|
||||
|
||||
# calculate error measures
|
||||
err = psi_pywt.flatten() - psi.flatten()
|
||||
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
|
||||
|
||||
msg = ('[RMS > EPSILON] for Wavelet: %s, '
|
||||
'rms=%.3g' % (wavelet, rms))
|
||||
assert_(rms < epsilon, msg=msg)
|
109
venv/Lib/site-packages/pywt/tests/test_modes.py
Normal file
109
venv/Lib/site-packages/pywt/tests/test_modes.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_raises, assert_equal, assert_allclose
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_available_modes():
|
||||
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
|
||||
'periodization', 'reflect', 'antisymmetric', 'antireflect']
|
||||
assert_equal(pywt.Modes.modes, modes)
|
||||
assert_equal(pywt.Modes.from_object('constant'), 2)
|
||||
|
||||
|
||||
def test_invalid_modes():
|
||||
x = np.arange(4)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
|
||||
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
|
||||
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
|
||||
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
|
||||
assert_raises(ValueError, pywt.Modes.from_object, -1)
|
||||
assert_raises(ValueError, pywt.Modes.from_object, 9)
|
||||
assert_raises(TypeError, pywt.Modes.from_object, None)
|
||||
|
||||
|
||||
def test_dwt_idwt_allmodes():
|
||||
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
dwt_results = {
|
||||
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
|
||||
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.8625013]),
|
||||
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.51935555],
|
||||
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
|
||||
0.25881905]),
|
||||
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.77817459],
|
||||
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
|
||||
1.22474487]),
|
||||
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.81224877],
|
||||
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
|
||||
-2.38013939]),
|
||||
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.9162743],
|
||||
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
|
||||
-1.99191082]),
|
||||
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
|
||||
7.45000519],
|
||||
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
|
||||
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
|
||||
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
|
||||
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
|
||||
6.12372436],
|
||||
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
|
||||
-4.94974747]),
|
||||
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
|
||||
8.22646233],
|
||||
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
|
||||
2.89777748])
|
||||
}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
cA, cD = pywt.dwt(x, 'db2', mode)
|
||||
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
|
||||
|
||||
|
||||
def test_dwt_short_input_allmodes():
|
||||
# some test cases where the input is shorter than the DWT filter
|
||||
x = [1, 3, 2]
|
||||
wavelet = 'db2'
|
||||
# manually pad each end by the filter size (4 for 'db2' used here)
|
||||
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
|
||||
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
|
||||
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
|
||||
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
|
||||
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
|
||||
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
|
||||
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
|
||||
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
|
||||
}
|
||||
for mode, xpad in padded_x.items():
|
||||
# DWT of the manually padded array. will discard edges later so
|
||||
# symmetric mode used here doesn't matter.
|
||||
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
|
||||
|
||||
# central region of the padded output (unaffected by mode )
|
||||
expected_result = (cApad[2:-2], cDpad[2:-2])
|
||||
|
||||
cA, cD = pywt.dwt(x, wavelet, mode)
|
||||
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
|
||||
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
|
||||
|
||||
|
||||
def test_default_mode():
|
||||
# The default mode should be 'symmetric'
|
||||
x = [1, 2, 1, 5, -1, 8, 4, 6]
|
||||
cA, cD = pywt.dwt(x, 'db2')
|
||||
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
|
||||
assert_allclose(cA, cA2)
|
||||
assert_allclose(cD, cD2)
|
||||
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
|
443
venv/Lib/site-packages/pywt/tests/test_multidim.py
Normal file
443
venv/Lib/site-packages/pywt/tests/test_multidim.py
Normal file
|
@ -0,0 +1,443 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from itertools import combinations
|
||||
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
|
||||
|
||||
import pywt
|
||||
# Check that float32, float64, complex64, complex128 are preserved.
|
||||
# Other real types get converted to float64.
|
||||
# complex256 gets converted to complex128
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# test complex256 as well if it is available
|
||||
try:
|
||||
dtypes_in += [np.complex256, ]
|
||||
dtypes_out += [np.complex128, ]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def test_dwtn_input():
|
||||
# Array-like must be accepted
|
||||
pywt.dwtn([1, 2, 3, 4], 'haar')
|
||||
# Others must not
|
||||
data = dict()
|
||||
assert_raises(TypeError, pywt.dwtn, data, 'haar')
|
||||
# Must be at least 1D
|
||||
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
|
||||
|
||||
|
||||
def test_3D_reconstruct():
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for mode in pywt.Modes.modes:
|
||||
d = pywt.dwtn(data, wavelet, mode=mode)
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_dwdtn_idwtn_allwavelets():
|
||||
rstate = np.random.RandomState(1234)
|
||||
r = rstate.randn(16, 16)
|
||||
# test 2D case only for all wavelet types
|
||||
wavelist = pywt.wavelist()
|
||||
if 'dmey' in wavelist:
|
||||
wavelist.remove('dmey')
|
||||
for wavelet in wavelist:
|
||||
if wavelet in ['cmor', 'shan', 'fbsp']:
|
||||
# skip these CWT families to avoid warnings
|
||||
continue
|
||||
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
|
||||
for mode in pywt.Modes.modes:
|
||||
coeffs = pywt.dwtn(r, wavelet, mode=mode)
|
||||
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
|
||||
r, rtol=1e-7, atol=1e-7)
|
||||
|
||||
|
||||
def test_stride():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
strided = np.ones((3, 12), dtype=data.dtype)
|
||||
strided[::-1, ::2] = data
|
||||
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(strided_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_byte_offset():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dtype in ('float32', 'float64'):
|
||||
data = np.array([[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
dtype=dtype)
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
expected = pywt.dwtn(data, wavelet)
|
||||
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
|
||||
'pad': ('byte', data.dtype.itemsize)},
|
||||
align=True))
|
||||
padded[:] = data
|
||||
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
|
||||
for key in expected.keys():
|
||||
assert_allclose(padded_dwtn[key], expected[key])
|
||||
|
||||
|
||||
def test_3D_reconstruct_complex():
|
||||
# All dimensions even length so `take` does not need to be specified
|
||||
data = np.array([
|
||||
[[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 26, 3, 2, 1],
|
||||
[5, 8, 2, 33, 4, 9],
|
||||
[2, 5, 19, 4, 19, 1]],
|
||||
[[1, 5, 1, 2, 3, 4],
|
||||
[7, 12, 6, 52, 7, 8],
|
||||
[2, 12, 3, 52, 6, 8],
|
||||
[5, 2, 6, 78, 12, 2]]])
|
||||
data = data + 1j
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
d = pywt.dwtn(data, wavelet)
|
||||
# idwtn creates even-length shapes (2x dwtn size)
|
||||
original_shape = tuple([slice(None, s) for s in data.shape])
|
||||
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
|
||||
rtol=1e-13, atol=1e-13)
|
||||
|
||||
|
||||
def test_idwtn_idwt2():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_idwt2_complex():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
data = data + 1j
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
for mode in pywt.Modes.modes:
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
|
||||
pywt.idwtn(d, wavelet, mode=mode),
|
||||
rtol=1e-14, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_missing():
|
||||
# Test to confirm missing data behave as zeroes
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
coefs = pywt.dwtn(data, wavelet)
|
||||
|
||||
# No point removing zero, or all
|
||||
for num_missing in range(1, len(coefs)):
|
||||
for missing in combinations(coefs.keys(), num_missing):
|
||||
missing_coefs = coefs.copy()
|
||||
for key in missing:
|
||||
del missing_coefs[key]
|
||||
LL = missing_coefs.get('aa', None)
|
||||
HL = missing_coefs.get('da', None)
|
||||
LH = missing_coefs.get('ad', None)
|
||||
HH = missing_coefs.get('dd', None)
|
||||
|
||||
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
|
||||
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
|
||||
|
||||
|
||||
def test_idwtn_all_coeffs_None():
|
||||
coefs = dict(aa=None, da=None, ad=None, dd=None)
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
|
||||
|
||||
|
||||
def test_error_on_invalid_keys():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# unexpected key
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
# mismatched key lengths
|
||||
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_error_mismatched_size():
|
||||
data = np.array([
|
||||
[0, 4, 1, 5, 1, 4],
|
||||
[0, 5, 6, 3, 2, 1],
|
||||
[2, 5, 19, 4, 19, 1]])
|
||||
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
|
||||
|
||||
# Pass/fail depends on first element being shorter than remaining ones so
|
||||
# set 3/4 to an incorrect size to maximize chances. Order of dict items
|
||||
# is random so may not trigger on every test run. Dict is constructed
|
||||
# inside idwtn function so no use using an OrderedDict here.
|
||||
LL = LL[:, :-1]
|
||||
LH = LH[:, :-1]
|
||||
HH = HH[:, :-1]
|
||||
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
|
||||
|
||||
assert_raises(ValueError, pywt.idwtn, d, wavelet)
|
||||
|
||||
|
||||
def test_dwt2_idwt2_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
|
||||
"dwt2: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
|
||||
|
||||
|
||||
def test_dwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1,))
|
||||
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
|
||||
assert_equal(coefs['a'], expected_a)
|
||||
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
|
||||
assert_equal(coefs['d'], expected_d)
|
||||
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
|
||||
assert_equal(coefs['aa'], expected_aa)
|
||||
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
|
||||
assert_equal(coefs['ad'], expected_ad)
|
||||
|
||||
|
||||
def test_idwtn_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwt2_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
cD = np.zeros_like(cD)
|
||||
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
cD = None
|
||||
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwtn_none_coeffs():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
data = data + 1j*data # test with complex data
|
||||
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
|
||||
# verify setting coefficients to None is the same as zeroing them
|
||||
coefs['dd'] = np.zeros_like(coefs['dd'])
|
||||
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
coefs['dd'] = None
|
||||
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
|
||||
|
||||
assert_equal(result_zeros, result_none)
|
||||
|
||||
|
||||
def test_idwt2_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
|
||||
|
||||
# too many axes
|
||||
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_idwt2_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4)))
|
||||
# test all combinations of 2 out of 3 axes transformed
|
||||
for axes in combinations((0, 1, 2), 2):
|
||||
coefs = pywt.dwt2(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_idwtn_axes_subsets():
|
||||
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
|
||||
# test all combinations of 3 out of 4 axes transformed
|
||||
for axes in combinations((0, 1, 2, 3), 3):
|
||||
coefs = pywt.dwtn(data, 'haar', axes=axes)
|
||||
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
|
||||
|
||||
|
||||
def test_negative_axes():
|
||||
data = np.array([[0, 1, 2, 3],
|
||||
[1, 1, 1, 1],
|
||||
[1, 4, 2, 8]])
|
||||
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
|
||||
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
|
||||
assert_equal(coefs1, coefs2)
|
||||
|
||||
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
|
||||
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
|
||||
assert_equal(rec1, rec2)
|
||||
|
||||
|
||||
def test_dwtn_idwtn_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
x = np.ones((4, 4), dtype=dt_in)
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
coeffs = pywt.dwtn(x, wavelet)
|
||||
for k, v in coeffs.items():
|
||||
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, wavelet)
|
||||
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
|
||||
|
||||
|
||||
def test_idwtn_mixed_complex_dtype():
|
||||
rstate = np.random.RandomState(0)
|
||||
x = rstate.randn(8, 8, 8)
|
||||
x = x + 1j*x
|
||||
coeffs = pywt.dwtn(x, 'db2')
|
||||
|
||||
x_roundtrip = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip, x, rtol=1e-10)
|
||||
|
||||
# mismatched dtypes OK
|
||||
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
|
||||
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
|
||||
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
|
||||
assert_(x_roundtrip2.dtype == np.complex128)
|
||||
|
||||
|
||||
def test_idwt2_size_mismatch_error():
|
||||
LL = np.zeros((6, 6))
|
||||
LH = HL = HH = np.zeros((5, 5))
|
||||
|
||||
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
|
||||
|
||||
|
||||
def test_dwt2_dimension_error():
|
||||
data = np.ones(16)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
|
||||
# wrong number of input dimensions
|
||||
assert_raises(ValueError, pywt.dwt2, data, wavelet)
|
||||
|
||||
# too many axes
|
||||
data2 = np.ones((8, 8))
|
||||
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
|
||||
|
||||
|
||||
def test_per_axis_wavelets_and_modes():
|
||||
# tests seperate wavelet and edge mode for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
# mode can be a string or a Modes enum
|
||||
modes = ('symmetric', 'periodization',
|
||||
pywt._extensions._pywt.Modes.reflect)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets[:1], modes)
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
|
||||
|
||||
coefs = pywt.dwtn(data, wavelets, modes[:1])
|
||||
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets or modes doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
|
||||
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
|
||||
|
||||
# dwt2/idwt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
|
||||
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
|
||||
atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
|
||||
[pywt.idwt2, pywt.idwtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
|
||||
|
||||
c = dec_fun(data, 'db1')
|
||||
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
|
1033
venv/Lib/site-packages/pywt/tests/test_multilevel.py
Normal file
1033
venv/Lib/site-packages/pywt/tests/test_multilevel.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Verify DWT perfect reconstruction.
|
||||
"""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_perfect_reconstruction():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
# list of mode names in pywt and matlab
|
||||
modes = [('zero', 'zpd'),
|
||||
('constant', 'sp0'),
|
||||
('symmetric', 'sym'),
|
||||
('periodic', 'ppd'),
|
||||
('smooth', 'sp1'),
|
||||
('periodization', 'per')]
|
||||
|
||||
dtypes = (np.float32, np.float64)
|
||||
|
||||
for wavelet in wavelets:
|
||||
for pmode, mmode in modes:
|
||||
for dt in dtypes:
|
||||
check_reconstruction(pmode, mmode, wavelet, dt)
|
||||
|
||||
|
||||
def check_reconstruction(pmode, mmode, wavelet, dtype):
|
||||
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
|
||||
50000, 100000]
|
||||
np.random.seed(12345)
|
||||
# TODO: smoke testing - more failures for different seeds
|
||||
|
||||
if dtype == np.float32:
|
||||
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
|
||||
epsilon = 6e-7
|
||||
else:
|
||||
epsilon = 5e-11
|
||||
|
||||
for N in data_size:
|
||||
data = np.asarray(np.random.random(N), dtype)
|
||||
|
||||
# compute dwt coefficients
|
||||
pa, pd = pywt.dwt(data, wavelet, pmode)
|
||||
|
||||
# compute reconstruction
|
||||
rec = pywt.idwt(pa, pd, wavelet, pmode)
|
||||
|
||||
if len(data) % 2:
|
||||
rec = rec[:len(data)]
|
||||
|
||||
rms_rec = np.sqrt(np.mean((data-rec)**2))
|
||||
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
|
||||
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
|
||||
assert_(rms_rec < epsilon, msg=msg)
|
633
venv/Lib/site-packages/pywt/tests/test_swt.py
Normal file
633
venv/Lib/site-packages/pywt/tests/test_swt.py
Normal file
|
@ -0,0 +1,633 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import combinations, permutations
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import (assert_allclose, assert_, assert_equal,
|
||||
assert_raises, assert_array_equal, assert_warns)
|
||||
|
||||
import pywt
|
||||
from pywt._extensions._swt import swt_axis
|
||||
|
||||
# Check that float32 and complex64 are preserved. Other real types get
|
||||
# converted to float64.
|
||||
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
|
||||
np.complex128]
|
||||
|
||||
# tolerances used in accuracy comparisons
|
||||
tol_single = 1e-6
|
||||
tol_double = 1e-13
|
||||
|
||||
####
|
||||
# 1d multilevel swt tests
|
||||
####
|
||||
|
||||
|
||||
def test_swt_decomposition():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
db1 = pywt.Wavelet('db1')
|
||||
atol = tol_double
|
||||
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
|
||||
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
|
||||
2.82842712, 7.07106781, 7.07106781, 6.36396103]
|
||||
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
|
||||
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
|
||||
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
|
||||
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
|
||||
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
expected_cA3 = [9.89949494, ] * 8
|
||||
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
|
||||
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
|
||||
0.00000000, 3.53553391, 4.24264069, 2.12132034]
|
||||
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
|
||||
|
||||
# level=1, start_level=1 decomposition should match level=2
|
||||
res = pywt.swt(cA1, db1, level=1, start_level=1)
|
||||
cA2, cD2 = res[0]
|
||||
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
|
||||
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
|
||||
|
||||
coeffs = pywt.swt(x, db1)
|
||||
assert_(len(coeffs) == 3)
|
||||
assert_(pywt.swt_max_level(len(x)), 3)
|
||||
|
||||
|
||||
def test_swt_max_level():
|
||||
# odd sized signal will warn about no levels of decomposition possible
|
||||
assert_warns(UserWarning, pywt.swt_max_level, 11)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', UserWarning)
|
||||
assert_equal(pywt.swt_max_level(11), 0)
|
||||
|
||||
# no warnings when >= 1 level of decomposition possible
|
||||
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
|
||||
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
|
||||
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
|
||||
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
|
||||
|
||||
|
||||
def test_swt_axis():
|
||||
x = [3, 7, 1, 3, -2, 6, 4, 6]
|
||||
|
||||
db1 = pywt.Wavelet('db1')
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
|
||||
|
||||
# test cases use 2D arrays based on tiling x along an axis and then
|
||||
# calling swt along the other axis.
|
||||
for order in ['C', 'F']:
|
||||
# test SWT of 2D data along default axis (-1)
|
||||
x_2d = np.asarray(x).reshape((1, -1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=0)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each row should match the 1D result
|
||||
for row in cA1_2d:
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d:
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d:
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d:
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# test SWT of 2D data along other axis (0)
|
||||
x_2d = np.asarray(x).reshape((-1, 1))
|
||||
x_2d = np.concatenate((x_2d, )*5, axis=1)
|
||||
if order == 'C':
|
||||
x_2d = np.ascontiguousarray(x_2d)
|
||||
elif order == 'F':
|
||||
x_2d = np.asfortranarray(x_2d)
|
||||
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
|
||||
axis=0)
|
||||
|
||||
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
|
||||
assert_(c.shape == x_2d.shape)
|
||||
# each column should match the 1D result
|
||||
for row in cA1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA1)
|
||||
for row in cA2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cA2)
|
||||
for row in cD1_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD1)
|
||||
for row in cD2_2d.transpose((1, 0)):
|
||||
assert_array_equal(row, cD2)
|
||||
|
||||
# axis too large
|
||||
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
|
||||
|
||||
|
||||
def test_swt_iswt_integration():
|
||||
# This function performs a round-trip swt/iswt transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt or iswt as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet seems to be a bit special - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length)
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_dtypes():
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
errmsg = "wrong dtype returned for {0} input".format(dt_in)
|
||||
|
||||
# swt
|
||||
x = np.ones(8, dtype=dt_in)
|
||||
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
|
||||
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
|
||||
"swt: " + errmsg)
|
||||
|
||||
# swt2
|
||||
x = np.ones((8, 8), dtype=dt_in)
|
||||
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
|
||||
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
|
||||
"swt2: " + errmsg)
|
||||
|
||||
|
||||
def test_swt_roundtrip_dtypes():
|
||||
# verify perfect reconstruction for all dtypes
|
||||
rstate = np.random.RandomState(5)
|
||||
wavelet = pywt.Wavelet('haar')
|
||||
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
|
||||
# swt, iswt
|
||||
x = rstate.standard_normal((8, )).astype(dt_in)
|
||||
c = pywt.swt(x, wavelet, level=2)
|
||||
xr = pywt.iswt(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
# swt2, iswt2
|
||||
x = rstate.standard_normal((8, 8)).astype(dt_in)
|
||||
c = pywt.swt2(x, wavelet, level=2)
|
||||
xr = pywt.iswt2(c, wavelet)
|
||||
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
|
||||
|
||||
|
||||
def test_swt_default_level_by_axis():
|
||||
# make sure default number of levels matches the max level along the axis
|
||||
wav = 'db2'
|
||||
x = np.ones((2**3, 2**4, 2**5))
|
||||
for axis in (0, 1, 2):
|
||||
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
|
||||
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
|
||||
|
||||
|
||||
def test_swt2_ndim_error():
|
||||
x = np.ones(8)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swt2_iswt2_integration(wavelets=None):
|
||||
# This function performs a round-trip swt2/iswt2 transform test on
|
||||
# all available types of wavelets in PyWavelets - except the
|
||||
# 'dmey' wavelet. The latter has been excluded because it does not
|
||||
# produce very precise results. This is likely due to the fact
|
||||
# that the 'dmey' wavelet is a discrete approximation of a
|
||||
# continuous wavelet. All wavelets are tested up to 3 levels. The
|
||||
# test validates neither swt2 or iswt2 as such, but it does ensure
|
||||
# that they are each other's inverse.
|
||||
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for current_wavelet_str in wavelets:
|
||||
current_wavelet = pywt.Wavelet(current_wavelet_str)
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not current_wavelet.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings when norm=True
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swt2(X, current_wavelet, max_level,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
def test_swt2_iswt2_quick():
|
||||
test_swt2_iswt2_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_swt2_iswt2_non_square(wavelets=None):
|
||||
for nrows in [8, 16, 48]:
|
||||
X = np.arange(nrows*32).reshape(nrows, 32)
|
||||
current_wavelet = 'db1'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
coeffs = pywt.swt2(X, current_wavelet, level=2)
|
||||
Y = pywt.iswt2(coeffs, current_wavelet)
|
||||
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
|
||||
|
||||
|
||||
def test_swt2_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
|
||||
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
|
||||
# opposite order
|
||||
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
|
||||
axes=(1, 0))[0]
|
||||
assert_allclose(cA1, cA2, atol=atol)
|
||||
assert_allclose(cH1, cV2, atol=atol)
|
||||
assert_allclose(cV1, cH2, atol=atol)
|
||||
assert_allclose(cD1, cD2, atol=atol)
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
|
||||
axes=(0, 0))
|
||||
# too few axes
|
||||
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
|
||||
|
||||
|
||||
def test_iswt2_2d_only():
|
||||
# iswt2 is not currently compatible with data that is not 2D
|
||||
x_3d = np.ones((4, 4, 4))
|
||||
c = pywt.swt2(x_3d, 'haar', level=1)
|
||||
assert_raises(ValueError, pywt.iswt2, c, 'haar')
|
||||
|
||||
|
||||
def test_swtn_axes():
|
||||
atol = 1e-14
|
||||
current_wavelet = pywt.Wavelet('db2')
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
current_wavelet.dec_len,
|
||||
current_wavelet.rec_len))))
|
||||
input_length = 2**(input_length_power)
|
||||
X = np.arange(input_length**2).reshape(input_length, input_length)
|
||||
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
|
||||
# opposite order
|
||||
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
|
||||
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
|
||||
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
|
||||
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
|
||||
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
|
||||
|
||||
# 0-level transform
|
||||
empty = pywt.swtn(X, current_wavelet, level=0)
|
||||
assert_equal(empty, [])
|
||||
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
|
||||
|
||||
# data.ndim = 0
|
||||
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
|
||||
|
||||
# start_level too large
|
||||
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
|
||||
level=1, start_level=2)
|
||||
|
||||
# level < 1 in swt_axis call
|
||||
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
|
||||
start_level=0)
|
||||
# odd-sized data not allowed
|
||||
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
|
||||
start_level=0, axis=0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_swtn_iswtn_integration(wavelets=None):
|
||||
# This function performs a round-trip swtn/iswtn transform for various
|
||||
# possible combinations of:
|
||||
# 1.) 1 out of 2 axes of a 2D array
|
||||
# 2.) 2 out of 3 axes of a 3D array
|
||||
#
|
||||
# To keep test time down, only wavelets of length <= 8 are run.
|
||||
#
|
||||
# This test does not validate swtn or iswtn individually, but only
|
||||
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
|
||||
max_level = 3
|
||||
if wavelets is None:
|
||||
wavelets = pywt.wavelist(kind='discrete')
|
||||
if 'dmey' in wavelets:
|
||||
# The 'dmey' wavelet is a special case - disregard it for now
|
||||
wavelets.remove('dmey')
|
||||
for ndim_transform in range(1, 3):
|
||||
ndim = ndim_transform + 1
|
||||
for axes in combinations(range(ndim), ndim_transform):
|
||||
for current_wavelet_str in wavelets:
|
||||
wav = pywt.Wavelet(current_wavelet_str)
|
||||
if wav.dec_len > 8:
|
||||
continue # avoid excessive test duration
|
||||
input_length_power = int(np.ceil(np.log2(max(
|
||||
wav.dec_len,
|
||||
wav.rec_len))))
|
||||
N = 2**(input_length_power + max_level - 1)
|
||||
X = np.arange(N**ndim).reshape((N, )*ndim)
|
||||
|
||||
for norm in [True, False]:
|
||||
if norm and not wav.orthogonal:
|
||||
# non-orthogonal wavelets to avoid warnings
|
||||
continue
|
||||
for trim_approx in [True, False]:
|
||||
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
|
||||
trim_approx=trim_approx, norm=norm)
|
||||
coeffs_copy = deepcopy(coeffs)
|
||||
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
|
||||
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# verify the inverse transform didn't modify any coeffs
|
||||
for c, c2 in zip(coeffs, coeffs_copy):
|
||||
for k, v in c.items():
|
||||
assert_array_equal(c2[k], v)
|
||||
|
||||
|
||||
def test_swtn_iswtn_quick():
|
||||
test_swtn_iswtn_integration(wavelets=['db1', ])
|
||||
|
||||
|
||||
def test_iswtn_errors():
|
||||
x = np.arange(8**3).reshape(8, 8, 8)
|
||||
max_level = 2
|
||||
axes = (0, 1)
|
||||
w = pywt.Wavelet('db1')
|
||||
coeffs = pywt.swtn(x, w, max_level, axes=axes)
|
||||
|
||||
# more axes than dimensions transformed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
|
||||
# duplicate axes not allowed
|
||||
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
|
||||
# mismatched coefficient size
|
||||
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
|
||||
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
|
||||
|
||||
|
||||
def test_swtn_iswtn_unique_shape_per_axis():
|
||||
# test case for gh-460
|
||||
_shape = (1, 48, 32) # unique shape per axis
|
||||
wav = 'sym2'
|
||||
max_level = 3
|
||||
rstate = np.random.RandomState(0)
|
||||
for shape in permutations(_shape):
|
||||
# transform only along the non-singleton axes
|
||||
axes = [ax for ax, s in enumerate(shape) if s != 1]
|
||||
x = rstate.standard_normal(shape)
|
||||
c = pywt.swtn(x, wav, max_level, axes=axes)
|
||||
r = pywt.iswtn(c, wav, axes=axes)
|
||||
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
|
||||
|
||||
|
||||
def test_per_axis_wavelets():
|
||||
# tests seperate wavelet for each axis.
|
||||
rstate = np.random.RandomState(1234)
|
||||
data = rstate.randn(16, 16, 16)
|
||||
level = 3
|
||||
|
||||
# wavelet can be a string or wavelet object
|
||||
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
|
||||
|
||||
coefs = pywt.swtn(data, wavelets, level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
|
||||
|
||||
# 1-tuple also okay
|
||||
coefs = pywt.swtn(data, wavelets[:1], level=level)
|
||||
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
|
||||
|
||||
# length of wavelets doesn't match the length of axes
|
||||
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
|
||||
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', FutureWarning)
|
||||
# swt2/iswt2 also support per-axis wavelets/modes
|
||||
data2 = data[..., 0]
|
||||
coefs2 = pywt.swt2(data2, wavelets[:2], level)
|
||||
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
|
||||
|
||||
|
||||
def test_error_on_continuous_wavelet():
|
||||
# A ValueError is raised if a Continuous wavelet is selected
|
||||
data = np.ones((16, 16))
|
||||
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
|
||||
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
|
||||
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
|
||||
assert_raises(ValueError, dec_func, data, wavelet=cwave,
|
||||
level=3)
|
||||
|
||||
c = dec_func(data, 'db1', level=3)
|
||||
assert_raises(ValueError, rec_func, c, wavelet=cwave)
|
||||
|
||||
|
||||
def test_iswt_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
x_real = np.arange(16).astype(np.float64)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
coeffs[0][1].astype(dtype2)]
|
||||
y = pywt.iswt(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswt2_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swt2(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
coeffs[0] = [coeffs[0][0].astype(dtype1),
|
||||
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
|
||||
y = pywt.iswt2(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_iswtn_mixed_dtypes():
|
||||
# Mixed precision inputs give double precision output
|
||||
rstate = np.random.RandomState(0)
|
||||
x_real = rstate.randn(8, 8, 8)
|
||||
x_complex = x_real + 1j*x_real
|
||||
wav = 'sym2'
|
||||
for dtype1, dtype2 in [(np.float64, np.float32),
|
||||
(np.float32, np.float64),
|
||||
(np.float16, np.float64),
|
||||
(np.complex128, np.complex64),
|
||||
(np.complex64, np.complex128)]:
|
||||
|
||||
if dtype1 in [np.complex64, np.complex128]:
|
||||
x = x_complex
|
||||
output_dtype = np.complex128
|
||||
else:
|
||||
x = x_real
|
||||
output_dtype = np.float64
|
||||
|
||||
coeffs = pywt.swtn(x, wav, 2)
|
||||
# different precision for the approximation coefficients
|
||||
a = coeffs[0].pop('a' * x.ndim)
|
||||
a = a.astype(dtype1)
|
||||
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
|
||||
coeffs[0]['a' * x.ndim] = a
|
||||
y = pywt.iswtn(coeffs, wav)
|
||||
assert_equal(output_dtype, y.dtype)
|
||||
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_swt_zero_size_axes():
|
||||
# raise on empty input array
|
||||
assert_raises(ValueError, pywt.swt, [], 'db2')
|
||||
|
||||
# >1D case uses a different code path so check there as well
|
||||
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
|
||||
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
|
||||
|
||||
|
||||
def test_swt_variance_and_energy_preservation():
|
||||
"""Verify that the 1D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(256)
|
||||
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
|
||||
variances = [np.var(c) for c in coeffs]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeffs)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
|
||||
|
||||
|
||||
def test_swt2_variance_and_energy_preservation():
|
||||
"""Verify that the 2D SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for v in d:
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swtn_variance_and_energy_preservation():
|
||||
"""Verify that the nD SWT partitions variance among the coefficients."""
|
||||
# When norm is True and the wavelet is orthogonal, the sum of the
|
||||
# variances of the coefficients should equal the variance of the signal.
|
||||
wav = 'db2'
|
||||
rstate = np.random.RandomState(5)
|
||||
x = rstate.randn(64, 64)
|
||||
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
|
||||
coeff_list = [coeffs[0].ravel()]
|
||||
for d in coeffs[1:]:
|
||||
for k, v in d.items():
|
||||
coeff_list.append(v.ravel())
|
||||
variances = [np.var(v) for v in coeff_list]
|
||||
assert_allclose(np.sum(variances), np.var(x))
|
||||
|
||||
# also verify L2-norm energy preservation property
|
||||
assert_allclose(np.linalg.norm(x),
|
||||
np.linalg.norm(np.concatenate(coeff_list)))
|
||||
|
||||
# non-orthogonal wavelet with norm=True raises a warning
|
||||
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
|
||||
|
||||
|
||||
def test_swt_ravel_and_unravel():
|
||||
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
|
||||
for ndim, _swt, _iswt, ravel_type in [
|
||||
(1, pywt.swt, pywt.iswt, 'swt'),
|
||||
(2, pywt.swt2, pywt.iswt2, 'swt2'),
|
||||
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
|
||||
x = np.ones((16, ) * ndim)
|
||||
c = _swt(x, 'sym2', level=3, trim_approx=True)
|
||||
arr, slices, shapes = pywt.ravel_coeffs(c)
|
||||
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
|
||||
r = _iswt(c, 'sym2')
|
||||
assert_allclose(x, r)
|
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
169
venv/Lib/site-packages/pywt/tests/test_thresholding.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
|
||||
real_dtypes = [np.float32, np.float64]
|
||||
|
||||
|
||||
def _sign(x):
|
||||
# Matlab-like sign function (numpy uses a different convention).
|
||||
return x / np.abs(x)
|
||||
|
||||
|
||||
def _soft(x, thresh):
|
||||
"""soft thresholding supporting complex values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This version is not robust to zeros in x.
|
||||
"""
|
||||
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
|
||||
|
||||
|
||||
def test_threshold():
|
||||
data = np.linspace(1, 4, 7)
|
||||
|
||||
# soft
|
||||
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'soft'),
|
||||
np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'soft'),
|
||||
-np.array(soft_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
|
||||
[[0, 1]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
|
||||
# soft thresholding complex values
|
||||
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
|
||||
[[0j, 1j]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
|
||||
[[0, 0]] * 2, rtol=1e-12)
|
||||
complex_data = [[1+2j, 2+2j]]*2
|
||||
for thresh in [1, 2]:
|
||||
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
|
||||
_soft(complex_data, thresh), rtol=1e-12)
|
||||
|
||||
# test soft thresholding with non-default substitute argument
|
||||
s = 5
|
||||
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
|
||||
[[s, 0.5]] * 2, rtol=1e-12)
|
||||
|
||||
# soft: no divide by zero warnings when input contains zeros
|
||||
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
|
||||
np.zeros(16), rtol=1e-12)
|
||||
|
||||
# hard
|
||||
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'hard'),
|
||||
np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold(-data, 2, 'hard'),
|
||||
-np.array(hard_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
|
||||
[[0, 2+2j]] * 2, rtol=1e-12)
|
||||
|
||||
# greater
|
||||
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
|
||||
assert_allclose(pywt.threshold(data, 2, 'greater'),
|
||||
np.array(greater_result), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
|
||||
[[0, 2]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
|
||||
[[s, 2]] * 2, rtol=1e-12)
|
||||
# greater doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
|
||||
|
||||
# less
|
||||
assert_allclose(pywt.threshold(data, 2, 'less'),
|
||||
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
|
||||
[[1, 0]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
|
||||
[[1, s]] * 2, rtol=1e-12)
|
||||
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
|
||||
[[1, 2]] * 2, rtol=1e-12)
|
||||
|
||||
# less doesn't allow complex-valued inputs
|
||||
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
|
||||
|
||||
# invalid
|
||||
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
|
||||
|
||||
|
||||
def test_nonnegative_garotte():
|
||||
thresh = 0.3
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_garotte = pywt.threshold(data, thresh, 'garotte')
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_garotte.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_garotte[lt] == 0))
|
||||
|
||||
# values > than the threshold are intermediate between soft and hard
|
||||
gt = np.where(np.abs(data) > thresh)
|
||||
gt_abs_garotte = np.abs(d_garotte[gt])
|
||||
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
|
||||
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
|
||||
|
||||
|
||||
def test_threshold_firm():
|
||||
thresh = 0.2
|
||||
thresh2 = 3 * thresh
|
||||
data_real = np.linspace(-1, 1, 100)
|
||||
for dtype in float_dtypes:
|
||||
if dtype in real_dtypes:
|
||||
data = np.asarray(data_real, dtype=dtype)
|
||||
else:
|
||||
data = np.asarray(data_real + 0.1j, dtype=dtype)
|
||||
if data.real.dtype == np.float32:
|
||||
rtol = atol = 1e-6
|
||||
else:
|
||||
rtol = atol = 1e-14
|
||||
d_hard = pywt.threshold(data, thresh, 'hard')
|
||||
d_soft = pywt.threshold(data, thresh, 'soft')
|
||||
d_firm = pywt.threshold_firm(data, thresh, thresh2)
|
||||
|
||||
# check dtypes
|
||||
assert_equal(d_hard.dtype, data.dtype)
|
||||
assert_equal(d_soft.dtype, data.dtype)
|
||||
assert_equal(d_firm.dtype, data.dtype)
|
||||
|
||||
# values < threshold are zero
|
||||
lt = np.where(np.abs(data) < thresh)
|
||||
assert_(np.all(d_firm[lt] == 0))
|
||||
|
||||
# values > than the threshold are equal to hard-thresholding
|
||||
gt = np.where(np.abs(data) >= thresh2)
|
||||
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
# other values are intermediate between soft and hard thresholding
|
||||
mt = np.where(np.logical_and(np.abs(data) > thresh,
|
||||
np.abs(data) < thresh2))
|
||||
mt_abs_firm = np.abs(d_firm[mt])
|
||||
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
|
||||
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
|
266
venv/Lib/site-packages/pywt/tests/test_wavelet.py
Normal file
266
venv/Lib/site-packages/pywt/tests/test_wavelet.py
Normal file
|
@ -0,0 +1,266 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose, assert_
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_properties():
|
||||
w = pywt.Wavelet('db3')
|
||||
|
||||
# Name
|
||||
assert_(w.name == 'db3')
|
||||
assert_(w.short_family_name == 'db')
|
||||
assert_(w.family_name, 'Daubechies')
|
||||
|
||||
# String representation
|
||||
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
|
||||
'Biorthogonal', 'Symmetry')
|
||||
for field in fields:
|
||||
assert_(field in str(w))
|
||||
|
||||
# Filter coefficients
|
||||
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
|
||||
0.45987750211933, 0.80689150931334, 0.33267055295096]
|
||||
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
|
||||
-0.13501102001039, 0.08544127388224, 0.03522629188210]
|
||||
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
|
||||
-0.13501102001039, -0.08544127388224, 0.03522629188210]
|
||||
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
|
||||
-0.45987750211933, 0.80689150931334, -0.33267055295096]
|
||||
assert_allclose(w.dec_lo, dec_lo)
|
||||
assert_allclose(w.dec_hi, dec_hi)
|
||||
assert_allclose(w.rec_lo, rec_lo)
|
||||
assert_allclose(w.rec_hi, rec_hi)
|
||||
|
||||
assert_(len(w.filter_bank) == 4)
|
||||
|
||||
# Orthogonality
|
||||
assert_(w.orthogonal)
|
||||
assert_(w.biorthogonal)
|
||||
|
||||
# Symmetry
|
||||
assert_(w.symmetry)
|
||||
|
||||
# Vanishing moments
|
||||
assert_(w.vanishing_moments_phi == 0)
|
||||
assert_(w.vanishing_moments_psi == 3)
|
||||
|
||||
|
||||
def test_wavelet_coefficients():
|
||||
families = ('db', 'sym', 'coif', 'bior', 'rbio')
|
||||
wavelets = sum([pywt.wavelist(name) for name in families], [])
|
||||
for wavelet in wavelets:
|
||||
if (pywt.Wavelet(wavelet).orthogonal):
|
||||
check_coefficients_orthogonal(wavelet)
|
||||
elif(pywt.Wavelet(wavelet).biorthogonal):
|
||||
check_coefficients_biorthogonal(wavelet)
|
||||
else:
|
||||
check_coefficients(wavelet)
|
||||
|
||||
|
||||
def check_coefficients_orthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi, psi, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
|
||||
res = np.sum(phi) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Wavelet function is orthogonal to the scaling function at the same scale
|
||||
res = np.sum(phi*psi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# The lowpass and highpass filter coefficients are orthogonal
|
||||
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients_biorthogonal(wavelet):
|
||||
|
||||
epsilon = 5e-11
|
||||
level = 5
|
||||
w = pywt.Wavelet(wavelet)
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
|
||||
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Scaling function integrates to unity
|
||||
res = np.sum(phi_d) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
res = np.sum(phi_r) - 2**level
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
def check_coefficients(wavelet):
|
||||
epsilon = 5e-11
|
||||
level = 10
|
||||
w = pywt.Wavelet(wavelet)
|
||||
# Lowpass filter coefficients sum to sqrt2
|
||||
res = np.sum(w.dec_lo)-np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# sum even coef = sum odd coef = 1 / sqrt(2)
|
||||
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
# Highpass filter coefficients sum to zero
|
||||
res = np.sum(w.dec_hi)
|
||||
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
|
||||
assert_(res < epsilon, msg=msg)
|
||||
|
||||
|
||||
class _CustomHaarFilterBank(object):
|
||||
@property
|
||||
def filter_bank(self):
|
||||
val = np.sqrt(2) / 2
|
||||
return ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
|
||||
|
||||
def test_custom_wavelet():
|
||||
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=_CustomHaarFilterBank())
|
||||
haar_custom1.orthogonal = True
|
||||
haar_custom1.biorthogonal = True
|
||||
|
||||
val = np.sqrt(2) / 2
|
||||
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
|
||||
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
|
||||
filter_bank=filter_bank)
|
||||
|
||||
# check expected default wavelet properties
|
||||
assert_(~haar_custom2.orthogonal)
|
||||
assert_(~haar_custom2.biorthogonal)
|
||||
assert_(haar_custom2.symmetry == 'unknown')
|
||||
assert_(haar_custom2.family_name == '')
|
||||
assert_(haar_custom2.short_family_name == '')
|
||||
assert_(haar_custom2.vanishing_moments_phi == 0)
|
||||
assert_(haar_custom2.vanishing_moments_psi == 0)
|
||||
|
||||
# Some properties can be set by the user
|
||||
haar_custom2.orthogonal = True
|
||||
haar_custom2.biorthogonal = True
|
||||
|
||||
|
||||
def test_wavefun_sym3():
|
||||
w = pywt.Wavelet('sym3')
|
||||
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
|
||||
phi, psi, x = w.wavefun(level=3)
|
||||
assert_(phi.size == 41)
|
||||
assert_(psi.size == 41)
|
||||
assert_(x.size == 41)
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, num=x.size))
|
||||
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
|
||||
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
|
||||
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
|
||||
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
|
||||
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
|
||||
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
|
||||
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
|
||||
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
|
||||
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
|
||||
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
|
||||
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
|
||||
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
|
||||
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
|
||||
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
|
||||
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
|
||||
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
|
||||
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
|
||||
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
|
||||
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
|
||||
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
|
||||
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
|
||||
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
|
||||
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
|
||||
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
|
||||
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
|
||||
0.00000000e+00, 0.00000000e+00])
|
||||
|
||||
assert_allclose(phi, phi_expect)
|
||||
assert_allclose(psi, psi_expect)
|
||||
|
||||
|
||||
def test_wavefun_bior13():
|
||||
w = pywt.Wavelet('bior1.3')
|
||||
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
|
||||
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
|
||||
for arr in [phi_d, psi_d, phi_r, psi_r]:
|
||||
assert_(arr.size == 40)
|
||||
|
||||
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
|
||||
0.01367188, 0.00390625, -0.03515625, -0.12890625,
|
||||
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
|
||||
0.15234375, 0.37890625, 0.78515625, 0.99609375,
|
||||
1.08203125, 1.13671875, 1.13671875, 1.08203125,
|
||||
0.99609375, 0.78515625, 0.37890625, 0.15234375,
|
||||
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
|
||||
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
|
||||
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
|
||||
phi_r_expect = np.zeros(x.size, dtype=np.float)
|
||||
phi_r_expect[15:23] = 1
|
||||
|
||||
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0.015625, -0.015625, -0.140625, -0.109375,
|
||||
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
|
||||
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
|
||||
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
|
||||
psi_r_expect = np.zeros(x.size, dtype=np.float)
|
||||
psi_r_expect[7:15] = -0.125
|
||||
psi_r_expect[15:19] = 1
|
||||
psi_r_expect[19:23] = -1
|
||||
psi_r_expect[23:31] = 0.125
|
||||
|
||||
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
|
||||
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
|
||||
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
|
||||
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
|
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
|
@ -0,0 +1,197 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_packet_structure():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
|
||||
|
||||
def test_traversing_wp_tree():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
# First level
|
||||
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
|
||||
7.778174593052, 10.606601717798]),
|
||||
rtol=1e-12)
|
||||
|
||||
# Second level
|
||||
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
|
||||
|
||||
# Third level
|
||||
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
|
||||
|
||||
|
||||
def test_acess_path():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp['a'].path == 'a')
|
||||
assert_(wp['aa'].path == 'aa')
|
||||
assert_(wp['aaa'].path == 'aaa')
|
||||
|
||||
# Maximum level reached:
|
||||
assert_raises(IndexError, lambda: wp['aaaa'].path)
|
||||
|
||||
# Wrong path
|
||||
assert_raises(ValueError, lambda: wp['ac'].path)
|
||||
|
||||
|
||||
def test_access_node_atributes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
|
||||
assert_(wp['ad'].path == 'ad')
|
||||
assert_(wp['ad'].node_name == 'd')
|
||||
assert_(wp['ad'].parent.path == 'a')
|
||||
assert_(wp['ad'].level == 2)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
assert_(wp['ad'].mode == 'symmetric')
|
||||
|
||||
|
||||
def test_collecting_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# All nodes in natural order
|
||||
assert_([node.path for node in wp.get_level(3, 'natural')] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
# and in frequency order.
|
||||
assert_([node.path for node in wp.get_level(3, 'freq')] ==
|
||||
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
|
||||
|
||||
|
||||
def test_reconstructing_data():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# Create another Wavelet Packet and feed it with some data.
|
||||
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['aa'] = wp['aa'].data
|
||||
new_wp['ad'] = [-2., -2.]
|
||||
|
||||
# For convenience, :attr:`Node.data` gets automatically extracted
|
||||
# from the :class:`Node` object:
|
||||
new_wp['d'] = wp['d']
|
||||
|
||||
# Reconstruct data from aa, ad, and d packets.
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
# The node's :attr:`~Node.data` will not be updated
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
# When `update` is True:
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
|
||||
['aa', 'ad', 'd'])
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
|
||||
def test_removing_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
wp.get_level(2)
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
node = wp['ad']
|
||||
del(wp['ad'])
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(3):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
wp.reconstruct()
|
||||
# The reconstruction is:
|
||||
assert_allclose(wp.reconstruct(),
|
||||
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
|
||||
|
||||
# Restore the data
|
||||
wp['ad'].data = node.data
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
N = 32
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(N).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assiging to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# first element of the tuple is the input dtype
|
||||
# second element of the tuple is the transform dtype
|
||||
dtype_pairs = [(np.uint8, np.float64),
|
||||
(np.intp, np.float64), ]
|
||||
if hasattr(np, "complex256"):
|
||||
dtype_pairs += [(np.complex256, np.complex128), ]
|
||||
if hasattr(np, "half"):
|
||||
dtype_pairs += [(np.half, np.float32), ]
|
||||
for (dtype, transform_dtype) in dtype_pairs:
|
||||
x = np.arange(N, dtype=dtype)
|
||||
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# no unnecessary copy made of top-level data
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstructed data will have modified dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, transform_dtype)
|
||||
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_db3_roundtrip():
|
||||
original = np.arange(512)
|
||||
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
177
venv/Lib/site-packages/pywt/tests/test_wp2d.py
Normal file
177
venv/Lib/site-packages/pywt/tests/test_wp2d.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_traversing_tree_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(np.all(wp.data == x))
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
||||
|
||||
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
||||
|
||||
assert_(wp['a']['a'].data is wp['aa'].data)
|
||||
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
|
||||
|
||||
assert_raises(IndexError, lambda: wp['aaaa'])
|
||||
assert_raises(ValueError, lambda: wp['f'])
|
||||
|
||||
|
||||
def test_accessing_node_atributes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
||||
assert_(wp['av'].path == 'av')
|
||||
assert_(wp['av'].node_name == 'v')
|
||||
assert_(wp['av'].parent.path == 'a')
|
||||
|
||||
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_(wp['av'].level == 2)
|
||||
assert_(wp['av'].maxlevel == 3)
|
||||
assert_(wp['av'].mode == 'symmetric')
|
||||
|
||||
|
||||
def test_collecting_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(len(wp.get_level(0)) == 1)
|
||||
assert_(wp.get_level(0)[0].path == '')
|
||||
|
||||
# First level
|
||||
assert_(len(wp.get_level(1)) == 4)
|
||||
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
|
||||
|
||||
# Second level
|
||||
assert_(len(wp.get_level(2)) == 16)
|
||||
paths = [node.path for node in wp.get_level(2)]
|
||||
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
|
||||
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
# Third level.
|
||||
assert_(len(wp.get_level(3)) == 64)
|
||||
paths = [node.path for node in wp.get_level(3)]
|
||||
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
|
||||
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
|
||||
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
|
||||
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
|
||||
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
|
||||
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
|
||||
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
|
||||
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
|
||||
|
||||
assert_(paths == expected_paths)
|
||||
|
||||
|
||||
def test_data_reconstruction_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
|
||||
def test_data_reconstruction_delete_nodes_2d():
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['vh'] = wp['vh'].data
|
||||
new_wp['vv'] = wp['vh'].data
|
||||
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
|
||||
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
||||
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
|
||||
new_wp['h'] = wp['h'] # all zeros
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=False),
|
||||
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
||||
rtol=1e-12)
|
||||
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
del(new_wp['va'])
|
||||
new_wp['va'] = wp['va'].data
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, x, rtol=1e-12)
|
||||
|
||||
# TODO: decompose=True
|
||||
|
||||
|
||||
def test_lazy_evaluation_2D():
|
||||
# Note: internal implementation detail not to be relied on. Testing for
|
||||
# now for backwards compatibility, but this test may be broken in needed.
|
||||
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.a is None)
|
||||
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
|
||||
rtol=1e-12)
|
||||
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
||||
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
shape = (16, 16)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(*shape).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assiging to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_2d_roundtrip():
|
||||
# test case corresponding to PyWavelets issue 447
|
||||
original = pywt.data.camera()
|
||||
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
10
venv/Lib/site-packages/pywt/version.py
Normal file
10
venv/Lib/site-packages/pywt/version.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
|
||||
# THIS FILE IS GENERATED FROM PYWAVELETS SETUP.PY
|
||||
short_version = '1.1.1'
|
||||
version = '1.1.1'
|
||||
full_version = '1.1.1'
|
||||
git_revision = '7b2f66b0ef9196fa91bba550a81d1870f5933a36'
|
||||
release = True
|
||||
|
||||
if not release:
|
||||
version = full_version
|
Loading…
Add table
Add a link
Reference in a new issue