Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
197
venv/Lib/site-packages/pywt/tests/test_wp.py
Normal file
|
@ -0,0 +1,197 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
||||
assert_equal)
|
||||
|
||||
import pywt
|
||||
|
||||
|
||||
def test_wavelet_packet_structure():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.data == [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
assert_(wp.path == '')
|
||||
assert_(wp.level == 0)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
|
||||
|
||||
def test_traversing_wp_tree():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp.maxlevel == 3)
|
||||
|
||||
# First level
|
||||
assert_allclose(wp['a'].data, np.array([2.12132034356, 4.949747468306,
|
||||
7.778174593052, 10.606601717798]),
|
||||
rtol=1e-12)
|
||||
|
||||
# Second level
|
||||
assert_allclose(wp['aa'].data, np.array([5., 13.]), rtol=1e-12)
|
||||
|
||||
# Third level
|
||||
assert_allclose(wp['aaa'].data, np.array([12.727922061358]), rtol=1e-12)
|
||||
|
||||
|
||||
def test_acess_path():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_(wp['a'].path == 'a')
|
||||
assert_(wp['aa'].path == 'aa')
|
||||
assert_(wp['aaa'].path == 'aaa')
|
||||
|
||||
# Maximum level reached:
|
||||
assert_raises(IndexError, lambda: wp['aaaa'].path)
|
||||
|
||||
# Wrong path
|
||||
assert_raises(ValueError, lambda: wp['ac'].path)
|
||||
|
||||
|
||||
def test_access_node_atributes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
assert_allclose(wp['ad'].data, np.array([-2., -2.]), rtol=1e-12)
|
||||
assert_(wp['ad'].path == 'ad')
|
||||
assert_(wp['ad'].node_name == 'd')
|
||||
assert_(wp['ad'].parent.path == 'a')
|
||||
assert_(wp['ad'].level == 2)
|
||||
assert_(wp['ad'].maxlevel == 3)
|
||||
assert_(wp['ad'].mode == 'symmetric')
|
||||
|
||||
|
||||
def test_collecting_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# All nodes in natural order
|
||||
assert_([node.path for node in wp.get_level(3, 'natural')] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
# and in frequency order.
|
||||
assert_([node.path for node in wp.get_level(3, 'freq')] ==
|
||||
['aaa', 'aad', 'add', 'ada', 'dda', 'ddd', 'dad', 'daa'])
|
||||
|
||||
|
||||
def test_reconstructing_data():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# Create another Wavelet Packet and feed it with some data.
|
||||
new_wp = pywt.WaveletPacket(data=None, wavelet='db1', mode='symmetric')
|
||||
new_wp['aa'] = wp['aa'].data
|
||||
new_wp['ad'] = [-2., -2.]
|
||||
|
||||
# For convenience, :attr:`Node.data` gets automatically extracted
|
||||
# from the :class:`Node` object:
|
||||
new_wp['d'] = wp['d']
|
||||
|
||||
# Reconstruct data from aa, ad, and d packets.
|
||||
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
||||
|
||||
# The node's :attr:`~Node.data` will not be updated
|
||||
assert_(new_wp.data is None)
|
||||
|
||||
# When `update` is True:
|
||||
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
||||
assert_allclose(new_wp.data, np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(False)] ==
|
||||
['aa', 'ad', 'd'])
|
||||
assert_([n.path for n in new_wp.get_leaf_nodes(True)] ==
|
||||
['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'])
|
||||
|
||||
|
||||
def test_removing_nodes():
|
||||
x = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
wp.get_level(2)
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
node = wp['ad']
|
||||
del(wp['ad'])
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(3):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
wp.reconstruct()
|
||||
# The reconstruction is:
|
||||
assert_allclose(wp.reconstruct(),
|
||||
np.array([2., 3., 2., 3., 6., 7., 6., 7.]), rtol=1e-12)
|
||||
|
||||
# Restore the data
|
||||
wp['ad'].data = node.data
|
||||
|
||||
dataleafs = [n.data for n in wp.get_leaf_nodes(False)]
|
||||
expected = np.array([[5., 13.], [-2, -2], [-1, -1], [0, 0]])
|
||||
|
||||
for i in range(4):
|
||||
assert_allclose(dataleafs[i], expected[i, :], atol=1e-12)
|
||||
|
||||
assert_allclose(wp.reconstruct(), np.arange(1, 9), rtol=1e-12)
|
||||
|
||||
|
||||
def test_wavelet_packet_dtypes():
|
||||
N = 32
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
||||
x = np.random.randn(N).astype(dtype)
|
||||
if np.iscomplexobj(x):
|
||||
x = x + 1j*np.random.randn(N).astype(x.real.dtype)
|
||||
wp = pywt.WaveletPacket(data=x, wavelet='db1', mode='symmetric')
|
||||
# no unnecessary copy made
|
||||
assert_(wp.data is x)
|
||||
|
||||
# assiging to a node should not change supported dtypes
|
||||
wp['d'] = wp['d'].data
|
||||
assert_equal(wp['d'].data.dtype, x.dtype)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstruction from coefficients should preserve dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, x.dtype)
|
||||
assert_allclose(r, x, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# first element of the tuple is the input dtype
|
||||
# second element of the tuple is the transform dtype
|
||||
dtype_pairs = [(np.uint8, np.float64),
|
||||
(np.intp, np.float64), ]
|
||||
if hasattr(np, "complex256"):
|
||||
dtype_pairs += [(np.complex256, np.complex128), ]
|
||||
if hasattr(np, "half"):
|
||||
dtype_pairs += [(np.half, np.float32), ]
|
||||
for (dtype, transform_dtype) in dtype_pairs:
|
||||
x = np.arange(N, dtype=dtype)
|
||||
wp = pywt.WaveletPacket(x, wavelet='db1', mode='symmetric')
|
||||
|
||||
# no unnecessary copy made of top-level data
|
||||
assert_(wp.data is x)
|
||||
|
||||
# full decomposition
|
||||
wp.get_level(wp.maxlevel)
|
||||
|
||||
# reconstructed data will have modified dtype
|
||||
r = wp.reconstruct(False)
|
||||
assert_equal(r.dtype, transform_dtype)
|
||||
assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_db3_roundtrip():
|
||||
original = np.arange(512)
|
||||
wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth',
|
||||
maxlevel=3)
|
||||
r = wp.reconstruct()
|
||||
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
|
Loading…
Add table
Add a link
Reference in a new issue