Fixed database typo and removed unnecessary class identifier.

This commit is contained in:
Batuhan Berk Başoğlu 2020-10-14 10:10:37 -04:00
parent 00ad49a143
commit 45fb349a7d
5098 changed files with 952558 additions and 85 deletions

View file

@ -0,0 +1,52 @@
"""
Helper functions for testing.
"""
import locale
import logging
import matplotlib as mpl
from matplotlib import cbook
_log = logging.getLogger(__name__)
@cbook.deprecated("3.2")
def is_called_from_pytest():
"""Whether we are in a pytest run."""
return getattr(mpl, '_called_from_pytest', False)
def set_font_settings_for_testing():
mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['text.hinting'] = 'none'
mpl.rcParams['text.hinting_factor'] = 8
def set_reproducibility_for_testing():
mpl.rcParams['svg.hashsalt'] = 'matplotlib'
def setup():
# The baseline images are created in this locale, so we should use
# it during all of the tests.
try:
locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
except locale.Error:
try:
locale.setlocale(locale.LC_ALL, 'English_United States.1252')
except locale.Error:
_log.warning(
"Could not set locale to English/United States. "
"Some date-related tests may fail.")
mpl.use('Agg')
with cbook._suppress_matplotlib_deprecation_warning():
mpl.rcdefaults() # Start with all defaults
# These settings *must* be hardcoded for running the comparison tests and
# are not necessarily the default values as specified in rcsetup.py.
set_font_settings_for_testing()
set_reproducibility_for_testing()

View file

