import numpy as np

from matplotlib import cbook, ticker as mticker
from matplotlib.transforms import Bbox, Transform
from .clip_path import clip_line_to_rect


def _deprecate_factor_none(factor):
    # After the deprecation period, calls to _deprecate_factor_none can just be
    # removed.
    if factor is None:
        cbook.warn_deprecated(
            "3.2", message="factor=None is deprecated since %(since)s and "
            "support will be removed %(removal)s; use/return factor=1 instead")
        factor = 1
    return factor


class ExtremeFinderSimple:
    """
    A helper class to figure out the range of grid lines that need to be drawn.
    """

    def __init__(self, nx, ny):
        """
        Parameters
        ----------
        nx, ny : int
            The number of samples in each direction.
        """
        self.nx = nx
        self.ny = ny

    def __call__(self, transform_xy, x1, y1, x2, y2):
        """
        Compute an approximation of the bounding box obtained by applying
        *transform_xy* to the box delimited by ``(x1, y1, x2, y2)``.

        The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates,
        and have *transform_xy* be the transform from axes coordinates to data
        coordinates; this method then returns the range of data coordinates
        that span the actual axes.

        The computation is done by sampling ``nx * ny`` equispaced points in
        the ``(x1, y1, x2, y2)`` box and finding the resulting points with
        extremal coordinates; then adding some padding to take into account the
        finite sampling.

        As each sampling step covers a relative range of *1/nx* or *1/ny*,
        the padding is computed by expanding the span covered by the extremal
        coordinates by these fractions.
        """
        x, y = np.meshgrid(
            np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny))
        xt, yt = transform_xy(np.ravel(x), np.ravel(y))
        return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())

    def _add_pad(self, x_min, x_max, y_min, y_max):
        """Perform the padding mentioned in `__call__`."""
        dx = (x_max - x_min) / self.nx
        dy = (y_max - y_min) / self.ny
        return x_min - dx, x_max + dx, y_min - dy, y_max + dy


class GridFinder:
    def __init__(self,
                 transform,
                 extreme_finder=None,
                 grid_locator1=None,
                 grid_locator2=None,
                 tick_formatter1=None,
                 tick_formatter2=None):
        """
        transform : transform from the image coordinate (which will be
        the transData of the axes to the world coordinate.

        or transform = (transform_xy, inv_transform_xy)

        locator1, locator2 : grid locator for 1st and 2nd axis.
        """
        if extreme_finder is None:
            extreme_finder = ExtremeFinderSimple(20, 20)
        if grid_locator1 is None:
            grid_locator1 = MaxNLocator()
        if grid_locator2 is None:
            grid_locator2 = MaxNLocator()
        if tick_formatter1 is None:
            tick_formatter1 = FormatterPrettyPrint()
        if tick_formatter2 is None:
            tick_formatter2 = FormatterPrettyPrint()
        self.extreme_finder = extreme_finder
        self.grid_locator1 = grid_locator1
        self.grid_locator2 = grid_locator2
        self.tick_formatter1 = tick_formatter1
        self.tick_formatter2 = tick_formatter2
        self.update_transform(transform)

    def get_grid_info(self, x1, y1, x2, y2):
        """
        lon_values, lat_values : list of grid values. if integer is given,
                           rough number of grids in each direction.
        """

        extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2)

        # min & max rage of lat (or lon) for each grid line will be drawn.
        # i.e., gridline of lon=0 will be drawn from lat_min to lat_max.

        lon_min, lon_max, lat_min, lat_max = extremes
        lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max)
        lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max)

        lon_values = lon_levs[:lon_n] / _deprecate_factor_none(lon_factor)
        lat_values = lat_levs[:lat_n] / _deprecate_factor_none(lat_factor)

        lon_lines, lat_lines = self._get_raw_grid_lines(lon_values,
                                                        lat_values,
                                                        lon_min, lon_max,
                                                        lat_min, lat_max)

        ddx = (x2-x1)*1.e-10
        ddy = (y2-y1)*1.e-10
        bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy)

        grid_info = {
            "extremes": extremes,
            "lon_lines": lon_lines,
            "lat_lines": lat_lines,
            "lon": self._clip_grid_lines_and_find_ticks(
                lon_lines, lon_values, lon_levs, bb),
            "lat": self._clip_grid_lines_and_find_ticks(
                lat_lines, lat_values, lat_levs, bb),
        }

        tck_labels = grid_info["lon"]["tick_labels"] = {}
        for direction in ["left", "bottom", "right", "top"]:
            levs = grid_info["lon"]["tick_levels"][direction]
            tck_labels[direction] = self.tick_formatter1(
                direction, lon_factor, levs)

        tck_labels = grid_info["lat"]["tick_labels"] = {}
        for direction in ["left", "bottom", "right", "top"]:
            levs = grid_info["lat"]["tick_levels"][direction]
            tck_labels[direction] = self.tick_formatter2(
                direction, lat_factor, levs)

        return grid_info

    def _get_raw_grid_lines(self,
                            lon_values, lat_values,
                            lon_min, lon_max, lat_min, lat_max):

        lons_i = np.linspace(lon_min, lon_max, 100)  # for interpolation
        lats_i = np.linspace(lat_min, lat_max, 100)

        lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i)
                     for lon in lon_values]
        lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat))
                     for lat in lat_values]

        return lon_lines, lat_lines

    def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
        gi = {
            "values": [],
            "levels": [],
            "tick_levels": dict(left=[], bottom=[], right=[], top=[]),
            "tick_locs": dict(left=[], bottom=[], right=[], top=[]),
            "lines": [],
        }

        tck_levels = gi["tick_levels"]
        tck_locs = gi["tick_locs"]
        for (lx, ly), v, lev in zip(lines, values, levs):
            xy, tcks = clip_line_to_rect(lx, ly, bb)
            if not xy:
                continue
            gi["levels"].append(v)
            gi["lines"].append(xy)

            for tck, direction in zip(tcks,
                                      ["left", "bottom", "right", "top"]):
                for t in tck:
                    tck_levels[direction].append(lev)
                    tck_locs[direction].append(t)

        return gi

    def update_transform(self, aux_trans):
        if isinstance(aux_trans, Transform):
            def transform_xy(x, y):
                ll1 = np.column_stack([x, y])
                ll2 = aux_trans.transform(ll1)
                lon, lat = ll2[:, 0], ll2[:, 1]
                return lon, lat

            def inv_transform_xy(x, y):
                ll1 = np.column_stack([x, y])
                ll2 = aux_trans.inverted().transform(ll1)
                lon, lat = ll2[:, 0], ll2[:, 1]
                return lon, lat

        else:
            transform_xy, inv_transform_xy = aux_trans

        self.transform_xy = transform_xy
        self.inv_transform_xy = inv_transform_xy

    def update(self, **kw):
        for k in kw:
            if k in ["extreme_finder",
                     "grid_locator1",
                     "grid_locator2",
                     "tick_formatter1",
                     "tick_formatter2"]:
                setattr(self, k, kw[k])
            else:
                raise ValueError("Unknown update property '%s'" % k)


