Vehicle-Anti-Theft-Face-Rec.../venv/Lib/site-packages/pywt/_doc_utils.py

188 lines
5.7 KiB
Python
Raw Normal View History

"""Utilities used to generate various figures in the documentation."""
from itertools import product
import numpy as np
from matplotlib import pyplot as plt
from ._dwt import pad
__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']
def wavedec_keys(level):
"""Subband keys corresponding to a wavedec decomposition."""
approx = ''
coeffs = {}
for lev in range(level):
for k in ['a', 'd']:
coeffs[approx + k] = None
approx = 'a' * (lev + 1)
if lev < level - 1:
coeffs.pop(approx)
return list(coeffs.keys())
def wavedec2_keys(level):
"""Subband keys corresponding to a wavedec2 decomposition."""
approx = ''
coeffs = {}
for lev in range(level):
for k in ['a', 'h', 'v', 'd']:
coeffs[approx + k] = None
approx = 'a' * (lev + 1)
if lev < level - 1:
coeffs.pop(approx)
return list(coeffs.keys())
def _box(bl, ur):
"""(x, y) coordinates for the 4 lines making up a rectangular box.
Parameters
==========
bl : float
The bottom left corner of the box
ur : float
The upper right corner of the box
Returns
=======
coords : 2-tuple
The first and second elements of the tuple are the x and y coordinates
of the box.
"""
xl, xr = bl[0], ur[0]
yb, yt = bl[1], ur[1]
box_x = [xl, xr,
xr, xr,
xr, xl,
xl, xl]
box_y = [yb, yb,
yb, yt,
yt, yt,
yt, yb]
return (box_x, box_y)
def _2d_wp_basis_coords(shape, keys):
# Coordinates of the lines to be drawn by draw_2d_wp_basis
coords = []
centers = {} # retain center of boxes for use in labeling
for key in keys:
offset_x = offset_y = 0
for n, char in enumerate(key):
if char in ['h', 'd']:
offset_x += shape[0] // 2**(n + 1)
if char in ['v', 'd']:
offset_y += shape[1] // 2**(n + 1)
sx = shape[0] // 2**(n + 1)
sy = shape[1] // 2**(n + 1)
xc, yc = _box((offset_x, -offset_y),
(offset_x + sx, -offset_y - sy))
coords.append((xc, yc))
centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
return coords, centers
def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
label_levels=0):
"""Plot a 2D representation of a WaveletPacket2D basis."""
coords, centers = _2d_wp_basis_coords(shape, keys)
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
for coord in coords:
ax.plot(coord[0], coord[1], fmt)
ax.set_axis_off()
ax.axis('square')
if label_levels > 0:
for key, c in centers.items():
if len(key) <= label_levels:
ax.text(c[0], c[1], key,
horizontalalignment='center',
verticalalignment='center')
return fig, ax
def _2d_fswavedecn_coords(shape, levels):
coords = []
centers = {} # retain center of boxes for use in labeling
for key in product(wavedec_keys(levels), repeat=2):
(key0, key1) = key
offsets = [0, 0]
widths = list(shape)
for n0, char in enumerate(key0):
if char in ['d']:
offsets[0] += shape[0] // 2**(n0 + 1)
for n1, char in enumerate(key1):
if char in ['d']:
offsets[1] += shape[1] // 2**(n1 + 1)
widths[0] = shape[0] // 2**(n0 + 1)
widths[1] = shape[1] // 2**(n1 + 1)
xc, yc = _box((offsets[0], -offsets[1]),
(offsets[0] + widths[0], -offsets[1] - widths[1]))
coords.append((xc, yc))
centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
-offsets[1] - widths[1] / 2)
return coords, centers
def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
label_levels=0):
"""Plot a 2D representation of a WaveletPacket2D basis."""
coords, centers = _2d_fswavedecn_coords(shape, levels)
if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()
for coord in coords:
ax.plot(coord[0], coord[1], fmt)
ax.set_axis_off()
ax.axis('square')
if label_levels > 0:
for key, c in centers.items():
lev = np.max([len(k) for k in key])
if lev <= label_levels:
ax.text(c[0], c[1], key,
horizontalalignment='center',
verticalalignment='center')
return fig, ax
def boundary_mode_subplot(x, mode, ax, symw=True):
"""Plot an illustration of the boundary mode in a subplot axis."""
# if odd-length, periodization replicates the last sample to make it even
if mode == 'periodization' and len(x) % 2 == 1:
x = np.concatenate((x, (x[-1], )))
npad = 2 * len(x)
t = np.arange(len(x) + 2 * npad)
xp = pad(x, (npad, npad), mode=mode)
ax.plot(t, xp, 'k.')
ax.set_title(mode)
# plot the original signal in red
if mode == 'periodization':
ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
else:
ax.plot(t[npad:npad + len(x)], x, 'r.')
# add vertical bars indicating points of symmetry or boundary extension
o2 = np.ones(2)
left = npad
if symw:
step = len(x) - 1
rng = range(-2, 4)
else:
left -= 0.5
step = len(x)
rng = range(-2, 4)
if mode in ['smooth', 'constant', 'zero']:
rng = range(0, 2)
for rep in rng:
ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')