Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,40 @@
# flake8: noqa
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
"""
Discrete forward and inverse wavelet transform, stationary wavelet transform,
wavelet packets signal decomposition and reconstruction module.
"""
from __future__ import division, print_function, absolute_import
from distutils.version import LooseVersion
from ._extensions._pywt import *
from ._functions import *
from ._multilevel import *
from ._multidim import *
from ._thresholding import *
from ._wavelet_packets import *
from ._dwt import *
from ._swt import *
from ._cwt import *
from . import data
__all__ = [s for s in dir() if not s.startswith('_')]
try:
# In Python 2.x the name of the tempvar leaks out of the list
# comprehension. Delete it to not make it show up in the main namespace.
del s
except NameError:
pass
from pywt.version import version as __version__
from ._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester

View file

@ -0,0 +1,3 @@
# Autogenerated file containing compile-time definitions
_have_c99_complex = 0

View file

@ -0,0 +1,203 @@
from math import floor, ceil
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency
__all__ = ["cwt"]
import numpy as np
try:
# Prefer scipy.fft (new in SciPy 1.4)
import scipy.fft
fftmodule = scipy.fft
next_fast_len = fftmodule.next_fast_len
except ImportError:
try:
import scipy.fftpack
fftmodule = scipy.fftpack
next_fast_len = fftmodule.next_fast_len
except ImportError:
fftmodule = np.fft
# provide a fallback so scipy is an optional requirement
def next_fast_len(n):
"""Round up size to the nearest power of two.
Given a number of samples `n`, returns the next power of two
following this number to take advantage of FFT speedup.
This fallback is less efficient than `scipy.fftpack.next_fast_len`
"""
return 2**ceil(np.log2(n))
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
"""
cwt(data, scales, wavelet)
One dimensional Continuous Wavelet Transform.
Parameters
----------
data : array_like
Input signal
scales : array_like
The wavelet scales to use. One can use
``f = scale2frequency(wavelet, scale)/sampling_period`` to determine
what physical frequency, ``f``. Here, ``f`` is in hertz when the
``sampling_period`` is given in seconds.
wavelet : Wavelet object or name
Wavelet to use
sampling_period : float
Sampling period for the frequencies output (optional).
The values computed for ``coefs`` are independent of the choice of
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
period).
method : {'conv', 'fft'}, optional
The method used to compute the CWT. Can be any of:
- ``conv`` uses ``numpy.convolve``.
- ``fft`` uses frequency domain convolution.
- ``auto`` uses automatic selection based on an estimate of the
computational complexity at each scale.
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
The ``fft`` method is ``O(N * log2(N))`` with
``N = len(scale) + len(data) - 1``. It is well suited for large size
signals but slightly slower than ``conv`` on small ones.
axis: int, optional
Axis over which to compute the CWT. If not given, the last axis is
used.
Returns
-------
coefs : array_like
Continuous wavelet transform of the input signal for the given scales
and wavelet. The first axis of ``coefs`` corresponds to the scales.
The remaining axes match the shape of ``data``.
frequencies : array_like
If the unit of sampling period are seconds and given, than frequencies
are in hertz. Otherwise, a sampling period of 1 is assumed.
Notes
-----
Size of coefficients arrays depends on the length of the input array and
the length of given scales.
Examples
--------
>>> import pywt
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> x = np.arange(512)
>>> y = np.sin(2*np.pi*x/32)
>>> coef, freqs=pywt.cwt(y,np.arange(1,129),'gaus1')
>>> plt.matshow(coef) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
----------
>>> import pywt
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(-1, 1, 200, endpoint=False)
>>> sig = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
>>> widths = np.arange(1, 31)
>>> cwtmatr, freqs = pywt.cwt(sig, widths, 'mexh')
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt)
dt_cplx = np.result_type(dt, np.complex64)
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])
if not np.isscalar(axis):
raise ValueError("axis must be a scalar.")
dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)
if method == 'fft':
size_scale0 = -1
fft_data = None
elif not method == 'conv':
raise ValueError("method must be 'conv' or 'fft'")
if data.ndim > 1:
# move axis to be transformed last (so it is contiguous)
data = data.swapaxes(-1, axis)
# reshape to (n_batch, data.shape[-1])
data_shape_pre = data.shape
data = data.reshape((-1, data.shape[-1]))
for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]
if method == 'conv':
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
else:
# batch convolution via loop
conv_shape = list(data.shape)
conv_shape[-1] += int_psi_scale.size - 1
conv_shape = tuple(conv_shape)
conv = np.empty(conv_shape, dtype=dt_out)
for n in range(data.shape[0]):
conv[n, :] = np.convolve(data[n], int_psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
if out.dtype.kind != 'c':
coef = coef.real
# transform axis is always -1 due to the data reshape above
d = (coef.shape[-1] - data.shape[-1]) / 2.
if d > 0:
coef = coef[..., floor(d):-ceil(d)]
elif d < 0:
raise ValueError(
"Selected scale of {} too small.".format(scale))
if data.ndim > 1:
# restore original data shape and axis position
coef = coef.reshape(data_shape_pre)
coef = coef.swapaxes(axis, -1)
out[i, ...] = coef
frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
frequencies /= sampling_period
return out, frequencies

View file

@ -0,0 +1,187 @@
"""Utilities used to generate various figures in the documentation."""
from itertools import product
import numpy as np
from matplotlib import pyplot as plt
from ._dwt import pad
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
def wavedec_keys(level):
"""Subband keys corresponding to a wavedec decomposition."""
approx = ''
coeffs = {}
for lev in range(level):
for k in ['a', 'd']:
coeffs[approx + k] = None
approx = 'a' * (lev + 1)
if lev < level - 1:
coeffs.pop(approx)
return list(coeffs.keys())
def wavedec2_keys(level):
"""Subband keys corresponding to a wavedec2 decomposition."""
approx = ''
coeffs = {}
for lev in range(level):
for k in ['a', 'h', 'v', 'd']:
coeffs[approx + k] = None
approx = 'a' * (lev + 1)
if lev < level - 1:
coeffs.pop(approx)
return list(coeffs.keys())
def _box(bl, ur):
"""(x, y) coordinates for the 4 lines making up a rectangular box.
Parameters
==========
bl : float
The bottom left corner of the box
ur : float
The upper right corner of the box
Returns
=======
coords : 2-tuple
The first and second elements of the tuple are the x and y coordinates
of the box.
"""
xl, xr = bl[0], ur[0]
yb, yt = bl[1], ur[1]
box_x = [xl, xr,
xr, xr,
xr, xl,
xl, xl]
box_y = [yb, yb,
yb, yt,
yt, yt,
yt, yb]
return (box_x, box_y)
def _2d_wp_basis_coords(shape, keys):
# Coordinates of the lines to be drawn by draw_2d_wp_basis
coords = []
centers = {} # retain center of boxes for use in labeling
for key in keys:
offset_x = offset_y = 0
for n, char in enumerate(key):
if char in ['h', 'd']:
offset_x += shape[0] // 2**(n + 1)
if char in ['v', 'd']:
offset_y += shape[1] // 2**(n + 1)
sx = shape[0] // 2**(n + 1)
sy = shape[1] // 2**(n + 1)
xc, yc = _box((offset_x, -offset_y),
(offset_x + sx, -offset_y - sy))
coords.append((xc, yc))
centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
return coords, centers
def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
label_levels=0):
"""Plot a 2D representation of a WaveletPacket2D basis."""
coords, centers = _2d_wp_basis_coords(shape, keys)
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
for coord in coords:
ax.plot(coord[0], coord[1], fmt)
ax.set_axis_off()
ax.axis('square')
if label_levels > 0:
for key, c in centers.items():
if len(key) <= label_levels:
ax.text(c[0], c[1], key,
horizontalalignment='center',
verticalalignment='center')
return fig, ax
def _2d_fswavedecn_coords(shape, levels):
coords = []
centers = {} # retain center of boxes for use in labeling
for key in product(wavedec_keys(levels), repeat=2):
(key0, key1) = key
offsets = [0, 0]
widths = list(shape)
for n0, char in enumerate(key0):
if char in ['d']:
offsets[0] += shape[0] // 2**(n0 + 1)
for n1, char in enumerate(key1):
if char in ['d']:
offsets[1] += shape[1] // 2**(n1 + 1)
widths[0] = shape[0] // 2**(n0 + 1)
widths[1] = shape[1] // 2**(n1 + 1)
xc, yc = _box((offsets[0], -offsets[1]),
(offsets[0] + widths[0], -offsets[1] - widths[1]))
coords.append((xc, yc))
centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
-offsets[1] - widths[1] / 2)
return coords, centers
def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
label_levels=0):
"""Plot a 2D representation of a WaveletPacket2D basis."""
coords, centers = _2d_fswavedecn_coords(shape, levels)
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
for coord in coords:
ax.plot(coord[0], coord[1], fmt)
ax.set_axis_off()
ax.axis('square')
if label_levels > 0:
for key, c in centers.items():
lev = np.max([len(k) for k in key])
if lev <= label_levels:
ax.text(c[0], c[1], key,
horizontalalignment='center',
verticalalignment='center')
return fig, ax
def boundary_mode_subplot(x, mode, ax, symw=True):
"""Plot an illustration of the boundary mode in a subplot axis."""
# if odd-length, periodization replicates the last sample to make it even
if mode == 'periodization' and len(x) % 2 == 1:
x = np.concatenate((x, (x[-1], )))
npad = 2 * len(x)
t = np.arange(len(x) + 2 * npad)
xp = pad(x, (npad, npad), mode=mode)
ax.plot(t, xp, 'k.')
ax.set_title(mode)
# plot the original signal in red
if mode == 'periodization':
ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
else:
ax.plot(t[npad:npad + len(x)], x, 'r.')
# add vertical bars indicating points of symmetry or boundary extension
o2 = np.ones(2)
left = npad
if symw:
step = len(x) - 1
rng = range(-2, 4)
else:
left -= 0.5
step = len(x)
rng = range(-2, 4)
if mode in ['smooth', 'constant', 'zero']:
rng = range(0, 2)
for rep in rng:
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')

View file

@ -0,0 +1,517 @@
from numbers import Number
import numpy as np
from ._c99_config import _have_c99_complex
from ._extensions._pywt import Wavelet, Modes, _check_dtype, wavelist
from ._extensions._dwt import (dwt_single, dwt_axis, idwt_single, idwt_axis,
upcoef as _upcoef, downcoef as _downcoef,
dwt_max_level as _dwt_max_level,
dwt_coeff_len as _dwt_coeff_len)
from ._utils import string_types, _as_wavelet
__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
"dwt_coeff_len", "pad"]
def dwt_max_level(data_len, filter_len):
r"""
dwt_max_level(data_len, filter_len)
Compute the maximum useful level of decomposition.
Parameters
----------
data_len : int
Input data length.
filter_len : int, str or Wavelet
The wavelet filter length. Alternatively, the name of a discrete
wavelet or a Wavelet object can be specified.
Returns
-------
max_level : int
Maximum level.
Notes
-----
The rational for the choice of levels is the maximum level where at least
one coefficient in the output is uncorrupted by edge effects caused by
signal extension. Put another way, decomposition stops when the signal
becomes shorter than the FIR filter length for a given wavelet. This
corresponds to:
.. max_level = floor(log2(data_len/(filter_len - 1)))
.. math::
\mathtt{max\_level} = \left\lfloor\log_2\left(\mathtt{
\frac{data\_len}{filter\_len - 1}}\right)\right\rfloor
Examples
--------
>>> import pywt
>>> w = pywt.Wavelet('sym5')
>>> pywt.dwt_max_level(data_len=1000, filter_len=w.dec_len)
6
>>> pywt.dwt_max_level(1000, w)
6
>>> pywt.dwt_max_level(1000, 'sym5')
6
"""
if isinstance(filter_len, Wavelet):
filter_len = filter_len.dec_len
elif isinstance(filter_len, string_types):
if filter_len in wavelist(kind='discrete'):
filter_len = Wavelet(filter_len).dec_len
else:
raise ValueError(
("'{}', is not a recognized discrete wavelet. A list of "
"supported wavelet names can be obtained via "
"pywt.wavelist(kind='discrete')").format(filter_len))
elif not (isinstance(filter_len, Number) and filter_len % 1 == 0):
raise ValueError(
"filter_len must be an integer, discrete Wavelet object, or the "
"name of a discrete wavelet.")
if filter_len < 2:
raise ValueError("invalid wavelet filter length")
return _dwt_max_level(data_len, filter_len)
def dwt_coeff_len(data_len, filter_len, mode):
"""
dwt_coeff_len(data_len, filter_len, mode='symmetric')
Returns length of dwt output for given data length, filter length and mode
Parameters
----------
data_len : int
Data length.
filter_len : int
Filter length.
mode : str, optional
Signal extension mode, see :ref:`Modes <ref-modes>`.
Returns
-------
len : int
Length of dwt output.
Notes
-----
For all modes except periodization::
len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)
for periodization mode ("per")::
len(cA) == len(cD) == ceil(len(data) / 2)
"""
if isinstance(filter_len, Wavelet):
filter_len = filter_len.dec_len
return _dwt_coeff_len(data_len, filter_len, Modes.from_object(mode))
def dwt(data, wavelet, mode='symmetric', axis=-1):
"""
dwt(data, wavelet, mode='symmetric', axis=-1)
Single level Discrete Wavelet Transform.
Parameters
----------
data : array_like
Input signal
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional
Signal extension mode, see :ref:`Modes <ref-modes>`.
axis: int, optional
Axis over which to compute the DWT. If not given, the
last axis is used.
Returns
-------
(cA, cD) : tuple
Approximation and detail coefficients.
Notes
-----
Length of coefficients arrays depends on the selected mode.
For all modes except periodization:
``len(cA) == len(cD) == floor((len(data) + wavelet.dec_len - 1) / 2)``
For periodization mode ("per"):
``len(cA) == len(cD) == ceil(len(data) / 2)``
Examples
--------
>>> import pywt
>>> (cA, cD) = pywt.dwt([1, 2, 3, 4, 5, 6], 'db1')
>>> cA
array([ 2.12132034, 4.94974747, 7.77817459])
>>> cD
array([-0.70710678, -0.70710678, -0.70710678])
"""
if not _have_c99_complex and np.iscomplexobj(data):
data = np.asarray(data)
cA_r, cD_r = dwt(data.real, wavelet, mode, axis)
cA_i, cD_i = dwt(data.imag, wavelet, mode, axis)
return (cA_r + 1j*cA_i, cD_r + 1j*cD_i)
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt, order='C')
mode = Modes.from_object(mode)
wavelet = _as_wavelet(wavelet)
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise ValueError("Axis greater than data dimensions")
if data.ndim == 1:
cA, cD = dwt_single(data, wavelet, mode)
# TODO: Check whether this makes a copy
cA, cD = np.asarray(cA, dt), np.asarray(cD, dt)
else:
cA, cD = dwt_axis(data, wavelet, mode, axis=axis)
return (cA, cD)
def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
"""
idwt(cA, cD, wavelet, mode='symmetric', axis=-1)
Single level Inverse Discrete Wavelet Transform.
Parameters
----------
cA : array_like or None
Approximation coefficients. If None, will be set to array of zeros
with same shape as ``cD``.
cD : array_like or None
Detail coefficients. If None, will be set to array of zeros
with same shape as ``cA``.
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional (default: 'symmetric')
Signal extension mode, see :ref:`Modes <ref-modes>`.
axis: int, optional
Axis over which to compute the inverse DWT. If not given, the
last axis is used.
Returns
-------
rec: array_like
Single level reconstruction of signal from given coefficients.
Examples
--------
>>> import pywt
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
>>> pywt.idwt(cA, cD, 'db2', 'smooth')
array([ 1., 2., 3., 4., 5., 6.])
One of the neat features of ``idwt`` is that one of the ``cA`` and ``cD``
arguments can be set to None. In that situation the reconstruction will be
performed using only the other one. Mathematically speaking, this is
equivalent to passing a zero-filled array as one of the arguments.
>>> (cA, cD) = pywt.dwt([1,2,3,4,5,6], 'db2', 'smooth')
>>> A = pywt.idwt(cA, None, 'db2', 'smooth')
>>> D = pywt.idwt(None, cD, 'db2', 'smooth')
>>> A + D
array([ 1., 2., 3., 4., 5., 6.])
"""
# TODO: Lots of possible allocations to eliminate (zeros_like, asarray(rec))
# accept array_like input; make a copy to ensure a contiguous array
if cA is None and cD is None:
raise ValueError("At least one coefficient parameter must be "
"specified.")
# for complex inputs: compute real and imaginary separately then combine
if not _have_c99_complex and (np.iscomplexobj(cA) or np.iscomplexobj(cD)):
if cA is None:
cD = np.asarray(cD)
cA = np.zeros_like(cD)
elif cD is None:
cA = np.asarray(cA)
cD = np.zeros_like(cA)
return (idwt(cA.real, cD.real, wavelet, mode, axis) +
1j*idwt(cA.imag, cD.imag, wavelet, mode, axis))
if cA is not None:
dt = _check_dtype(cA)
cA = np.asarray(cA, dtype=dt, order='C')
if cD is not None:
dt = _check_dtype(cD)
cD = np.asarray(cD, dtype=dt, order='C')
if cA is not None and cD is not None:
if cA.dtype != cD.dtype:
# need to upcast to common type
if cA.dtype.kind == 'c' or cD.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
cA = cA.astype(dtype)
cD = cD.astype(dtype)
elif cA is None:
cA = np.zeros_like(cD)
elif cD is None:
cD = np.zeros_like(cA)
# cA and cD should be same dimension by here
ndim = cA.ndim
mode = Modes.from_object(mode)
wavelet = _as_wavelet(wavelet)
if axis < 0:
axis = axis + ndim
if not 0 <= axis < ndim:
raise ValueError("Axis greater than coefficient dimensions")
if ndim == 1:
rec = idwt_single(cA, cD, wavelet, mode)
else:
rec = idwt_axis(cA, cD, wavelet, mode, axis=axis)
return rec
def downcoef(part, data, wavelet, mode='symmetric', level=1):
"""
downcoef(part, data, wavelet, mode='symmetric', level=1)
Partial Discrete Wavelet Transform data decomposition.
Similar to ``pywt.dwt``, but computes only one set of coefficients.
Useful when you need only approximation or only details at the given level.
Parameters
----------
part : str
Coefficients type:
* 'a' - approximations reconstruction is performed
* 'd' - details reconstruction is performed
data : array_like
Input signal.
wavelet : Wavelet object or name
Wavelet to use
mode : str, optional
Signal extension mode, see :ref:`Modes <ref-modes>`.
level : int, optional
Decomposition level. Default is 1.
Returns
-------
coeffs : ndarray
1-D array of coefficients.
See Also
--------
upcoef
"""
if not _have_c99_complex and np.iscomplexobj(data):
return (downcoef(part, data.real, wavelet, mode, level) +
1j*downcoef(part, data.imag, wavelet, mode, level))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt, order='C')
if data.ndim > 1:
raise ValueError("downcoef only supports 1d data.")
if part not in 'ad':
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
mode = Modes.from_object(mode)
wavelet = _as_wavelet(wavelet)
return np.asarray(_downcoef(part == 'a', data, wavelet, mode, level))
def upcoef(part, coeffs, wavelet, level=1, take=0):
"""
upcoef(part, coeffs, wavelet, level=1, take=0)
Direct reconstruction from coefficients.
Parameters
----------
part : str
Coefficients type:
* 'a' - approximations reconstruction is performed
* 'd' - details reconstruction is performed
coeffs : array_like
Coefficients array to recontruct
wavelet : Wavelet object or name
Wavelet to use
level : int, optional
Multilevel reconstruction level. Default is 1.
take : int, optional
Take central part of length equal to 'take' from the result.
Default is 0.
Returns
-------
rec : ndarray
1-D array with reconstructed data from coefficients.
See Also
--------
downcoef
Examples
--------
>>> import pywt
>>> data = [1,2,3,4,5,6]
>>> (cA, cD) = pywt.dwt(data, 'db2', 'smooth')
>>> pywt.upcoef('a', cA, 'db2') + pywt.upcoef('d', cD, 'db2')
array([-0.25 , -0.4330127 , 1. , 2. , 3. ,
4. , 5. , 6. , 1.78589838, -1.03108891])
>>> n = len(data)
>>> pywt.upcoef('a', cA, 'db2', take=n) + pywt.upcoef('d', cD, 'db2', take=n)
array([ 1., 2., 3., 4., 5., 6.])
"""
if not _have_c99_complex and np.iscomplexobj(coeffs):
return (upcoef(part, coeffs.real, wavelet, level, take) +
1j*upcoef(part, coeffs.imag, wavelet, level, take))
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(coeffs)
coeffs = np.asarray(coeffs, dtype=dt, order='C')
if coeffs.ndim > 1:
raise ValueError("upcoef only supports 1d coeffs.")
wavelet = _as_wavelet(wavelet)
if part not in 'ad':
raise ValueError("Argument 1 must be 'a' or 'd', not '%s'." % part)
return np.asarray(_upcoef(part == 'a', coeffs, wavelet, level, take))
def pad(x, pad_widths, mode):
"""Extend a 1D signal using a given boundary mode.
This function operates like :func:`numpy.pad` but supports all signal
extension modes that can be used by PyWavelets discrete wavelet transforms.
Parameters
----------
x : ndarray
The array to pad
pad_widths : {sequence, array_like, int}
Number of values padded to the edges of each axis.
``((before_1, after_1), (before_N, after_N))`` unique pad widths for
each axis. ``((before, after),)`` yields same before and after pad for
each axis. ``(pad,)`` or int is a shortcut for
``before = after = pad width`` for all axes.
mode : str, optional
Signal extension mode, see :ref:`Modes <ref-modes>`.
Returns
-------
pad : ndarray
Padded array of rank equal to array with shape increased according to
``pad_widths``.
Notes
-----
The performance of padding in dimensions > 1 may be substantially slower
for modes ``'smooth'`` and ``'antisymmetric'`` as these modes are not
supported efficiently by the underlying :func:`numpy.pad` function.
Note that the behavior of the ``'constant'`` mode here follows the
PyWavelets convention which is different from NumPy (it is equivalent to
``mode='edge'`` in :func:`numpy.pad`).
"""
x = np.asanyarray(x)
# process pad_widths exactly as in numpy.pad
pad_widths = np.array(pad_widths)
pad_widths = np.round(pad_widths).astype(np.intp, copy=False)
if pad_widths.min() < 0:
raise ValueError("pad_widths must be > 0")
pad_widths = np.broadcast_to(pad_widths, (x.ndim, 2)).tolist()
if mode in ['symmetric', 'reflect']:
xp = np.pad(x, pad_widths, mode=mode)
elif mode in ['periodic', 'periodization']:
if mode == 'periodization':
# Promote odd-sized dimensions to even length by duplicating the
# last value.
edge_pad_widths = [(0, x.shape[ax] % 2)
for ax in range(x.ndim)]
x = np.pad(x, edge_pad_widths, mode='edge')
xp = np.pad(x, pad_widths, mode='wrap')
elif mode == 'zero':
xp = np.pad(x, pad_widths, mode='constant', constant_values=0)
elif mode == 'constant':
xp = np.pad(x, pad_widths, mode='edge')
elif mode == 'smooth':
def pad_smooth(vector, pad_width, iaxis, kwargs):
# smooth extension to left
left = vector[pad_width[0]]
slope_left = (left - vector[pad_width[0] + 1])
vector[:pad_width[0]] = \
left + np.arange(pad_width[0], 0, -1) * slope_left
# smooth extension to right
right = vector[-pad_width[1] - 1]
slope_right = (right - vector[-pad_width[1] - 2])
vector[-pad_width[1]:] = \
right + np.arange(1, pad_width[1] + 1) * slope_right
return vector
xp = np.pad(x, pad_widths, pad_smooth)
elif mode == 'antisymmetric':
def pad_antisymmetric(vector, pad_width, iaxis, kwargs):
# smooth extension to left
# implement by flipping portions symmetric padding
npad_l, npad_r = pad_width
vsize_nonpad = vector.size - npad_l - npad_r
# Note: must modify vector in-place
vector[:] = np.pad(vector[pad_width[0]:-pad_width[-1]],
pad_width, mode='symmetric')
vp = vector
r_edge = npad_l + vsize_nonpad - 1
l_edge = npad_l
# width of each reflected segment
seg_width = vsize_nonpad
# flip reflected segments on the right of the original signal
n = 1
while r_edge <= vp.size:
segment_slice = slice(r_edge + 1,
min(r_edge + 1 + seg_width, vp.size))
if n % 2:
vp[segment_slice] *= -1
r_edge += seg_width
n += 1
# flip reflected segments on the left of the original signal
n = 1
while l_edge >= 0:
segment_slice = slice(max(0, l_edge - seg_width), l_edge)
if n % 2:
vp[segment_slice] *= -1
l_edge -= seg_width
n += 1
return vector
xp = np.pad(x, pad_widths, pad_antisymmetric)
elif mode == 'antireflect':
xp = np.pad(x, pad_widths, mode='reflect', reflect_type='odd')
else:
raise ValueError(
("unsupported mode: {}. The supported modes are {}").format(
mode, Modes.modes))
return xp

View file

@ -0,0 +1,240 @@
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
"""
Other wavelet related functions.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
from numpy.fft import fft
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
"orthogonal_filter_bank",
"intwave", "centrfrq", "scal2frq", "orthfilt"]
_DEPRECATION_MSG = ("`{old}` has been renamed to `{new}` and will "
"be removed in a future version of pywt.")
def _integrate(arr, step):
integral = np.cumsum(arr)
integral *= step
return integral
def intwave(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='intwave', new='integrate_wavelet')
warnings.warn(msg, DeprecationWarning)
return integrate_wavelet(*args, **kwargs)
def centrfrq(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='centrfrq', new='central_frequency')
warnings.warn(msg, DeprecationWarning)
return central_frequency(*args, **kwargs)
def scal2frq(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='scal2frq', new='scale2frequency')
warnings.warn(msg, DeprecationWarning)
return scale2frequency(*args, **kwargs)
def orthfilt(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='orthfilt', new='orthogonal_filter_bank')
warnings.warn(msg, DeprecationWarning)
return orthogonal_filter_bank(*args, **kwargs)
def integrate_wavelet(wavelet, precision=8):
"""
Integrate `psi` wavelet function from -Inf to x using the rectangle
integration method.
Parameters
----------
wavelet : Wavelet instance or str
Wavelet to integrate. If a string, should be the name of a wavelet.
precision : int, optional
Precision that will be used for wavelet function
approximation computed with the wavefun(level=precision)
Wavelet's method (default: 8).
Returns
-------
[int_psi, x] :
for orthogonal wavelets
[int_psi_d, int_psi_r, x] :
for other wavelets
Examples
--------
>>> from pywt import Wavelet, integrate_wavelet
>>> wavelet1 = Wavelet('db2')
>>> [int_psi, x] = integrate_wavelet(wavelet1, precision=5)
>>> wavelet2 = Wavelet('bior1.3')
>>> [int_psi_d, int_psi_r, x] = integrate_wavelet(wavelet2, precision=5)
"""
# FIXME: this function should really use scipy.integrate.quad
if type(wavelet) in (tuple, list):
msg = ("Integration of a general signal is deprecated "
"and will be removed in a future version of pywt.")
warnings.warn(msg, DeprecationWarning)
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if type(wavelet) in (tuple, list):
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
step = x[1] - x[0]
return _integrate(psi, step), x
functions_approximations = wavelet.wavefun(precision)
if len(functions_approximations) == 2: # continuous wavelet
psi, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi, step), x
elif len(functions_approximations) == 3: # orthogonal wavelet
phi, psi, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi, step), x
else: # biorthogonal wavelet
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi_d, step), _integrate(psi_r, step), x
def central_frequency(wavelet, precision=8):
"""
Computes the central frequency of the `psi` wavelet function.
Parameters
----------
wavelet : Wavelet instance, str or tuple
Wavelet to integrate. If a string, should be the name of a wavelet.
precision : int, optional
Precision that will be used for wavelet function
approximation computed with the wavefun(level=precision)
Wavelet's method (default: 8).
Returns
-------
scalar
"""
if not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
functions_approximations = wavelet.wavefun(precision)
if len(functions_approximations) == 2:
psi, x = functions_approximations
else:
# (psi, x) for (phi, psi, x)
# (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x)
psi, x = functions_approximations[1], functions_approximations[-1]
domain = float(x[-1] - x[0])
assert domain > 0
index = np.argmax(abs(fft(psi)[1:])) + 2
if index > len(psi) / 2:
index = len(psi) - index + 2
return 1.0 / (domain / (index - 1))
def scale2frequency(wavelet, scale, precision=8):
"""
Parameters
----------
wavelet : Wavelet instance or str
Wavelet to integrate. If a string, should be the name of a wavelet.
scale : scalar
precision : int, optional
Precision that will be used for wavelet function approximation computed
with ``wavelet.wavefun(level=precision)``. Default is 8.
Returns
-------
freq : scalar
"""
return central_frequency(wavelet, precision=precision) / scale
def qmf(filt):
"""
Returns the Quadrature Mirror Filter(QMF).
The magnitude response of QMF is mirror image about `pi/2` of that of the
input filter.
Parameters
----------
filt : array_like
Input filter for which QMF needs to be computed.
Returns
-------
qm_filter : ndarray
Quadrature mirror of the input filter.
"""
qm_filter = np.array(filt)[::-1]
qm_filter[1::2] = -qm_filter[1::2]
return qm_filter
def orthogonal_filter_bank(scaling_filter):
"""
Returns the orthogonal filter bank.
The orthogonal filter bank consists of the HPFs and LPFs at
decomposition and reconstruction stage for the input scaling filter.
Parameters
----------
scaling_filter : array_like
Input scaling filter (father wavelet).
Returns
-------
orth_filt_bank : tuple of 4 ndarrays
The orthogonal filter bank of the input scaling filter in the order :
1] Decomposition LPF
2] Decomposition HPF
3] Reconstruction LPF
4] Reconstruction HPF
"""
if not (len(scaling_filter) % 2 == 0):
raise ValueError("`scaling_filter` length has to be even.")
scaling_filter = np.asarray(scaling_filter, dtype=np.float64)
rec_lo = np.sqrt(2) * scaling_filter / np.sum(scaling_filter)
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
orth_filt_bank = (dec_lo, dec_hi, rec_lo, rec_hi)
return orth_filt_bank

