# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
#                         <https://github.com/PyWavelets/pywt>
# See COPYING for license details.

"""
2D and nD Discrete Wavelet Transforms and Inverse Discrete Wavelet Transforms.
"""

from __future__ import division, print_function, absolute_import

from itertools import product

import numpy as np

from ._c99_config import _have_c99_complex
from ._extensions._dwt import dwt_axis, idwt_axis
from ._utils import _wavelets_per_axis, _modes_per_axis


__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']


def dwt2(data, wavelet, mode='symmetric', axes=(-2, -1)):
    """
    2D Discrete Wavelet Transform.

    Parameters
    ----------
    data : array_like
        2D array with input data
    wavelet : Wavelet object or name string, or 2-tuple of wavelets
        Wavelet to use.  This can also be a tuple containing a wavelet to
        apply along each axis in ``axes``.
    mode : str or 2-tuple of strings, optional
        Signal extension mode, see :ref:`Modes <ref-modes>`. This can
        also be a tuple of modes specifying the mode to use on each axis in
        ``axes``.
    axes : 2-tuple of ints, optional
        Axes over which to compute the DWT. Repeated elements mean the DWT will
        be performed multiple times along these axes.

    Returns
    -------
    (cA, (cH, cV, cD)) : tuple
        Approximation, horizontal detail, vertical detail and diagonal
        detail coefficients respectively.  Horizontal refers to array axis 0
        (or ``axes[0]`` for user-specified ``axes``).

    Examples
    --------
    >>> import numpy as np
    >>> import pywt
    >>> data = np.ones((4,4), dtype=np.float64)
    >>> coeffs = pywt.dwt2(data, 'haar')
    >>> cA, (cH, cV, cD) = coeffs
    >>> cA
    array([[ 2.,  2.],
           [ 2.,  2.]])
    >>> cV
    array([[ 0.,  0.],
           [ 0.,  0.]])

    """
    axes = tuple(axes)
    data = np.asarray(data)
    if len(axes) != 2:
        raise ValueError("Expected 2 axes")
    if data.ndim < len(np.unique(axes)):
        raise ValueError("Input array has fewer dimensions than the specified "
                         "axes")

    coefs = dwtn(data, wavelet, mode, axes)
    return coefs['aa'], (coefs['da'], coefs['ad'], coefs['dd'])


def idwt2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
    """
    2-D Inverse Discrete Wavelet Transform.

    Reconstructs data from coefficient arrays.

    Parameters
    ----------
    coeffs : tuple
        (cA, (cH, cV, cD)) A tuple with approximation coefficients and three
        details coefficients 2D arrays like from ``dwt2``.  If any of these
        components are set to ``None``, it will be treated as zeros.
    wavelet : Wavelet object or name string, or 2-tuple of wavelets
        Wavelet to use.  This can also be a tuple containing a wavelet to
        apply along each axis in ``axes``.
    mode : str or 2-tuple of strings, optional
        Signal extension mode, see :ref:`Modes <ref-modes>`. This can
        also be a tuple of modes specifying the mode to use on each axis in
        ``axes``.
    axes : 2-tuple of ints, optional
        Axes over which to compute the IDWT. Repeated elements mean the IDWT
        will be performed multiple times along these axes.

    Examples
    --------
    >>> import numpy as np
    >>> import pywt
    >>> data = np.array([[1,2], [3,4]], dtype=np.float64)
    >>> coeffs = pywt.dwt2(data, 'haar')
    >>> pywt.idwt2(coeffs, 'haar')
    array([[ 1.,  2.],
           [ 3.,  4.]])

    """
    # L -low-pass data, H - high-pass data
    LL, (HL, LH, HH) = coeffs
    axes = tuple(axes)
    if len(axes) != 2:
        raise ValueError("Expected 2 axes")

    coeffs = {'aa': LL, 'da': HL, 'ad': LH, 'dd': HH}
    return idwtn(coeffs, wavelet, mode, axes)


def dwtn(data, wavelet, mode='symmetric', axes=None):
    """
    Single-level n-dimensional Discrete Wavelet Transform.

    Parameters
    ----------
    data : array_like
        n-dimensional array with input data.
    wavelet : Wavelet object or name string, or tuple of wavelets
        Wavelet to use.  This can also be a tuple containing a wavelet to
        apply along each axis in ``axes``.
    mode : str or tuple of string, optional
        Signal extension mode used in the decomposition,
        see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
        specifying the mode to use on each axis in ``axes``.
    axes : sequence of ints, optional
        Axes over which to compute the DWT. Repeated elements mean the DWT will
        be performed multiple times along these axes. A value of ``None`` (the
        default) selects all axes.

        Axes may be repeated, but information about the original size may be
        lost if it is not divisible by ``2 ** nrepeats``. The reconstruction
        will be larger, with additional values derived according to the
        ``mode`` parameter. ``pywt.wavedecn`` should be used for multilevel
        decomposition.

    Returns
    -------
    coeffs : dict
        Results are arranged in a dictionary, where key specifies
        the transform type on each dimension and value is a n-dimensional
        coefficients array.

        For example, for a 2D case the result will look something like this::

            {'aa': <coeffs>  # A(LL) - approx. on 1st dim, approx. on 2nd dim
             'ad': <coeffs>  # V(LH) - approx. on 1st dim, det. on 2nd dim
             'da': <coeffs>  # H(HL) - det. on 1st dim, approx. on 2nd dim
             'dd': <coeffs>  # D(HH) - det. on 1st dim, det. on 2nd dim
            }

        For user-specified ``axes``, the order of the characters in the
        dictionary keys map to the specified ``axes``.

    """
    data = np.asarray(data)
    if not _have_c99_complex and np.iscomplexobj(data):
        real = dwtn(data.real, wavelet, mode, axes)
        imag = dwtn(data.imag, wavelet, mode, axes)
        return dict((k, real[k] + 1j * imag[k]) for k in real.keys())

    if data.dtype == np.dtype('object'):
        raise TypeError("Input must be a numeric array-like")
    if data.ndim < 1:
        raise ValueError("Input data must be at least 1D")

    if axes is None:
        axes = range(data.ndim)
    axes = [a + data.ndim if a < 0 else a for a in axes]

    modes = _modes_per_axis(mode, axes)
    wavelets = _wavelets_per_axis(wavelet, axes)

    coeffs = [('', data)]
    for axis, wav, mode in zip(axes, wavelets, modes):
        new_coeffs = []
        for subband, x in coeffs:
            cA, cD = dwt_axis(x, wav, mode, axis)
            new_coeffs.extend([(subband + 'a', cA),
                               (subband + 'd', cD)])
        coeffs = new_coeffs
    return dict(coeffs)


