Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,6 @@
from .._shared.utils import warn
from .viewers import ImageViewer, CollectionViewer
from .qt import has_qt
if not has_qt:
warn('Viewer requires Qt')

View file

@ -0,0 +1,3 @@
from .linetool import LineTool, ThickLineTool
from .recttool import RectangleTool
from .painttool import PaintTool

View file

@ -0,0 +1,135 @@
import numpy as np
from matplotlib import lines
__all__ = ['CanvasToolBase', 'ToolHandles']
def _pass(*args):
pass
class CanvasToolBase(object):
"""Base canvas tool for matplotlib axes.
Parameters
----------
manager : Viewer or PlotPlugin.
Skimage viewer or plot plugin object.
on_move : function
Function called whenever a control handle is moved.
This function must accept the end points of line as the only argument.
on_release : function
Function called whenever the control handle is released.
on_enter : function
Function called whenever the "enter" key is pressed.
"""
def __init__(self, manager, on_move=None, on_enter=None, on_release=None,
useblit=True, ax=None):
self.manager = manager
self.ax = manager.ax
self.artists = []
self.active = True
self.callback_on_move = _pass if on_move is None else on_move
self.callback_on_enter = _pass if on_enter is None else on_enter
self.callback_on_release = _pass if on_release is None else on_release
def ignore(self, event):
"""Return True if event should be ignored.
This method (or a version of it) should be called at the beginning
of any event callback.
"""
return not self.active
def hit_test(self, event):
return False
def redraw(self):
self.manager.redraw()
def set_visible(self, val):
for artist in self.artists:
artist.set_visible(val)
def on_key_press(self, event):
if event.key == 'enter':
self.callback_on_enter(self.geometry)
self.set_visible(False)
self.manager.redraw()
def on_mouse_press(self, event):
pass
def on_mouse_release(self, event):
pass
def on_move(self, event):
pass
def on_scroll(self, event):
pass
def remove(self):
self.manager.remove_tool(self)
@property
def geometry(self):
"""Geometry information that gets passed to callback functions."""
return None
class ToolHandles(object):
"""Control handles for canvas tools.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Matplotlib axes where tool handles are displayed.
x, y : 1D arrays
Coordinates of control handles.
marker : str
Shape of marker used to display handle. See `matplotlib.pyplot.plot`.
marker_props : dict
Additional marker properties. See :class:`matplotlib.lines.Line2D`.
"""
def __init__(self, ax, x, y, marker='o', marker_props=None):
self.ax = ax
props = dict(marker=marker, markersize=7, mfc='w', ls='none',
alpha=0.5, visible=False)
props.update(marker_props if marker_props is not None else {})
self._markers = lines.Line2D(x, y, animated=True, **props)
self.ax.add_line(self._markers)
self.artist = self._markers
@property
def x(self):
return self._markers.get_xdata()
@property
def y(self):
return self._markers.get_ydata()
def set_data(self, pts, y=None):
"""Set x and y positions of handles"""
if y is not None:
x = pts
pts = np.array([x, y])
self._markers.set_data(pts)
def set_visible(self, val):
self._markers.set_visible(val)
def set_animated(self, val):
self._markers.set_animated(val)
def closest(self, x, y):
"""Return index and pixel distance to closest index."""
pts = np.transpose((self.x, self.y))
# Transform data coordinates to pixel coordinates.
pts = self.ax.transData.transform(pts)
diff = pts - ((x, y))
dist = np.sqrt(np.sum(diff**2, axis=1))
return np.argmin(dist), np.min(dist)

View file

@ -0,0 +1,212 @@
import numpy as np
from matplotlib import lines
from ...viewer.canvastools.base import CanvasToolBase, ToolHandles
__all__ = ['LineTool', 'ThickLineTool']
class LineTool(CanvasToolBase):
"""Widget for line selection in a plot.
Parameters
----------
manager : Viewer or PlotPlugin.
Skimage viewer or plot plugin object.
on_move : function
Function called whenever a control handle is moved.
This function must accept the end points of line as the only argument.
on_release : function
Function called whenever the control handle is released.
on_enter : function
Function called whenever the "enter" key is pressed.
maxdist : float
Maximum pixel distance allowed when selecting control handle.
line_props : dict
Properties for :class:`matplotlib.lines.Line2D`.
handle_props : dict
Marker properties for the handles (also see
:class:`matplotlib.lines.Line2D`).
Attributes
----------
end_points : 2D array
End points of line ((x1, y1), (x2, y2)).
"""
def __init__(self, manager, on_move=None, on_release=None, on_enter=None,
maxdist=10, line_props=None, handle_props=None,
**kwargs):
super(LineTool, self).__init__(manager, on_move=on_move,
on_enter=on_enter,
on_release=on_release, **kwargs)
props = dict(color='r', linewidth=1, alpha=0.4, solid_capstyle='butt')
props.update(line_props if line_props is not None else {})
self.linewidth = props['linewidth']
self.maxdist = maxdist
self._active_pt = None
x = (0, 0)
y = (0, 0)
self._end_pts = np.transpose([x, y])
self._line = lines.Line2D(x, y, visible=False, animated=True, **props)
self.ax.add_line(self._line)
self._handles = ToolHandles(self.ax, x, y,
marker_props=handle_props)
self._handles.set_visible(False)
self.artists = [self._line, self._handles.artist]
if on_enter is None:
def on_enter(pts):
x, y = np.transpose(pts)
print("length = %0.2f" %
np.sqrt(np.diff(x)**2 + np.diff(y)**2))
self.callback_on_enter = on_enter
self.manager.add_tool(self)
@property
def end_points(self):
return self._end_pts.astype(int)
@end_points.setter
def end_points(self, pts):
self._end_pts = np.asarray(pts)
self._line.set_data(np.transpose(pts))
self._handles.set_data(np.transpose(pts))
self._line.set_linewidth(self.linewidth)
self.set_visible(True)
self.redraw()
def hit_test(self, event):
if event.button != 1 or not self.ax.in_axes(event):
return False
idx, px_dist = self._handles.closest(event.x, event.y)
if px_dist < self.maxdist:
self._active_pt = idx
return True
else:
self._active_pt = None
return False
def on_mouse_press(self, event):
self.set_visible(True)
if self._active_pt is None:
self._active_pt = 0
x, y = event.xdata, event.ydata
self._end_pts = np.array([[x, y], [x, y]])
def on_mouse_release(self, event):
if event.button != 1:
return
self._active_pt = None
self.callback_on_release(self.geometry)
self.redraw()
def on_move(self, event):
if event.button != 1 or self._active_pt is None:
return
if not self.ax.in_axes(event):
return
self.update(event.xdata, event.ydata)
self.callback_on_move(self.geometry)
def update(self, x=None, y=None):
if x is not None:
self._end_pts[self._active_pt, :] = x, y
self.end_points = self._end_pts
@property
def geometry(self):
return self.end_points
class ThickLineTool(LineTool):
"""Widget for line selection in a plot.
The thickness of the line can be varied using the mouse scroll wheel, or
with the '+' and '-' keys.
Parameters
----------
manager : Viewer or PlotPlugin.
Skimage viewer or plot plugin object.
on_move : function
Function called whenever a control handle is moved.
This function must accept the end points of line as the only argument.
on_release : function
Function called whenever the control handle is released.
on_enter : function
Function called whenever the "enter" key is pressed.
on_change : function
Function called whenever the line thickness is changed.
maxdist : float
Maximum pixel distance allowed when selecting control handle.
line_props : dict
Properties for :class:`matplotlib.lines.Line2D`.
handle_props : dict
Marker properties for the handles (also see
:class:`matplotlib.lines.Line2D`).
Attributes
----------
end_points : 2D array
End points of line ((x1, y1), (x2, y2)).
"""
def __init__(self, manager, on_move=None, on_enter=None, on_release=None,
on_change=None, maxdist=10, line_props=None, handle_props=None):
super(ThickLineTool, self).__init__(manager,
on_move=on_move,
on_enter=on_enter,
on_release=on_release,
maxdist=maxdist,
line_props=line_props,
handle_props=handle_props)
if on_change is None:
def on_change(*args):
pass
self.callback_on_change = on_change
def on_scroll(self, event):
if not event.inaxes:
return
if event.button == 'up':
self._thicken_scan_line()
elif event.button == 'down':
self._shrink_scan_line()
def on_key_press(self, event):
if event.key == '+':
self._thicken_scan_line()
elif event.key == '-':
self._shrink_scan_line()
def _thicken_scan_line(self):
self.linewidth += 1
self.update()
self.callback_on_change(self.geometry)
def _shrink_scan_line(self):
if self.linewidth > 1:
self.linewidth -= 1
self.update()
self.callback_on_change(self.geometry)
if __name__ == '__main__': # pragma: no cover
from ... import data
from ...viewer import ImageViewer
image = data.camera()
viewer = ImageViewer(image)
h, w = image.shape
line_tool = ThickLineTool(viewer)
line_tool.end_points = ([w/3, h/2], [2*w/3, h/2])
viewer.show()

View file

@ -0,0 +1,238 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
LABELS_CMAP = mcolors.ListedColormap(['white', 'red', 'dodgerblue', 'gold',
'greenyellow', 'blueviolet'])
from ...viewer.canvastools.base import CanvasToolBase
__all__ = ['PaintTool']
class PaintTool(CanvasToolBase):
"""Widget for painting on top of a plot.
Parameters
----------
manager : Viewer or PlotPlugin.
Skimage viewer or plot plugin object.
overlay_shape : shape tuple
2D shape tuple used to initialize overlay image.
radius : int
The size of the paint cursor.
alpha : float (between [0, 1])
Opacity of overlay.
on_move : function
Function called whenever a control handle is moved.
This function must accept the end points of line as the only argument.
on_release : function
Function called whenever the control handle is released.
on_enter : function
Function called whenever the "enter" key is pressed.
rect_props : dict
Properties for :class:`matplotlib.patches.Rectangle`. This class
redefines defaults in :class:`matplotlib.widgets.RectangleSelector`.
Attributes
----------
overlay : array
Overlay of painted labels displayed on top of image.
label : int
Current paint color.
Examples
----------
>>> from skimage.data import camera
>>> import matplotlib.pyplot as plt
>>> from skimage.viewer.canvastools import PaintTool
>>> import numpy as np
>>> img = camera() #doctest: +SKIP
>>> ax = plt.subplot(111) #doctest: +SKIP
>>> plt.imshow(img, cmap=plt.cm.gray) #doctest: +SKIP
>>> p = PaintTool(ax,np.shape(img[:-1]),10,0.2) #doctest: +SKIP
>>> plt.show() #doctest: +SKIP
>>> mask = p.overlay #doctest: +SKIP
>>> plt.imshow(mask,cmap=plt.cm.gray) #doctest: +SKIP
>>> plt.show() #doctest: +SKIP
"""
def __init__(self, manager, overlay_shape, radius=5, alpha=0.3,
on_move=None, on_release=None, on_enter=None,
rect_props=None):
super(PaintTool, self).__init__(manager, on_move=on_move,
on_enter=on_enter,
on_release=on_release)
props = dict(edgecolor='r', facecolor='0.7', alpha=0.5, animated=True)
props.update(rect_props if rect_props is not None else {})
self.alpha = alpha
self.cmap = LABELS_CMAP
self._overlay_plot = None
self.shape = overlay_shape[:2]
self._cursor = plt.Rectangle((0, 0), 0, 0, **props)
self._cursor.set_visible(False)
self.ax.add_patch(self._cursor)
# `label` and `radius` can only be set after initializing `_cursor`
self.label = 1
self.radius = radius
# Note that the order is important: Redraw cursor *after* overlay
self.artists = [self._overlay_plot, self._cursor]
self.manager.add_tool(self)
@property
def label(self):
return self._label
@label.setter
def label(self, value):
if value >= self.cmap.N:
raise ValueError('Maximum label value = %s' % len(self.cmap - 1))
self._label = value
self._cursor.set_edgecolor(self.cmap(value))
@property
def radius(self):
return self._radius
@radius.setter
def radius(self, r):
self._radius = r
self._width = 2 * r + 1
self._cursor.set_width(self._width)
self._cursor.set_height(self._width)
self.window = CenteredWindow(r, self._shape)
@property
def overlay(self):
return self._overlay
@overlay.setter
def overlay(self, image):
self._overlay = image
if image is None:
self.ax.images.remove(self._overlay_plot)
self._overlay_plot = None
elif self._overlay_plot is None:
props = dict(cmap=self.cmap, alpha=self.alpha,
norm=mcolors.NoNorm(vmin=0, vmax=self.cmap.N),
animated=True)
self._overlay_plot = self.ax.imshow(image, **props)
else:
self._overlay_plot.set_data(image)
self.redraw()
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
self._shape = shape
if not self._overlay_plot is None:
self._overlay_plot.set_extent((-0.5, shape[1] + 0.5,
shape[0] + 0.5, -0.5))
self.radius = self._radius
self.overlay = np.zeros(shape, dtype='uint8')
def on_key_press(self, event):
if event.key == 'enter':
self.callback_on_enter(self.geometry)
self.redraw()
def on_mouse_press(self, event):
if event.button != 1 or not self.ax.in_axes(event):
return
self.update_cursor(event.xdata, event.ydata)
self.update_overlay(event.xdata, event.ydata)
def on_mouse_release(self, event):
if event.button != 1:
return
self.callback_on_release(self.geometry)
def on_move(self, event):
if not self.ax.in_axes(event):
self._cursor.set_visible(False)
self.redraw() # make sure cursor is not visible
return
self._cursor.set_visible(True)
self.update_cursor(event.xdata, event.ydata)
if event.button != 1:
self.redraw() # update cursor position
return
self.update_overlay(event.xdata, event.ydata)
self.callback_on_move(self.geometry)
def update_overlay(self, x, y):
overlay = self.overlay
overlay[self.window.at(y, x)] = self.label
# Note that overlay calls `redraw`
self.overlay = overlay
def update_cursor(self, x, y):
x = x - self.radius - 1
y = y - self.radius - 1
self._cursor.set_xy((x, y))
@property
def geometry(self):
return self.overlay
class CenteredWindow(object):
"""Window that create slices numpy arrays over 2D windows.
Examples
--------
>>> a = np.arange(16).reshape(4, 4)
>>> w = CenteredWindow(1, a.shape)
>>> a[w.at(1, 1)]
array([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
>>> a[w.at(0, 0)]
array([[0, 1],
[4, 5]])
>>> a[w.at(4, 3)]
array([[14, 15]])
"""
def __init__(self, radius, array_shape):
self.radius = radius
self.array_shape = array_shape
def at(self, row, col):
h, w = self.array_shape
r = round(self.radius)
# Note: the int() cast is necessary because row and col are np.float64,
# which does not get cast by round(), unlike a normal Python float:
# >>> round(4.5)
# 4
# >>> round(np.float64(4.5))
# 4.0
# >>> int(round(np.float64(4.5)))
# 4
row, col = int(round(row)), int(round(col))
xmin = max(0, col - r)
xmax = min(w, col + r + 1)
ymin = max(0, row - r)
ymax = min(h, row + r + 1)
return (slice(ymin, ymax), slice(xmin, xmax))
if __name__ == '__main__': # pragma: no cover
np.testing.rundocs()
from ... import data
from ...viewer import ImageViewer
image = data.camera()
viewer = ImageViewer(image)
paint_tool = PaintTool(viewer, image.shape)
viewer.show()

