Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,25 @@
Copyright (C) 2010-2019 Max-Planck-Society
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its contributors may
be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,9 @@
""" FFT backend using pypocketfft """
from .basic import *
from .realtransforms import *
from .helper import *
from scipy._lib._testutils import PytestTester
test = PytestTester(__name__)
del PytestTester

View file

@ -0,0 +1,297 @@
"""
Discrete Fourier Transforms - basic.py
"""
import numpy as np
import functools
from . import pypocketfft as pfft
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
_fix_shape, _fix_shape_1d, _normalization,
_workers)
def c2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
workers=None, *, plan=None):
""" Return discrete Fourier transform of real or complex sequence. """
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
overwrite_x = overwrite_x or _datacopied(tmp, x)
norm = _normalization(norm, forward)
workers = _workers(workers)
if n is not None:
tmp, copied = _fix_shape_1d(tmp, n, axis)
overwrite_x = overwrite_x or copied
elif tmp.shape[axis] < 1:
raise ValueError("invalid number of data points ({0}) specified"
.format(tmp.shape[axis]))
out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
return pfft.c2c(tmp, (axis,), forward, norm, out, workers)
fft = functools.partial(c2c, True)
fft.__name__ = 'fft'
ifft = functools.partial(c2c, False)
ifft.__name__ = 'ifft'
def r2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
workers=None, *, plan=None):
"""
Discrete Fourier transform of a real sequence.
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
norm = _normalization(norm, forward)
workers = _workers(workers)
if not np.isrealobj(tmp):
raise TypeError("x must be a real sequence")
if n is not None:
tmp, _ = _fix_shape_1d(tmp, n, axis)
elif tmp.shape[axis] < 1:
raise ValueError("invalid number of data points ({0}) specified"
.format(tmp.shape[axis]))
# Note: overwrite_x is not utilised
return pfft.r2c(tmp, (axis,), forward, norm, None, workers)
rfft = functools.partial(r2c, True)
rfft.__name__ = 'rfft'
ihfft = functools.partial(r2c, False)
ihfft.__name__ = 'ihfft'
def c2r(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
workers=None, *, plan=None):
"""
Return inverse discrete Fourier transform of real sequence x.
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
norm = _normalization(norm, forward)
workers = _workers(workers)
# TODO: Optimize for hermitian and real?
if np.isrealobj(tmp):
tmp = tmp + 0.j
# Last axis utilizes hermitian symmetry
if n is None:
n = (tmp.shape[axis] - 1) * 2
if n < 1:
raise ValueError("Invalid number of data points ({0}) specified"
.format(n))
else:
tmp, _ = _fix_shape_1d(tmp, (n//2) + 1, axis)
# Note: overwrite_x is not utilized
return pfft.c2r(tmp, (axis,), n, forward, norm, None, workers)
hfft = functools.partial(c2r, True)
hfft.__name__ = 'hfft'
irfft = functools.partial(c2r, False)
irfft.__name__ = 'irfft'
def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete Fourier transform.
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return fftn(x, s, axes, norm, overwrite_x, workers)
def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete inverse Fourier transform of real or complex sequence.
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return ifftn(x, s, axes, norm, overwrite_x, workers)
def rfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete Fourier transform of a real sequence
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return rfftn(x, s, axes, norm, overwrite_x, workers)
def irfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete inverse Fourier transform of a real sequence
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return irfftn(x, s, axes, norm, overwrite_x, workers)
def hfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete Fourier transform of a Hermitian sequence
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return hfftn(x, s, axes, norm, overwrite_x, workers)
def ihfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
*, plan=None):
"""
2-D discrete inverse Fourier transform of a Hermitian sequence
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
return ihfftn(x, s, axes, norm, overwrite_x, workers)
def c2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
workers=None, *, plan=None):
"""
Return multidimensional discrete Fourier transform.
"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
overwrite_x = overwrite_x or _datacopied(tmp, x)
workers = _workers(workers)
if len(axes) == 0:
return x
tmp, copied = _fix_shape(tmp, shape, axes)
overwrite_x = overwrite_x or copied
norm = _normalization(norm, forward)
out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
return pfft.c2c(tmp, axes, forward, norm, out, workers)
fftn = functools.partial(c2cn, True)
fftn.__name__ = 'fftn'
ifftn = functools.partial(c2cn, False)
ifftn.__name__ = 'ifftn'
def r2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
workers=None, *, plan=None):
"""Return multidimensional discrete Fourier transform of real input"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
if not np.isrealobj(tmp):
raise TypeError("x must be a real sequence")
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
tmp, _ = _fix_shape(tmp, shape, axes)
norm = _normalization(norm, forward)
workers = _workers(workers)
if len(axes) == 0:
raise ValueError("at least 1 axis must be transformed")
# Note: overwrite_x is not utilized
return pfft.r2c(tmp, axes, forward, norm, None, workers)
rfftn = functools.partial(r2cn, True)
rfftn.__name__ = 'rfftn'
ihfftn = functools.partial(r2cn, False)
ihfftn.__name__ = 'ihfftn'
def c2rn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
workers=None, *, plan=None):
"""Multidimensional inverse discrete fourier transform with real output"""
if plan is not None:
raise NotImplementedError('Passing a precomputed plan is not yet '
'supported by scipy.fft functions')
tmp = _asfarray(x)
# TODO: Optimize for hermitian and real?
if np.isrealobj(tmp):
tmp = tmp + 0.j
noshape = s is None
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
if len(axes) == 0:
raise ValueError("at least 1 axis must be transformed")
if noshape:
shape[-1] = (x.shape[axes[-1]] - 1) * 2
norm = _normalization(norm, forward)
workers = _workers(workers)
# Last axis utilizes hermitian symmetry
lastsize = shape[-1]
shape[-1] = (shape[-1] // 2) + 1
tmp, _ = _fix_shape(tmp, shape, axes)
# Note: overwrite_x is not utilized
return pfft.c2r(tmp, axes, lastsize, forward, norm, None, workers)
hfftn = functools.partial(c2rn, True)
hfftn.__name__ = 'hfftn'
irfftn = functools.partial(c2rn, False)
irfftn.__name__ = 'irfftn'
def r2r_fftpack(forward, x, n=None, axis=-1, norm=None, overwrite_x=False):
"""FFT of a real sequence, returning fftpack half complex format"""
tmp = _asfarray(x)
overwrite_x = overwrite_x or _datacopied(tmp, x)
norm = _normalization(norm, forward)
workers = _workers(None)
if tmp.dtype.kind == 'c':
raise TypeError('x must be a real sequence')
if n is not None:
tmp, copied = _fix_shape_1d(tmp, n, axis)
overwrite_x = overwrite_x or copied
elif tmp.shape[axis] < 1:
raise ValueError("invalid number of data points ({0}) specified"
.format(tmp.shape[axis]))
out = (tmp if overwrite_x else None)
return pfft.r2r_fftpack(tmp, (axis,), forward, forward, norm, out, workers)
rfft_fftpack = functools.partial(r2r_fftpack, True)
rfft_fftpack.__name__ = 'rfft_fftpack'
irfft_fftpack = functools.partial(r2r_fftpack, False)
irfft_fftpack.__name__ = 'irfft_fftpack'

View file

@ -0,0 +1,213 @@
from numbers import Number
import operator
import os
import threading
import contextlib
import numpy as np
# good_size is exposed (and used) from this import
from .pypocketfft import good_size
_config = threading.local()
_cpu_count = os.cpu_count()
def _iterable_of_int(x, name=None):
"""Convert ``x`` to an iterable sequence of int
Parameters
----------
x : value, or sequence of values, convertible to int
name : str, optional
Name of the argument being converted, only used in the error message
Returns
-------
y : ``List[int]``
"""
if isinstance(x, Number):
x = (x,)
try:
x = [operator.index(a) for a in x]
except TypeError as e:
name = name or "value"
raise ValueError("{} must be a scalar or iterable of integers"
.format(name)) from e
return x
def _init_nd_shape_and_axes(x, shape, axes):
"""Handles shape and axes arguments for nd transforms"""
noshape = shape is None
noaxes = axes is None
if not noaxes:
axes = _iterable_of_int(axes, 'axes')
axes = [a + x.ndim if a < 0 else a for a in axes]
if any(a >= x.ndim or a < 0 for a in axes):
raise ValueError("axes exceeds dimensionality of input")
if len(set(axes)) != len(axes):
raise ValueError("all axes must be unique")
if not noshape:
shape = _iterable_of_int(shape, 'shape')
if axes and len(axes) != len(shape):
raise ValueError("when given, axes and shape arguments"
" have to be of the same length")
if noaxes:
if len(shape) > x.ndim:
raise ValueError("shape requires more axes than are present")
axes = range(x.ndim - len(shape), x.ndim)
shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
elif noaxes:
shape = list(x.shape)
axes = range(x.ndim)
else:
shape = [x.shape[a] for a in axes]
if any(s < 1 for s in shape):
raise ValueError(
"invalid number of data points ({0}) specified".format(shape))
return shape, axes
def _asfarray(x):
"""
Convert to array with floating or complex dtype.
float16 values are also promoted to float32.
"""
if not hasattr(x, "dtype"):
x = np.asarray(x)
if x.dtype == np.float16:
return np.asarray(x, np.float32)
elif x.dtype.kind not in 'fc':
return np.asarray(x, np.float64)
# Require native byte order
dtype = x.dtype.newbyteorder('=')
# Always align input
copy = not x.flags['ALIGNED']
return np.array(x, dtype=dtype, copy=copy)
def _datacopied(arr, original):
"""
Strict check for `arr` not sharing any data with `original`,
under the assumption that arr = asarray(original)
"""
if arr is original:
return False
if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
return False
return arr.base is None
def _fix_shape(x, shape, axes):
"""Internal auxiliary function for _raw_fft, _raw_fftnd."""
must_copy = False
# Build an nd slice with the dimensions to be read from x
index = [slice(None)]*x.ndim
for n, ax in zip(shape, axes):
if x.shape[ax] >= n:
index[ax] = slice(0, n)
else:
index[ax] = slice(0, x.shape[ax])
must_copy = True
index = tuple(index)
if not must_copy:
return x[index], False
s = list(x.shape)
for n, axis in zip(shape, axes):
s[axis] = n
z = np.zeros(s, x.dtype)
z[index] = x[index]
return z, True
def _fix_shape_1d(x, n, axis):
if n < 1:
raise ValueError(
"invalid number of data points ({0}) specified".format(n))
return _fix_shape(x, (n,), (axis,))
def _normalization(norm, forward):
"""Returns the pypocketfft normalization mode from the norm argument"""
if norm is None:
return 0 if forward else 2
if norm == 'ortho':
return 1
raise ValueError(
"Invalid norm value {}, should be None or \"ortho\".".format(norm))
def _workers(workers):
if workers is None:
return getattr(_config, 'default_workers', 1)
if workers < 0:
if workers >= -_cpu_count:
workers += 1 + _cpu_count
else:
raise ValueError("workers value out of range; got {}, must not be"
" less than {}".format(workers, -_cpu_count))
elif workers == 0:
raise ValueError("workers must not be zero")
return workers
@contextlib.contextmanager
def set_workers(workers):
"""Context manager for the default number of workers used in `scipy.fft`
Parameters
----------
workers : int
The default number of workers to use
Examples
--------
>>> from scipy import fft, signal
>>> x = np.random.randn(128, 64)
>>> with fft.set_workers(4):
... y = signal.fftconvolve(x, x)
"""
old_workers = get_workers()
_config.default_workers = _workers(operator.index(workers))
try:
yield
finally:
_config.default_workers = old_workers
def get_workers():
"""Returns the default number of workers within the current context
Examples
--------
>>> from scipy import fft
>>> fft.get_workers()
1
>>> with fft.set_workers(4):
... fft.get_workers()
4
"""
return getattr(_config, 'default_workers', 1)

View file

@ -0,0 +1,110 @@
import numpy as np
from . import pypocketfft as pfft
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
_fix_shape, _fix_shape_1d, _normalization, _workers)
import functools
def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
overwrite_x=False, workers=None):
"""Forward or backward 1-D DCT/DST
Parameters
----------
forward: bool
Transform direction (determines type and normalisation)
transform: {pypocketfft.dct, pypocketfft.dst}
The transform to perform
"""
tmp = _asfarray(x)
overwrite_x = overwrite_x or _datacopied(tmp, x)
norm = _normalization(norm, forward)
workers = _workers(workers)
if not forward:
if type == 2:
type = 3
elif type == 3:
type = 2
if n is not None:
tmp, copied = _fix_shape_1d(tmp, n, axis)
overwrite_x = overwrite_x or copied
elif tmp.shape[axis] < 1:
raise ValueError("invalid number of data points ({0}) specified"
.format(tmp.shape[axis]))
out = (tmp if overwrite_x else None)
# For complex input, transform real and imaginary components separably
if np.iscomplexobj(x):
out = np.empty_like(tmp) if out is None else out
transform(tmp.real, type, (axis,), norm, out.real, workers)
transform(tmp.imag, type, (axis,), norm, out.imag, workers)
return out
return transform(tmp, type, (axis,), norm, out, workers)
dct = functools.partial(_r2r, True, pfft.dct)
dct.__name__ = 'dct'
idct = functools.partial(_r2r, False, pfft.dct)
idct.__name__ = 'idct'
dst = functools.partial(_r2r, True, pfft.dst)
dst.__name__ = 'dst'
idst = functools.partial(_r2r, False, pfft.dst)
idst.__name__ = 'idst'
def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
overwrite_x=False, workers=None):
"""Forward or backward nd DCT/DST
Parameters
----------
forward: bool
Transform direction (determines type and normalisation)
transform: {pypocketfft.dct, pypocketfft.dst}
The transform to perform
"""
tmp = _asfarray(x)
shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
overwrite_x = overwrite_x or _datacopied(tmp, x)
if len(axes) == 0:
return x
tmp, copied = _fix_shape(tmp, shape, axes)
overwrite_x = overwrite_x or copied
if not forward:
if type == 2:
type = 3
elif type == 3:
type = 2
norm = _normalization(norm, forward)
workers = _workers(workers)
out = (tmp if overwrite_x else None)
# For complex input, transform real and imaginary components separably
if np.iscomplexobj(x):
out = np.empty_like(tmp) if out is None else out
transform(tmp.real, type, axes, norm, out.real, workers)
transform(tmp.imag, type, axes, norm, out.imag, workers)
return out
return transform(tmp, type, axes, norm, out, workers)
dctn = functools.partial(_r2rn, True, pfft.dct)
dctn.__name__ = 'dctn'
idctn = functools.partial(_r2rn, False, pfft.dct)
idctn.__name__ = 'idctn'
dstn = functools.partial(_r2rn, True, pfft.dst)
dstn.__name__ = 'dstn'
idstn = functools.partial(_r2rn, False, pfft.dst)
idstn.__name__ = 'idstn'

View file

@ -0,0 +1,49 @@
def pre_build_hook(build_ext, ext):
from scipy._build_utils.compiler_helper import (
set_cxx_flags_hook, try_add_flag, try_compile, has_flag)
cc = build_ext._cxx_compiler
args = ext.extra_compile_args
set_cxx_flags_hook(build_ext, ext)
if cc.compiler_type == 'msvc':
args.append('/EHsc')
else:
# Use pthreads if available
has_pthreads = try_compile(cc, code='#include <pthread.h>\n'
'int main(int argc, char **argv) {}')
if has_pthreads:
ext.define_macros.append(('POCKETFFT_PTHREADS', None))
if has_flag(cc, '-pthread'):
args.append('-pthread')
ext.extra_link_args.append('-pthread')
else:
raise RuntimeError("Build failed: System has pthreads header "
"but could not compile with -pthread option")
# Don't export library symbols
try_add_flag(args, cc, '-fvisibility=hidden')
def configuration(parent_package='', top_path=None):
from numpy.distutils.misc_util import Configuration
import pybind11
include_dirs = [pybind11.get_include(True), pybind11.get_include(False)]
config = Configuration('_pocketfft', parent_package, top_path)
ext = config.add_extension('pypocketfft',
sources=['pypocketfft.cxx'],
depends=['pocketfft_hdronly.h'],
include_dirs=include_dirs,
language='c++')
ext._pre_build_hook = pre_build_hook
config.add_data_files('LICENSE.md')
config.add_data_dir('tests')
return config
if __name__ == '__main__':
from numpy.distutils.core import setup
setup(**configuration(top_path='').todict())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,487 @@
from os.path import join, dirname
import numpy as np
from numpy.testing import (
assert_array_almost_equal, assert_equal, assert_allclose)
import pytest
from pytest import raises as assert_raises
from scipy.fft._pocketfft.realtransforms import (
dct, idct, dst, idst, dctn, idctn, dstn, idstn)
fftpack_test_dir = join(dirname(__file__), '..', '..', '..', 'fftpack', 'tests')
MDATA_COUNT = 8
FFTWDATA_COUNT = 14
def is_longdouble_binary_compatible():
try:
one = np.frombuffer(
b'\x00\x00\x00\x00\x00\x00\x00\x80\xff\x3f\x00\x00\x00\x00\x00\x00',
dtype='<f16')
return one == np.longfloat(1.)
except TypeError:
return False
def get_reference_data():
ref = getattr(globals(), '__reference_data', None)
if ref is not None:
return ref
# Matlab reference data
MDATA = np.load(join(fftpack_test_dir, 'test.npz'))
X = [MDATA['x%d' % i] for i in range(MDATA_COUNT)]
Y = [MDATA['y%d' % i] for i in range(MDATA_COUNT)]
# FFTW reference data: the data are organized as follows:
# * SIZES is an array containing all available sizes
# * for every type (1, 2, 3, 4) and every size, the array dct_type_size
# contains the output of the DCT applied to the input np.linspace(0, size-1,
# size)
FFTWDATA_DOUBLE = np.load(join(fftpack_test_dir, 'fftw_double_ref.npz'))
FFTWDATA_SINGLE = np.load(join(fftpack_test_dir, 'fftw_single_ref.npz'))
FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
assert len(FFTWDATA_SIZES) == FFTWDATA_COUNT
if is_longdouble_binary_compatible():
FFTWDATA_LONGDOUBLE = np.load(
join(fftpack_test_dir, 'fftw_longdouble_ref.npz'))
else:
FFTWDATA_LONGDOUBLE = {k: v.astype(np.longfloat)
for k,v in FFTWDATA_DOUBLE.items()}
ref = {
'FFTWDATA_LONGDOUBLE': FFTWDATA_LONGDOUBLE,
'FFTWDATA_DOUBLE': FFTWDATA_DOUBLE,
'FFTWDATA_SINGLE': FFTWDATA_SINGLE,
'FFTWDATA_SIZES': FFTWDATA_SIZES,
'X': X,
'Y': Y
}
globals()['__reference_data'] = ref
return ref
@pytest.fixture(params=range(FFTWDATA_COUNT))
def fftwdata_size(request):
return get_reference_data()['FFTWDATA_SIZES'][request.param]
@pytest.fixture(params=range(MDATA_COUNT))
def mdata_x(request):
return get_reference_data()['X'][request.param]
@pytest.fixture(params=range(MDATA_COUNT))
def mdata_xy(request):
ref = get_reference_data()
y = ref['Y'][request.param]
x = ref['X'][request.param]
return x, y
def fftw_dct_ref(type, size, dt):
x = np.linspace(0, size-1, size).astype(dt)
dt = np.result_type(np.float32, dt)
if dt == np.double:
data = get_reference_data()['FFTWDATA_DOUBLE']
elif dt == np.float32:
data = get_reference_data()['FFTWDATA_SINGLE']
elif dt == np.longfloat:
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
else:
raise ValueError()
y = (data['dct_%d_%d' % (type, size)]).astype(dt)
return x, y, dt
def fftw_dst_ref(type, size, dt):
x = np.linspace(0, size-1, size).astype(dt)
dt = np.result_type(np.float32, dt)
if dt == np.double:
data = get_reference_data()['FFTWDATA_DOUBLE']
elif dt == np.float32:
data = get_reference_data()['FFTWDATA_SINGLE']
elif dt == np.longfloat:
data = get_reference_data()['FFTWDATA_LONGDOUBLE']
else:
raise ValueError()
y = (data['dst_%d_%d' % (type, size)]).astype(dt)
return x, y, dt
def ref_2d(func, x, **kwargs):
"""Calculate 2-D reference data from a 1d transform"""
x = np.array(x, copy=True)
for row in range(x.shape[0]):
x[row, :] = func(x[row, :], **kwargs)
for col in range(x.shape[1]):
x[:, col] = func(x[:, col], **kwargs)
return x
def naive_dct1(x, norm=None):
"""Calculate textbook definition version of DCT-I."""
x = np.array(x, copy=True)
N = len(x)
M = N-1
y = np.zeros(N)
m0, m = 1, 2
if norm == 'ortho':
m0 = np.sqrt(1.0/M)
m = np.sqrt(2.0/M)
for k in range(N):
for n in range(1, N-1):
y[k] += m*x[n]*np.cos(np.pi*n*k/M)
y[k] += m0 * x[0]
y[k] += m0 * x[N-1] * (1 if k % 2 == 0 else -1)
if norm == 'ortho':
y[0] *= 1/np.sqrt(2)
y[N-1] *= 1/np.sqrt(2)
return y
def naive_dst1(x, norm=None):
"""Calculate textbook definition version of DST-I."""
x = np.array(x, copy=True)
N = len(x)
M = N+1
y = np.zeros(N)
for k in range(N):
for n in range(N):
y[k] += 2*x[n]*np.sin(np.pi*(n+1.0)*(k+1.0)/M)
if norm == 'ortho':
y *= np.sqrt(0.5/M)
return y
def naive_dct4(x, norm=None):
"""Calculate textbook definition version of DCT-IV."""
x = np.array(x, copy=True)
N = len(x)
y = np.zeros(N)
for k in range(N):
for n in range(N):
y[k] += x[n]*np.cos(np.pi*(n+0.5)*(k+0.5)/(N))
if norm == 'ortho':
y *= np.sqrt(2.0/N)
else:
y *= 2
return y
def naive_dst4(x, norm=None):
"""Calculate textbook definition version of DST-IV."""
x = np.array(x, copy=True)
N = len(x)
y = np.zeros(N)
for k in range(N):
for n in range(N):
y[k] += x[n]*np.sin(np.pi*(n+0.5)*(k+0.5)/(N))
if norm == 'ortho':
y *= np.sqrt(2.0/N)
else:
y *= 2
return y
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128, np.longcomplex])
@pytest.mark.parametrize('transform', [dct, dst, idct, idst])
def test_complex(transform, dtype):
y = transform(1j*np.arange(5, dtype=dtype))
x = 1j*transform(np.arange(5))
assert_array_almost_equal(x, y)
# map (tranform, dtype, type) -> decimal
dec_map = {
# DCT
(dct, np.double, 1): 13,
(dct, np.float32, 1): 6,
(dct, np.double, 2): 14,
(dct, np.float32, 2): 5,
(dct, np.double, 3): 14,
(dct, np.float32, 3): 5,
(dct, np.double, 4): 13,
(dct, np.float32, 4): 6,
# IDCT
(idct, np.double, 1): 14,
(idct, np.float32, 1): 6,
(idct, np.double, 2): 14,
(idct, np.float32, 2): 5,
(idct, np.double, 3): 14,
(idct, np.float32, 3): 5,
(idct, np.double, 4): 14,
(idct, np.float32, 4): 6,
# DST
(dst, np.double, 1): 13,
(dst, np.float32, 1): 6,
(dst, np.double, 2): 14,
(dst, np.float32, 2): 6,
(dst, np.double, 3): 14,
(dst, np.float32, 3): 7,
(dst, np.double, 4): 13,
(dst, np.float32, 4): 6,
# IDST
(idst, np.double, 1): 14,
(idst, np.float32, 1): 6,
(idst, np.double, 2): 14,
(idst, np.float32, 2): 6,
(idst, np.double, 3): 14,
(idst, np.float32, 3): 6,
(idst, np.double, 4): 14,
(idst, np.float32, 4): 6,
}
for k,v in dec_map.copy().items():
if k[1] == np.double:
dec_map[(k[0], np.longdouble, k[2])] = v
elif k[1] == np.float32:
dec_map[(k[0], int, k[2])] = v
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
@pytest.mark.parametrize('type', [1, 2, 3, 4])
class TestDCT:
def test_definition(self, rdt, type, fftwdata_size):
x, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
y = dct(x, type=type)
assert_equal(y.dtype, dt)
dec = dec_map[(dct, rdt, type)]
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
@pytest.mark.parametrize('size', [7, 8, 9, 16, 32, 64])
def test_axis(self, rdt, type, size):
nt = 2
dec = dec_map[(dct, rdt, type)]
x = np.random.randn(nt, size)
y = dct(x, type=type)
for j in range(nt):
assert_array_almost_equal(y[j], dct(x[j], type=type),
decimal=dec)
x = x.T
y = dct(x, axis=0, type=type)
for j in range(nt):
assert_array_almost_equal(y[:,j], dct(x[:,j], type=type),
decimal=dec)
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dct1_definition_ortho(rdt, mdata_x):
# Test orthornomal mode.
dec = dec_map[(dct, rdt, 1)]
x = np.array(mdata_x, dtype=rdt)
dt = np.result_type(np.float32, rdt)
y = dct(x, norm='ortho', type=1)
y2 = naive_dct1(x, norm='ortho')
assert_equal(y.dtype, dt)
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dct2_definition_matlab(mdata_xy, rdt):
# Test correspondence with matlab (orthornomal mode).
dt = np.result_type(np.float32, rdt)
x = np.array(mdata_xy[0], dtype=dt)
yr = mdata_xy[1]
y = dct(x, norm="ortho", type=2)
dec = dec_map[(dct, rdt, 2)]
assert_equal(y.dtype, dt)
assert_array_almost_equal(y, yr, decimal=dec)
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dct3_definition_ortho(mdata_x, rdt):
# Test orthornomal mode.
x = np.array(mdata_x, dtype=rdt)
dt = np.result_type(np.float32, rdt)
y = dct(x, norm='ortho', type=2)
xi = dct(y, norm="ortho", type=3)
dec = dec_map[(dct, rdt, 3)]
assert_equal(xi.dtype, dt)
assert_array_almost_equal(xi, x, decimal=dec)
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dct4_definition_ortho(mdata_x, rdt):
# Test orthornomal mode.
x = np.array(mdata_x, dtype=rdt)
dt = np.result_type(np.float32, rdt)
y = dct(x, norm='ortho', type=4)
y2 = naive_dct4(x, norm='ortho')
dec = dec_map[(dct, rdt, 4)]
assert_equal(y.dtype, dt)
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
@pytest.mark.parametrize('type', [1, 2, 3, 4])
def test_idct_definition(fftwdata_size, rdt, type):
xr, yr, dt = fftw_dct_ref(type, fftwdata_size, rdt)
x = idct(yr, type=type)
dec = dec_map[(idct, rdt, type)]
assert_equal(x.dtype, dt)
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
@pytest.mark.parametrize('type', [1, 2, 3, 4])
def test_definition(fftwdata_size, rdt, type):
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
y = dst(xr, type=type)
dec = dec_map[(dst, rdt, type)]
assert_equal(y.dtype, dt)
assert_allclose(y, yr, rtol=0., atol=np.max(yr)*10**(-dec))
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dst1_definition_ortho(rdt, mdata_x):
# Test orthornomal mode.
dec = dec_map[(dst, rdt, 1)]
x = np.array(mdata_x, dtype=rdt)
dt = np.result_type(np.float32, rdt)
y = dst(x, norm='ortho', type=1)
y2 = naive_dst1(x, norm='ortho')
assert_equal(y.dtype, dt)
assert_allclose(y, y2, rtol=0., atol=np.max(y2)*10**(-dec))
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
def test_dst4_definition_ortho(rdt, mdata_x):
# Test orthornomal mode.
dec = dec_map[(dst, rdt, 4)]
x = np.array(mdata_x, dtype=rdt)
dt = np.result_type(np.float32, rdt)
y = dst(x, norm='ortho', type=4)
y2 = naive_dst4(x, norm='ortho')
assert_equal(y.dtype, dt)
assert_array_almost_equal(y, y2, decimal=dec)
@pytest.mark.parametrize('rdt', [np.longfloat, np.double, np.float32, int])
@pytest.mark.parametrize('type', [1, 2, 3, 4])
def test_idst_definition(fftwdata_size, rdt, type):
xr, yr, dt = fftw_dst_ref(type, fftwdata_size, rdt)
x = idst(yr, type=type)
dec = dec_map[(idst, rdt, type)]
assert_equal(x.dtype, dt)
assert_allclose(x, xr, rtol=0., atol=np.max(xr)*10**(-dec))
@pytest.mark.parametrize('routine', [dct, dst, idct, idst])
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.longfloat])
@pytest.mark.parametrize('shape, axis', [
((16,), -1), ((16, 2), 0), ((2, 16), 1)
])
@pytest.mark.parametrize('type', [1, 2, 3, 4])
@pytest.mark.parametrize('overwrite_x', [True, False])
@pytest.mark.parametrize('norm', [None, 'ortho'])
def test_overwrite(routine, dtype, shape, axis, type, norm, overwrite_x):
# Check input overwrite behavior
np.random.seed(1234)
if np.issubdtype(dtype, np.complexfloating):
x = np.random.randn(*shape) + 1j*np.random.randn(*shape)
else:
x = np.random.randn(*shape)
x = x.astype(dtype)
x2 = x.copy()
routine(x2, type, None, axis, norm, overwrite_x=overwrite_x)
sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
routine.__name__, x.dtype, x.shape, None, axis, overwrite_x)
if not overwrite_x:
assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
class Test_DCTN_IDCTN(object):
dec = 14
dct_type = [1, 2, 3, 4]
norms = [None, 'ortho']
rstate = np.random.RandomState(1234)
shape = (32, 16)
data = rstate.randn(*shape)
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
(dstn, idstn)])
@pytest.mark.parametrize('axes', [None,
1, (1,), [1],
0, (0,), [0],
(0, 1), [0, 1],
(-2, -1), [-2, -1]])
@pytest.mark.parametrize('dct_type', dct_type)
@pytest.mark.parametrize('norm', ['ortho'])
def test_axes_round_trip(self, fforward, finverse, axes, dct_type, norm):
tmp = fforward(self.data, type=dct_type, axes=axes, norm=norm)
tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm)
assert_array_almost_equal(self.data, tmp, decimal=12)
@pytest.mark.parametrize('funcn,func', [(dctn, dct), (dstn, dst)])
@pytest.mark.parametrize('dct_type', dct_type)
@pytest.mark.parametrize('norm', norms)
def test_dctn_vs_2d_reference(self, funcn, func, dct_type, norm):
y1 = funcn(self.data, type=dct_type, axes=None, norm=norm)
y2 = ref_2d(func, self.data, type=dct_type, norm=norm)
assert_array_almost_equal(y1, y2, decimal=11)
@pytest.mark.parametrize('funcn,func', [(idctn, idct), (idstn, idst)])
@pytest.mark.parametrize('dct_type', dct_type)
@pytest.mark.parametrize('norm', [None, 'ortho'])
def test_idctn_vs_2d_reference(self, funcn, func, dct_type, norm):
fdata = dctn(self.data, type=dct_type, norm=norm)
y1 = funcn(fdata, type=dct_type, norm=norm)
y2 = ref_2d(func, fdata, type=dct_type, norm=norm)
assert_array_almost_equal(y1, y2, decimal=11)
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
(dstn, idstn)])
def test_axes_and_shape(self, fforward, finverse):
with assert_raises(ValueError,
match="when given, axes and shape arguments"
" have to be of the same length"):
fforward(self.data, s=self.data.shape[0], axes=(0, 1))
with assert_raises(ValueError,
match="when given, axes and shape arguments"
" have to be of the same length"):
fforward(self.data, s=self.data.shape, axes=0)
@pytest.mark.parametrize('fforward', [dctn, dstn])
def test_shape(self, fforward):
tmp = fforward(self.data, s=(128, 128), axes=None)
assert_equal(tmp.shape, (128, 128))
@pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
(dstn, idstn)])
@pytest.mark.parametrize('axes', [1, (1,), [1],
0, (0,), [0]])
def test_shape_is_none_with_axes(self, fforward, finverse, axes):
tmp = fforward(self.data, s=None, axes=axes, norm='ortho')
tmp = finverse(tmp, s=None, axes=axes, norm='ortho')
assert_array_almost_equal(self.data, tmp, decimal=self.dec)
@pytest.mark.parametrize('func', [dct, dctn, idct, idctn,
dst, dstn, idst, idstn])
def test_swapped_byte_order(func):
rng = np.random.RandomState(1234)
x = rng.rand(10)
swapped_dt = x.dtype.newbyteorder('S')
assert_allclose(func(x.astype(swapped_dt)), func(x))