208 lines
6.3 KiB
Python
208 lines
6.3 KiB
Python
from collections import namedtuple
|
|
import numpy as np
|
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
import matplotlib.image
|
|
from ...util import dtype as dtypes
|
|
from ...exposure import is_low_contrast
|
|
from ..._shared.utils import warn
|
|
from math import floor, ceil
|
|
|
|
|
|
_default_colormap = 'gray'
|
|
_nonstandard_colormap = 'viridis'
|
|
_diverging_colormap = 'RdBu'
|
|
|
|
|
|
ImageProperties = namedtuple('ImageProperties',
|
|
['signed', 'out_of_range_float',
|
|
'low_data_range', 'unsupported_dtype'])
|
|
|
|
|
|
def _get_image_properties(image):
|
|
"""Determine nonstandard properties of an input image.
|
|
|
|
Parameters
|
|
----------
|
|
image : array
|
|
The input image.
|
|
|
|
Returns
|
|
-------
|
|
ip : ImageProperties named tuple
|
|
The properties of the image:
|
|
|
|
- signed: whether the image has negative values.
|
|
- out_of_range_float: if the image has floating point data
|
|
outside of [-1, 1].
|
|
- low_data_range: if the image is in the standard image
|
|
range (e.g. [0, 1] for a floating point image) but its
|
|
data range would be too small to display with standard
|
|
image ranges.
|
|
- unsupported_dtype: if the image data type is not a
|
|
standard skimage type, e.g. ``numpy.uint64``.
|
|
"""
|
|
immin, immax = np.min(image), np.max(image)
|
|
imtype = image.dtype.type
|
|
try:
|
|
lo, hi = dtypes.dtype_range[imtype]
|
|
except KeyError:
|
|
lo, hi = immin, immax
|
|
|
|
signed = immin < 0
|
|
out_of_range_float = (np.issubdtype(image.dtype, np.floating) and
|
|
(immin < lo or immax > hi))
|
|
low_data_range = (immin != immax and
|
|
is_low_contrast(image))
|
|
unsupported_dtype = image.dtype not in dtypes._supported_types
|
|
|
|
return ImageProperties(signed, out_of_range_float,
|
|
low_data_range, unsupported_dtype)
|
|
|
|
|
|
def _raise_warnings(image_properties):
|
|
"""Raise the appropriate warning for each nonstandard image type.
|
|
|
|
Parameters
|
|
----------
|
|
image_properties : ImageProperties named tuple
|
|
The properties of the considered image.
|
|
"""
|
|
ip = image_properties
|
|
if ip.unsupported_dtype:
|
|
warn("Non-standard image type; displaying image with "
|
|
"stretched contrast.", stacklevel=3)
|
|
if ip.low_data_range:
|
|
warn("Low image data range; displaying image with "
|
|
"stretched contrast.", stacklevel=3)
|
|
if ip.out_of_range_float:
|
|
warn("Float image out of standard range; displaying "
|
|
"image with stretched contrast.", stacklevel=3)
|
|
|
|
|
|
def _get_display_range(image):
|
|
"""Return the display range for a given set of image properties.
|
|
|
|
Parameters
|
|
----------
|
|
image : array
|
|
The input image.
|
|
|
|
Returns
|
|
-------
|
|
lo, hi : same type as immin, immax
|
|
The display range to be used for the input image.
|
|
cmap : string
|
|
The name of the colormap to use.
|
|
"""
|
|
ip = _get_image_properties(image)
|
|
immin, immax = np.min(image), np.max(image)
|
|
if ip.signed:
|
|
magnitude = max(abs(immin), abs(immax))
|
|
lo, hi = -magnitude, magnitude
|
|
cmap = _diverging_colormap
|
|
elif any(ip):
|
|
_raise_warnings(ip)
|
|
lo, hi = immin, immax
|
|
cmap = _nonstandard_colormap
|
|
else:
|
|
lo = 0
|
|
imtype = image.dtype.type
|
|
hi = dtypes.dtype_range[imtype][1]
|
|
cmap = _default_colormap
|
|
return lo, hi, cmap
|
|
|
|
|
|
def imshow(image, ax=None, show_cbar=None, **kwargs):
|
|
"""Show the input image and return the current axes.
|
|
|
|
By default, the image is displayed in greyscale, rather than
|
|
the matplotlib default colormap.
|
|
|
|
Images are assumed to have standard range for their type. For
|
|
example, if a floating point image has values in [0, 0.5], the
|
|
most intense color will be gray50, not white.
|
|
|
|
If the image exceeds the standard range, or if the range is too
|
|
small to display, we fall back on displaying exactly the range of
|
|
the input image, along with a colorbar to clearly indicate that
|
|
this range transformation has occurred.
|
|
|
|
For signed images, we use a diverging colormap centered at 0.
|
|
|
|
Parameters
|
|
----------
|
|
image : array, shape (M, N[, 3])
|
|
The image to display.
|
|
ax: `matplotlib.axes.Axes`, optional
|
|
The axis to use for the image, defaults to plt.gca().
|
|
show_cbar: boolean, optional.
|
|
Whether to show the colorbar (used to override default behavior).
|
|
**kwargs : Keyword arguments
|
|
These are passed directly to `matplotlib.pyplot.imshow`.
|
|
|
|
Returns
|
|
-------
|
|
ax_im : `matplotlib.pyplot.AxesImage`
|
|
The `AxesImage` object returned by `plt.imshow`.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
lo, hi, cmap = _get_display_range(image)
|
|
|
|
kwargs.setdefault('interpolation', 'nearest')
|
|
kwargs.setdefault('cmap', cmap)
|
|
kwargs.setdefault('vmin', lo)
|
|
kwargs.setdefault('vmax', hi)
|
|
|
|
ax = ax or plt.gca()
|
|
ax_im = ax.imshow(image, **kwargs)
|
|
if (cmap != _default_colormap and show_cbar is not False) or show_cbar:
|
|
divider = make_axes_locatable(ax)
|
|
cax = divider.append_axes("right", size="5%", pad=0.05)
|
|
plt.colorbar(ax_im, cax=cax)
|
|
ax.get_figure().tight_layout()
|
|
|
|
return ax_im
|
|
|
|
|
|
def imshow_collection(ic, *args, **kwargs):
|
|
"""Display all images in the collection.
|
|
|
|
Returns
|
|
-------
|
|
fig : `matplotlib.figure.Figure`
|
|
The `Figure` object returned by `plt.subplots`.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
if len(ic) < 1:
|
|
raise ValueError('Number of images to plot must be greater than 0')
|
|
|
|
# The target is to plot images on a grid with aspect ratio 4:3
|
|
num_images = len(ic)
|
|
# Two pairs of `nrows, ncols` are possible
|
|
k = (num_images * 12)**0.5
|
|
r1 = max(1, floor(k / 4))
|
|
r2 = ceil(k / 4)
|
|
c1 = ceil(num_images / r1)
|
|
c2 = ceil(num_images / r2)
|
|
# Select the one which is closer to 4:3
|
|
if abs(r1 / c1 - 0.75) < abs(r2 / c2 - 0.75):
|
|
nrows, ncols = r1, c1
|
|
else:
|
|
nrows, ncols = r2, c2
|
|
|
|
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
|
|
ax = np.asarray(axes).ravel()
|
|
for n, image in enumerate(ic):
|
|
ax[n].imshow(image, *args, **kwargs)
|
|
kwargs['ax'] = axes
|
|
return fig
|
|
|
|
|
|
imread = matplotlib.image.imread
|
|
|
|
|
|
def _app_show():
|
|
from matplotlib.pyplot import show
|
|
show()
|