View file

@ -0,0 +1,245 @@
from matplotlib.widgets import RectangleSelector
from ...viewer.canvastools.base import CanvasToolBase
from ...viewer.canvastools.base import ToolHandles
__all__ = ['RectangleTool']
class RectangleTool(CanvasToolBase, RectangleSelector):
"""Widget for selecting a rectangular region in a plot.
After making the desired selection, press "Enter" to accept the selection
and call the `on_enter` callback function.
Parameters
----------
manager : Viewer or PlotPlugin.
Skimage viewer or plot plugin object.
on_move : function
Function called whenever a control handle is moved.
This function must accept the rectangle extents as the only argument.
on_release : function
Function called whenever the control handle is released.
on_enter : function
Function called whenever the "enter" key is pressed.
maxdist : float
Maximum pixel distance allowed when selecting control handle.
rect_props : dict
Properties for :class:`matplotlib.patches.Rectangle`. This class
redefines defaults in :class:`matplotlib.widgets.RectangleSelector`.
Attributes
----------
extents : tuple
Rectangle extents: (xmin, xmax, ymin, ymax).
Examples
----------
>>> from skimage import data
>>> from skimage.viewer import ImageViewer
>>> from skimage.viewer.canvastools import RectangleTool
>>> from skimage.draw import line
>>> from skimage.draw import set_color
>>> viewer = ImageViewer(data.coffee()) # doctest: +SKIP
>>> def print_the_rect(extents):
... global viewer
... im = viewer.image
... coord = np.int64(extents)
... [rr1, cc1] = line(coord[2],coord[0],coord[2],coord[1])
... [rr2, cc2] = line(coord[2],coord[1],coord[3],coord[1])
... [rr3, cc3] = line(coord[3],coord[1],coord[3],coord[0])
... [rr4, cc4] = line(coord[3],coord[0],coord[2],coord[0])
... set_color(im, (rr1, cc1), [255, 255, 0])
... set_color(im, (rr2, cc2), [0, 255, 255])
... set_color(im, (rr3, cc3), [255, 0, 255])
... set_color(im, (rr4, cc4), [0, 0, 0])
... viewer.image=im
>>> rect_tool = RectangleTool(viewer, on_enter=print_the_rect) # doctest: +SKIP
>>> viewer.show() # doctest: +SKIP
"""
def __init__(self, manager, on_move=None, on_release=None, on_enter=None,
maxdist=10, rect_props=None):
self._rect = None
props = dict(edgecolor=None, facecolor='r', alpha=0.15)
props.update(rect_props if rect_props is not None else {})
if props['edgecolor'] is None:
props['edgecolor'] = props['facecolor']
RectangleSelector.__init__(self, manager.ax, lambda *args: None,
rectprops=props)
CanvasToolBase.__init__(self, manager, on_move=on_move,
on_enter=on_enter, on_release=on_release)
# Events are handled by the viewer
try:
self.disconnect_events()
except AttributeError:
# disconnect the events manually (hack for older mpl versions)
[self.canvas.mpl_disconnect(i) for i in range(10)]
# Alias rectangle attribute, which is initialized in RectangleSelector.
self._rect = self.to_draw
self._rect.set_animated(True)
self.maxdist = maxdist
self.active_handle = None
self._extents_on_press = None
if on_enter is None:
def on_enter(extents):
print("(xmin=%.3g, xmax=%.3g, ymin=%.3g, ymax=%.3g)" % extents)
self.callback_on_enter = on_enter
props = dict(mec=props['edgecolor'])
self._corner_order = ['NW', 'NE', 'SE', 'SW']
xc, yc = self.corners
self._corner_handles = ToolHandles(self.ax, xc, yc, marker_props=props)
self._edge_order = ['W', 'N', 'E', 'S']
xe, ye = self.edge_centers
self._edge_handles = ToolHandles(self.ax, xe, ye, marker='s',
marker_props=props)
self.artists = [self._rect,
self._corner_handles.artist,
self._edge_handles.artist]
self.manager.add_tool(self)
@property
def _rect_bbox(self):
if not self._rect:
return 0, 0, 0, 0
x0 = self._rect.get_x()
y0 = self._rect.get_y()
width = self._rect.get_width()
height = self._rect.get_height()
return x0, y0, width, height
@property
def corners(self):
"""Corners of rectangle from lower left, moving clockwise."""
x0, y0, width, height = self._rect_bbox
xc = x0, x0 + width, x0 + width, x0
yc = y0, y0, y0 + height, y0 + height
return xc, yc
@property
def edge_centers(self):
"""Midpoint of rectangle edges from left, moving clockwise."""
x0, y0, width, height = self._rect_bbox
w = width / 2.
h = height / 2.
xe = x0, x0 + w, x0 + width, x0 + w
ye = y0 + h, y0, y0 + h, y0 + height
return xe, ye
@property
def extents(self):
"""Return (xmin, xmax, ymin, ymax)."""
x0, y0, width, height = self._rect_bbox
xmin, xmax = sorted([x0, x0 + width])
ymin, ymax = sorted([y0, y0 + height])
return xmin, xmax, ymin, ymax
@extents.setter
def extents(self, extents):
x1, x2, y1, y2 = extents
xmin, xmax = sorted([x1, x2])
ymin, ymax = sorted([y1, y2])
# Update displayed rectangle
self._rect.set_x(xmin)
self._rect.set_y(ymin)
self._rect.set_width(xmax - xmin)
self._rect.set_height(ymax - ymin)
# Update displayed handles
self._corner_handles.set_data(*self.corners)
self._edge_handles.set_data(*self.edge_centers)
self.set_visible(True)
self.redraw()
def on_mouse_release(self, event):
if event.button != 1:
return
if not self.ax.in_axes(event):
self.eventpress = None
return
RectangleSelector.release(self, event)
self._extents_on_press = None
# Undo hiding of rectangle and redraw.
self.set_visible(True)
self.redraw()
self.callback_on_release(self.geometry)
def on_mouse_press(self, event):
if event.button != 1 or not self.ax.in_axes(event):
return
self._set_active_handle(event)
if self.active_handle is None:
# Clear previous rectangle before drawing new rectangle.
self.set_visible(False)
self.redraw()
self.set_visible(True)
RectangleSelector.press(self, event)
def _set_active_handle(self, event):
"""Set active handle based on the location of the mouse event"""
# Note: event.xdata/ydata in data coordinates, event.x/y in pixels
c_idx, c_dist = self._corner_handles.closest(event.x, event.y)
e_idx, e_dist = self._edge_handles.closest(event.x, event.y)
# Set active handle as closest handle, if mouse click is close enough.
if c_dist > self.maxdist and e_dist > self.maxdist:
self.active_handle = None
return
elif c_dist < e_dist:
self.active_handle = self._corner_order[c_idx]
else:
self.active_handle = self._edge_order[e_idx]
# Save coordinates of rectangle at the start of handle movement.
x1, x2, y1, y2 = self.extents
# Switch variables so that only x2 and/or y2 are updated on move.
if self.active_handle in ['W', 'SW', 'NW']:
x1, x2 = x2, event.xdata
if self.active_handle in ['N', 'NW', 'NE']:
y1, y2 = y2, event.ydata
self._extents_on_press = x1, x2, y1, y2
def on_move(self, event):
if self.eventpress is None or not self.ax.in_axes(event):
return
if self.active_handle is None:
# New rectangle
x1 = self.eventpress.xdata
y1 = self.eventpress.ydata
x2, y2 = event.xdata, event.ydata
else:
x1, x2, y1, y2 = self._extents_on_press
if self.active_handle in ['E', 'W'] + self._corner_order:
x2 = event.xdata
if self.active_handle in ['N', 'S'] + self._corner_order:
y2 = event.ydata
self.extents = (x1, x2, y1, y2)
self.callback_on_move(self.geometry)
@property
def geometry(self):
return self.extents
if __name__ == '__main__': # pragma: no cover
from ...viewer import ImageViewer
from ... import data
viewer = ImageViewer(data.camera())
rect_tool = RectangleTool(viewer)
viewer.show()
print("Final selection:")
rect_tool.callback_on_enter(rect_tool.extents)

View file

@ -0,0 +1,9 @@
from .base import Plugin
from .canny import CannyPlugin
from .color_histogram import ColorHistogram
from .crop import Crop
from .labelplugin import LabelPainter
from .lineprofile import LineProfile
from .measure import Measure
from .overlayplugin import OverlayPlugin
from .plotplugin import PlotPlugin

View file

