"""Catch all for categorical functions"""
import pytest
import numpy as np

import matplotlib as mpl
from matplotlib._api import MatplotlibDeprecationWarning
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import matplotlib.category as cat
from matplotlib.testing.decorators import check_figures_equal


class TestUnitData:
    test_cases = [('single', (["hello world"], [0])),
                  ('unicode', (["Здравствуйте мир"], [0])),
                  ('mixed', (['A', "np.nan", 'B', "3.14", "мир"],
                             [0, 1, 2, 3, 4]))]
    ids, data = zip(*test_cases)

    @pytest.mark.parametrize("data, locs", data, ids=ids)
    def test_unit(self, data, locs):
        unit = cat.UnitData(data)
        assert list(unit._mapping.keys()) == data
        assert list(unit._mapping.values()) == locs

    def test_update(self):
        data = ['a', 'd']
        locs = [0, 1]

        data_update = ['b', 'd', 'e']
        unique_data = ['a', 'd', 'b', 'e']
        updated_locs = [0, 1, 2, 3]

        unit = cat.UnitData(data)
        assert list(unit._mapping.keys()) == data
        assert list(unit._mapping.values()) == locs

        unit.update(data_update)
        assert list(unit._mapping.keys()) == unique_data
        assert list(unit._mapping.values()) == updated_locs

    failing_test_cases = [("number", 3.14), ("nan", np.nan),
                          ("list", [3.14, 12]), ("mixed type", ["A", 2])]

    fids, fdata = zip(*test_cases)

    @pytest.mark.parametrize("fdata", fdata, ids=fids)
    def test_non_string_fails(self, fdata):
        with pytest.raises(TypeError):
            cat.UnitData(fdata)

    @pytest.mark.parametrize("fdata", fdata, ids=fids)
    def test_non_string_update_fails(self, fdata):
        unitdata = cat.UnitData()
        with pytest.raises(TypeError):
            unitdata.update(fdata)


class FakeAxis:
    def __init__(self, units):
        self.units = units


class TestStrCategoryConverter:
    """
    Based on the pandas conversion and factorization tests:

    ref: /pandas/tseries/tests/test_converter.py
         /pandas/tests/test_algos.py:TestFactorize
    """
    test_cases = [("unicode", ["Здравствуйте мир"]),
                  ("ascii", ["hello world"]),
                  ("single", ['a', 'b', 'c']),
                  ("integer string", ["1", "2"]),
                  ("single + values>10", ["A", "B", "C", "D", "E", "F", "G",
                                          "H", "I", "J", "K", "L", "M", "N",
                                          "O", "P", "Q", "R", "S", "T", "U",
                                          "V", "W", "X", "Y", "Z"])]

    ids, values = zip(*test_cases)

    failing_test_cases = [("mixed", [3.14, 'A', np.inf]),
                          ("string integer", ['42', 42])]

    fids, fvalues = zip(*failing_test_cases)

    @pytest.fixture(autouse=True)
    def mock_axis(self, request):
        self.cc = cat.StrCategoryConverter()
        # self.unit should be probably be replaced with real mock unit
        self.unit = cat.UnitData()
        self.ax = FakeAxis(self.unit)

    @pytest.mark.parametrize("vals", values, ids=ids)
    def test_convert(self, vals):
        np.testing.assert_allclose(self.cc.convert(vals, self.ax.units,
                                                   self.ax),
                                   range(len(vals)))

    @pytest.mark.parametrize("value", ["hi", "мир"], ids=["ascii", "unicode"])
    def test_convert_one_string(self, value):
        assert self.cc.convert(value, self.unit, self.ax) == 0

    def test_convert_one_number(self):
        with pytest.warns(MatplotlibDeprecationWarning):
            actual = self.cc.convert(0.0, self.unit, self.ax)
        np.testing.assert_allclose(actual, np.array([0.]))

    def test_convert_float_array(self):
        data = np.array([1, 2, 3], dtype=float)
        with pytest.warns(MatplotlibDeprecationWarning):
            actual = self.cc.convert(data, self.unit, self.ax)
        np.testing.assert_allclose(actual, np.array([1., 2., 3.]))

    @pytest.mark.parametrize("fvals", fvalues, ids=fids)
    def test_convert_fail(self, fvals):
        with pytest.raises(TypeError):
            self.cc.convert(fvals, self.unit, self.ax)

    def test_axisinfo(self):
        axis = self.cc.axisinfo(self.unit, self.ax)
        assert isinstance(axis.majloc, cat.StrCategoryLocator)
        assert isinstance(axis.majfmt, cat.StrCategoryFormatter)

    def test_default_units(self):
        assert isinstance(self.cc.default_units(["a"], self.ax), cat.UnitData)


