import functools from matplotlib import artist as martist, cbook, transforms as mtransforms from matplotlib.axes import subplot_class_factory from matplotlib.transforms import Bbox from .mpl_axes import Axes import numpy as np class ParasiteAxesBase: def get_images_artists(self): artists = {a for a in self.get_children() if a.get_visible()} images = {a for a in self.images if a.get_visible()} return list(images), list(artists - images) def __init__(self, parent_axes, **kwargs): self._parent_axes = parent_axes kwargs["frameon"] = False super().__init__(parent_axes.figure, parent_axes._position, **kwargs) def cla(self): super().cla() martist.setp(self.get_children(), visible=False) self._get_lines = self._parent_axes._get_lines # In mpl's Axes, zorders of x- and y-axis are originally set # within Axes.draw(). if self._axisbelow: self.xaxis.set_zorder(0.5) self.yaxis.set_zorder(0.5) else: self.xaxis.set_zorder(2.5) self.yaxis.set_zorder(2.5) def pick(self, mouseevent): # This most likely goes to Artist.pick (depending on axes_class given # to the factory), which only handles pick events registered on the # axes associated with each child: super().pick(mouseevent) # But parasite axes are additionally given pick events from their host # axes (cf. HostAxesBase.pick), which we handle here: for a in self.get_children(): if (hasattr(mouseevent.inaxes, "parasites") and self in mouseevent.inaxes.parasites): a.pick(mouseevent) @functools.lru_cache(None) def parasite_axes_class_factory(axes_class=None): if axes_class is None: cbook.warn_deprecated( "3.3", message="Support for passing None to " "parasite_axes_class_factory is deprecated since %(since)s and " "will be removed %(removal)s; explicitly pass the default Axes " "class instead.") axes_class = Axes return type("%sParasite" % axes_class.__name__, (ParasiteAxesBase, axes_class), {}) ParasiteAxes = parasite_axes_class_factory(Axes) class ParasiteAxesAuxTransBase: def __init__(self, parent_axes, aux_transform, viewlim_mode=None, **kwargs): self.transAux = aux_transform self.set_viewlim_mode(viewlim_mode) super().__init__(parent_axes, **kwargs) def _set_lim_and_transforms(self): self.transAxes = self._parent_axes.transAxes self.transData = \ self.transAux + \ self._parent_axes.transData self._xaxis_transform = mtransforms.blended_transform_factory( self.transData, self.transAxes) self._yaxis_transform = mtransforms.blended_transform_factory( self.transAxes, self.transData) def set_viewlim_mode(self, mode): cbook._check_in_list([None, "equal", "transform"], mode=mode) self._viewlim_mode = mode def get_viewlim_mode(self): return self._viewlim_mode def update_viewlim(self): viewlim = self._parent_axes.viewLim.frozen() mode = self.get_viewlim_mode() if mode is None: pass elif mode == "equal": self.axes.viewLim.set(viewlim) elif mode == "transform": self.axes.viewLim.set( viewlim.transformed(self.transAux.inverted())) else: cbook._check_in_list([None, "equal", "transform"], mode=mode) def _pcolor(self, super_pcolor, *XYC, **kwargs): if len(XYC) == 1: C = XYC[0] ny, nx = C.shape gx = np.arange(-0.5, nx) gy = np.arange(-0.5, ny) X, Y = np.meshgrid(gx, gy) else: X, Y, C = XYC if "transform" in kwargs: mesh = super_pcolor(X, Y, C, **kwargs) else: orig_shape = X.shape xyt = np.column_stack([X.flat, Y.flat]) wxy = self.transAux.transform(xyt) gx = wxy[:, 0].reshape(orig_shape) gy = wxy[:, 1].reshape(orig_shape) mesh = super_pcolor(gx, gy, C, **kwargs) mesh.set_transform(self._parent_axes.transData) return mesh def pcolormesh(self, *XYC, **kwargs): return self._pcolor(super().pcolormesh, *XYC, **kwargs) def pcolor(self, *XYC, **kwargs): return self._pcolor(super().pcolor, *XYC, **kwargs) def _contour(self, super_contour, *XYCL, **kwargs): if len(XYCL) <= 2: C = XYCL[0] ny, nx = C.shape gx = np.arange(0., nx) gy = np.arange(0., ny) X, Y = np.meshgrid(gx, gy) CL = XYCL else: X, Y = XYCL[:2] CL = XYCL[2:] if "transform" in kwargs: cont = super_contour(X, Y, *CL, **kwargs) else: orig_shape = X.shape xyt = np.column_stack([X.flat, Y.flat]) wxy = self.transAux.transform(xyt) gx = wxy[:, 0].reshape(orig_shape) gy = wxy[:, 1].reshape(orig_shape) cont = super_contour(gx, gy, *CL, **kwargs) for c in cont.collections: c.set_transform(self._parent_axes.transData) return cont def contour(self, *XYCL, **kwargs): return self._contour(super().contour, *XYCL, **kwargs) def contourf(self, *XYCL, **kwargs): return self._contour(super().contourf, *XYCL, **kwargs) def apply_aspect(self, position=None): self.update_viewlim() super().apply_aspect() @functools.lru_cache(None) def parasite_axes_auxtrans_class_factory(axes_class=None): if axes_class is None: cbook.warn_deprecated( "3.3", message="Support for passing None to " "parasite_axes_auxtrans_class_factory is deprecated since " "%(since)s and will be removed %(removal)s; explicitly pass the " "default ParasiteAxes class instead.") parasite_axes_class = ParasiteAxes elif not issubclass(axes_class, ParasiteAxesBase): parasite_axes_class = parasite_axes_class_factory(axes_class) else: parasite_axes_class = axes_class return type("%sParasiteAuxTrans" % parasite_axes_class.__name__, (ParasiteAxesAuxTransBase, parasite_axes_class), {'name': 'parasite_axes'}) ParasiteAxesAuxTrans = parasite_axes_auxtrans_class_factory(ParasiteAxes) class HostAxesBase: def __init__(self, *args, **kwargs): self.parasites = [] super().__init__(*args, **kwargs) def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=ParasiteAxes): parasite_axes_class = parasite_axes_auxtrans_class_factory(axes_class) ax2 = parasite_axes_class(self, tr, viewlim_mode) # note that ax2.transData == tr + ax1.transData # Anything you draw in ax2 will match the ticks and grids of ax1. self.parasites.append(ax2) ax2._remove_method = self.parasites.remove return ax2 def _get_legend_handles(self, legend_handler_map=None): all_handles = super()._get_legend_handles() for ax in self.parasites: all_handles.extend(ax._get_legend_handles(legend_handler_map)) return all_handles def draw(self, renderer): orig_artists = list(self.artists) orig_images = list(self.images) if hasattr(self, "get_axes_locator"): locator = self.get_axes_locator() if locator: pos = locator(self, renderer) self.set_position(pos, which="active") self.apply_aspect(pos) else: self.apply_aspect() else: self.apply_aspect() rect = self.get_position() for ax in self.parasites: ax.apply_aspect(rect) images, artists = ax.get_images_artists() self.images.extend(images) self.artists.extend(artists) super().draw(renderer) self.artists = orig_artists self.images = orig_images def cla(self): for ax in self.parasites: ax.cla() super().cla() def pick(self, mouseevent): super().pick(mouseevent) # Also pass pick events on to parasite axes and, in turn, their # children (cf. ParasiteAxesBase.pick) for a in self.parasites: a.pick(mouseevent) def twinx(self, axes_class=None): """ Create a twin of Axes with a shared x-axis but independent y-axis. The y-axis of self will have ticks on the left and the returned axes will have ticks on the right. """ if axes_class is None: axes_class = self._get_base_axes() parasite_axes_class = parasite_axes_class_factory(axes_class) ax2 = parasite_axes_class(self, sharex=self, frameon=False) self.parasites.append(ax2) ax2._remove_method = self._remove_twinx self.axis["right"].set_visible(False) ax2.axis["right"].set_visible(True) ax2.axis["left", "top", "bottom"].set_visible(False) return ax2 def _remove_twinx(self, ax): self.parasites.remove(ax) self.axis["right"].set_visible(True) self.axis["right"].toggle(ticklabels=False, label=False) def twiny(self, axes_class=None): """ Create a twin of Axes with a shared y-axis but independent x-axis. The x-axis of self will have ticks on the bottom and the returned axes will have ticks on the top. """ if axes_class is None: axes_class = self._get_base_axes() parasite_axes_class = parasite_axes_class_factory(axes_class) ax2 = parasite_axes_class(self, sharey=self, frameon=False) self.parasites.append(ax2) ax2._remove_method = self._remove_twiny self.axis["top"].set_visible(False) ax2.axis["top"].set_visible(True) ax2.axis["left", "right", "bottom"].set_visible(False) return ax2 def _remove_twiny(self, ax): self.parasites.remove(ax) self.axis["top"].set_visible(True) self.axis["top"].toggle(ticklabels=False, label=False) def twin(self, aux_trans=None, axes_class=None): """ Create a twin of Axes with no shared axis. While self will have ticks on the left and bottom axis, the returned axes will have ticks on the top and right axis. """ if axes_class is None: axes_class = self._get_base_axes() parasite_axes_auxtrans_class = \ parasite_axes_auxtrans_class_factory(axes_class) if aux_trans is None: ax2 = parasite_axes_auxtrans_class( self, mtransforms.IdentityTransform(), viewlim_mode="equal") else: ax2 = parasite_axes_auxtrans_class( self, aux_trans, viewlim_mode="transform") self.parasites.append(ax2) ax2._remove_method = self.parasites.remove self.axis["top", "right"].set_visible(False) ax2.axis["top", "right"].set_visible(True) ax2.axis["left", "bottom"].set_visible(False) def _remove_method(h): self.parasites.remove(h) self.axis["top", "right"].set_visible(True) self.axis["top", "right"].toggle(ticklabels=False, label=False) ax2._remove_method = _remove_method return ax2 def get_tightbbox(self, renderer, call_axes_locator=True, bbox_extra_artists=None): bbs = [ *[ax.get_tightbbox(renderer, call_axes_locator=call_axes_locator) for ax in self.parasites], super().get_tightbbox(renderer, call_axes_locator=call_axes_locator, bbox_extra_artists=bbox_extra_artists)] return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0]) @functools.lru_cache(None) def host_axes_class_factory(axes_class=None): if axes_class is None: cbook.warn_deprecated( "3.3", message="Support for passing None to host_axes_class is " "deprecated since %(since)s and will be removed %(removed)s; " "explicitly pass the default Axes class instead.") axes_class = Axes def _get_base_axes(self): return axes_class return type("%sHostAxes" % axes_class.__name__, (HostAxesBase, axes_class), {'_get_base_axes': _get_base_axes}) def host_subplot_class_factory(axes_class): host_axes_class = host_axes_class_factory(axes_class) subplot_host_class = subplot_class_factory(host_axes_class) return subplot_host_class HostAxes = host_axes_class_factory(Axes) SubplotHost = subplot_class_factory(HostAxes) def host_axes(*args, axes_class=Axes, figure=None, **kwargs): """ Create axes that can act as a hosts to parasitic axes. Parameters ---------- figure : `matplotlib.figure.Figure` Figure to which the axes will be added. Defaults to the current figure `.pyplot.gcf()`. *args, **kwargs Will be passed on to the underlying ``Axes`` object creation. """ import matplotlib.pyplot as plt host_axes_class = host_axes_class_factory(axes_class) if figure is None: figure = plt.gcf() ax = host_axes_class(figure, *args, **kwargs) figure.add_axes(ax) plt.draw_if_interactive() return ax def host_subplot(*args, axes_class=Axes, figure=None, **kwargs): """ Create a subplot that can act as a host to parasitic axes. Parameters ---------- figure : `matplotlib.figure.Figure` Figure to which the subplot will be added. Defaults to the current figure `.pyplot.gcf()`. *args, **kwargs Will be passed on to the underlying ``Axes`` object creation. """ import matplotlib.pyplot as plt host_subplot_class = host_subplot_class_factory(axes_class) if figure is None: figure = plt.gcf() ax = host_subplot_class(figure, *args, **kwargs) figure.add_subplot(ax) plt.draw_if_interactive() return ax