View file

@ -0,0 +1,311 @@
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
"""
2D and nD Discrete Wavelet Transforms and Inverse Discrete Wavelet Transforms.
"""
from __future__ import division, print_function, absolute_import
from itertools import product
import numpy as np
from ._c99_config import _have_c99_complex
from ._extensions._dwt import dwt_axis, idwt_axis
from ._utils import _wavelets_per_axis, _modes_per_axis
__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']
def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
"""
2D Discrete Wavelet Transform.
Parameters
----------
data : array_like
2D array with input data
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or 2-tuple of strings, optional
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
also be a tuple of modes specifying the mode to use on each axis in
``axes``.
axes : 2-tuple of ints, optional
Axes over which to compute the DWT. Repeated elements mean the DWT will
be performed multiple times along these axes.
Returns
-------
(cA, (cH, cV, cD)) : tuple
Approximation, horizontal detail, vertical detail and diagonal
detail coefficients respectively. Horizontal refers to array axis 0
(or ``axes[0]`` for user-specified ``axes``).
Examples
--------
>>> import numpy as np
>>> import pywt
>>> data = np.ones((4,4), dtype=np.float64)
>>> coeffs = pywt.dwt2(data, 'haar')
>>> cA, (cH, cV, cD) = coeffs
>>> cA
array([[ 2., 2.],
[ 2., 2.]])
>>> cV
array([[ 0., 0.],
[ 0., 0.]])
"""
axes = tuple(axes)
data = np.asarray(data)
if len(axes) != 2:
raise ValueError("Expected 2 axes")
if data.ndim < len(np.unique(axes)):
raise ValueError("Input array has fewer dimensions than the specified "
"axes")
coefs = dwtn(data, wavelet, mode, axes)
return coefs['aa'], (coefs['da'], coefs['ad'], coefs['dd'])
def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
"""
2-D Inverse Discrete Wavelet Transform.
Reconstructs data from coefficient arrays.
Parameters
----------
coeffs : tuple
(cA, (cH, cV, cD)) A tuple with approximation coefficients and three
details coefficients 2D arrays like from ``dwt2``. If any of these
components are set to ``None``, it will be treated as zeros.
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or 2-tuple of strings, optional
Signal extension mode, see :ref:`Modes <ref-modes>`. This can
also be a tuple of modes specifying the mode to use on each axis in
``axes``.
axes : 2-tuple of ints, optional
Axes over which to compute the IDWT. Repeated elements mean the IDWT
will be performed multiple times along these axes.
Examples
--------
>>> import numpy as np
>>> import pywt
>>> data = np.array([[1,2], [3,4]], dtype=np.float64)
>>> coeffs = pywt.dwt2(data, 'haar')
>>> pywt.idwt2(coeffs, 'haar')
array([[ 1., 2.],
[ 3., 4.]])
"""
# L -low-pass data, H - high-pass data
LL, (HL, LH, HH) = coeffs
axes = tuple(axes)
if len(axes) != 2:
raise ValueError("Expected 2 axes")
coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
return idwtn(coeffs, wavelet, mode, axes)
def dwtn(data, wavelet, mode='symmetric', axes=None):
"""
Single-level n-dimensional Discrete Wavelet Transform.
Parameters
----------
data : array_like
n-dimensional array with input data.
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or tuple of string, optional
Signal extension mode used in the decomposition,
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
specifying the mode to use on each axis in ``axes``.
axes : sequence of ints, optional
Axes over which to compute the DWT. Repeated elements mean the DWT will
be performed multiple times along these axes. A value of ``None`` (the
default) selects all axes.
Axes may be repeated, but information about the original size may be
lost if it is not divisible by ``2 ** nrepeats``. The reconstruction
will be larger, with additional values derived according to the
``mode`` parameter. ``pywt.wavedecn`` should be used for multilevel
decomposition.
Returns
-------
coeffs : dict
Results are arranged in a dictionary, where key specifies
the transform type on each dimension and value is a n-dimensional
coefficients array.
For example, for a 2D case the result will look something like this::
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
}
For user-specified ``axes``, the order of the characters in the
dictionary keys map to the specified ``axes``.
"""
data = np.asarray(data)
if not _have_c99_complex and np.iscomplexobj(data):
real = dwtn(data.real, wavelet, mode, axes)
imag = dwtn(data.imag, wavelet, mode, axes)
return dict((k, real[k] + 1j * imag[k]) for k in real.keys())
if data.dtype == np.dtype('object'):
raise TypeError("Input must be a numeric array-like")
if data.ndim < 1:
raise ValueError("Input data must be at least 1D")
if axes is None:
axes = range(data.ndim)
axes = [a + data.ndim if a < 0 else a for a in axes]
modes = _modes_per_axis(mode, axes)
wavelets = _wavelets_per_axis(wavelet, axes)
coeffs = [('', data)]
for axis, wav, mode in zip(axes, wavelets, modes):
new_coeffs = []
for subband, x in coeffs:
cA, cD = dwt_axis(x, wav, mode, axis)
new_coeffs.extend([(subband + 'a', cA),
(subband + 'd', cD)])
coeffs = new_coeffs
return dict(coeffs)
def _fix_coeffs(coeffs):
missing_keys = [k for k, v in coeffs.items() if v is None]
if missing_keys:
raise ValueError(
"The following detail coefficients were set to None:\n"
"{0}\n"
"For multilevel transforms, rather than setting\n"
"\tcoeffs[key] = None\n"
"use\n"
"\tcoeffs[key] = np.zeros_like(coeffs[key])\n".format(
missing_keys))
invalid_keys = [k for k, v in coeffs.items() if
not set(k) <= set('ad')]
if invalid_keys:
raise ValueError(
"The following invalid keys were found in the detail "
"coefficient dictionary: {}.".format(invalid_keys))
key_lengths = [len(k) for k in coeffs.keys()]
if len(np.unique(key_lengths)) > 1:
raise ValueError(
"All detail coefficient names must have equal length.")
return dict((k, np.asarray(v)) for k, v in coeffs.items())
def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
"""
Single-level n-dimensional Inverse Discrete Wavelet Transform.
Parameters
----------
coeffs: dict
Dictionary as in output of ``dwtn``. Missing or ``None`` items
will be treated as zeros.
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple containing a wavelet to
apply along each axis in ``axes``.
mode : str or list of string, optional
Signal extension mode used in the decomposition,
see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
specifying the mode to use on each axis in ``axes``.
axes : sequence of ints, optional
Axes over which to compute the IDWT. Repeated elements mean the IDWT
will be performed multiple times along these axes. A value of ``None``
(the default) selects all axes.
For the most accurate reconstruction, the axes should be provided in
the same order as they were provided to ``dwtn``.
Returns
-------
data: ndarray
Original signal reconstructed from input data.
"""
# drop the keys corresponding to value = None
coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)
# Raise error for invalid key combinations
coeffs = _fix_coeffs(coeffs)
if (not _have_c99_complex and
any(np.iscomplexobj(v) for v in coeffs.values())):
real_coeffs = dict((k, v.real) for k, v in coeffs.items())
imag_coeffs = dict((k, v.imag) for k, v in coeffs.items())
return (idwtn(real_coeffs, wavelet, mode, axes) +
1j * idwtn(imag_coeffs, wavelet, mode, axes))
# key length matches the number of axes transformed
ndim_transform = max(len(key) for key in coeffs.keys())
try:
coeff_shapes = (v.shape for k, v in coeffs.items()
if v is not None and len(k) == ndim_transform)
coeff_shape = next(coeff_shapes)
except StopIteration:
raise ValueError("`coeffs` must contain at least one non-null wavelet "
"band")
if any(s != coeff_shape for s in coeff_shapes):
raise ValueError("`coeffs` must all be of equal size (or None)")
if axes is None:
axes = range(ndim_transform)
ndim = ndim_transform
else:
ndim = len(coeff_shape)
axes = [a + ndim if a < 0 else a for a in axes]
modes = _modes_per_axis(mode, axes)
wavelets = _wavelets_per_axis(wavelet, axes)
for key_length, (axis, wav, mode) in reversed(
list(enumerate(zip(axes, wavelets, modes)))):
if axis < 0 or axis >= ndim:
raise ValueError("Axis greater than data dimensions")
new_coeffs = {}
new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]
for key in new_keys:
L = coeffs.get(key + 'a', None)
H = coeffs.get(key + 'd', None)
if L is not None and H is not None:
if L.dtype != H.dtype:
# upcast to a common dtype (float64 or complex128)
if L.dtype.kind == 'c' or H.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
L = np.asarray(L, dtype=dtype)
H = np.asarray(H, dtype=dtype)
new_coeffs[key] = idwt_axis(L, H, wav, mode, axis)
coeffs = new_coeffs
return coeffs['']

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,68 @@
"""common test-related code."""
import os
import sys
import multiprocessing
import numpy as np
import pytest
__all__ = ['uses_matlab', # skip if pymatbridge and Matlab unavailable
'uses_futures', # skip if futures unavailable
'uses_pymatbridge', # skip if no PYWT_XSLOW environment variable
'uses_precomputed', # skip if PYWT_XSLOW environment variable found
'matlab_result_dict_cwt', # dict with precomputed Matlab dwt data
'matlab_result_dict_dwt', # dict with precomputed Matlab cwt data
'futures', # the futures module or None
'max_workers', # the number of workers available to futures
'size_set', # the set of Matlab tests to run
]
try:
if sys.version_info[0] == 2:
import futures
else:
from concurrent import futures
max_workers = multiprocessing.cpu_count()
futures_available = True
except ImportError:
futures_available = False
futures = None
# check if pymatbridge + MATLAB tests should be run
matlab_result_dict_dwt = None
matlab_result_dict_cwt = None
matlab_missing = True
use_precomputed = True
size_set = 'reduced'
if 'PYWT_XSLOW' in os.environ:
try:
from pymatbridge import Matlab
mlab = Matlab()
matlab_missing = False
use_precomputed = False
size_set = 'full'
except ImportError:
print("To run Matlab compatibility tests you need to have MathWorks "
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
"package installed.")
if use_precomputed:
# load dictionaries of precomputed results
data_dir = os.path.join(os.path.dirname(__file__), 'tests', 'data')
matlab_data_file_cwt = os.path.join(
data_dir, 'cwt_matlabR2015b_result.npz')
matlab_result_dict_cwt = np.load(matlab_data_file_cwt)
matlab_data_file_dwt = os.path.join(
data_dir, 'dwt_matlabR2012a_result.npz')
matlab_result_dict_dwt = np.load(matlab_data_file_dwt)
uses_futures = pytest.mark.skipif(
not futures_available, reason='futures not available')
uses_matlab = pytest.mark.skipif(
matlab_missing, reason='pymatbridge and/or Matlab not available')
uses_pymatbridge = pytest.mark.skipif(
use_precomputed,
reason='PYWT_XSLOW set: skipping tests against precomputed Matlab results')
uses_precomputed = pytest.mark.skipif(
not use_precomputed,
reason='PYWT_XSLOW not set: test against precomputed matlab tests')

View file

@ -0,0 +1,164 @@
"""
Pytest test running.
This module implements the ``test()`` function for NumPy modules. The usual
boiler plate for doing that is to put the following in the module
``__init__.py`` file::
from pywt._pytesttester import PytestTester
test = PytestTester(__name__).test
del PytestTester
Warnings filtering and other runtime settings should be dealt with in the
``pytest.ini`` file in the pywt repo root. The behavior of the test depends on
whether or not that file is found as follows:
* ``pytest.ini`` is present (develop mode)
All warnings except those explicily filtered out are raised as error.
* ``pytest.ini`` is absent (release mode)
DeprecationWarnings and PendingDeprecationWarnings are ignored, other
warnings are passed through.
In practice, tests run from the PyWavelets repo are run in develop mode. That
includes the standard ``python runtests.py`` invocation.
"""
from __future__ import division, absolute_import, print_function
import sys
import os
__all__ = ['PytestTester']
def _show_pywt_info():
import pywt
from pywt._c99_config import _have_c99_complex
print("PyWavelets version %s" % pywt.__version__)
if _have_c99_complex:
print("Compiled with C99 complex support.")
else:
print("Compiled without C99 complex support.")
class PytestTester(object):
"""
Pytest test runner.
This class is made available in ``pywt.testing``, and a test function
is typically added to a package's __init__.py like so::
from pywt.testing import PytestTester
test = PytestTester(__name__).test
del PytestTester
Calling this test function finds and runs all tests associated with the
module and all its sub-modules.
Attributes
----------
module_name : str
Full path to the package to test.
Parameters
----------
module_name : module name
The name of the module to test.
"""
def __init__(self, module_name):
self.module_name = module_name
def __call__(self, label='fast', verbose=1, extra_argv=None,
doctests=False, coverage=False, durations=-1, tests=None):
"""
Run tests for module using pytest.
Parameters
----------
label : {'fast', 'full'}, optional
Identifies the tests to run. When set to 'fast', tests decorated
with `pytest.mark.slow` are skipped, when 'full', the slow marker
is ignored.
verbose : int, optional
Verbosity value for test outputs, in the range 1-3. Default is 1.
extra_argv : list, optional
List with any extra arguments to pass to pytests.
doctests : bool, optional
.. note:: Not supported
coverage : bool, optional
If True, report coverage of NumPy code. Default is False.
Requires installation of (pip) pytest-cov.
durations : int, optional
If < 0, do nothing, If 0, report time of all tests, if > 0,
report the time of the slowest `timer` tests. Default is -1.
tests : test or list of tests
Tests to be executed with pytest '--pyargs'
Returns
-------
result : bool
Return True on success, false otherwise.
Examples
--------
>>> result = np.lib.test() #doctest: +SKIP
...
1023 passed, 2 skipped, 6 deselected, 1 xfailed in 10.39 seconds
>>> result
True
"""
import pytest
module = sys.modules[self.module_name]
module_path = os.path.abspath(module.__path__[0])
# setup the pytest arguments
pytest_args = ["-l"]
# offset verbosity. The "-q" cancels a "-v".
pytest_args += ["-q"]
# Filter out annoying import messages. Want these in both develop and
# release mode.
pytest_args += [
"-W ignore:Not importing directory",
"-W ignore:numpy.dtype size changed",
"-W ignore:numpy.ufunc size changed", ]
if doctests:
raise ValueError("Doctests not supported")
if extra_argv:
pytest_args += list(extra_argv)
if verbose > 1:
pytest_args += ["-" + "v"*(verbose - 1)]
if coverage:
pytest_args += ["--cov=" + module_path]
if label == "fast":
pytest_args += ["-m", "not slow"]
elif label != "full":
pytest_args += ["-m", label]
if durations >= 0:
pytest_args += ["--durations=%s" % durations]
if tests is None:
tests = [self.module_name]
pytest_args += ["--pyargs"] + list(tests)
# run tests.
_show_pywt_info()
try:
code = pytest.main(pytest_args)
except SystemExit as exc:
code = exc.code
return code == 0

View file