@cbook.deprecated("3.2")
class GridFinderBase(GridFinder):
    def __init__(self,
                 extreme_finder,
                 grid_locator1=None,
                 grid_locator2=None,
                 tick_formatter1=None,
                 tick_formatter2=None):
        super().__init__((None, None), extreme_finder,
                         grid_locator1, grid_locator2,
                         tick_formatter1, tick_formatter2)


class MaxNLocator(mticker.MaxNLocator):
    def __init__(self, nbins=10, steps=None,
                 trim=True,
                 integer=False,
                 symmetric=False,
                 prune=None):
        # trim argument has no effect. It has been left for API compatibility
        mticker.MaxNLocator.__init__(self, nbins, steps=steps,
                                     integer=integer,
                                     symmetric=symmetric, prune=prune)
        self.create_dummy_axis()
        self._factor = 1

    def __call__(self, v1, v2):
        self.set_bounds(v1 * self._factor, v2 * self._factor)
        locs = mticker.MaxNLocator.__call__(self)
        return np.array(locs), len(locs), self._factor

    @cbook.deprecated("3.3")
    def set_factor(self, f):
        self._factor = _deprecate_factor_none(f)


class FixedLocator:
    def __init__(self, locs):
        self._locs = locs
        self._factor = 1

    def __call__(self, v1, v2):
        v1, v2 = sorted([v1 * self._factor, v2 * self._factor])
        locs = np.array([l for l in self._locs if v1 <= l <= v2])
        return locs, len(locs), self._factor

    @cbook.deprecated("3.3")
    def set_factor(self, f):
        self._factor = _deprecate_factor_none(f)


# Tick Formatter

class FormatterPrettyPrint:
    def __init__(self, useMathText=True):
        self._fmt = mticker.ScalarFormatter(
            useMathText=useMathText, useOffset=False)
        self._fmt.create_dummy_axis()

    def __call__(self, direction, factor, values):
        return self._fmt.format_ticks(values)


class DictFormatter:
    def __init__(self, format_dict, formatter=None):
        """
        format_dict : dictionary for format strings to be used.
        formatter : fall-back formatter
        """
        super().__init__()
        self._format_dict = format_dict
        self._fallback_formatter = formatter

    def __call__(self, direction, factor, values):
        """
        factor is ignored if value is found in the dictionary
        """
        if self._fallback_formatter:
            fallback_strings = self._fallback_formatter(
                direction, factor, values)
        else:
            fallback_strings = [""] * len(values)
        return [self._format_dict.get(k, v)
                for k, v in zip(values, fallback_strings)]