Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
25
venv/Lib/site-packages/scipy/fft/_pocketfft/LICENSE.md
Normal file
25
venv/Lib/site-packages/scipy/fft/_pocketfft/LICENSE.md
Normal file
|
@ -0,0 +1,25 @@
|
|||
Copyright (C) 2010-2019 Max-Planck-Society
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or
|
||||
other materials provided with the distribution.
|
||||
* Neither the name of the copyright holder nor the names of its contributors may
|
||||
be used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
9
venv/Lib/site-packages/scipy/fft/_pocketfft/__init__.py
Normal file
9
venv/Lib/site-packages/scipy/fft/_pocketfft/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
""" FFT backend using pypocketfft """
|
||||
|
||||
from .basic import *
|
||||
from .realtransforms import *
|
||||
from .helper import *
|
||||
|
||||
from scipy._lib._testutils import PytestTester
|
||||
test = PytestTester(__name__)
|
||||
del PytestTester
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
297
venv/Lib/site-packages/scipy/fft/_pocketfft/basic.py
Normal file
297
venv/Lib/site-packages/scipy/fft/_pocketfft/basic.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Discrete Fourier Transforms - basic.py
|
||||
"""
|
||||
import numpy as np
|
||||
import functools
|
||||
from . import pypocketfft as pfft
|
||||
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
|
||||
_fix_shape, _fix_shape_1d, _normalization,
|
||||
_workers)
|
||||
|
||||
def c2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
""" Return discrete Fourier transform of real or complex sequence. """
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
overwrite_x = overwrite_x or _datacopied(tmp, x)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
if n is not None:
|
||||
tmp, copied = _fix_shape_1d(tmp, n, axis)
|
||||
overwrite_x = overwrite_x or copied
|
||||
elif tmp.shape[axis] < 1:
|
||||
raise ValueError("invalid number of data points ({0}) specified"
|
||||
.format(tmp.shape[axis]))
|
||||
|
||||
out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
|
||||
|
||||
return pfft.c2c(tmp, (axis,), forward, norm, out, workers)
|
||||
|
||||
|
||||
fft = functools.partial(c2c, True)
|
||||
fft.__name__ = 'fft'
|
||||
ifft = functools.partial(c2c, False)
|
||||
ifft.__name__ = 'ifft'
|
||||
|
||||
|
||||
def r2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
"""
|
||||
Discrete Fourier transform of a real sequence.
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
if not np.isrealobj(tmp):
|
||||
raise TypeError("x must be a real sequence")
|
||||
|
||||
if n is not None:
|
||||
tmp, _ = _fix_shape_1d(tmp, n, axis)
|
||||
elif tmp.shape[axis] < 1:
|
||||
raise ValueError("invalid number of data points ({0}) specified"
|
||||
.format(tmp.shape[axis]))
|
||||
|
||||
# Note: overwrite_x is not utilised
|
||||
return pfft.r2c(tmp, (axis,), forward, norm, None, workers)
|
||||
|
||||
|
||||
rfft = functools.partial(r2c, True)
|
||||
rfft.__name__ = 'rfft'
|
||||
ihfft = functools.partial(r2c, False)
|
||||
ihfft.__name__ = 'ihfft'
|
||||
|
||||
|
||||
def c2r(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
"""
|
||||
Return inverse discrete Fourier transform of real sequence x.
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
# TODO: Optimize for hermitian and real?
|
||||
if np.isrealobj(tmp):
|
||||
tmp = tmp + 0.j
|
||||
|
||||
# Last axis utilizes hermitian symmetry
|
||||
if n is None:
|
||||
n = (tmp.shape[axis] - 1) * 2
|
||||
if n < 1:
|
||||
raise ValueError("Invalid number of data points ({0}) specified"
|
||||
.format(n))
|
||||
else:
|
||||
tmp, _ = _fix_shape_1d(tmp, (n//2) + 1, axis)
|
||||
|
||||
# Note: overwrite_x is not utilized
|
||||
return pfft.c2r(tmp, (axis,), n, forward, norm, None, workers)
|
||||
|
||||
|
||||
hfft = functools.partial(c2r, True)
|
||||
hfft.__name__ = 'hfft'
|
||||
irfft = functools.partial(c2r, False)
|
||||
irfft.__name__ = 'irfft'
|
||||
|
||||
|
||||
def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete Fourier transform.
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return fftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete inverse Fourier transform of real or complex sequence.
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return ifftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def rfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete Fourier transform of a real sequence
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return rfftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def irfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete inverse Fourier transform of a real sequence
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return irfftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def hfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete Fourier transform of a Hermitian sequence
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return hfftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def ihfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
|
||||
*, plan=None):
|
||||
"""
|
||||
2-D discrete inverse Fourier transform of a Hermitian sequence
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
return ihfftn(x, s, axes, norm, overwrite_x, workers)
|
||||
|
||||
|
||||
def c2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
"""
|
||||
Return multidimensional discrete Fourier transform.
|
||||
"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
|
||||
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
|
||||
overwrite_x = overwrite_x or _datacopied(tmp, x)
|
||||
workers = _workers(workers)
|
||||
|
||||
if len(axes) == 0:
|
||||
return x
|
||||
|
||||
tmp, copied = _fix_shape(tmp, shape, axes)
|
||||
overwrite_x = overwrite_x or copied
|
||||
|
||||
norm = _normalization(norm, forward)
|
||||
out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
|
||||
|
||||
return pfft.c2c(tmp, axes, forward, norm, out, workers)
|
||||
|
||||
|
||||
fftn = functools.partial(c2cn, True)
|
||||
fftn.__name__ = 'fftn'
|
||||
ifftn = functools.partial(c2cn, False)
|
||||
ifftn.__name__ = 'ifftn'
|
||||
|
||||
def r2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
"""Return multidimensional discrete Fourier transform of real input"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
|
||||
if not np.isrealobj(tmp):
|
||||
raise TypeError("x must be a real sequence")
|
||||
|
||||
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
|
||||
tmp, _ = _fix_shape(tmp, shape, axes)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
if len(axes) == 0:
|
||||
raise ValueError("at least 1 axis must be transformed")
|
||||
|
||||
# Note: overwrite_x is not utilized
|
||||
return pfft.r2c(tmp, axes, forward, norm, None, workers)
|
||||
|
||||
|
||||
rfftn = functools.partial(r2cn, True)
|
||||
rfftn.__name__ = 'rfftn'
|
||||
ihfftn = functools.partial(r2cn, False)
|
||||
ihfftn.__name__ = 'ihfftn'
|
||||
|
||||
|
||||
def c2rn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
|
||||
workers=None, *, plan=None):
|
||||
"""Multidimensional inverse discrete fourier transform with real output"""
|
||||
if plan is not None:
|
||||
raise NotImplementedError('Passing a precomputed plan is not yet '
|
||||
'supported by scipy.fft functions')
|
||||
tmp = _asfarray(x)
|
||||
|
||||
# TODO: Optimize for hermitian and real?
|
||||
if np.isrealobj(tmp):
|
||||
tmp = tmp + 0.j
|
||||
|
||||
noshape = s is None
|
||||
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
|
||||
|
||||
if len(axes) == 0:
|
||||
raise ValueError("at least 1 axis must be transformed")
|
||||
|
||||
if noshape:
|
||||
shape[-1] = (x.shape[axes[-1]] - 1) * 2
|
||||
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
# Last axis utilizes hermitian symmetry
|
||||
lastsize = shape[-1]
|
||||
shape[-1] = (shape[-1] // 2) + 1
|
||||
|
||||
tmp, _ = _fix_shape(tmp, shape, axes)
|
||||
|
||||
# Note: overwrite_x is not utilized
|
||||
return pfft.c2r(tmp, axes, lastsize, forward, norm, None, workers)
|
||||
|
||||
|
||||
hfftn = functools.partial(c2rn, True)
|
||||
hfftn.__name__ = 'hfftn'
|
||||
irfftn = functools.partial(c2rn, False)
|
||||
irfftn.__name__ = 'irfftn'
|
||||
|
||||
|
||||
def r2r_fftpack(forward, x, n=None, axis=-1, norm=None, overwrite_x=False):
|
||||
"""FFT of a real sequence, returning fftpack half complex format"""
|
||||
tmp = _asfarray(x)
|
||||
overwrite_x = overwrite_x or _datacopied(tmp, x)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(None)
|
||||
|
||||
if tmp.dtype.kind == 'c':
|
||||
raise TypeError('x must be a real sequence')
|
||||
|
||||
if n is not None:
|
||||
tmp, copied = _fix_shape_1d(tmp, n, axis)
|
||||
overwrite_x = overwrite_x or copied
|
||||
elif tmp.shape[axis] < 1:
|
||||
raise ValueError("invalid number of data points ({0}) specified"
|
||||
.format(tmp.shape[axis]))
|
||||
|
||||
out = (tmp if overwrite_x else None)
|
||||
|
||||
return pfft.r2r_fftpack(tmp, (axis,), forward, forward, norm, out, workers)
|
||||
|
||||
|
||||
rfft_fftpack = functools.partial(r2r_fftpack, True)
|
||||
rfft_fftpack.__name__ = 'rfft_fftpack'
|
||||
irfft_fftpack = functools.partial(r2r_fftpack, False)
|
||||
irfft_fftpack.__name__ = 'irfft_fftpack'
|
213
venv/Lib/site-packages/scipy/fft/_pocketfft/helper.py
Normal file
213
venv/Lib/site-packages/scipy/fft/_pocketfft/helper.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
from numbers import Number
|
||||
import operator
|
||||
import os
|
||||
import threading
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
# good_size is exposed (and used) from this import
|
||||
from .pypocketfft import good_size
|
||||
|
||||
_config = threading.local()
|
||||
_cpu_count = os.cpu_count()
|
||||
|
||||
|
||||
def _iterable_of_int(x, name=None):
|
||||
"""Convert ``x`` to an iterable sequence of int
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : value, or sequence of values, convertible to int
|
||||
name : str, optional
|
||||
Name of the argument being converted, only used in the error message
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : ``List[int]``
|
||||
"""
|
||||
if isinstance(x, Number):
|
||||
x = (x,)
|
||||
|
||||
try:
|
||||
x = [operator.index(a) for a in x]
|
||||
except TypeError as e:
|
||||
name = name or "value"
|
||||
raise ValueError("{} must be a scalar or iterable of integers"
|
||||
.format(name)) from e
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _init_nd_shape_and_axes(x, shape, axes):
|
||||
"""Handles shape and axes arguments for nd transforms"""
|
||||
noshape = shape is None
|
||||
noaxes = axes is None
|
||||
|
||||
if not noaxes:
|
||||
axes = _iterable_of_int(axes, 'axes')
|
||||
axes = [a + x.ndim if a < 0 else a for a in axes]
|
||||
|
||||
if any(a >= x.ndim or a < 0 for a in axes):
|
||||
raise ValueError("axes exceeds dimensionality of input")
|
||||
if len(set(axes)) != len(axes):
|
||||
raise ValueError("all axes must be unique")
|
||||
|
||||
if not noshape:
|
||||
shape = _iterable_of_int(shape, 'shape')
|
||||
|
||||
if axes and len(axes) != len(shape):
|
||||
raise ValueError("when given, axes and shape arguments"
|
||||
" have to be of the same length")
|
||||
if noaxes:
|
||||
if len(shape) > x.ndim:
|
||||
raise ValueError("shape requires more axes than are present")
|
||||
axes = range(x.ndim - len(shape), x.ndim)
|
||||
|
||||
shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
|
||||
elif noaxes:
|
||||
shape = list(x.shape)
|
||||
axes = range(x.ndim)
|
||||
else:
|
||||
shape = [x.shape[a] for a in axes]
|
||||
|
||||
if any(s < 1 for s in shape):
|
||||
raise ValueError(
|
||||
"invalid number of data points ({0}) specified".format(shape))
|
||||
|
||||
return shape, axes
|
||||
|
||||
|
||||
def _asfarray(x):
|
||||
"""
|
||||
Convert to array with floating or complex dtype.
|
||||
|
||||
float16 values are also promoted to float32.
|
||||
"""
|
||||
if not hasattr(x, "dtype"):
|
||||
x = np.asarray(x)
|
||||
|
||||
if x.dtype == np.float16:
|
||||
return np.asarray(x, np.float32)
|
||||
elif x.dtype.kind not in 'fc':
|
||||
return np.asarray(x, np.float64)
|
||||
|
||||
# Require native byte order
|
||||
dtype = x.dtype.newbyteorder('=')
|
||||
# Always align input
|
||||
copy = not x.flags['ALIGNED']
|
||||
return np.array(x, dtype=dtype, copy=copy)
|
||||
|
||||
def _datacopied(arr, original):
|
||||
"""
|
||||
Strict check for `arr` not sharing any data with `original`,
|
||||
under the assumption that arr = asarray(original)
|
||||
"""
|
||||
if arr is original:
|
||||
return False
|
||||
if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
|
||||
return False
|
||||
return arr.base is None
|
||||
|
||||
|
||||
def _fix_shape(x, shape, axes):
|
||||
"""Internal auxiliary function for _raw_fft, _raw_fftnd."""
|
||||
must_copy = False
|
||||
|
||||
# Build an nd slice with the dimensions to be read from x
|
||||
index = [slice(None)]*x.ndim
|
||||
for n, ax in zip(shape, axes):
|
||||
if x.shape[ax] >= n:
|
||||
index[ax] = slice(0, n)
|
||||
else:
|
||||
index[ax] = slice(0, x.shape[ax])
|
||||
must_copy = True
|
||||
|
||||
index = tuple(index)
|
||||
|
||||
if not must_copy:
|
||||
return x[index], False
|
||||
|
||||
s = list(x.shape)
|
||||
for n, axis in zip(shape, axes):
|
||||
s[axis] = n
|
||||
|
||||
z = np.zeros(s, x.dtype)
|
||||
z[index] = x[index]
|
||||
return z, True
|
||||
|
||||
|
||||
def _fix_shape_1d(x, n, axis):
|
||||
if n < 1:
|
||||
raise ValueError(
|
||||
"invalid number of data points ({0}) specified".format(n))
|
||||
|
||||
return _fix_shape(x, (n,), (axis,))
|
||||
|
||||
|
||||
def _normalization(norm, forward):
|
||||
"""Returns the pypocketfft normalization mode from the norm argument"""
|
||||
|
||||
if norm is None:
|
||||
return 0 if forward else 2
|
||||
|
||||
if norm == 'ortho':
|
||||
return 1
|
||||
|
||||
raise ValueError(
|
||||
"Invalid norm value {}, should be None or \"ortho\".".format(norm))
|
||||
|
||||
|
||||
def _workers(workers):
|
||||
if workers is None:
|
||||
return getattr(_config, 'default_workers', 1)
|
||||
|
||||
if workers < 0:
|
||||
if workers >= -_cpu_count:
|
||||
workers += 1 + _cpu_count
|
||||
else:
|
||||
raise ValueError("workers value out of range; got {}, must not be"
|
||||
" less than {}".format(workers, -_cpu_count))
|
||||
elif workers == 0:
|
||||
raise ValueError("workers must not be zero")
|
||||
|
||||
return workers
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_workers(workers):
|
||||
"""Context manager for the default number of workers used in `scipy.fft`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
workers : int
|
||||
The default number of workers to use
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy import fft, signal
|
||||
>>> x = np.random.randn(128, 64)
|
||||
>>> with fft.set_workers(4):
|
||||
... y = signal.fftconvolve(x, x)
|
||||
|
||||
"""
|
||||
old_workers = get_workers()
|
||||
_config.default_workers = _workers(operator.index(workers))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_config.default_workers = old_workers
|
||||
|
||||
|
||||
def get_workers():
|
||||
"""Returns the default number of workers within the current context
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from scipy import fft
|
||||
>>> fft.get_workers()
|
||||
1
|
||||
>>> with fft.set_workers(4):
|
||||
... fft.get_workers()
|
||||
4
|
||||
"""
|
||||
return getattr(_config, 'default_workers', 1)
|
Binary file not shown.
110
venv/Lib/site-packages/scipy/fft/_pocketfft/realtransforms.py
Normal file
110
venv/Lib/site-packages/scipy/fft/_pocketfft/realtransforms.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import numpy as np
|
||||
from . import pypocketfft as pfft
|
||||
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
|
||||
_fix_shape, _fix_shape_1d, _normalization, _workers)
|
||||
import functools
|
||||
|
||||
|
||||
def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
|
||||
overwrite_x=False, workers=None):
|
||||
"""Forward or backward 1-D DCT/DST
|
||||
|
||||
Parameters
|
||||
----------
|
||||
forward: bool
|
||||
Transform direction (determines type and normalisation)
|
||||
transform: {pypocketfft.dct, pypocketfft.dst}
|
||||
The transform to perform
|
||||
"""
|
||||
tmp = _asfarray(x)
|
||||
overwrite_x = overwrite_x or _datacopied(tmp, x)
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
|
||||
if not forward:
|
||||
if type == 2:
|
||||
type = 3
|
||||
elif type == 3:
|
||||
type = 2
|
||||
|
||||
if n is not None:
|
||||
tmp, copied = _fix_shape_1d(tmp, n, axis)
|
||||
overwrite_x = overwrite_x or copied
|
||||
elif tmp.shape[axis] < 1:
|
||||
raise ValueError("invalid number of data points ({0}) specified"
|
||||
.format(tmp.shape[axis]))
|
||||
|
||||
out = (tmp if overwrite_x else None)
|
||||
|
||||
# For complex input, transform real and imaginary components separably
|
||||
if np.iscomplexobj(x):
|
||||
out = np.empty_like(tmp) if out is None else out
|
||||
transform(tmp.real, type, (axis,), norm, out.real, workers)
|
||||
transform(tmp.imag, type, (axis,), norm, out.imag, workers)
|
||||
return out
|
||||
|
||||
return transform(tmp, type, (axis,), norm, out, workers)
|
||||
|
||||
|
||||
dct = functools.partial(_r2r, True, pfft.dct)
|
||||
dct.__name__ = 'dct'
|
||||
idct = functools.partial(_r2r, False, pfft.dct)
|
||||
idct.__name__ = 'idct'
|
||||
|
||||
dst = functools.partial(_r2r, True, pfft.dst)
|
||||
dst.__name__ = 'dst'
|
||||
idst = functools.partial(_r2r, False, pfft.dst)
|
||||
idst.__name__ = 'idst'
|
||||
|
||||
|
||||
def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
|
||||
overwrite_x=False, workers=None):
|
||||
"""Forward or backward nd DCT/DST
|
||||
|
||||
Parameters
|
||||
----------
|
||||
forward: bool
|
||||
Transform direction (determines type and normalisation)
|
||||
transform: {pypocketfft.dct, pypocketfft.dst}
|
||||
The transform to perform
|
||||
"""
|
||||
tmp = _asfarray(x)
|
||||
|
||||
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
|
||||
overwrite_x = overwrite_x or _datacopied(tmp, x)
|
||||
|
||||
if len(axes) == 0:
|
||||
return x
|
||||
|
||||
tmp, copied = _fix_shape(tmp, shape, axes)
|
||||
overwrite_x = overwrite_x or copied
|
||||
|
||||
if not forward:
|
||||
if type == 2:
|
||||
type = 3
|
||||
elif type == 3:
|
||||
type = 2
|
||||
|
||||
norm = _normalization(norm, forward)
|
||||
workers = _workers(workers)
|
||||
out = (tmp if overwrite_x else None)
|
||||
|
||||
# For complex input, transform real and imaginary components separably
|
||||
if np.iscomplexobj(x):
|
||||
out = np.empty_like(tmp) if out is None else out
|
||||
transform(tmp.real, type, axes, norm, out.real, workers)
|
||||
transform(tmp.imag, type, axes, norm, out.imag, workers)
|
||||
return out
|
||||
|
||||
return transform(tmp, type, axes, norm, out, workers)
|
||||
|
||||
|
||||
dctn = functools.partial(_r2rn, True, pfft.dct)
|
||||
dctn.__name__ = 'dctn'
|
||||
idctn = functools.partial(_r2rn, False, pfft.dct)
|
||||
idctn.__name__ = 'idctn'
|
||||
|
||||
dstn = functools.partial(_r2rn, True, pfft.dst)
|
||||
dstn.__name__ = 'dstn'
|
||||
idstn = functools.partial(_r2rn, False, pfft.dst)
|
||||
idstn.__name__ = 'idstn'
|
49
venv/Lib/site-packages/scipy/fft/_pocketfft/setup.py
Normal file
49
venv/Lib/site-packages/scipy/fft/_pocketfft/setup.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
|
||||
def pre_build_hook(build_ext, ext):
|
||||
from scipy._build_utils.compiler_helper import (
|
||||
set_cxx_flags_hook, try_add_flag, try_compile, has_flag)
|
||||
cc = build_ext._cxx_compiler
|
||||
args = ext.extra_compile_args
|
||||
|
||||
set_cxx_flags_hook(build_ext, ext)
|
||||
|
||||
if cc.compiler_type == 'msvc':
|
||||
args.append('/EHsc')
|
||||
else:
|
||||
# Use pthreads if available
|
||||
has_pthreads = try_compile(cc, code='#include <pthread.h>\n'
|
||||
'int main(int argc, char **argv) {}')
|
||||
if has_pthreads:
|
||||
ext.define_macros.append(('POCKETFFT_PTHREADS', None))
|
||||
if has_flag(cc, '-pthread'):
|
||||
args.append('-pthread')
|
||||
ext.extra_link_args.append('-pthread')
|
||||
else:
|
||||
raise RuntimeError("Build failed: System has pthreads header "
|
||||
"but could not compile with -pthread option")
|
||||
|
||||
# Don't export library symbols
|
||||
try_add_flag(args, cc, '-fvisibility=hidden')
|
||||
|
||||
|
||||
def configuration(parent_package='', top_path=None):
|
||||
from numpy.distutils.misc_util import Configuration
|
||||
import pybind11
|
||||
include_dirs = [pybind11.get_include(True), pybind11.get_include(False)]
|
||||
|
||||
config = Configuration('_pocketfft', parent_package, top_path)
|
||||
ext = config.add_extension('pypocketfft',
|
||||
sources=['pypocketfft.cxx'],
|
||||
depends=['pocketfft_hdronly.h'],
|
||||
include_dirs=include_dirs,
|
||||
language='c++')
|
||||
ext._pre_build_hook = pre_build_hook
|
||||
|
||||
config.add_data_files('LICENSE.md')
|
||||
config.add_data_dir('tests')
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from numpy.distutils.core import setup
|
||||
setup(**configuration(top_path='').todict())
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
1021
venv/Lib/site-packages/scipy/fft/_pocketfft/tests/test_basic.py
Normal file
1021
venv/Lib/site-packages/scipy/fft/_pocketfft/tests/test_basic.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,487 @@
|
|||
from os.path import join, dirname
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (
|
||||
assert_array_almost_equal, assert_equal, assert_allclose)
|
||||
import pytest
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy.fft._pocketfft.realtransforms import (
|
||||
dct, idct, dst, idst, dctn, idctn, dstn, idstn)
|
||||
|
||||
fftpack_test_dir = join(dirname(__file__), '..', '..', '..', 'fftpack', 'tests')
|
||||
|
||||
MDATA_COUNT = 8
|
||||
FFTWDATA_COUNT = 14
|
||||
|
||||
def is_longdouble_binary_compatible():
|
||||
try:
|
||||
one = np.frombuffer(
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x80\xff\x3f\x00\x00\x00\x00\x00\x00',
|
||||
dtype='<f16')
|
||||
return one == np.longfloat(1.)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_reference_data():
|
||||
ref = getattr(globals(), '__reference_data', None)
|
||||
if ref is not None:
|
||||
return ref
|
||||
|
||||
# Matlab reference data
|
||||
MDATA = np.load(join(fftpack_test_dir, 'test.npz'))
|
||||
X = [MDATA['x%d' % i] for i in range(MDATA_COUNT)]
|
||||
Y = [MDATA['y%d' % i] for i in range(MDATA_COUNT)]
|
||||
|
||||
# FFTW reference data: the data are organized as follows:
|
||||
# * SIZES is an array containing all available sizes
|
||||
# * for every type (1, 2, 3, 4) and every size, the array dct_type_size
|
||||
# contains the output of the DCT applied to the input np.linspace(0, size-1,
|
||||
# size)
|
||||
FFTWDATA_DOUBLE = np.load(join(fftpack_test_dir, 'fftw_double_ref.npz'))
|
||||
FFTWDATA_SINGLE = np.load(join(fftpack_test_dir, 'fftw_single_ref.npz'))
|
||||
FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
|
||||
assert len(FFTWDATA_SIZES) == FFTWDATA_COUNT
|
||||
|
||||
if is_longdouble_binary_compatible():
|
||||
FFTWDATA_LONGDOUBLE = np.load(
|
||||
join(fftpack_test_dir, 'fftw_longdouble_ref.npz'))
|
||||
else:
|
||||
FFTWDATA_LONGDOUBLE = {k: v.astype(np.longfloat)
|
||||
for k,v in FFTWDATA_DOUBLE.items()}
|
||||
|
||||
ref = {
|
||||
'FFTWDATA_LONGDOUBLE': FFTWDATA_LONGDOUBLE,
|
||||
'FFTWDATA_DOUBLE': FFTWDATA_DOUBLE,
|
||||
'FFTWDATA_SINGLE': FFTWDATA_SINGLE,
|
||||
'FFTWDATA_SIZES': FFTWDATA_SIZES,
|
||||
'X': X,
|
||||
'Y': Y
|
||||
}
|
||||
|
||||
globals()['__reference_data'] = ref
|
||||
return ref
|
||||
|
||||
|
||||
@pytest.fixture(params=range(FFTWDATA_COUNT))
|
||||
def fftwdata_size(request):
|
||||
return get_reference_data()['FFTWDATA_SIZES'][request.param]
|
||||
|
||||
@pytest.fixture(params=range(MDATA_COUNT))
|
||||
def mdata_x(request):
|
||||
return get_reference_data()['X'][request.param]
|
||||
|
||||
|
||||
@pytest.fixture(params=range(MDATA_COUNT))
|
||||
def mdata_xy(request):
|
||||
ref = get_reference_data()
|
||||
y = ref['Y'][request.param]
|
||||
x = ref['X'][request.param]
|
||||
return x, y
|
||||
|
||||
|
||||
def fftw_dct_ref(type, size, dt):
|
||||
x = np.linspace(0, size-1, size).astype(dt)
|
||||
dt = np.result_type(np.float32, dt)
|
||||
if dt == np.double:
|
||||
data = get_reference_data()['FFTWDATA_DOUBLE']
|
||||
elif dt == np.float32:
|
||||
data = get_reference_data()['FFTWDATA_SINGLE']
|
||||
elif dt == np.longfloat:
|
||||
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
|
||||
else:
|
||||
raise ValueError()
|
||||
y = (data['dct_%d_%d' % (type, size)]).astype(dt)
|
||||
return x, y, dt
|
||||
|
||||
|
||||
def fftw_dst_ref(type, size, dt):
|
||||
x = np.linspace(0, size-1, size).astype(dt)
|
||||
dt = np.result_type(np.float32, dt)
|
||||
if dt == np.double:
|
||||
data = get_reference_data()['FFTWDATA_DOUBLE']
|
||||
elif dt == np.float32:
|
||||
data = get_reference_data()['FFTWDATA_SINGLE']
|
||||
elif dt == np.longfloat:
|
||||
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
|
||||
else:
|
||||
raise ValueError()
|
||||
y = (data['dst_%d_%d' % (type, size)]).astype(dt)
|
||||
return x, y, dt
|
||||
|
||||
|
||||
def ref_2d(func, x, **kwargs):
|
||||
"""Calculate 2-D reference data from a 1d transform"""
|
||||
x = np.array(x, copy=True)
|
||||
for row in range(x.shape[0]):
|
||||
x[row, :] = func(x[row, :], **kwargs)
|
||||
for col in range(x.shape[1]):
|
||||
x[:, col] = func(x[:, col], **kwargs)
|
||||
return x
|
||||
|
||||
|
||||
def naive_dct1(x, norm=None):
|
||||
"""Calculate textbook definition version of DCT-I."""
|
||||
x = np.array(x, copy=True)
|
||||
N = len(x)
|
||||
M = N-1
|
||||
y = np.zeros(N)
|
||||
m0, m = 1, 2
|
||||
if norm == 'ortho':
|
||||
m0 = np.sqrt(1.0/M)
|
||||
m = np.sqrt(2.0/M)
|
||||
for k in range(N):
|
||||
for n in range(1, N-1):
|
||||
y[k] += m*x[n]*np.cos(np.pi*n*k/M)
|
||||
y[k] += m0 * x[0]
|
||||
y[k] += m0 * x[N-1] * (1 if k % 2 == 0 else -1)
|
||||
if norm == 'ortho':
|
||||
y[0] *= 1/np.sqrt(2)
|
||||
y[N-1] *= 1/np.sqrt(2)
|
||||
return y
|
||||
|
||||
|
||||
def naive_dst1(x, norm=None):
|
||||
"""Calculate textbook definition version of DST-I."""
|
||||
x = np.array(x, copy=True)
|
||||
N = len(x)
|
||||
M = N+1
|
||||
y = np.zeros(N)
|
||||
for k in range(N):
|
||||
for n in range(N):
|
||||
y[k] += 2*x[n]*np.sin(np.pi*(n+1.0)*(k+1.0)/M)
|
||||
if norm == 'ortho':
|
||||
y *= np.sqrt(0.5/M)
|
||||
return y
|
||||
|
||||
|
||||
def naive_dct4(x, norm=None):
|
||||
"""Calculate textbook definition version of DCT-IV."""
|
||||
x = np.array(x, copy=True)
|
||||
N = len(x)
|
||||
y = np.zeros(N)
|
||||
for k in range(N):
|
||||
for n in range(N):
|
||||
y[k] += x[n]*np.cos(np.pi*(n+0.5)*(k+0.5)/(N))
|
||||
if norm == 'ortho':
|
||||
y *= np.sqrt(2.0/N)
|
||||
else:
|
||||
y *= 2
|
||||
return y
|
||||
|
||||
|
||||
def naive_dst4(x, norm=None):
|
||||
"""Calculate textbook definition version of DST-IV."""
|
||||
x = np.array(x, copy=True)
|
||||
N = len(x)
|
||||
y = np.zeros(N)
|
||||
for k in range(N):
|
||||
for n in range(N):
|
||||
y[k] += x[n]*np.sin(np.pi*(n+0.5)*(k+0.5)/(N))
|
||||
if norm == 'ortho':
|
||||
y *= np.sqrt(2.0/N)
|
||||
else:
|
||||
y *= 2
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128, np.longcomplex])
|
||||
@pytest.mark.parametrize('transform', [dct, dst, idct, idst])
|
||||
def test_complex(transform, dtype):
|
||||
y = transform(1j*np.arange(5, dtype=dtype))
|
||||
x = 1j*transform(np.arange(5))
|
||||
assert_array_almost_equal(x, y)
|
||||
|
||||
|
||||
# map (tranform, dtype, type) -> decimal
|
||||
dec_map = {
|
||||
# DCT
|
||||
(dct, np.double, 1): 13,
|
||||
(dct, np.float32, 1): 6,
|
||||
|
||||
(dct, np.double, 2): 14,
|
||||
(dct, np.float32, 2): 5,
|
||||
|
||||
(dct, np.double, 3): 14,
|
||||
(dct, np.float32, 3): 5,
|
||||
|
||||
(dct, np.double, 4): 13,
|
||||
(dct, np.float32, 4): 6,
|
||||
|
||||
# IDCT
|
||||
(idct, np.double, 1): 14,
|
||||
(idct, np.float32, 1): 6,
|
||||
|
||||
(idct, np.double, 2): 14,
|
||||
(idct, np.float32, 2): 5,
|
||||
|
||||
(idct, np.double, 3): 14,
|
||||
(idct, np.float32, 3): 5,
|
||||
|
||||
(idct, np.double, 4): 14,
|
||||
(idct, np.float32, 4): 6,
|
||||
|
||||
# DST
|
||||
(dst, np.double, 1): 13,
|
||||
(dst, np.float32, 1): 6,
|
||||
|
||||
(dst, np.double, 2): 14,
|
||||
(dst, np.float32, 2): 6,
|
||||
|
||||
(dst, np.double, 3): 14,
|
||||
(dst, np.float32, 3): 7,
|
||||
|
||||
(dst, np.double, 4): 13,
|
||||
(dst, np.float32, 4): 6,
|
||||
|
||||
# IDST
|
||||
(idst, np.double, 1): 14,
|
||||
(idst, np.float32, 1): 6,
|
||||
|
||||
(idst, np.double, 2): 14,
|
||||
(idst, np.float32, 2): 6,
|
||||
|
||||
(idst, np.double, 3): 14,
|
||||
(idst, np.float32, 3): 6,
|
||||
|
||||
(idst, np.double, 4): 14,
|
||||
(idst, np.float32, 4): 6,
|
||||
}
|
||||
|
||||
for k,v in dec_map.copy().items():
|
||||
if k[1] == np.double:
|
||||
dec_map[(k[0], np.longdouble, k[2])] = v
|
||||
elif k[1] == np.float32:
|
||||
dec_map[(k[0], int, k[2])] = v
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
||||
class TestDCT:
|
||||
def test_definition(self, rdt, type, fftwdata_size):
|
||||
x, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
|
||||
y = dct(x, type=type)
|
||||
assert_equal(y.dtype, dt)
|
||||
dec = dec_map[(dct, rdt, type)]
|
||||
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
|
||||
|
||||
@pytest.mark.parametrize('size', [7, 8, 9, 16, 32, 64])
|
||||
def test_axis(self, rdt, type, size):
|
||||
nt = 2
|
||||
dec = dec_map[(dct, rdt, type)]
|
||||
x = np.random.randn(nt, size)
|
||||
y = dct(x, type=type)
|
||||
for j in range(nt):
|
||||
assert_array_almost_equal(y[j], dct(x[j], type=type),
|
||||
decimal=dec)
|
||||
|
||||
x = x.T
|
||||
y = dct(x, axis=0, type=type)
|
||||
for j in range(nt):
|
||||
assert_array_almost_equal(y[:,j], dct(x[:,j], type=type),
|
||||
decimal=dec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dct1_definition_ortho(rdt, mdata_x):
|
||||
# Test orthornomal mode.
|
||||
dec = dec_map[(dct, rdt, 1)]
|
||||
x = np.array(mdata_x, dtype=rdt)
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
y = dct(x, norm='ortho', type=1)
|
||||
y2 = naive_dct1(x, norm='ortho')
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dct2_definition_matlab(mdata_xy, rdt):
|
||||
# Test correspondence with matlab (orthornomal mode).
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
x = np.array(mdata_xy[0], dtype=dt)
|
||||
|
||||
yr = mdata_xy[1]
|
||||
y = dct(x, norm="ortho", type=2)
|
||||
dec = dec_map[(dct, rdt, 2)]
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_array_almost_equal(y, yr, decimal=dec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dct3_definition_ortho(mdata_x, rdt):
|
||||
# Test orthornomal mode.
|
||||
x = np.array(mdata_x, dtype=rdt)
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
y = dct(x, norm='ortho', type=2)
|
||||
xi = dct(y, norm="ortho", type=3)
|
||||
dec = dec_map[(dct, rdt, 3)]
|
||||
assert_equal(xi.dtype, dt)
|
||||
assert_array_almost_equal(xi, x, decimal=dec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dct4_definition_ortho(mdata_x, rdt):
|
||||
# Test orthornomal mode.
|
||||
x = np.array(mdata_x, dtype=rdt)
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
y = dct(x, norm='ortho', type=4)
|
||||
y2 = naive_dct4(x, norm='ortho')
|
||||
dec = dec_map[(dct, rdt, 4)]
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
||||
def test_idct_definition(fftwdata_size, rdt, type):
|
||||
xr, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
|
||||
x = idct(yr, type=type)
|
||||
dec = dec_map[(idct, rdt, type)]
|
||||
assert_equal(x.dtype, dt)
|
||||
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
||||
def test_definition(fftwdata_size, rdt, type):
|
||||
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
|
||||
y = dst(xr, type=type)
|
||||
dec = dec_map[(dst, rdt, type)]
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dst1_definition_ortho(rdt, mdata_x):
|
||||
# Test orthornomal mode.
|
||||
dec = dec_map[(dst, rdt, 1)]
|
||||
x = np.array(mdata_x, dtype=rdt)
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
y = dst(x, norm='ortho', type=1)
|
||||
y2 = naive_dst1(x, norm='ortho')
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
def test_dst4_definition_ortho(rdt, mdata_x):
|
||||
# Test orthornomal mode.
|
||||
dec = dec_map[(dst, rdt, 4)]
|
||||
x = np.array(mdata_x, dtype=rdt)
|
||||
dt = np.result_type(np.float32, rdt)
|
||||
y = dst(x, norm='ortho', type=4)
|
||||
y2 = naive_dst4(x, norm='ortho')
|
||||
assert_equal(y.dtype, dt)
|
||||
assert_array_almost_equal(y, y2, decimal=dec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
|
||||
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
||||
def test_idst_definition(fftwdata_size, rdt, type):
|
||||
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
|
||||
x = idst(yr, type=type)
|
||||
dec = dec_map[(idst, rdt, type)]
|
||||
assert_equal(x.dtype, dt)
|
||||
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('routine', [dct, dst, idct, idst])
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.longfloat])
|
||||
@pytest.mark.parametrize('shape, axis', [
|
||||
((16,), -1), ((16, 2), 0), ((2, 16), 1)
|
||||
])
|
||||
@pytest.mark.parametrize('type', [1, 2, 3, 4])
|
||||
@pytest.mark.parametrize('overwrite_x', [True, False])
|
||||
@pytest.mark.parametrize('norm', [None, 'ortho'])
|
||||
def test_overwrite(routine, dtype, shape, axis, type, norm, overwrite_x):
|
||||
# Check input overwrite behavior
|
||||
np.random.seed(1234)
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
x = np.random.randn(*shape) + 1j*np.random.randn(*shape)
|
||||
else:
|
||||
x = np.random.randn(*shape)
|
||||
x = x.astype(dtype)
|
||||
x2 = x.copy()
|
||||
routine(x2, type, None, axis, norm, overwrite_x=overwrite_x)
|
||||
|
||||
sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
|
||||
routine.__name__, x.dtype, x.shape, None, axis, overwrite_x)
|
||||
if not overwrite_x:
|
||||
assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
|
||||
|
||||
|
||||
class Test_DCTN_IDCTN(object):
|
||||
dec = 14
|
||||
dct_type = [1, 2, 3, 4]
|
||||
norms = [None, 'ortho']
|
||||
rstate = np.random.RandomState(1234)
|
||||
shape = (32, 16)
|
||||
data = rstate.randn(*shape)
|
||||
|
||||
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
||||
(dstn, idstn)])
|
||||
@pytest.mark.parametrize('axes', [None,
|
||||
1, (1,), [1],
|
||||
0, (0,), [0],
|
||||
(0, 1), [0, 1],
|
||||
(-2, -1), [-2, -1]])
|
||||
@pytest.mark.parametrize('dct_type', dct_type)
|
||||
@pytest.mark.parametrize('norm', ['ortho'])
|
||||
def test_axes_round_trip(self, fforward, finverse, axes, dct_type, norm):
|
||||
tmp = fforward(self.data, type=dct_type, axes=axes, norm=norm)
|
||||
tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm)
|
||||
assert_array_almost_equal(self.data, tmp, decimal=12)
|
||||
|
||||
@pytest.mark.parametrize('funcn,func', [(dctn, dct), (dstn, dst)])
|
||||
@pytest.mark.parametrize('dct_type', dct_type)
|
||||
@pytest.mark.parametrize('norm', norms)
|
||||
def test_dctn_vs_2d_reference(self, funcn, func, dct_type, norm):
|
||||
y1 = funcn(self.data, type=dct_type, axes=None, norm=norm)
|
||||
y2 = ref_2d(func, self.data, type=dct_type, norm=norm)
|
||||
assert_array_almost_equal(y1, y2, decimal=11)
|
||||
|
||||
@pytest.mark.parametrize('funcn,func', [(idctn, idct), (idstn, idst)])
|
||||
@pytest.mark.parametrize('dct_type', dct_type)
|
||||
@pytest.mark.parametrize('norm', [None, 'ortho'])
|
||||
def test_idctn_vs_2d_reference(self, funcn, func, dct_type, norm):
|
||||
fdata = dctn(self.data, type=dct_type, norm=norm)
|
||||
y1 = funcn(fdata, type=dct_type, norm=norm)
|
||||
y2 = ref_2d(func, fdata, type=dct_type, norm=norm)
|
||||
assert_array_almost_equal(y1, y2, decimal=11)
|
||||
|
||||
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
||||
(dstn, idstn)])
|
||||
def test_axes_and_shape(self, fforward, finverse):
|
||||
with assert_raises(ValueError,
|
||||
match="when given, axes and shape arguments"
|
||||
" have to be of the same length"):
|
||||
fforward(self.data, s=self.data.shape[0], axes=(0, 1))
|
||||
|
||||
with assert_raises(ValueError,
|
||||
match="when given, axes and shape arguments"
|
||||
" have to be of the same length"):
|
||||
fforward(self.data, s=self.data.shape, axes=0)
|
||||
|
||||
@pytest.mark.parametrize('fforward', [dctn, dstn])
|
||||
def test_shape(self, fforward):
|
||||
tmp = fforward(self.data, s=(128, 128), axes=None)
|
||||
assert_equal(tmp.shape, (128, 128))
|
||||
|
||||
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
|
||||
(dstn, idstn)])
|
||||
@pytest.mark.parametrize('axes', [1, (1,), [1],
|
||||
0, (0,), [0]])
|
||||
def test_shape_is_none_with_axes(self, fforward, finverse, axes):
|
||||
tmp = fforward(self.data, s=None, axes=axes, norm='ortho')
|
||||
tmp = finverse(tmp, s=None, axes=axes, norm='ortho')
|
||||
assert_array_almost_equal(self.data, tmp, decimal=self.dec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('func', [dct, dctn, idct, idctn,
|
||||
dst, dstn, idst, idstn])
|
||||
def test_swapped_byte_order(func):
|
||||
rng = np.random.RandomState(1234)
|
||||
x = rng.rand(10)
|
||||
swapped_dt = x.dtype.newbyteorder('S')
|
||||
assert_allclose(func(x.astype(swapped_dt)), func(x))
|
Loading…
Add table
Add a link
Reference in a new issue