@ -0,0 +1,261 @@
"""
Base class for Plugins that interact with ImageViewer.
"""
from warnings import warn
import numpy as np
from ..qt import QtWidgets, QtCore, Signal
from ..utils import RequiredAttr, init_qtapp
class Plugin(QtWidgets.QDialog):
"""Base class for plugins that interact with an ImageViewer.
A plugin connects an image filter (or another function) to an image viewer.
Note that a Plugin is initialized *without* an image viewer and attached in
a later step. See example below for details.
Parameters
----------
image_viewer : ImageViewer
Window containing image used in measurement/manipulation.
image_filter : function
Function that gets called to update image in image viewer. This value
can be `None` if, for example, you have a plugin that extracts
information from an image and doesn't manipulate it. Alternatively,
this function can be defined as a method in a Plugin subclass.
height, width : int
Size of plugin window in pixels. Note that Qt will automatically resize
a window to fit components. So if you're adding rows of components, you
can leave `height = 0` and just let Qt determine the final height.
useblit : bool
If True, use blitting to speed up animation. Only available on some
Matplotlib backends. If None, set to True when using Agg backend.
This only has an effect if you draw on top of an image viewer.
Attributes
----------
image_viewer : ImageViewer
Window containing image used in measurement.
name : str
Name of plugin. This is displayed as the window title.
artist : list
List of Matplotlib artists and canvastools. Any artists created by the
plugin should be added to this list so that it gets cleaned up on
close.
Examples
--------
>>> from skimage.viewer import ImageViewer
>>> from skimage.viewer.widgets import Slider
>>> from skimage import data
>>>
>>> plugin = Plugin(image_filter=lambda img,
... threshold: img > threshold) # doctest: +SKIP
>>> plugin += Slider('threshold', 0, 255) # doctest: +SKIP
>>>
>>> image = data.coins()
>>> viewer = ImageViewer(image) # doctest: +SKIP
>>> viewer += plugin # doctest: +SKIP
>>> thresholded = viewer.show()[0][0] # doctest: +SKIP
The plugin will automatically delegate parameters to `image_filter` based
on its parameter type, i.e., `ptype` (widgets for required arguments must
be added in the order they appear in the function). The image attached
to the viewer is **automatically passed as the first argument** to the
filter function.
#TODO: Add flag so image is not passed to filter function by default.
`ptype = 'kwarg'` is the default for most widgets so it's unnecessary here.
"""
name = 'Plugin'
image_viewer = RequiredAttr("%s is not attached to ImageViewer" % name)
# Signals used when viewers are linked to the Plugin output.
image_changed = Signal(np.ndarray)
_started = Signal(int)
def __init__(self, image_filter=None, height=0, width=400, useblit=True,
dock='bottom'):
init_qtapp()
super(Plugin, self).__init__()
self.dock = dock
self.image_viewer = None
# If subclass defines `image_filter` method ignore input.
if not hasattr(self, 'image_filter'):
self.image_filter = image_filter
elif image_filter is not None:
warn("If the Plugin class defines an `image_filter` method, "
"then the `image_filter` argument is ignored.")
self.setWindowTitle(self.name)
self.layout = QtWidgets.QGridLayout(self)
self.resize(width, height)
self.row = 0
self.arguments = []
self.keyword_arguments = {}
self.useblit = useblit
self.cids = []
self.artists = []
def attach(self, image_viewer):
"""Attach the plugin to an ImageViewer.
Note that the ImageViewer will automatically call this method when the
plugin is added to the ImageViewer. For example::
viewer += Plugin(...)
Also note that `attach` automatically calls the filter function so that
the image matches the filtered value specified by attached widgets.
"""
self.setParent(image_viewer)
self.setWindowFlags(QtCore.Qt.Dialog)
self.image_viewer = image_viewer
self.image_viewer.plugins.append(self)
#TODO: Always passing image as first argument may be bad assumption.
self.arguments = [self.image_viewer.original_image]
# Call filter so that filtered image matches widget values
self.filter_image()
def add_widget(self, widget):
"""Add widget to plugin.
Alternatively, Plugin's `__add__` method is overloaded to add widgets::
plugin += Widget(...)
Widgets can adjust required or optional arguments of filter function or
parameters for the plugin. This is specified by the Widget's `ptype`.
"""
if widget.ptype == 'kwarg':
name = widget.name.replace(' ', '_')
self.keyword_arguments[name] = widget
widget.callback = self.filter_image
elif widget.ptype == 'arg':
self.arguments.append(widget)
widget.callback = self.filter_image
elif widget.ptype == 'plugin':
widget.callback = self.update_plugin
widget.plugin = self
self.layout.addWidget(widget, self.row, 0)
self.row += 1
def __add__(self, widget):
self.add_widget(widget)
return self
def filter_image(self, *widget_arg):
"""Call `image_filter` with widget args and kwargs
Note: `display_filtered_image` is automatically called.
"""
# `widget_arg` is passed by the active widget but is unused since all
# filter arguments are pulled directly from attached the widgets.
if self.image_filter is None:
return
arguments = [self._get_value(a) for a in self.arguments]
kwargs = {name: self._get_value(a)
for name, a in self.keyword_arguments.items()}
filtered = self.image_filter(*arguments, **kwargs)
self.display_filtered_image(filtered)
self.image_changed.emit(filtered)
def _get_value(self, param):
# If param is a widget, return its `val` attribute.
return param if not hasattr(param, 'val') else param.val
def _update_original_image(self, image):
"""Update the original image argument passed to the filter function.
This method is called by the viewer when the original image is updated.
"""
self.arguments[0] = image
self._on_new_image(image)
self.filter_image()
def _on_new_image(self, image):
"""Override this method to update your plugin for new images."""
pass
@property
def filtered_image(self):
"""Return filtered image."""
return self.image_viewer.image
def display_filtered_image(self, image):
"""Display the filtered image on image viewer.
If you don't want to simply replace the displayed image with the
filtered image (e.g., you want to display a transparent overlay),
you can override this method.
"""
self.image_viewer.image = image
def update_plugin(self, name, value):
"""Update keyword parameters of the plugin itself.
These parameters will typically be implemented as class properties so
that they update the image or some other component.
"""
setattr(self, name, value)
def show(self, main_window=True):
"""Show plugin."""
super(Plugin, self).show()
self.activateWindow()
self.raise_()
# Emit signal with x-hint so new windows can be displayed w/o overlap.
size = self.frameGeometry()
x_hint = size.x() + size.width()
self._started.emit(x_hint)
def closeEvent(self, event):
"""On close disconnect all artists and events from ImageViewer.
Note that artists must be appended to `self.artists`.
"""
self.clean_up()
self.close()
def clean_up(self):
self.remove_image_artists()
if self in self.image_viewer.plugins:
self.image_viewer.plugins.remove(self)
self.image_viewer.reset_image()
self.image_viewer.redraw()
def remove_image_artists(self):
"""Remove artists that are connected to the image viewer."""
for a in self.artists:
a.remove()
def output(self):
"""Return the plugin's representation and data.
Returns
-------
image : array, same shape as ``self.image_viewer.image``, or None
The filtered image.
data : None
Any data associated with the plugin.
Notes
-----
Derived classes should override this method to return a tuple
containing an *overlay* of the same shape of the image, and a
*data* object. Either of these is optional: return ``None`` if
you don't want to return a value.
"""
return (self.image_viewer.image, None)

View file

@ -0,0 +1,30 @@
import numpy as np
import skimage
from ...feature import canny
from .overlayplugin import OverlayPlugin
from ..widgets import Slider, ComboBox
class CannyPlugin(OverlayPlugin):
"""Canny filter plugin to show edges of an image."""
name = 'Canny Filter'
def __init__(self, *args, **kwargs):
super(CannyPlugin, self).__init__(image_filter=canny, **kwargs)
def attach(self, image_viewer):
image = image_viewer.image
imin, imax = skimage.dtype_limits(image, clip_negative=False)
itype = 'float' if np.issubdtype(image.dtype, np.floating) else 'int'
self.add_widget(Slider('sigma', 0, 5, update_on='release'))
self.add_widget(Slider('low threshold', imin, imax, value_type=itype,
update_on='release'))
self.add_widget(Slider('high threshold', imin, imax, value_type=itype,
update_on='release'))
self.add_widget(ComboBox('color', self.color_names, ptype='plugin'))
# Call parent method at end b/c it calls `filter_image`, which needs
# the values specified by the widgets. Alternatively, move call to
# parent method to beginning and add a call to `self.filter_image()`
super(CannyPlugin,self).attach(image_viewer)

View file

@ -0,0 +1,93 @@
import numpy as np
import matplotlib.pyplot as plt
from ... import color, exposure
from .plotplugin import PlotPlugin
from ..canvastools import RectangleTool
class ColorHistogram(PlotPlugin):
name = 'Color Histogram'
def __init__(self, max_pct=0.99, **kwargs):
super(ColorHistogram, self).__init__(height=400, **kwargs)
self.max_pct = max_pct
print(self.help())
def attach(self, image_viewer):
super(ColorHistogram, self).attach(image_viewer)
self.rect_tool = RectangleTool(self,
on_release=self.ab_selected)
self._on_new_image(image_viewer.image)
def _on_new_image(self, image):
self.lab_image = color.rgb2lab(image)
# Calculate color histogram in the Lab colorspace:
L, a, b = self.lab_image.T
left, right = -100, 100
ab_extents = [left, right, right, left]
self.mask = np.ones(L.shape, bool)
bins = np.arange(left, right)
hist, x_edges, y_edges = np.histogram2d(a.flatten(), b.flatten(),
bins, normed=True)
self.data = {'bins': bins, 'hist': hist, 'edges': (x_edges, y_edges),
'extents': (left, right, left, right)}
# Clip bin heights that dominate a-b histogram
max_val = pct_total_area(hist, percentile=self.max_pct)
hist = exposure.rescale_intensity(hist, in_range=(0, max_val))
self.ax.imshow(hist, extent=ab_extents, cmap=plt.cm.gray)
self.ax.set_title('Color Histogram')
self.ax.set_xlabel('b')
self.ax.set_ylabel('a')
def help(self):
helpstr = ("Color Histogram tool:",
"Select region of a-b colorspace to highlight on image.")
return '\n'.join(helpstr)
def ab_selected(self, extents):
x0, x1, y0, y1 = extents
self.data['extents'] = extents
lab_masked = self.lab_image.copy()
L, a, b = lab_masked.T
self.mask = ((a > y0) & (a < y1)) & ((b > x0) & (b < x1))
lab_masked[..., 1:][~self.mask.T] = 0
self.image_viewer.image = color.lab2rgb(lab_masked)
def output(self):
"""Return the image mask and the histogram data.
Returns
-------
mask : array of bool, same shape as image
The selected pixels.
data : dict
The data describing the histogram and the selected region.
The dictionary contains:
- 'bins' : array of float
The bin boundaries for both `a` and `b` channels.
- 'hist' : 2D array of float
The normalized histogram.
- 'edges' : tuple of array of float
The bin edges along each dimension
- 'extents' : tuple of float
The left and right and top and bottom of the selected region.
"""
return (self.mask, self.data)
def pct_total_area(image, percentile=0.80):
"""Return threshold value based on percentage of total area.
The specified percent of pixels less than the given intensity threshold.
"""
idx = int((image.size - 1) * percentile)
sorted_pixels = np.sort(image.flat)
return sorted_pixels[idx]

View file

@ -0,0 +1,45 @@
from .base import Plugin
from ..canvastools import RectangleTool
from ...viewer.widgets import SaveButtons, Button
__all__ = ['Crop']
class Crop(Plugin):
name = 'Crop'
def __init__(self, maxdist=10, **kwargs):
super(Crop, self).__init__(**kwargs)
self.maxdist = maxdist
self.add_widget(SaveButtons())
print(self.help())
def attach(self, image_viewer):
super(Crop, self).attach(image_viewer)
self.rect_tool = RectangleTool(image_viewer,
maxdist=self.maxdist,
on_enter=self.crop)
self.artists.append(self.rect_tool)
self.reset_button = Button('Reset', self.reset)
self.add_widget(self.reset_button)
def help(self):
helpstr = ("Crop tool",
"Select rectangular region and press enter to crop.")
return '\n'.join(helpstr)
def crop(self, extents):
xmin, xmax, ymin, ymax = extents
if xmin == xmax or ymin == ymax:
return
image = self.image_viewer.image[ymin:ymax+1, xmin:xmax+1]
self.image_viewer.image = image
self.image_viewer.ax.relim()
def reset(self):
self.rect_tool.extents = -10, -10, -10, -10
self.image_viewer.image = self.image_viewer.original_image
self.image_viewer.ax.relim()

View file

@ -0,0 +1,67 @@
import numpy as np
from .base import Plugin
from ..widgets import ComboBox, Slider
from ..canvastools import PaintTool
__all__ = ['LabelPainter']
rad2deg = 180 / np.pi
class LabelPainter(Plugin):
name = 'LabelPainter'
def __init__(self, max_radius=20, **kwargs):
super(LabelPainter, self).__init__(**kwargs)
# These widgets adjust plugin properties instead of an image filter.
self._radius_widget = Slider('radius', low=1, high=max_radius,
value=5, value_type='int', ptype='plugin')
labels = [str(i) for i in range(6)]
labels[0] = 'Erase'
self._label_widget = ComboBox('label', labels, ptype='plugin')
self.add_widget(self._radius_widget)
self.add_widget(self._label_widget)
print(self.help())
def help(self):
helpstr = ("Label painter",
"Hold left-mouse button and paint on canvas.")
return '\n'.join(helpstr)
def attach(self, image_viewer):
super(LabelPainter, self).attach(image_viewer)
image = image_viewer.original_image
self.paint_tool = PaintTool(image_viewer, image.shape,
on_enter=self.on_enter)
self.paint_tool.radius = self.radius
self.paint_tool.label = self._label_widget.index = 1
self.artists.append(self.paint_tool)
def _on_new_image(self, image):
"""Update plugin for new images."""
self.paint_tool.shape = image.shape
def on_enter(self, overlay):
pass
@property
def radius(self):
return self._radius_widget.val
@radius.setter
def radius(self, val):
self.paint_tool.radius = val
@property
def label(self):
return self._label_widget.val
@label.setter
def label(self, val):
self.paint_tool.label = val

View file

