203 lines
7.5 KiB
Python
203 lines
7.5 KiB
Python
from math import floor, ceil
|
|
|
|
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
|
|
Wavelet, _check_dtype)
|
|
from ._functions import integrate_wavelet, scale2frequency
|
|
|
|
|
|
__all__ = ["cwt"]
|
|
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
# Prefer scipy.fft (new in SciPy 1.4)
|
|
import scipy.fft
|
|
fftmodule = scipy.fft
|
|
next_fast_len = fftmodule.next_fast_len
|
|
except ImportError:
|
|
try:
|
|
import scipy.fftpack
|
|
fftmodule = scipy.fftpack
|
|
next_fast_len = fftmodule.next_fast_len
|
|
except ImportError:
|
|
fftmodule = np.fft
|
|
|
|
# provide a fallback so scipy is an optional requirement
|
|
def next_fast_len(n):
|
|
"""Round up size to the nearest power of two.
|
|
|
|
Given a number of samples `n`, returns the next power of two
|
|
following this number to take advantage of FFT speedup.
|
|
This fallback is less efficient than `scipy.fftpack.next_fast_len`
|
|
"""
|
|
return 2**ceil(np.log2(n))
|
|
|
|
|
|
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
|
|
"""
|
|
cwt(data, scales, wavelet)
|
|
|
|
One dimensional Continuous Wavelet Transform.
|
|
|
|
Parameters
|
|
----------
|
|
data : array_like
|
|
Input signal
|
|
scales : array_like
|
|
The wavelet scales to use. One can use
|
|
``f = scale2frequency(wavelet, scale)/sampling_period`` to determine
|
|
what physical frequency, ``f``. Here, ``f`` is in hertz when the
|
|
``sampling_period`` is given in seconds.
|
|
wavelet : Wavelet object or name
|
|
Wavelet to use
|
|
sampling_period : float
|
|
Sampling period for the frequencies output (optional).
|
|
The values computed for ``coefs`` are independent of the choice of
|
|
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
|
|
period).
|
|
method : {'conv', 'fft'}, optional
|
|
The method used to compute the CWT. Can be any of:
|
|
- ``conv`` uses ``numpy.convolve``.
|
|
- ``fft`` uses frequency domain convolution.
|
|
- ``auto`` uses automatic selection based on an estimate of the
|
|
computational complexity at each scale.
|
|
|
|
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
|
|
The ``fft`` method is ``O(N * log2(N))`` with
|
|
``N = len(scale) + len(data) - 1``. It is well suited for large size
|
|
signals but slightly slower than ``conv`` on small ones.
|
|
axis: int, optional
|
|
Axis over which to compute the CWT. If not given, the last axis is
|
|
used.
|
|
|
|
Returns
|
|
-------
|
|
coefs : array_like
|
|
Continuous wavelet transform of the input signal for the given scales
|
|
and wavelet. The first axis of ``coefs`` corresponds to the scales.
|
|
The remaining axes match the shape of ``data``.
|
|
frequencies : array_like
|
|
If the unit of sampling period are seconds and given, than frequencies
|
|
are in hertz. Otherwise, a sampling period of 1 is assumed.
|
|
|
|
Notes
|
|
-----
|
|
Size of coefficients arrays depends on the length of the input array and
|
|
the length of given scales.
|
|
|
|
Examples
|
|
--------
|
|
>>> import pywt
|
|
>>> import numpy as np
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> x = np.arange(512)
|
|
>>> y = np.sin(2*np.pi*x/32)
|
|
>>> coef, freqs=pywt.cwt(y,np.arange(1,129),'gaus1')
|
|
>>> plt.matshow(coef) # doctest: +SKIP
|
|
>>> plt.show() # doctest: +SKIP
|
|
----------
|
|
>>> import pywt
|
|
>>> import numpy as np
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> t = np.linspace(-1, 1, 200, endpoint=False)
|
|
>>> sig = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
|
|
>>> widths = np.arange(1, 31)
|
|
>>> cwtmatr, freqs = pywt.cwt(sig, widths, 'mexh')
|
|
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
|
|
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) # doctest: +SKIP
|
|
>>> plt.show() # doctest: +SKIP
|
|
"""
|
|
|
|
# accept array_like input; make a copy to ensure a contiguous array
|
|
dt = _check_dtype(data)
|
|
data = np.asarray(data, dtype=dt)
|
|
dt_cplx = np.result_type(dt, np.complex64)
|
|
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
|
|
wavelet = DiscreteContinuousWavelet(wavelet)
|
|
if np.isscalar(scales):
|
|
scales = np.array([scales])
|
|
if not np.isscalar(axis):
|
|
raise ValueError("axis must be a scalar.")
|
|
|
|
dt_out = dt_cplx if wavelet.complex_cwt else dt
|
|
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
|
|
precision = 10
|
|
int_psi, x = integrate_wavelet(wavelet, precision=precision)
|
|
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
|
|
|
|
# convert int_psi, x to the same precision as the data
|
|
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
|
|
int_psi = np.asarray(int_psi, dtype=dt_psi)
|
|
x = np.asarray(x, dtype=data.real.dtype)
|
|
|
|
if method == 'fft':
|
|
size_scale0 = -1
|
|
fft_data = None
|
|
elif not method == 'conv':
|
|
raise ValueError("method must be 'conv' or 'fft'")
|
|
|
|
if data.ndim > 1:
|
|
# move axis to be transformed last (so it is contiguous)
|
|
data = data.swapaxes(-1, axis)
|
|
|
|
# reshape to (n_batch, data.shape[-1])
|
|
data_shape_pre = data.shape
|
|
data = data.reshape((-1, data.shape[-1]))
|
|
|
|
for i, scale in enumerate(scales):
|
|
step = x[1] - x[0]
|
|
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
|
|
j = j.astype(int) # floor
|
|
if j[-1] >= int_psi.size:
|
|
j = np.extract(j < int_psi.size, j)
|
|
int_psi_scale = int_psi[j][::-1]
|
|
|
|
if method == 'conv':
|
|
if data.ndim == 1:
|
|
conv = np.convolve(data, int_psi_scale)
|
|
else:
|
|
# batch convolution via loop
|
|
conv_shape = list(data.shape)
|
|
conv_shape[-1] += int_psi_scale.size - 1
|
|
conv_shape = tuple(conv_shape)
|
|
conv = np.empty(conv_shape, dtype=dt_out)
|
|
for n in range(data.shape[0]):
|
|
conv[n, :] = np.convolve(data[n], int_psi_scale)
|
|
else:
|
|
# The padding is selected for:
|
|
# - optimal FFT complexity
|
|
# - to be larger than the two signals length to avoid circular
|
|
# convolution
|
|
size_scale = next_fast_len(
|
|
data.shape[-1] + int_psi_scale.size - 1
|
|
)
|
|
if size_scale != size_scale0:
|
|
# Must recompute fft_data when the padding size changes.
|
|
fft_data = fftmodule.fft(data, size_scale, axis=-1)
|
|
size_scale0 = size_scale
|
|
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
|
|
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
|
|
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
|
|
|
|
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
|
|
if out.dtype.kind != 'c':
|
|
coef = coef.real
|
|
# transform axis is always -1 due to the data reshape above
|
|
d = (coef.shape[-1] - data.shape[-1]) / 2.
|
|
if d > 0:
|
|
coef = coef[..., floor(d):-ceil(d)]
|
|
elif d < 0:
|
|
raise ValueError(
|
|
"Selected scale of {} too small.".format(scale))
|
|
if data.ndim > 1:
|
|
# restore original data shape and axis position
|
|
coef = coef.reshape(data_shape_pre)
|
|
coef = coef.swapaxes(axis, -1)
|
|
out[i, ...] = coef
|
|
|
|
frequencies = scale2frequency(wavelet, scales, precision)
|
|
if np.isscalar(frequencies):
|
|
frequencies = np.array([frequencies])
|
|
frequencies /= sampling_period
|
|
return out, frequencies
|