@ -0,0 +1,774 @@
import warnings
from itertools import product
import numpy as np
from ._c99_config import _have_c99_complex
from ._extensions._dwt import idwt_single
from ._extensions._swt import swt_max_level, swt as _swt, swt_axis as _swt_axis
from ._extensions._pywt import Wavelet, Modes, _check_dtype
from ._multidim import idwt2, idwtn
from ._utils import _as_wavelet, _wavelets_per_axis
__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']
def _rescale_wavelet_filterbank(wavelet, sf):
wav = Wavelet(wavelet.name + 'r',
[np.asarray(f) * sf for f in wavelet.filter_bank])
# copy attributes from the original wavelet
wav.orthogonal = wavelet.orthogonal
wav.biorthogonal = wavelet.biorthogonal
return wav
def swt(data, wavelet, level=None, start_level=0, axis=-1,
trim_approx=False, norm=False):
"""
Multilevel 1D stationary wavelet transform.
Parameters
----------
data :
Input signal
wavelet :
Wavelet to use (Wavelet object or name)
level : int, optional
The number of decomposition steps to perform.
start_level : int, optional
The level at which the decomposition will begin (it allows one to
skip a given number of transform steps and compute
coefficients starting from start_level) (default: 0)
axis: int, optional
Axis over which to compute the SWT. If not given, the
last axis is used.
trim_approx : bool, optional
If True, approximation coefficients at the final level are retained.
norm : bool, optional
If True, transform is normalized so that the energy of the coefficients
will be equal to the energy of ``data``. In other words,
``np.linalg.norm(data.ravel())`` will equal the norm of the
concatenated transform coefficients when ``trim_approx`` is True.
Returns
-------
coeffs : list
List of approximation and details coefficients pairs in order
similar to wavedec function::
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
where n equals input parameter ``level``.
If ``start_level = m`` is given, then the beginning m steps are
skipped::
[(cAm+n, cDm+n), ..., (cAm+1, cDm+1), (cAm, cDm)]
If ``trim_approx`` is ``True``, then the output list is exactly as in
``pywt.wavedec``, where the first coefficient in the list is the
approximation coefficient at the final level and the rest are the
detail coefficients::
[cAn, cDn, ..., cD2, cD1]
Notes
-----
The implementation here follows the "algorithm a-trous" and requires that
the signal length along the transformed axis be a multiple of ``2**level``.
If this is not the case, the user should pad up to an appropriate size
using a function such as ``numpy.pad``.
A primary benefit of this transform in comparison to its decimated
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
at cost of redundancy in the transform (the size of the output coefficients
is larger than the input).
When the following three conditions are true:
1. The wavelet is orthogonal
2. ``swt`` is called with ``norm=True``
3. ``swt`` is called with ``trim_approx=True``
the transform has the following additional properties that may be
desirable in applications:
1. energy is conserved
2. variance is partitioned across scales
When used with ``norm=True``, this transform is closely related to the
multiple-overlap DWT (MODWT) as popularized for time-series analysis,
although the underlying implementation is slightly different from the one
published in [1]_. Specifically, the implementation used here requires a
signal that is a multiple of ``2**level`` in length.
References
----------
.. [1] DB Percival and AT Walden. Wavelet Methods for Time Series Analysis.
Cambridge University Press, 2000.
"""
if not _have_c99_complex and np.iscomplexobj(data):
data = np.asarray(data)
coeffs_real = swt(data.real, wavelet, level, start_level, trim_approx)
coeffs_imag = swt(data.imag, wavelet, level, start_level, trim_approx)
if not trim_approx:
coeffs_cplx = []
for (cA_r, cD_r), (cA_i, cD_i) in zip(coeffs_real, coeffs_imag):
coeffs_cplx.append((cA_r + 1j*cA_i, cD_r + 1j*cD_i))
else:
coeffs_cplx = [cr + 1j*ci
for (cr, ci) in zip(coeffs_real, coeffs_imag)]
return coeffs_cplx
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.array(data, dtype=dt)
wavelet = _as_wavelet(wavelet)
if norm:
if not wavelet.orthogonal:
warnings.warn(
"norm=True, but the wavelet is not orthogonal: \n"
"\tThe conditions for energy preservation are not satisfied.")
wavelet = _rescale_wavelet_filterbank(wavelet, 1/np.sqrt(2))
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise ValueError("Axis greater than data dimensions")
if level is None:
level = swt_max_level(data.shape[axis])
if data.ndim == 1:
ret = _swt(data, wavelet, level, start_level, trim_approx)
else:
ret = _swt_axis(data, wavelet, level, start_level, axis, trim_approx)
return ret
def iswt(coeffs, wavelet, norm=False):
"""
Multilevel 1D inverse discrete stationary wavelet transform.
Parameters
----------
coeffs : array_like
Coefficients list of tuples::
[(cAn, cDn), ..., (cA2, cD2), (cA1, cD1)]
where cA is approximation, cD is details. Index 1 corresponds to
``start_level`` from ``pywt.swt``.
wavelet : Wavelet object or name string
Wavelet to use
norm : bool, optional
Controls the normalization used by the inverse transform. This must
be set equal to the value that was used by ``pywt.swt`` to preserve the
energy of a round-trip transform.
Returns
-------
1D array of reconstructed data.
Examples
--------
>>> import pywt
>>> coeffs = pywt.swt([1,2,3,4,5,6,7,8], 'db2', level=2)
>>> pywt.iswt(coeffs, 'db2')
array([ 1., 2., 3., 4., 5., 6., 7., 8.])
"""
# copy to avoid modification of input data
# If swt was called with trim_approx=False, first element is a tuple
trim_approx = not isinstance(coeffs[0], (tuple, list))
if trim_approx:
cA = coeffs[0]
coeffs = coeffs[1:]
else:
cA = coeffs[0][0]
dt = _check_dtype(cA)
output = np.array(cA, dtype=dt, copy=True)
if not _have_c99_complex and np.iscomplexobj(output):
# compute real and imaginary separately then combine
if trim_approx:
coeffs_real = [c.real for c in coeffs]
coeffs_imag = [c.imag for c in coeffs]
else:
coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs]
coeffs_imag = [(cA.imag, cD.imag) for (cA, cD) in coeffs]
return iswt(coeffs_real, wavelet) + 1j*iswt(coeffs_imag, wavelet)
# num_levels, equivalent to the decomposition level, n
num_levels = len(coeffs)
wavelet = _as_wavelet(wavelet)
if norm:
wavelet = _rescale_wavelet_filterbank(wavelet, np.sqrt(2))
mode = Modes.from_object('periodization')
for j in range(num_levels, 0, -1):
step_size = int(pow(2, j-1))
last_index = step_size
if trim_approx:
cD = coeffs[-j]
else:
_, cD = coeffs[-j]
cD = np.asarray(cD, dtype=_check_dtype(cD))
if cD.dtype != output.dtype:
# upcast to a common dtype (float64 or complex128)
if output.dtype.kind == 'c' or cD.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
output = np.asarray(output, dtype=dtype)
cD = np.asarray(cD, dtype=dtype)
for first in range(last_index): # 0 to last_index - 1
# Getting the indices that we will transform
indices = np.arange(first, len(cD), step_size)
# select the even indices
even_indices = indices[0::2]
# select the odd indices
odd_indices = indices[1::2]
# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
# Note: indexing with an array of ints returns a contiguous
# copy as required by idwt_single.
x1 = idwt_single(output[even_indices],
cD[even_indices],
wavelet, mode)
x2 = idwt_single(output[odd_indices],
cD[odd_indices],
wavelet, mode)
# perform a circular shift right
x2 = np.roll(x2, 1)
# average and insert into the correct indices
output[indices] = (x1 + x2)/2.
return output
def swt2(data, wavelet, level, start_level=0, axes=(-2, -1),
trim_approx=False, norm=False):
"""
Multilevel 2D stationary wavelet transform.
Parameters
----------
data : array_like
2D array with input data
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a tuple of wavelets to apply per
axis in ``axes``.
level : int
The number of decomposition steps to perform.
start_level : int, optional
The level at which the decomposition will start (default: 0)
axes : 2-tuple of ints, optional
Axes over which to compute the SWT. Repeated elements are not allowed.
trim_approx : bool, optional
If True, approximation coefficients at the final level are retained.
norm : bool, optional
If True, transform is normalized so that the energy of the coefficients
will be equal to the energy of ``data``. In other words,
``np.linalg.norm(data.ravel())`` will equal the norm of the
concatenated transform coefficients when ``trim_approx`` is True.
Returns
-------
coeffs : list
Approximation and details coefficients (for ``start_level = m``).
If ``trim_approx`` is ``True``, approximation coefficients are
retained for all levels::
[
(cA_m+level,
(cH_m+level, cV_m+level, cD_m+level)
),
...,
(cA_m+1,
(cH_m+1, cV_m+1, cD_m+1)
),
(cA_m,
(cH_m, cV_m, cD_m)
)
]
where cA is approximation, cH is horizontal details, cV is
vertical details, cD is diagonal details and m is ``start_level``.
If ``trim_approx`` is ``False``, approximation coefficients are only
retained at the final level of decomposition. This matches the format
used by ``pywt.wavedec2``::
[
cA_m+level,
(cH_m+level, cV_m+level, cD_m+level),
...,
(cH_m+1, cV_m+1, cD_m+1),
(cH_m, cV_m, cD_m),
]
Notes
-----
The implementation here follows the "algorithm a-trous" and requires that
the signal length along the transformed axes be a multiple of ``2**level``.
If this is not the case, the user should pad up to an appropriate size
using a function such as ``numpy.pad``.
A primary benefit of this transform in comparison to its decimated
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
at cost of redundancy in the transform (the size of the output coefficients
is larger than the input).
When the following three conditions are true:
1. The wavelet is orthogonal
2. ``swt2`` is called with ``norm=True``
3. ``swt2`` is called with ``trim_approx=True``
the transform has the following additional properties that may be
desirable in applications:
1. energy is conserved
2. variance is partitioned across scales
"""
axes = tuple(axes)
data = np.asarray(data)
if len(axes) != 2:
raise ValueError("Expected 2 axes")
if len(axes) != len(set(axes)):
raise ValueError("The axes passed to swt2 must be unique.")
if data.ndim < len(np.unique(axes)):
raise ValueError("Input array has fewer dimensions than the specified "
"axes")
coefs = swtn(data, wavelet, level, start_level, axes, trim_approx, norm)
ret = []
if trim_approx:
ret.append(coefs[0])
coefs = coefs[1:]
for c in coefs:
if trim_approx:
ret.append((c['da'], c['ad'], c['dd']))
else:
ret.append((c['aa'], (c['da'], c['ad'], c['dd'])))
return ret
def iswt2(coeffs, wavelet, norm=False):
"""
Multilevel 2D inverse discrete stationary wavelet transform.
Parameters
----------
coeffs : list
Approximation and details coefficients::
[
(cA_n,
(cH_n, cV_n, cD_n)
),
...,
(cA_2,
(cH_2, cV_2, cD_2)
),
(cA_1,
(cH_1, cV_1, cD_1)
)
]
where cA is approximation, cH is horizontal details, cV is
vertical details, cD is diagonal details and n is the number of
levels. Index 1 corresponds to ``start_level`` from ``pywt.swt2``.
wavelet : Wavelet object or name string, or 2-tuple of wavelets
Wavelet to use. This can also be a 2-tuple of wavelets to apply per
axis.
norm : bool, optional
Controls the normalization used by the inverse transform. This must
be set equal to the value that was used by ``pywt.swt2`` to preserve
the energy of a round-trip transform.
Returns
-------
2D array of reconstructed data.
Examples
--------
>>> import pywt
>>> coeffs = pywt.swt2([[1,2,3,4],[5,6,7,8],
... [9,10,11,12],[13,14,15,16]],
... 'db1', level=2)
>>> pywt.iswt2(coeffs, 'db1')
array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]])
"""
# If swt was called with trim_approx=False, first element is a tuple
trim_approx = not isinstance(coeffs[0], (tuple, list))
if trim_approx:
cA = coeffs[0]
coeffs = coeffs[1:]
else:
cA = coeffs[0][0]
# copy to avoid modification of input data
dt = _check_dtype(cA)
output = np.array(cA, dtype=dt, copy=True)
if output.ndim != 2:
raise ValueError(
"iswt2 only supports 2D arrays. see iswtn for a general "
"n-dimensionsal ISWT")
# num_levels, equivalent to the decomposition level, n
num_levels = len(coeffs)
wavelets = _wavelets_per_axis(wavelet, axes=(0, 1))
if norm:
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
for wav in wavelets]
for j in range(num_levels):
step_size = int(pow(2, num_levels-j-1))
last_index = step_size
if trim_approx:
(cH, cV, cD) = coeffs[j]
else:
_, (cH, cV, cD) = coeffs[j]
# We are going to assume cH, cV, and cD are of equal size
if (cH.shape != cV.shape) or (cH.shape != cD.shape):
raise RuntimeError(
"Mismatch in shape of intermediate coefficient arrays")
# make sure output shares the common dtype
# (conversion of dtype for individual coeffs is handled within idwt2 )
common_dtype = np.result_type(*(
[dt, ] + [_check_dtype(c) for c in [cH, cV, cD]]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)
for first_h in range(last_index): # 0 to last_index - 1
for first_w in range(last_index): # 0 to last_index - 1
# Getting the indices that we will transform
indices_h = slice(first_h, cH.shape[0], step_size)
indices_w = slice(first_w, cH.shape[1], step_size)
even_idx_h = slice(first_h, cH.shape[0], 2*step_size)
even_idx_w = slice(first_w, cH.shape[1], 2*step_size)
odd_idx_h = slice(first_h + step_size, cH.shape[0], 2*step_size)
odd_idx_w = slice(first_w + step_size, cH.shape[1], 2*step_size)
# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
x1 = idwt2((output[even_idx_h, even_idx_w],
(cH[even_idx_h, even_idx_w],
cV[even_idx_h, even_idx_w],
cD[even_idx_h, even_idx_w])),
wavelets, 'periodization')
x2 = idwt2((output[even_idx_h, odd_idx_w],
(cH[even_idx_h, odd_idx_w],
cV[even_idx_h, odd_idx_w],
cD[even_idx_h, odd_idx_w])),
wavelets, 'periodization')
x3 = idwt2((output[odd_idx_h, even_idx_w],
(cH[odd_idx_h, even_idx_w],
cV[odd_idx_h, even_idx_w],
cD[odd_idx_h, even_idx_w])),
wavelets, 'periodization')
x4 = idwt2((output[odd_idx_h, odd_idx_w],
(cH[odd_idx_h, odd_idx_w],
cV[odd_idx_h, odd_idx_w],
cD[odd_idx_h, odd_idx_w])),
wavelets, 'periodization')
# perform a circular shifts
x2 = np.roll(x2, 1, axis=1)
x3 = np.roll(x3, 1, axis=0)
x4 = np.roll(x4, 1, axis=0)
x4 = np.roll(x4, 1, axis=1)
output[indices_h, indices_w] = (x1 + x2 + x3 + x4) / 4
return output
def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False,
norm=False):
"""
n-dimensional stationary wavelet transform.
Parameters
----------
data : array_like
n-dimensional array with input data.
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple of wavelets to apply per
axis in ``axes``.
level : int
The number of decomposition steps to perform.
start_level : int, optional
The level at which the decomposition will start (default: 0)
axes : sequence of ints, optional
Axes over which to compute the SWT. A value of ``None`` (the
default) selects all axes. Axes may not be repeated.
trim_approx : bool, optional
If True, approximation coefficients at the final level are retained.
norm : bool, optional
If True, transform is normalized so that the energy of the coefficients
will be equal to the energy of ``data``. In other words,
``np.linalg.norm(data.ravel())`` will equal the norm of the
concatenated transform coefficients when ``trim_approx`` is True.
Returns
-------
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
Results for each level are arranged in a dictionary, where the key
specifies the transform type on each dimension and value is a
n-dimensional coefficients array.
For example, for a 2D case the result at a given level will look
something like this::
{'aa': <coeffs> # A(LL) - approx. on 1st dim, approx. on 2nd dim
'ad': <coeffs> # V(LH) - approx. on 1st dim, det. on 2nd dim
'da': <coeffs> # H(HL) - det. on 1st dim, approx. on 2nd dim
'dd': <coeffs> # D(HH) - det. on 1st dim, det. on 2nd dim
}
For user-specified ``axes``, the order of the characters in the
dictionary keys map to the specified ``axes``.
If ``trim_approx`` is ``True``, the first element of the list contains
the array of approximation coefficients from the final level of
decomposition, while the remaining coefficient dictionaries contain
only detail coefficients. This matches the behavior of `pywt.wavedecn`.
Notes
-----
The implementation here follows the "algorithm a-trous" and requires that
the signal length along the transformed axes be a multiple of ``2**level``.
If this is not the case, the user should pad up to an appropriate size
using a function such as ``numpy.pad``.
A primary benefit of this transform in comparison to its decimated
counterpart (``pywt.wavedecn``), is that it is shift-invariant. This comes
at cost of redundancy in the transform (the size of the output coefficients
is larger than the input).
When the following three conditions are true:
1. The wavelet is orthogonal
2. ``swtn`` is called with ``norm=True``
3. ``swtn`` is called with ``trim_approx=True``
the transform has the following additional properties that may be
desirable in applications:
1. energy is conserved
2. variance is partitioned across scales
"""
data = np.asarray(data)
if not _have_c99_complex and np.iscomplexobj(data):
real = swtn(data.real, wavelet, level, start_level, axes, trim_approx)
imag = swtn(data.imag, wavelet, level, start_level, axes, trim_approx)
if trim_approx:
cplx = [real[0] + 1j * imag[0]]
offset = 1
else:
cplx = []
offset = 0
for rdict, idict in zip(real[offset:], imag[offset:]):
cplx.append(
dict((k, rdict[k] + 1j * idict[k]) for k in rdict.keys()))
return cplx
if data.dtype == np.dtype('object'):
raise TypeError("Input must be a numeric array-like")
if data.ndim < 1:
raise ValueError("Input data must be at least 1D")
if axes is None:
axes = range(data.ndim)
axes = [a + data.ndim if a < 0 else a for a in axes]
if len(axes) != len(set(axes)):
raise ValueError("The axes passed to swtn must be unique.")
num_axes = len(axes)
wavelets = _wavelets_per_axis(wavelet, axes)
if norm:
if not np.all([wav.orthogonal for wav in wavelets]):
warnings.warn(
"norm=True, but the wavelets used are not orthogonal: \n"
"\tThe conditions for energy preservation are not satisfied.")
wavelets = [_rescale_wavelet_filterbank(wav, 1/np.sqrt(2))
for wav in wavelets]
ret = []
for i in range(start_level, start_level + level):
coeffs = [('', data)]
for axis, wavelet in zip(axes, wavelets):
new_coeffs = []
for subband, x in coeffs:
cA, cD = _swt_axis(x, wavelet, level=1, start_level=i,
axis=axis)[0]
new_coeffs.extend([(subband + 'a', cA),
(subband + 'd', cD)])
coeffs = new_coeffs
coeffs = dict(coeffs)
ret.append(coeffs)
# data for the next level is the approximation coeffs from this level
data = coeffs['a' * num_axes]
if trim_approx:
coeffs.pop('a' * num_axes)
if trim_approx:
ret.append(data)
ret.reverse()
return ret
def iswtn(coeffs, wavelet, axes=None, norm=False):
"""
Multilevel nD inverse discrete stationary wavelet transform.
Parameters
----------
coeffs : list
[{coeffs_level_n}, ..., {coeffs_level_1}]: list of dict
wavelet : Wavelet object or name string, or tuple of wavelets
Wavelet to use. This can also be a tuple of wavelets to apply per
axis in ``axes``.
axes : sequence of ints, optional
Axes over which to compute the inverse SWT. Axes may not be repeated.
The default is ``None``, which means transform all axes
(``axes = range(data.ndim)``).
norm : bool, optional
Controls the normalization used by the inverse transform. This must
be set equal to the value that was used by ``pywt.swtn`` to preserve
the energy of a round-trip transform.
Returns
-------
nD array of reconstructed data.
Examples
--------
>>> import pywt
>>> coeffs = pywt.swtn([[1,2,3,4],[5,6,7,8],
... [9,10,11,12],[13,14,15,16]],
... 'db1', level=2)
>>> pywt.iswtn(coeffs, 'db1')
array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]])
"""
# key length matches the number of axes transformed
ndim_transform = max(len(key) for key in coeffs[-1].keys())
trim_approx = not isinstance(coeffs[0], dict)
if trim_approx:
cA = coeffs[0]
coeffs = coeffs[1:]
else:
cA = coeffs[0]['a'*ndim_transform]
# copy to avoid modification of input data
dt = _check_dtype(cA)
output = np.array(cA, dtype=dt, copy=True)
ndim = output.ndim
if axes is None:
axes = range(output.ndim)
axes = [a + ndim if a < 0 else a for a in axes]
if len(axes) != len(set(axes)):
raise ValueError("The axes passed to swtn must be unique.")
if ndim_transform != len(axes):
raise ValueError("The number of axes used in iswtn must match the "
"number of dimensions transformed in swtn.")
# num_levels, equivalent to the decomposition level, n
num_levels = len(coeffs)
wavelets = _wavelets_per_axis(wavelet, axes)
if norm:
wavelets = [_rescale_wavelet_filterbank(wav, np.sqrt(2))
for wav in wavelets]
# initialize various slice objects used in the loops below
# these will remain slice(None) only on axes that aren't transformed
indices = [slice(None), ]*ndim
even_indices = [slice(None), ]*ndim
odd_indices = [slice(None), ]*ndim
odd_even_slices = [slice(None), ]*ndim
for j in range(num_levels):
step_size = int(pow(2, num_levels-j-1))
last_index = step_size
if not trim_approx:
a = coeffs[j].pop('a'*ndim_transform) # will restore later
details = coeffs[j]
# make sure dtype matches the coarsest level approximation coefficients
common_dtype = np.result_type(*(
[dt, ] + [v.dtype for v in details.values()]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)
# We assume all coefficient arrays are of equal size
shapes = [v.shape for k, v in details.items()]
if len(set(shapes)) != 1:
raise RuntimeError(
"Mismatch in shape of intermediate coefficient arrays")
# shape of a single coefficient array, excluding non-transformed axes
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])
# nested loop over all combinations of axis offsets at this level
for firsts in product(*([range(last_index), ]*ndim_transform)):
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
indices[ax] = slice(first, sh, step_size)
even_indices[ax] = slice(first, sh, 2*step_size)
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)
# nested loop over all combinations of odd/even inidices
approx = output.copy()
output[tuple(indices)] = 0
ntransforms = 0
for odds in product(*([(0, 1), ]*ndim_transform)):
for o, ax in zip(odds, axes):
if o:
odd_even_slices[ax] = odd_indices[ax]
else:
odd_even_slices[ax] = even_indices[ax]
# extract the odd/even indices for all detail coefficients
details_slice = {}
for key, value in details.items():
details_slice[key] = value[tuple(odd_even_slices)]
details_slice['a'*ndim_transform] = approx[
tuple(odd_even_slices)]
# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
x = idwtn(details_slice, wavelets, 'periodization', axes=axes)
for o, ax in zip(odds, axes):
# circular shift along any odd indexed axis
if o:
x = np.roll(x, 1, axis=ax)
output[tuple(indices)] += x
ntransforms += 1
output[tuple(indices)] /= ntransforms # normalize
if not trim_approx:
coeffs[j]['a'*ndim_transform] = a # restore approx coeffs to dict
return output

View file

@ -0,0 +1,250 @@
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
"""
The thresholding helper module implements the most popular signal thresholding
functions.
"""
from __future__ import division, print_function, absolute_import
import numpy as np
__all__ = ['threshold', 'threshold_firm']
def soft(data, value, substitute=0):
data = np.asarray(data)
magnitude = np.absolute(data)
with np.errstate(divide='ignore'):
# divide by zero okay as np.inf values get clipped, so ignore warning.
thresholded = (1 - value/magnitude)
thresholded.clip(min=0, max=None, out=thresholded)
thresholded = data * thresholded
if substitute == 0:
return thresholded
else:
cond = np.less(magnitude, value)
return np.where(cond, substitute, thresholded)
def nn_garrote(data, value, substitute=0):
"""Non-negative Garrote."""
data = np.asarray(data)
magnitude = np.absolute(data)
with np.errstate(divide='ignore'):
# divide by zero okay as np.inf values get clipped, so ignore warning.
thresholded = (1 - value**2/magnitude**2)
thresholded.clip(min=0, max=None, out=thresholded)
thresholded = data * thresholded
if substitute == 0:
return thresholded
else:
cond = np.less(magnitude, value)
return np.where(cond, substitute, thresholded)
def hard(data, value, substitute=0):
data = np.asarray(data)
cond = np.less(np.absolute(data), value)
return np.where(cond, substitute, data)
def greater(data, value, substitute=0):
data = np.asarray(data)
if np.iscomplexobj(data):
raise ValueError("greater thresholding only supports real data")
return np.where(np.less(data, value), substitute, data)
def less(data, value, substitute=0):
data = np.asarray(data)
if np.iscomplexobj(data):
raise ValueError("less thresholding only supports real data")
return np.where(np.greater(data, value), substitute, data)
thresholding_options = {'soft': soft,
'hard': hard,
'greater': greater,
'less': less,
'garrote': nn_garrote,
# misspelled garrote for backwards compatibility
'garotte': nn_garrote,
}
def threshold(data, value, mode='soft', substitute=0):
"""
Thresholds the input data depending on the mode argument.
In ``soft`` thresholding [1]_, data values with absolute value less than
`param` are replaced with `substitute`. Data values with absolute value
greater or equal to the thresholding value are shrunk toward zero
by `value`. In other words, the new value is
``data/np.abs(data) * np.maximum(np.abs(data) - value, 0)``.
In ``hard`` thresholding, the data values where their absolute value is
less than the value param are replaced with `substitute`. Data values with
absolute value greater or equal to the thresholding value stay untouched.
``garrote`` corresponds to the Non-negative garrote threshold [2]_, [3]_.
It is intermediate between ``hard`` and ``soft`` thresholding. It behaves
like soft thresholding for small data values and approaches hard
thresholding for large data values.
In ``greater`` thresholding, the data is replaced with `substitute` where
data is below the thresholding value. Greater data values pass untouched.
In ``less`` thresholding, the data is replaced with `substitute` where data
is above the thresholding value. Lesser data values pass untouched.
Both ``hard`` and ``soft`` thresholding also support complex-valued data.
Parameters
----------
data : array_like
Numeric data.
value : scalar
Thresholding value.
mode : {'soft', 'hard', 'garrote', 'greater', 'less'}
Decides the type of thresholding to be applied on input data. Default
is 'soft'.
substitute : float, optional
Substitute value (default: 0).
Returns
-------
output : array
Thresholded array.
See Also
--------
threshold_firm
References
----------
.. [1] D.L. Donoho and I.M. Johnstone. Ideal Spatial Adaptation via
Wavelet Shrinkage. Biometrika. Vol. 81, No. 3, pp.425-455, 1994.
DOI:10.1093/biomet/81.3.425
.. [2] L. Breiman. Better Subset Regression Using the Nonnegative Garrote.
Technometrics, Vol. 37, pp. 373-384, 1995.
DOI:10.2307/1269730
.. [3] H-Y. Gao. Wavelet Shrinkage Denoising Using the Non-Negative
Garrote. Journal of Computational and Graphical Statistics Vol. 7,
No. 4, pp.469-488. 1998.
DOI:10.1080/10618600.1998.10474789
Examples
--------
>>> import numpy as np
>>> import pywt
>>> data = np.linspace(1, 4, 7)
>>> data
array([ 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. ])
>>> pywt.threshold(data, 2, 'soft')
array([ 0. , 0. , 0. , 0.5, 1. , 1.5, 2. ])
>>> pywt.threshold(data, 2, 'hard')
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
>>> pywt.threshold(data, 2, 'garrote')
array([ 0. , 0. , 0. , 0.9 , 1.66666667,
2.35714286, 3. ])
>>> pywt.threshold(data, 2, 'greater')
array([ 0. , 0. , 2. , 2.5, 3. , 3.5, 4. ])
>>> pywt.threshold(data, 2, 'less')
array([ 1. , 1.5, 2. , 0. , 0. , 0. , 0. ])
"""
try:
return thresholding_options[mode](data, value, substitute)
except KeyError:
# Make sure error is always identical by sorting keys
keys = ("'{0}'".format(key) for key in
sorted(thresholding_options.keys()))
raise ValueError("The mode parameter only takes values from: {0}."
.format(', '.join(keys)))
def threshold_firm(data, value_low, value_high):
"""Firm threshold.
The approach is intermediate between soft and hard thresholding [1]_. It
behaves the same as soft-thresholding for values below `value_low` and
the same as hard-thresholding for values above `thresh_high`. For
intermediate values, the thresholded value is in between that corresponding
to soft or hard thresholding.
Parameters
----------
data : array-like
The data to threshold. This can be either real or complex-valued.
value_low : float
Any values smaller then `value_low` will be set to zero.
value_high : float
Any values larger than `value_high` will not be modified.
Notes
-----
This thresholding technique is also known as semi-soft thresholding [2]_.
For each value, `x`, in `data`. This function computes::
if np.abs(x) <= value_low:
return 0
elif np.abs(x) > value_high:
return x
elif value_low < np.abs(x) and np.abs(x) <= value_high:
return x * value_high * (1 - value_low/x)/(value_high - value_low)
``firm`` is a continuous function (like soft thresholding), but is
unbiased for large values (like hard thresholding).
If ``value_high == value_low`` this function becomes hard-thresholding.
If ``value_high`` is infinity, this function becomes soft-thresholding.
Returns
-------
val_new : array-like
The values after firm thresholding at the specified thresholds.
See Also
--------
threshold
References
----------
.. [1] H.-Y. Gao and A.G. Bruce. Waveshrink with firm shrinkage.
Statistica Sinica, Vol. 7, pp. 855-874, 1997.
.. [2] A. Bruce and H-Y. Gao. WaveShrink: Shrinkage Functions and
Thresholds. Proc. SPIE 2569, Wavelet Applications in Signal and
Image Processing III, 1995.
DOI:10.1117/12.217582
"""
if value_low < 0:
raise ValueError("value_low must be non-negative.")
if value_high < value_low:
raise ValueError(
"value_high must be greater than or equal to value_low.")
data = np.asarray(data)
magnitude = np.absolute(data)
with np.errstate(divide='ignore'):
# divide by zero okay as np.inf values get clipped, so ignore warning.
vdiff = value_high - value_low
thresholded = value_high * (1 - value_low/magnitude) / vdiff
thresholded.clip(min=0, max=None, out=thresholded)
thresholded = data * thresholded
# restore hard-thresholding behavior for values > value_high
large_vals = np.where(magnitude > value_high)
if np.any(large_vals[0]):
thresholded[large_vals] = data[large_vals]
return thresholded

View file

@ -0,0 +1,101 @@
# Copyright (c) 2017 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
import inspect
import numpy as np
import sys
from collections.abc import Iterable
from ._extensions._pywt import (Wavelet, ContinuousWavelet,
DiscreteContinuousWavelet, Modes)
# define string_types as in six for Python 2/3 compatibility
if sys.version_info[0] == 3:
string_types = str,
else:
string_types = basestring,
def _as_wavelet(wavelet):
"""Convert wavelet name to a Wavelet object."""
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if isinstance(wavelet, ContinuousWavelet):
raise ValueError(
"A ContinuousWavelet object was provided, but only discrete "
"Wavelet objects are supported by this function. A list of all "
"supported discrete wavelets can be obtained by running:\n"
"print(pywt.wavelist(kind='discrete'))")
return wavelet
def _wavelets_per_axis(wavelet, axes):
"""Initialize Wavelets for each axis to be transformed.
Parameters
----------
wavelet : Wavelet or tuple of Wavelets
If a single Wavelet is provided, it will used for all axes. Otherwise
one Wavelet per axis must be provided.
axes : list
The tuple of axes to be transformed.
Returns
-------
wavelets : list of Wavelet objects
A tuple of Wavelets equal in length to ``axes``.
"""
axes = tuple(axes)
if isinstance(wavelet, string_types + (Wavelet, )):
# same wavelet on all axes
wavelets = [_as_wavelet(wavelet), ] * len(axes)
elif isinstance(wavelet, Iterable):
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
if len(wavelet) == 1:
wavelets = [_as_wavelet(wavelet[0]), ] * len(axes)
else:
if len(wavelet) != len(axes):
raise ValueError((
"The number of wavelets must match the number of axes "
"to be transformed."))
wavelets = [_as_wavelet(w) for w in wavelet]
else:
raise ValueError("wavelet must be a str, Wavelet or iterable")
return wavelets
def _modes_per_axis(modes, axes):
"""Initialize mode for each axis to be transformed.
Parameters
----------
modes : str or tuple of strings
If a single mode is provided, it will used for all axes. Otherwise
one mode per axis must be provided.
axes : tuple
The tuple of axes to be transformed.
Returns
-------
modes : tuple of int
A tuple of Modes equal in length to ``axes``.
"""
axes = tuple(axes)
if isinstance(modes, string_types + (int, )):
# same wavelet on all axes
modes = [Modes.from_object(modes), ] * len(axes)
elif isinstance(modes, Iterable):
if len(modes) == 1:
modes = [Modes.from_object(modes[0]), ] * len(axes)
else:
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
if len(modes) != len(axes):
raise ValueError(("The number of modes must match the number "
"of axes to be transformed."))
modes = [Modes.from_object(mode) for mode in modes]
else:
raise ValueError("modes must be a str, Mode enum or iterable")
return modes

View file