@ -0,0 +1,165 @@
import numpy as np
from ...util.dtype import dtype_range
from ... import draw, measure
from .plotplugin import PlotPlugin
from ..canvastools import ThickLineTool
__all__ = ['LineProfile']
class LineProfile(PlotPlugin):
"""Plugin to compute interpolated intensity under a scan line on an image.
See PlotPlugin and Plugin classes for additional details.
Parameters
----------
maxdist : float
Maximum pixel distance allowed when selecting end point of scan line.
limits : tuple or {None, 'image', 'dtype'}
(minimum, maximum) intensity limits for plotted profile. The following
special values are defined:
None : rescale based on min/max intensity along selected scan line.
'image' : fixed scale based on min/max intensity in image.
'dtype' : fixed scale based on min/max intensity of image dtype.
"""
name = 'Line Profile'
def __init__(self, maxdist=10, epsilon='deprecated',
limits='image', **kwargs):
super(LineProfile, self).__init__(**kwargs)
self.maxdist = maxdist
self._limit_type = limits
print(self.help())
def attach(self, image_viewer):
super(LineProfile, self).attach(image_viewer)
image = image_viewer.original_image
if self._limit_type == 'image':
self.limits = (np.min(image), np.max(image))
elif self._limit_type == 'dtype':
self.limits = dtype_range[image.dtype.type]
elif self._limit_type is None or len(self._limit_type) == 2:
self.limits = self._limit_type
else:
raise ValueError("Unrecognized `limits`: %s" % self._limit_type)
if not self._limit_type is None:
self.ax.set_ylim(self.limits)
h, w = image.shape[0:2]
x = [w / 3, 2 * w / 3]
y = [h / 2] * 2
self.line_tool = ThickLineTool(self.image_viewer,
maxdist=self.maxdist,
on_move=self.line_changed,
on_change=self.line_changed)
self.line_tool.end_points = np.transpose([x, y])
scan_data = measure.profile_line(image,
*self.line_tool.end_points[:, ::-1])
self.scan_data = scan_data
if scan_data.ndim == 1:
scan_data = scan_data[:, np.newaxis]
self.reset_axes(scan_data)
self._autoscale_view()
def help(self):
helpstr = ("Line profile tool",
"+ and - keys or mouse scroll changes width of scan line.",
"Select and drag ends of the scan line to adjust it.")
return '\n'.join(helpstr)
def get_profiles(self):
"""Return intensity profile of the selected line.
Returns
-------
end_points: (2, 2) array
The positions ((x1, y1), (x2, y2)) of the line ends.
profile: list of 1d arrays
Profile of intensity values. Length 1 (grayscale) or 3 (rgb).
"""
self._update_data()
profiles = [data.get_ydata() for data in self.profile]
return self.line_tool.end_points, profiles
def _autoscale_view(self):
if self.limits is None:
self.ax.autoscale_view(tight=True)
else:
self.ax.autoscale_view(scaley=False, tight=True)
def line_changed(self, end_points):
x, y = np.transpose(end_points)
self.line_tool.end_points = end_points
self._update_data()
self.ax.relim()
self._autoscale_view()
self.redraw()
def _update_data(self):
scan = measure.profile_line(self.image_viewer.image,
*self.line_tool.end_points[:, ::-1],
linewidth=self.line_tool.linewidth)
self.scan_data = scan
if scan.ndim == 1:
scan = scan[:, np.newaxis]
if scan.shape[1] != len(self.profile):
self.reset_axes(scan)
for i in range(len(scan[0])):
self.profile[i].set_xdata(np.arange(scan.shape[0]))
self.profile[i].set_ydata(scan[:, i])
def reset_axes(self, scan_data):
# Clear lines out
for line in self.ax.lines:
self.ax.lines = []
if scan_data.shape[1] == 1:
self.profile = self.ax.plot(scan_data, 'k-')
else:
self.profile = self.ax.plot(scan_data[:, 0], 'r-',
scan_data[:, 1], 'g-',
scan_data[:, 2], 'b-')
def output(self):
"""Return the drawn line and the resulting scan.
Returns
-------
line_image : (M, N) uint8 array, same shape as image
An array of 0s with the scanned line set to 255.
If the linewidth of the line tool is greater than 1,
sets the values within the profiled polygon to 128.
scan : (P,) or (P, 3) array of int or float
The line scan values across the image.
"""
end_points = self.line_tool.end_points
line_image = np.zeros(self.image_viewer.image.shape[:2],
np.uint8)
width = self.line_tool.linewidth
if width > 1:
rp, cp = measure.profile._line_profile_coordinates(
*end_points[:, ::-1], linewidth=width)
# the points are aliased, so create a polygon using the corners
yp = np.rint(rp[[0, 0, -1, -1],[0, -1, -1, 0]]).astype(int)
xp = np.rint(cp[[0, 0, -1, -1],[0, -1, -1, 0]]).astype(int)
rp, cp = draw.polygon(yp, xp, line_image.shape)
line_image[rp, cp] = 128
(x1, y1), (x2, y2) = end_points.astype(int)
rr, cc = draw.line(y1, x1, y2, x2)
line_image[rr, cc] = 255
return line_image, self.scan_data

View file

@ -0,0 +1,49 @@
import numpy as np
from .base import Plugin
from ..widgets import Text
from ..canvastools import LineTool
__all__ = ['Measure']
rad2deg = 180 / np.pi
class Measure(Plugin):
name = 'Measure'
def __init__(self, maxdist=10, **kwargs):
super(Measure, self).__init__(**kwargs)
self.maxdist = maxdist
self._length = Text('Length:')
self._angle = Text('Angle:')
self.add_widget(self._length)
self.add_widget(self._angle)
print(self.help())
def attach(self, image_viewer):
super(Measure, self).attach(image_viewer)
image = image_viewer.original_image
h, w = image.shape
self.line_tool = LineTool(self.image_viewer,
maxdist=self.maxdist,
on_move=self.line_changed)
self.artists.append(self.line_tool)
def help(self):
helpstr = ("Measure tool",
"Select line to measure distance and angle.")
return '\n'.join(helpstr)
def line_changed(self, end_points):
x, y = np.transpose(end_points)
dx = np.diff(x)[0]
dy = np.diff(y)[0]
self._length.text = '%.1f' % np.hypot(dx, dy)
self._angle.text = '%.1f°' % (180 - np.arctan2(dy, dx) * rad2deg)

View file

@ -0,0 +1,115 @@
from ...util.dtype import dtype_range
from .base import Plugin
from ..utils import ClearColormap, update_axes_image
from ..._shared.version_requirements import is_installed
__all__ = ['OverlayPlugin']
class OverlayPlugin(Plugin):
"""Plugin for ImageViewer that displays an overlay on top of main image.
The base Plugin class displays the filtered image directly on the viewer.
OverlayPlugin will instead overlay an image with a transparent colormap.
See base Plugin class for additional details.
Attributes
----------
overlay : array
Overlay displayed on top of image. This overlay defaults to a color map
with alpha values varying linearly from 0 to 1.
color : int
Color of overlay.
"""
colors = {'red': (1, 0, 0),
'yellow': (1, 1, 0),
'green': (0, 1, 0),
'cyan': (0, 1, 1)}
def __init__(self, **kwargs):
super(OverlayPlugin, self).__init__(**kwargs)
self._overlay_plot = None
self._overlay = None
self.cmap = None
self.color_names = sorted(list(self.colors.keys()))
def attach(self, image_viewer):
super(OverlayPlugin, self).attach(image_viewer)
#TODO: `color` doesn't update GUI widget when set manually.
self.color = 0
@property
def overlay(self):
return self._overlay
@overlay.setter
def overlay(self, image):
self._overlay = image
ax = self.image_viewer.ax
if image is None:
ax.images.remove(self._overlay_plot)
self._overlay_plot = None
elif self._overlay_plot is None:
vmin, vmax = dtype_range[image.dtype.type]
self._overlay_plot = ax.imshow(image, cmap=self.cmap,
vmin=vmin, vmax=vmax)
else:
update_axes_image(self._overlay_plot, image)
if self.image_viewer.useblit:
self.image_viewer._blit_manager.background = None
self.image_viewer.redraw()
@property
def color(self):
return self._color
@color.setter
def color(self, index):
# Update colormap whenever color is changed.
if isinstance(index, str) and \
index not in self.color_names:
raise ValueError("%s not defined in OverlayPlugin.colors" % index)
else:
name = self.color_names[index]
self._color = name
rgb = self.colors[name]
self.cmap = ClearColormap(rgb)
if self._overlay_plot is not None:
self._overlay_plot.set_cmap(self.cmap)
self.image_viewer.redraw()
@property
def filtered_image(self):
"""Return filtered image.
This "filtered image" is used when saving from the plugin.
"""
return self.overlay
def display_filtered_image(self, image):
"""Display filtered image as an overlay on top of image in viewer."""
self.overlay = image
def closeEvent(self, event):
# clear overlay from ImageViewer on close
self.overlay = None
super(OverlayPlugin, self).closeEvent(event)
def output(self):
"""Return the overlaid image.
Returns
-------
overlay : array, same shape as image
The overlay currently displayed.
data : None
"""
return (self.overlay, None)

View file

@ -0,0 +1,74 @@
import numpy as np
from ..qt import QtGui
from ..utils import new_plot
from ..utils.canvas import BlitManager, EventManager
from .base import Plugin
__all__ = ['PlotPlugin']
class PlotPlugin(Plugin):
"""Plugin for ImageViewer that contains a plot canvas.
Base class for plugins that contain a Matplotlib plot canvas, which can,
for example, display an image histogram.
See base Plugin class for additional details.
"""
def __init__(self, image_filter=None, height=150, width=400, **kwargs):
super(PlotPlugin, self).__init__(image_filter=image_filter,
height=height, width=width, **kwargs)
self._height = height
self._width = width
self._blit_manager = None
self._tools = []
self._event_manager = None
def attach(self, image_viewer):
super(PlotPlugin, self).attach(image_viewer)
# Add plot for displaying intensity profile.
self.add_plot()
if image_viewer.useblit:
self._blit_manager = BlitManager(self.ax)
self._event_manager = EventManager(self.ax)
def redraw(self):
"""Redraw plot."""
self.canvas.draw_idle()
def add_plot(self):
self.fig, self.ax = new_plot()
self.fig.set_figwidth(self._width / float(self.fig.dpi))
self.fig.set_figheight(self._height / float(self.fig.dpi))
self.canvas = self.fig.canvas
#TODO: Converted color is slightly different than Qt background.
qpalette = QtGui.QPalette()
qcolor = qpalette.color(QtGui.QPalette.Window)
bgcolor = qcolor.toRgb().value()
if np.isscalar(bgcolor):
bgcolor = str(bgcolor / 255.)
self.fig.patch.set_facecolor(bgcolor)
self.layout.addWidget(self.canvas, self.row, 0)
def _update_original_image(self, image):
super(PlotPlugin, self)._update_original_image(image)
self.redraw()
def add_tool(self, tool):
if self._blit_manager:
self._blit_manager.add_artists(tool.artists)
self._tools.append(tool)
self._event_manager.attach(tool)
def remove_tool(self, tool):
if tool not in self._tools:
return
if self._blit_manager:
self._blit_manager.remove_artists(tool.artists)
self._tools.remove(tool)
self._event_manager.detach(tool)

View file

@ -0,0 +1,44 @@
_qt_version = None
has_qt = True
try:
from matplotlib.backends.qt_compat import QtGui, QtCore, QtWidgets, QT_RC_MAJOR_VERSION as _qt_version
except ImportError:
try:
from matplotlib.backends.qt4_compat import QtGui, QtCore
QtWidgets = QtGui
_qt_version = 4
except ImportError:
# Mock objects
class QtGui_cls(object):
QMainWindow = object
QDialog = object
QWidget = object
class QtCore_cls(object):
class Qt(object):
TopDockWidgetArea = None
BottomDockWidgetArea = None
LeftDockWidgetArea = None
RightDockWidgetArea = None
def Signal(self, *args, **kwargs):
pass
QtGui = QtWidgets = QtGui_cls()
QtCore = QtCore_cls()
has_qt = False
if _qt_version == 5:
from matplotlib.backends.backend_qt5 import FigureManagerQT
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
elif _qt_version == 4:
from matplotlib.backends.backend_qt4 import FigureManagerQT
from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg
else:
FigureManagerQT = object
FigureCanvasQTAgg = object
Qt = QtCore.Qt
Signal = QtCore.Signal

View file

@ -0,0 +1,9 @@
from ..._shared.testing import setup_test, teardown_test
def setup():
setup_test()
def teardown():
teardown_test()

View file

