Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
733
venv/Lib/site-packages/pywt/_wavelet_packets.py
Normal file
733
venv/Lib/site-packages/pywt/_wavelet_packets.py
Normal file
|
@ -0,0 +1,733 @@
|
|||
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
|
||||
# Copyright (c) 2012-2016 The PyWavelets Developers
|
||||
# <https://github.com/PyWavelets/pywt>
|
||||
# See COPYING for license details.
|
||||
|
||||
"""1D and 2D Wavelet packet transform module."""
|
||||
|
||||
from __future__ import division, print_function, absolute_import
|
||||
|
||||
__all__ = ["BaseNode", "Node", "WaveletPacket", "Node2D", "WaveletPacket2D"]
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._extensions._pywt import Wavelet, _check_dtype
|
||||
from ._dwt import dwt, idwt, dwt_max_level
|
||||
from ._multidim import dwt2, idwt2
|
||||
|
||||
|
||||
def get_graycode_order(level, x='a', y='d'):
|
||||
graycode_order = [x, y]
|
||||
for i in range(level - 1):
|
||||
graycode_order = [x + path for path in graycode_order] + \
|
||||
[y + path for path in graycode_order[::-1]]
|
||||
return graycode_order
|
||||
|
||||
|
||||
class BaseNode(object):
|
||||
"""
|
||||
BaseNode for wavelet packet 1D and 2D tree nodes.
|
||||
|
||||
The BaseNode is a base class for `Node` and `Node2D`.
|
||||
It should not be used directly unless creating a new transformation
|
||||
type. It is included here to document the common interface of 1D
|
||||
and 2D node and wavelet packet transform classes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parent :
|
||||
Parent node. If parent is None then the node is considered detached
|
||||
(ie root).
|
||||
data : 1D or 2D array
|
||||
Data associated with the node. 1D or 2D numeric array, depending on the
|
||||
transform type.
|
||||
node_name :
|
||||
A name identifying the coefficients type.
|
||||
See `Node.node_name` and `Node2D.node_name`
|
||||
for information on the accepted subnodes names.
|
||||
"""
|
||||
|
||||
# PART_LEN and PARTS attributes that define path tokens for node[] lookup
|
||||
# must be defined in subclasses.
|
||||
PART_LEN = None
|
||||
PARTS = None
|
||||
|
||||
def __init__(self, parent, data, node_name):
|
||||
self.parent = parent
|
||||
if parent is not None:
|
||||
self.wavelet = parent.wavelet
|
||||
self.mode = parent.mode
|
||||
self.level = parent.level + 1
|
||||
self._maxlevel = parent.maxlevel
|
||||
self.path = parent.path + node_name
|
||||
else:
|
||||
self.wavelet = None
|
||||
self.mode = None
|
||||
self.path = ""
|
||||
self.level = 0
|
||||
|
||||
# data - signal on level 0, coeffs on higher levels
|
||||
self.data = data
|
||||
# Need to retain original data size/shape so we can trim any excess
|
||||
# boundary coefficients from the inverse transform.
|
||||
if self.data is None:
|
||||
self._data_shape = None
|
||||
else:
|
||||
self._data_shape = np.asarray(data).shape
|
||||
|
||||
self._init_subnodes()
|
||||
|
||||
def _init_subnodes(self):
|
||||
for part in self.PARTS:
|
||||
self._set_node(part, None)
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _create_subnode_base(self, node_cls, part, data=None, overwrite=True):
|
||||
self._validate_node_name(part)
|
||||
if not overwrite and self._get_node(part) is not None:
|
||||
return self._get_node(part)
|
||||
node = node_cls(self, data, part)
|
||||
self._set_node(part, node)
|
||||
return node
|
||||
|
||||
def _get_node(self, part):
|
||||
return getattr(self, part)
|
||||
|
||||
def _set_node(self, part, node):
|
||||
setattr(self, part, node)
|
||||
|
||||
def _delete_node(self, part):
|
||||
self._set_node(part, None)
|
||||
|
||||
def _validate_node_name(self, part):
|
||||
if part not in self.PARTS:
|
||||
raise ValueError("Subnode name must be in [%s], not '%s'." %
|
||||
(', '.join("'%s'" % p for p in self.PARTS), part))
|
||||
|
||||
def _evaluate_maxlevel(self, evaluate_from='parent'):
|
||||
"""
|
||||
Try to find the value of maximum decomposition level if it is not
|
||||
specified explicitly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
evaluate_from : {'parent', 'subnodes'}
|
||||
"""
|
||||
assert evaluate_from in ('parent', 'subnodes')
|
||||
|
||||
if self._maxlevel is not None:
|
||||
return self._maxlevel
|
||||
elif self.data is not None:
|
||||
return self.level + dwt_max_level(
|
||||
min(self.data.shape), self.wavelet)
|
||||
|
||||
if evaluate_from == 'parent':
|
||||
if self.parent is not None:
|
||||
return self.parent._evaluate_maxlevel(evaluate_from)
|
||||
elif evaluate_from == 'subnodes':
|
||||
for node_name in self.PARTS:
|
||||
node = getattr(self, node_name, None)
|
||||
if node is not None:
|
||||
level = node._evaluate_maxlevel(evaluate_from)
|
||||
if level is not None:
|
||||
return level
|
||||
return None
|
||||
|
||||
@property
|
||||
def maxlevel(self):
|
||||
if self._maxlevel is not None:
|
||||
return self._maxlevel
|
||||
|
||||
# Try getting the maxlevel from parents first
|
||||
self._maxlevel = self._evaluate_maxlevel(evaluate_from='parent')
|
||||
|
||||
# If not found, check whether it can be evaluated from subnodes
|
||||
if self._maxlevel is None:
|
||||
self._maxlevel = self._evaluate_maxlevel(evaluate_from='subnodes')
|
||||
return self._maxlevel
|
||||
|
||||
@property
|
||||
def node_name(self):
|
||||
return self.path[-self.PART_LEN:]
|
||||
|
||||
def decompose(self):
|
||||
"""
|
||||
Decompose node data creating DWT coefficients subnodes.
|
||||
|
||||
Performs Discrete Wavelet Transform on the `~BaseNode.data` and
|
||||
returns transform coefficients.
|
||||
|
||||
Note
|
||||
----
|
||||
Descends to subnodes and recursively
|
||||
calls `~BaseNode.reconstruct` on them.
|
||||
|
||||
"""
|
||||
if self.level < self.maxlevel:
|
||||
return self._decompose()
|
||||
else:
|
||||
raise ValueError("Maximum decomposition level reached.")
|
||||
|
||||
def _decompose(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reconstruct(self, update=False):
|
||||
"""
|
||||
Reconstruct node from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True, then reconstructed data replaces the current
|
||||
node data (default: False).
|
||||
|
||||
Returns:
|
||||
- original node data if subnodes do not exist
|
||||
- IDWT of subnodes otherwise.
|
||||
"""
|
||||
if not self.has_any_subnode:
|
||||
return self.data
|
||||
return self._reconstruct(update)
|
||||
|
||||
def _reconstruct(self):
|
||||
raise NotImplementedError() # override this in subclasses
|
||||
|
||||
def get_subnode(self, part, decompose=True):
|
||||
"""
|
||||
Returns subnode or None (see `decomposition` flag description).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
part :
|
||||
Subnode name
|
||||
decompose : bool, optional
|
||||
If the param is True and corresponding subnode does not
|
||||
exist, the subnode will be created using coefficients
|
||||
from the DWT decomposition of the current node.
|
||||
(default: True)
|
||||
"""
|
||||
self._validate_node_name(part)
|
||||
subnode = self._get_node(part)
|
||||
if subnode is None and decompose and not self.is_empty:
|
||||
self.decompose()
|
||||
subnode = self._get_node(part)
|
||||
return subnode
|
||||
|
||||
def __getitem__(self, path):
|
||||
"""
|
||||
Find node represented by the given path.
|
||||
|
||||
Similar to `~BaseNode.get_subnode` method with `decompose=True`, but
|
||||
can access nodes on any level in the decomposition tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names. See `Node.node_name` and
|
||||
`Node2D.node_name` for node naming convention.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If node does not exist yet, it will be created by decomposition of its
|
||||
parent node.
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
if (self.maxlevel is not None
|
||||
and len(path) > self.maxlevel * self.PART_LEN):
|
||||
raise IndexError("Path length is out of range.")
|
||||
if path:
|
||||
return self.get_subnode(path[0:self.PART_LEN], True)[
|
||||
path[self.PART_LEN:]]
|
||||
else:
|
||||
return self
|
||||
else:
|
||||
raise TypeError("Invalid path parameter type - expected string but"
|
||||
" got %s." % type(path))
|
||||
|
||||
def __setitem__(self, path, data):
|
||||
"""
|
||||
Set node or node's data in the decomposition tree. Nodes are
|
||||
identified by string `path`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names.
|
||||
data : array or BaseNode subclass.
|
||||
"""
|
||||
|
||||
if isinstance(path, str):
|
||||
if (
|
||||
self.maxlevel is not None
|
||||
and len(self.path) + len(path) > self.maxlevel * self.PART_LEN
|
||||
):
|
||||
raise IndexError("Path length out of range.")
|
||||
if path:
|
||||
subnode = self.get_subnode(path[0:self.PART_LEN], False)
|
||||
if subnode is None:
|
||||
self._create_subnode(path[0:self.PART_LEN], None)
|
||||
subnode = self.get_subnode(path[0:self.PART_LEN], False)
|
||||
subnode[path[self.PART_LEN:]] = data
|
||||
else:
|
||||
if isinstance(data, BaseNode):
|
||||
self.data = np.asarray(data.data)
|
||||
else:
|
||||
self.data = np.asarray(data)
|
||||
# convert data to nearest supported dtype
|
||||
dtype = _check_dtype(data)
|
||||
if self.data.dtype != dtype:
|
||||
self.data = self.data.astype(dtype)
|
||||
else:
|
||||
raise TypeError("Invalid path parameter type - expected string but"
|
||||
" got %s." % type(path))
|
||||
|
||||
def __delitem__(self, path):
|
||||
"""
|
||||
Remove node from the tree.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
String composed of node names.
|
||||
"""
|
||||
node = self[path]
|
||||
# don't clear node value and subnodes (node may still exist outside
|
||||
# the tree)
|
||||
# # node._init_subnodes()
|
||||
# # node.data = None
|
||||
parent = node.parent
|
||||
node.parent = None # TODO
|
||||
if parent and node.node_name:
|
||||
parent._delete_node(node.node_name)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return self.data is None
|
||||
|
||||
@property
|
||||
def has_any_subnode(self):
|
||||
for part in self.PARTS:
|
||||
if self._get_node(part) is not None: # and not .is_empty
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_leaf_nodes(self, decompose=False):
|
||||
"""
|
||||
Returns leaf nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decompose : bool, optional
|
||||
(default: True)
|
||||
"""
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == node.maxlevel and not node.is_empty:
|
||||
result.append(node)
|
||||
return False
|
||||
if not decompose and not node.has_any_subnode:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
self.walk(collect, decompose=decompose)
|
||||
return result
|
||||
|
||||
def walk(self, func, args=(), kwargs=None, decompose=True):
|
||||
"""
|
||||
Traverses the decomposition tree and calls
|
||||
``func(node, *args, **kwargs)`` on every node. If `func` returns True,
|
||||
descending to subnodes will continue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable accepting `BaseNode` as the first param and
|
||||
optional positional and keyword arguments
|
||||
args :
|
||||
func params
|
||||
kwargs :
|
||||
func keyword params
|
||||
decompose : bool, optional
|
||||
If True (default), the method will also try to decompose the tree
|
||||
up to the `maximum level <BaseNode.maxlevel>`.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func(self, *args, **kwargs) and self.level < self.maxlevel:
|
||||
for part in self.PARTS:
|
||||
subnode = self.get_subnode(part, decompose)
|
||||
if subnode is not None:
|
||||
subnode.walk(func, args, kwargs, decompose)
|
||||
|
||||
def walk_depth(self, func, args=(), kwargs=None, decompose=True):
|
||||
"""
|
||||
Walk tree and call func on every node starting from the bottom-most
|
||||
nodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
Callable accepting :class:`BaseNode` as the first param and
|
||||
optional positional and keyword arguments
|
||||
args :
|
||||
func params
|
||||
kwargs :
|
||||
func keyword params
|
||||
decompose : bool, optional
|
||||
(default: False)
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.level < self.maxlevel:
|
||||
for part in self.PARTS:
|
||||
subnode = self.get_subnode(part, decompose)
|
||||
if subnode is not None:
|
||||
subnode.walk_depth(func, args, kwargs, decompose)
|
||||
func(self, *args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return self.path + ": " + str(self.data)
|
||||
|
||||
|
||||
class Node(BaseNode):
|
||||
"""
|
||||
WaveletPacket tree node.
|
||||
|
||||
Subnodes are called `a` and `d`, just like approximation
|
||||
and detail coefficients in the Discrete Wavelet Transform.
|
||||
"""
|
||||
|
||||
A = 'a'
|
||||
D = 'd'
|
||||
PARTS = A, D
|
||||
PART_LEN = 1
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
return self._create_subnode_base(node_cls=Node, part=part, data=data,
|
||||
overwrite=overwrite)
|
||||
|
||||
def _decompose(self):
|
||||
"""
|
||||
|
||||
See also
|
||||
--------
|
||||
dwt : for 1D Discrete Wavelet Transform output coefficients.
|
||||
"""
|
||||
if self.is_empty:
|
||||
data_a, data_d = None, None
|
||||
if self._get_node(self.A) is None:
|
||||
self._create_subnode(self.A, data_a)
|
||||
if self._get_node(self.D) is None:
|
||||
self._create_subnode(self.D, data_d)
|
||||
else:
|
||||
data_a, data_d = dwt(self.data, self.wavelet, self.mode)
|
||||
self._create_subnode(self.A, data_a)
|
||||
self._create_subnode(self.D, data_d)
|
||||
return self._get_node(self.A), self._get_node(self.D)
|
||||
|
||||
def _reconstruct(self, update):
|
||||
data_a, data_d = None, None
|
||||
node_a, node_d = self._get_node(self.A), self._get_node(self.D)
|
||||
|
||||
if node_a is not None:
|
||||
data_a = node_a.reconstruct() # TODO: (update) ???
|
||||
if node_d is not None:
|
||||
data_d = node_d.reconstruct() # TODO: (update) ???
|
||||
|
||||
if data_a is None and data_d is None:
|
||||
raise ValueError("Node is a leaf node and cannot be reconstructed"
|
||||
" from subnodes.")
|
||||
else:
|
||||
rec = idwt(data_a, data_d, self.wavelet, self.mode)
|
||||
if self._data_shape is not None and (
|
||||
rec.shape != self._data_shape):
|
||||
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
|
||||
if update:
|
||||
self.data = rec
|
||||
return rec
|
||||
|
||||
|
||||
class Node2D(BaseNode):
|
||||
"""
|
||||
WaveletPacket tree node.
|
||||
|
||||
Subnodes are called 'a' (LL), 'h' (HL), 'v' (LH) and 'd' (HH), like
|
||||
approximation and detail coefficients in the 2D Discrete Wavelet Transform
|
||||
"""
|
||||
|
||||
LL = 'a'
|
||||
HL = 'h'
|
||||
LH = 'v'
|
||||
HH = 'd'
|
||||
|
||||
PARTS = LL, HL, LH, HH
|
||||
PART_LEN = 1
|
||||
|
||||
def _create_subnode(self, part, data=None, overwrite=True):
|
||||
return self._create_subnode_base(node_cls=Node2D, part=part, data=data,
|
||||
overwrite=overwrite)
|
||||
|
||||
def _decompose(self):
|
||||
"""
|
||||
See also
|
||||
--------
|
||||
dwt2 : for 2D Discrete Wavelet Transform output coefficients.
|
||||
"""
|
||||
if self.is_empty:
|
||||
data_ll, data_lh, data_hl, data_hh = None, None, None, None
|
||||
else:
|
||||
data_ll, (data_hl, data_lh, data_hh) =\
|
||||
dwt2(self.data, self.wavelet, self.mode)
|
||||
self._create_subnode(self.LL, data_ll)
|
||||
self._create_subnode(self.LH, data_lh)
|
||||
self._create_subnode(self.HL, data_hl)
|
||||
self._create_subnode(self.HH, data_hh)
|
||||
return (self._get_node(self.LL), self._get_node(self.HL),
|
||||
self._get_node(self.LH), self._get_node(self.HH))
|
||||
|
||||
def _reconstruct(self, update):
|
||||
data_ll, data_lh, data_hl, data_hh = None, None, None, None
|
||||
|
||||
node_ll, node_lh, node_hl, node_hh =\
|
||||
self._get_node(self.LL), self._get_node(self.LH),\
|
||||
self._get_node(self.HL), self._get_node(self.HH)
|
||||
|
||||
if node_ll is not None:
|
||||
data_ll = node_ll.reconstruct()
|
||||
if node_lh is not None:
|
||||
data_lh = node_lh.reconstruct()
|
||||
if node_hl is not None:
|
||||
data_hl = node_hl.reconstruct()
|
||||
if node_hh is not None:
|
||||
data_hh = node_hh.reconstruct()
|
||||
|
||||
if (data_ll is None and data_lh is None
|
||||
and data_hl is None and data_hh is None):
|
||||
raise ValueError(
|
||||
"Tree is missing data - all subnodes of `%s` node "
|
||||
"are None. Cannot reconstruct node." % self.path
|
||||
)
|
||||
else:
|
||||
coeffs = data_ll, (data_hl, data_lh, data_hh)
|
||||
rec = idwt2(coeffs, self.wavelet, self.mode)
|
||||
if self._data_shape is not None and (
|
||||
rec.shape != self._data_shape):
|
||||
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
|
||||
if update:
|
||||
self.data = rec
|
||||
return rec
|
||||
|
||||
def expand_2d_path(self, path):
|
||||
expanded_paths = {
|
||||
self.HH: 'hh',
|
||||
self.HL: 'hl',
|
||||
self.LH: 'lh',
|
||||
self.LL: 'll'
|
||||
}
|
||||
return (''.join([expanded_paths[p][0] for p in path]),
|
||||
''.join([expanded_paths[p][1] for p in path]))
|
||||
|
||||
|
||||
class WaveletPacket(Node):
|
||||
"""
|
||||
Data structure representing Wavelet Packet decomposition of signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : 1D ndarray
|
||||
Original data (signal)
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet used in DWT decomposition and reconstruction
|
||||
mode : str, optional
|
||||
Signal extension mode for the `dwt` and `idwt` decomposition and
|
||||
reconstruction functions.
|
||||
maxlevel : int, optional
|
||||
Maximum level of decomposition.
|
||||
If None, it will be calculated based on the `wavelet` and `data`
|
||||
length using `pywt.dwt_max_level`.
|
||||
"""
|
||||
def __init__(self, data, wavelet, mode='symmetric', maxlevel=None):
|
||||
super(WaveletPacket, self).__init__(None, data, "")
|
||||
|
||||
if not isinstance(wavelet, Wavelet):
|
||||
wavelet = Wavelet(wavelet)
|
||||
self.wavelet = wavelet
|
||||
self.mode = mode
|
||||
|
||||
if data is not None:
|
||||
data = np.asarray(data)
|
||||
assert data.ndim == 1
|
||||
self.data_size = data.shape[0]
|
||||
if maxlevel is None:
|
||||
maxlevel = dwt_max_level(self.data_size, self.wavelet)
|
||||
else:
|
||||
self.data_size = None
|
||||
|
||||
self._maxlevel = maxlevel
|
||||
|
||||
def reconstruct(self, update=True):
|
||||
"""
|
||||
Reconstruct data value using coefficients from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True (default), then data values will be replaced by
|
||||
reconstruction values, also in subnodes.
|
||||
"""
|
||||
if self.has_any_subnode:
|
||||
data = super(WaveletPacket, self).reconstruct(update)
|
||||
if update:
|
||||
self.data = data
|
||||
return data
|
||||
return self.data # return original data
|
||||
|
||||
def get_level(self, level, order="natural", decompose=True):
|
||||
"""
|
||||
Returns all nodes on the specified level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : int
|
||||
Specifies decomposition `level` from which the nodes will be
|
||||
collected.
|
||||
order : {'natural', 'freq'}, optional
|
||||
- "natural" - left to right in tree (default)
|
||||
- "freq" - band ordered
|
||||
decompose : bool, optional
|
||||
If set then the method will try to decompose the data up
|
||||
to the specified `level` (default: True).
|
||||
|
||||
Notes
|
||||
-----
|
||||
If nodes at the given level are missing (i.e. the tree is partially
|
||||
decomposed) and the `decompose` is set to False, only existing nodes
|
||||
will be returned.
|
||||
"""
|
||||
assert order in ["natural", "freq"]
|
||||
if level > self.maxlevel:
|
||||
raise ValueError("The level cannot be greater than the maximum"
|
||||
" decomposition level value (%d)" % self.maxlevel)
|
||||
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == level:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
|
||||
self.walk(collect, decompose=decompose)
|
||||
if order == "natural":
|
||||
return result
|
||||
elif order == "freq":
|
||||
result = dict((node.path, node) for node in result)
|
||||
graycode_order = get_graycode_order(level)
|
||||
return [result[path] for path in graycode_order if path in result]
|
||||
else:
|
||||
raise ValueError("Invalid order name - %s." % order)
|
||||
|
||||
|
||||
class WaveletPacket2D(Node2D):
|
||||
"""
|
||||
Data structure representing 2D Wavelet Packet decomposition of signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : 2D ndarray
|
||||
Data associated with the node.
|
||||
wavelet : Wavelet object or name string
|
||||
Wavelet used in DWT decomposition and reconstruction
|
||||
mode : str, optional
|
||||
Signal extension mode for the `dwt` and `idwt` decomposition and
|
||||
reconstruction functions.
|
||||
maxlevel : int
|
||||
Maximum level of decomposition.
|
||||
If None, it will be calculated based on the `wavelet` and `data`
|
||||
length using `pywt.dwt_max_level`.
|
||||
"""
|
||||
def __init__(self, data, wavelet, mode='smooth', maxlevel=None):
|
||||
super(WaveletPacket2D, self).__init__(None, data, "")
|
||||
|
||||
if not isinstance(wavelet, Wavelet):
|
||||
wavelet = Wavelet(wavelet)
|
||||
self.wavelet = wavelet
|
||||
self.mode = mode
|
||||
|
||||
if data is not None:
|
||||
data = np.asarray(data)
|
||||
assert data.ndim == 2
|
||||
self.data_size = data.shape
|
||||
if maxlevel is None:
|
||||
maxlevel = dwt_max_level(min(self.data_size), self.wavelet)
|
||||
else:
|
||||
self.data_size = None
|
||||
self._maxlevel = maxlevel
|
||||
|
||||
def reconstruct(self, update=True):
|
||||
"""
|
||||
Reconstruct data using coefficients from subnodes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
update : bool, optional
|
||||
If True (default) then the coefficients of the current node
|
||||
and its subnodes will be replaced with values from reconstruction.
|
||||
"""
|
||||
if self.has_any_subnode:
|
||||
data = super(WaveletPacket2D, self).reconstruct(update)
|
||||
if update:
|
||||
self.data = data
|
||||
return data
|
||||
return self.data # return original data
|
||||
|
||||
def get_level(self, level, order="natural", decompose=True):
|
||||
"""
|
||||
Returns all nodes from specified level.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : int
|
||||
Decomposition `level` from which the nodes will be
|
||||
collected.
|
||||
order : {'natural', 'freq'}, optional
|
||||
If `natural` (default) a flat list is returned.
|
||||
If `freq`, a 2d structure with rows and cols
|
||||
sorted by corresponding dimension frequency of 2d
|
||||
coefficient array (adapted from 1d case).
|
||||
decompose : bool, optional
|
||||
If set then the method will try to decompose the data up
|
||||
to the specified `level` (default: True).
|
||||
"""
|
||||
assert order in ["natural", "freq"]
|
||||
if level > self.maxlevel:
|
||||
raise ValueError("The level cannot be greater than the maximum"
|
||||
" decomposition level value (%d)" % self.maxlevel)
|
||||
|
||||
result = []
|
||||
|
||||
def collect(node):
|
||||
if node.level == level:
|
||||
result.append(node)
|
||||
return False
|
||||
return True
|
||||
|
||||
self.walk(collect, decompose=decompose)
|
||||
|
||||
if order == "freq":
|
||||
nodes = {}
|
||||
for (row_path, col_path), node in [
|
||||
(self.expand_2d_path(node.path), node) for node in result
|
||||
]:
|
||||
nodes.setdefault(row_path, {})[col_path] = node
|
||||
graycode_order = get_graycode_order(level, x='l', y='h')
|
||||
nodes = [nodes[path] for path in graycode_order if path in nodes]
|
||||
result = []
|
||||
for row in nodes:
|
||||
result.append(
|
||||
[row[path] for path in graycode_order if path in row]
|
||||
)
|
||||
return result
|
Loading…
Add table
Add a link
Reference in a new issue