317 lines
10 KiB
Python
317 lines
10 KiB
Python
import itertools
|
|
import functools
|
|
|
|
import numpy as np
|
|
from scipy import ndimage as ndi
|
|
|
|
from ..metrics import mean_squared_error
|
|
from ..util import img_as_float
|
|
|
|
|
|
def _interpolate_image(image, *, multichannel=False):
|
|
"""Replacing each pixel in ``image`` with the average of its neighbors.
|
|
|
|
Parameters
|
|
----------
|
|
image : ndarray
|
|
Input data to be interpolated.
|
|
multichannel : bool, optional
|
|
Whether the last axis of the image is to be interpreted as multiple
|
|
channels or another spatial dimension.
|
|
|
|
Returns
|
|
-------
|
|
interp : ndarray
|
|
Interpolated version of `image`.
|
|
"""
|
|
spatialdims = image.ndim if not multichannel else image.ndim - 1
|
|
conv_filter = ndi.generate_binary_structure(spatialdims, 1).astype(image.dtype)
|
|
conv_filter.ravel()[conv_filter.size // 2] = 0
|
|
conv_filter /= conv_filter.sum()
|
|
|
|
if multichannel:
|
|
interp = np.zeros_like(image)
|
|
for i in range(image.shape[-1]):
|
|
interp[..., i] = ndi.convolve(image[..., i], conv_filter,
|
|
mode='mirror')
|
|
else:
|
|
interp = ndi.convolve(image, conv_filter, mode='mirror')
|
|
return interp
|
|
|
|
|
|
def _generate_grid_slice(shape, *, offset, stride=3):
|
|
"""Generate slices of uniformly-spaced points in an array.
|
|
|
|
Parameters
|
|
----------
|
|
shape : tuple of int
|
|
Shape of the mask.
|
|
offset : int
|
|
The offset of the grid of ones. Iterating over ``offset`` will cover
|
|
the entire array. It should be between 0 and ``stride ** ndim``, not
|
|
inclusive, where ``ndim = len(shape)``.
|
|
stride : int, optional
|
|
The spacing between ones, used in each dimension.
|
|
|
|
Returns
|
|
-------
|
|
mask : ndarray
|
|
The mask.
|
|
|
|
Examples
|
|
--------
|
|
>>> shape = (4, 4)
|
|
>>> array = np.zeros(shape, dtype=int)
|
|
>>> grid_slice = _generate_grid_slice(shape, offset=0, stride=2)
|
|
>>> array[grid_slice] = 1
|
|
>>> print(array)
|
|
[[1 0 1 0]
|
|
[0 0 0 0]
|
|
[1 0 1 0]
|
|
[0 0 0 0]]
|
|
|
|
Changing the offset moves the location of the 1s:
|
|
|
|
>>> array = np.zeros(shape, dtype=int)
|
|
>>> grid_slice = _generate_grid_slice(shape, offset=3, stride=2)
|
|
>>> array[grid_slice] = 1
|
|
>>> print(array)
|
|
[[0 0 0 0]
|
|
[0 1 0 1]
|
|
[0 0 0 0]
|
|
[0 1 0 1]]
|
|
"""
|
|
phases = np.unravel_index(offset, (stride,) * len(shape))
|
|
mask = tuple(slice(p, None, stride) for p in phases)
|
|
|
|
return mask
|
|
|
|
|
|
def _invariant_denoise(image, denoise_function, *, stride=4,
|
|
masks=None, denoiser_kwargs=None):
|
|
"""Apply a J-invariant version of `denoise_function`.
|
|
|
|
Parameters
|
|
----------
|
|
image : ndarray
|
|
Input data to be denoised (converted using `img_as_float`).
|
|
denoise_function : function
|
|
Original denoising function.
|
|
stride : int, optional
|
|
Stride used in masking procedure that converts `denoise_function`
|
|
to J-invariance.
|
|
masks : list of ndarray, optional
|
|
Set of masks to use for computing J-invariant output. If `None`,
|
|
a full set of masks covering the image will be used.
|
|
denoiser_kwargs:
|
|
Keyword arguments passed to `denoise_function`.
|
|
|
|
Returns
|
|
-------
|
|
output : ndarray
|
|
Denoised image, of same shape as `image`.
|
|
"""
|
|
image = img_as_float(image)
|
|
if denoiser_kwargs is None:
|
|
denoiser_kwargs = {}
|
|
|
|
if 'multichannel' in denoiser_kwargs:
|
|
multichannel = denoiser_kwargs['multichannel']
|
|
else:
|
|
multichannel = False
|
|
interp = _interpolate_image(image, multichannel=multichannel)
|
|
output = np.zeros_like(image)
|
|
|
|
if masks is None:
|
|
spatialdims = image.ndim if not multichannel else image.ndim - 1
|
|
n_masks = stride ** spatialdims
|
|
masks = (_generate_grid_slice(image.shape[:spatialdims],
|
|
offset=idx, stride=stride)
|
|
for idx in range(n_masks))
|
|
|
|
for mask in masks:
|
|
input_image = image.copy()
|
|
input_image[mask] = interp[mask]
|
|
output[mask] = denoise_function(input_image, **denoiser_kwargs)[mask]
|
|
return output
|
|
|
|
|
|
def _product_from_dict(dictionary):
|
|
"""Utility function to convert parameter ranges to parameter combinations.
|
|
|
|
Converts a dict of lists into a list of dicts whose values consist of the
|
|
cartesian product of the values in the original dict.
|
|
|
|
Parameters
|
|
----------
|
|
dictionary : dict of lists
|
|
Dictionary of lists to be multiplied.
|
|
|
|
Yields
|
|
------
|
|
selections : dicts of values
|
|
Dicts containing individual combinations of the values in the input
|
|
dict.
|
|
"""
|
|
keys = dictionary.keys()
|
|
for element in itertools.product(*dictionary.values()):
|
|
yield dict(zip(keys, element))
|
|
|
|
|
|
def calibrate_denoiser(image, denoise_function, denoise_parameters, *,
|
|
stride=4, approximate_loss=True,
|
|
extra_output=False):
|
|
"""Calibrate a denoising function and return optimal J-invariant version.
|
|
|
|
The returned function is partially evaluated with optimal parameter values
|
|
set for denoising the input image.
|
|
|
|
Parameters
|
|
----------
|
|
image : ndarray
|
|
Input data to be denoised (converted using `img_as_float`).
|
|
denoise_function : function
|
|
Denoising function to be calibrated.
|
|
denoise_parameters : dict of list
|
|
Ranges of parameters for `denoise_function` to be calibrated over.
|
|
stride : int, optional
|
|
Stride used in masking procedure that converts `denoise_function`
|
|
to J-invariance.
|
|
approximate_loss : bool, optional
|
|
Whether to approximate the self-supervised loss used to evaluate the
|
|
denoiser by only computing it on one masked version of the image.
|
|
If False, the runtime will be a factor of `stride**image.ndim` longer.
|
|
extra_output : bool, optional
|
|
If True, return parameters and losses in addition to the calibrated
|
|
denoising function
|
|
|
|
Returns
|
|
-------
|
|
best_denoise_function : function
|
|
The optimal J-invariant version of `denoise_function`.
|
|
|
|
If `extra_output` is True, the following tuple is also returned:
|
|
|
|
(parameters_tested, losses) : tuple (list of dict, list of int)
|
|
List of parameters tested for `denoise_function`, as a dictionary of
|
|
kwargs
|
|
Self-supervised loss for each set of parameters in `parameters_tested`.
|
|
|
|
|
|
Notes
|
|
-----
|
|
|
|
The calibration procedure uses a self-supervised mean-square-error loss
|
|
to evaluate the performance of J-invariant versions of `denoise_function`.
|
|
The minimizer of the self-supervised loss is also the minimizer of the
|
|
ground-truth loss (i.e., the true MSE error) [1]. The returned function
|
|
can be used on the original noisy image, or other images with similar
|
|
characteristics.
|
|
|
|
Increasing the stride increases the performance of `best_denoise_function`
|
|
at the expense of increasing its runtime. It has no effect on the runtime
|
|
of the calibration.
|
|
|
|
References
|
|
----------
|
|
.. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
|
|
International Conference on Machine Learning, p. 524-533 (2019).
|
|
|
|
Examples
|
|
--------
|
|
|
|
>>> from skimage import color, data
|
|
>>> from skimage.restoration import denoise_wavelet
|
|
>>> import numpy as np
|
|
>>> img = color.rgb2gray(data.astronaut()[:50, :50])
|
|
>>> noisy = img + 0.5 * img.std() * np.random.randn(*img.shape)
|
|
>>> parameters = {'sigma': np.arange(0.1, 0.4, 0.02)}
|
|
>>> denoising_function = calibrate_denoiser(noisy, denoise_wavelet,
|
|
... denoise_parameters=parameters)
|
|
>>> denoised_img = denoising_function(img)
|
|
|
|
"""
|
|
parameters_tested, losses = _calibrate_denoiser_search(
|
|
image, denoise_function,
|
|
denoise_parameters=denoise_parameters,
|
|
stride=stride,
|
|
approximate_loss=approximate_loss
|
|
)
|
|
|
|
idx = np.argmin(losses)
|
|
best_parameters = parameters_tested[idx]
|
|
|
|
best_denoise_function = functools.partial(
|
|
_invariant_denoise,
|
|
denoise_function=denoise_function,
|
|
stride=stride,
|
|
denoiser_kwargs=best_parameters,
|
|
)
|
|
|
|
if extra_output:
|
|
return best_denoise_function, (parameters_tested, losses)
|
|
else:
|
|
return best_denoise_function
|
|
|
|
|
|
def _calibrate_denoiser_search(image, denoise_function, denoise_parameters, *,
|
|
stride=4, approximate_loss=True):
|
|
"""Return a parameter search history with losses for a denoise function.
|
|
|
|
Parameters
|
|
----------
|
|
image : ndarray
|
|
Input data to be denoised (converted using `img_as_float`).
|
|
denoise_function : function
|
|
Denoising function to be calibrated.
|
|
denoise_parameters : dict of list
|
|
Ranges of parameters for `denoise_function` to be calibrated over.
|
|
stride : int, optional
|
|
Stride used in masking procedure that converts `denoise_function`
|
|
to J-invariance.
|
|
approximate_loss : bool, optional
|
|
Whether to approximate the self-supervised loss used to evaluate the
|
|
denoiser by only computing it on one masked version of the image.
|
|
If False, the runtime will be a factor of `stride**image.ndim` longer.
|
|
|
|
Returns
|
|
-------
|
|
parameters_tested : list of dict
|
|
List of parameters tested for `denoise_function`, as a dictionary of
|
|
kwargs.
|
|
losses : list of int
|
|
Self-supervised loss for each set of parameters in `parameters_tested`.
|
|
"""
|
|
image = img_as_float(image)
|
|
parameters_tested = list(_product_from_dict(denoise_parameters))
|
|
losses = []
|
|
|
|
for denoiser_kwargs in parameters_tested:
|
|
if 'multichannel' in denoiser_kwargs:
|
|
multichannel = denoiser_kwargs['multichannel']
|
|
else:
|
|
multichannel = False
|
|
if not approximate_loss:
|
|
denoised = _invariant_denoise(
|
|
image, denoise_function,
|
|
stride=stride,
|
|
denoiser_kwargs=denoiser_kwargs
|
|
)
|
|
loss = mean_squared_error(image, denoised)
|
|
else:
|
|
spatialdims = image.ndim if not multichannel else image.ndim - 1
|
|
n_masks = stride ** spatialdims
|
|
mask = _generate_grid_slice(image.shape[:spatialdims],
|
|
offset=n_masks // 2, stride=stride)
|
|
|
|
masked_denoised = _invariant_denoise(
|
|
image, denoise_function,
|
|
masks=[mask],
|
|
denoiser_kwargs=denoiser_kwargs
|
|
)
|
|
|
|
loss = mean_squared_error(image[mask], masked_denoised[mask])
|
|
|
|
losses.append(loss)
|
|
|
|
return parameters_tested, losses
|