@ -0,0 +1,181 @@
import numpy as np
from skimage import util
import skimage.data as data
from skimage.filters.rank import median
from skimage.morphology import disk
from skimage.viewer import ImageViewer, has_qt
from skimage.viewer.plugins.base import Plugin
from skimage.viewer.widgets import Slider
from skimage.viewer.plugins import (
LineProfile, Measure, CannyPlugin, LabelPainter, Crop, ColorHistogram,
PlotPlugin)
from skimage._shared import testing
from skimage._shared.testing import (assert_equal, assert_allclose,
assert_almost_equal)
def setup_line_profile(image, limits='image'):
viewer = ImageViewer(util.img_as_float(image))
plugin = LineProfile(limits=limits)
viewer += plugin
return plugin
@testing.skipif(not has_qt, reason="Qt not installed")
def test_line_profile():
""" Test a line profile using an ndim=2 image"""
plugin = setup_line_profile(data.camera())
line_image, scan_data = plugin.output()
for inp in [line_image.nonzero()[0].size,
line_image.sum() / line_image.max(),
scan_data.size]:
assert_equal(inp, 172)
assert_equal(line_image.shape, (512, 512))
assert_allclose(scan_data.max(), 0.9176, rtol=1e-3)
assert_allclose(scan_data.mean(), 0.2812, rtol=1e-3)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_line_profile_rgb():
""" Test a line profile using an ndim=3 image"""
plugin = setup_line_profile(data.chelsea(), limits=None)
for i in range(6):
plugin.line_tool._thicken_scan_line()
line_image, scan_data = plugin.output()
assert_equal(line_image[line_image == 128].size, 750)
assert_equal(line_image[line_image == 255].size, 151)
assert_equal(line_image.shape, (300, 451))
assert_equal(scan_data.shape, (151, 3))
assert_allclose(scan_data.max(), 0.772, rtol=1e-3)
assert_allclose(scan_data.mean(), 0.4359, rtol=1e-3)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_line_profile_dynamic():
"""Test a line profile updating after an image transform"""
image = data.coins()[:-50, :] # shave some off to make the line lower
image = util.img_as_float(image)
viewer = ImageViewer(image)
lp = LineProfile(limits='dtype')
viewer += lp
line = lp.get_profiles()[-1][0]
assert line.size == 129
assert_almost_equal(np.std(viewer.image), 0.208, 3)
assert_almost_equal(np.std(line), 0.229, 3)
assert_almost_equal(np.max(line) - np.min(line), 0.725, 1)
viewer.image = util.img_as_float(
median(util.img_as_ubyte(image), selem=disk(radius=3)))
line = lp.get_profiles()[-1][0]
assert_almost_equal(np.std(viewer.image), 0.198, 3)
assert_almost_equal(np.std(line), 0.220, 3)
assert_almost_equal(np.max(line) - np.min(line), 0.639, 1)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_measure():
image = data.camera()
viewer = ImageViewer(image)
m = Measure()
viewer += m
m.line_changed([(0, 0), (10, 10)])
assert_equal(str(m._length.text), '14.1')
assert_equal(str(m._angle.text[:5]), '135.0')
@testing.skipif(not has_qt, reason="Qt not installed")
def test_canny():
image = data.camera()
viewer = ImageViewer(image)
c = CannyPlugin()
viewer += c
canny_edges = viewer.show(False)
viewer.close()
edges = canny_edges[0][0]
assert edges.sum() == 2846
@testing.skipif(not has_qt, reason="Qt not installed")
def test_label_painter():
image = data.camera()
moon = data.moon()
viewer = ImageViewer(image)
lp = LabelPainter()
viewer += lp
assert_equal(lp.radius, 5)
lp.label = 1
assert_equal(str(lp.label), '1')
lp.label = 2
assert_equal(str(lp.paint_tool.label), '2')
assert_equal(lp.paint_tool.radius, 5)
lp._on_new_image(moon)
assert_equal(lp.paint_tool.shape, moon.shape)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_crop():
image = data.camera()
viewer = ImageViewer(image)
c = Crop()
viewer += c
c.crop((0, 100, 0, 100))
assert_equal(viewer.image.shape, (101, 101))
@testing.skipif(not has_qt, reason="Qt not installed")
def test_color_histogram():
image = util.img_as_float(data.colorwheel())
viewer = ImageViewer(image)
ch = ColorHistogram(dock='right')
viewer += ch
assert_almost_equal(viewer.image.std(), 0.352, 3),
ch.ab_selected((0, 100, 0, 100)),
assert_almost_equal(viewer.image.std(), 0.325, 3)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_plot_plugin():
viewer = ImageViewer(data.moon())
plugin = PlotPlugin(image_filter=lambda x: x)
viewer += plugin
assert_equal(viewer.image, data.moon())
plugin._update_original_image(data.coins())
assert_equal(viewer.image, data.coins())
viewer.close()
@testing.skipif(not has_qt, reason="Qt not installed")
def test_plugin():
img = util.img_as_float(data.moon())
viewer = ImageViewer(img)
def median_filter(img, radius=3):
return median(
util.img_as_ubyte(img), selem=disk(radius=radius))
plugin = Plugin(image_filter=median_filter)
viewer += plugin
plugin += Slider('radius', 1, 5)
assert_almost_equal(np.std(viewer.image), 12.556, 3)
plugin.filter_image()
assert_almost_equal(np.std(viewer.image), 12.931, 3)
plugin.show()
plugin.close()
plugin.clean_up()
img, _ = plugin.output()
assert_equal(img, viewer.image)

View file

@ -0,0 +1,212 @@
from collections import namedtuple
import numpy as np
from skimage import data
from skimage.viewer import ImageViewer, has_qt
from skimage.viewer.canvastools import (
LineTool, ThickLineTool, RectangleTool, PaintTool)
from skimage.viewer.canvastools.base import CanvasToolBase
from skimage._shared import testing
from skimage._shared.testing import assert_equal, parametrize
try:
from matplotlib.testing.decorators import cleanup
except ImportError:
def cleanup(func):
return func
def get_end_points(image):
h, w = image.shape[0:2]
x = [w / 3, 2 * w / 3]
y = [h / 2] * 2
return np.transpose([x, y])
def do_event(viewer, etype, button=1, xdata=0, ydata=0, key=None):
"""
*name*
the event name
*canvas*
the FigureCanvas instance generating the event
*guiEvent*
the GUI event that triggered the matplotlib event
*x*
x position - pixels from left of canvas
*y*
y position - pixels from bottom of canvas
*inaxes*
the :class:`~matplotlib.axes.Axes` instance if mouse is over axes
*xdata*
x coord of mouse in data coords
*ydata*
y coord of mouse in data coords
*button*
button pressed None, 1, 2, 3, 'up', 'down' (up and down are used
for scroll events)
*key*
the key depressed when the mouse event triggered (see
:class:`KeyEvent`)
*step*
number of scroll steps (positive for 'up', negative for 'down')
"""
ax = viewer.ax
event = namedtuple('Event',
('name canvas guiEvent x y inaxes xdata ydata '
'button key step'))
event.button = button
event.x, event.y = ax.transData.transform((xdata, ydata))
event.xdata, event.ydata = xdata, ydata
event.inaxes = ax
event.canvas = ax.figure.canvas
event.key = key
event.step = 1
event.guiEvent = None
event.name = 'Custom'
func = getattr(viewer._event_manager, 'on_%s' % etype)
func(event)
@cleanup
@testing.skipif(not has_qt, reason="Qt not installed")
def test_line_tool():
img = data.camera()
viewer = ImageViewer(img)
tool = LineTool(viewer, maxdist=10, line_props=dict(linewidth=3),
handle_props=dict(markersize=5))
tool.end_points = get_end_points(img)
assert_equal(tool.end_points, np.array([[170, 256], [341, 256]]))
# grab a handle and move it
do_event(viewer, 'mouse_press', xdata=170, ydata=256)
do_event(viewer, 'move', xdata=180, ydata=260)
do_event(viewer, 'mouse_release')
assert_equal(tool.geometry, np.array([[180, 260], [341, 256]]))
# create a new line
do_event(viewer, 'mouse_press', xdata=10, ydata=10)
do_event(viewer, 'move', xdata=100, ydata=100)
do_event(viewer, 'mouse_release')
assert_equal(tool.geometry, np.array([[100, 100], [10, 10]]))
@cleanup
@testing.skipif(not has_qt, reason="Qt not installed")
def test_thick_line_tool():
img = data.camera()
viewer = ImageViewer(img)
tool = ThickLineTool(viewer, maxdist=10, line_props=dict(color='red'),
handle_props=dict(markersize=5))
tool.end_points = get_end_points(img)
do_event(viewer, 'scroll', button='up')
assert_equal(tool.linewidth, 2)
do_event(viewer, 'scroll', button='down')
assert_equal(tool.linewidth, 1)
do_event(viewer, 'key_press', key='+')
assert_equal(tool.linewidth, 2)
do_event(viewer, 'key_press', key='-')
assert_equal(tool.linewidth, 1)
@cleanup
@testing.skipif(not has_qt, reason="Qt not installed")
def test_rect_tool():
img = data.camera()
viewer = ImageViewer(img)
tool = RectangleTool(viewer, maxdist=10)
tool.extents = (100, 150, 100, 150)
assert_equal(tool.corners,
((100, 150, 150, 100), (100, 100, 150, 150)))
assert_equal(tool.extents, (100, 150, 100, 150))
assert_equal(tool.edge_centers,
((100, 125.0, 150, 125.0), (125.0, 100, 125.0, 150)))
assert_equal(tool.geometry, (100, 150, 100, 150))
# grab a corner and move it
do_event(viewer, 'mouse_press', xdata=100, ydata=100)
do_event(viewer, 'move', xdata=120, ydata=120)
do_event(viewer, 'mouse_release')
# assert_equal(tool.geometry, [120, 150, 120, 150])
# create a new line
do_event(viewer, 'mouse_press', xdata=10, ydata=10)
do_event(viewer, 'move', xdata=100, ydata=100)
do_event(viewer, 'mouse_release')
assert_equal(tool.geometry, [10, 100, 10, 100])
@cleanup
@testing.skipif(not has_qt, reason="Qt not installed")
@parametrize('img', [data.moon(), data.astronaut()])
def test_paint_tool(img):
viewer = ImageViewer(img)
tool = PaintTool(viewer, img.shape)
tool.radius = 10
assert_equal(tool.radius, 10)
tool.label = 2
assert_equal(tool.label, 2)
assert_equal(tool.shape, img.shape[:2])
do_event(viewer, 'mouse_press', xdata=100, ydata=100)
do_event(viewer, 'move', xdata=110, ydata=110)
do_event(viewer, 'mouse_release')
assert_equal(tool.overlay[tool.overlay == 2].size, 761)
tool.label = 5
do_event(viewer, 'mouse_press', xdata=20, ydata=20)
do_event(viewer, 'move', xdata=40, ydata=40)
do_event(viewer, 'mouse_release')
assert_equal(tool.overlay[tool.overlay == 5].size, 881)
assert_equal(tool.overlay[tool.overlay == 2].size, 761)
do_event(viewer, 'key_press', key='enter')
tool.overlay = tool.overlay * 0
assert_equal(tool.overlay.sum(), 0)
assert_equal(tool.cmap.N, tool._overlay_plot.norm.vmax)
@cleanup
@testing.skipif(not has_qt, reason="Qt not installed")
def test_base_tool():
img = data.moon()
viewer = ImageViewer(img)
tool = CanvasToolBase(viewer)
tool.set_visible(False)
tool.set_visible(True)
do_event(viewer, 'key_press', key='enter')
tool.redraw()
tool.remove()
tool = CanvasToolBase(viewer, useblit=False)
tool.redraw()

View file

@ -0,0 +1,40 @@
from skimage.viewer import utils
from skimage.viewer.utils import dialogs
from skimage.viewer.qt import QtCore, QtWidgets, has_qt
from skimage._shared import testing
@testing.skipif(not has_qt, reason="Qt not installed")
def test_event_loop():
utils.init_qtapp()
timer = QtCore.QTimer()
timer.singleShot(10, QtWidgets.QApplication.quit)
utils.start_qtapp()
@testing.skipif(not has_qt, reason="Qt not installed")
def test_format_filename():
fname = dialogs._format_filename(('apple', 2))
assert fname == 'apple'
fname = dialogs._format_filename('')
assert fname is None
@testing.skipif(True, reason="Can't automatically close window. See #3081.")
@testing.skipif(not has_qt, reason="Qt not installed")
def test_open_file_dialog():
QApp = utils.init_qtapp()
timer = QtCore.QTimer()
timer.singleShot(100, lambda: QApp.quit())
filename = dialogs.open_file_dialog()
assert filename is None
@testing.skipif(True, reason="Can't automatically close window. See #3081.")
@testing.skipif(not has_qt, reason="Qt not installed")
def test_save_file_dialog():
QApp = utils.init_qtapp()
timer = QtCore.QTimer()
timer.singleShot(100, lambda: QApp.quit())
filename = dialogs.save_file_dialog()
assert filename is None

View file

