"""
Tests used to verify running PyWavelets transforms in parallel via
concurrent.futures.ThreadPoolExecutor does not raise errors.
"""

from __future__ import division, print_function, absolute_import

import warnings
import numpy as np
from functools import partial
from numpy.testing import assert_array_equal, assert_allclose
from pywt._pytest import uses_futures, futures, max_workers

import pywt


def _assert_all_coeffs_equal(coefs1, coefs2):
    # return True only if all coefficients of SWT or DWT match over all levels
    if len(coefs1) != len(coefs2):
        return False
    for (c1, c2) in zip(coefs1, coefs2):
        if isinstance(c1, tuple):
            # for swt, swt2, dwt, dwt2, wavedec, wavedec2
            for a1, a2 in zip(c1, c2):
                assert_array_equal(a1, a2)
        elif isinstance(c1, dict):
            # for swtn, dwtn, wavedecn
            for k, v in c1.items():
                assert_array_equal(v, c2[k])
        else:
            return False
    return True


@uses_futures
def test_concurrent_swt():
    # tests error-free concurrent operation (see gh-288)
    # swt on 1D data calls the Cython swt
    # other cases call swt_axes
    with warnings.catch_warnings():
        # can remove catch_warnings once the swt2 FutureWarning is removed
        warnings.simplefilter('ignore', FutureWarning)
        for swt_func, x in zip([pywt.swt, pywt.swt2, pywt.swtn],
                               [np.ones(8), np.eye(16), np.eye(16)]):
            transform = partial(swt_func, wavelet='haar', level=3)
            for _ in range(10):
                arrs = [x.copy() for _ in range(100)]
                with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
                    results = list(ex.map(transform, arrs))

        # validate result from  one of the concurrent runs
        expected_result = transform(x)
        _assert_all_coeffs_equal(expected_result, results[-1])


@uses_futures
def test_concurrent_wavedec():
    # wavedec on 1D data calls the Cython dwt_single
    # other cases call dwt_axis
    for wavedec_func, x in zip([pywt.wavedec, pywt.wavedec2, pywt.wavedecn],
                               [np.ones(8), np.eye(16), np.eye(16)]):
        transform = partial(wavedec_func, wavelet='haar', level=1)
        for _ in range(10):
            arrs = [x.copy() for _ in range(100)]
            with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
                results = list(ex.map(transform, arrs))

        # validate result from  one of the concurrent runs
        expected_result = transform(x)
        _assert_all_coeffs_equal(expected_result, results[-1])


@uses_futures
def test_concurrent_dwt():
    # dwt on 1D data calls the Cython dwt_single
    # other cases call dwt_axis
    for dwt_func, x in zip([pywt.dwt, pywt.dwt2, pywt.dwtn],
                           [np.ones(8), np.eye(16), np.eye(16)]):
        transform = partial(dwt_func, wavelet='haar')
        for _ in range(10):
            arrs = [x.copy() for _ in range(100)]
            with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
                results = list(ex.map(transform, arrs))

        # validate result from  one of the concurrent runs
        expected_result = transform(x)
        _assert_all_coeffs_equal([expected_result, ], [results[-1], ])


@uses_futures
def test_concurrent_cwt():
    atol = rtol = 1e-14
    time, sst = pywt.data.nino()
    dt = time[1]-time[0]
    transform = partial(pywt.cwt, scales=np.arange(1, 4), wavelet='cmor1.5-1',
                        sampling_period=dt)
    for _ in range(10):
        arrs = [sst.copy() for _ in range(50)]
        with futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
            results = list(ex.map(transform, arrs))

    # validate result from  one of the concurrent runs
    expected_result = transform(sst)
    for a1, a2 in zip(expected_result, results[-1]):
        assert_allclose(a1, a2, atol=atol, rtol=rtol)