import numpy as np
from . import pypocketfft as pfft
from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
                     _fix_shape, _fix_shape_1d, _normalization, _workers)
import functools


def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
         overwrite_x=False, workers=None):
    """Forward or backward 1-D DCT/DST

    Parameters
    ----------
    forward: bool
        Transform direction (determines type and normalisation)
    transform: {pypocketfft.dct, pypocketfft.dst}
        The transform to perform
    """
    tmp = _asfarray(x)
    overwrite_x = overwrite_x or _datacopied(tmp, x)
    norm = _normalization(norm, forward)
    workers = _workers(workers)

    if not forward:
        if type == 2:
            type = 3
        elif type == 3:
            type = 2

    if n is not None:
        tmp, copied = _fix_shape_1d(tmp, n, axis)
        overwrite_x = overwrite_x or copied
    elif tmp.shape[axis] < 1:
        raise ValueError("invalid number of data points ({0}) specified"
                         .format(tmp.shape[axis]))

    out = (tmp if overwrite_x else None)

    # For complex input, transform real and imaginary components separably
    if np.iscomplexobj(x):
        out = np.empty_like(tmp) if out is None else out
        transform(tmp.real, type, (axis,), norm, out.real, workers)
        transform(tmp.imag, type, (axis,), norm, out.imag, workers)
        return out

    return transform(tmp, type, (axis,), norm, out, workers)


dct = functools.partial(_r2r, True, pfft.dct)
dct.__name__ = 'dct'
idct = functools.partial(_r2r, False, pfft.dct)
idct.__name__ = 'idct'

dst = functools.partial(_r2r, True, pfft.dst)
dst.__name__ = 'dst'
idst = functools.partial(_r2r, False, pfft.dst)
idst.__name__ = 'idst'


def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
          overwrite_x=False, workers=None):
    """Forward or backward nd DCT/DST

    Parameters
    ----------
    forward: bool
        Transform direction (determines type and normalisation)
    transform: {pypocketfft.dct, pypocketfft.dst}
        The transform to perform
    """
    tmp = _asfarray(x)

    shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
    overwrite_x = overwrite_x or _datacopied(tmp, x)

    if len(axes) == 0:
        return x

    tmp, copied = _fix_shape(tmp, shape, axes)
    overwrite_x = overwrite_x or copied

    if not forward:
        if type == 2:
            type = 3
        elif type == 3:
            type = 2

    norm = _normalization(norm, forward)
    workers = _workers(workers)
    out = (tmp if overwrite_x else None)

    # For complex input, transform real and imaginary components separably
    if np.iscomplexobj(x):
        out = np.empty_like(tmp) if out is None else out
        transform(tmp.real, type, axes, norm, out.real, workers)
        transform(tmp.imag, type, axes, norm, out.imag, workers)
        return out

    return transform(tmp, type, axes, norm, out, workers)


dctn = functools.partial(_r2rn, True, pfft.dct)
dctn.__name__ = 'dctn'
idctn = functools.partial(_r2rn, False, pfft.dct)
idctn.__name__ = 'idctn'

dstn = functools.partial(_r2rn, True, pfft.dst)
dstn.__name__ = 'dstn'
idstn = functools.partial(_r2rn, False, pfft.dst)
idstn.__name__ = 'idstn'