@ -0,0 +1,79 @@
from skimage import data
from skimage.transform import pyramid_gaussian
from skimage.filters import sobel
from skimage.viewer.qt import QtGui, QtCore, has_qt
from skimage.viewer import ImageViewer, CollectionViewer
from skimage.viewer.plugins import OverlayPlugin
from skimage._shared.version_requirements import is_installed
from skimage._shared import testing
from skimage._shared.testing import assert_equal
@testing.skipif(not has_qt, reason="Qt not installed")
def test_viewer():
astro = data.astronaut()
coins = data.coins()
view = ImageViewer(astro)
import tempfile
_, filename = tempfile.mkstemp(suffix='.png')
view.show(False)
view.close()
view.save_to_file(filename)
view.open_file(filename)
assert_equal(view.image, astro)
view.image = coins
assert_equal(view.image, coins),
view.save_to_file(filename),
view.open_file(filename),
view.reset_image(),
assert_equal(view.image, coins)
def make_key_event(key):
return QtGui.QKeyEvent(QtCore.QEvent.KeyPress, key,
QtCore.Qt.NoModifier)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_collection_viewer():
img = data.astronaut()
img_collection = tuple(pyramid_gaussian(img, multichannel=True))
view = CollectionViewer(img_collection)
make_key_event(48)
view.update_index('', 2),
assert_equal(view.image, img_collection[2])
view.keyPressEvent(make_key_event(53))
assert_equal(view.image, img_collection[5])
view._format_coord(10, 10)
@testing.skipif(not has_qt, reason="Qt not installed")
@testing.skipif(not is_installed('matplotlib', '>=1.2'),
reason="matplotlib < 1.2")
def test_viewer_with_overlay():
img = data.coins()
ov = OverlayPlugin(image_filter=sobel)
viewer = ImageViewer(img)
viewer += ov
import tempfile
_, filename = tempfile.mkstemp(suffix='.png')
ov.color = 3
assert_equal(ov.color, 'yellow')
viewer.save_to_file(filename)
ov.display_filtered_image(img)
assert_equal(ov.overlay, img)
ov.overlay = None
assert_equal(ov.overlay, None)
ov.overlay = img
assert_equal(ov.overlay, img)
assert_equal(ov.filtered_image, img)

View file

@ -0,0 +1,129 @@
import os
from skimage import data, img_as_float, io, img_as_uint
from skimage.viewer import ImageViewer
from skimage.viewer.qt import QtWidgets, QtCore, has_qt
from skimage.viewer.widgets import (
Slider, OKCancelButtons, SaveButtons, ComboBox, CheckBox, Text)
from skimage.viewer.plugins.base import Plugin
from skimage._shared import testing
from skimage._shared.testing import assert_almost_equal, assert_equal
def get_image_viewer():
image = data.coins()
viewer = ImageViewer(img_as_float(image))
viewer += Plugin()
return viewer
@testing.skipif(not has_qt, reason="Qt not installed")
def test_check_box():
viewer = get_image_viewer()
cb = CheckBox('hello', value=True, alignment='left')
viewer.plugins[0] += cb
assert_equal(cb.val, True)
cb.val = False
assert_equal(cb.val, False)
cb.val = 1
assert_equal(cb.val, True)
cb.val = 0
assert_equal(cb.val, False)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_combo_box():
viewer = get_image_viewer()
cb = ComboBox('hello', ('a', 'b', 'c'))
viewer.plugins[0] += cb
assert_equal(str(cb.val), 'a')
assert_equal(cb.index, 0)
cb.index = 2
assert_equal(str(cb.val), 'c'),
assert_equal(cb.index, 2)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_text_widget():
viewer = get_image_viewer()
txt = Text('hello', 'hello, world!')
viewer.plugins[0] += txt
assert_equal(str(txt.text), 'hello, world!')
txt.text = 'goodbye, world!'
assert_equal(str(txt.text), 'goodbye, world!')
@testing.skipif(not has_qt, reason="Qt not installed")
def test_slider_int():
viewer = get_image_viewer()
sld = Slider('radius', 2, 10, value_type='int')
viewer.plugins[0] += sld
assert_equal(sld.val, 4)
sld.val = 6
assert_equal(sld.val, 6)
sld.editbox.setText('5')
sld._on_editbox_changed()
assert_equal(sld.val, 5)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_slider_float():
viewer = get_image_viewer()
sld = Slider('alpha', 2.1, 3.1, value=2.1, value_type='float',
orientation='vertical', update_on='move')
viewer.plugins[0] += sld
assert_equal(sld.val, 2.1)
sld.val = 2.5
assert_almost_equal(sld.val, 2.5, 2)
sld.editbox.setText('0.1')
sld._on_editbox_changed()
assert_almost_equal(sld.val, 2.5, 2)
@testing.skipif(True, reason="Can't automatically close window. See #3081.")
@testing.skipif(not has_qt, reason="Qt not installed")
def test_save_buttons():
viewer = get_image_viewer()
sv = SaveButtons()
viewer.plugins[0] += sv
import tempfile
fid, filename = tempfile.mkstemp(suffix='.png')
os.close(fid)
timer = QtCore.QTimer()
timer.singleShot(100, QtWidgets.QApplication.quit)
# exercise the button clicks
sv.save_stack.click()
sv.save_file.click()
# call the save functions directly
sv.save_to_stack()
sv.save_to_file(filename)
img = data.imread(filename)
assert_almost_equal(img, img_as_uint(viewer.image))
img = io.pop()
assert_almost_equal(img, viewer.image)
os.remove(filename)
@testing.skipif(not has_qt, reason="Qt not installed")
def test_ok_buttons():
viewer = get_image_viewer()
ok = OKCancelButtons()
viewer.plugins[0] += ok
ok.update_original_image(),
ok.close_plugin()

View file

@ -0,0 +1 @@
from .core import *

View file

@ -0,0 +1,106 @@
class BlitManager(object):
"""Object that manages blits on an axes"""
def __init__(self, ax):
self.ax = ax
self.canvas = ax.figure.canvas
self.canvas.mpl_connect('draw_event', self.on_draw_event)
self.ax = ax
self.background = None
self.artists = []
def add_artists(self, artists):
self.artists.extend(artists)
self.redraw()
def remove_artists(self, artists):
for artist in artists:
self.artists.remove(artist)
def on_draw_event(self, event=None):
self.background = self.canvas.copy_from_bbox(self.ax.bbox)
self.draw_artists()
def redraw(self):
if self.background is not None:
self.canvas.restore_region(self.background)
self.draw_artists()
self.canvas.blit(self.ax.bbox)
else:
self.canvas.draw_idle()
def draw_artists(self):
for artist in self.artists:
self.ax.draw_artist(artist)
class EventManager(object):
"""Object that manages events on a canvas"""
def __init__(self, ax):
self.canvas = ax.figure.canvas
self.connect_event('button_press_event', self.on_mouse_press)
self.connect_event('key_press_event', self.on_key_press)
self.connect_event('button_release_event', self.on_mouse_release)
self.connect_event('motion_notify_event', self.on_move)
self.connect_event('scroll_event', self.on_scroll)
self.tools = []
self.active_tool = None
def connect_event(self, name, handler):
self.canvas.mpl_connect(name, handler)
def attach(self, tool):
self.tools.append(tool)
self.active_tool = tool
def detach(self, tool):
self.tools.remove(tool)
if self.tools:
self.active_tool = self.tools[-1]
else:
self.active_tool = None
def on_mouse_press(self, event):
for tool in self.tools:
if not tool.ignore(event) and tool.hit_test(event):
self.active_tool = tool
break
if self.active_tool and not self.active_tool.ignore(event):
self.active_tool.on_mouse_press(event)
return
for tool in reversed(self.tools):
if not tool.ignore(event):
self.active_tool = tool
tool.on_mouse_press(event)
return
def on_key_press(self, event):
tool = self._get_tool(event)
if tool is not None:
tool.on_key_press(event)
def _get_tool(self, event):
if not self.tools or self.active_tool.ignore(event):
return None
return self.active_tool
def on_mouse_release(self, event):
tool = self._get_tool(event)
if tool is not None:
tool.on_mouse_release(event)
def on_move(self, event):
tool = self._get_tool(event)
if tool is not None:
tool.on_move(event)
def on_scroll(self, event):
tool = self._get_tool(event)
if tool is not None:
tool.on_scroll(event)

View file

@ -0,0 +1,212 @@
import numpy as np
from ..qt import QtWidgets, has_qt, FigureManagerQT, FigureCanvasQTAgg
from ..._shared.utils import warn
import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib import _pylab_helpers
from matplotlib.colors import LinearSegmentedColormap
if has_qt and 'agg' not in mpl.get_backend().lower():
warn("Recommended matplotlib backend is `Agg` for full "
"skimage.viewer functionality.")
__all__ = ['init_qtapp', 'start_qtapp', 'RequiredAttr', 'figimage',
'LinearColormap', 'ClearColormap', 'FigureCanvas', 'new_plot',
'update_axes_image']
QApp = None
def init_qtapp():
"""Initialize QAppliction.
The QApplication needs to be initialized before creating any QWidgets
"""
global QApp
QApp = QtWidgets.QApplication.instance()
if QApp is None:
QApp = QtWidgets.QApplication([])
return QApp
def is_event_loop_running(app=None):
"""Return True if event loop is running."""
if app is None:
app = init_qtapp()
if hasattr(app, '_in_event_loop'):
return app._in_event_loop
else:
return False
def start_qtapp(app=None):
"""Start Qt mainloop"""
if app is None:
app = init_qtapp()
if not is_event_loop_running(app):
app._in_event_loop = True
app.exec_()
app._in_event_loop = False
else:
app._in_event_loop = True
class RequiredAttr(object):
"""A class attribute that must be set before use."""
instances = dict()
def __init__(self, init_val=None):
self.instances[self, None] = init_val
def __get__(self, obj, objtype):
value = self.instances[self, obj]
if value is None:
raise AttributeError('Required attribute not set')
return value
def __set__(self, obj, value):
self.instances[self, obj] = value
class LinearColormap(LinearSegmentedColormap):
"""LinearSegmentedColormap in which color varies smoothly.
This class is a simplification of LinearSegmentedColormap, which doesn't
support jumps in color intensities.
Parameters
----------
name : str
Name of colormap.
segmented_data : dict
Dictionary of 'red', 'green', 'blue', and (optionally) 'alpha' values.
Each color key contains a list of `x`, `y` tuples. `x` must increase
monotonically from 0 to 1 and corresponds to input values for a
mappable object (e.g. an image). `y` corresponds to the color
intensity.
"""
def __init__(self, name, segmented_data, **kwargs):
segmented_data = {key: [(x, y, y) for x, y in value]
for key, value in segmented_data.items()}
LinearSegmentedColormap.__init__(self, name, segmented_data, **kwargs)
class ClearColormap(LinearColormap):
"""Color map that varies linearly from alpha = 0 to 1
"""
def __init__(self, rgb, max_alpha=1, name='clear_color'):
r, g, b = rgb
cg_speq = {'blue': [(0.0, b), (1.0, b)],
'green': [(0.0, g), (1.0, g)],
'red': [(0.0, r), (1.0, r)],
'alpha': [(0.0, 0.0), (1.0, max_alpha)]}
LinearColormap.__init__(self, name, cg_speq)
class FigureCanvas(FigureCanvasQTAgg):
"""Canvas for displaying images."""
def __init__(self, figure, **kwargs):
self.fig = figure
FigureCanvasQTAgg.__init__(self, self.fig)
FigureCanvasQTAgg.setSizePolicy(self,
QtWidgets.QSizePolicy.Expanding,
QtWidgets.QSizePolicy.Expanding)
FigureCanvasQTAgg.updateGeometry(self)
def resizeEvent(self, event):
FigureCanvasQTAgg.resizeEvent(self, event)
# Call to `resize_event` missing in FigureManagerQT.
# See https://github.com/matplotlib/matplotlib/pull/1585
self.resize_event()
def new_canvas(*args, **kwargs):
"""Return a new figure canvas."""
allnums = _pylab_helpers.Gcf.figs.keys()
num = max(allnums) + 1 if allnums else 1
FigureClass = kwargs.pop('FigureClass', Figure)
figure = FigureClass(*args, **kwargs)
canvas = FigureCanvas(figure)
fig_manager = FigureManagerQT(canvas, num)
return fig_manager.canvas
def new_plot(parent=None, subplot_kw=None, **fig_kw):
"""Return new figure and axes.
Parameters
----------
parent : QtWidget
Qt widget that displays the plot objects. If None, you must manually
call ``canvas.setParent`` and pass the parent widget.
subplot_kw : dict
Keyword arguments passed ``matplotlib.figure.Figure.add_subplot``.
fig_kw : dict
Keyword arguments passed ``matplotlib.figure.Figure``.
"""
if subplot_kw is None:
subplot_kw = {}
canvas = new_canvas(**fig_kw)
canvas.setParent(parent)
fig = canvas.figure
ax = fig.add_subplot(1, 1, 1, **subplot_kw)
return fig, ax
def figimage(image, scale=1, dpi=None, **kwargs):
"""Return figure and axes with figure tightly surrounding image.
Unlike pyplot.figimage, this actually plots onto an axes object, which
fills the figure. Plotting the image onto an axes allows for subsequent
overlays of axes artists.
Parameters
----------
image : array
image to plot
scale : float
If scale is 1, the figure and axes have the same dimension as the
image. Smaller values of `scale` will shrink the figure.
dpi : int
Dots per inch for figure. If None, use the default rcParam.
"""
dpi = dpi if dpi is not None else mpl.rcParams['figure.dpi']
kwargs.setdefault('interpolation', 'nearest')
kwargs.setdefault('cmap', 'gray')
h, w, d = np.atleast_3d(image).shape
figsize = np.array((w, h), dtype=float) / dpi * scale
fig, ax = new_plot(figsize=figsize, dpi=dpi)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
ax.set_axis_off()
ax.imshow(image, **kwargs)
ax.figure.canvas.draw()
return fig, ax
def update_axes_image(image_axes, image):
"""Update the image displayed by an image plot.
This sets the image plot's array and updates its shape appropriately
Parameters
----------
image_axes : `matplotlib.image.AxesImage`
Image axes to update.
image : array
Image array.
"""
image_axes.set_array(image)
# Adjust size if new image shape doesn't match the original
h, w = image.shape[:2]
image_axes.set_extent((0, w, h, 0))

