# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/> # Copyright (c) 2012-2016 The PyWavelets Developers # <https://github.com/PyWavelets/pywt> # See COPYING for license details. """1D and 2D Wavelet packet transform module.""" from __future__ import division, print_function, absolute_import __all__ = ["BaseNode", "Node", "WaveletPacket", "Node2D", "WaveletPacket2D"] import numpy as np from ._extensions._pywt import Wavelet, _check_dtype from ._dwt import dwt, idwt, dwt_max_level from ._multidim import dwt2, idwt2 def get_graycode_order(level, x='a', y='d'): graycode_order = [x, y] for i in range(level - 1): graycode_order = [x + path for path in graycode_order] + \ [y + path for path in graycode_order[::-1]] return graycode_order class BaseNode(object): """ BaseNode for wavelet packet 1D and 2D tree nodes. The BaseNode is a base class for `Node` and `Node2D`. It should not be used directly unless creating a new transformation type. It is included here to document the common interface of 1D and 2D node and wavelet packet transform classes. Parameters ---------- parent : Parent node. If parent is None then the node is considered detached (ie root). data : 1D or 2D array Data associated with the node. 1D or 2D numeric array, depending on the transform type. node_name : A name identifying the coefficients type. See `Node.node_name` and `Node2D.node_name` for information on the accepted subnodes names. """ # PART_LEN and PARTS attributes that define path tokens for node[] lookup # must be defined in subclasses. PART_LEN = None PARTS = None def __init__(self, parent, data, node_name): self.parent = parent if parent is not None: self.wavelet = parent.wavelet self.mode = parent.mode self.level = parent.level + 1 self._maxlevel = parent.maxlevel self.path = parent.path + node_name else: self.wavelet = None self.mode = None self.path = "" self.level = 0 # data - signal on level 0, coeffs on higher levels self.data = data # Need to retain original data size/shape so we can trim any excess # boundary coefficients from the inverse transform. if self.data is None: self._data_shape = None else: self._data_shape = np.asarray(data).shape self._init_subnodes() def _init_subnodes(self): for part in self.PARTS: self._set_node(part, None) def _create_subnode(self, part, data=None, overwrite=True): raise NotImplementedError() def _create_subnode_base(self, node_cls, part, data=None, overwrite=True): self._validate_node_name(part) if not overwrite and self._get_node(part) is not None: return self._get_node(part) node = node_cls(self, data, part) self._set_node(part, node) return node def _get_node(self, part): return getattr(self, part) def _set_node(self, part, node): setattr(self, part, node) def _delete_node(self, part): self._set_node(part, None) def _validate_node_name(self, part): if part not in self.PARTS: raise ValueError("Subnode name must be in [%s], not '%s'." % (', '.join("'%s'" % p for p in self.PARTS), part)) def _evaluate_maxlevel(self, evaluate_from='parent'): """ Try to find the value of maximum decomposition level if it is not specified explicitly. Parameters ---------- evaluate_from : {'parent', 'subnodes'} """ assert evaluate_from in ('parent', 'subnodes') if self._maxlevel is not None: return self._maxlevel elif self.data is not None: return self.level + dwt_max_level( min(self.data.shape), self.wavelet) if evaluate_from == 'parent': if self.parent is not None: return self.parent._evaluate_maxlevel(evaluate_from) elif evaluate_from == 'subnodes': for node_name in self.PARTS: node = getattr(self, node_name, None) if node is not None: level = node._evaluate_maxlevel(evaluate_from) if level is not None: return level return None @property def maxlevel(self): if self._maxlevel is not None: return self._maxlevel # Try getting the maxlevel from parents first self._maxlevel = self._evaluate_maxlevel(evaluate_from='parent') # If not found, check whether it can be evaluated from subnodes if self._maxlevel is None: self._maxlevel = self._evaluate_maxlevel(evaluate_from='subnodes') return self._maxlevel @property def node_name(self): return self.path[-self.PART_LEN:] def decompose(self): """ Decompose node data creating DWT coefficients subnodes. Performs Discrete Wavelet Transform on the `~BaseNode.data` and returns transform coefficients. Note ---- Descends to subnodes and recursively calls `~BaseNode.reconstruct` on them. """ if self.level < self.maxlevel: return self._decompose() else: raise ValueError("Maximum decomposition level reached.") def _decompose(self): raise NotImplementedError() def reconstruct(self, update=False): """ Reconstruct node from subnodes. Parameters ---------- update : bool, optional If True, then reconstructed data replaces the current node data (default: False). Returns: - original node data if subnodes do not exist - IDWT of subnodes otherwise. """ if not self.has_any_subnode: return self.data return self._reconstruct(update) def _reconstruct(self): raise NotImplementedError() # override this in subclasses def get_subnode(self, part, decompose=True): """ Returns subnode or None (see `decomposition` flag description). Parameters ---------- part : Subnode name decompose : bool, optional If the param is True and corresponding subnode does not exist, the subnode will be created using coefficients from the DWT decomposition of the current node. (default: True) """ self._validate_node_name(part) subnode = self._get_node(part) if subnode is None and decompose and not self.is_empty: self.decompose() subnode = self._get_node(part) return subnode def __getitem__(self, path): """ Find node represented by the given path. Similar to `~BaseNode.get_subnode` method with `decompose=True`, but can access nodes on any level in the decomposition tree. Parameters ---------- path : str String composed of node names. See `Node.node_name` and `Node2D.node_name` for node naming convention. Notes ----- If node does not exist yet, it will be created by decomposition of its parent node. """ if isinstance(path, str): if (self.maxlevel is not None and len(path) > self.maxlevel * self.PART_LEN): raise IndexError("Path length is out of range.") if path: return self.get_subnode(path[0:self.PART_LEN], True)[ path[self.PART_LEN:]] else: return self else: raise TypeError("Invalid path parameter type - expected string but" " got %s." % type(path)) def __setitem__(self, path, data): """ Set node or node's data in the decomposition tree. Nodes are identified by string `path`. Parameters ---------- path : str String composed of node names. data : array or BaseNode subclass. """ if isinstance(path, str): if ( self.maxlevel is not None and len(self.path) + len(path) > self.maxlevel * self.PART_LEN ): raise IndexError("Path length out of range.") if path: subnode = self.get_subnode(path[0:self.PART_LEN], False) if subnode is None: self._create_subnode(path[0:self.PART_LEN], None) subnode = self.get_subnode(path[0:self.PART_LEN], False) subnode[path[self.PART_LEN:]] = data else: if isinstance(data, BaseNode): self.data = np.asarray(data.data) else: self.data = np.asarray(data) # convert data to nearest supported dtype dtype = _check_dtype(data) if self.data.dtype != dtype: self.data = self.data.astype(dtype) else: raise TypeError("Invalid path parameter type - expected string but" " got %s." % type(path)) def __delitem__(self, path): """ Remove node from the tree. Parameters ---------- path : str String composed of node names. """ node = self[path] # don't clear node value and subnodes (node may still exist outside # the tree) # # node._init_subnodes() # # node.data = None parent = node.parent node.parent = None # TODO if parent and node.node_name: parent._delete_node(node.node_name) @property def is_empty(self): return self.data is None @property def has_any_subnode(self): for part in self.PARTS: if self._get_node(part) is not None: # and not .is_empty return True return False def get_leaf_nodes(self, decompose=False): """ Returns leaf nodes. Parameters ---------- decompose : bool, optional (default: True) """ result = [] def collect(node): if node.level == node.maxlevel and not node.is_empty: result.append(node) return False if not decompose and not node.has_any_subnode: result.append(node) return False return True self.walk(collect, decompose=decompose) return result def walk(self, func, args=(), kwargs=None, decompose=True): """ Traverses the decomposition tree and calls ``func(node, *args, **kwargs)`` on every node. If `func` returns True, descending to subnodes will continue. Parameters ---------- func : callable Callable accepting `BaseNode` as the first param and optional positional and keyword arguments args : func params kwargs : func keyword params decompose : bool, optional If True (default), the method will also try to decompose the tree up to the `maximum level <BaseNode.maxlevel>`. """ if kwargs is None: kwargs = {} if func(self, *args, **kwargs) and self.level < self.maxlevel: for part in self.PARTS: subnode = self.get_subnode(part, decompose) if subnode is not None: subnode.walk(func, args, kwargs, decompose) def walk_depth(self, func, args=(), kwargs=None, decompose=True): """ Walk tree and call func on every node starting from the bottom-most nodes. Parameters ---------- func : callable Callable accepting :class:`BaseNode` as the first param and optional positional and keyword arguments args : func params kwargs : func keyword params decompose : bool, optional (default: False) """ if kwargs is None: kwargs = {} if self.level < self.maxlevel: for part in self.PARTS: subnode = self.get_subnode(part, decompose) if subnode is not None: subnode.walk_depth(func, args, kwargs, decompose) func(self, *args, **kwargs) def __str__(self): return self.path + ": " + str(self.data) class Node(BaseNode): """ WaveletPacket tree node. Subnodes are called `a` and `d`, just like approximation and detail coefficients in the Discrete Wavelet Transform. """ A = 'a' D = 'd' PARTS = A, D PART_LEN = 1 def _create_subnode(self, part, data=None, overwrite=True): return self._create_subnode_base(node_cls=Node, part=part, data=data, overwrite=overwrite) def _decompose(self): """ See also -------- dwt : for 1D Discrete Wavelet Transform output coefficients. """ if self.is_empty: data_a, data_d = None, None if self._get_node(self.A) is None: self._create_subnode(self.A, data_a) if self._get_node(self.D) is None: self._create_subnode(self.D, data_d) else: data_a, data_d = dwt(self.data, self.wavelet, self.mode) self._create_subnode(self.A, data_a) self._create_subnode(self.D, data_d) return self._get_node(self.A), self._get_node(self.D) def _reconstruct(self, update): data_a, data_d = None, None node_a, node_d = self._get_node(self.A), self._get_node(self.D) if node_a is not None: data_a = node_a.reconstruct() # TODO: (update) ??? if node_d is not None: data_d = node_d.reconstruct() # TODO: (update) ??? if data_a is None and data_d is None: raise ValueError("Node is a leaf node and cannot be reconstructed" " from subnodes.") else: rec = idwt(data_a, data_d, self.wavelet, self.mode) if self._data_shape is not None and ( rec.shape != self._data_shape): rec = rec[tuple([slice(sz) for sz in self._data_shape])] if update: self.data = rec return rec class Node2D(BaseNode): """ WaveletPacket tree node. Subnodes are called 'a' (LL), 'h' (HL), 'v' (LH) and 'd' (HH), like approximation and detail coefficients in the 2D Discrete Wavelet Transform """ LL = 'a' HL = 'h' LH = 'v' HH = 'd' PARTS = LL, HL, LH, HH PART_LEN = 1 def _create_subnode(self, part, data=None, overwrite=True): return self._create_subnode_base(node_cls=Node2D, part=part, data=data, overwrite=overwrite) def _decompose(self): """ See also -------- dwt2 : for 2D Discrete Wavelet Transform output coefficients. """ if self.is_empty: data_ll, data_lh, data_hl, data_hh = None, None, None, None else: data_ll, (data_hl, data_lh, data_hh) =\ dwt2(self.data, self.wavelet, self.mode) self._create_subnode(self.LL, data_ll) self._create_subnode(self.LH, data_lh) self._create_subnode(self.HL, data_hl) self._create_subnode(self.HH, data_hh) return (self._get_node(self.LL), self._get_node(self.HL), self._get_node(self.LH), self._get_node(self.HH)) def _reconstruct(self, update): data_ll, data_lh, data_hl, data_hh = None, None, None, None node_ll, node_lh, node_hl, node_hh =\ self._get_node(self.LL), self._get_node(self.LH),\ self._get_node(self.HL), self._get_node(self.HH) if node_ll is not None: data_ll = node_ll.reconstruct() if node_lh is not None: data_lh = node_lh.reconstruct() if node_hl is not None: data_hl = node_hl.reconstruct() if node_hh is not None: data_hh = node_hh.reconstruct() if (data_ll is None and data_lh is None and data_hl is None and data_hh is None): raise ValueError( "Tree is missing data - all subnodes of `%s` node " "are None. Cannot reconstruct node." % self.path ) else: coeffs = data_ll, (data_hl, data_lh, data_hh) rec = idwt2(coeffs, self.wavelet, self.mode) if self._data_shape is not None and ( rec.shape != self._data_shape): rec = rec[tuple([slice(sz) for sz in self._data_shape])] if update: self.data = rec return rec def expand_2d_path(self, path): expanded_paths = { self.HH: 'hh', self.HL: 'hl', self.LH: 'lh', self.LL: 'll' } return (''.join([expanded_paths[p][0] for p in path]), ''.join([expanded_paths[p][1] for p in path])) class WaveletPacket(Node): """ Data structure representing Wavelet Packet decomposition of signal. Parameters ---------- data : 1D ndarray Original data (signal) wavelet : Wavelet object or name string Wavelet used in DWT decomposition and reconstruction mode : str, optional Signal extension mode for the `dwt` and `idwt` decomposition and reconstruction functions. maxlevel : int, optional Maximum level of decomposition. If None, it will be calculated based on the `wavelet` and `data` length using `pywt.dwt_max_level`. """ def __init__(self, data, wavelet, mode='symmetric', maxlevel=None): super(WaveletPacket, self).__init__(None, data, "") if not isinstance(wavelet, Wavelet): wavelet = Wavelet(wavelet) self.wavelet = wavelet self.mode = mode if data is not None: data = np.asarray(data) assert data.ndim == 1 self.data_size = data.shape[0] if maxlevel is None: maxlevel = dwt_max_level(self.data_size, self.wavelet) else: self.data_size = None self._maxlevel = maxlevel def reconstruct(self, update=True): """ Reconstruct data value using coefficients from subnodes. Parameters ---------- update : bool, optional If True (default), then data values will be replaced by reconstruction values, also in subnodes. """ if self.has_any_subnode: data = super(WaveletPacket, self).reconstruct(update) if update: self.data = data return data return self.data # return original data def get_level(self, level, order="natural", decompose=True): """ Returns all nodes on the specified level. Parameters ---------- level : int Specifies decomposition `level` from which the nodes will be collected. order : {'natural', 'freq'}, optional - "natural" - left to right in tree (default) - "freq" - band ordered decompose : bool, optional If set then the method will try to decompose the data up to the specified `level` (default: True). Notes ----- If nodes at the given level are missing (i.e. the tree is partially decomposed) and the `decompose` is set to False, only existing nodes will be returned. """ assert order in ["natural", "freq"] if level > self.maxlevel: raise ValueError("The level cannot be greater than the maximum" " decomposition level value (%d)" % self.maxlevel) result = [] def collect(node): if node.level == level: result.append(node) return False return True self.walk(collect, decompose=decompose) if order == "natural": return result elif order == "freq": result = dict((node.path, node) for node in result) graycode_order = get_graycode_order(level) return [result[path] for path in graycode_order if path in result] else: raise ValueError("Invalid order name - %s." % order) class WaveletPacket2D(Node2D): """ Data structure representing 2D Wavelet Packet decomposition of signal. Parameters ---------- data : 2D ndarray Data associated with the node. wavelet : Wavelet object or name string Wavelet used in DWT decomposition and reconstruction mode : str, optional Signal extension mode for the `dwt` and `idwt` decomposition and reconstruction functions. maxlevel : int Maximum level of decomposition. If None, it will be calculated based on the `wavelet` and `data` length using `pywt.dwt_max_level`. """ def __init__(self, data, wavelet, mode='smooth', maxlevel=None): super(WaveletPacket2D, self).__init__(None, data, "") if not isinstance(wavelet, Wavelet): wavelet = Wavelet(wavelet) self.wavelet = wavelet self.mode = mode if data is not None: data = np.asarray(data) assert data.ndim == 2 self.data_size = data.shape if maxlevel is None: maxlevel = dwt_max_level(min(self.data_size), self.wavelet) else: self.data_size = None self._maxlevel = maxlevel def reconstruct(self, update=True): """ Reconstruct data using coefficients from subnodes. Parameters ---------- update : bool, optional If True (default) then the coefficients of the current node and its subnodes will be replaced with values from reconstruction. """ if self.has_any_subnode: data = super(WaveletPacket2D, self).reconstruct(update) if update: self.data = data return data return self.data # return original data def get_level(self, level, order="natural", decompose=True): """ Returns all nodes from specified level. Parameters ---------- level : int Decomposition `level` from which the nodes will be collected. order : {'natural', 'freq'}, optional If `natural` (default) a flat list is returned. If `freq`, a 2d structure with rows and cols sorted by corresponding dimension frequency of 2d coefficient array (adapted from 1d case). decompose : bool, optional If set then the method will try to decompose the data up to the specified `level` (default: True). """ assert order in ["natural", "freq"] if level > self.maxlevel: raise ValueError("The level cannot be greater than the maximum" " decomposition level value (%d)" % self.maxlevel) result = [] def collect(node): if node.level == level: result.append(node) return False return True self.walk(collect, decompose=decompose) if order == "freq": nodes = {} for (row_path, col_path), node in [ (self.expand_2d_path(node.path), node) for node in result ]: nodes.setdefault(row_path, {})[col_path] = node graycode_order = get_graycode_order(level, x='l', y='h') nodes = [nodes[path] for path in graycode_order if path in nodes] result = [] for row in nodes: result.append( [row[path] for path in graycode_order if path in row] ) return result