#!/usr/bin/env python

from __future__ import division, print_function, absolute_import

import numpy as np
from numpy.testing import (assert_allclose, assert_, assert_raises,
                           assert_equal)

import pywt


def test_wavelet_packet_structure():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
    assert_(wp.path == '')
    assert_(wp.level == 0)
    assert_(wp['ad'].maxlevel == 3)


def test_traversing_wp_tree():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    assert_(wp.maxlevel == 3)

    # First level
    assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
                                           7.778174593052, 10.606601717798]),
                    rtol=1e-12)

    # Second level
    assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)

    # Third level
    assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)


def test_acess_path():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    assert_(wp['a'].path == 'a')
    assert_(wp['aa'].path == 'aa')
    assert_(wp['aaa'].path == 'aaa')

    # Maximum level reached:
    assert_raises(IndexError, lambda: wp['aaaa'].path)

    # Wrong path
    assert_raises(ValueError, lambda: wp['ac'].path)


def test_access_node_atributes():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
    assert_(wp['ad'].path == 'ad')
    assert_(wp['ad'].node_name == 'd')
    assert_(wp['ad'].parent.path == 'a')
    assert_(wp['ad'].level == 2)
    assert_(wp['ad'].maxlevel == 3)
    assert_(wp['ad'].mode == 'symmetric')


def test_collecting_nodes():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    # All nodes in natural order
    assert_([node.path for node in wp.get_level(3, 'natural')] ==
            ['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])

    # and in frequency order.
    assert_([node.path for node in wp.get_level(3, 'freq')] ==
            ['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])


def test_reconstructing_data():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')

    # Create another Wavelet Packet and feed it with some data.
    new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
    new_wp['aa'] = wp['aa'].data
    new_wp['ad'] = [-2., -2.]

    # For convenience, :attr:`Node.data` gets automatically extracted
    # from the :class:`Node` object:
    new_wp['d'] = wp['d']

    # Reconstruct data from aa, ad, and d packets.
    assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)

    # The node's :attr:`~Node.data` will not be updated
    assert_(new_wp.data is None)

    # When `update` is True:
    assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
    assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)

    assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
            ['aa', 'ad', 'd'])
    assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
            ['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])


def test_removing_nodes():
    x = [1, 2, 3, 4, 5, 6, 7, 8]
    wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
    wp.get_level(2)

    dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
    expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])

    for i in range(4):
        assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)

    node = wp['ad']
    del(wp['ad'])
    dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
    expected = np.array([[5., 13.], [-1, -1], [0, 0]])

    for i in range(3):
        assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)

    wp.reconstruct()
    # The reconstruction is:
    assert_allclose(wp.reconstruct(),
                    np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)

    # Restore the data
    wp['ad'].data = node.data

    dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
    expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])

    for i in range(4):
        assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)

    assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)


def test_wavelet_packet_dtypes():
    N = 32
    for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
        x = np.random.randn(N).astype(dtype)
        if np.iscomplexobj(x):
            x = x + 1j*np.random.randn(N).astype(x.real.dtype)
        wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
        # no unnecessary copy made
        assert_(wp.data is x)

        # assiging to a node should not change supported dtypes
        wp['d'] = wp['d'].data
        assert_equal(wp['d'].data.dtype, x.dtype)

        # full decomposition
        wp.get_level(wp.maxlevel)

        # reconstruction from coefficients should preserve dtype
        r = wp.reconstruct(False)
        assert_equal(r.dtype, x.dtype)
        assert_allclose(r, x, atol=1e-5, rtol=1e-5)

    # first element of the tuple is the input dtype
    # second element of the tuple is the transform dtype
    dtype_pairs = [(np.uint8, np.float64),
                   (np.intp, np.float64), ]
    if hasattr(np, "complex256"):
        dtype_pairs += [(np.complex256, np.complex128), ]
    if hasattr(np, "half"):
        dtype_pairs += [(np.half, np.float32), ]
    for (dtype, transform_dtype) in dtype_pairs:
        x = np.arange(N, dtype=dtype)
        wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')

        # no unnecessary copy made of top-level data
        assert_(wp.data is x)

        # full decomposition
        wp.get_level(wp.maxlevel)

        # reconstructed data will have modified dtype
        r = wp.reconstruct(False)
        assert_equal(r.dtype, transform_dtype)
        assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)


def test_db3_roundtrip():
    original = np.arange(512)
    wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
                            maxlevel=3)
    r = wp.reconstruct()
    assert_allclose(original, r, atol=1e-12, rtol=1e-12)