def _fix_coeffs(coeffs):
    missing_keys = [k for k, v in coeffs.items() if v is None]
    if missing_keys:
        raise ValueError(
            "The following detail coefficients were set to None:\n"
            "{0}\n"
            "For multilevel transforms, rather than setting\n"
            "\tcoeffs[key] = None\n"
            "use\n"
            "\tcoeffs[key] = np.zeros_like(coeffs[key])\n".format(
                missing_keys))

    invalid_keys = [k for k, v in coeffs.items() if
                    not set(k) <= set('ad')]
    if invalid_keys:
        raise ValueError(
            "The following invalid keys were found in the detail "
            "coefficient dictionary: {}.".format(invalid_keys))

    key_lengths = [len(k) for k in coeffs.keys()]
    if len(np.unique(key_lengths)) > 1:
        raise ValueError(
            "All detail coefficient names must have equal length.")

    return dict((k, np.asarray(v)) for k, v in coeffs.items())


def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
    """
    Single-level n-dimensional Inverse Discrete Wavelet Transform.

    Parameters
    ----------
    coeffs: dict
        Dictionary as in output of ``dwtn``. Missing or ``None`` items
        will be treated as zeros.
    wavelet : Wavelet object or name string, or tuple of wavelets
        Wavelet to use.  This can also be a tuple containing a wavelet to
        apply along each axis in ``axes``.
    mode : str or list of string, optional
        Signal extension mode used in the decomposition,
        see :ref:`Modes <ref-modes>`. This can also be a tuple of modes
        specifying the mode to use on each axis in ``axes``.
    axes : sequence of ints, optional
        Axes over which to compute the IDWT. Repeated elements mean the IDWT
        will be performed multiple times along these axes. A value of ``None``
        (the default) selects all axes.

        For the most accurate reconstruction, the axes should be provided in
        the same order as they were provided to ``dwtn``.

    Returns
    -------
    data: ndarray
        Original signal reconstructed from input data.

    """

    # drop the keys corresponding to value = None
    coeffs = dict((k, v) for k, v in coeffs.items() if v is not None)

    # Raise error for invalid key combinations
    coeffs = _fix_coeffs(coeffs)

    if (not _have_c99_complex and
            any(np.iscomplexobj(v) for v in coeffs.values())):
        real_coeffs = dict((k, v.real) for k, v in coeffs.items())
        imag_coeffs = dict((k, v.imag) for k, v in coeffs.items())
        return (idwtn(real_coeffs, wavelet, mode, axes) +
                1j * idwtn(imag_coeffs, wavelet, mode, axes))

    # key length matches the number of axes transformed
    ndim_transform = max(len(key) for key in coeffs.keys())

    try:
        coeff_shapes = (v.shape for k, v in coeffs.items()
                        if v is not None and len(k) == ndim_transform)
        coeff_shape = next(coeff_shapes)
    except StopIteration:
        raise ValueError("`coeffs` must contain at least one non-null wavelet "
                         "band")
    if any(s != coeff_shape for s in coeff_shapes):
        raise ValueError("`coeffs` must all be of equal size (or None)")

    if axes is None:
        axes = range(ndim_transform)
        ndim = ndim_transform
    else:
        ndim = len(coeff_shape)
    axes = [a + ndim if a < 0 else a for a in axes]

    modes = _modes_per_axis(mode, axes)
    wavelets = _wavelets_per_axis(wavelet, axes)
    for key_length, (axis, wav, mode) in reversed(
            list(enumerate(zip(axes, wavelets, modes)))):
        if axis < 0 or axis >= ndim:
            raise ValueError("Axis greater than data dimensions")

        new_coeffs = {}
        new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]

        for key in new_keys:
            L = coeffs.get(key + 'a', None)
            H = coeffs.get(key + 'd', None)
            if L is not None and H is not None:
                if L.dtype != H.dtype:
                    # upcast to a common dtype (float64 or complex128)
                    if L.dtype.kind == 'c' or H.dtype.kind == 'c':
                        dtype = np.complex128
                    else:
                        dtype = np.float64
                    L = np.asarray(L, dtype=dtype)
                    H = np.asarray(H, dtype=dtype)
            new_coeffs[key] = idwt_axis(L, H, wav, mode, axis)
        coeffs = new_coeffs

    return coeffs['']