# Copyright (c) 2006-2012 Filip Wasilewski
# Copyright (c) 2012-2016 The PyWavelets Developers
#
# See COPYING for license details.
"""
Other wavelet related functions.
"""
from __future__ import division, print_function, absolute_import
import warnings
import numpy as np
from numpy.fft import fft
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
"orthogonal_filter_bank",
"intwave", "centrfrq", "scal2frq", "orthfilt"]
_DEPRECATION_MSG = ("`{old}` has been renamed to `{new}` and will "
"be removed in a future version of pywt.")
def _integrate(arr, step):
integral = np.cumsum(arr)
integral *= step
return integral
def intwave(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='intwave', new='integrate_wavelet')
warnings.warn(msg, DeprecationWarning)
return integrate_wavelet(*args, **kwargs)
def centrfrq(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='centrfrq', new='central_frequency')
warnings.warn(msg, DeprecationWarning)
return central_frequency(*args, **kwargs)
def scal2frq(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='scal2frq', new='scale2frequency')
warnings.warn(msg, DeprecationWarning)
return scale2frequency(*args, **kwargs)
def orthfilt(*args, **kwargs):
msg = _DEPRECATION_MSG.format(old='orthfilt', new='orthogonal_filter_bank')
warnings.warn(msg, DeprecationWarning)
return orthogonal_filter_bank(*args, **kwargs)
def integrate_wavelet(wavelet, precision=8):
"""
Integrate `psi` wavelet function from -Inf to x using the rectangle
integration method.
Parameters
----------
wavelet : Wavelet instance or str
Wavelet to integrate. If a string, should be the name of a wavelet.
precision : int, optional
Precision that will be used for wavelet function
approximation computed with the wavefun(level=precision)
Wavelet's method (default: 8).
Returns
-------
[int_psi, x] :
for orthogonal wavelets
[int_psi_d, int_psi_r, x] :
for other wavelets
Examples
--------
>>> from pywt import Wavelet, integrate_wavelet
>>> wavelet1 = Wavelet('db2')
>>> [int_psi, x] = integrate_wavelet(wavelet1, precision=5)
>>> wavelet2 = Wavelet('bior1.3')
>>> [int_psi_d, int_psi_r, x] = integrate_wavelet(wavelet2, precision=5)
"""
# FIXME: this function should really use scipy.integrate.quad
if type(wavelet) in (tuple, list):
msg = ("Integration of a general signal is deprecated "
"and will be removed in a future version of pywt.")
warnings.warn(msg, DeprecationWarning)
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if type(wavelet) in (tuple, list):
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
step = x[1] - x[0]
return _integrate(psi, step), x
functions_approximations = wavelet.wavefun(precision)
if len(functions_approximations) == 2: # continuous wavelet
psi, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi, step), x
elif len(functions_approximations) == 3: # orthogonal wavelet
phi, psi, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi, step), x
else: # biorthogonal wavelet
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
step = x[1] - x[0]
return _integrate(psi_d, step), _integrate(psi_r, step), x
def central_frequency(wavelet, precision=8):
"""
Computes the central frequency of the `psi` wavelet function.
Parameters
----------
wavelet : Wavelet instance, str or tuple
Wavelet to integrate. If a string, should be the name of a wavelet.
precision : int, optional
Precision that will be used for wavelet function
approximation computed with the wavefun(level=precision)
Wavelet's method (default: 8).
Returns
-------
scalar
"""
if not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
functions_approximations = wavelet.wavefun(precision)
if len(functions_approximations) == 2:
psi, x = functions_approximations
else:
# (psi, x) for (phi, psi, x)
# (psi_d, x) for (phi_d, psi_d, phi_r, psi_r, x)
psi, x = functions_approximations[1], functions_approximations[-1]
domain = float(x[-1] - x[0])
assert domain > 0
index = np.argmax(abs(fft(psi)[1:])) + 2
if index > len(psi) / 2:
index = len(psi) - index + 2
return 1.0 / (domain / (index - 1))
def scale2frequency(wavelet, scale, precision=8):
"""
Parameters
----------
wavelet : Wavelet instance or str
Wavelet to integrate. If a string, should be the name of a wavelet.
scale : scalar
precision : int, optional
Precision that will be used for wavelet function approximation computed
with ``wavelet.wavefun(level=precision)``. Default is 8.
Returns
-------
freq : scalar
"""
return central_frequency(wavelet, precision=precision) / scale
def qmf(filt):
"""
Returns the Quadrature Mirror Filter(QMF).
The magnitude response of QMF is mirror image about `pi/2` of that of the
input filter.
Parameters
----------
filt : array_like
Input filter for which QMF needs to be computed.
Returns
-------
qm_filter : ndarray
Quadrature mirror of the input filter.
"""
qm_filter = np.array(filt)[::-1]
qm_filter[1::2] = -qm_filter[1::2]
return qm_filter
def orthogonal_filter_bank(scaling_filter):
"""
Returns the orthogonal filter bank.
The orthogonal filter bank consists of the HPFs and LPFs at
decomposition and reconstruction stage for the input scaling filter.
Parameters
----------
scaling_filter : array_like
Input scaling filter (father wavelet).
Returns
-------
orth_filt_bank : tuple of 4 ndarrays
The orthogonal filter bank of the input scaling filter in the order :
1] Decomposition LPF
2] Decomposition HPF
3] Reconstruction LPF
4] Reconstruction HPF
"""
if not (len(scaling_filter) % 2 == 0):
raise ValueError("`scaling_filter` length has to be even.")
scaling_filter = np.asarray(scaling_filter, dtype=np.float64)
rec_lo = np.sqrt(2) * scaling_filter / np.sum(scaling_filter)
dec_lo = rec_lo[::-1]
rec_hi = qmf(rec_lo)
dec_hi = rec_hi[::-1]
orth_filt_bank = (dec_lo, dec_hi, rec_lo, rec_hi)
return orth_filt_bank