# Copyright (c) 2006-2012 Filip Wasilewski
# Copyright (c) 2012-2016 The PyWavelets Developers
#
# See COPYING for license details.
"""1D and 2D Wavelet packet transform module."""
from __future__ import division, print_function, absolute_import
__all__ = ["BaseNode", "Node", "WaveletPacket", "Node2D", "WaveletPacket2D"]
import numpy as np
from ._extensions._pywt import Wavelet, _check_dtype
from ._dwt import dwt, idwt, dwt_max_level
from ._multidim import dwt2, idwt2
def get_graycode_order(level, x='a', y='d'):
graycode_order = [x, y]
for i in range(level - 1):
graycode_order = [x + path for path in graycode_order] + \
[y + path for path in graycode_order[::-1]]
return graycode_order
class BaseNode(object):
"""
BaseNode for wavelet packet 1D and 2D tree nodes.
The BaseNode is a base class for `Node` and `Node2D`.
It should not be used directly unless creating a new transformation
type. It is included here to document the common interface of 1D
and 2D node and wavelet packet transform classes.
Parameters
----------
parent :
Parent node. If parent is None then the node is considered detached
(ie root).
data : 1D or 2D array
Data associated with the node. 1D or 2D numeric array, depending on the
transform type.
node_name :
A name identifying the coefficients type.
See `Node.node_name` and `Node2D.node_name`
for information on the accepted subnodes names.
"""
# PART_LEN and PARTS attributes that define path tokens for node[] lookup
# must be defined in subclasses.
PART_LEN = None
PARTS = None
def __init__(self, parent, data, node_name):
self.parent = parent
if parent is not None:
self.wavelet = parent.wavelet
self.mode = parent.mode
self.level = parent.level + 1
self._maxlevel = parent.maxlevel
self.path = parent.path + node_name
else:
self.wavelet = None
self.mode = None
self.path = ""
self.level = 0
# data - signal on level 0, coeffs on higher levels
self.data = data
# Need to retain original data size/shape so we can trim any excess
# boundary coefficients from the inverse transform.
if self.data is None:
self._data_shape = None
else:
self._data_shape = np.asarray(data).shape
self._init_subnodes()
def _init_subnodes(self):
for part in self.PARTS:
self._set_node(part, None)
def _create_subnode(self, part, data=None, overwrite=True):
raise NotImplementedError()
def _create_subnode_base(self, node_cls, part, data=None, overwrite=True):
self._validate_node_name(part)
if not overwrite and self._get_node(part) is not None:
return self._get_node(part)
node = node_cls(self, data, part)
self._set_node(part, node)
return node
def _get_node(self, part):
return getattr(self, part)
def _set_node(self, part, node):
setattr(self, part, node)
def _delete_node(self, part):
self._set_node(part, None)
def _validate_node_name(self, part):
if part not in self.PARTS:
raise ValueError("Subnode name must be in [%s], not '%s'." %
(', '.join("'%s'" % p for p in self.PARTS), part))
def _evaluate_maxlevel(self, evaluate_from='parent'):
"""
Try to find the value of maximum decomposition level if it is not
specified explicitly.
Parameters
----------
evaluate_from : {'parent', 'subnodes'}
"""
assert evaluate_from in ('parent', 'subnodes')
if self._maxlevel is not None:
return self._maxlevel
elif self.data is not None:
return self.level + dwt_max_level(
min(self.data.shape), self.wavelet)
if evaluate_from == 'parent':
if self.parent is not None:
return self.parent._evaluate_maxlevel(evaluate_from)
elif evaluate_from == 'subnodes':
for node_name in self.PARTS:
node = getattr(self, node_name, None)
if node is not None:
level = node._evaluate_maxlevel(evaluate_from)
if level is not None:
return level
return None
@property
def maxlevel(self):
if self._maxlevel is not None:
return self._maxlevel
# Try getting the maxlevel from parents first
self._maxlevel = self._evaluate_maxlevel(evaluate_from='parent')
# If not found, check whether it can be evaluated from subnodes
if self._maxlevel is None:
self._maxlevel = self._evaluate_maxlevel(evaluate_from='subnodes')
return self._maxlevel
@property
def node_name(self):
return self.path[-self.PART_LEN:]
def decompose(self):
"""
Decompose node data creating DWT coefficients subnodes.
Performs Discrete Wavelet Transform on the `~BaseNode.data` and
returns transform coefficients.
Note
----
Descends to subnodes and recursively
calls `~BaseNode.reconstruct` on them.
"""
if self.level < self.maxlevel:
return self._decompose()
else:
raise ValueError("Maximum decomposition level reached.")
def _decompose(self):
raise NotImplementedError()
def reconstruct(self, update=False):
"""
Reconstruct node from subnodes.
Parameters
----------
update : bool, optional
If True, then reconstructed data replaces the current
node data (default: False).
Returns:
- original node data if subnodes do not exist
- IDWT of subnodes otherwise.
"""
if not self.has_any_subnode:
return self.data
return self._reconstruct(update)
def _reconstruct(self):
raise NotImplementedError() # override this in subclasses
def get_subnode(self, part, decompose=True):
"""
Returns subnode or None (see `decomposition` flag description).
Parameters
----------
part :
Subnode name
decompose : bool, optional
If the param is True and corresponding subnode does not
exist, the subnode will be created using coefficients
from the DWT decomposition of the current node.
(default: True)
"""
self._validate_node_name(part)
subnode = self._get_node(part)
if subnode is None and decompose and not self.is_empty:
self.decompose()
subnode = self._get_node(part)
return subnode
def __getitem__(self, path):
"""
Find node represented by the given path.
Similar to `~BaseNode.get_subnode` method with `decompose=True`, but
can access nodes on any level in the decomposition tree.
Parameters
----------
path : str
String composed of node names. See `Node.node_name` and
`Node2D.node_name` for node naming convention.
Notes
-----
If node does not exist yet, it will be created by decomposition of its
parent node.
"""
if isinstance(path, str):
if (self.maxlevel is not None
and len(path) > self.maxlevel * self.PART_LEN):
raise IndexError("Path length is out of range.")
if path:
return self.get_subnode(path[0:self.PART_LEN], True)[
path[self.PART_LEN:]]
else:
return self
else:
raise TypeError("Invalid path parameter type - expected string but"
" got %s." % type(path))
def __setitem__(self, path, data):
"""
Set node or node's data in the decomposition tree. Nodes are
identified by string `path`.
Parameters
----------
path : str
String composed of node names.
data : array or BaseNode subclass.
"""
if isinstance(path, str):
if (
self.maxlevel is not None
and len(self.path) + len(path) > self.maxlevel * self.PART_LEN
):
raise IndexError("Path length out of range.")
if path:
subnode = self.get_subnode(path[0:self.PART_LEN], False)
if subnode is None:
self._create_subnode(path[0:self.PART_LEN], None)
subnode = self.get_subnode(path[0:self.PART_LEN], False)
subnode[path[self.PART_LEN:]] = data
else:
if isinstance(data, BaseNode):
self.data = np.asarray(data.data)
else:
self.data = np.asarray(data)
# convert data to nearest supported dtype
dtype = _check_dtype(data)
if self.data.dtype != dtype:
self.data = self.data.astype(dtype)
else:
raise TypeError("Invalid path parameter type - expected string but"
" got %s." % type(path))
def __delitem__(self, path):
"""
Remove node from the tree.
Parameters
----------
path : str
String composed of node names.
"""
node = self[path]
# don't clear node value and subnodes (node may still exist outside
# the tree)
# # node._init_subnodes()
# # node.data = None
parent = node.parent
node.parent = None # TODO
if parent and node.node_name:
parent._delete_node(node.node_name)
@property
def is_empty(self):
return self.data is None
@property
def has_any_subnode(self):
for part in self.PARTS:
if self._get_node(part) is not None: # and not .is_empty
return True
return False
def get_leaf_nodes(self, decompose=False):
"""
Returns leaf nodes.
Parameters
----------
decompose : bool, optional
(default: True)
"""
result = []
def collect(node):
if node.level == node.maxlevel and not node.is_empty:
result.append(node)
return False
if not decompose and not node.has_any_subnode:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
return result
def walk(self, func, args=(), kwargs=None, decompose=True):
"""
Traverses the decomposition tree and calls
``func(node, *args, **kwargs)`` on every node. If `func` returns True,
descending to subnodes will continue.
Parameters
----------
func : callable
Callable accepting `BaseNode` as the first param and
optional positional and keyword arguments
args :
func params
kwargs :
func keyword params
decompose : bool, optional
If True (default), the method will also try to decompose the tree
up to the `maximum level `.
"""
if kwargs is None:
kwargs = {}
if func(self, *args, **kwargs) and self.level < self.maxlevel:
for part in self.PARTS:
subnode = self.get_subnode(part, decompose)
if subnode is not None:
subnode.walk(func, args, kwargs, decompose)
def walk_depth(self, func, args=(), kwargs=None, decompose=True):
"""
Walk tree and call func on every node starting from the bottom-most
nodes.
Parameters
----------
func : callable
Callable accepting :class:`BaseNode` as the first param and
optional positional and keyword arguments
args :
func params
kwargs :
func keyword params
decompose : bool, optional
(default: False)
"""
if kwargs is None:
kwargs = {}
if self.level < self.maxlevel:
for part in self.PARTS:
subnode = self.get_subnode(part, decompose)
if subnode is not None:
subnode.walk_depth(func, args, kwargs, decompose)
func(self, *args, **kwargs)
def __str__(self):
return self.path + ": " + str(self.data)
class Node(BaseNode):
"""
WaveletPacket tree node.
Subnodes are called `a` and `d`, just like approximation
and detail coefficients in the Discrete Wavelet Transform.
"""
A = 'a'
D = 'd'
PARTS = A, D
PART_LEN = 1
def _create_subnode(self, part, data=None, overwrite=True):
return self._create_subnode_base(node_cls=Node, part=part, data=data,
overwrite=overwrite)
def _decompose(self):
"""
See also
--------
dwt : for 1D Discrete Wavelet Transform output coefficients.
"""
if self.is_empty:
data_a, data_d = None, None
if self._get_node(self.A) is None:
self._create_subnode(self.A, data_a)
if self._get_node(self.D) is None:
self._create_subnode(self.D, data_d)
else:
data_a, data_d = dwt(self.data, self.wavelet, self.mode)
self._create_subnode(self.A, data_a)
self._create_subnode(self.D, data_d)
return self._get_node(self.A), self._get_node(self.D)
def _reconstruct(self, update):
data_a, data_d = None, None
node_a, node_d = self._get_node(self.A), self._get_node(self.D)
if node_a is not None:
data_a = node_a.reconstruct() # TODO: (update) ???
if node_d is not None:
data_d = node_d.reconstruct() # TODO: (update) ???
if data_a is None and data_d is None:
raise ValueError("Node is a leaf node and cannot be reconstructed"
" from subnodes.")
else:
rec = idwt(data_a, data_d, self.wavelet, self.mode)
if self._data_shape is not None and (
rec.shape != self._data_shape):
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
if update:
self.data = rec
return rec
class Node2D(BaseNode):
"""
WaveletPacket tree node.
Subnodes are called 'a' (LL), 'h' (HL), 'v' (LH) and 'd' (HH), like
approximation and detail coefficients in the 2D Discrete Wavelet Transform
"""
LL = 'a'
HL = 'h'
LH = 'v'
HH = 'd'
PARTS = LL, HL, LH, HH
PART_LEN = 1
def _create_subnode(self, part, data=None, overwrite=True):
return self._create_subnode_base(node_cls=Node2D, part=part, data=data,
overwrite=overwrite)
def _decompose(self):
"""
See also
--------
dwt2 : for 2D Discrete Wavelet Transform output coefficients.
"""
if self.is_empty:
data_ll, data_lh, data_hl, data_hh = None, None, None, None
else:
data_ll, (data_hl, data_lh, data_hh) =\
dwt2(self.data, self.wavelet, self.mode)
self._create_subnode(self.LL, data_ll)
self._create_subnode(self.LH, data_lh)
self._create_subnode(self.HL, data_hl)
self._create_subnode(self.HH, data_hh)
return (self._get_node(self.LL), self._get_node(self.HL),
self._get_node(self.LH), self._get_node(self.HH))
def _reconstruct(self, update):
data_ll, data_lh, data_hl, data_hh = None, None, None, None
node_ll, node_lh, node_hl, node_hh =\
self._get_node(self.LL), self._get_node(self.LH),\
self._get_node(self.HL), self._get_node(self.HH)
if node_ll is not None:
data_ll = node_ll.reconstruct()
if node_lh is not None:
data_lh = node_lh.reconstruct()
if node_hl is not None:
data_hl = node_hl.reconstruct()
if node_hh is not None:
data_hh = node_hh.reconstruct()
if (data_ll is None and data_lh is None
and data_hl is None and data_hh is None):
raise ValueError(
"Tree is missing data - all subnodes of `%s` node "
"are None. Cannot reconstruct node." % self.path
)
else:
coeffs = data_ll, (data_hl, data_lh, data_hh)
rec = idwt2(coeffs, self.wavelet, self.mode)
if self._data_shape is not None and (
rec.shape != self._data_shape):
rec = rec[tuple([slice(sz) for sz in self._data_shape])]
if update:
self.data = rec
return rec
def expand_2d_path(self, path):
expanded_paths = {
self.HH: 'hh',
self.HL: 'hl',
self.LH: 'lh',
self.LL: 'll'
}
return (''.join([expanded_paths[p][0] for p in path]),
''.join([expanded_paths[p][1] for p in path]))
class WaveletPacket(Node):
"""
Data structure representing Wavelet Packet decomposition of signal.
Parameters
----------
data : 1D ndarray
Original data (signal)
wavelet : Wavelet object or name string
Wavelet used in DWT decomposition and reconstruction
mode : str, optional
Signal extension mode for the `dwt` and `idwt` decomposition and
reconstruction functions.
maxlevel : int, optional
Maximum level of decomposition.
If None, it will be calculated based on the `wavelet` and `data`
length using `pywt.dwt_max_level`.
"""
def __init__(self, data, wavelet, mode='symmetric', maxlevel=None):
super(WaveletPacket, self).__init__(None, data, "")
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)
self.wavelet = wavelet
self.mode = mode
if data is not None:
data = np.asarray(data)
assert data.ndim == 1
self.data_size = data.shape[0]
if maxlevel is None:
maxlevel = dwt_max_level(self.data_size, self.wavelet)
else:
self.data_size = None
self._maxlevel = maxlevel
def reconstruct(self, update=True):
"""
Reconstruct data value using coefficients from subnodes.
Parameters
----------
update : bool, optional
If True (default), then data values will be replaced by
reconstruction values, also in subnodes.
"""
if self.has_any_subnode:
data = super(WaveletPacket, self).reconstruct(update)
if update:
self.data = data
return data
return self.data # return original data
def get_level(self, level, order="natural", decompose=True):
"""
Returns all nodes on the specified level.
Parameters
----------
level : int
Specifies decomposition `level` from which the nodes will be
collected.
order : {'natural', 'freq'}, optional
- "natural" - left to right in tree (default)
- "freq" - band ordered
decompose : bool, optional
If set then the method will try to decompose the data up
to the specified `level` (default: True).
Notes
-----
If nodes at the given level are missing (i.e. the tree is partially
decomposed) and the `decompose` is set to False, only existing nodes
will be returned.
"""
assert order in ["natural", "freq"]
if level > self.maxlevel:
raise ValueError("The level cannot be greater than the maximum"
" decomposition level value (%d)" % self.maxlevel)
result = []
def collect(node):
if node.level == level:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
if order == "natural":
return result
elif order == "freq":
result = dict((node.path, node) for node in result)
graycode_order = get_graycode_order(level)
return [result[path] for path in graycode_order if path in result]
else:
raise ValueError("Invalid order name - %s." % order)
class WaveletPacket2D(Node2D):
"""
Data structure representing 2D Wavelet Packet decomposition of signal.
Parameters
----------
data : 2D ndarray
Data associated with the node.
wavelet : Wavelet object or name string
Wavelet used in DWT decomposition and reconstruction
mode : str, optional
Signal extension mode for the `dwt` and `idwt` decomposition and
reconstruction functions.
maxlevel : int
Maximum level of decomposition.
If None, it will be calculated based on the `wavelet` and `data`
length using `pywt.dwt_max_level`.
"""
def __init__(self, data, wavelet, mode='smooth', maxlevel=None):
super(WaveletPacket2D, self).__init__(None, data, "")
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)
self.wavelet = wavelet
self.mode = mode
if data is not None:
data = np.asarray(data)
assert data.ndim == 2
self.data_size = data.shape
if maxlevel is None:
maxlevel = dwt_max_level(min(self.data_size), self.wavelet)
else:
self.data_size = None
self._maxlevel = maxlevel
def reconstruct(self, update=True):
"""
Reconstruct data using coefficients from subnodes.
Parameters
----------
update : bool, optional
If True (default) then the coefficients of the current node
and its subnodes will be replaced with values from reconstruction.
"""
if self.has_any_subnode:
data = super(WaveletPacket2D, self).reconstruct(update)
if update:
self.data = data
return data
return self.data # return original data
def get_level(self, level, order="natural", decompose=True):
"""
Returns all nodes from specified level.
Parameters
----------
level : int
Decomposition `level` from which the nodes will be
collected.
order : {'natural', 'freq'}, optional
If `natural` (default) a flat list is returned.
If `freq`, a 2d structure with rows and cols
sorted by corresponding dimension frequency of 2d
coefficient array (adapted from 1d case).
decompose : bool, optional
If set then the method will try to decompose the data up
to the specified `level` (default: True).
"""
assert order in ["natural", "freq"]
if level > self.maxlevel:
raise ValueError("The level cannot be greater than the maximum"
" decomposition level value (%d)" % self.maxlevel)
result = []
def collect(node):
if node.level == level:
result.append(node)
return False
return True
self.walk(collect, decompose=decompose)
if order == "freq":
nodes = {}
for (row_path, col_path), node in [
(self.expand_2d_path(node.path), node) for node in result
]:
nodes.setdefault(row_path, {})[col_path] = node
graycode_order = get_graycode_order(level, x='l', y='h')
nodes = [nodes[path] for path in graycode_order if path in nodes]
result = []
for row in nodes:
result.append(
[row[path] for path in graycode_order if path in row]
)
return result