@ -0,0 +1,733 @@
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
"""1D and 2D Wavelet packet transform module."""
from __future__ import division, print_function, absolute_import
__all__ = ["BaseNode", "Node", "WaveletPacket", "Node2D", "WaveletPacket2D"]
import numpy as np
from ._extensions._pywt import Wavelet, _check_dtype
from ._dwt import dwt, idwt, dwt_max_level
from ._multidim import dwt2, idwt2
def get_graycode_order(level, x='a', y='d'):
graycode_order = [x, y]
for i in range(level - 1):
graycode_order = [x + path for path in graycode_order] + \
[y + path for path in graycode_order[::-1]]
return graycode_order
class BaseNode(object):
"""
BaseNode for wavelet packet 1D and 2D tree nodes.
The BaseNode is a base class for `Node` and `Node2D`.
It should not be used directly unless creating a new transformation
type. It is included here to document the common interface of 1D
and 2D node and wavelet packet transform classes.
Parameters
----------
parent :
Parent node. If parent is None then the node is considered detached
(ie root).
data : 1D or 2D array
Data associated with the node. 1D or 2D numeric array, depending on the
transform type.
node_name :
A name identifying the coefficients type.
See `Node.node_name` and `Node2D.node_name`
for information on the accepted subnodes names.
"""
# PART_LEN and PARTS attributes that define path tokens for node[] lookup
# must be defined in subclasses.
PART_LEN = None
PARTS = None
def __init__(self, parent, data, node_name):
self.parent = parent
if parent is not None:
self.wavelet = parent.wavelet
self.mode = parent.mode
self.level = parent.level + 1
self._maxlevel = parent.maxlevel
self.path = parent.path + node_name
else:
self.wavelet = None
self.mode = None
self.path = ""
self.level = 0
# data - signal on level 0, coeffs on higher levels
self.data = data
# Need to retain original data size/shape so we can trim any excess
# boundary coefficients from the inverse transform.
if self.data is None:
self._data_shape = None
else:
self._data_shape = np.asarray(data).shape
self._init_subnodes()
def _init_subnodes(self):
for part in self.PARTS:
self._set_node(part, None)
def _create_subnode(self, part, data=None, overwrite=True):
raise NotImplementedError()
def _create_subnode_base(self, node_cls, part, data=None, overwrite=True):
self._validate_node_name(part)
if not overwrite and self._get_node(part) is not None:
return self._get_node(part)
node = node_cls(self, data, part)
self._set_node(part, node)
return node
def _get_node(self, part):
return getattr(self, part)
def _set_node(self, part, node):
setattr(self, part, node)
def _delete_node(self, part):
self._set_node(part, None)
def _validate_node_name(self, part):
if part not in self.PARTS:
raise ValueError("Subnode name must be in [%s], not '%s'." %
(', '.join("'%s'" % p for p in self.PARTS), part))
def _evaluate_maxlevel(self, evaluate_from='parent'):
"""
Try to find the value of maximum decomposition level if it is not
specified explicitly.
Parameters
----------
evaluate_from : {'parent', 'subnodes'}
"""
assert evaluate_from in ('parent', 'subnodes')
if self._maxlevel is not None:
return self._maxlevel
elif self.data is not None:
return self.level + dwt_max_level(
min(self.data.shape), self.wavelet)
if evaluate_from == 'parent':
if self.parent is not None:
return self.parent._evaluate_maxlevel(evaluate_from)
elif evaluate_from == 'subnodes':
for node_name in self.PARTS:
node = getattr(self, node_name, None)
if node is not None:
level = node._evaluate_maxlevel(evaluate_from)
if level is not None:
return level
return None
@property
def maxlevel(self):
if self._maxlevel is not None:
return self._maxlevel
# Try getting the maxlevel from parents first
self._maxlevel = self._evaluate_maxlevel(evaluate_from='parent')
# If not found, check whether it can be evaluated from subnodes
if self._maxlevel is None:
self._maxlevel = self._evaluate_maxlevel(evaluate_from='subnodes')
return self._maxlevel
@property
def node_name(self):
return self.path[-self.PART_LEN:]
def decompose(self):
"""
Decompose node data creating DWT coefficients subnodes.
Performs Discrete Wavelet Transform on the `~BaseNode.data` and
returns transform coefficients.
Note
----
Descends to subnodes and recursively
calls `~BaseNode.reconstruct` on them.
"""
if self.level < self.maxlevel:
return self._decompose()
else:
raise ValueError("Maximum decomposition level reached.")
def _decompose(self):
raise NotImplementedError()
def reconstruct(self, update=False):
"""
Reconstruct node from subnodes.
Parameters
----------
update : bool, optional
If True, then reconstructed data replaces the current
node data (default: False).
Returns:
- original node data if subnodes do not exist
- IDWT of subnodes otherwise.
"""
if not self.has_any_subnode:
return self.data
return self._reconstruct(update)
def _reconstruct(self):
raise NotImplementedError() # override this in subclasses
def get_subnode(self, part, decompose=True):
"""
Returns subnode or None (see `decomposition` flag description).
Parameters
----------
part :
Subnode name
decompose : bool, optional
If the param is True and corresponding subnode does not
exist, the subnode will be created using coefficients
from the DWT decomposition of the current node.
(default: True)
"""
self._validate_node_name(part)
subnode = self._get_node(part)
if subnode is None and decompose and not self.is_empty:
self.decompose()
subnode = self._get_node(part)
return subnode
def __getitem__(self, path):
"""
Find node represented by the given path.
Similar to `~BaseNode.get_subnode` method with `decompose=True`, but
can access nodes on any level in the decomposition tree.
Parameters
----------
path : str
String composed of node names. See `Node.node_name` and
`Node2D.node_name` for node naming convention.
Notes
-----
If node does not exist yet, it will be created by decomposition of its
parent node.
"""
if isinstance(path, str):
if (self.maxlevel is not None
and len(path) > self.maxlevel * self.PART_LEN):
raise IndexError("Path length is out of range.")
if path:
return self.get_subnode(path[0:self.PART_LEN], True)[
path[self.PART_LEN:]]
else:
return self
else:
raise TypeError("Invalid path parameter type - expected string but"
" got %s." % type(path))
def __setitem__(self, path, data):
"""
Set node or node's data in the decomposition tree. Nodes are
identified by string `path`.
Parameters
----------
path : str
String composed of node names.
data : array or BaseNode subclass.
"""
if isinstance(path, str):
if (
self.maxlevel is not None
and len(self.path) + len(path) > self.maxlevel * self.PART_LEN
):
raise IndexError("Path length out of range.")
if path:
subnode = self.get_subnode(path[0:self.PART_LEN], False)
if subnode is None:
self._create_subnode(path[0:self.PART_LEN], None)
subnode = self.get_subnode(path[0:self.PART_LEN], False)
subnode[path[self.PART_LEN:]] = data
else:
if isinstance(data, BaseNode):
self.data = np.asarray(data.data)
else:
self.data = np.asarray(data)
# convert data to nearest supported dtype
dtype = _check_dtype(data)
if self.data.dtype != dtype:
self.data = self.data.astype(dtype)
else:
raise TypeError("Invalid path parameter type - expected string but"
" got %s." % type(path))
def __delitem__(self, path):
"""
Remove node from the tree.
Parameters
----------
path : str
String composed of node names.
"""
node = self[path]
# don't clear node value and subnodes (node may still exist outside
# the tree)
# # node._init_subnodes()
# # node.data = None
parent = node.parent
node.parent = None # TODO
if parent and node.node_name:
parent._delete_node(node.node_name)
@property
def is_empty(self):
return self.data is None
@property
def has_any_subnode(self):
for part in self.PARTS:
if self._get_node(part) is not None: # and not .is_empty
return True
return False
def get_leaf_nodes(self, decompose=False):
"""
Returns leaf nodes.
Parameters
----------
decompose : bool, optional
(default: True)
"""
result = []
def collect(node):
if node.level == node.maxlevel and not node.is_empty:
result.append(node)
return False
if not decompose and not node.has_any_subnode:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
return result
def walk(self, func, args=(), kwargs=None, decompose=True):
"""
Traverses the decomposition tree and calls
``func(node, *args, **kwargs)`` on every node. If `func` returns True,
descending to subnodes will continue.
Parameters
----------
func : callable
Callable accepting `BaseNode` as the first param and
optional positional and keyword arguments
args :
func params
kwargs :
func keyword params
decompose : bool, optional
If True (default), the method will also try to decompose the tree
up to the `maximum level <BaseNode.maxlevel>`.
"""
if kwargs is None:
kwargs = {}
if func(self, *args, **kwargs) and self.level < self.maxlevel:
for part in self.PARTS:
subnode = self.get_subnode(part, decompose)
if subnode is not None:
subnode.walk(func, args, kwargs, decompose)
def walk_depth(self, func, args=(), kwargs=None, decompose=True):
"""
Walk tree and call func on every node starting from the bottom-most
nodes.
Parameters
----------
func : callable
Callable accepting :class:`BaseNode` as the first param and
optional positional and keyword arguments
args :
func params
kwargs :
func keyword params
decompose : bool, optional
(default: False)
"""
if kwargs is None:
kwargs = {}
if self.level < self.maxlevel:
for part in self.PARTS:
subnode = self.get_subnode(part, decompose)
if subnode is not None:
subnode.walk_depth(func, args, kwargs, decompose)
func(self, *args, **kwargs)
def __str__(self):
return self.path + ": " + str(self.data)
class Node(BaseNode):
"""
WaveletPacket tree node.
Subnodes are called `a` and `d`, just like approximation
and detail coefficients in the Discrete Wavelet Transform.
"""
A = 'a'
D = 'd'
PARTS = A, D
PART_LEN = 1
def _create_subnode(self, part, data=None, overwrite=True):
return self._create_subnode_base(node_cls=Node, part=part, data=data,
overwrite=overwrite)
def _decompose(self):
"""
See also
--------
dwt : for 1D Discrete Wavelet Transform output coefficients.
"""
if self.is_empty:
data_a, data_d = None, None
if self._get_node(self.A) is None:
self._create_subnode(self.A, data_a)
if self._get_node(self.D) is None:
self._create_subnode(self.D, data_d)
else:
data_a, data_d = dwt(self.data, self.wavelet, self.mode)
self._create_subnode(self.A, data_a)
self._create_subnode(self.D, data_d)
return self._get_node(self.A), self._get_node(self.D)
def _reconstruct(self, update):
data_a, data_d = None, None
node_a, node_d = self._get_node(self.A), self._get_node(self.D)
if node_a is not None:
data_a = node_a.reconstruct() # TODO: (update) ???
if node_d is not None:
data_d = node_d.reconstruct() # TODO: (update) ???
if data_a is None and data_d is None:
raise ValueError("Node is a leaf node and cannot be reconstructed"
" from subnodes.")
else:
rec = idwt(data_a, data_d, self.wavelet, self.mode)
if self._data_shape is not None and (
rec.shape != self._data_shape):
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
if update:
self.data = rec
return rec
class Node2D(BaseNode):
"""
WaveletPacket tree node.
Subnodes are called 'a' (LL), 'h' (HL), 'v' (LH) and 'd' (HH), like
approximation and detail coefficients in the 2D Discrete Wavelet Transform
"""
LL = 'a'
HL = 'h'
LH = 'v'
HH = 'd'
PARTS = LL, HL, LH, HH
PART_LEN = 1
def _create_subnode(self, part, data=None, overwrite=True):
return self._create_subnode_base(node_cls=Node2D, part=part, data=data,
overwrite=overwrite)
def _decompose(self):
"""
See also
--------
dwt2 : for 2D Discrete Wavelet Transform output coefficients.
"""
if self.is_empty:
data_ll, data_lh, data_hl, data_hh = None, None, None, None
else:
data_ll, (data_hl, data_lh, data_hh) =\
dwt2(self.data, self.wavelet, self.mode)
self._create_subnode(self.LL, data_ll)
self._create_subnode(self.LH, data_lh)
self._create_subnode(self.HL, data_hl)
self._create_subnode(self.HH, data_hh)
return (self._get_node(self.LL), self._get_node(self.HL),
self._get_node(self.LH), self._get_node(self.HH))
def _reconstruct(self, update):
data_ll, data_lh, data_hl, data_hh = None, None, None, None
node_ll, node_lh, node_hl, node_hh =\
self._get_node(self.LL), self._get_node(self.LH),\
self._get_node(self.HL), self._get_node(self.HH)
if node_ll is not None:
data_ll = node_ll.reconstruct()
if node_lh is not None:
data_lh = node_lh.reconstruct()
if node_hl is not None:
data_hl = node_hl.reconstruct()
if node_hh is not None:
data_hh = node_hh.reconstruct()
if (data_ll is None and data_lh is None
and data_hl is None and data_hh is None):
raise ValueError(
"Tree is missing data - all subnodes of `%s` node "
"are None. Cannot reconstruct node." % self.path
)
else:
coeffs = data_ll, (data_hl, data_lh, data_hh)
rec = idwt2(coeffs, self.wavelet, self.mode)
if self._data_shape is not None and (
rec.shape != self._data_shape):
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
if update:
self.data = rec
return rec
def expand_2d_path(self, path):
expanded_paths = {
self.HH: 'hh',
self.HL: 'hl',
self.LH: 'lh',
self.LL: 'll'
}
return (''.join([expanded_paths[p][0] for p in path]),
''.join([expanded_paths[p][1] for p in path]))
class WaveletPacket(Node):
"""
Data structure representing Wavelet Packet decomposition of signal.
Parameters
----------
data : 1D ndarray
Original data (signal)
wavelet : Wavelet object or name string
Wavelet used in DWT decomposition and reconstruction
mode : str, optional
Signal extension mode for the `dwt` and `idwt` decomposition and
reconstruction functions.
maxlevel : int, optional
Maximum level of decomposition.
If None, it will be calculated based on the `wavelet` and `data`
length using `pywt.dwt_max_level`.
"""
def __init__(self, data, wavelet, mode='symmetric', maxlevel=None):
super(WaveletPacket, self).__init__(None, data, "")
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)
self.wavelet = wavelet
self.mode = mode
if data is not None:
data = np.asarray(data)
assert data.ndim == 1
self.data_size = data.shape[0]
if maxlevel is None:
maxlevel = dwt_max_level(self.data_size, self.wavelet)
else:
self.data_size = None
self._maxlevel = maxlevel
def reconstruct(self, update=True):
"""
Reconstruct data value using coefficients from subnodes.
Parameters
----------
update : bool, optional
If True (default), then data values will be replaced by
reconstruction values, also in subnodes.
"""
if self.has_any_subnode:
data = super(WaveletPacket, self).reconstruct(update)
if update:
self.data = data
return data
return self.data # return original data
def get_level(self, level, order="natural", decompose=True):
"""
Returns all nodes on the specified level.
Parameters
----------
level : int
Specifies decomposition `level` from which the nodes will be
collected.
order : {'natural', 'freq'}, optional
- "natural" - left to right in tree (default)
- "freq" - band ordered
decompose : bool, optional
If set then the method will try to decompose the data up
to the specified `level` (default: True).
Notes
-----
If nodes at the given level are missing (i.e. the tree is partially
decomposed) and the `decompose` is set to False, only existing nodes
will be returned.
"""
assert order in ["natural", "freq"]
if level > self.maxlevel:
raise ValueError("The level cannot be greater than the maximum"
" decomposition level value (%d)" % self.maxlevel)
result = []
def collect(node):
if node.level == level:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
if order == "natural":
return result
elif order == "freq":
result = dict((node.path, node) for node in result)
graycode_order = get_graycode_order(level)
return [result[path] for path in graycode_order if path in result]
else:
raise ValueError("Invalid order name - %s." % order)
class WaveletPacket2D(Node2D):
"""
Data structure representing 2D Wavelet Packet decomposition of signal.
Parameters
----------
data : 2D ndarray
Data associated with the node.
wavelet : Wavelet object or name string
Wavelet used in DWT decomposition and reconstruction
mode : str, optional
Signal extension mode for the `dwt` and `idwt` decomposition and
reconstruction functions.
maxlevel : int
Maximum level of decomposition.
If None, it will be calculated based on the `wavelet` and `data`
length using `pywt.dwt_max_level`.
"""
def __init__(self, data, wavelet, mode='smooth', maxlevel=None):
super(WaveletPacket2D, self).__init__(None, data, "")
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)
self.wavelet = wavelet
self.mode = mode
if data is not None:
data = np.asarray(data)
assert data.ndim == 2
self.data_size = data.shape
if maxlevel is None:
maxlevel = dwt_max_level(min(self.data_size), self.wavelet)
else:
self.data_size = None
self._maxlevel = maxlevel
def reconstruct(self, update=True):
"""
Reconstruct data using coefficients from subnodes.
Parameters
----------
update : bool, optional
If True (default) then the coefficients of the current node
and its subnodes will be replaced with values from reconstruction.
"""
if self.has_any_subnode:
data = super(WaveletPacket2D, self).reconstruct(update)
if update:
self.data = data
return data
return self.data # return original data
def get_level(self, level, order="natural", decompose=True):
"""
Returns all nodes from specified level.
Parameters
----------
level : int
Decomposition `level` from which the nodes will be
collected.
order : {'natural', 'freq'}, optional
If `natural` (default) a flat list is returned.
If `freq`, a 2d structure with rows and cols
sorted by corresponding dimension frequency of 2d
coefficient array (adapted from 1d case).
decompose : bool, optional
If set then the method will try to decompose the data up
to the specified `level` (default: True).
"""
assert order in ["natural", "freq"]
if level > self.maxlevel:
raise ValueError("The level cannot be greater than the maximum"
" decomposition level value (%d)" % self.maxlevel)
result = []
def collect(node):
if node.level == level:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
if order == "freq":
nodes = {}
for (row_path, col_path), node in [
(self.expand_2d_path(node.path), node) for node in result
]:
nodes.setdefault(row_path, {})[col_path] = node
graycode_order = get_graycode_order(level, x='l', y='h')
nodes = [nodes[path] for path in graycode_order if path in nodes]
result = []
for row in nodes:
result.append(
[row[path] for path in graycode_order if path in row]
)
return result

View file

@ -0,0 +1,6 @@
import pytest
def pytest_configure(config):
config.addinivalue_line("markers",
"slow: Tests that are slow.")

View file

@ -0,0 +1,2 @@
from ._readers import ascent, aero, ecg, camera, nino
from ._wavelab_signals import demo_signal

View file