PLOT_LIST = [Axes.scatter, Axes.plot, Axes.bar]
PLOT_IDS = ["scatter", "plot", "bar"]


class TestStrCategoryLocator:
    def test_StrCategoryLocator(self):
        locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        unit = cat.UnitData([str(j) for j in locs])
        ticks = cat.StrCategoryLocator(unit._mapping)
        np.testing.assert_array_equal(ticks.tick_values(None, None), locs)

    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_StrCategoryLocatorPlot(self, plotter):
        ax = plt.figure().subplots()
        plotter(ax, [1, 2, 3], ["a", "b", "c"])
        np.testing.assert_array_equal(ax.yaxis.major.locator(), range(3))


class TestStrCategoryFormatter:
    test_cases = [("ascii", ["hello", "world", "hi"]),
                  ("unicode", ["Здравствуйте", "привет"])]

    ids, cases = zip(*test_cases)

    @pytest.mark.parametrize("ydata", cases, ids=ids)
    def test_StrCategoryFormatter(self, ydata):
        unit = cat.UnitData(ydata)
        labels = cat.StrCategoryFormatter(unit._mapping)
        for i, d in enumerate(ydata):
            assert labels(i, i) == d
            assert labels(i, None) == d

    @pytest.mark.parametrize("ydata", cases, ids=ids)
    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_StrCategoryFormatterPlot(self, ydata, plotter):
        ax = plt.figure().subplots()
        plotter(ax, range(len(ydata)), ydata)
        for i, d in enumerate(ydata):
            assert ax.yaxis.major.formatter(i) == d
        assert ax.yaxis.major.formatter(i+1) == ""


def axis_test(axis, labels):
    ticks = list(range(len(labels)))
    np.testing.assert_array_equal(axis.get_majorticklocs(), ticks)
    graph_labels = [axis.major.formatter(i, i) for i in ticks]
    # _text also decodes bytes as utf-8.
    assert graph_labels == [cat.StrCategoryFormatter._text(l) for l in labels]
    assert list(axis.units._mapping.keys()) == [l for l in labels]
    assert list(axis.units._mapping.values()) == ticks


class TestPlotBytes:
    bytes_cases = [('string list', ['a', 'b', 'c']),
                   ('bytes list', [b'a', b'b', b'c']),
                   ('bytes ndarray', np.array([b'a', b'b', b'c']))]

    bytes_ids, bytes_data = zip(*bytes_cases)

    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    @pytest.mark.parametrize("bdata", bytes_data, ids=bytes_ids)
    def test_plot_bytes(self, plotter, bdata):
        ax = plt.figure().subplots()
        counts = np.array([4, 6, 5])
        plotter(ax, bdata, counts)
        axis_test(ax.xaxis, bdata)


class TestPlotNumlike:
    numlike_cases = [('string list', ['1', '11', '3']),
                     ('string ndarray', np.array(['1', '11', '3'])),
                     ('bytes list', [b'1', b'11', b'3']),
                     ('bytes ndarray', np.array([b'1', b'11', b'3']))]
    numlike_ids, numlike_data = zip(*numlike_cases)

    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    @pytest.mark.parametrize("ndata", numlike_data, ids=numlike_ids)
    def test_plot_numlike(self, plotter, ndata):
        ax = plt.figure().subplots()
        counts = np.array([4, 6, 5])
        plotter(ax, ndata, counts)
        axis_test(ax.xaxis, ndata)