View file

@ -0,0 +1,35 @@
import os
from ..qt import QtWidgets
__all__ = ['open_file_dialog', 'save_file_dialog']
def _format_filename(filename):
if isinstance(filename, tuple):
# Handle discrepancy between PyQt4 and PySide APIs.
filename = filename[0]
if len(filename) == 0:
return None
return str(filename)
def open_file_dialog():
"""Return user-selected file path."""
filename = QtWidgets.QFileDialog.getOpenFileName()
filename = _format_filename(filename)
return filename
def save_file_dialog(default_format='png'):
"""Return user-selected file path."""
filename = QtWidgets.QFileDialog.getSaveFileName()
filename = _format_filename(filename)
if filename is None:
return None
# TODO: io plugins should assign default image formats
basename, ext = os.path.splitext(filename)
if not ext:
filename = '%s.%s' % (filename, default_format)
return filename

View file

@ -0,0 +1 @@
from .core import ImageViewer, CollectionViewer

View file

@ -0,0 +1,395 @@
"""
ImageViewer class for viewing and interacting with images.
"""
import numpy as np
from ... import io, img_as_float
from ...util.dtype import dtype_range
from ...exposure import rescale_intensity
from ..qt import QtWidgets, QtGui, Qt, Signal
from ..widgets import Slider
from ..utils import (dialogs, init_qtapp, figimage, start_qtapp,
update_axes_image)
from ..utils.canvas import BlitManager, EventManager
from ..plugins.base import Plugin
__all__ = ['ImageViewer', 'CollectionViewer']
def mpl_image_to_rgba(mpl_image):
"""Return RGB image from the given matplotlib image object.
Each image in a matplotlib figure has its own colormap and normalization
function. Return RGBA (RGB + alpha channel) image with float dtype.
Parameters
----------
mpl_image : matplotlib.image.AxesImage object
The image being converted.
Returns
-------
img : array of float, shape (M, N, 4)
An image of float values in [0, 1].
"""
image = mpl_image.get_array()
if image.ndim == 2:
input_range = (mpl_image.norm.vmin, mpl_image.norm.vmax)
image = rescale_intensity(image, in_range=input_range)
# cmap complains on bool arrays
image = mpl_image.cmap(img_as_float(image))
elif image.ndim == 3 and image.shape[2] == 3:
# add alpha channel if it's missing
image = np.dstack((image, np.ones_like(image)))
return img_as_float(image)
class ImageViewer(QtWidgets.QMainWindow):
"""Viewer for displaying images.
This viewer is a simple container object that holds a Matplotlib axes
for showing images. `ImageViewer` doesn't subclass the Matplotlib axes (or
figure) because of the high probability of name collisions.
Subclasses and plugins will likely extend the `update_image` method to add
custom overlays or filter the displayed image.
Parameters
----------
image : array
Image being viewed.
Attributes
----------
canvas, fig, ax : Matplotlib canvas, figure, and axes
Matplotlib canvas, figure, and axes used to display image.
image : array
Image being viewed. Setting this value will update the displayed frame.
original_image : array
Plugins typically operate on (but don't change) the *original* image.
plugins : list
List of attached plugins.
Examples
--------
>>> from skimage import data
>>> image = data.coins()
>>> viewer = ImageViewer(image) # doctest: +SKIP
>>> viewer.show() # doctest: +SKIP
"""
dock_areas = {'top': Qt.TopDockWidgetArea,
'bottom': Qt.BottomDockWidgetArea,
'left': Qt.LeftDockWidgetArea,
'right': Qt.RightDockWidgetArea}
# Signal that the original image has been changed
original_image_changed = Signal(np.ndarray)
def __init__(self, image, useblit=True):
# Start main loop
init_qtapp()
super(ImageViewer, self).__init__()
# TODO: Add ImageViewer to skimage.io window manager
self.setAttribute(Qt.WA_DeleteOnClose)
self.setWindowTitle("Image Viewer")
self.file_menu = QtWidgets.QMenu('&File', self)
self.file_menu.addAction('Open file', self.open_file,
Qt.CTRL + Qt.Key_O)
self.file_menu.addAction('Save to file', self.save_to_file,
Qt.CTRL + Qt.Key_S)
self.file_menu.addAction('Quit', self.close,
Qt.CTRL + Qt.Key_Q)
self.menuBar().addMenu(self.file_menu)
self.main_widget = QtWidgets.QWidget()
self.setCentralWidget(self.main_widget)
if isinstance(image, Plugin):
plugin = image
image = plugin.filtered_image
plugin.image_changed.connect(self._update_original_image)
# When plugin is started, start
plugin._started.connect(self._show)
self.fig, self.ax = figimage(image)
self.canvas = self.fig.canvas
self.canvas.setParent(self)
self.ax.autoscale(enable=False)
self._tools = []
self.useblit = useblit
if useblit:
self._blit_manager = BlitManager(self.ax)
self._event_manager = EventManager(self.ax)
self._image_plot = self.ax.images[0]
self._update_original_image(image)
self.plugins = []
self.layout = QtWidgets.QVBoxLayout(self.main_widget)
self.layout.addWidget(self.canvas)
status_bar = self.statusBar()
self.status_message = status_bar.showMessage
sb_size = status_bar.sizeHint()
cs_size = self.canvas.sizeHint()
self.resize(cs_size.width(), cs_size.height() + sb_size.height())
self.connect_event('motion_notify_event', self._update_status_bar)
def __add__(self, plugin):
"""Add plugin to ImageViewer"""
plugin.attach(self)
self.original_image_changed.connect(plugin._update_original_image)
if plugin.dock:
location = self.dock_areas[plugin.dock]
dock_location = Qt.DockWidgetArea(location)
dock = QtWidgets.QDockWidget()
dock.setWidget(plugin)
dock.setWindowTitle(plugin.name)
self.addDockWidget(dock_location, dock)
horiz = (self.dock_areas['left'], self.dock_areas['right'])
dimension = 'width' if location in horiz else 'height'
self._add_widget_size(plugin, dimension=dimension)
return self
def _add_widget_size(self, widget, dimension='width'):
widget_size = widget.sizeHint()
viewer_size = self.frameGeometry()
dx = dy = 0
if dimension == 'width':
dx = widget_size.width()
elif dimension == 'height':
dy = widget_size.height()
w = viewer_size.width()
h = viewer_size.height()
self.resize(w + dx, h + dy)
def open_file(self, filename=None):
"""Open image file and display in viewer."""
if filename is None:
filename = dialogs.open_file_dialog()
if filename is None:
return
image = io.imread(filename)
self._update_original_image(image)
def update_image(self, image):
"""Update displayed image.
This method can be overridden or extended in subclasses and plugins to
react to image changes.
"""
self._update_original_image(image)
def _update_original_image(self, image):
self.original_image = image # update saved image
self.image = image.copy() # update displayed image
self.original_image_changed.emit(image)
def save_to_file(self, filename=None):
"""Save current image to file.
The current behavior is not ideal: It saves the image displayed on
screen, so all images will be converted to RGB, and the image size is
not preserved (resizing the viewer window will alter the size of the
saved image).
"""
if filename is None:
filename = dialogs.save_file_dialog()
if filename is None:
return
if len(self.ax.images) == 1:
io.imsave(filename, self.image)
else:
underlay = mpl_image_to_rgba(self.ax.images[0])
overlay = mpl_image_to_rgba(self.ax.images[1])
alpha = overlay[:, :, 3]
# alpha can be set by channel of array or by a scalar value.
# Prefer the alpha channel, but fall back to scalar value.
if np.all(alpha == 1):
alpha = np.ones_like(alpha) * self.ax.images[1].get_alpha()
alpha = alpha[:, :, np.newaxis]
composite = (overlay[:, :, :3] * alpha +
underlay[:, :, :3] * (1 - alpha))
io.imsave(filename, composite)
def closeEvent(self, event):
self.close()
def _show(self, x=0):
self.move(x, 0)
for p in self.plugins:
p.show()
super(ImageViewer, self).show()
self.activateWindow()
self.raise_()
def show(self, main_window=True):
"""Show ImageViewer and attached plugins.
This behaves much like `matplotlib.pyplot.show` and `QWidget.show`.
"""
self._show()
if main_window:
start_qtapp()
return [p.output() for p in self.plugins]
def redraw(self):
if self.useblit:
self._blit_manager.redraw()
else:
self.canvas.draw_idle()
@property
def image(self):
return self._img
@image.setter
def image(self, image):
self._img = image
update_axes_image(self._image_plot, image)
# update display (otherwise image doesn't fill the canvas)
h, w = image.shape[:2]
self.ax.set_xlim(0, w)
self.ax.set_ylim(h, 0)
# update color range
clim = dtype_range[image.dtype.type]
if clim[0] < 0 and image.min() >= 0:
clim = (0, clim[1])
self._image_plot.set_clim(clim)
if self.useblit:
self._blit_manager.background = None
self.redraw()
def reset_image(self):
self.image = self.original_image.copy()
def connect_event(self, event, callback):
"""Connect callback function to matplotlib event and return id."""
cid = self.canvas.mpl_connect(event, callback)
return cid
def disconnect_event(self, callback_id):
"""Disconnect callback by its id (returned by `connect_event`)."""
self.canvas.mpl_disconnect(callback_id)
def _update_status_bar(self, event):
if event.inaxes and event.inaxes.get_navigate():
self.status_message(self._format_coord(event.xdata, event.ydata))
else:
self.status_message('')
def add_tool(self, tool):
if self.useblit:
self._blit_manager.add_artists(tool.artists)
self._tools.append(tool)
self._event_manager.attach(tool)
def remove_tool(self, tool):
if tool not in self._tools:
return
if self.useblit:
self._blit_manager.remove_artists(tool.artists)
self._tools.remove(tool)
self._event_manager.detach(tool)
def _format_coord(self, x, y):
# callback function to format coordinate display in status bar
x = int(x + 0.5)
y = int(y + 0.5)
try:
return "%4s @ [%4s, %4s]" % (self.image[y, x], x, y)
except IndexError:
return ""
class CollectionViewer(ImageViewer):
"""Viewer for displaying image collections.
Select the displayed frame of the image collection using the slider or
with the following keyboard shortcuts:
left/right arrows
Previous/next image in collection.
number keys, 0--9
0% to 90% of collection. For example, "5" goes to the image in the
middle (i.e. 50%) of the collection.
home/end keys
First/last image in collection.
Parameters
----------
image_collection : list of images
List of images to be displayed.
update_on : {'move' | 'release'}
Control whether image is updated on slide or release of the image
slider. Using 'on_release' will give smoother behavior when displaying
large images or when writing a plugin/subclass that requires heavy
computation.
"""
def __init__(self, image_collection, update_on='move', **kwargs):
self.image_collection = image_collection
self.index = 0
self.num_images = len(self.image_collection)
first_image = image_collection[0]
super(CollectionViewer, self).__init__(first_image)
slider_kws = dict(value=0, low=0, high=self.num_images - 1)
slider_kws['update_on'] = update_on
slider_kws['callback'] = self.update_index
slider_kws['value_type'] = 'int'
self.slider = Slider('frame', **slider_kws)
self.layout.addWidget(self.slider)
# TODO: Adjust height to accommodate slider; the following doesn't work
# s_size = self.slider.sizeHint()
# cs_size = self.canvas.sizeHint()
# self.resize(cs_size.width(), cs_size.height() + s_size.height())
def update_index(self, name, index):
"""Select image on display using index into image collection."""
index = int(round(index))
if index == self.index:
return
# clip index value to collection limits
index = max(index, 0)
index = min(index, self.num_images - 1)
self.index = index
self.slider.val = index
self.update_image(self.image_collection[index])
def keyPressEvent(self, event):
if type(event) == QtGui.QKeyEvent:
key = event.key()
# Number keys (code: 0 = key 48, 9 = key 57) move to deciles
if 48 <= key < 58:
index = int(0.1 * (key - 48) * self.num_images)
self.update_index('', index)
event.accept()
else:
event.ignore()
else:
event.ignore()