@ -0,0 +1,185 @@
import os
import numpy as np
def ascent():
"""
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
easy use in demos
The image is derived from accent-to-the-top.jpg at
http://www.public-domain-image.com/people-public-domain-images-pictures/
Parameters
----------
None
Returns
-------
ascent : ndarray
convenient image to use for testing and demonstration
Examples
--------
>>> import pywt.data
>>> ascent = pywt.data.ascent()
>>> ascent.shape == (512, 512)
True
>>> ascent.max()
255
>>> import matplotlib.pyplot as plt
>>> plt.gray()
>>> plt.imshow(ascent) # doctest: +ELLIPSIS
<matplotlib.image.AxesImage object at ...>
>>> plt.show() # doctest: +SKIP
"""
fname = os.path.join(os.path.dirname(__file__), 'ascent.npz')
ascent = np.load(fname)['data']
return ascent
def aero():
"""
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
easy use in demos
Parameters
----------
None
Returns
-------
aero : ndarray
convenient image to use for testing and demonstration
Examples
--------
>>> import pywt.data
>>> aero = pywt.data.ascent()
>>> aero.shape == (512, 512)
True
>>> aero.max()
255
>>> import matplotlib.pyplot as plt
>>> plt.gray()
>>> plt.imshow(aero) # doctest: +ELLIPSIS
<matplotlib.image.AxesImage object at ...>
>>> plt.show() # doctest: +SKIP
"""
fname = os.path.join(os.path.dirname(__file__), 'aero.npz')
aero = np.load(fname)['data']
return aero
def camera():
"""
Get an 8-bit grayscale bit-depth, 512 x 512 derived image for
easy use in demos
Parameters
----------
None
Returns
-------
camera : ndarray
convenient image to use for testing and demonstration
Examples
--------
>>> import pywt.data
>>> camera = pywt.data.ascent()
>>> camera.shape == (512, 512)
True
>>> import matplotlib.pyplot as plt
>>> plt.gray()
>>> plt.imshow(camera) # doctest: +ELLIPSIS
<matplotlib.image.AxesImage object at ...>
>>> plt.show() # doctest: +SKIP
"""
fname = os.path.join(os.path.dirname(__file__), 'camera.npz')
camera = np.load(fname)['data']
return camera
def ecg():
"""
Get 1024 points of an ECG timeseries.
Parameters
----------
None
Returns
-------
ecg : ndarray
convenient timeseries to use for testing and demonstration
Examples
--------
>>> import pywt.data
>>> ecg = pywt.data.ecg()
>>> ecg.shape == (1024,)
True
>>> import matplotlib.pyplot as plt
>>> plt.plot(ecg) # doctest: +ELLIPSIS
[<matplotlib.lines.Line2D object at ...>]
>>> plt.show() # doctest: +SKIP
"""
fname = os.path.join(os.path.dirname(__file__), 'ecg.npy')
ecg = np.load(fname)
return ecg
def nino():
"""
This data contains the averaged monthly sea surface temperature in degrees
Celcius of the Pacific Ocean, between 0-10 degrees South and 90-80 degrees West, from 1950 to 2016.
This dataset is in the public domain and was obtained from NOAA.
National Oceanic and Atmospheric Administration's National Weather Service
ERSSTv4 dataset, nino 3, http://www.cpc.ncep.noaa.gov/data/indices/
Parameters
----------
None
Returns
-------
time : ndarray
convenient timeseries to use for testing and demonstration
sst : ndarray
convenient timeseries to use for testing and demonstration
Examples
--------
>>> import pywt.data
>>> time, sst = pywt.data.nino()
>>> sst.shape == (264,)
True
>>> import matplotlib.pyplot as plt
>>> plt.plot(time,sst) # doctest: +ELLIPSIS
[<matplotlib.lines.Line2D object at ...>]
>>> plt.show() # doctest: +SKIP
"""
fname = os.path.join(os.path.dirname(__file__), 'sst_nino3.npz')
sst_csv = np.load(fname)['sst_csv']
# sst_csv = pd.read_csv("http://www.cpc.ncep.noaa.gov/data/indices/ersst4.nino.mth.81-10.ascii", sep=' ', skipinitialspace=True)
# take only full years
n = int(np.floor(sst_csv.shape[0]/12.)*12.)
# Building the mean of three mounth
# the 4. column is nino 3
sst = np.mean(np.reshape(np.array(sst_csv)[:n, 4], (n//3, -1)), axis=1)
sst = (sst - np.mean(sst)) / np.std(sst, ddof=1)
dt = 0.25
time = np.arange(len(sst)) * dt + 1950.0 # construct time array
return time, sst

View file

@ -0,0 +1,259 @@
# -*- coding:utf-8 -*-
from __future__ import division
import numpy as np
__all__ = ['demo_signal']
_implemented_signals = [
'Blocks',
'Bumps',
'HeaviSine',
'Doppler',
'Ramp',
'HiSine',
'LoSine',
'LinChirp',
'TwoChirp',
'QuadChirp',
'MishMash',
'WernerSorrows',
'HypChirps',
'LinChirps',
'Chirps',
'Gabor',
'sineoneoverx',
'Piece-Regular',
'Piece-Polynomial',
'Riemann']
def demo_signal(name='Bumps', n=None):
"""Simple 1D wavelet test functions.
This function can generate a number of common 1D test signals used in
papers by David Donoho and colleagues (e.g. [1]_) as well as the wavelet
book by Stéphane Mallat [2]_.
Parameters
----------
name : {'Blocks', 'Bumps', 'HeaviSine', 'Doppler', ...}
The type of test signal to generate (`name` is case-insensitive). If
`name` is set to `'list'`, a list of the avialable test functions is
returned.
n : int or None
The length of the test signal. This should be provided for all test
signals except `'Gabor'` and `'sineoneoverx'` which have a fixed
length.
Returns
-------
f : np.ndarray
Array of length ``n`` corresponding to the specified test signal type.
References
----------
.. [1] D.L. Donoho and I.M. Johnstone. Ideal spatial adaptation by
wavelet shrinkage. Biometrika, vol. 81, pp. 425455, 1994.
.. [2] S. Mallat. A Wavelet Tour of Signal Processing: The Sparse Way.
Academic Press. 2009.
Notes
-----
This function is a partial reimplementation of the `MakeSignal` function
from the [Wavelab](https://statweb.stanford.edu/~wavelab/) toolbox. These
test signals are provided with permission of Dr. Donoho to encourage
reproducible research.
"""
if name.lower() == 'list':
return _implemented_signals
if n is not None:
if n < 1 or (n % 1) != 0:
raise ValueError("n must be an integer >= 1")
t = np.arange(1/n, 1 + 1/n, 1/n)
# The following function types don't allow user-specified `n`.
n_hard_coded = ['gabor', 'sineoneoverx']
name = name.lower()
if name in n_hard_coded and n is not None:
raise ValueError(
"Parameter n must be set to None when name is {}".format(name))
elif n is None and name not in n_hard_coded:
raise ValueError(
"Parameter n must be provided when name is {}".format(name))
if name == 'blocks':
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
hs = [4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2]
f = 0
for (t0, h) in zip(t0s, hs):
f += h * (1 + np.sign(t - t0)) / 2
elif name == 'bumps':
t0s = [.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]
hs = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
ws = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
f = 0
for (t0, h, w) in zip(t0s, hs, ws):
f += h / (1 + np.abs((t - t0) / w))**4
elif name == 'heavisine':
f = 4 * np.sin(4 * np.pi * t) - np.sign(t - 0.3) - np.sign(0.72 - t)
elif name == 'doppler':
f = np.sqrt(t * (1 - t)) * np.sin(2 * np.pi * 1.05 / (t + 0.05))
elif name == 'ramp':
f = t - (t >= .37)
elif name == 'hisine':
f = np.sin(np.pi * (n * .6902) * t)
elif name == 'losine':
f = np.sin(np.pi * (n * .3333) * t)
elif name == 'linchirp':
f = np.sin(np.pi * t * ((n * .500) * t))
elif name == 'twochirp':
f = np.sin(np.pi * t * (n * t)) + np.sin((np.pi / 3) * t * (n * t))
elif name == 'quadchirp':
f = np.sin((np.pi / 3) * t * (n * t**2))
elif name == 'mishmash': # QuadChirp + LinChirp + HiSine
f = np.sin((np.pi / 3) * t * (n * t**2))
f += np.sin(np.pi * (n * .6902) * t)
f += np.sin(np.pi * t * (n * .125 * t))
elif name == 'wernersorrows':
f = np.sin(np.pi * t * (n / 2 * t**2))
f = f + np.sin(np.pi * (n * .6902) * t)
f = f + np.sin(np.pi * t * (n * t))
pos = [.1, .13, .15, .23, .25, .40, .44, .65, .76, .78, .81]
hgt = [4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 5.1, 4.2]
wth = [.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]
for p, h, w in zip(pos, hgt, wth):
f += h / (1 + np.abs((t - p) / w))**4
elif name == 'hypchirps': # Hyperbolic Chirps of Mallat's book
alpha = 15 * n * np.pi / 1024
beta = 5 * n * np.pi / 1024
t = np.arange(1.001, n + .001 + 1) / n
f1 = np.zeros(n)
f2 = np.zeros(n)
f1 = np.sin(alpha / (.8 - t)) * (0.1 < t) * (t < 0.68)
f2 = np.sin(beta / (.8 - t)) * (0.1 < t) * (t < 0.75)
m = int(np.round(0.65 * n))
p = m // 4
envelope = np.ones(m) # the rinp.sing cutoff function
tmp = np.arange(1, p + 1)-np.ones(p)
envelope[:p] = (1 + np.sin(-np.pi / 2 + tmp / (p - 1) * np.pi)) / 2
envelope[m-p:m] = envelope[:p][::-1]
env = np.zeros(n)
env[int(np.ceil(n / 10)) - 1:m + int(np.ceil(n / 10)) - 1] = \
envelope[:m]
f = (f1 + f2) * env
elif name == 'linchirps': # Linear Chirps of Mallat's book
b = 100 * n * np.pi / 1024
a = 250 * n * np.pi / 1024
t = np.arange(1, n + 1) / n
A1 = np.sqrt((t - 1 / n) * (1 - t))
f = A1 * (np.cos(a * t**2) + np.cos(b * t + a * t**2))
elif name == 'chirps': # Mixture of Chirps of Mallat's book
t = np.arange(1, n + 1)/n * 10 * np.pi
f1 = np.cos(t**2 * n / 1024)
a = 30 * n / 1024
t = np.arange(1, n + 1)/n * np.pi
f2 = np.cos(a * (t**3))
f2 = f2[::-1]
ix = np.arange(-n, n + 1) / n * 20
g = np.exp(-ix**2 * 4 * n / 1024)
i1 = slice(n // 2, n // 2 + n)
i2 = slice(n // 8, n // 8 + n)
j = np.arange(1, n + 1) / n
f3 = g[i1] * np.cos(50 * np.pi * j * n / 1024)
f4 = g[i2] * np.cos(350 * np.pi * j * n / 1024)
f = f1 + f2 + f3 + f4
envelope = np.ones(n) # the rinp.sing cutoff function
tmp = np.arange(1, n // 8 + 1) - np.ones(n // 8)
envelope[:n // 8] = (
1 + np.sin(-np.pi / 2 + tmp / (n / 8 - 1) * np.pi)) / 2
envelope[7 * n // 8:n] = envelope[:n // 8][::-1]
f = f*envelope
elif name == 'gabor': # two modulated Gabor functions in Mallat's book
n = 512
t = np.arange(-n, n + 1)*5 / n
j = np.arange(1, n + 1) / n
g = np.exp(-t**2 * 20)
i1 = slice(2*n // 4, 2 * n // 4 + n)
i2 = slice(n // 4, n // 4 + n)
f1 = 3 * g[i1] * np.exp(1j * (n // 16) * np.pi * j)
f2 = 3 * g[i2] * np.exp(1j * (n // 4) * np.pi * j)
f = f1 + f2
elif name == 'sineoneoverx': # np.sin(1/x) in Mallat's book
n = 1024
i1 = np.arange(-n + 1, n + 1, dtype=float)
i1[i1 == 0] = 1 / 100
i1 = i1 / (n - 1)
f = np.sin(1.5 / i1)
f = f[512:1536]
elif name == 'piece-regular':
f = np.zeros(n)
n_12 = int(np.fix(n / 12))
n_7 = int(np.fix(n / 7))
n_5 = int(np.fix(n / 5))
n_3 = int(np.fix(n / 3))
n_2 = int(np.fix(n / 2))
n_20 = int(np.fix(n / 20))
f1 = -15 * demo_signal('bumps', n)
t = np.arange(1, n_12 + 1) / n_12
f2 = -np.exp(4 * t)
t = np.arange(1, n_7 + 1) / n_7
f5 = np.exp(4 * t)-np.exp(4)
t = np.arange(1, n_3 + 1) / n_3
fma = 6 / 40
f6 = -70 * np.exp(-((t - 0.5) * (t - 0.5)) / (2 * fma**2))
f[:n_7] = f6[:n_7]
f[n_7:n_5] = 0.5 * f6[n_7:n_5]
f[n_5:n_3] = f6[n_5:n_3]
f[n_3:n_2] = f1[n_3:n_2]
f[n_2:n_2 + n_12] = f2
f[n_2 + 2 * n_12 - 1:n_2 + n_12 - 1:-1] = f2
f[n_2 + 2 * n_12 + n_20:n_2 + 2 * n_12 + 3 * n_20] = -np.ones(
n_2 + 2*n_12 + 3*n_20 - n_2 - 2*n_12 - n_20) * 25
k = n_2 + 2 * n_12 + 3 * n_20
f[k:k + n_7] = f5
diff = n - 5 * n_5
f[5 * n_5:n] = f[diff - 1::-1]
# zero-mean
bias = np.sum(f) / n
f = bias - f
elif name == 'piece-polynomial':
f = np.zeros(n)
n_5 = int(np.fix(n / 5))
n_10 = int(np.fix(n / 10))
n_20 = int(np.fix(n / 20))
t = np.arange(1, n_5 + 1) / n_5
f1 = 20 * (t**3 + t**2 + 4)
f3 = 40 * (2 * t**3 + t) + 100
f2 = 10 * t**3 + 45
f4 = 16 * t**2 + 8 * t + 16
f5 = 20 * (t + 4)
f6 = np.ones(n_10) * 20
f[:n_5] = f1
f[2 * n_5 - 1:n_5 - 1:-1] = f2
f[2 * n_5:3 * n_5] = f3
f[3 * n_5:4 * n_5] = f4
f[4 * n_5:5 * n_5] = f5[n_5::-1]
diff = n - 5*n_5
f[5 * n_5:n] = f[diff - 1::-1]
f[n_20:n_20 + n_10] = np.ones(n_10) * 10
f[n - n_10:n + n_20 - n_10] = np.ones(n_20) * 150
# zero-mean
bias = np.sum(f) / n
f = f - bias
elif name == 'riemann':
# Riemann's Non-differentiable Function
sqn = int(np.round(np.sqrt(n)))
idx = np.arange(1, sqn + 1)
idx *= idx
f = np.zeros_like(t)
f[idx - 1] = 1. / np.arange(1, sqn + 1)
f = np.real(np.fft.ifft(f))
else:
raise ValueError(
"unknown name: {}. name must be one of: {}".format(
name, _implemented_signals))
return f

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,39 @@
#!/usr/bin/env python
"""Helper script for creating image .dat files by numpy.save
Usage:
python create_dat.py <name of image file> <name of dat file>
Example (to create aero.dat):
python create_dat.py aero.png aero.dat
Requires Scipy and PIL.
"""
from __future__ import print_function
import sys
import numpy as np
def main():
from scipy.misc import imread
if len(sys.argv) != 3:
print(__doc__)
exit()
image_fname = sys.argv[1]
dat_fname = sys.argv[2]
data = imread(image_fname)
np.savez_compressed(dat_fname, data=data)
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,96 @@
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
the outputs from Matlab R2012a. """
from __future__ import division, print_function, absolute_import
import numpy as np
import pywt
try:
from pymatbridge import Matlab
mlab = Matlab()
_matlab_missing = False
except ImportError:
print("To run Matlab compatibility tests you need to have MathWorks "
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
"package installed.")
_matlab_missing = True
if _matlab_missing:
raise EnvironmentError("Can't generate matlab data files without MATLAB")
size_set = 'reduced'
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('reflect', 'symw'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per'),
('antisymmetric', 'asym'),
('antireflect', 'asymw')]
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
rstate = np.random.RandomState(1234)
mlab.start()
try:
all_matlab_results = {}
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
mlab.set_variable('wavelet', wavelet)
if size_set == 'full':
data_sizes = list(range(w.dec_len, 40)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (w.dec_len, w.dec_len + 1)
for N in data_sizes:
data = rstate.randn(N)
mlab.set_variable('data', data)
for pmode, mmode in modes:
# Matlab result
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
"'mode', '%s');" % mmode)
else:
mlab_code = ("[ma, md] = dwt(data, wavelet, "
"'mode', '%s');" % mmode)
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
all_matlab_results[ma_key] = ma
all_matlab_results[md_key] = md
# Matlab result
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, "
"'mode', '%s');" % mmode)
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
all_matlab_results[ma_key] = ma
all_matlab_results[md_key] = md
finally:
mlab.stop()
np.savez('dwt_matlabR2012a_result.npz', **all_matlab_results)

View file

@ -0,0 +1,86 @@
""" This script was used to generate dwt_matlabR2012a_result.npz by storing
the outputs from Matlab R2012a. """
from __future__ import division, print_function, absolute_import
import numpy as np
import pywt
try:
from pymatbridge import Matlab
mlab = Matlab()
_matlab_missing = False
except ImportError:
print("To run Matlab compatibility tests you need to have MathWorks "
"MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
"package installed.")
_matlab_missing = True
if _matlab_missing:
raise EnvironmentError("Can't generate matlab data files without MATLAB")
size_set = 'reduced'
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per')]
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
wavelets = sum([pywt.wavelist(name) for name in families], [])
rstate = np.random.RandomState(1234)
mlab.start()
try:
all_matlab_results = {}
for wavelet in wavelets:
w = pywt.ContinuousWavelet(wavelet)
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
elif wavelet == 'fbsp':
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
else:
mlab.set_variable('wavelet', wavelet)
if size_set == 'full':
data_sizes = list(range(100, 101)) + \
[100, 200, 500, 1000, 50000]
Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
else:
data_sizes = (1000, 1000 + 1)
Scales = (1,np.arange(1,3))
mlab_code = ("psi = wavefun(wavelet,10)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
psi = np.asarray(mlab.get_variable('psi'))
psi_key = '_'.join([wavelet, 'psi'])
all_matlab_results[psi_key] = psi
for N in data_sizes:
data = rstate.randn(N)
mlab.set_variable('data', data)
# Matlab result
scale_count = 0
for scales in Scales:
scale_count += 1
mlab.set_variable('scales', scales)
mlab_code = ("coefs = cwt(data, scales, wavelet)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError(
"Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is type float
coefs = np.asarray(mlab.get_variable('coefs'))
coefs_key = '_'.join([str(scale_count), wavelet, str(N), 'coefs'])
all_matlab_results[coefs_key] = coefs
finally:
mlab.stop()
np.savez('cwt_matlabR2015b_result.npz', **all_matlab_results)

View file

@ -0,0 +1,170 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_, assert_raises
import pywt
def test_upcoef_reconstruct():
data = np.arange(3)
a = pywt.downcoef('a', data, 'haar')
d = pywt.downcoef('d', data, 'haar')
rec = (pywt.upcoef('a', a, 'haar', take=3) +
pywt.upcoef('d', d, 'haar', take=3))
assert_allclose(rec, data)
def test_downcoef_multilevel():
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.downcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.downcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)
def test_downcoef_complex():
rstate = np.random.RandomState(1234)
r = rstate.randn(16) + 1j * rstate.randn(16)
nlevels = 3
a = pywt.downcoef('a', r, 'haar', level=nlevels)
a_ref = pywt.downcoef('a', r.real, 'haar', level=nlevels)
a_ref = a_ref + 1j * pywt.downcoef('a', r.imag, 'haar', level=nlevels)
assert_allclose(a, a_ref)
def test_downcoef_errs():
# invalid part string (not 'a' or 'd')
assert_raises(ValueError, pywt.downcoef, 'f', np.ones(16), 'haar')
def test_compare_downcoef_coeffs():
rstate = np.random.RandomState(1234)
r = rstate.randn(16)
# compare downcoef against wavedec outputs
for nlevels in [1, 2, 3]:
for wavelet in pywt.wavelist():
if wavelet in ['cmor', 'shan', 'fbsp']:
# skip these CWT families to avoid warnings
continue
wavelet = pywt.DiscreteContinuousWavelet(wavelet)
if isinstance(wavelet, pywt.Wavelet):
max_level = pywt.dwt_max_level(r.size, wavelet.dec_len)
if nlevels <= max_level:
a = pywt.downcoef('a', r, wavelet, level=nlevels)
d = pywt.downcoef('d', r, wavelet, level=nlevels)
coeffs = pywt.wavedec(r, wavelet, level=nlevels)
assert_allclose(a, coeffs[0])
assert_allclose(d, coeffs[1])
def test_upcoef_multilevel():
rstate = np.random.RandomState(1234)
r = rstate.randn(4)
nlevels = 3
# calling with level=1 nlevels times
a1 = r.copy()
for i in range(nlevels):
a1 = pywt.upcoef('a', a1, 'haar', level=1)
# call with level=nlevels once
a3 = pywt.upcoef('a', r, 'haar', level=nlevels)
assert_allclose(a1, a3)
def test_upcoef_complex():
rstate = np.random.RandomState(1234)
r = rstate.randn(4) + 1j*rstate.randn(4)
nlevels = 3
a = pywt.upcoef('a', r, 'haar', level=nlevels)
a_ref = pywt.upcoef('a', r.real, 'haar', level=nlevels)
a_ref = a_ref + 1j*pywt.upcoef('a', r.imag, 'haar', level=nlevels)
assert_allclose(a, a_ref)
def test_upcoef_errs():
# invalid part string (not 'a' or 'd')
assert_raises(ValueError, pywt.upcoef, 'f', np.ones(4), 'haar')
def test_upcoef_and_downcoef_1d_only():
# upcoef and downcoef raise a ValueError if data.ndim > 1d
for ndim in [2, 3]:
data = np.ones((8, )*ndim)
assert_raises(ValueError, pywt.downcoef, 'a', data, 'haar')
assert_raises(ValueError, pywt.upcoef, 'a', data, 'haar')
def test_wavelet_repr():
from pywt._extensions import _pywt
wavelet = _pywt.Wavelet('sym8')
repr_wavelet = eval(wavelet.__repr__())
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
def test_dwt_max_level():
assert_(pywt.dwt_max_level(16, 2) == 4)
assert_(pywt.dwt_max_level(16, 8) == 1)
assert_(pywt.dwt_max_level(16, 9) == 1)
assert_(pywt.dwt_max_level(16, 10) == 0)
assert_(pywt.dwt_max_level(16, np.int8(10)) == 0)
assert_(pywt.dwt_max_level(16, 10.) == 0)
assert_(pywt.dwt_max_level(16, 18) == 0)
# accepts discrete Wavelet object or string as well
assert_(pywt.dwt_max_level(32, pywt.Wavelet('sym5')) == 1)
assert_(pywt.dwt_max_level(32, 'sym5') == 1)
# string input that is not a discrete wavelet
assert_raises(ValueError, pywt.dwt_max_level, 16, 'mexh')
# filter_len must be an integer >= 2
assert_raises(ValueError, pywt.dwt_max_level, 16, 1)
assert_raises(ValueError, pywt.dwt_max_level, 16, -1)
assert_raises(ValueError, pywt.dwt_max_level, 16, 3.3)
def test_ContinuousWavelet_errs():
assert_raises(ValueError, pywt.ContinuousWavelet, 'qwertz')
def test_ContinuousWavelet_repr():
from pywt._extensions import _pywt
wavelet = _pywt.ContinuousWavelet('gaus2')
repr_wavelet = eval(wavelet.__repr__())
assert_(wavelet.__repr__() == repr_wavelet.__repr__())
def test_wavelist():
for name in pywt.wavelist(family='coif'):
assert_(name.startswith('coif'))
assert_('cgau7' in pywt.wavelist(kind='continuous'))
assert_('sym20' in pywt.wavelist(kind='discrete'))
assert_(len(pywt.wavelist(kind='continuous')) +
len(pywt.wavelist(kind='discrete')) ==
len(pywt.wavelist(kind='all')))
assert_raises(ValueError, pywt.wavelist, kind='foobar')
def test_wavelet_errormsgs():
try:
pywt.Wavelet('gaus1')
except ValueError as e:
assert_(e.args[0].startswith('The `Wavelet` class'))
try:
pywt.Wavelet('cmord')
except ValueError as e:
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")

View file

@ -0,0 +1,105 @@
"""
Tests used to verify running PyWavelets transforms in parallel via
concurrent.futures.ThreadPoolExecutor does not raise errors.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
from functools import partial
from numpy.testing import assert_array_equal, assert_allclose
from pywt._pytest import uses_futures, futures, max_workers
import pywt
def _assert_all_coeffs_equal(coefs1, coefs2):
# return True only if all coefficients of SWT or DWT match over all levels
if len(coefs1) != len(coefs2):
return False
for (c1, c2) in zip(coefs1, coefs2):
if isinstance(c1, tuple):
# for swt, swt2, dwt, dwt2, wavedec, wavedec2
for a1, a2 in zip(c1, c2):
assert_array_equal(a1, a2)
elif isinstance(c1, dict):
# for swtn, dwtn, wavedecn
for k, v in c1.items():
assert_array_equal(v, c2[k])
else:
return False
return True
@uses_futures
def test_concurrent_swt():
# tests error-free concurrent operation (see gh-288)
# swt on 1D data calls the Cython swt
# other cases call swt_axes
with warnings.catch_warnings():
# can remove catch_warnings once the swt2 FutureWarning is removed
warnings.simplefilter('ignore', FutureWarning)
for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(swt_func, wavelet='haar', level=3)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_wavedec():
# wavedec on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(wavedec_func, wavelet='haar', level=1)
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal(expected_result, results[-1])
@uses_futures
def test_concurrent_dwt():
# dwt on 1D data calls the Cython dwt_single
# other cases call dwt_axis
for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
[np.ones(8), np.eye(16), np.eye(16)]):
transform = partial(dwt_func, wavelet='haar')
for _ in range(10):
arrs = [x.copy() for _ in range(100)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(x)
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
@uses_futures
def test_concurrent_cwt():
atol = rtol = 1e-14
time, sst = pywt.data.nino()
dt = time[1]-time[0]
transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
sampling_period=dt)
for _ in range(10):
arrs = [sst.copy() for _ in range(50)]
with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
results = list(ex.map(transform, arrs))
# validate result from one of the concurrent runs
expected_result = transform(sst)
for a1, a2 in zip(expected_result, results[-1]):
assert_allclose(a1, a2, atol=atol, rtol=rtol)

View file

@ -0,0 +1,434 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from itertools import product
from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
assert_raises, assert_equal)
import pytest
import numpy as np
import pywt
def ref_gaus(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = (2./np.pi)**(1./4.)*np.exp(-(X**2))
if (num == 1):
psi = -2.*X*F0
elif (num == 2):
psi = -2/(3**(1/2))*(-1 + 2*X**2)*F0
elif (num == 3):
psi = -4/(15**(1/2))*X*(3 - 2*X**2)*F0
elif (num == 4):
psi = 4/(105**(1/2))*(3 - 12*X**2 + 4*X**4)*F0
elif (num == 5):
psi = 8/(3*(105**(1/2)))*X*(-15 + 20*X**2 - 4*X**4)*F0
elif (num == 6):
psi = -8/(3*(1155**(1/2)))*(-15 + 90*X**2 - 60*X**4 + 8*X**6)*F0
elif (num == 7):
psi = -16/(3*(15015**(1/2)))*X*(105 - 210*X**2 + 84*X**4 - 8*X**6)*F0
elif (num == 8):
psi = 16/(45*(1001**(1/2)))*(105 - 840*X**2 + 840*X**4 -
224*X**6 + 16*X**8)*F0
return (psi, X)
def ref_cgau(LB, UB, N, num):
X = np.linspace(LB, UB, N)
F0 = np.exp(-X**2)
F1 = np.exp(-1j*X)
F2 = (F1*F0)/(np.exp(-1/2)*2**(1/2)*np.pi**(1/2))**(1/2)
if (num == 1):
psi = F2*(-1j - 2*X)*2**(1/2)
elif (num == 2):
psi = 1/3*F2*(-3 + 4j*X + 4*X**2)*6**(1/2)
elif (num == 3):
psi = 1/15*F2*(7j + 18*X - 12j*X**2 - 8*X**3)*30**(1/2)
elif (num == 4):
psi = 1/105*F2*(25 - 56j*X - 72*X**2 + 32j*X**3 + 16*X**4)*210**(1/2)
elif (num == 5):
psi = 1/315*F2*(-81j - 250*X + 280j*X**2 + 240*X**3 -
80j*X**4 - 32*X**5)*210**(1/2)
elif (num == 6):
psi = 1/3465*F2*(-331 + 972j*X + 1500*X**2 - 1120j*X**3 - 720*X**4 +
192j*X**5 + 64*X**6)*2310**(1/2)
elif (num == 7):
psi = 1/45045*F2*(
1303j + 4634*X - 6804j*X**2 - 7000*X**3 + 3920j*X**4 + 2016*X**5 -
448j*X**6 - 128*X**7)*30030**(1/2)
elif (num == 8):
psi = 1/45045*F2*(
5937 - 20848j*X - 37072*X**2 + 36288j*X**3 + 28000*X**4 -
12544j*X**5 - 5376*X**6 + 1024j*X**7 + 256*X**8)*2002**(1/2)
psi = psi/np.real(np.sqrt(np.real(np.sum(psi*np.conj(psi)))*(X[1] - X[0])))
return (psi, X)
def sinc2(x):
y = np.ones_like(x)
k = np.where(x)[0]
y[k] = np.sin(np.pi*x[k])/(np.pi*x[k])
return y
def ref_shan(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*(sinc2(Fb*x)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_fbsp(LB, UB, N, m, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = np.sqrt(Fb)*((sinc2(Fb*x/m)**m)*np.exp(2j*np.pi*Fc*x))
return (psi, x)
def ref_cmor(LB, UB, N, Fb, Fc):
x = np.linspace(LB, UB, N)
psi = ((np.pi*Fb)**(-0.5))*np.exp(2j*np.pi*Fc*x)*np.exp(-(x**2)/Fb)
return (psi, x)
def ref_morl(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = np.exp(-(x**2)/2)*np.cos(5*x)
return (psi, x)
def ref_mexh(LB, UB, N):
x = np.linspace(LB, UB, N)
psi = (2/(np.sqrt(3)*np.pi**0.25))*np.exp(-(x**2)/2)*(1 - (x**2))
return (psi, x)
def test_gaus():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_gaus(LB, UB, N, num)
w = pywt.ContinuousWavelet("gaus" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cgau():
LB = -5
UB = 5
N = 1000
for num in np.arange(1, 9):
[psi, x] = ref_cgau(LB, UB, N, num)
w = pywt.ContinuousWavelet("cgau" + str(num))
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_shan():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_shan(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("shan{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_cmor():
LB = -20
UB = 20
N = 1000
Fb = 1
Fc = 1.5
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
Fb = 1.5
Fc = 1
[psi, x] = ref_cmor(LB, UB, N, Fb, Fc)
w = pywt.ContinuousWavelet("cmor{}-{}".format(Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
def test_fbsp():
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1
Fc = 1.5
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 2
Fb = 1.5
Fc = 1
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi), atol=1e-15)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-15)
assert_allclose(X, x, atol=1e-15)
LB = -20
UB = 20
N = 1000
M = 3
Fb = 1.5
Fc = 1.2
[psi, x] = ref_fbsp(LB, UB, N, M, Fb, Fc)
w = pywt.ContinuousWavelet("fbsp{}-{}-{}".format(M, Fb, Fc))
assert_almost_equal(w.center_frequency, Fc)
assert_almost_equal(w.bandwidth_frequency, Fb)
w.fbsp_order = M
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
# TODO: investigate why atol = 1e-5 is necessary
assert_allclose(np.real(PSI), np.real(psi), atol=1e-5)
assert_allclose(np.imag(PSI), np.imag(psi), atol=1e-5)
assert_allclose(X, x, atol=1e-15)
def test_morl():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_morl(LB, UB, N)
w = pywt.ContinuousWavelet("morl")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_mexh():
LB = -5
UB = 5
N = 1000
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
LB = -5
UB = 5
N = 1001
[psi, x] = ref_mexh(LB, UB, N)
w = pywt.ContinuousWavelet("mexh")
w.upper_bound = UB
w.lower_bound = LB
PSI, X = w.wavefun(length=N)
assert_allclose(np.real(PSI), np.real(psi))
assert_allclose(np.imag(PSI), np.imag(psi))
assert_allclose(X, x)
def test_cwt_parameters_in_names():
for func in [pywt.ContinuousWavelet, pywt.DiscreteContinuousWavelet]:
for name in ['fbsp', 'cmor', 'shan']:
# additional parameters should be specified within the name
assert_warns(FutureWarning, func, name)
for name in ['cmor', 'shan']:
# valid names
func(name + '1.5-1.0')
func(name + '1-4')
# invalid names
assert_raises(ValueError, func, name + '1.0')
assert_raises(ValueError, func, name + 'B-C')
assert_raises(ValueError, func, name + '1.0-1.0-1.0')
# valid names
func('fbsp1-1.5-1.0')
func('fbsp1.0-1.5-1')
func('fbsp2-5-1')
# invalid name (non-integer order)
assert_raises(ValueError, func, 'fbsp1.5-1-1')
assert_raises(ValueError, func, 'fbspM-B-C')
# invalid name (too few or too many params)
assert_raises(ValueError, func, 'fbsp1.0')
assert_raises(ValueError, func, 'fbsp1.0-0.4')
assert_raises(ValueError, func, 'fbsp1-1-1-1')
@pytest.mark.parametrize('dtype, tol, method',
[(np.float32, 1e-5, 'conv'),
(np.float32, 1e-5, 'fft'),
(np.float64, 1e-13, 'conv'),
(np.float64, 1e-13, 'fft')])
def test_cwt_complex(dtype, tol, method):
time, sst = pywt.data.nino()
sst = np.asarray(sst, dtype=dtype)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# real-valued tranfsorm as a reference
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# complex-valued transform equals sum of the transforms of the real
# and imaginary components
sst_complex = sst + 1j*sst
[cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt,
method=method)
assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol)
# verify dtype is preserved
assert_equal(cfs_complex.dtype, sst_complex.dtype)
@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft']))
def test_cwt_batch(axis, method):
dtype = np.float64
time, sst = pywt.data.nino()
n_batch = 8
batch_axis = 1 - axis
sst1 = np.asarray(sst, dtype=dtype)
sst = np.stack((sst1, ) * n_batch, axis=batch_axis)
dt = time[1] - time[0]
wavelet = 'cmor1.5-1.0'
scales = np.arange(1, 32)
# non-batch transform as reference
[cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis)
shape_in = sst.shape
[cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis)
# shape of input is not modified
assert_equal(shape_in, sst.shape)
# verify same precision
assert_equal(cfs.real.dtype, sst.dtype)
# verify expected shape
assert_equal(cfs.shape[0], len(scales))
assert_equal(cfs.shape[1 + batch_axis], n_batch)
assert_equal(cfs.shape[1 + axis], sst.shape[axis])
# batch result on stacked input is the same as stacked 1d result
assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1))
def test_cwt_small_scales():
data = np.zeros(32)
# A scale of 0.1 was chosen specifically to give a filter of length 2 for
# mexh. This corner case should not raise an error.
cfs, f = pywt.cwt(data, scales=0.1, wavelet='mexh')
assert_allclose(cfs, np.zeros_like(cfs))
# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
def test_cwt_method_fft():
rstate = np.random.RandomState(1)
data = rstate.randn(50)
data[15] = 1.
scales = np.arange(1, 64)
wavelet = 'cmor1.5-1.0'
# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')
# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)

View file

@ -0,0 +1,77 @@
import os
import numpy as np
from numpy.testing import assert_allclose, assert_raises, assert_
import pywt.data
data_dir = os.path.join(os.path.dirname(__file__), 'data')
wavelab_data_file = os.path.join(data_dir, 'wavelab_test_signals.npz')
wavelab_result_dict = np.load(wavelab_data_file)
def test_data_aero():
aero = pywt.data.aero()
ref = np.array([[178, 178, 179],
[170, 173, 171],
[185, 174, 171]])
assert_allclose(aero[:3, :3], ref)
def test_data_ascent():
ascent = pywt.data.ascent()
ref = np.array([[83, 83, 83],
[82, 82, 83],
[80, 81, 83]])
assert_allclose(ascent[:3, :3], ref)
def test_data_camera():
ascent = pywt.data.camera()
ref = np.array([[156, 157, 160],
[156, 157, 159],
[158, 157, 156]])
assert_allclose(ascent[:3, :3], ref)
def test_data_ecg():
ecg = pywt.data.ecg()
ref = np.array([-86, -87, -87])
assert_allclose(ecg[:3], ref)
def test_wavelab_signals():
"""Comparison with results generated using WaveLab"""
rtol = atol = 1e-12
# get a list of the available signals
available_signals = pywt.data.demo_signal('list')
assert_('Doppler' in available_signals)
for signal in available_signals:
# reference dictionary has lowercase names for the keys
key = signal.replace('-', '_').lower()
val = wavelab_result_dict[key]
if key in ['gabor', 'sineoneoverx']:
# these functions do not allow a size to be provided
assert_allclose(val, pywt.data.demo_signal(signal),
rtol=rtol, atol=atol)
assert_raises(ValueError, pywt.data.demo_signal, key, val.size)
else:
assert_allclose(val, pywt.data.demo_signal(signal, val.size),
rtol=rtol, atol=atol)
# these functions require a size to be provided
assert_raises(ValueError, pywt.data.demo_signal, key)
# ValueError on unrecognized signal type
assert_raises(ValueError, pywt.data.demo_signal, 'unknown_signal', 512)
# ValueError on invalid length
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)

View file

@ -0,0 +1,89 @@
import warnings
import numpy as np
from numpy.testing import assert_warns, assert_array_equal
import pywt
def test_intwave_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.intwave, wavelet)
def test_centrfrq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.centrfrq, wavelet)
def test_scal2frq_deprecation():
wavelet = pywt.Wavelet('db3')
assert_warns(DeprecationWarning, pywt.scal2frq, wavelet, 1)
def test_orthfilt_deprecation():
assert_warns(DeprecationWarning, pywt.orthfilt, range(6))
def test_integrate_wave_tuple():
sig = [0, 1, 2, 3]
xgrid = [0, 1, 2, 3]
assert_warns(DeprecationWarning, pywt.integrate_wavelet, (sig, xgrid))
old_modes = ['zpd',
'cpd',
'sym',
'ppd',
'sp1',
'per',
]
def test_MODES_from_object_deprecation():
for mode in old_modes:
assert_warns(DeprecationWarning, pywt.Modes.from_object, mode)
def test_MODES_attributes_deprecation():
def get_mode(Modes, name):
return getattr(Modes, name)
for mode in old_modes:
assert_warns(DeprecationWarning, get_mode, pywt.Modes, mode)
def test_MODES_deprecation_new():
def use_MODES_new():
return pywt.MODES.symmetric
assert_warns(DeprecationWarning, use_MODES_new)
def test_MODES_deprecation_old():
def use_MODES_old():
return pywt.MODES.sym
assert_warns(DeprecationWarning, use_MODES_old)
def test_MODES_deprecation_getattr():
def use_MODES_new():
return getattr(pywt.MODES, 'symmetric')
assert_warns(DeprecationWarning, use_MODES_new)
def test_mode_equivalence():
old_new = [('zpd', 'zero'),
('cpd', 'constant'),
('sym', 'symmetric'),
('ppd', 'periodic'),
('sp1', 'smooth'),
('per', 'periodization')]
x = np.arange(8.)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
for old, new in old_new:
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
pywt.dwt(x, 'db2', mode=new))

View file

@ -0,0 +1,25 @@
from __future__ import division, print_function, absolute_import
import doctest
import glob
import os
import unittest
try:
import numpy as np
np.set_printoptions(legacy='1.13')
except TypeError:
pass
pdir = os.path.pardir
docs_base = os.path.abspath(os.path.join(os.path.dirname(__file__),
pdir, pdir, "doc", "source"))
files = glob.glob(os.path.join(docs_base, "*.rst")) + \
glob.glob(os.path.join(docs_base, "*", "*.rst"))
suite = doctest.DocFileSuite(*files, module_relative=False, encoding="utf-8")
if __name__ == "__main__":
unittest.TextTestRunner().run(suite)

View file

@ -0,0 +1,299 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_array_equal)
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
# Other real types get converted to float64.
# complex256 gets converted to complex128
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# test complex256 as well if it is available
try:
dtypes_in += [np.complex256, ]
dtypes_out += [np.complex128, ]
except AttributeError:
pass
def test_dwt_idwt_basic():
x = [3, 7, 1, 1, -2, 5, 4, 6]
cA, cD = pywt.dwt(x, 'db2')
cA_expect = [5.65685425, 7.39923721, 0.22414387, 3.33677403, 7.77817459]
cD_expect = [-2.44948974, -1.60368225, -4.44140056, -0.41361256,
1.22474487]
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.float64)
def test_idwt_mixed_complex_dtype():
x = np.arange(8).astype(float)
x = x + 1j*x[::-1]
cA, cD = pywt.dwt(x, 'db2')
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)
def test_dwt_idwt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones(4, dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, cD = pywt.dwt(x, wavelet)
assert_(cA.dtype == cD.dtype == dt_out, "dwt: " + errmsg)
x_roundtrip = pywt.idwt(cA, cD, wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwt: " + errmsg)
def test_dwt_idwt_basic_complex():
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 0.5j*x
cA, cD = pywt.dwt(x, 'db2')
cA_expect = np.asarray([5.65685425, 7.39923721, 0.22414387, 3.33677403,
7.77817459])
cA_expect = cA_expect + 0.5j*cA_expect
cD_expect = np.asarray([-2.44948974, -1.60368225, -4.44140056, -0.41361256,
1.22474487])
cD_expect = cD_expect + 0.5j*cD_expect
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
def test_dwt_idwt_partial_complex():
x = np.asarray([3, 7, 1, 1, -2, 5, 4, 6])
x = x + 0.5j*x
cA, cD = pywt.dwt(x, 'haar')
cA_rec_expect = np.array([5.0+2.5j, 5.0+2.5j, 1.0+0.5j, 1.0+0.5j,
1.5+0.75j, 1.5+0.75j, 5.0+2.5j, 5.0+2.5j])
cA_rec = pywt.idwt(cA, None, 'haar')
assert_allclose(cA_rec, cA_rec_expect)
cD_rec_expect = np.array([-2.0-1.0j, 2.0+1.0j, 0.0+0.0j, 0.0+0.0j,
-3.5-1.75j, 3.5+1.75j, -1.0-0.5j, 1.0+0.5j])
cD_rec = pywt.idwt(None, cD, 'haar')
assert_allclose(cD_rec, cD_rec_expect)
assert_allclose(cA_rec + cD_rec, x)
def test_dwt_wavelet_kwd():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
w = pywt.Wavelet('sym3')
cA, cD = pywt.dwt(x, wavelet=w, mode='constant')
cA_expect = [4.38354585, 3.80302657, 7.31813271, -0.58565539, 4.09727044,
7.81994027]
cD_expect = [-1.33068221, -2.78795192, -3.16825651, -0.67715519,
-0.09722957, -0.07045258]
assert_allclose(cA, cA_expect)
assert_allclose(cD, cD_expect)
def test_dwt_coeff_len():
x = np.array([3, 7, 1, 1, -2, 5, 4, 6])
w = pywt.Wavelet('sym3')
ln_modes = [pywt.dwt_coeff_len(len(x), w.dec_len, mode) for mode in
pywt.Modes.modes]
expected_result = [6, ] * len(pywt.Modes.modes)
expected_result[pywt.Modes.modes.index('periodization')] = 4
assert_allclose(ln_modes, expected_result)
ln_modes = [pywt.dwt_coeff_len(len(x), w, mode) for mode in
pywt.Modes.modes]
assert_allclose(ln_modes, expected_result)
def test_idwt_none_input():
# None input equals arrays of zeros of the right length
res1 = pywt.idwt([1, 2, 0, 1], None, 'db2', 'symmetric')
res2 = pywt.idwt([1, 2, 0, 1], [0, 0, 0, 0], 'db2', 'symmetric')
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
res1 = pywt.idwt(None, [1, 2, 0, 1], 'db2', 'symmetric')
res2 = pywt.idwt([0, 0, 0, 0], [1, 2, 0, 1], 'db2', 'symmetric')
assert_allclose(res1, res2, rtol=1e-15, atol=1e-15)
# Only one argument at a time can be None
assert_raises(ValueError, pywt.idwt, None, None, 'db2', 'symmetric')
def test_idwt_invalid_input():
# Too short, min length is 4 for 'db4':
assert_raises(ValueError, pywt.idwt, [1, 2, 4], [4, 1, 3], 'db4', 'symmetric')
def test_dwt_single_axis():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA, cD = pywt.dwt(x, 'db2', axis=-1)
cA0, cD0 = pywt.dwt(x[0], 'db2')
cA1, cD1 = pywt.dwt(x[1], 'db2')
assert_allclose(cA[0], cA0)
assert_allclose(cA[1], cA1)
assert_allclose(cD[0], cD0)
assert_allclose(cD[1], cD1)
def test_idwt_single_axis():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
x = np.asarray(x)
x = x + 1j*x # test with complex data
cA, cD = pywt.dwt(x, 'db2', axis=-1)
x0 = pywt.idwt(cA[0], cD[0], 'db2', axis=-1)
x1 = pywt.idwt(cA[1], cD[1], 'db2', axis=-1)
assert_allclose(x[0], x0)
assert_allclose(x[1], x1)
def test_dwt_axis_arg():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA_, cD_ = pywt.dwt(x, 'db2', axis=-1)
cA, cD = pywt.dwt(x, 'db2', axis=1)
assert_allclose(cA_, cA)
assert_allclose(cD_, cD)
def test_idwt_axis_arg():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
cA, cD = pywt.dwt(x, 'db2', axis=1)
x_ = pywt.idwt(cA, cD, 'db2', axis=-1)
x = pywt.idwt(cA, cD, 'db2', axis=1)
assert_allclose(x_, x)
def test_dwt_idwt_axis_excess():
x = [[3, 7, 1, 1],
[-2, 5, 4, 6]]
# can't transform over axes that aren't there
assert_raises(ValueError,
pywt.dwt, x, 'db2', 'symmetric', axis=2)
assert_raises(ValueError,
pywt.idwt, [1, 2, 4], [4, 1, 3], 'db2', 'symmetric', axis=1)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((32, ))
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, pywt.dwt, data, cwave)
cA, cD = pywt.dwt(data, 'db1')
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
def test_dwt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.dwt, [], 'db2')
# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.dwt, x, 'db2', axis=0)
def test_pad_1d():
x = [1, 2, 3]
assert_array_equal(pywt.pad(x, (4, 6), 'periodization'),
[1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3, 1, 2])
assert_array_equal(pywt.pad(x, (4, 6), 'periodic'),
[3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'constant'),
[1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'zero'),
[0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0, 0])
assert_array_equal(pywt.pad(x, (4, 6), 'smooth'),
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert_array_equal(pywt.pad(x, (4, 6), 'symmetric'),
[3, 3, 2, 1, 1, 2, 3, 3, 2, 1, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'antisymmetric'),
[3, -3, -2, -1, 1, 2, 3, -3, -2, -1, 1, 2, 3])
assert_array_equal(pywt.pad(x, (4, 6), 'reflect'),
[1, 2, 3, 2, 1, 2, 3, 2, 1, 2, 3, 2, 1])
assert_array_equal(pywt.pad(x, (4, 6), 'antireflect'),
[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# equivalence of various pad_width formats
assert_array_equal(pywt.pad(x, 4, 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
assert_array_equal(pywt.pad(x, (4, ), 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
assert_array_equal(pywt.pad(x, [(4, 4)], 'periodic'),
pywt.pad(x, (4, 4), 'periodic'))
def test_pad_errors():
# negative pad width
x = [1, 2, 3]
assert_raises(ValueError, pywt.pad, x, -2, 'periodic')
# wrong length pad width
assert_raises(ValueError, pywt.pad, x, (1, 1, 1), 'periodic')
# invalid mode name
assert_raises(ValueError, pywt.pad, x, 2, 'bad_mode')
def test_pad_nd():
for ndim in [2, 3]:
x = np.arange(4**ndim).reshape((4, ) * ndim)
if ndim == 2:
pad_widths = [(2, 1), (2, 3)]
else:
pad_widths = [(2, 1), ] * ndim
for mode in pywt.Modes.modes:
xp = pywt.pad(x, pad_widths, mode)
# expected result is the same as applying along axes separably
xp_expected = x.copy()
for ax in range(ndim):
xp_expected = np.apply_along_axis(pywt.pad,
ax,
xp_expected,
pad_widths=[pad_widths[ax]],
mode=mode)
assert_array_equal(xp, xp_expected)

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
from numpy.testing import assert_almost_equal, assert_allclose
import pywt
def test_centrfreq():
# db1 is Haar function, frequency=1
w = pywt.Wavelet('db1')
expected = 1
result = pywt.central_frequency(w, precision=12)
assert_almost_equal(result, expected, decimal=3)
# db2, frequency=2/3
w = pywt.Wavelet('db2')
expected = 2/3.
result = pywt.central_frequency(w, precision=12)
assert_almost_equal(result, expected)
def test_scal2frq_scale():
scale = 2
w = pywt.Wavelet('db1')
expected = 1. / scale
result = pywt.scale2frequency(w, scale, precision=12)
assert_almost_equal(result, expected, decimal=3)
def test_intwave_orthogonal():
w = pywt.Wavelet('db1')
int_psi, x = pywt.integrate_wavelet(w, precision=12)
ix = x < 0.5
# For x < 0.5, the integral is equal to x
assert_allclose(int_psi[ix], x[ix])
# For x > 0.5, the integral is equal to (1 - x)
# Ignore last point here, there x > 1 and something goes wrong
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)

View file

@ -0,0 +1,160 @@
"""
Test used to verify PyWavelets Discrete Wavelet Transform computation
accuracy against MathWorks Wavelet Toolbox.
"""
from __future__ import division, print_function, absolute_import
import numpy as np
import pytest
from numpy.testing import assert_
import pywt
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('reflect', 'symw'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per'),
# TODO: Now have implemented asymmetric modes too.
# Would be nice to update the Matlab data to test these as well.
('antisymmetric', 'asym'),
('antireflect', 'asymw'),
]
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
def _get_data_sizes(w):
""" Return the sizes to test for wavelet w. """
if size_set == 'full':
data_sizes = list(range(w.dec_len, 40)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (w.dec_len, w.dec_len + 1)
return data_sizes
@uses_pymatbridge
@pytest.mark.slow
def test_accuracy_pymatbridge():
Matlab = pytest.importorskip("pymatbridge.Matlab")
mlab = Matlab()
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 5.0e-5
epsilon_pywt_coeffs = 1.0e-10
mlab.start()
try:
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
mlab.set_variable('wavelet', wavelet)
for N in _get_data_sizes(w):
data = rstate.randn(N)
mlab.set_variable('data', data)
for pmode, mmode in modes:
ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
finally:
mlab.stop()
@uses_precomputed
@pytest.mark.slow
def test_accuracy_precomputed():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 5.0e-5
epsilon_pywt_coeffs = 1.0e-10
for wavelet in wavelets:
w = pywt.Wavelet(wavelet)
for N in _get_data_sizes(w):
data = rstate.randn(N)
for pmode, mmode in modes:
ma, md = _load_matlab_result(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
_check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
def _compute_matlab_result(data, wavelet, mmode, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
already been set externally.
"""
if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
w = pywt.Wavelet(wavelet)
mlab.set_variable('Lo_D', w.dec_lo)
mlab.set_variable('Hi_D', w.dec_hi)
mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
else:
mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError("Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is a single float64
ma = np.asarray(mlab.get_variable('ma'))
md = np.asarray(mlab.get_variable('md'))
return ma, md
def _load_matlab_result(data, wavelet, mmode):
""" Load the precomputed result.
"""
N = len(data)
ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
md_key = '_'.join([mmode, wavelet, str(N), 'md'])
if (ma_key not in matlab_result_dict) or \
(md_key not in matlab_result_dict):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
ma = matlab_result_dict[ma_key]
md = matlab_result_dict[md_key]
return ma, md
def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
""" Load the precomputed result.
"""
N = len(data)
ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
if (ma_key not in matlab_result_dict) or \
(md_key not in matlab_result_dict):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
ma = matlab_result_dict[ma_key]
md = matlab_result_dict[md_key]
return ma, md
def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
# PyWavelets result
pa, pd = pywt.dwt(data, w, pmode)
# calculate error measures
rms_a = np.sqrt(np.mean((pa - ma) ** 2))
rms_d = np.sqrt(np.mean((pd - md) ** 2))
msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
assert_(rms_a < epsilon, msg=msg)
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
assert_(rms_d < epsilon, msg=msg)

View file

@ -0,0 +1,174 @@
"""
Test used to verify PyWavelets Continuous Wavelet Transform computation
accuracy against MathWorks Wavelet Toolbox.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
import pytest
from numpy.testing import assert_
import pywt
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
matlab_result_dict_cwt)
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
wavelets = sum([pywt.wavelist(name) for name in families], [])
def _get_data_sizes(w):
""" Return the sizes to test for wavelet w. """
if size_set == 'full':
data_sizes = list(range(100, 101)) + \
[100, 200, 500, 1000, 50000]
else:
data_sizes = (1000, 1000 + 1)
return data_sizes
def _get_scales(w):
""" Return the scales to test for wavelet w. """
if size_set == 'full':
scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
else:
scales = (1, np.arange(1, 3))
return scales
@uses_pymatbridge # skip this case if precomputed results are used instead
@pytest.mark.slow
def test_accuracy_pymatbridge_cwt():
Matlab = pytest.importorskip("pymatbridge.Matlab")
mlab = Matlab()
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 1e-15
epsilon_psi = 1e-15
mlab.start()
try:
for wavelet in wavelets:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
w = pywt.ContinuousWavelet(wavelet)
if np.any((wavelet == np.array(['shan', 'cmor'])),axis=0):
mlab.set_variable('wavelet', wavelet+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
elif wavelet == 'fbsp':
mlab.set_variable('wavelet', wavelet+str(w.fbsp_order)+'-'+str(w.bandwidth_frequency)+'-'+str(w.center_frequency))
else:
mlab.set_variable('wavelet', wavelet)
mlab_code = ("psi = wavefun(wavelet,10)")
res = mlab.run_code(mlab_code)
psi = np.asarray(mlab.get_variable('psi'))
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
mlab.set_variable('data', data)
for scales in _get_scales(w):
coefs = _compute_matlab_result(data, wavelet, scales, mlab)
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
finally:
mlab.stop()
@uses_precomputed # skip this case if pymatbridge + Matlab are being used
@pytest.mark.slow
def test_accuracy_precomputed_cwt():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
# has to be improved
epsilon = 2e-15
epsilon32 = 1e-5
epsilon_psi = 1e-15
for wavelet in wavelets:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
w = pywt.ContinuousWavelet(wavelet)
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
psi = _load_matlab_result_psi(wavelet)
_check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
data32 = data.astype(np.float32)
scales_count = 0
for scales in _get_scales(w):
scales_count += 1
coefs = _load_matlab_result(data, wavelet, scales_count)
_check_accuracy(data, w, scales, coefs, wavelet, epsilon)
_check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
def _compute_matlab_result(data, wavelet, scales, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
already been set externally.
"""
mlab.set_variable('scales', scales)
mlab_code = ("coefs = cwt(data, scales, wavelet)")
res = mlab.run_code(mlab_code)
if not res['success']:
raise RuntimeError("Matlab failed to execute the provided code. "
"Check that the wavelet toolbox is installed.")
# need np.asarray because sometimes the output is a single float64
coefs = np.asarray(mlab.get_variable('coefs'))
return coefs
def _load_matlab_result(data, wavelet, scales):
""" Load the precomputed result.
"""
N = len(data)
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
if (coefs_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
coefs = matlab_result_dict_cwt[coefs_key]
return coefs
def _load_matlab_result_psi(wavelet):
""" Load the precomputed result.
"""
psi_key = '_'.join([wavelet, 'psi'])
if (psi_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab psi result not found for wavelet: "
"{0}}".format(wavelet))
psi = matlab_result_dict_cwt[psi_key]
return psi
def _check_accuracy(data, w, scales, coefs, wavelet, epsilon):
# PyWavelets result
coefs_pywt, freq = pywt.cwt(data, scales, w)
# coefs from Matlab are from R2012a which is missing the complex conjugate
# as shown in Eq. 2 of Torrence and Compo. We take the complex conjugate of
# the precomputed Matlab result to account for this.
coefs = np.conj(coefs)
# calculate error measures
err = coefs_pywt - coefs
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
msg = ('[RMS > EPSILON] for Scale: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (scales, wavelet, len(data), rms))
assert_(rms < epsilon, msg=msg)
def _check_accuracy_psi(w, psi, wavelet, epsilon):
# PyWavelets result
psi_pywt, x = w.wavefun(length=1024)
# calculate error measures
err = psi_pywt.flatten() - psi.flatten()
rms = np.real(np.sqrt(np.mean(np.conj(err) * err)))
msg = ('[RMS > EPSILON] for Wavelet: %s, '
'rms=%.3g' % (wavelet, rms))
assert_(rms < epsilon, msg=msg)

View file

@ -0,0 +1,109 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_raises, assert_equal, assert_allclose
import pywt
def test_available_modes():
modes = ['zero', 'constant', 'symmetric', 'periodic', 'smooth',
'periodization', 'reflect', 'antisymmetric', 'antireflect']
assert_equal(pywt.Modes.modes, modes)
assert_equal(pywt.Modes.from_object('constant'), 2)
def test_invalid_modes():
x = np.arange(4)
assert_raises(ValueError, pywt.dwt, x, 'db2', 'unknown')
assert_raises(ValueError, pywt.dwt, x, 'db2', -1)
assert_raises(ValueError, pywt.dwt, x, 'db2', 9)
assert_raises(TypeError, pywt.dwt, x, 'db2', None)
assert_raises(ValueError, pywt.Modes.from_object, 'unknown')
assert_raises(ValueError, pywt.Modes.from_object, -1)
assert_raises(ValueError, pywt.Modes.from_object, 9)
assert_raises(TypeError, pywt.Modes.from_object, None)
def test_dwt_idwt_allmodes():
# Test that :func:`dwt` and :func:`idwt` can be performed using every mode
x = [1, 2, 1, 5, -1, 8, 4, 6]
dwt_results = {
'zero': ([-0.03467518, 1.73309178, 3.40612438, 6.32928585, 6.95094948],
[-0.12940952, -2.15599552, -5.95034847, -1.21545369,
-1.8625013]),
'constant': ([1.28480404, 1.73309178, 3.40612438, 6.32928585,
7.51935555],
[-0.48296291, -2.15599552, -5.95034847, -1.21545369,
0.25881905]),
'symmetric': ([1.76776695, 1.73309178, 3.40612438, 6.32928585,
7.77817459],
[-0.61237244, -2.15599552, -5.95034847, -1.21545369,
1.22474487]),
'reflect': ([2.12132034, 1.73309178, 3.40612438, 6.32928585,
6.81224877],
[-0.70710678, -2.15599552, -5.95034847, -1.21545369,
-2.38013939]),
'periodic': ([6.9162743, 1.73309178, 3.40612438, 6.32928585,
6.9162743],
[-1.99191082, -2.15599552, -5.95034847, -1.21545369,
-1.99191082]),
'smooth': ([-0.51763809, 1.73309178, 3.40612438, 6.32928585,
7.45000519],
[0, -2.15599552, -5.95034847, -1.21545369, 0]),
'periodization': ([4.053172, 3.05257099, 2.85381112, 8.42522221],
[0.18946869, 4.18258152, 4.33737503, 2.60428326]),
'antisymmetric': ([-1.83711731, 1.73309178, 3.40612438, 6.32928585,
6.12372436],
[0.353553391, -2.15599552, -5.95034847, -1.21545369,
-4.94974747]),
'antireflect': ([0.44828774, 1.73309178, 3.40612438, 6.32928585,
8.22646233],
[-0.25881905, -2.15599552, -5.95034847, -1.21545369,
2.89777748])
}
for mode in pywt.Modes.modes:
cA, cD = pywt.dwt(x, 'db2', mode)
assert_allclose(cA, dwt_results[mode][0], rtol=1e-7, atol=1e-8)
assert_allclose(cD, dwt_results[mode][1], rtol=1e-7, atol=1e-8)
assert_allclose(pywt.idwt(cA, cD, 'db2', mode), x, rtol=1e-10)
def test_dwt_short_input_allmodes():
# some test cases where the input is shorter than the DWT filter
x = [1, 3, 2]
wavelet = 'db2'
# manually pad each end by the filter size (4 for 'db2' used here)
padded_x = {'zero': [0, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0],
'constant': [1, 1, 1, 1, 1, 3, 2, 2, 2, 2, 2],
'symmetric': [2, 2, 3, 1, 1, 3, 2, 2, 3, 1, 1],
'reflect': [1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2],
'periodic': [2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 1],
'smooth': [-7, -5, -3, -1, 1, 3, 2, 1, 0, -1, -2],
'antisymmetric': [2, -2, -3, -1, 1, 3, 2, -2, -3, -1, 1],
'antireflect': [1, -1, 0, -1, 1, 3, 2, 1, 3, 5, 4],
}
for mode, xpad in padded_x.items():
# DWT of the manually padded array. will discard edges later so
# symmetric mode used here doesn't matter.
cApad, cDpad = pywt.dwt(xpad, wavelet, mode='symmetric')
# central region of the padded output (unaffected by mode )
expected_result = (cApad[2:-2], cDpad[2:-2])
cA, cD = pywt.dwt(x, wavelet, mode)
assert_allclose(cA, expected_result[0], rtol=1e-7, atol=1e-8)
assert_allclose(cD, expected_result[1], rtol=1e-7, atol=1e-8)
def test_default_mode():
# The default mode should be 'symmetric'
x = [1, 2, 1, 5, -1, 8, 4, 6]
cA, cD = pywt.dwt(x, 'db2')
cA2, cD2 = pywt.dwt(x, 'db2', mode='symmetric')
assert_allclose(cA, cA2)
assert_allclose(cD, cD2)
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)

View file

@ -0,0 +1,443 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from itertools import combinations
from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
# Other real types get converted to float64.
# complex256 gets converted to complex128
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# test complex256 as well if it is available
try:
dtypes_in += [np.complex256, ]
dtypes_out += [np.complex128, ]
except AttributeError:
pass
def test_dwtn_input():
# Array-like must be accepted
pywt.dwtn([1, 2, 3, 4], 'haar')
# Others must not
data = dict()
assert_raises(TypeError, pywt.dwtn, data, 'haar')
# Must be at least 1D
assert_raises(ValueError, pywt.dwtn, 2, 'haar')
def test_3D_reconstruct():
data = np.array([
[[0, 4, 1, 5, 1, 4],
[0, 5, 26, 3, 2, 1],
[5, 8, 2, 33, 4, 9],
[2, 5, 19, 4, 19, 1]],
[[1, 5, 1, 2, 3, 4],
[7, 12, 6, 52, 7, 8],
[2, 12, 3, 52, 6, 8],
[5, 2, 6, 78, 12, 2]]])
wavelet = pywt.Wavelet('haar')
for mode in pywt.Modes.modes:
d = pywt.dwtn(data, wavelet, mode=mode)
assert_allclose(data, pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-13, atol=1e-13)
def test_dwdtn_idwtn_allwavelets():
rstate = np.random.RandomState(1234)
r = rstate.randn(16, 16)
# test 2D case only for all wavelet types
wavelist = pywt.wavelist()
if 'dmey' in wavelist:
wavelist.remove('dmey')
for wavelet in wavelist:
if wavelet in ['cmor', 'shan', 'fbsp']:
# skip these CWT families to avoid warnings
continue
if isinstance(pywt.DiscreteContinuousWavelet(wavelet), pywt.Wavelet):
for mode in pywt.Modes.modes:
coeffs = pywt.dwtn(r, wavelet, mode=mode)
assert_allclose(pywt.idwtn(coeffs, wavelet, mode=mode),
r, rtol=1e-7, atol=1e-7)
def test_stride():
wavelet = pywt.Wavelet('haar')
for dtype in ('float32', 'float64'):
data = np.array([[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]],
dtype=dtype)
for mode in pywt.Modes.modes:
expected = pywt.dwtn(data, wavelet)
strided = np.ones((3, 12), dtype=data.dtype)
strided[::-1, ::2] = data
strided_dwtn = pywt.dwtn(strided[::-1, ::2], wavelet)
for key in expected.keys():
assert_allclose(strided_dwtn[key], expected[key])
def test_byte_offset():
wavelet = pywt.Wavelet('haar')
for dtype in ('float32', 'float64'):
data = np.array([[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]],
dtype=dtype)
for mode in pywt.Modes.modes:
expected = pywt.dwtn(data, wavelet)
padded = np.ones((3, 6), dtype=np.dtype({'data': (data.dtype, 0),
'pad': ('byte', data.dtype.itemsize)},
align=True))
padded[:] = data
padded_dwtn = pywt.dwtn(padded['data'], wavelet)
for key in expected.keys():
assert_allclose(padded_dwtn[key], expected[key])
def test_3D_reconstruct_complex():
# All dimensions even length so `take` does not need to be specified
data = np.array([
[[0, 4, 1, 5, 1, 4],
[0, 5, 26, 3, 2, 1],
[5, 8, 2, 33, 4, 9],
[2, 5, 19, 4, 19, 1]],
[[1, 5, 1, 2, 3, 4],
[7, 12, 6, 52, 7, 8],
[2, 12, 3, 52, 6, 8],
[5, 2, 6, 78, 12, 2]]])
data = data + 1j
wavelet = pywt.Wavelet('haar')
d = pywt.dwtn(data, wavelet)
# idwtn creates even-length shapes (2x dwtn size)
original_shape = tuple([slice(None, s) for s in data.shape])
assert_allclose(data, pywt.idwtn(d, wavelet)[original_shape],
rtol=1e-13, atol=1e-13)
def test_idwtn_idwt2():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
for mode in pywt.Modes.modes:
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-14, atol=1e-14)
def test_idwtn_idwt2_complex():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
data = data + 1j
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
for mode in pywt.Modes.modes:
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet, mode=mode),
pywt.idwtn(d, wavelet, mode=mode),
rtol=1e-14, atol=1e-14)
def test_idwtn_missing():
# Test to confirm missing data behave as zeroes
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
coefs = pywt.dwtn(data, wavelet)
# No point removing zero, or all
for num_missing in range(1, len(coefs)):
for missing in combinations(coefs.keys(), num_missing):
missing_coefs = coefs.copy()
for key in missing:
del missing_coefs[key]
LL = missing_coefs.get('aa', None)
HL = missing_coefs.get('da', None)
LH = missing_coefs.get('ad', None)
HH = missing_coefs.get('dd', None)
assert_allclose(pywt.idwt2((LL, (HL, LH, HH)), wavelet),
pywt.idwtn(missing_coefs, 'haar'), atol=1e-15)
def test_idwtn_all_coeffs_None():
coefs = dict(aa=None, da=None, ad=None, dd=None)
assert_raises(ValueError, pywt.idwtn, coefs, 'haar')
def test_error_on_invalid_keys():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
# unexpected key
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH, 'ff': LH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
# mismatched key lengths
d = {'a': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
def test_error_mismatched_size():
data = np.array([
[0, 4, 1, 5, 1, 4],
[0, 5, 6, 3, 2, 1],
[2, 5, 19, 4, 19, 1]])
wavelet = pywt.Wavelet('haar')
LL, (HL, LH, HH) = pywt.dwt2(data, wavelet)
# Pass/fail depends on first element being shorter than remaining ones so
# set 3/4 to an incorrect size to maximize chances. Order of dict items
# is random so may not trigger on every test run. Dict is constructed
# inside idwtn function so no use using an OrderedDict here.
LL = LL[:, :-1]
LH = LH[:, :-1]
HH = HH[:, :-1]
d = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
assert_raises(ValueError, pywt.idwtn, d, wavelet)
def test_dwt2_idwt2_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
cA, (cH, cV, cD) = pywt.dwt2(x, wavelet)
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype,
"dwt2: " + errmsg)
x_roundtrip = pywt.idwt2((cA, (cH, cV, cD)), wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwt2: " + errmsg)
def test_dwtn_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1,))
expected_a = list(map(lambda x: pywt.dwt(x, 'haar')[0], data))
assert_equal(coefs['a'], expected_a)
expected_d = list(map(lambda x: pywt.dwt(x, 'haar')[1], data))
assert_equal(coefs['d'], expected_d)
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
expected_aa = list(map(lambda x: pywt.dwt(x, 'haar')[0], expected_a))
assert_equal(coefs['aa'], expected_aa)
expected_ad = list(map(lambda x: pywt.dwt(x, 'haar')[1], expected_a))
assert_equal(coefs['ad'], expected_ad)
def test_idwtn_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
assert_allclose(pywt.idwtn(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
def test_idwt2_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
cA, (cH, cV, cD) = pywt.dwt2(data, 'haar', axes=(1, 1))
# verify setting coefficients to None is the same as zeroing them
cD = np.zeros_like(cD)
result_zeros = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
cD = None
result_none = pywt.idwt2((cA, (cH, cV, cD)), 'haar', axes=(1, 1))
assert_equal(result_zeros, result_none)
def test_idwtn_none_coeffs():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
data = data + 1j*data # test with complex data
coefs = pywt.dwtn(data, 'haar', axes=(1, 1))
# verify setting coefficients to None is the same as zeroing them
coefs['dd'] = np.zeros_like(coefs['dd'])
result_zeros = pywt.idwtn(coefs, 'haar', axes=(1, 1))
coefs['dd'] = None
result_none = pywt.idwtn(coefs, 'haar', axes=(1, 1))
assert_equal(result_zeros, result_none)
def test_idwt2_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
coefs = pywt.dwt2(data, 'haar', axes=(1, 1))
assert_allclose(pywt.idwt2(coefs, 'haar', axes=(1, 1)), data, atol=1e-14)
# too many axes
assert_raises(ValueError, pywt.idwt2, coefs, 'haar', axes=(0, 1, 1))
def test_idwt2_axes_subsets():
data = np.array(np.random.standard_normal((4, 4, 4)))
# test all combinations of 2 out of 3 axes transformed
for axes in combinations((0, 1, 2), 2):
coefs = pywt.dwt2(data, 'haar', axes=axes)
assert_allclose(pywt.idwt2(coefs, 'haar', axes=axes), data, atol=1e-14)
def test_idwtn_axes_subsets():
data = np.array(np.random.standard_normal((4, 4, 4, 4)))
# test all combinations of 3 out of 4 axes transformed
for axes in combinations((0, 1, 2, 3), 3):
coefs = pywt.dwtn(data, 'haar', axes=axes)
assert_allclose(pywt.idwtn(coefs, 'haar', axes=axes), data, atol=1e-14)
def test_negative_axes():
data = np.array([[0, 1, 2, 3],
[1, 1, 1, 1],
[1, 4, 2, 8]])
coefs1 = pywt.dwtn(data, 'haar', axes=(1, 1))
coefs2 = pywt.dwtn(data, 'haar', axes=(-1, -1))
assert_equal(coefs1, coefs2)
rec1 = pywt.idwtn(coefs1, 'haar', axes=(1, 1))
rec2 = pywt.idwtn(coefs1, 'haar', axes=(-1, -1))
assert_equal(rec1, rec2)
def test_dwtn_idwtn_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
x = np.ones((4, 4), dtype=dt_in)
errmsg = "wrong dtype returned for {0} input".format(dt_in)
coeffs = pywt.dwtn(x, wavelet)
for k, v in coeffs.items():
assert_(v.dtype == dt_out, "dwtn: " + errmsg)
x_roundtrip = pywt.idwtn(coeffs, wavelet)
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)
def test_idwtn_mixed_complex_dtype():
rstate = np.random.RandomState(0)
x = rstate.randn(8, 8, 8)
x = x + 1j*x
coeffs = pywt.dwtn(x, 'db2')
x_roundtrip = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)
# mismatched dtypes OK
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)
def test_idwt2_size_mismatch_error():
LL = np.zeros((6, 6))
LH = HL = HH = np.zeros((5, 5))
assert_raises(ValueError, pywt.idwt2, (LL, (LH, HL, HH)), wavelet='haar')
def test_dwt2_dimension_error():
data = np.ones(16)
wavelet = pywt.Wavelet('haar')
# wrong number of input dimensions
assert_raises(ValueError, pywt.dwt2, data, wavelet)
# too many axes
data2 = np.ones((8, 8))
assert_raises(ValueError, pywt.dwt2, data2, wavelet, axes=(0, 1, 1))
def test_per_axis_wavelets_and_modes():
# tests seperate wavelet and edge mode for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(16, 16, 16)
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
# mode can be a string or a Modes enum
modes = ('symmetric', 'periodization',
pywt._extensions._pywt.Modes.reflect)
coefs = pywt.dwtn(data, wavelets, modes)
assert_allclose(pywt.idwtn(coefs, wavelets, modes), data, atol=1e-14)
coefs = pywt.dwtn(data, wavelets[:1], modes)
assert_allclose(pywt.idwtn(coefs, wavelets[:1], modes), data, atol=1e-14)
coefs = pywt.dwtn(data, wavelets, modes[:1])
assert_allclose(pywt.idwtn(coefs, wavelets, modes[:1]), data, atol=1e-14)
# length of wavelets or modes doesn't match the length of axes
assert_raises(ValueError, pywt.dwtn, data, wavelets[:2])
assert_raises(ValueError, pywt.dwtn, data, wavelets, mode=modes[:2])
assert_raises(ValueError, pywt.idwtn, coefs, wavelets[:2])
assert_raises(ValueError, pywt.idwtn, coefs, wavelets, mode=modes[:2])
# dwt2/idwt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.dwt2(data2, wavelets[:2], modes[:2])
assert_allclose(pywt.idwt2(coefs2, wavelets[:2], modes[:2]), data2,
atol=1e-14)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_fun, rec_fun in zip([pywt.dwt2, pywt.dwtn],
[pywt.idwt2, pywt.idwtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_fun, data, wavelet=cwave)
c = dec_fun(data, 'db1')
assert_raises(ValueError, rec_fun, c, wavelet=cwave)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
"""
Verify DWT perfect reconstruction.
"""
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_
import pywt
def test_perfect_reconstruction():
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
('constant', 'sp0'),
('symmetric', 'sym'),
('periodic', 'ppd'),
('smooth', 'sp1'),
('periodization', 'per')]
dtypes = (np.float32, np.float64)
for wavelet in wavelets:
for pmode, mmode in modes:
for dt in dtypes:
check_reconstruction(pmode, mmode, wavelet, dt)
def check_reconstruction(pmode, mmode, wavelet, dtype):
data_size = list(range(2, 40)) + [100, 200, 500, 1000, 2000, 10000,
50000, 100000]
np.random.seed(12345)
# TODO: smoke testing - more failures for different seeds
if dtype == np.float32:
# was 3e-7 has to be lowered as db21, db29, db33, db35, coif14, coif16 were failing
epsilon = 6e-7
else:
epsilon = 5e-11
for N in data_size:
data = np.asarray(np.random.random(N), dtype)
# compute dwt coefficients
pa, pd = pywt.dwt(data, wavelet, pmode)
# compute reconstruction
rec = pywt.idwt(pa, pd, wavelet, pmode)
if len(data) % 2:
rec = rec[:len(data)]
rms_rec = np.sqrt(np.mean((data-rec)**2))
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
assert_(rms_rec < epsilon, msg=msg)

View file

@ -0,0 +1,633 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import warnings
from copy import deepcopy
from itertools import combinations, permutations
import numpy as np
import pytest
from numpy.testing import (assert_allclose, assert_, assert_equal,
assert_raises, assert_array_equal, assert_warns)
import pywt
from pywt._extensions._swt import swt_axis
# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
dtypes_in = [np.int8, np.float16, np.float32, np.float64, np.complex64,
np.complex128]
dtypes_out = [np.float64, np.float32, np.float32, np.float64, np.complex64,
np.complex128]
# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
####
# 1d multilevel swt tests
####
def test_swt_decomposition():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
atol = tol_double
(cA3, cD3), (cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=3)
expected_cA1 = [7.07106781, 5.65685425, 2.82842712, 0.70710678,
2.82842712, 7.07106781, 7.07106781, 6.36396103]
assert_allclose(cA1, expected_cA1, rtol=1e-8, atol=atol)
expected_cD1 = [-2.82842712, 4.24264069, -1.41421356, 3.53553391,
-5.65685425, 1.41421356, -1.41421356, 2.12132034]
assert_allclose(cD1, expected_cD1, rtol=1e-8, atol=atol)
expected_cA2 = [7, 4.5, 4, 5.5, 7, 9.5, 10, 8.5]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
expected_cD2 = [3, 3.5, 0, -4.5, -3, 0.5, 0, 0.5]
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
expected_cA3 = [9.89949494, ] * 8
assert_allclose(cA3, expected_cA3, rtol=1e-8, atol=atol)
expected_cD3 = [0.00000000, -3.53553391, -4.24264069, -2.12132034,
0.00000000, 3.53553391, 4.24264069, 2.12132034]
assert_allclose(cD3, expected_cD3, rtol=1e-8, atol=atol)
# level=1, start_level=1 decomposition should match level=2
res = pywt.swt(cA1, db1, level=1, start_level=1)
cA2, cD2 = res[0]
assert_allclose(cA2, expected_cA2, rtol=tol_double, atol=atol)
assert_allclose(cD2, expected_cD2, rtol=tol_double, atol=atol)
coeffs = pywt.swt(x, db1)
assert_(len(coeffs) == 3)
assert_(pywt.swt_max_level(len(x)), 3)
def test_swt_max_level():
# odd sized signal will warn about no levels of decomposition possible
assert_warns(UserWarning, pywt.swt_max_level, 11)
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
assert_equal(pywt.swt_max_level(11), 0)
# no warnings when >= 1 level of decomposition possible
assert_equal(pywt.swt_max_level(2), 1) # divisible by 2**1
assert_equal(pywt.swt_max_level(4*3), 2) # divisible by 2**2
assert_equal(pywt.swt_max_level(16), 4) # divisible by 2**4
assert_equal(pywt.swt_max_level(16*3), 4) # divisible by 2**4
def test_swt_axis():
x = [3, 7, 1, 3, -2, 6, 4, 6]
db1 = pywt.Wavelet('db1')
(cA2, cD2), (cA1, cD1) = pywt.swt(x, db1, level=2)
# test cases use 2D arrays based on tiling x along an axis and then
# calling swt along the other axis.
for order in ['C', 'F']:
# test SWT of 2D data along default axis (-1)
x_2d = np.asarray(x).reshape((1, -1))
x_2d = np.concatenate((x_2d, )*5, axis=0)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each row should match the 1D result
for row in cA1_2d:
assert_array_equal(row, cA1)
for row in cA2_2d:
assert_array_equal(row, cA2)
for row in cD1_2d:
assert_array_equal(row, cD1)
for row in cD2_2d:
assert_array_equal(row, cD2)
# test SWT of 2D data along other axis (0)
x_2d = np.asarray(x).reshape((-1, 1))
x_2d = np.concatenate((x_2d, )*5, axis=1)
if order == 'C':
x_2d = np.ascontiguousarray(x_2d)
elif order == 'F':
x_2d = np.asfortranarray(x_2d)
(cA2_2d, cD2_2d), (cA1_2d, cD1_2d) = pywt.swt(x_2d, db1, level=2,
axis=0)
for c in [cA2_2d, cD2_2d, cA1_2d, cD1_2d]:
assert_(c.shape == x_2d.shape)
# each column should match the 1D result
for row in cA1_2d.transpose((1, 0)):
assert_array_equal(row, cA1)
for row in cA2_2d.transpose((1, 0)):
assert_array_equal(row, cA2)
for row in cD1_2d.transpose((1, 0)):
assert_array_equal(row, cD1)
for row in cD2_2d.transpose((1, 0)):
assert_array_equal(row, cD2)
# axis too large
assert_raises(ValueError, pywt.swt, x, db1, level=2, axis=5)
def test_swt_iswt_integration():
# This function performs a round-trip swt/iswt transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt or iswt as such, but it does ensure
# that they are each other's inverse.
max_level = 3
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet seems to be a bit special - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-7)
def test_swt_dtypes():
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
errmsg = "wrong dtype returned for {0} input".format(dt_in)
# swt
x = np.ones(8, dtype=dt_in)
(cA2, cD2), (cA1, cD1) = pywt.swt(x, wavelet, level=2)
assert_(cA2.dtype == cD2.dtype == cA1.dtype == cD1.dtype == dt_out,
"swt: " + errmsg)
# swt2
x = np.ones((8, 8), dtype=dt_in)
cA, (cH, cV, cD) = pywt.swt2(x, wavelet, level=1)[0]
assert_(cA.dtype == cH.dtype == cV.dtype == cD.dtype == dt_out,
"swt2: " + errmsg)
def test_swt_roundtrip_dtypes():
# verify perfect reconstruction for all dtypes
rstate = np.random.RandomState(5)
wavelet = pywt.Wavelet('haar')
for dt_in, dt_out in zip(dtypes_in, dtypes_out):
# swt, iswt
x = rstate.standard_normal((8, )).astype(dt_in)
c = pywt.swt(x, wavelet, level=2)
xr = pywt.iswt(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
# swt2, iswt2
x = rstate.standard_normal((8, 8)).astype(dt_in)
c = pywt.swt2(x, wavelet, level=2)
xr = pywt.iswt2(c, wavelet)
assert_allclose(x, xr, rtol=1e-6, atol=1e-7)
def test_swt_default_level_by_axis():
# make sure default number of levels matches the max level along the axis
wav = 'db2'
x = np.ones((2**3, 2**4, 2**5))
for axis in (0, 1, 2):
sdec = pywt.swt(x, wav, level=None, start_level=0, axis=axis)
assert_equal(len(sdec), pywt.swt_max_level(x.shape[axis]))
def test_swt2_ndim_error():
x = np.ones(8)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
@pytest.mark.slow
def test_swt2_iswt2_integration(wavelets=None):
# This function performs a round-trip swt2/iswt2 transform test on
# all available types of wavelets in PyWavelets - except the
# 'dmey' wavelet. The latter has been excluded because it does not
# produce very precise results. This is likely due to the fact
# that the 'dmey' wavelet is a discrete approximation of a
# continuous wavelet. All wavelets are tested up to 3 levels. The
# test validates neither swt2 or iswt2 as such, but it does ensure
# that they are each other's inverse.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for current_wavelet_str in wavelets:
current_wavelet = pywt.Wavelet(current_wavelet_str)
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power + max_level - 1)
X = np.arange(input_length**2).reshape(input_length, input_length)
for norm in [True, False]:
if norm and not current_wavelet.orthogonal:
# non-orthogonal wavelets to avoid warnings when norm=True
continue
for trim_approx in [True, False]:
coeffs = pywt.swt2(X, current_wavelet, max_level,
trim_approx=trim_approx, norm=norm)
Y = pywt.iswt2(coeffs, current_wavelet, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
def test_swt2_iswt2_quick():
test_swt2_iswt2_integration(wavelets=['db1', ])
def test_swt2_iswt2_non_square(wavelets=None):
for nrows in [8, 16, 48]:
X = np.arange(nrows*32).reshape(nrows, 32)
current_wavelet = 'db1'
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
coeffs = pywt.swt2(X, current_wavelet, level=2)
Y = pywt.iswt2(coeffs, current_wavelet)
assert_allclose(Y, X, rtol=tol_single, atol=tol_single)
def test_swt2_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
(cA1, (cH1, cV1, cD1)) = pywt.swt2(X, current_wavelet, level=1)[0]
# opposite order
(cA2, (cH2, cV2, cD2)) = pywt.swt2(X, current_wavelet, level=1,
axes=(1, 0))[0]
assert_allclose(cA1, cA2, atol=atol)
assert_allclose(cH1, cV2, atol=atol)
assert_allclose(cV1, cH2, atol=atol)
assert_allclose(cD1, cD2, atol=atol)
# duplicate axes not allowed
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1,
axes=(0, 0))
# too few axes
assert_raises(ValueError, pywt.swt2, X, current_wavelet, 1, axes=(0, ))
def test_iswt2_2d_only():
# iswt2 is not currently compatible with data that is not 2D
x_3d = np.ones((4, 4, 4))
c = pywt.swt2(x_3d, 'haar', level=1)
assert_raises(ValueError, pywt.iswt2, c, 'haar')
def test_swtn_axes():
atol = 1e-14
current_wavelet = pywt.Wavelet('db2')
input_length_power = int(np.ceil(np.log2(max(
current_wavelet.dec_len,
current_wavelet.rec_len))))
input_length = 2**(input_length_power)
X = np.arange(input_length**2).reshape(input_length, input_length)
coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
# opposite order
coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)
# 0-level transform
empty = pywt.swtn(X, current_wavelet, level=0)
assert_equal(empty, [])
# duplicate axes not allowed
assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))
# data.ndim = 0
assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)
# start_level too large
assert_raises(ValueError, pywt.swtn, X, current_wavelet,
level=1, start_level=2)
# level < 1 in swt_axis call
assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
start_level=0)
# odd-sized data not allowed
assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
start_level=0, axis=0)
@pytest.mark.slow
def test_swtn_iswtn_integration(wavelets=None):
# This function performs a round-trip swtn/iswtn transform for various
# possible combinations of:
# 1.) 1 out of 2 axes of a 2D array
# 2.) 2 out of 3 axes of a 3D array
#
# To keep test time down, only wavelets of length <= 8 are run.
#
# This test does not validate swtn or iswtn individually, but only
# confirms that iswtn yields an (almost) perfect reconstruction of swtn.
max_level = 3
if wavelets is None:
wavelets = pywt.wavelist(kind='discrete')
if 'dmey' in wavelets:
# The 'dmey' wavelet is a special case - disregard it for now
wavelets.remove('dmey')
for ndim_transform in range(1, 3):
ndim = ndim_transform + 1
for axes in combinations(range(ndim), ndim_transform):
for current_wavelet_str in wavelets:
wav = pywt.Wavelet(current_wavelet_str)
if wav.dec_len > 8:
continue # avoid excessive test duration
input_length_power = int(np.ceil(np.log2(max(
wav.dec_len,
wav.rec_len))))
N = 2**(input_length_power + max_level - 1)
X = np.arange(N**ndim).reshape((N, )*ndim)
for norm in [True, False]:
if norm and not wav.orthogonal:
# non-orthogonal wavelets to avoid warnings
continue
for trim_approx in [True, False]:
coeffs = pywt.swtn(X, wav, max_level, axes=axes,
trim_approx=trim_approx, norm=norm)
coeffs_copy = deepcopy(coeffs)
Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
assert_allclose(Y, X, rtol=1e-5, atol=1e-5)
# verify the inverse transform didn't modify any coeffs
for c, c2 in zip(coeffs, coeffs_copy):
for k, v in c.items():
assert_array_equal(c2[k], v)
def test_swtn_iswtn_quick():
test_swtn_iswtn_integration(wavelets=['db1', ])
def test_iswtn_errors():
x = np.arange(8**3).reshape(8, 8, 8)
max_level = 2
axes = (0, 1)
w = pywt.Wavelet('db1')
coeffs = pywt.swtn(x, w, max_level, axes=axes)
# more axes than dimensions transformed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
# duplicate axes not allowed
assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
# mismatched coefficient size
coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
def test_swtn_iswtn_unique_shape_per_axis():
# test case for gh-460
_shape = (1, 48, 32) # unique shape per axis
wav = 'sym2'
max_level = 3
rstate = np.random.RandomState(0)
for shape in permutations(_shape):
# transform only along the non-singleton axes
axes = [ax for ax, s in enumerate(shape) if s != 1]
x = rstate.standard_normal(shape)
c = pywt.swtn(x, wav, max_level, axes=axes)
r = pywt.iswtn(c, wav, axes=axes)
assert_allclose(x, r, rtol=1e-10, atol=1e-10)
def test_per_axis_wavelets():
# tests seperate wavelet for each axis.
rstate = np.random.RandomState(1234)
data = rstate.randn(16, 16, 16)
level = 3
# wavelet can be a string or wavelet object
wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')
coefs = pywt.swtn(data, wavelets, level=level)
assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)
# 1-tuple also okay
coefs = pywt.swtn(data, wavelets[:1], level=level)
assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)
# length of wavelets doesn't match the length of axes
assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
# swt2/iswt2 also support per-axis wavelets/modes
data2 = data[..., 0]
coefs2 = pywt.swt2(data2, wavelets[:2], level)
assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
def test_error_on_continuous_wavelet():
# A ValueError is raised if a Continuous wavelet is selected
data = np.ones((16, 16))
for dec_func, rec_func in zip([pywt.swt, pywt.swt2, pywt.swtn],
[pywt.iswt, pywt.iswt2, pywt.iswtn]):
for cwave in ['morl', pywt.DiscreteContinuousWavelet('morl')]:
assert_raises(ValueError, dec_func, data, wavelet=cwave,
level=3)
c = dec_func(data, 'db1', level=3)
assert_raises(ValueError, rec_func, c, wavelet=cwave)
def test_iswt_mixed_dtypes():
# Mixed precision inputs give double precision output
x_real = np.arange(16).astype(np.float64)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
coeffs[0][1].astype(dtype2)]
y = pywt.iswt(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswt2_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swt2(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
y = pywt.iswt2(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_iswtn_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:
if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64
coeffs = pywt.swtn(x, wav, 2)
# different precision for the approximation coefficients
a = coeffs[0].pop('a' * x.ndim)
a = a.astype(dtype1)
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
coeffs[0]['a' * x.ndim] = a
y = pywt.iswtn(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def test_swt_zero_size_axes():
# raise on empty input array
assert_raises(ValueError, pywt.swt, [], 'db2')
# >1D case uses a different code path so check there as well
x = np.ones((1, 4))[0:0, :] # 2D with a size zero axis
assert_raises(ValueError, pywt.swtn, x, 'db2', level=1, axes=(0,))
def test_swt_variance_and_energy_preservation():
"""Verify that the 1D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(256)
coeffs = pywt.swt(x, wav, trim_approx=True, norm=True)
variances = [np.var(c) for c in coeffs]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeffs)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt, x, 'bior2.2', norm=True)
def test_swt2_variance_and_energy_preservation():
"""Verify that the 2D SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swt2(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for v in d:
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swt2, x, 'bior2.2', level=4, norm=True)
def test_swtn_variance_and_energy_preservation():
"""Verify that the nD SWT partitions variance among the coefficients."""
# When norm is True and the wavelet is orthogonal, the sum of the
# variances of the coefficients should equal the variance of the signal.
wav = 'db2'
rstate = np.random.RandomState(5)
x = rstate.randn(64, 64)
coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
coeff_list = [coeffs[0].ravel()]
for d in coeffs[1:]:
for k, v in d.items():
coeff_list.append(v.ravel())
variances = [np.var(v) for v in coeff_list]
assert_allclose(np.sum(variances), np.var(x))
# also verify L2-norm energy preservation property
assert_allclose(np.linalg.norm(x),
np.linalg.norm(np.concatenate(coeff_list)))
# non-orthogonal wavelet with norm=True raises a warning
assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
def test_swt_ravel_and_unravel():
# When trim_approx=True, all swt functions can user pywt.ravel_coeffs
for ndim, _swt, _iswt, ravel_type in [
(1, pywt.swt, pywt.iswt, 'swt'),
(2, pywt.swt2, pywt.iswt2, 'swt2'),
(3, pywt.swtn, pywt.iswtn, 'swtn')]:
x = np.ones((16, ) * ndim)
c = _swt(x, 'sym2', level=3, trim_approx=True)
arr, slices, shapes = pywt.ravel_coeffs(c)
c = pywt.unravel_coeffs(arr, slices, shapes, output_format=ravel_type)
r = _iswt(c, 'sym2')
assert_allclose(x, r)

View file

@ -0,0 +1,169 @@
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
import pywt
float_dtypes = [np.float32, np.float64, np.complex64, np.complex128]
real_dtypes = [np.float32, np.float64]
def _sign(x):
# Matlab-like sign function (numpy uses a different convention).
return x / np.abs(x)
def _soft(x, thresh):
"""soft thresholding supporting complex values.
Notes
-----
This version is not robust to zeros in x.
"""
return _sign(x) * np.maximum(np.abs(x) - thresh, 0)
def test_threshold():
data = np.linspace(1, 4, 7)
# soft
soft_result = [0., 0., 0., 0.5, 1., 1.5, 2.]
assert_allclose(pywt.threshold(data, 2, 'soft'),
np.array(soft_result), rtol=1e-12)
assert_allclose(pywt.threshold(-data, 2, 'soft'),
-np.array(soft_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'soft'),
[[0, 1]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'soft'),
[[0, 0]] * 2, rtol=1e-12)
# soft thresholding complex values
assert_allclose(pywt.threshold([[1j, 2j]] * 2, 1, 'soft'),
[[0j, 1j]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 6, 'soft'),
[[0, 0]] * 2, rtol=1e-12)
complex_data = [[1+2j, 2+2j]]*2
for thresh in [1, 2]:
assert_allclose(pywt.threshold(complex_data, thresh, 'soft'),
_soft(complex_data, thresh), rtol=1e-12)
# test soft thresholding with non-default substitute argument
s = 5
assert_allclose(pywt.threshold([[1j, 2]] * 2, 1.5, 'soft', substitute=s),
[[s, 0.5]] * 2, rtol=1e-12)
# soft: no divide by zero warnings when input contains zeros
assert_allclose(pywt.threshold(np.zeros(16), 2, 'soft'),
np.zeros(16), rtol=1e-12)
# hard
hard_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
assert_allclose(pywt.threshold(data, 2, 'hard'),
np.array(hard_result), rtol=1e-12)
assert_allclose(pywt.threshold(-data, 2, 'hard'),
-np.array(hard_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'hard'),
[[1, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard'),
[[0, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'hard', substitute=s),
[[s, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1+1j, 2+2j]] * 2, 2, 'hard'),
[[0, 2+2j]] * 2, rtol=1e-12)
# greater
greater_result = [0., 0., 2., 2.5, 3., 3.5, 4.]
assert_allclose(pywt.threshold(data, 2, 'greater'),
np.array(greater_result), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'greater'),
[[1, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater'),
[[0, 2]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'greater', substitute=s),
[[s, 2]] * 2, rtol=1e-12)
# greater doesn't allow complex-valued inputs
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'greater')
# less
assert_allclose(pywt.threshold(data, 2, 'less'),
np.array([1., 1.5, 2., 0., 0., 0., 0.]), rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less'),
[[1, 0]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 1, 'less', substitute=s),
[[1, s]] * 2, rtol=1e-12)
assert_allclose(pywt.threshold([[1, 2]] * 2, 2, 'less'),
[[1, 2]] * 2, rtol=1e-12)
# less doesn't allow complex-valued inputs
assert_raises(ValueError, pywt.threshold, [1j, 2j], 2, 'less')
# invalid
assert_raises(ValueError, pywt.threshold, data, 2, 'foo')
def test_nonnegative_garotte():
thresh = 0.3
data_real = np.linspace(-1, 1, 100)
for dtype in float_dtypes:
if dtype in real_dtypes:
data = np.asarray(data_real, dtype=dtype)
else:
data = np.asarray(data_real + 0.1j, dtype=dtype)
d_hard = pywt.threshold(data, thresh, 'hard')
d_soft = pywt.threshold(data, thresh, 'soft')
d_garotte = pywt.threshold(data, thresh, 'garotte')
# check dtypes
assert_equal(d_hard.dtype, data.dtype)
assert_equal(d_soft.dtype, data.dtype)
assert_equal(d_garotte.dtype, data.dtype)
# values < threshold are zero
lt = np.where(np.abs(data) < thresh)
assert_(np.all(d_garotte[lt] == 0))
# values > than the threshold are intermediate between soft and hard
gt = np.where(np.abs(data) > thresh)
gt_abs_garotte = np.abs(d_garotte[gt])
assert_(np.all(gt_abs_garotte < np.abs(d_hard[gt])))
assert_(np.all(gt_abs_garotte > np.abs(d_soft[gt])))
def test_threshold_firm():
thresh = 0.2
thresh2 = 3 * thresh
data_real = np.linspace(-1, 1, 100)
for dtype in float_dtypes:
if dtype in real_dtypes:
data = np.asarray(data_real, dtype=dtype)
else:
data = np.asarray(data_real + 0.1j, dtype=dtype)
if data.real.dtype == np.float32:
rtol = atol = 1e-6
else:
rtol = atol = 1e-14
d_hard = pywt.threshold(data, thresh, 'hard')
d_soft = pywt.threshold(data, thresh, 'soft')
d_firm = pywt.threshold_firm(data, thresh, thresh2)
# check dtypes
assert_equal(d_hard.dtype, data.dtype)
assert_equal(d_soft.dtype, data.dtype)
assert_equal(d_firm.dtype, data.dtype)
# values < threshold are zero
lt = np.where(np.abs(data) < thresh)
assert_(np.all(d_firm[lt] == 0))
# values > than the threshold are equal to hard-thresholding
gt = np.where(np.abs(data) >= thresh2)
assert_allclose(np.abs(d_hard[gt]), np.abs(d_firm[gt]),
rtol=rtol, atol=atol)
# other values are intermediate between soft and hard thresholding
mt = np.where(np.logical_and(np.abs(data) > thresh,
np.abs(data) < thresh2))
mt_abs_firm = np.abs(d_firm[mt])
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))

View file

@ -0,0 +1,266 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import assert_allclose, assert_
import pywt
def test_wavelet_properties():
w = pywt.Wavelet('db3')
# Name
assert_(w.name == 'db3')
assert_(w.short_family_name == 'db')
assert_(w.family_name, 'Daubechies')
# String representation
fields = ('Family name', 'Short name', 'Filters length', 'Orthogonal',
'Biorthogonal', 'Symmetry')
for field in fields:
assert_(field in str(w))
# Filter coefficients
dec_lo = [0.03522629188210, -0.08544127388224, -0.13501102001039,
0.45987750211933, 0.80689150931334, 0.33267055295096]
dec_hi = [-0.33267055295096, 0.80689150931334, -0.45987750211933,
-0.13501102001039, 0.08544127388224, 0.03522629188210]
rec_lo = [0.33267055295096, 0.80689150931334, 0.45987750211933,
-0.13501102001039, -0.08544127388224, 0.03522629188210]
rec_hi = [0.03522629188210, 0.08544127388224, -0.13501102001039,
-0.45987750211933, 0.80689150931334, -0.33267055295096]
assert_allclose(w.dec_lo, dec_lo)
assert_allclose(w.dec_hi, dec_hi)
assert_allclose(w.rec_lo, rec_lo)
assert_allclose(w.rec_hi, rec_hi)
assert_(len(w.filter_bank) == 4)
# Orthogonality
assert_(w.orthogonal)
assert_(w.biorthogonal)
# Symmetry
assert_(w.symmetry)
# Vanishing moments
assert_(w.vanishing_moments_phi == 0)
assert_(w.vanishing_moments_psi == 3)
def test_wavelet_coefficients():
families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])
for wavelet in wavelets:
if (pywt.Wavelet(wavelet).orthogonal):
check_coefficients_orthogonal(wavelet)
elif(pywt.Wavelet(wavelet).biorthogonal):
check_coefficients_biorthogonal(wavelet)
else:
check_coefficients(wavelet)
def check_coefficients_orthogonal(wavelet):
epsilon = 5e-11
level = 5
w = pywt.Wavelet(wavelet)
phi, psi, x = w.wavefun(level=level)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Scaling function integrates to unity
res = np.sum(phi) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Wavelet function is orthogonal to the scaling function at the same scale
res = np.sum(phi*psi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# The lowpass and highpass filter coefficients are orthogonal
res = np.sum(np.array(w.dec_lo)*np.array(w.dec_hi))
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
def check_coefficients_biorthogonal(wavelet):
epsilon = 5e-11
level = 5
w = pywt.Wavelet(wavelet)
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=level)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Scaling function integrates to unity
res = np.sum(phi_d) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(phi_r) - 2**level
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
def check_coefficients(wavelet):
epsilon = 5e-11
level = 10
w = pywt.Wavelet(wavelet)
# Lowpass filter coefficients sum to sqrt2
res = np.sum(w.dec_lo)-np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# sum even coef = sum odd coef = 1 / sqrt(2)
res = np.sum(w.dec_lo[::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
res = np.sum(w.dec_lo[1::2])-1./np.sqrt(2)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
# Highpass filter coefficients sum to zero
res = np.sum(w.dec_hi)
msg = ('[RMS_REC > EPSILON] for Wavelet: %s, rms=%.3g' % (wavelet, res))
assert_(res < epsilon, msg=msg)
class _CustomHaarFilterBank(object):
@property
def filter_bank(self):
val = np.sqrt(2) / 2
return ([val]*2, [-val, val], [val]*2, [val, -val])
def test_custom_wavelet():
haar_custom1 = pywt.Wavelet('Custom Haar Wavelet',
filter_bank=_CustomHaarFilterBank())
haar_custom1.orthogonal = True
haar_custom1.biorthogonal = True
val = np.sqrt(2) / 2
filter_bank = ([val]*2, [-val, val], [val]*2, [val, -val])
haar_custom2 = pywt.Wavelet('Custom Haar Wavelet',
filter_bank=filter_bank)
# check expected default wavelet properties
assert_(~haar_custom2.orthogonal)
assert_(~haar_custom2.biorthogonal)
assert_(haar_custom2.symmetry == 'unknown')
assert_(haar_custom2.family_name == '')
assert_(haar_custom2.short_family_name == '')
assert_(haar_custom2.vanishing_moments_phi == 0)
assert_(haar_custom2.vanishing_moments_psi == 0)
# Some properties can be set by the user
haar_custom2.orthogonal = True
haar_custom2.biorthogonal = True
def test_wavefun_sym3():
w = pywt.Wavelet('sym3')
# sym3 is an orthogonal wavelet, so 3 outputs from wavefun
phi, psi, x = w.wavefun(level=3)
assert_(phi.size == 41)
assert_(psi.size == 41)
assert_(x.size == 41)
assert_allclose(x, np.linspace(0, 5, num=x.size))
phi_expect = np.array([0.00000000e+00, 1.04132926e-01, 2.52574126e-01,
3.96525521e-01, 5.70356539e-01, 7.18934305e-01,
8.70293448e-01, 1.05363620e+00, 1.24921722e+00,
1.15296888e+00, 9.41669683e-01, 7.55875887e-01,
4.96118565e-01, 3.28293151e-01, 1.67624969e-01,
-7.33690312e-02, -3.35452855e-01, -3.31221131e-01,
-2.32061503e-01, -1.66854239e-01, -4.34091324e-02,
-2.86152390e-02, -3.63563035e-02, 2.06034491e-02,
8.30280254e-02, 7.17779073e-02, 3.85914311e-02,
1.47527100e-02, -2.31896077e-02, -1.86122172e-02,
-1.56211329e-03, -8.70615088e-04, 3.20760857e-03,
2.34142153e-03, -7.73737194e-04, -2.99879354e-04,
1.23636238e-04, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00])
psi_expect = np.array([0.00000000e+00, 1.10265752e-02, 2.67449277e-02,
4.19878574e-02, 6.03947231e-02, 7.61275365e-02,
9.21548684e-02, 1.11568926e-01, 1.32278887e-01,
6.45829680e-02, -3.97635130e-02, -1.38929884e-01,
-2.62428322e-01, -3.62246804e-01, -4.62843343e-01,
-5.89607507e-01, -7.25363076e-01, -3.36865858e-01,
2.67715108e-01, 8.40176767e-01, 1.55574430e+00,
1.18688954e+00, 4.20276324e-01, -1.51697311e-01,
-9.42076108e-01, -7.93172332e-01, -3.26343710e-01,
-1.24552779e-01, 2.12909254e-01, 1.75770320e-01,
1.47523075e-02, 8.22192707e-03, -3.02920592e-02,
-2.21119497e-02, 7.30703025e-03, 2.83200488e-03,
-1.16759765e-03, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00])
assert_allclose(phi, phi_expect)
assert_allclose(psi, psi_expect)
def test_wavefun_bior13():
w = pywt.Wavelet('bior1.3')
# bior1.3 is not an orthogonal wavelet, so 5 outputs from wavefun
phi_d, psi_d, phi_r, psi_r, x = w.wavefun(level=3)
for arr in [phi_d, psi_d, phi_r, psi_r]:
assert_(arr.size == 40)
phi_d_expect = np.array([0., -0.00195313, 0.00195313, 0.01757813,
0.01367188, 0.00390625, -0.03515625, -0.12890625,
-0.15234375, -0.125, -0.09375, -0.0625, 0.03125,
0.15234375, 0.37890625, 0.78515625, 0.99609375,
1.08203125, 1.13671875, 1.13671875, 1.08203125,
0.99609375, 0.78515625, 0.37890625, 0.15234375,
0.03125, -0.0625, -0.09375, -0.125, -0.15234375,
-0.12890625, -0.03515625, 0.00390625, 0.01367188,
0.01757813, 0.00195313, -0.00195313, 0., 0., 0.])
phi_r_expect = np.zeros(x.size, dtype=np.float)
phi_r_expect[15:23] = 1
psi_d_expect = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
0.015625, -0.015625, -0.140625, -0.109375,
-0.03125, 0.28125, 1.03125, 1.21875, 1.125, 0.625,
-0.625, -1.125, -1.21875, -1.03125, -0.28125,
0.03125, 0.109375, 0.140625, 0.015625, -0.015625,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
psi_r_expect = np.zeros(x.size, dtype=np.float)
psi_r_expect[7:15] = -0.125
psi_r_expect[15:19] = 1
psi_r_expect[19:23] = -1
psi_r_expect[23:31] = 0.125
assert_allclose(x, np.linspace(0, 5, x.size, endpoint=False))
assert_allclose(phi_d, phi_d_expect, rtol=1e-5, atol=1e-9)
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)

View file

@ -0,0 +1,197 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_equal)
import pywt
def test_wavelet_packet_structure():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
assert_(wp.path == '')
assert_(wp.level == 0)
assert_(wp['ad'].maxlevel == 3)
def test_traversing_wp_tree():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp.maxlevel == 3)
# First level
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
7.778174593052, 10.606601717798]),
rtol=1e-12)
# Second level
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
# Third level
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
def test_acess_path():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_(wp['a'].path == 'a')
assert_(wp['aa'].path == 'aa')
assert_(wp['aaa'].path == 'aaa')
# Maximum level reached:
assert_raises(IndexError, lambda: wp['aaaa'].path)
# Wrong path
assert_raises(ValueError, lambda: wp['ac'].path)
def test_access_node_atributes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
assert_(wp['ad'].path == 'ad')
assert_(wp['ad'].node_name == 'd')
assert_(wp['ad'].parent.path == 'a')
assert_(wp['ad'].level == 2)
assert_(wp['ad'].maxlevel == 3)
assert_(wp['ad'].mode == 'symmetric')
def test_collecting_nodes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# All nodes in natural order
assert_([node.path for node in wp.get_level(3, 'natural')] ==
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
# and in frequency order.
assert_([node.path for node in wp.get_level(3, 'freq')] ==
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
def test_reconstructing_data():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# Create another Wavelet Packet and feed it with some data.
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
new_wp['aa'] = wp['aa'].data
new_wp['ad'] = [-2., -2.]
# For convenience, :attr:`Node.data` gets automatically extracted
# from the :class:`Node` object:
new_wp['d'] = wp['d']
# Reconstruct data from aa, ad, and d packets.
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
# The node's :attr:`~Node.data` will not be updated
assert_(new_wp.data is None)
# When `update` is True:
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
['aa', 'ad', 'd'])
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
def test_removing_nodes():
x = [1, 2, 3, 4, 5, 6, 7, 8]
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
wp.get_level(2)
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
for i in range(4):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
node = wp['ad']
del(wp['ad'])
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
for i in range(3):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
wp.reconstruct()
# The reconstruction is:
assert_allclose(wp.reconstruct(),
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
# Restore the data
wp['ad'].data = node.data
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
for i in range(4):
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
def test_wavelet_packet_dtypes():
N = 32
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
x = np.random.randn(N).astype(dtype)
if np.iscomplexobj(x):
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
# no unnecessary copy made
assert_(wp.data is x)
# assiging to a node should not change supported dtypes
wp['d'] = wp['d'].data
assert_equal(wp['d'].data.dtype, x.dtype)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstruction from coefficients should preserve dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, x.dtype)
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
# first element of the tuple is the input dtype
# second element of the tuple is the transform dtype
dtype_pairs = [(np.uint8, np.float64),
(np.intp, np.float64), ]
if hasattr(np, "complex256"):
dtype_pairs += [(np.complex256, np.complex128), ]
if hasattr(np, "half"):
dtype_pairs += [(np.half, np.float32), ]
for (dtype, transform_dtype) in dtype_pairs:
x = np.arange(N, dtype=dtype)
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
# no unnecessary copy made of top-level data
assert_(wp.data is x)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstructed data will have modified dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, transform_dtype)
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
def test_db3_roundtrip():
original = np.arange(512)
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)

View file

@ -0,0 +1,177 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
assert_equal)
import pywt
def test_traversing_tree_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(np.all(wp.data == x))
assert_(wp.path == '')
assert_(wp.level == 0)
assert_(wp.maxlevel == 3)
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_allclose(wp['h'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['v'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['d'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
assert_allclose(wp['aa'].data, np.array([[10., 26.]] * 2), rtol=1e-12)
assert_(wp['a']['a'].data is wp['aa'].data)
assert_allclose(wp['aaa'].data, np.array([[36.]]), rtol=1e-12)
assert_raises(IndexError, lambda: wp['aaaa'])
assert_raises(ValueError, lambda: wp['f'])
def test_accessing_node_atributes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_allclose(wp['av'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
assert_(wp['av'].path == 'av')
assert_(wp['av'].node_name == 'v')
assert_(wp['av'].parent.path == 'a')
assert_allclose(wp['av'].parent.data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_(wp['av'].level == 2)
assert_(wp['av'].maxlevel == 3)
assert_(wp['av'].mode == 'symmetric')
def test_collecting_nodes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(len(wp.get_level(0)) == 1)
assert_(wp.get_level(0)[0].path == '')
# First level
assert_(len(wp.get_level(1)) == 4)
assert_([node.path for node in wp.get_level(1)] == ['a', 'h', 'v', 'd'])
# Second level
assert_(len(wp.get_level(2)) == 16)
paths = [node.path for node in wp.get_level(2)]
expected_paths = ['aa', 'ah', 'av', 'ad', 'ha', 'hh', 'hv', 'hd', 'va',
'vh', 'vv', 'vd', 'da', 'dh', 'dv', 'dd']
assert_(paths == expected_paths)
# Third level.
assert_(len(wp.get_level(3)) == 64)
paths = [node.path for node in wp.get_level(3)]
expected_paths = ['aaa', 'aah', 'aav', 'aad', 'aha', 'ahh', 'ahv', 'ahd',
'ava', 'avh', 'avv', 'avd', 'ada', 'adh', 'adv', 'add',
'haa', 'hah', 'hav', 'had', 'hha', 'hhh', 'hhv', 'hhd',
'hva', 'hvh', 'hvv', 'hvd', 'hda', 'hdh', 'hdv', 'hdd',
'vaa', 'vah', 'vav', 'vad', 'vha', 'vhh', 'vhv', 'vhd',
'vva', 'vvh', 'vvv', 'vvd', 'vda', 'vdh', 'vdv', 'vdd',
'daa', 'dah', 'dav', 'dad', 'dha', 'dhh', 'dhv', 'dhd',
'dva', 'dvh', 'dvv', 'dvd', 'dda', 'ddh', 'ddv', 'ddd']
assert_(paths == expected_paths)
def test_data_reconstruction_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
new_wp['vh'] = wp['vh'].data
new_wp['vv'] = wp['vh'].data
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
new_wp['h'] = wp['h'] # all zeros
assert_allclose(new_wp.reconstruct(update=False),
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
rtol=1e-12)
assert_allclose(wp['va'].data, np.zeros((2, 2)) - 2, rtol=1e-12)
new_wp['va'] = wp['va'].data
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
def test_data_reconstruction_delete_nodes_2d():
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
new_wp = pywt.WaveletPacket2D(data=None, wavelet='db1', mode='symmetric')
new_wp['vh'] = wp['vh'].data
new_wp['vv'] = wp['vh'].data
new_wp['vd'] = np.zeros((2, 2), dtype=np.float64)
new_wp['a'] = [[3.0, 7.0, 11.0, 15.0]] * 4
new_wp['d'] = np.zeros((4, 4), dtype=np.float64)
new_wp['h'] = wp['h'] # all zeros
assert_allclose(new_wp.reconstruct(update=False),
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
rtol=1e-12)
new_wp['va'] = wp['va'].data
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
del(new_wp['va'])
new_wp['va'] = wp['va'].data
assert_(new_wp.data is None)
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
assert_allclose(new_wp.data, x, rtol=1e-12)
# TODO: decompose=True
def test_lazy_evaluation_2D():
# Note: internal implementation detail not to be relied on. Testing for
# now for backwards compatibility, but this test may be broken in needed.
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
assert_(wp.a is None)
assert_allclose(wp['a'].data, np.array([[3., 7., 11., 15.]] * 4),
rtol=1e-12)
assert_allclose(wp.a.data, np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
assert_allclose(wp.d.data, np.zeros((4, 4)), rtol=1e-12, atol=1e-12)
def test_wavelet_packet_dtypes():
shape = (16, 16)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
x = np.random.randn(*shape).astype(dtype)
if np.iscomplexobj(x):
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
wp = pywt.WaveletPacket2D(data=x, wavelet='db1', mode='symmetric')
# no unnecessary copy made
assert_(wp.data is x)
# assiging to a node should not change supported dtypes
wp['d'] = wp['d'].data
assert_equal(wp['d'].data.dtype, x.dtype)
# full decomposition
wp.get_level(wp.maxlevel)
# reconstruction from coefficients should preserve dtype
r = wp.reconstruct(False)
assert_equal(r.dtype, x.dtype)
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
def test_2d_roundtrip():
# test case corresponding to PyWavelets issue 447
original = pywt.data.camera()
wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth',
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)

View file

@ -0,0 +1,10 @@
# THIS FILE IS GENERATED FROM PYWAVELETS SETUP.PY
short_version = '1.1.1'
version = '1.1.1'
full_version = '1.1.1'
git_revision = '7b2f66b0ef9196fa91bba550a81d1870f5933a36'
release = True
if not release:
version = full_version