class TestPlotTypes:
    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_plot_unicode(self, plotter):
        ax = plt.figure().subplots()
        words = ['Здравствуйте', 'привет']
        plotter(ax, words, [0, 1])
        axis_test(ax.xaxis, words)

    @pytest.fixture
    def test_data(self):
        self.x = ["hello", "happy", "world"]
        self.xy = [2, 6, 3]
        self.y = ["Python", "is", "fun"]
        self.yx = [3, 4, 5]

    @pytest.mark.usefixtures("test_data")
    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_plot_xaxis(self, test_data, plotter):
        ax = plt.figure().subplots()
        plotter(ax, self.x, self.xy)
        axis_test(ax.xaxis, self.x)

    @pytest.mark.usefixtures("test_data")
    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_plot_yaxis(self, test_data, plotter):
        ax = plt.figure().subplots()
        plotter(ax, self.yx, self.y)
        axis_test(ax.yaxis, self.y)

    @pytest.mark.usefixtures("test_data")
    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_plot_xyaxis(self, test_data, plotter):
        ax = plt.figure().subplots()
        plotter(ax, self.x, self.y)
        axis_test(ax.xaxis, self.x)
        axis_test(ax.yaxis, self.y)

    @pytest.mark.parametrize("plotter", PLOT_LIST, ids=PLOT_IDS)
    def test_update_plot(self, plotter):
        ax = plt.figure().subplots()
        plotter(ax, ['a', 'b'], ['e', 'g'])
        plotter(ax, ['a', 'b', 'd'], ['f', 'a', 'b'])
        plotter(ax, ['b', 'c', 'd'], ['g', 'e', 'd'])
        axis_test(ax.xaxis, ['a', 'b', 'd', 'c'])
        axis_test(ax.yaxis, ['e', 'g', 'f', 'a', 'b', 'd'])

    failing_test_cases = [("mixed", ['A', 3.14]),
                          ("number integer", ['1', 1]),
                          ("string integer", ['42', 42]),
                          ("missing", ['12', np.nan])]

    fids, fvalues = zip(*failing_test_cases)

    plotters = [Axes.scatter, Axes.bar,
                pytest.param(Axes.plot, marks=pytest.mark.xfail)]

    @pytest.mark.parametrize("plotter", plotters)
    @pytest.mark.parametrize("xdata", fvalues, ids=fids)
    def test_mixed_type_exception(self, plotter, xdata):
        ax = plt.figure().subplots()
        with pytest.raises(TypeError):
            plotter(ax, xdata, [1, 2])

    @pytest.mark.parametrize("plotter", plotters)
    @pytest.mark.parametrize("xdata", fvalues, ids=fids)
    def test_mixed_type_update_exception(self, plotter, xdata):
        ax = plt.figure().subplots()
        with pytest.raises(TypeError):
            plotter(ax, [0, 3], [1, 3])
            plotter(ax, xdata, [1, 2])


@mpl.style.context('default')
@check_figures_equal(extensions=["png"])
def test_overriding_units_in_plot(fig_test, fig_ref):
    from datetime import datetime

    t0 = datetime(2018, 3, 1)
    t1 = datetime(2018, 3, 2)
    t2 = datetime(2018, 3, 3)
    t3 = datetime(2018, 3, 4)

    ax_test = fig_test.subplots()
    ax_ref = fig_ref.subplots()
    for ax, kwargs in zip([ax_test, ax_ref],
                          ({}, dict(xunits=None, yunits=None))):
        # First call works
        ax.plot([t0, t1], ["V1", "V2"], **kwargs)
        x_units = ax.xaxis.units
        y_units = ax.yaxis.units
        # this should not raise
        ax.plot([t2, t3], ["V1", "V2"], **kwargs)
        # assert that we have not re-set the units attribute at all
        assert x_units is ax.xaxis.units
        assert y_units is ax.yaxis.units


def test_hist():
    fig, ax = plt.subplots()
    n, bins, patches = ax.hist(['a', 'b', 'a', 'c', 'ff'])
    assert n.shape == (10,)
    np.testing.assert_allclose(n, [2., 0., 0., 1., 0., 0., 1., 0., 0., 1.])