View file

@ -0,0 +1,20 @@
"""
Widgets for interacting with ImageViewer.
These widgets should be added to a Plugin subclass using its `add_widget`
method or calling::
plugin += Widget(...)
on a Plugin instance. The Plugin will delegate action based on the widget's
parameter type specified by its `ptype` attribute, which can be::
'arg' : positional argument passed to Plugin's `filter_image` method.
'kwarg' : keyword argument passed to Plugin's `filter_image` method.
'plugin' : attribute of Plugin. You'll probably need to add a class
property of the same name that updates the display.
"""
from .core import *
from .history import *

View file

@ -0,0 +1,309 @@
from ..qt import QtWidgets, QtCore, Qt, QtGui
from ..utils import RequiredAttr
__all__ = ['BaseWidget', 'Slider', 'ComboBox', 'CheckBox', 'Text', 'Button']
class BaseWidget(QtWidgets.QWidget):
plugin = RequiredAttr("Widget is not attached to a Plugin.")
def __init__(self, name, ptype=None, callback=None):
super(BaseWidget, self).__init__()
self.name = name
self.ptype = ptype
self.callback = callback
self.plugin = None
@property
def val(self):
msg = "Subclass of BaseWidget requires `val` property"
raise NotImplementedError(msg)
def _value_changed(self, value):
self.callback(self.name, value)
class Text(BaseWidget):
def __init__(self, name=None, text=''):
super(Text, self).__init__(name)
self._label = QtWidgets.QLabel()
self.text = text
self.layout = QtWidgets.QHBoxLayout(self)
if name is not None:
name_label = QtWidgets.QLabel()
name_label.setText(name)
self.layout.addWidget(name_label)
self.layout.addWidget(self._label)
@property
def text(self):
return self._label.text()
@text.setter
def text(self, text_str):
self._label.setText(text_str)
class Slider(BaseWidget):
"""Slider widget for adjusting numeric parameters.
Parameters
----------
name : str
Name of slider parameter. If this parameter is passed as a keyword
argument, it must match the name of that keyword argument (spaces are
replaced with underscores). In addition, this name is displayed as the
name of the slider.
low, high : float
Range of slider values.
value : float
Default slider value. If None, use midpoint between `low` and `high`.
value_type : {'float' | 'int'}, optional
Numeric type of slider value.
ptype : {'kwarg' | 'arg' | 'plugin'}, optional
Parameter type.
callback : callable f(widget_name, value), optional
Callback function called in response to slider changes.
*Note:* This function is typically set (overridden) when the widget is
added to a plugin.
orientation : {'horizontal' | 'vertical'}, optional
Slider orientation.
update_on : {'release' | 'move'}, optional
Control when callback function is called: on slider move or release.
"""
def __init__(self, name, low=0.0, high=1.0, value=None, value_type='float',
ptype='kwarg', callback=None, max_edit_width=60,
orientation='horizontal', update_on='release'):
super(Slider, self).__init__(name, ptype, callback)
if value is None:
value = (high - low) / 2.
# Set widget orientation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if orientation == 'vertical':
self.slider = QtWidgets.QSlider(Qt.Vertical)
alignment = QtCore.Qt.AlignHCenter
align_text = QtCore.Qt.AlignHCenter
align_value = QtCore.Qt.AlignHCenter
self.layout = QtWidgets.QVBoxLayout(self)
elif orientation == 'horizontal':
self.slider = QtWidgets.QSlider(Qt.Horizontal)
alignment = QtCore.Qt.AlignVCenter
align_text = QtCore.Qt.AlignLeft
align_value = QtCore.Qt.AlignRight
self.layout = QtWidgets.QHBoxLayout(self)
else:
msg = "Unexpected value %s for 'orientation'"
raise ValueError(msg % orientation)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set slider behavior for float and int values.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if value_type == 'float':
# divide slider into 1000 discrete values
slider_max = 1000
self._scale = float(high - low) / slider_max
self.slider.setRange(0, slider_max)
self.value_fmt = '%2.2f'
elif value_type == 'int':
self.slider.setRange(low, high)
self.value_fmt = '%d'
else:
msg = "Expected `value_type` to be 'float' or 'int'; received: %s"
raise ValueError(msg % value_type)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.value_type = value_type
self._low = low
self._high = high
# Update slider position to default value
self.val = value
if update_on == 'move':
self.slider.valueChanged.connect(self._on_slider_changed)
elif update_on == 'release':
self.slider.sliderReleased.connect(self._on_slider_changed)
else:
raise ValueError("Unexpected value %s for 'update_on'" % update_on)
self.slider.setFocusPolicy(QtCore.Qt.StrongFocus)
self.name_label = QtWidgets.QLabel()
self.name_label.setText(self.name)
self.name_label.setAlignment(align_text)
self.editbox = QtWidgets.QLineEdit()
self.editbox.setMaximumWidth(max_edit_width)
self.editbox.setText(self.value_fmt % self.val)
self.editbox.setAlignment(align_value)
self.editbox.editingFinished.connect(self._on_editbox_changed)
self.layout.addWidget(self.name_label)
self.layout.addWidget(self.slider)
self.layout.addWidget(self.editbox)
def _on_slider_changed(self):
"""Call callback function with slider's name and value as parameters"""
value = self.val
self.editbox.setText(str(value)[:4])
self.callback(self.name, value)
def _on_editbox_changed(self):
"""Validate input and set slider value"""
try:
value = float(self.editbox.text())
except ValueError:
self._bad_editbox_input()
return
if not self._low <= value <= self._high:
self._bad_editbox_input()
return
self.val = value
self._good_editbox_input()
self.callback(self.name, value)
def _good_editbox_input(self):
self.editbox.setStyleSheet("background-color: rgb(255, 255, 255)")
def _bad_editbox_input(self):
self.editbox.setStyleSheet("background-color: rgb(255, 200, 200)")
@property
def val(self):
value = self.slider.value()
if self.value_type == 'float':
value = value * self._scale + self._low
return value
@val.setter
def val(self, value):
if self.value_type == 'float':
value = (value - self._low) / self._scale
self.slider.setValue(value)
class ComboBox(BaseWidget):
"""ComboBox widget for selecting among a list of choices.
Parameters
----------
name : str
Name of ComboBox parameter. If this parameter is passed as a keyword
argument, it must match the name of that keyword argument (spaces are
replaced with underscores). In addition, this name is displayed as the
name of the ComboBox.
items: list of str
Allowed parameter values.
ptype : {'arg' | 'kwarg' | 'plugin'}, optional
Parameter type.
callback : callable f(widget_name, value), optional
Callback function called in response to combobox changes.
*Note:* This function is typically set (overridden) when the widget is
added to a plugin.
"""
def __init__(self, name, items, ptype='kwarg', callback=None):
super(ComboBox, self).__init__(name, ptype, callback)
self.name_label = QtWidgets.QLabel()
self.name_label.setText(self.name)
self.name_label.setAlignment(QtCore.Qt.AlignLeft)
self._combo_box = QtWidgets.QComboBox()
self._combo_box.addItems(list(items))
self.layout = QtWidgets.QHBoxLayout(self)
self.layout.addWidget(self.name_label)
self.layout.addWidget(self._combo_box)
self._combo_box.currentIndexChanged.connect(self._value_changed)
@property
def val(self):
return self._combo_box.currentText()
@property
def index(self):
return self._combo_box.currentIndex()
@index.setter
def index(self, i):
self._combo_box.setCurrentIndex(i)
class CheckBox(BaseWidget):
"""CheckBox widget
Parameters
----------
name : str
Name of CheckBox parameter. If this parameter is passed as a keyword
argument, it must match the name of that keyword argument (spaces are
replaced with underscores). In addition, this name is displayed as the
name of the CheckBox.
value: {False, True}, optional
Initial state of the CheckBox.
alignment: {'center','left','right'}, optional
Checkbox alignment
ptype : {'arg' | 'kwarg' | 'plugin'}, optional
Parameter type
callback : callable f(widget_name, value), optional
Callback function called in response to checkbox changes.
*Note:* This function is typically set (overridden) when the widget is
added to a plugin.
"""
def __init__(self, name, value=False, alignment='center', ptype='kwarg',
callback=None):
super(CheckBox, self).__init__(name, ptype, callback)
self._check_box = QtWidgets.QCheckBox()
self._check_box.setChecked(value)
self._check_box.setText(self.name)
self.layout = QtWidgets.QHBoxLayout(self)
if alignment == 'center':
self.layout.setAlignment(QtCore.Qt.AlignCenter)
elif alignment == 'left':
self.layout.setAlignment(QtCore.Qt.AlignLeft)
elif alignment == 'right':
self.layout.setAlignment(QtCore.Qt.AlignRight)
else:
raise ValueError("Unexpected value %s for 'alignment'" % alignment)
self.layout.addWidget(self._check_box)
self._check_box.stateChanged.connect(self._value_changed)
@property
def val(self):
return self._check_box.isChecked()
@val.setter
def val(self, i):
self._check_box.setChecked(i)
class Button(BaseWidget):
"""Button which calls callback upon click.
Parameters
----------
name : str
Name of button.
callback : callable f()
Function to call when button is clicked.
"""
def __init__(self, name, callback):
super(Button, self).__init__(self)
self._button = QtWidgets.QPushButton(name)
self._button.clicked.connect(callback)
self.layout = QtWidgets.QHBoxLayout(self)
self.layout.addWidget(self._button)

View file

@ -0,0 +1,104 @@
from textwrap import dedent
from ..qt import QtGui, QtCore, QtWidgets
import numpy as np
from ... import io
from ...util import img_as_ubyte
from .core import BaseWidget
from ..utils import dialogs
__all__ = ['OKCancelButtons', 'SaveButtons']
class OKCancelButtons(BaseWidget):
"""Buttons that close the parent plugin.
OK will replace the original image with the current (filtered) image.
Cancel will just close the plugin.
"""
def __init__(self, button_width=80):
name = 'OK/Cancel'
super(OKCancelButtons, self).__init__(name)
self.ok = QtWidgets.QPushButton('OK')
self.ok.clicked.connect(self.update_original_image)
self.ok.setMaximumWidth(button_width)
self.ok.setFocusPolicy(QtCore.Qt.NoFocus)
self.cancel = QtWidgets.QPushButton('Cancel')
self.cancel.clicked.connect(self.close_plugin)
self.cancel.setMaximumWidth(button_width)
self.cancel.setFocusPolicy(QtCore.Qt.NoFocus)
self.layout = QtWidgets.QHBoxLayout(self)
self.layout.addStretch()
self.layout.addWidget(self.cancel)
self.layout.addWidget(self.ok)
def update_original_image(self):
image = self.plugin.image_viewer.image
self.plugin.image_viewer.original_image = image
self.plugin.close()
def close_plugin(self):
# Image viewer will restore original image on close.
self.plugin.close()
class SaveButtons(BaseWidget):
"""Buttons to save image to io.stack or to a file."""
def __init__(self, name='Save to:', default_format='png'):
super(SaveButtons, self).__init__(name)
self.default_format = default_format
self.name_label = QtWidgets.QLabel()
self.name_label.setText(name)
self.save_file = QtWidgets.QPushButton('File')
self.save_file.clicked.connect(self.save_to_file)
self.save_file.setFocusPolicy(QtCore.Qt.NoFocus)
self.save_stack = QtWidgets.QPushButton('Stack')
self.save_stack.clicked.connect(self.save_to_stack)
self.save_stack.setFocusPolicy(QtCore.Qt.NoFocus)
self.layout = QtWidgets.QHBoxLayout(self)
self.layout.addWidget(self.name_label)
self.layout.addWidget(self.save_stack)
self.layout.addWidget(self.save_file)
def save_to_stack(self):
image = self.plugin.filtered_image.copy()
io.push(image)
msg = dedent('''\
The image has been pushed to the io stack.
Use io.pop() to retrieve the most recently pushed image.
NOTE: The io stack only works in interactive sessions.''')
notify(msg)
def save_to_file(self, filename=None):
if not filename:
filename = dialogs.save_file_dialog()
if not filename:
return
image = self.plugin.filtered_image
if image.dtype == np.bool:
# TODO: This check/conversion should probably be in `imsave`.
image = img_as_ubyte(image)
io.imsave(filename, image)
def notify(msg):
msglabel = QtWidgets.QLabel(msg)
dialog = QtWidgets.QDialog()
ok = QtWidgets.QPushButton('OK', dialog)
ok.clicked.connect(dialog.accept)
ok.setDefault(True)
dialog.layout = QtWidgets.QGridLayout(dialog)
dialog.layout.addWidget(msglabel, 0, 0, 1, 3)
dialog.layout.addWidget(ok, 1, 1)
dialog.exec_()