"""
Test used to verify PyWavelets Discrete Wavelet Transform computation
accuracy against MathWorks Wavelet Toolbox.
"""

from __future__ import division, print_function, absolute_import

import numpy as np
import pytest
from numpy.testing import assert_

import pywt
from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict

# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
         ('constant', 'sp0'),
         ('symmetric', 'sym'),
         ('reflect', 'symw'),
         ('periodic', 'ppd'),
         ('smooth', 'sp1'),
         ('periodization', 'per'),
         # TODO: Now have implemented asymmetric modes too.
         #       Would be nice to update the Matlab data to test these as well.
         ('antisymmetric', 'asym'),
         ('antireflect', 'asymw'),
         ]

families = ('db', 'sym', 'coif', 'bior', 'rbio')
wavelets = sum([pywt.wavelist(name) for name in families], [])


def _get_data_sizes(w):
    """ Return the sizes to test for wavelet w. """
    if size_set == 'full':
        data_sizes = list(range(w.dec_len, 40)) + \
            [100, 200, 500, 1000, 50000]
    else:
        data_sizes = (w.dec_len, w.dec_len + 1)
    return data_sizes


@uses_pymatbridge
@pytest.mark.slow
def test_accuracy_pymatbridge():
    Matlab = pytest.importorskip("pymatbridge.Matlab")
    mlab = Matlab()

    rstate = np.random.RandomState(1234)
    # max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
    epsilon = 5.0e-5
    epsilon_pywt_coeffs = 1.0e-10
    mlab.start()
    try:
        for wavelet in wavelets:
            w = pywt.Wavelet(wavelet)
            mlab.set_variable('wavelet', wavelet)
            for N in _get_data_sizes(w):
                data = rstate.randn(N)
                mlab.set_variable('data', data)
                for pmode, mmode in modes:
                    ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
                    _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
                    ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
                    _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)

    finally:
        mlab.stop()


@uses_precomputed
@pytest.mark.slow
def test_accuracy_precomputed():
    # Keep this specific random seed to match the precomputed Matlab result.
    rstate = np.random.RandomState(1234)
    # max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
    epsilon = 5.0e-5
    epsilon_pywt_coeffs = 1.0e-10
    for wavelet in wavelets:
        w = pywt.Wavelet(wavelet)
        for N in _get_data_sizes(w):
            data = rstate.randn(N)
            for pmode, mmode in modes:
                ma, md = _load_matlab_result(data, wavelet, mmode)
                _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
                ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
                _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)


def _compute_matlab_result(data, wavelet, mmode, mlab):
    """ Compute the result using MATLAB.

    This function assumes that the Matlab variables `wavelet` and `data` have
    already been set externally.
    """
    if np.any((wavelet == np.array(['coif6', 'coif7', 'coif8', 'coif9', 'coif10', 'coif11', 'coif12', 'coif13', 'coif14', 'coif15', 'coif16', 'coif17'])),axis=0):
        w = pywt.Wavelet(wavelet)
        mlab.set_variable('Lo_D', w.dec_lo)
        mlab.set_variable('Hi_D', w.dec_hi)
        mlab_code = ("[ma, md] = dwt(data, Lo_D, Hi_D, 'mode', '%s');" % mmode)
    else:
        mlab_code = "[ma, md] = dwt(data, wavelet, 'mode', '%s');" % mmode
    res = mlab.run_code(mlab_code)
    if not res['success']:
        raise RuntimeError("Matlab failed to execute the provided code. "
                           "Check that the wavelet toolbox is installed.")
    # need np.asarray because sometimes the output is a single float64
    ma = np.asarray(mlab.get_variable('ma'))
    md = np.asarray(mlab.get_variable('md'))
    return ma, md


def _load_matlab_result(data, wavelet, mmode):
    """ Load the precomputed result.
    """
    N = len(data)
    ma_key = '_'.join([mmode, wavelet, str(N), 'ma'])
    md_key = '_'.join([mmode, wavelet, str(N), 'md'])
    if (ma_key not in matlab_result_dict) or \
            (md_key not in matlab_result_dict):
        raise KeyError(
            "Precompted Matlab result not found for wavelet: "
            "{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
    ma = matlab_result_dict[ma_key]
    md = matlab_result_dict[md_key]
    return ma, md


def _load_matlab_result_pywt_coeffs(data, wavelet, mmode):
    """ Load the precomputed result.
    """
    N = len(data)
    ma_key = '_'.join([mmode, wavelet, str(N), 'ma_pywtCoeffs'])
    md_key = '_'.join([mmode, wavelet, str(N), 'md_pywtCoeffs'])
    if (ma_key not in matlab_result_dict) or \
            (md_key not in matlab_result_dict):
        raise KeyError(
            "Precompted Matlab result not found for wavelet: "
            "{0}, mode: {1}, size: {2}".format(wavelet, mmode, N))
    ma = matlab_result_dict[ma_key]
    md = matlab_result_dict[md_key]
    return ma, md


def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
    # PyWavelets result
    pa, pd = pywt.dwt(data, w, pmode)

    # calculate error measures
    rms_a = np.sqrt(np.mean((pa - ma) ** 2))
    rms_d = np.sqrt(np.mean((pd - md) ** 2))

    msg = ('[RMS_A > EPSILON] for Mode: %s, Wavelet: %s, '
           'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_a))
    assert_(rms_a < epsilon, msg=msg)

    msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
           'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
    assert_(rms_d < epsilon, msg=msg)