@ -0,0 +1,472 @@
"""
Utilities for comparing image results.
"""
import atexit
import hashlib
import os
from pathlib import Path
import re
import shutil
import subprocess
import sys
from tempfile import TemporaryDirectory, TemporaryFile
import numpy as np
from PIL import Image
import matplotlib as mpl
from matplotlib import cbook
from matplotlib.testing.exceptions import ImageComparisonFailure
__all__ = ['compare_images', 'comparable_formats']
def make_test_filename(fname, purpose):
"""
Make a new filename by inserting *purpose* before the file's extension.
"""
base, ext = os.path.splitext(fname)
return '%s-%s%s' % (base, purpose, ext)
def get_cache_dir():
cache_dir = Path(mpl.get_cachedir(), 'test_cache')
cache_dir.mkdir(parents=True, exist_ok=True)
return str(cache_dir)
def get_file_hash(path, block_size=2 ** 20):
md5 = hashlib.md5()
with open(path, 'rb') as fd:
while True:
data = fd.read(block_size)
if not data:
break
md5.update(data)
if Path(path).suffix == '.pdf':
md5.update(str(mpl._get_executable_info("gs").version)
.encode('utf-8'))
elif Path(path).suffix == '.svg':
md5.update(str(mpl._get_executable_info("inkscape").version)
.encode('utf-8'))
return md5.hexdigest()
@cbook.deprecated("3.3")
def make_external_conversion_command(cmd):
def convert(old, new):
cmdline = cmd(old, new)
pipe = subprocess.Popen(cmdline, universal_newlines=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = pipe.communicate()
errcode = pipe.wait()
if not os.path.exists(new) or errcode:
msg = "Conversion command failed:\n%s\n" % ' '.join(cmdline)
if stdout:
msg += "Standard output:\n%s\n" % stdout
if stderr:
msg += "Standard error:\n%s\n" % stderr
raise IOError(msg)
return convert
# Modified from https://bugs.python.org/issue25567.
_find_unsafe_bytes = re.compile(br'[^a-zA-Z0-9_@%+=:,./-]').search
def _shlex_quote_bytes(b):
return (b if _find_unsafe_bytes(b) is None
else b"'" + b.replace(b"'", b"'\"'\"'") + b"'")
class _ConverterError(Exception):
pass
class _Converter:
def __init__(self):
self._proc = None
# Explicitly register deletion from an atexit handler because if we
# wait until the object is GC'd (which occurs later), then some module
# globals (e.g. signal.SIGKILL) has already been set to None, and
# kill() doesn't work anymore...
atexit.register(self.__del__)
def __del__(self):
if self._proc:
self._proc.kill()
self._proc.wait()
for stream in filter(None, [self._proc.stdin,
self._proc.stdout,
self._proc.stderr]):
stream.close()
self._proc = None
def _read_until(self, terminator):
"""Read until the prompt is reached."""
buf = bytearray()
while True:
c = self._proc.stdout.read(1)
if not c:
raise _ConverterError
buf.extend(c)
if buf.endswith(terminator):
return bytes(buf[:-len(terminator)])
class _GSConverter(_Converter):
def __call__(self, orig, dest):
if not self._proc:
self._proc = subprocess.Popen(
[mpl._get_executable_info("gs").executable,
"-dNOSAFER", "-dNOPAUSE", "-sDEVICE=png16m"],
# As far as I can see, ghostscript never outputs to stderr.
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
try:
self._read_until(b"\nGS")
except _ConverterError as err:
raise OSError("Failed to start Ghostscript") from err
def encode_and_escape(name):
return (os.fsencode(name)
.replace(b"\\", b"\\\\")
.replace(b"(", br"\(")
.replace(b")", br"\)"))
self._proc.stdin.write(
b"<< /OutputFile ("
+ encode_and_escape(dest)
+ b") >> setpagedevice ("
+ encode_and_escape(orig)
+ b") run flush\n")
self._proc.stdin.flush()
# GS> if nothing left on the stack; GS<n> if n items left on the stack.
err = self._read_until(b"GS")
stack = self._read_until(b">")
if stack or not os.path.exists(dest):
stack_size = int(stack[1:]) if stack else 0
self._proc.stdin.write(b"pop\n" * stack_size)
# Using the systemencoding should at least get the filenames right.
raise ImageComparisonFailure(
(err + b"GS" + stack + b">")
.decode(sys.getfilesystemencoding(), "replace"))
class _SVGConverter(_Converter):
def __call__(self, orig, dest):
old_inkscape = mpl._get_executable_info("inkscape").version < "1"
terminator = b"\n>" if old_inkscape else b"> "
if not hasattr(self, "_tmpdir"):
self._tmpdir = TemporaryDirectory()
if (not self._proc # First run.
or self._proc.poll() is not None): # Inkscape terminated.
env = {
**os.environ,
# If one passes e.g. a png file to Inkscape, it will try to
# query the user for conversion options via a GUI (even with
# `--without-gui`). Unsetting `DISPLAY` prevents this (and
# causes GTK to crash and Inkscape to terminate, but that'll
# just be reported as a regular exception below).
"DISPLAY": "",
# Do not load any user options.
"INKSCAPE_PROFILE_DIR": os.devnull,
}
# Old versions of Inkscape (e.g. 0.48.3.1) seem to sometimes
# deadlock when stderr is redirected to a pipe, so we redirect it
# to a temporary file instead. This is not necessary anymore as of
# Inkscape 0.92.1.
stderr = TemporaryFile()
self._proc = subprocess.Popen(
["inkscape", "--without-gui", "--shell"] if old_inkscape else
["inkscape", "--shell"],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=stderr,
env=env, cwd=self._tmpdir.name)
# Slight abuse, but makes shutdown handling easier.
self._proc.stderr = stderr
try:
self._read_until(terminator)
except _ConverterError as err:
raise OSError("Failed to start Inkscape in interactive "
"mode") from err
# Inkscape's shell mode does not support escaping metacharacters in the
# filename ("\n", and ":;" for inkscape>=1). Avoid any problems by
# running from a temporary directory and using fixed filenames.
inkscape_orig = Path(self._tmpdir.name, os.fsdecode(b"f.svg"))
inkscape_dest = Path(self._tmpdir.name, os.fsdecode(b"f.png"))
try:
inkscape_orig.symlink_to(Path(orig).resolve())
except OSError:
shutil.copyfile(orig, inkscape_orig)
self._proc.stdin.write(
b"f.svg --export-png=f.png\n" if old_inkscape else
b"file-open:f.svg;export-filename:f.png;export-do;file-close\n")
self._proc.stdin.flush()
try:
self._read_until(terminator)
except _ConverterError as err:
# Inkscape's output is not localized but gtk's is, so the output
# stream probably has a mixed encoding. Using the filesystem
# encoding should at least get the filenames right...
self._proc.stderr.seek(0)
raise ImageComparisonFailure(
self._proc.stderr.read().decode(
sys.getfilesystemencoding(), "replace")) from err
os.remove(inkscape_orig)
shutil.move(inkscape_dest, dest)
def __del__(self):
super().__del__()
if hasattr(self, "_tmpdir"):
self._tmpdir.cleanup()
def _update_converter():
try:
mpl._get_executable_info("gs")
except mpl.ExecutableNotFoundError:
pass
else:
converter['pdf'] = converter['eps'] = _GSConverter()
try:
mpl._get_executable_info("inkscape")
except mpl.ExecutableNotFoundError:
pass
else:
converter['svg'] = _SVGConverter()
#: A dictionary that maps filename extensions to functions which
#: themselves map arguments `old` and `new` (filenames) to a list of strings.
#: The list can then be passed to Popen to convert files with that
#: extension to png format.
converter = {}
_update_converter()
def comparable_formats():
"""
Return the list of file formats that `.compare_images` can compare
on this system.
Returns
-------
list of str
E.g. ``['png', 'pdf', 'svg', 'eps']``.
"""
return ['png', *converter]
def convert(filename, cache):
"""
Convert the named file to png; return the name of the created file.
If *cache* is True, the result of the conversion is cached in
`matplotlib.get_cachedir() + '/test_cache/'`. The caching is based on a
hash of the exact contents of the input file. There is no limit on the
size of the cache, so it may need to be manually cleared periodically.
"""
path = Path(filename)
if not path.exists():
raise IOError(f"{path} does not exist")
if path.suffix[1:] not in converter:
import pytest
pytest.skip(f"Don't know how to convert {path.suffix} files to png")
newpath = path.parent / f"{path.stem}_{path.suffix[1:]}.png"
# Only convert the file if the destination doesn't already exist or
# is out of date.
if not newpath.exists() or newpath.stat().st_mtime < path.stat().st_mtime:
cache_dir = Path(get_cache_dir()) if cache else None
if cache_dir is not None:
hash_value = get_file_hash(path)
cached_path = cache_dir / (hash_value + newpath.suffix)
if cached_path.exists():
shutil.copyfile(cached_path, newpath)
return str(newpath)
converter[path.suffix[1:]](path, newpath)
if cache_dir is not None:
shutil.copyfile(newpath, cached_path)
return str(newpath)
def crop_to_same(actual_path, actual_image, expected_path, expected_image):
# clip the images to the same size -- this is useful only when
# comparing eps to pdf
if actual_path[-7:-4] == 'eps' and expected_path[-7:-4] == 'pdf':
aw, ah, ad = actual_image.shape
ew, eh, ed = expected_image.shape
actual_image = actual_image[int(aw / 2 - ew / 2):int(
aw / 2 + ew / 2), int(ah / 2 - eh / 2):int(ah / 2 + eh / 2)]
return actual_image, expected_image
def calculate_rms(expected_image, actual_image):
"""
Calculate the per-pixel errors, then compute the root mean square error.
"""
if expected_image.shape != actual_image.shape:
raise ImageComparisonFailure(
"Image sizes do not match expected size: {} "
"actual size {}".format(expected_image.shape, actual_image.shape))
# Convert to float to avoid overflowing finite integer types.
return np.sqrt(((expected_image - actual_image).astype(float) ** 2).mean())
# NOTE: compare_image and save_diff_image assume that the image does not have
# 16-bit depth, as Pillow converts these to RGB incorrectly.
def compare_images(expected, actual, tol, in_decorator=False):
"""
Compare two "image" files checking differences within a tolerance.
The two given filenames may point to files which are convertible to
PNG via the `.converter` dictionary. The underlying RMS is calculated
with the `.calculate_rms` function.
Parameters
----------
expected : str
The filename of the expected image.
actual : str
The filename of the actual image.
tol : float
The tolerance (a color value difference, where 255 is the
maximal difference). The test fails if the average pixel
difference is greater than this value.
in_decorator : bool
Determines the output format. If called from image_comparison
decorator, this should be True. (default=False)
Returns
-------
None or dict or str
Return *None* if the images are equal within the given tolerance.
If the images differ, the return value depends on *in_decorator*.
If *in_decorator* is true, a dict with the following entries is
returned:
- *rms*: The RMS of the image difference.
- *expected*: The filename of the expected image.
- *actual*: The filename of the actual image.
- *diff_image*: The filename of the difference image.
- *tol*: The comparison tolerance.
Otherwise, a human-readable multi-line string representation of this
information is returned.
Examples
--------
::
img1 = "./baseline/plot.png"
img2 = "./output/plot.png"
compare_images(img1, img2, 0.001)
"""
actual = os.fspath(actual)
if not os.path.exists(actual):
raise Exception("Output image %s does not exist." % actual)
if os.stat(actual).st_size == 0:
raise Exception("Output image file %s is empty." % actual)
# Convert the image to png
expected = os.fspath(expected)
if not os.path.exists(expected):
raise IOError('Baseline image %r does not exist.' % expected)
extension = expected.split('.')[-1]
if extension != 'png':
actual = convert(actual, cache=False)
expected = convert(expected, cache=True)
# open the image files and remove the alpha channel (if it exists)
expected_image = np.asarray(Image.open(expected).convert("RGB"))
actual_image = np.asarray(Image.open(actual).convert("RGB"))
actual_image, expected_image = crop_to_same(
actual, actual_image, expected, expected_image)
diff_image = make_test_filename(actual, 'failed-diff')
if tol <= 0:
if np.array_equal(expected_image, actual_image):
return None
# convert to signed integers, so that the images can be subtracted without
# overflow
expected_image = expected_image.astype(np.int16)
actual_image = actual_image.astype(np.int16)
rms = calculate_rms(expected_image, actual_image)
if rms <= tol:
return None
save_diff_image(expected, actual, diff_image)
results = dict(rms=rms, expected=str(expected),
actual=str(actual), diff=str(diff_image), tol=tol)
if not in_decorator:
# Then the results should be a string suitable for stdout.
template = ['Error: Image files did not match.',
'RMS Value: {rms}',
'Expected: \n {expected}',
'Actual: \n {actual}',
'Difference:\n {diff}',
'Tolerance: \n {tol}', ]
results = '\n '.join([line.format(**results) for line in template])
return results
def save_diff_image(expected, actual, output):
"""
Parameters
----------
expected : str
File path of expected image.
actual : str
File path of actual image.
output : str
File path to save difference image to.
"""
# Drop alpha channels, similarly to compare_images.
expected_image = np.asarray(Image.open(expected).convert("RGB"))
actual_image = np.asarray(Image.open(actual).convert("RGB"))
actual_image, expected_image = crop_to_same(
actual, actual_image, expected, expected_image)
expected_image = np.array(expected_image).astype(float)
actual_image = np.array(actual_image).astype(float)
if expected_image.shape != actual_image.shape:
raise ImageComparisonFailure(
"Image sizes do not match expected size: {} "
"actual size {}".format(expected_image.shape, actual_image.shape))
abs_diff_image = np.abs(expected_image - actual_image)
# expand differences in luminance domain
abs_diff_image *= 255 * 10
save_image_np = np.clip(abs_diff_image, 0, 255).astype(np.uint8)
height, width, depth = save_image_np.shape
# The PDF renderer doesn't produce an alpha channel, but the
# matplotlib PNG writer requires one, so expand the array
if depth == 3:
with_alpha = np.empty((height, width, 4), dtype=np.uint8)
with_alpha[:, :, 0:3] = save_image_np
save_image_np = with_alpha
# Hard-code the alpha channel to fully solid
save_image_np[:, :, 3] = 255
Image.fromarray(save_image_np).save(output, format="png")

View file

@ -0,0 +1,137 @@
import pytest
import sys
import matplotlib
from matplotlib import cbook
def pytest_configure(config):
# config is initialized here rather than in pytest.ini so that `pytest
# --pyargs matplotlib` (which would not find pytest.ini) works. The only
# entries in pytest.ini set minversion (which is checked earlier),
# testpaths/python_files, as they are required to properly find the tests
for key, value in [
("markers", "flaky: (Provided by pytest-rerunfailures.)"),
("markers", "timeout: (Provided by pytest-timeout.)"),
("markers", "backend: Set alternate Matplotlib backend temporarily."),
("markers", "style: Set alternate Matplotlib style temporarily."),
("markers", "baseline_images: Compare output against references."),
("markers", "pytz: Tests that require pytz to be installed."),
("markers", "network: Tests that reach out to the network."),
("filterwarnings", "error"),
]:
config.addinivalue_line(key, value)
matplotlib.use('agg', force=True)
matplotlib._called_from_pytest = True
matplotlib._init_tests()
def pytest_unconfigure(config):
matplotlib._called_from_pytest = False
@pytest.fixture(autouse=True)
def mpl_test_settings(request):
from matplotlib.testing.decorators import _cleanup_cm
with _cleanup_cm():
backend = None
backend_marker = request.node.get_closest_marker('backend')
if backend_marker is not None:
assert len(backend_marker.args) == 1, \
"Marker 'backend' must specify 1 backend."
backend, = backend_marker.args
skip_on_importerror = backend_marker.kwargs.get(
'skip_on_importerror', False)
prev_backend = matplotlib.get_backend()
# special case Qt backend importing to avoid conflicts
if backend.lower().startswith('qt4'):
if any(k in sys.modules for k in ('PyQt5', 'PySide2')):
pytest.skip('Qt5 binding already imported')
try:
import PyQt4
# RuntimeError if PyQt5 already imported.
except (ImportError, RuntimeError):
try:
import PySide
except ImportError:
pytest.skip("Failed to import a Qt4 binding.")
elif backend.lower().startswith('qt5'):
if any(k in sys.modules for k in ('PyQt4', 'PySide')):
pytest.skip('Qt4 binding already imported')
try:
import PyQt5
# RuntimeError if PyQt4 already imported.
except (ImportError, RuntimeError):
try:
import PySide2
except ImportError:
pytest.skip("Failed to import a Qt5 binding.")
# Default of cleanup and image_comparison too.
style = ["classic", "_classic_test_patch"]
style_marker = request.node.get_closest_marker('style')
if style_marker is not None:
assert len(style_marker.args) == 1, \
"Marker 'style' must specify 1 style."
style, = style_marker.args
matplotlib.testing.setup()
with cbook._suppress_matplotlib_deprecation_warning():
if backend is not None:
# This import must come after setup() so it doesn't load the
# default backend prematurely.
import matplotlib.pyplot as plt
try:
plt.switch_backend(backend)
except ImportError as exc:
# Should only occur for the cairo backend tests, if neither
# pycairo nor cairocffi are installed.
if 'cairo' in backend.lower() or skip_on_importerror:
pytest.skip("Failed to switch to backend {} ({})."
.format(backend, exc))
else:
raise
matplotlib.style.use(style)
try:
yield
finally:
if backend is not None:
plt.switch_backend(prev_backend)
@pytest.fixture
def mpl_image_comparison_parameters(request, extension):
# This fixture is applied automatically by the image_comparison decorator.
#
# The sole purpose of this fixture is to provide an indirect method of
# obtaining parameters *without* modifying the decorated function
# signature. In this way, the function signature can stay the same and
# pytest won't get confused.
# We annotate the decorated function with any parameters captured by this
# fixture so that they can be used by the wrapper in image_comparison.
baseline_images, = request.node.get_closest_marker('baseline_images').args
if baseline_images is None:
# Allow baseline image list to be produced on the fly based on current
# parametrization.
baseline_images = request.getfixturevalue('baseline_images')
func = request.function
with cbook._setattr_cm(func.__wrapped__,
parameters=(baseline_images, extension)):
yield
@pytest.fixture
def pd():
"""Fixture to import and configure pandas."""
pd = pytest.importorskip('pandas')
try:
from pandas.plotting import (
deregister_matplotlib_converters as deregister)
deregister()
except ImportError:
pass
return pd

View file

@ -0,0 +1,493 @@
import contextlib
from distutils.version import StrictVersion
import functools
import inspect
import os
from pathlib import Path
import shutil
import string
import sys
import unittest
import warnings
try:
from contextlib import nullcontext
except ImportError:
from contextlib import ExitStack as nullcontext # Py3.6.
import matplotlib as mpl
import matplotlib.style
import matplotlib.units
import matplotlib.testing
from matplotlib import cbook
from matplotlib import ft2font
from matplotlib import pyplot as plt
from matplotlib import ticker
from .compare import comparable_formats, compare_images, make_test_filename
from .exceptions import ImageComparisonFailure
@contextlib.contextmanager
def _cleanup_cm():
orig_units_registry = matplotlib.units.registry.copy()
try:
with warnings.catch_warnings(), matplotlib.rc_context():
yield
finally:
matplotlib.units.registry.clear()
matplotlib.units.registry.update(orig_units_registry)
plt.close("all")
class CleanupTestCase(unittest.TestCase):
"""A wrapper for unittest.TestCase that includes cleanup operations."""
@classmethod
def setUpClass(cls):
cls._cm = _cleanup_cm().__enter__()
@classmethod
def tearDownClass(cls):
cls._cm.__exit__(None, None, None)
def cleanup(style=None):
"""
A decorator to ensure that any global state is reset before
running a test.
Parameters
----------
style : str, dict, or list, optional
The style(s) to apply. Defaults to ``["classic",
"_classic_test_patch"]``.
"""
# If cleanup is used without arguments, *style* will be a callable, and we
# pass it directly to the wrapper generator. If cleanup if called with an
# argument, it is a string naming a style, and the function will be passed
# as an argument to what we return. This is a confusing, but somewhat
# standard, pattern for writing a decorator with optional arguments.
def make_cleanup(func):
if inspect.isgeneratorfunction(func):
@functools.wraps(func)
def wrapped_callable(*args, **kwargs):
with _cleanup_cm(), matplotlib.style.context(style):
yield from func(*args, **kwargs)
else:
@functools.wraps(func)
def wrapped_callable(*args, **kwargs):
with _cleanup_cm(), matplotlib.style.context(style):
func(*args, **kwargs)
return wrapped_callable
if callable(style):
result = make_cleanup(style)
# Default of mpl_test_settings fixture and image_comparison too.
style = ["classic", "_classic_test_patch"]
return result
else:
return make_cleanup
def check_freetype_version(ver):
if ver is None:
return True
if isinstance(ver, str):
ver = (ver, ver)
ver = [StrictVersion(x) for x in ver]
found = StrictVersion(ft2font.__freetype_version__)
return ver[0] <= found <= ver[1]
def _checked_on_freetype_version(required_freetype_version):
import pytest
reason = ("Mismatched version of freetype. "
"Test requires '%s', you have '%s'" %
(required_freetype_version, ft2font.__freetype_version__))
return pytest.mark.xfail(
not check_freetype_version(required_freetype_version),
reason=reason, raises=ImageComparisonFailure, strict=False)
def remove_ticks_and_titles(figure):
figure.suptitle("")
null_formatter = ticker.NullFormatter()
for ax in figure.get_axes():
ax.set_title("")
ax.xaxis.set_major_formatter(null_formatter)
ax.xaxis.set_minor_formatter(null_formatter)
ax.yaxis.set_major_formatter(null_formatter)
ax.yaxis.set_minor_formatter(null_formatter)
try:
ax.zaxis.set_major_formatter(null_formatter)
ax.zaxis.set_minor_formatter(null_formatter)
except AttributeError:
pass
def _raise_on_image_difference(expected, actual, tol):
__tracebackhide__ = True
err = compare_images(expected, actual, tol, in_decorator=True)
if err:
for key in ["actual", "expected", "diff"]:
err[key] = os.path.relpath(err[key])
raise ImageComparisonFailure(
('images not close (RMS %(rms).3f):'
'\n\t%(actual)s\n\t%(expected)s\n\t%(diff)s') % err)
def _skip_if_format_is_uncomparable(extension):
import pytest
return pytest.mark.skipif(
extension not in comparable_formats(),
reason='Cannot compare {} files on this system'.format(extension))
def _mark_skip_if_format_is_uncomparable(extension):
import pytest
if isinstance(extension, str):
name = extension
marks = []
elif isinstance(extension, tuple):
# Extension might be a pytest ParameterSet instead of a plain string.
# Unfortunately, this type is not exposed, so since it's a namedtuple,
# check for a tuple instead.
name, = extension.values
marks = [*extension.marks]
else:
# Extension might be a pytest marker instead of a plain string.
name, = extension.args
marks = [extension.mark]
return pytest.param(name,
marks=[*marks, _skip_if_format_is_uncomparable(name)])
class _ImageComparisonBase:
"""
Image comparison base class
This class provides *just* the comparison-related functionality and avoids
any code that would be specific to any testing framework.
"""
def __init__(self, func, tol, remove_text, savefig_kwargs):
self.func = func
self.baseline_dir, self.result_dir = _image_directories(func)
self.tol = tol
self.remove_text = remove_text
self.savefig_kwargs = savefig_kwargs
def copy_baseline(self, baseline, extension):
baseline_path = self.baseline_dir / baseline
orig_expected_path = baseline_path.with_suffix(f'.{extension}')
if extension == 'eps' and not orig_expected_path.exists():
orig_expected_path = orig_expected_path.with_suffix('.pdf')
expected_fname = make_test_filename(
self.result_dir / orig_expected_path.name, 'expected')
try:
# os.symlink errors if the target already exists.
with contextlib.suppress(OSError):
os.remove(expected_fname)
try:
os.symlink(orig_expected_path, expected_fname)
except OSError: # On Windows, symlink *may* be unavailable.
shutil.copyfile(orig_expected_path, expected_fname)
except OSError as err:
raise ImageComparisonFailure(
f"Missing baseline image {expected_fname} because the "
f"following file cannot be accessed: "
f"{orig_expected_path}") from err
return expected_fname
def compare(self, idx, baseline, extension, *, _lock=False):
__tracebackhide__ = True
fignum = plt.get_fignums()[idx]
fig = plt.figure(fignum)
if self.remove_text:
remove_ticks_and_titles(fig)
actual_path = (self.result_dir / baseline).with_suffix(f'.{extension}')
kwargs = self.savefig_kwargs.copy()
if extension == 'pdf':
kwargs.setdefault('metadata',
{'Creator': None, 'Producer': None,
'CreationDate': None})
lock = cbook._lock_path(actual_path) if _lock else nullcontext()
with lock:
fig.savefig(actual_path, **kwargs)
expected_path = self.copy_baseline(baseline, extension)
_raise_on_image_difference(expected_path, actual_path, self.tol)
def _pytest_image_comparison(baseline_images, extensions, tol,
freetype_version, remove_text, savefig_kwargs,
style):
"""
Decorate function with image comparison for pytest.
This function creates a decorator that wraps a figure-generating function
with image comparison code.
"""
import pytest
extensions = map(_mark_skip_if_format_is_uncomparable, extensions)
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
def decorator(func):
old_sig = inspect.signature(func)
@functools.wraps(func)
@pytest.mark.parametrize('extension', extensions)
@pytest.mark.style(style)
@_checked_on_freetype_version(freetype_version)
@functools.wraps(func)
def wrapper(*args, extension, request, **kwargs):
__tracebackhide__ = True
if 'extension' in old_sig.parameters:
kwargs['extension'] = extension
if 'request' in old_sig.parameters:
kwargs['request'] = request
img = _ImageComparisonBase(func, tol=tol, remove_text=remove_text,
savefig_kwargs=savefig_kwargs)
matplotlib.testing.set_font_settings_for_testing()
func(*args, **kwargs)
# If the test is parametrized in any way other than applied via
# this decorator, then we need to use a lock to prevent two
# processes from touching the same output file.
needs_lock = any(
marker.args[0] != 'extension'
for marker in request.node.iter_markers('parametrize'))
if baseline_images is not None:
our_baseline_images = baseline_images
else:
# Allow baseline image list to be produced on the fly based on
# current parametrization.
our_baseline_images = request.getfixturevalue(
'baseline_images')
assert len(plt.get_fignums()) == len(our_baseline_images), (
"Test generated {} images but there are {} baseline images"
.format(len(plt.get_fignums()), len(our_baseline_images)))
for idx, baseline in enumerate(our_baseline_images):
img.compare(idx, baseline, extension, _lock=needs_lock)
parameters = list(old_sig.parameters.values())
if 'extension' not in old_sig.parameters:
parameters += [inspect.Parameter('extension', KEYWORD_ONLY)]
if 'request' not in old_sig.parameters:
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig
# Reach a bit into pytest internals to hoist the marks from our wrapped
# function.
new_marks = getattr(func, 'pytestmark', []) + wrapper.pytestmark
wrapper.pytestmark = new_marks
return wrapper
return decorator
def image_comparison(baseline_images, extensions=None, tol=0,
freetype_version=None, remove_text=False,
savefig_kwarg=None,
# Default of mpl_test_settings fixture and cleanup too.
style=("classic", "_classic_test_patch")):
"""
Compare images generated by the test with those specified in
*baseline_images*, which must correspond, else an `ImageComparisonFailure`
exception will be raised.
Parameters
----------
baseline_images : list or None
A list of strings specifying the names of the images generated by
calls to `.Figure.savefig`.
If *None*, the test function must use the ``baseline_images`` fixture,
either as a parameter or with `pytest.mark.usefixtures`. This value is
only allowed when using pytest.
extensions : None or list of str
The list of extensions to test, e.g. ``['png', 'pdf']``.
If *None*, defaults to all supported extensions: png, pdf, and svg.
When testing a single extension, it can be directly included in the
names passed to *baseline_images*. In that case, *extensions* must not
be set.
In order to keep the size of the test suite from ballooning, we only
include the ``svg`` or ``pdf`` outputs if the test is explicitly
exercising a feature dependent on that backend (see also the
`check_figures_equal` decorator for that purpose).
tol : float, default: 0
The RMS threshold above which the test is considered failed.
Due to expected small differences in floating-point calculations, on
32-bit systems an additional 0.06 is added to this threshold.
freetype_version : str or tuple
The expected freetype version or range of versions for this test to
pass.
remove_text : bool
Remove the title and tick text from the figure before comparison. This
is useful to make the baseline images independent of variations in text
rendering between different versions of FreeType.
This does not remove other, more deliberate, text, such as legends and
annotations.
savefig_kwarg : dict
Optional arguments that are passed to the savefig method.
style : str, dict, or list
The optional style(s) to apply to the image test. The test itself
can also apply additional styles if desired. Defaults to ``["classic",
"_classic_test_patch"]``.
"""
if baseline_images is not None:
# List of non-empty filename extensions.
baseline_exts = [*filter(None, {Path(baseline).suffix[1:]
for baseline in baseline_images})]
if baseline_exts:
if extensions is not None:
raise ValueError(
"When including extensions directly in 'baseline_images', "
"'extensions' cannot be set as well")
if len(baseline_exts) > 1:
raise ValueError(
"When including extensions directly in 'baseline_images', "
"all baselines must share the same suffix")
extensions = baseline_exts
baseline_images = [ # Chop suffix out from baseline_images.
Path(baseline).stem for baseline in baseline_images]
if extensions is None:
# Default extensions to test, if not set via baseline_images.
extensions = ['png', 'pdf', 'svg']
if savefig_kwarg is None:
savefig_kwarg = dict() # default no kwargs to savefig
if sys.maxsize <= 2**32:
tol += 0.06
return _pytest_image_comparison(
baseline_images=baseline_images, extensions=extensions, tol=tol,
freetype_version=freetype_version, remove_text=remove_text,
savefig_kwargs=savefig_kwarg, style=style)
def check_figures_equal(*, extensions=("png", "pdf", "svg"), tol=0):
"""
Decorator for test cases that generate and compare two figures.
The decorated function must take two keyword arguments, *fig_test*
and *fig_ref*, and draw the test and reference images on them.
After the function returns, the figures are saved and compared.
This decorator should be preferred over `image_comparison` when possible in
order to keep the size of the test suite from ballooning.
Parameters
----------
extensions : list, default: ["png", "pdf", "svg"]
The extensions to test.
tol : float
The RMS threshold above which the test is considered failed.
Examples
--------
Check that calling `.Axes.plot` with a single argument plots it against
``[0, 1, 2, ...]``::
@check_figures_equal()
def test_plot(fig_test, fig_ref):
fig_test.subplots().plot([1, 3, 5])
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
"""
ALLOWED_CHARS = set(string.digits + string.ascii_letters + '_-[]()')
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
def decorator(func):
import pytest
_, result_dir = _image_directories(func)
old_sig = inspect.signature(func)
if not {"fig_test", "fig_ref"}.issubset(old_sig.parameters):
raise ValueError("The decorated function must have at least the "
"parameters 'fig_ref' and 'fig_test', but your "
f"function has the signature {old_sig}")
@pytest.mark.parametrize("ext", extensions)
def wrapper(*args, ext, request, **kwargs):
if 'ext' in old_sig.parameters:
kwargs['ext'] = ext
if 'request' in old_sig.parameters:
kwargs['request'] = request
file_name = "".join(c for c in request.node.name
if c in ALLOWED_CHARS)
try:
fig_test = plt.figure("test")
fig_ref = plt.figure("reference")
func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
test_image_path = result_dir / (file_name + "." + ext)
ref_image_path = result_dir / (file_name + "-expected." + ext)
fig_test.savefig(test_image_path)
fig_ref.savefig(ref_image_path)
_raise_on_image_difference(
ref_image_path, test_image_path, tol=tol
)
finally:
plt.close(fig_test)
plt.close(fig_ref)
parameters = [
param
for param in old_sig.parameters.values()
if param.name not in {"fig_test", "fig_ref"}
]
if 'ext' not in old_sig.parameters:
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
if 'request' not in old_sig.parameters:
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig
# reach a bit into pytest internals to hoist the marks from
# our wrapped function
new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
wrapper.pytestmark = new_marks
return wrapper
return decorator
def _image_directories(func):
"""
Compute the baseline and result image directories for testing *func*.
For test module ``foo.bar.test_baz``, the baseline directory is at
``foo/bar/baseline_images/test_baz`` and the result directory at
``$(pwd)/result_images/test_baz``. The result directory is created if it
doesn't exist.
"""
module_path = Path(sys.modules[func.__module__].__file__)
baseline_dir = module_path.parent / "baseline_images" / module_path.stem
result_dir = Path().resolve() / "result_images" / module_path.stem
result_dir.mkdir(parents=True, exist_ok=True)
return baseline_dir, result_dir

View file

@ -0,0 +1,153 @@
# Originally from astropy project (http://astropy.org), under BSD
# 3-clause license.
import contextlib
import socket
from matplotlib import cbook
cbook.warn_deprecated("3.2", name=__name__, obj_type="module",
alternative="pytest-remotedata")
# save original socket method for restoration
# These are global so that re-calling the turn_off_internet function doesn't
# overwrite them again
socket_original = socket.socket
socket_create_connection = socket.create_connection
socket_bind = socket.socket.bind
socket_connect = socket.socket.connect
INTERNET_OFF = False
# urllib2 uses a global variable to cache its default "opener" for opening
# connections for various protocols; we store it off here so we can restore to
# the default after re-enabling internet use
_orig_opener = None
# ::1 is apparently another valid name for localhost?
# it is returned by getaddrinfo when that function is given localhost
def check_internet_off(original_function):
"""
Wrap ``original_function``, which in most cases is assumed
to be a `socket.socket` method, to raise an `IOError` for any operations
on non-local AF_INET sockets.
"""
def new_function(*args, **kwargs):
if isinstance(args[0], socket.socket):
if not args[0].family in (socket.AF_INET, socket.AF_INET6):
# Should be fine in all but some very obscure cases
# More to the point, we don't want to affect AF_UNIX
# sockets.
return original_function(*args, **kwargs)
host = args[1][0]
addr_arg = 1
valid_hosts = ('localhost', '127.0.0.1', '::1')
else:
# The only other function this is used to wrap currently is
# socket.create_connection, which should be passed a 2-tuple, but
# we'll check just in case
if not (isinstance(args[0], tuple) and len(args[0]) == 2):
return original_function(*args, **kwargs)
host = args[0][0]
addr_arg = 0
valid_hosts = ('localhost', '127.0.0.1')
hostname = socket.gethostname()
fqdn = socket.getfqdn()
if host in (hostname, fqdn):
host = 'localhost'
new_addr = (host, args[addr_arg][1])
args = args[:addr_arg] + (new_addr,) + args[addr_arg + 1:]
if any(h in host for h in valid_hosts):
return original_function(*args, **kwargs)
else:
raise IOError("An attempt was made to connect to the internet "
"by a test that was not marked `remote_data`.")
return new_function
def turn_off_internet(verbose=False):
"""
Disable internet access via python by preventing connections from being
created using the socket module. Presumably this could be worked around by
using some other means of accessing the internet, but all default python
modules (urllib, requests, etc.) use socket [citation needed].
"""
import urllib.request
global INTERNET_OFF
global _orig_opener
if INTERNET_OFF:
return
INTERNET_OFF = True
__tracebackhide__ = True
if verbose:
print("Internet access disabled")
# Update urllib2 to force it not to use any proxies
# Must use {} here (the default of None will kick off an automatic search
# for proxies)
_orig_opener = urllib.request.build_opener()
no_proxy_handler = urllib.request.ProxyHandler({})
opener = urllib.request.build_opener(no_proxy_handler)
urllib.request.install_opener(opener)
socket.create_connection = check_internet_off(socket_create_connection)
socket.socket.bind = check_internet_off(socket_bind)
socket.socket.connect = check_internet_off(socket_connect)
return socket
def turn_on_internet(verbose=False):
"""
Restore internet access. Not used, but kept in case it is needed.
"""
import urllib.request
global INTERNET_OFF
global _orig_opener
if not INTERNET_OFF:
return
INTERNET_OFF = False
if verbose:
print("Internet access enabled")
urllib.request.install_opener(_orig_opener)
socket.create_connection = socket_create_connection
socket.socket.bind = socket_bind
socket.socket.connect = socket_connect
return socket
@contextlib.contextmanager
def no_internet(verbose=False):
"""
Temporarily disables internet access (if not already disabled).
If it was already disabled before entering the context manager
(i.e. `turn_off_internet` was called previously) then this is a no-op and
leaves internet access disabled until a manual call to `turn_on_internet`.
"""
already_disabled = INTERNET_OFF
turn_off_internet(verbose=verbose)
try:
yield
finally:
if not already_disabled:
turn_on_internet(verbose=verbose)

View file

@ -0,0 +1,4 @@
class ImageComparisonFailure(AssertionError):
"""
Raise this exception to mark a test as a comparison between two images.
"""

View file

@ -0,0 +1,165 @@
"""Duration module."""
import operator
from matplotlib import cbook
class Duration:
"""Class Duration in development."""
allowed = ["ET", "UTC"]
def __init__(self, frame, seconds):
"""
Create a new Duration object.
= ERROR CONDITIONS
- If the input frame is not in the allowed list, an error is thrown.
= INPUT VARIABLES
- frame The frame of the duration. Must be 'ET' or 'UTC'
- seconds The number of seconds in the Duration.
"""
cbook._check_in_list(self.allowed, frame=frame)
self._frame = frame
self._seconds = seconds
def frame(self):
"""Return the frame the duration is in."""
return self._frame
def __abs__(self):
"""Return the absolute value of the duration."""
return Duration(self._frame, abs(self._seconds))
def __neg__(self):
"""Return the negative value of this Duration."""
return Duration(self._frame, -self._seconds)
def seconds(self):
"""Return the number of seconds in the Duration."""
return self._seconds
def __bool__(self):
return self._seconds != 0
def __eq__(self, rhs):
return self._cmp(rhs, operator.eq)
def __ne__(self, rhs):
return self._cmp(rhs, operator.ne)
def __lt__(self, rhs):
return self._cmp(rhs, operator.lt)
def __le__(self, rhs):
return self._cmp(rhs, operator.le)
def __gt__(self, rhs):
return self._cmp(rhs, operator.gt)
def __ge__(self, rhs):
return self._cmp(rhs, operator.ge)
def _cmp(self, rhs, op):
"""
Compare two Durations.
= INPUT VARIABLES
- rhs The Duration to compare against.
- op The function to do the comparison
= RETURN VALUE
- Returns op(self, rhs)
"""
self.checkSameFrame(rhs, "compare")
return op(self._seconds, rhs._seconds)
def __add__(self, rhs):
"""
Add two Durations.
= ERROR CONDITIONS
- If the input rhs is not in the same frame, an error is thrown.
= INPUT VARIABLES
- rhs The Duration to add.
= RETURN VALUE
- Returns the sum of ourselves and the input Duration.
"""
# Delay-load due to circular dependencies.
import matplotlib.testing.jpl_units as U
if isinstance(rhs, U.Epoch):
return rhs + self
self.checkSameFrame(rhs, "add")
return Duration(self._frame, self._seconds + rhs._seconds)
def __sub__(self, rhs):
"""
Subtract two Durations.
= ERROR CONDITIONS
- If the input rhs is not in the same frame, an error is thrown.
= INPUT VARIABLES
- rhs The Duration to subtract.
= RETURN VALUE
- Returns the difference of ourselves and the input Duration.
"""
self.checkSameFrame(rhs, "sub")
return Duration(self._frame, self._seconds - rhs._seconds)
def __mul__(self, rhs):
"""
Scale a UnitDbl by a value.
= INPUT VARIABLES
- rhs The scalar to multiply by.
= RETURN VALUE
- Returns the scaled Duration.
"""
return Duration(self._frame, self._seconds * float(rhs))
def __rmul__(self, lhs):
"""
Scale a Duration by a value.
= INPUT VARIABLES
- lhs The scalar to multiply by.
= RETURN VALUE
- Returns the scaled Duration.
"""
return Duration(self._frame, self._seconds * float(lhs))
def __str__(self):
"""Print the Duration."""
return "%g %s" % (self._seconds, self._frame)
def __repr__(self):
"""Print the Duration."""
return "Duration('%s', %g)" % (self._frame, self._seconds)
def checkSameFrame(self, rhs, func):
"""
Check to see if frames are the same.
= ERROR CONDITIONS
- If the frame of the rhs Duration is not the same as our frame,
an error is thrown.
= INPUT VARIABLES
- rhs The Duration to check for the same frame
- func The name of the function doing the check.
"""
if self._frame != rhs._frame:
raise ValueError(
f"Cannot {func} Durations with different frames.\n"
f"LHS: {self._frame}\n"
f"RHS: {rhs._frame}")

View file

@ -0,0 +1,232 @@
"""Epoch module."""
import operator
import math
import datetime as DT
from matplotlib import cbook
from matplotlib.dates import date2num
class Epoch:
# Frame conversion offsets in seconds
# t(TO) = t(FROM) + allowed[ FROM ][ TO ]
allowed = {
"ET": {
"UTC": +64.1839,
},
"UTC": {
"ET": -64.1839,
},
}
def __init__(self, frame, sec=None, jd=None, daynum=None, dt=None):
"""
Create a new Epoch object.
Build an epoch 1 of 2 ways:
Using seconds past a Julian date:
# Epoch('ET', sec=1e8, jd=2451545)
or using a matplotlib day number
# Epoch('ET', daynum=730119.5)
= ERROR CONDITIONS
- If the input units are not in the allowed list, an error is thrown.
= INPUT VARIABLES
- frame The frame of the epoch. Must be 'ET' or 'UTC'
- sec The number of seconds past the input JD.
- jd The Julian date of the epoch.
- daynum The matplotlib day number of the epoch.
- dt A python datetime instance.
"""
if ((sec is None and jd is not None) or
(sec is not None and jd is None) or
(daynum is not None and
(sec is not None or jd is not None)) or
(daynum is None and dt is None and
(sec is None or jd is None)) or
(daynum is not None and dt is not None) or
(dt is not None and (sec is not None or jd is not None)) or
((dt is not None) and not isinstance(dt, DT.datetime))):
raise ValueError(
"Invalid inputs. Must enter sec and jd together, "
"daynum by itself, or dt (must be a python datetime).\n"
"Sec = %s\n"
"JD = %s\n"
"dnum= %s\n"
"dt = %s" % (sec, jd, daynum, dt))
cbook._check_in_list(self.allowed, frame=frame)
self._frame = frame
if dt is not None:
daynum = date2num(dt)
if daynum is not None:
# 1-JAN-0001 in JD = 1721425.5
jd = float(daynum) + 1721425.5
self._jd = math.floor(jd)
self._seconds = (jd - self._jd) * 86400.0
else:
self._seconds = float(sec)
self._jd = float(jd)
# Resolve seconds down to [ 0, 86400)
deltaDays = math.floor(self._seconds / 86400)
self._jd += deltaDays
self._seconds -= deltaDays * 86400.0
def convert(self, frame):
if self._frame == frame:
return self
offset = self.allowed[self._frame][frame]
return Epoch(frame, self._seconds + offset, self._jd)
def frame(self):
return self._frame
def julianDate(self, frame):
t = self
if frame != self._frame:
t = self.convert(frame)
return t._jd + t._seconds / 86400.0
def secondsPast(self, frame, jd):
t = self
if frame != self._frame:
t = self.convert(frame)
delta = t._jd - jd
return t._seconds + delta * 86400
def __eq__(self, rhs):
return self._cmp(rhs, operator.eq)
def __ne__(self, rhs):
return self._cmp(rhs, operator.ne)
def __lt__(self, rhs):
return self._cmp(rhs, operator.lt)
def __le__(self, rhs):
return self._cmp(rhs, operator.le)
def __gt__(self, rhs):
return self._cmp(rhs, operator.gt)
def __ge__(self, rhs):
return self._cmp(rhs, operator.ge)
def _cmp(self, rhs, op):
"""
Compare two Epoch's.
= INPUT VARIABLES
- rhs The Epoch to compare against.
- op The function to do the comparison
= RETURN VALUE
- Returns op(self, rhs)
"""
t = self
if self._frame != rhs._frame:
t = self.convert(rhs._frame)
if t._jd != rhs._jd:
return op(t._jd, rhs._jd)
return op(t._seconds, rhs._seconds)
def __add__(self, rhs):
"""
Add a duration to an Epoch.
= INPUT VARIABLES
- rhs The Epoch to subtract.
= RETURN VALUE
- Returns the difference of ourselves and the input Epoch.
"""
t = self
if self._frame != rhs.frame():
t = self.convert(rhs._frame)
sec = t._seconds + rhs.seconds()
return Epoch(t._frame, sec, t._jd)
def __sub__(self, rhs):
"""
Subtract two Epoch's or a Duration from an Epoch.
Valid:
Duration = Epoch - Epoch
Epoch = Epoch - Duration
= INPUT VARIABLES
- rhs The Epoch to subtract.
= RETURN VALUE
- Returns either the duration between to Epoch's or the a new
Epoch that is the result of subtracting a duration from an epoch.
"""
# Delay-load due to circular dependencies.
import matplotlib.testing.jpl_units as U
# Handle Epoch - Duration
if isinstance(rhs, U.Duration):
return self + -rhs
t = self
if self._frame != rhs._frame:
t = self.convert(rhs._frame)
days = t._jd - rhs._jd
sec = t._seconds - rhs._seconds
return U.Duration(rhs._frame, days*86400 + sec)
def __str__(self):
"""Print the Epoch."""
return "%22.15e %s" % (self.julianDate(self._frame), self._frame)
def __repr__(self):
"""Print the Epoch."""
return str(self)
@staticmethod
def range(start, stop, step):
"""
Generate a range of Epoch objects.
Similar to the Python range() method. Returns the range [
start, stop) at the requested step. Each element will be a
Epoch object.
= INPUT VARIABLES
- start The starting value of the range.
- stop The stop value of the range.
- step Step to use.
= RETURN VALUE
- Returns a list containing the requested Epoch values.
"""
elems = []
i = 0
while True:
d = start + i * step
if d >= stop:
break
elems.append(d)
i += 1
return elems

View file

@ -0,0 +1,99 @@
"""EpochConverter module containing class EpochConverter."""
from matplotlib import cbook
import matplotlib.units as units
import matplotlib.dates as date_ticker
__all__ = ['EpochConverter']
class EpochConverter(units.ConversionInterface):
"""
Provides Matplotlib conversion functionality for Monte Epoch and Duration
classes.
"""
# julian date reference for "Jan 1, 0001" minus 1 day because
# Matplotlib really wants "Jan 0, 0001"
jdRef = 1721425.5 - 1
@staticmethod
def axisinfo(unit, axis):
# docstring inherited
majloc = date_ticker.AutoDateLocator()
majfmt = date_ticker.AutoDateFormatter(majloc)
return units.AxisInfo(majloc=majloc, majfmt=majfmt, label=unit)
@staticmethod
def float2epoch(value, unit):
"""
Convert a Matplotlib floating-point date into an Epoch of the specified
units.
= INPUT VARIABLES
- value The Matplotlib floating-point date.
- unit The unit system to use for the Epoch.
= RETURN VALUE
- Returns the value converted to an Epoch in the specified time system.
"""
# Delay-load due to circular dependencies.
import matplotlib.testing.jpl_units as U
secPastRef = value * 86400.0 * U.UnitDbl(1.0, 'sec')
return U.Epoch(unit, secPastRef, EpochConverter.jdRef)
@staticmethod
def epoch2float(value, unit):
"""
Convert an Epoch value to a float suitable for plotting as a python
datetime object.
= INPUT VARIABLES
- value An Epoch or list of Epochs that need to be converted.
- unit The units to use for an axis with Epoch data.
= RETURN VALUE
- Returns the value parameter converted to floats.
"""
return value.julianDate(unit) - EpochConverter.jdRef
@staticmethod
def duration2float(value):
"""
Convert a Duration value to a float suitable for plotting as a python
datetime object.
= INPUT VARIABLES
- value A Duration or list of Durations that need to be converted.
= RETURN VALUE
- Returns the value parameter converted to floats.
"""
return value.seconds() / 86400.0
@staticmethod
def convert(value, unit, axis):
# docstring inherited
# Delay-load due to circular dependencies.
import matplotlib.testing.jpl_units as U
if not cbook.is_scalar_or_string(value):
return [EpochConverter.convert(x, unit, axis) for x in value]
if units.ConversionInterface.is_numlike(value):
return value
if unit is None:
unit = EpochConverter.default_units(value, axis)
if isinstance(value, U.Duration):
return EpochConverter.duration2float(value)
else:
return EpochConverter.epoch2float(value, unit)
@staticmethod
def default_units(value, axis):
# docstring inherited
if cbook.is_scalar_or_string(value):
return value.frame()
else:
return EpochConverter.default_units(value[0], axis)

View file

@ -0,0 +1,100 @@
"""StrConverter module containing class StrConverter."""
import numpy as np
import matplotlib.units as units
__all__ = ['StrConverter']
class StrConverter(units.ConversionInterface):
"""
A Matplotlib converter class for string data values.
Valid units for string are:
- 'indexed' : Values are indexed as they are specified for plotting.
- 'sorted' : Values are sorted alphanumerically.
- 'inverted' : Values are inverted so that the first value is on top.
- 'sorted-inverted' : A combination of 'sorted' and 'inverted'
"""
@staticmethod
def axisinfo(unit, axis):
# docstring inherited
return None
@staticmethod
def convert(value, unit, axis):
# docstring inherited
if units.ConversionInterface.is_numlike(value):
return value
if value == []:
return []
# we delay loading to make matplotlib happy
ax = axis.axes
if axis is ax.get_xaxis():
isXAxis = True
else:
isXAxis = False
axis.get_major_ticks()
ticks = axis.get_ticklocs()
labels = axis.get_ticklabels()
labels = [l.get_text() for l in labels if l.get_text()]
if not labels:
ticks = []
labels = []
if not np.iterable(value):
value = [value]
newValues = []
for v in value:
if v not in labels and v not in newValues:
newValues.append(v)
labels.extend(newValues)
# DISABLED: This is disabled because matplotlib bar plots do not
# DISABLED: recalculate the unit conversion of the data values
# DISABLED: this is due to design and is not really a bug.
# DISABLED: If this gets changed, then we can activate the following
# DISABLED: block of code. Note that this works for line plots.
# DISABLED if unit:
# DISABLED if unit.find("sorted") > -1:
# DISABLED labels.sort()
# DISABLED if unit.find("inverted") > -1:
# DISABLED labels = labels[::-1]
# add padding (so they do not appear on the axes themselves)
labels = [''] + labels + ['']
ticks = list(range(len(labels)))
ticks[0] = 0.5
ticks[-1] = ticks[-1] - 0.5
axis.set_ticks(ticks)
axis.set_ticklabels(labels)
# we have to do the following lines to make ax.autoscale_view work
loc = axis.get_major_locator()
loc.set_bounds(ticks[0], ticks[-1])
if isXAxis:
ax.set_xlim(ticks[0], ticks[-1])
else:
ax.set_ylim(ticks[0], ticks[-1])
result = [ticks[labels.index(v)] for v in value]
ax.viewLim.ignore(-1)
return result
@staticmethod
def default_units(value, axis):
# docstring inherited
# The default behavior for string indexing.
return "indexed"

View file

@ -0,0 +1,262 @@
"""UnitDbl module."""
import operator
from matplotlib import cbook
class UnitDbl:
"""Class UnitDbl in development."""
# Unit conversion table. Small subset of the full one but enough
# to test the required functions. First field is a scale factor to
# convert the input units to the units of the second field. Only
# units in this table are allowed.
allowed = {
"m": (0.001, "km"),
"km": (1, "km"),
"mile": (1.609344, "km"),
"rad": (1, "rad"),
"deg": (1.745329251994330e-02, "rad"),
"sec": (1, "sec"),
"min": (60.0, "sec"),
"hour": (3600, "sec"),
}
_types = {
"km": "distance",
"rad": "angle",
"sec": "time",
}
def __init__(self, value, units):
"""
Create a new UnitDbl object.
Units are internally converted to km, rad, and sec. The only
valid inputs for units are [m, km, mile, rad, deg, sec, min, hour].
The field UnitDbl.value will contain the converted value. Use
the convert() method to get a specific type of units back.
= ERROR CONDITIONS
- If the input units are not in the allowed list, an error is thrown.
= INPUT VARIABLES
- value The numeric value of the UnitDbl.
- units The string name of the units the value is in.
"""
data = cbook._check_getitem(self.allowed, units=units)
self._value = float(value * data[0])
self._units = data[1]
def convert(self, units):
"""
Convert the UnitDbl to a specific set of units.
= ERROR CONDITIONS
- If the input units are not in the allowed list, an error is thrown.
= INPUT VARIABLES
- units The string name of the units to convert to.
= RETURN VALUE
- Returns the value of the UnitDbl in the requested units as a floating
point number.
"""
if self._units == units:
return self._value
data = cbook._check_getitem(self.allowed, units=units)
if self._units != data[1]:
raise ValueError(f"Error trying to convert to different units.\n"
f" Invalid conversion requested.\n"
f" UnitDbl: {self}\n"
f" Units: {units}\n")
return self._value / data[0]
def __abs__(self):
"""Return the absolute value of this UnitDbl."""
return UnitDbl(abs(self._value), self._units)
def __neg__(self):
"""Return the negative value of this UnitDbl."""
return UnitDbl(-self._value, self._units)
def __bool__(self):
"""Return the truth value of a UnitDbl."""
return bool(self._value)
def __eq__(self, rhs):
return self._cmp(rhs, operator.eq)
def __ne__(self, rhs):
return self._cmp(rhs, operator.ne)
def __lt__(self, rhs):
return self._cmp(rhs, operator.lt)
def __le__(self, rhs):
return self._cmp(rhs, operator.le)
def __gt__(self, rhs):
return self._cmp(rhs, operator.gt)
def __ge__(self, rhs):
return self._cmp(rhs, operator.ge)
def _cmp(self, rhs, op):
"""
Compare two UnitDbl's.
= ERROR CONDITIONS
- If the input rhs units are not the same as our units,
an error is thrown.
= INPUT VARIABLES
- rhs The UnitDbl to compare against.
- op The function to do the comparison
= RETURN VALUE
- Returns op(self, rhs)
"""
self.checkSameUnits(rhs, "compare")
return op(self._value, rhs._value)
def __add__(self, rhs):
"""
Add two UnitDbl's.
= ERROR CONDITIONS
- If the input rhs units are not the same as our units,
an error is thrown.
= INPUT VARIABLES
- rhs The UnitDbl to add.
= RETURN VALUE
- Returns the sum of ourselves and the input UnitDbl.
"""
self.checkSameUnits(rhs, "add")
return UnitDbl(self._value + rhs._value, self._units)
def __sub__(self, rhs):
"""
Subtract two UnitDbl's.
= ERROR CONDITIONS
- If the input rhs units are not the same as our units,
an error is thrown.
= INPUT VARIABLES
- rhs The UnitDbl to subtract.
= RETURN VALUE
- Returns the difference of ourselves and the input UnitDbl.
"""
self.checkSameUnits(rhs, "subtract")
return UnitDbl(self._value - rhs._value, self._units)
def __mul__(self, rhs):
"""
Scale a UnitDbl by a value.
= INPUT VARIABLES
- rhs The scalar to multiply by.
= RETURN VALUE
- Returns the scaled UnitDbl.
"""
return UnitDbl(self._value * rhs, self._units)
def __rmul__(self, lhs):
"""
Scale a UnitDbl by a value.
= INPUT VARIABLES
- lhs The scalar to multiply by.
= RETURN VALUE
- Returns the scaled UnitDbl.
"""
return UnitDbl(self._value * lhs, self._units)
def __str__(self):
"""Print the UnitDbl."""
return "%g *%s" % (self._value, self._units)
def __repr__(self):
"""Print the UnitDbl."""
return "UnitDbl(%g, '%s')" % (self._value, self._units)
def type(self):
"""Return the type of UnitDbl data."""
return self._types[self._units]
@staticmethod
def range(start, stop, step=None):
"""
Generate a range of UnitDbl objects.
Similar to the Python range() method. Returns the range [
start, stop) at the requested step. Each element will be a
UnitDbl object.
= INPUT VARIABLES
- start The starting value of the range.
- stop The stop value of the range.
- step Optional step to use. If set to None, then a UnitDbl of
value 1 w/ the units of the start is used.
= RETURN VALUE
- Returns a list containing the requested UnitDbl values.
"""
if step is None:
step = UnitDbl(1, start._units)
elems = []
i = 0
while True:
d = start + i * step
if d >= stop:
break
elems.append(d)
i += 1
return elems
@cbook.deprecated("3.2")
def checkUnits(self, units):
"""
Check to see if some units are valid.
= ERROR CONDITIONS
- If the input units are not in the allowed list, an error is thrown.
= INPUT VARIABLES
- units The string name of the units to check.
"""
if units not in self.allowed:
raise ValueError("Input units '%s' are not one of the supported "
"types of %s" % (
units, list(self.allowed.keys())))
def checkSameUnits(self, rhs, func):
"""
Check to see if units are the same.
= ERROR CONDITIONS
- If the units of the rhs UnitDbl are not the same as our units,
an error is thrown.
= INPUT VARIABLES
- rhs The UnitDbl to check for the same units
- func The name of the function doing the check.
"""
if self._units != rhs._units:
raise ValueError(f"Cannot {func} units of different types.\n"
f"LHS: {self._units}\n"
f"RHS: {rhs._units}")

View file

@ -0,0 +1,91 @@
"""UnitDblConverter module containing class UnitDblConverter."""
import numpy as np
from matplotlib import cbook
import matplotlib.units as units
import matplotlib.projections.polar as polar
__all__ = ['UnitDblConverter']
# A special function for use with the matplotlib FuncFormatter class
# for formatting axes with radian units.
# This was copied from matplotlib example code.
def rad_fn(x, pos=None):
"""Radian function formatter."""
n = int((x / np.pi) * 2.0 + 0.25)
if n == 0:
return str(x)
elif n == 1:
return r'$\pi/2$'
elif n == 2:
return r'$\pi$'
elif n % 2 == 0:
return fr'${n//2}\pi$'
else:
return fr'${n}\pi/2$'
class UnitDblConverter(units.ConversionInterface):
"""
Provides Matplotlib conversion functionality for the Monte UnitDbl class.
"""
# default for plotting
defaults = {
"distance": 'km',
"angle": 'deg',
"time": 'sec',
}
@staticmethod
def axisinfo(unit, axis):
# docstring inherited
# Delay-load due to circular dependencies.
import matplotlib.testing.jpl_units as U
# Check to see if the value used for units is a string unit value
# or an actual instance of a UnitDbl so that we can use the unit
# value for the default axis label value.
if unit:
label = unit if isinstance(unit, str) else unit.label()
else:
label = None
if label == "deg" and isinstance(axis.axes, polar.PolarAxes):
# If we want degrees for a polar plot, use the PolarPlotFormatter
majfmt = polar.PolarAxes.ThetaFormatter()
else:
majfmt = U.UnitDblFormatter(useOffset=False)
return units.AxisInfo(majfmt=majfmt, label=label)
@staticmethod
def convert(value, unit, axis):
# docstring inherited
if not cbook.is_scalar_or_string(value):
return [UnitDblConverter.convert(x, unit, axis) for x in value]
# If the incoming value behaves like a number,
# then just return it because we don't know how to convert it
# (or it is already converted)
if units.ConversionInterface.is_numlike(value):
return value
# If no units were specified, then get the default units to use.
if unit is None:
unit = UnitDblConverter.default_units(value, axis)
# Convert the incoming UnitDbl value/values to float/floats
if isinstance(axis.axes, polar.PolarAxes) and value.type() == "angle":
# Guarantee that units are radians for polar plots.
return value.convert("rad")
return value.convert(unit)
@staticmethod
def default_units(value, axis):
# docstring inherited
# Determine the default units based on the user preferences set for
# default units when printing a UnitDbl.
if cbook.is_scalar_or_string(value):
return UnitDblConverter.defaults[value.type()]
else:
return UnitDblConverter.default_units(value[0], axis)

View file

@ -0,0 +1,28 @@
"""UnitDblFormatter module containing class UnitDblFormatter."""
import matplotlib.ticker as ticker
__all__ = ['UnitDblFormatter']
class UnitDblFormatter(ticker.ScalarFormatter):
"""
The formatter for UnitDbl data types.
This allows for formatting with the unit string.
"""
def __call__(self, x, pos=None):
# docstring inherited
if len(self.locs) == 0:
return ''
else:
return '{:.12}'.format(x)
def format_data_short(self, value):
# docstring inherited
return '{:.12}'.format(value)
def format_data(self, value):
# docstring inherited
return '{:.12}'.format(value)

View file

@ -0,0 +1,76 @@
"""
A sample set of units for use with testing unit conversion
of Matplotlib routines. These are used because they use very strict
enforcement of unitized data which will test the entire spectrum of how
unitized data might be used (it is not always meaningful to convert to
a float without specific units given).
UnitDbl is essentially a unitized floating point number. It has a
minimal set of supported units (enough for testing purposes). All
of the mathematical operation are provided to fully test any behaviour
that might occur with unitized data. Remember that unitized data has
rules as to how it can be applied to one another (a value of distance
cannot be added to a value of time). Thus we need to guard against any
accidental "default" conversion that will strip away the meaning of the
data and render it neutered.
Epoch is different than a UnitDbl of time. Time is something that can be
measured where an Epoch is a specific moment in time. Epochs are typically
referenced as an offset from some predetermined epoch.
A difference of two epochs is a Duration. The distinction between a Duration
and a UnitDbl of time is made because an Epoch can have different frames (or
units). In the case of our test Epoch class the two allowed frames are 'UTC'
and 'ET' (Note that these are rough estimates provided for testing purposes
and should not be used in production code where accuracy of time frames is
desired). As such a Duration also has a frame of reference and therefore needs
to be called out as different that a simple measurement of time since a delta-t
in one frame may not be the same in another.
"""
from .Duration import Duration
from .Epoch import Epoch
from .UnitDbl import UnitDbl
from .StrConverter import StrConverter
from .EpochConverter import EpochConverter
from .UnitDblConverter import UnitDblConverter
from .UnitDblFormatter import UnitDblFormatter
__version__ = "1.0"
__all__ = [
'register',
'Duration',
'Epoch',
'UnitDbl',
'UnitDblFormatter',
]
def register():
"""Register the unit conversion classes with matplotlib."""
import matplotlib.units as mplU
mplU.registry[str] = StrConverter()
mplU.registry[Epoch] = EpochConverter()
mplU.registry[Duration] = EpochConverter()
mplU.registry[UnitDbl] = UnitDblConverter()
# Some default unit instances
# Distances
m = UnitDbl(1.0, "m")
km = UnitDbl(1.0, "km")
mile = UnitDbl(1.0, "mile")
# Angles
deg = UnitDbl(1.0, "deg")
rad = UnitDbl(1.0, "rad")
# Time
sec = UnitDbl(1.0, "sec")
min = UnitDbl(1.0, "min")
hr = UnitDbl(1.0, "hour")
day = UnitDbl(24.0, "hour")
sec = UnitDbl(1.0, "sec")

View file

@ -0,0 +1,57 @@
"""
========================
Widget testing utilities
========================
Functions that are useful for testing widgets.
See also matplotlib.tests.test_widgets
"""
import matplotlib.pyplot as plt
from unittest import mock
def get_ax():
"""Creates plot and returns its axes"""
fig, ax = plt.subplots(1, 1)
ax.plot([0, 200], [0, 200])
ax.set_aspect(1.0)
ax.figure.canvas.draw()
return ax
def do_event(tool, etype, button=1, xdata=0, ydata=0, key=None, step=1):
"""
Trigger an event
Parameters
----------
tool : matplotlib.widgets.RectangleSelector
etype
the event to trigger
xdata : int
x coord of mouse in data coords
ydata : int
y coord of mouse in data coords
button : int or str
button pressed None, 1, 2, 3, 'up', 'down' (up and down are used
for scroll events)
key
the key depressed when the mouse event triggered (see
:class:`KeyEvent`)
step : int
number of scroll steps (positive for 'up', negative for 'down')
"""
event = mock.Mock()
event.button = button
ax = tool.ax
event.x, event.y = ax.transData.transform([(xdata, ydata),
(xdata, ydata)])[0]
event.xdata, event.ydata = xdata, ydata
event.inaxes = ax
event.canvas = ax.figure.canvas
event.key = key
event.step = step
event.guiEvent = None
event.name = 'Custom'
func = getattr(tool, etype)
func(event)