from __future__ import absolute_import

from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs

# Optional imports, may be None for users that only use our core functionality.
np = optional_imports.get_module('numpy')
pd = optional_imports.get_module('pandas')
scipy = optional_imports.get_module('scipy')
scipy_stats = optional_imports.get_module('scipy.stats')


DEFAULT_HISTNORM = 'probability density'
ALTERNATIVE_HISTNORM = 'probability'


def validate_distplot(hist_data, curve_type):
    """
    Distplot-specific validations

    :raises: (PlotlyError) If hist_data is not a list of lists
    :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
        'normal').
    """
    hist_data_types = (list,)
    if np:
        hist_data_types += (np.ndarray,)
    if pd:
        hist_data_types += (pd.core.series.Series,)

    if not isinstance(hist_data[0], hist_data_types):
        raise exceptions.PlotlyError("Oops, this function was written "
                                     "to handle multiple datasets, if "
                                     "you want to plot just one, make "
                                     "sure your hist_data variable is "
                                     "still a list of lists, i.e. x = "
                                     "[1, 2, 3] -> x = [[1, 2, 3]]")

    curve_opts = ('kde', 'normal')
    if curve_type not in curve_opts:
        raise exceptions.PlotlyError("curve_type must be defined as "
                                     "'kde' or 'normal'")

    if not scipy:
        raise ImportError("FigureFactory.create_distplot requires scipy")


def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde',
                    colors=None, rug_text=None, histnorm=DEFAULT_HISTNORM,
                    show_hist=True, show_curve=True, show_rug=True):
    """
    BETA function that creates a distplot similar to seaborn.distplot

    The distplot can be composed of all or any combination of the following
    3 components: (1) histogram, (2) curve: (a) kernel density estimation
    or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
    (from multiple datasets) can be created in the same plot.

    :param (list[list]) hist_data: Use list of lists to plot multiple data
        sets on the same plot.
    :param (list[str]) group_labels: Names for each data set.
    :param (list[float]|float) bin_size: Size of histogram bins.
        Default = 1.
    :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
    :param (str) histnorm: 'probability density' or 'probability'
        Default = 'probability density'
    :param (bool) show_hist: Add histogram to distplot? Default = True
    :param (bool) show_curve: Add curve to distplot? Default = True
    :param (bool) show_rug: Add rug to distplot? Default = True
    :param (list[str]) colors: Colors for traces.
    :param (list[list]) rug_text: Hovertext values for rug_plot,
    :return (dict): Representation of a distplot figure.

    Example 1: Simple distplot of 1 data set
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot

    hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
                  3.5, 4.1, 4.4, 4.5, 4.5,
                  5.0, 5.0, 5.2, 5.5, 5.5,
                  5.5, 5.5, 5.5, 6.1, 7.0]]

    group_labels = ['distplot example']

    fig = create_distplot(hist_data, group_labels)

    url = py.plot(fig, filename='Simple distplot', validate=False)
    ```

    Example 2: Two data sets and added rug text
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot

    # Add histogram data
    hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
               -0.9, -0.07, 1.95, 0.9, -0.2,
               -0.5, 0.3, 0.4, -0.37, 0.6]
    hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
               1.0, 0.8, 1.7, 0.5, 0.8,
               -0.3, 1.2, 0.56, 0.3, 2.2]

    # Group data together
    hist_data = [hist1_x, hist2_x]

    group_labels = ['2012', '2013']

    # Add text
    rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
          'f1', 'g1', 'h1', 'i1', 'j1',
          'k1', 'l1', 'm1', 'n1', 'o1']

    rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
          'f2', 'g2', 'h2', 'i2', 'j2',
          'k2', 'l2', 'm2', 'n2', 'o2']

    # Group text together
    rug_text_all = [rug_text_1, rug_text_2]

    # Create distplot
    fig = create_distplot(
        hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)

    # Add title
    fig['layout'].update(title='Dist Plot')

    # Plot!
    url = py.plot(fig, filename='Distplot with rug text', validate=False)
    ```

    Example 3: Plot with normal curve and hide rug plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot
    import numpy as np

    x1 = np.random.randn(190)
    x2 = np.random.randn(200)+1
    x3 = np.random.randn(200)-1
    x4 = np.random.randn(210)+2

    hist_data = [x1, x2, x3, x4]
    group_labels = ['2012', '2013', '2014', '2015']

    fig = create_distplot(
        hist_data, group_labels, curve_type='normal',
        show_rug=False, bin_size=.4)

    url = py.plot(fig, filename='hist and normal curve', validate=False)

    Example 4: Distplot with Pandas
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot
    import numpy as np
    import pandas as pd

    df = pd.DataFrame({'2012': np.random.randn(200),
                       '2013': np.random.randn(200)+1})
    py.iplot(create_distplot([df[c] for c in df.columns], df.columns),
                             filename='examples/distplot with pandas',
                             validate=False)
    ```
    """
    if colors is None:
        colors = []
    if rug_text is None:
        rug_text = []

    validate_distplot(hist_data, curve_type)
    utils.validate_equal_length(hist_data, group_labels)

    if isinstance(bin_size, (float, int)):
        bin_size = [bin_size] * len(hist_data)

    hist = _Distplot(
        hist_data, histnorm, group_labels, bin_size,
        curve_type, colors, rug_text,
        show_hist, show_curve).make_hist()

    if curve_type == 'normal':
        curve = _Distplot(
            hist_data, histnorm, group_labels, bin_size,
            curve_type, colors, rug_text,
            show_hist, show_curve).make_normal()
    else:
        curve = _Distplot(
            hist_data, histnorm, group_labels, bin_size,
            curve_type, colors, rug_text,
            show_hist, show_curve).make_kde()

    rug = _Distplot(
        hist_data, histnorm, group_labels, bin_size,
        curve_type, colors, rug_text,
        show_hist, show_curve).make_rug()

    data = []
    if show_hist:
        data.append(hist)
    if show_curve:
        data.append(curve)
    if show_rug:
        data.append(rug)
        layout = graph_objs.Layout(
            barmode='overlay',
            hovermode='closest',
            legend=dict(traceorder='reversed'),
            xaxis1=dict(domain=[0.0, 1.0],
                        anchor='y2',
                        zeroline=False),
            yaxis1=dict(domain=[0.35, 1],
                        anchor='free',
                        position=0.0),
            yaxis2=dict(domain=[0, 0.25],
                        anchor='x1',
                        dtick=1,
                        showticklabels=False))
    else:
        layout = graph_objs.Layout(
            barmode='overlay',
            hovermode='closest',
            legend=dict(traceorder='reversed'),
            xaxis1=dict(domain=[0.0, 1.0],
                        anchor='y2',
                        zeroline=False),
            yaxis1=dict(domain=[0., 1],
                        anchor='free',
                        position=0.0))

    data = sum(data, [])
    return graph_objs.Figure(data=data, layout=layout)


class _Distplot(object):
    """
    Refer to TraceFactory.create_distplot() for docstring
    """
    def __init__(self, hist_data, histnorm, group_labels,
                 bin_size, curve_type, colors,
                 rug_text, show_hist, show_curve):
        self.hist_data = hist_data
        self.histnorm = histnorm
        self.group_labels = group_labels
        self.bin_size = bin_size
        self.show_hist = show_hist
        self.show_curve = show_curve
        self.trace_number = len(hist_data)
        if rug_text:
            self.rug_text = rug_text
        else:
            self.rug_text = [None] * self.trace_number

        self.start = []
        self.end = []
        if colors:
            self.colors = colors
        else:
            self.colors = [
                "rgb(31, 119, 180)", "rgb(255, 127, 14)",
                "rgb(44, 160, 44)", "rgb(214, 39, 40)",
                "rgb(148, 103, 189)", "rgb(140, 86, 75)",
                "rgb(227, 119, 194)", "rgb(127, 127, 127)",
                "rgb(188, 189, 34)", "rgb(23, 190, 207)"]
        self.curve_x = [None] * self.trace_number
        self.curve_y = [None] * self.trace_number

        for trace in self.hist_data:
            self.start.append(min(trace) * 1.)
            self.end.append(max(trace) * 1.)

    def make_hist(self):
        """
        Makes the histogram(s) for FigureFactory.create_distplot().

        :rtype (list) hist: list of histogram representations
        """
        hist = [None] * self.trace_number

        for index in range(self.trace_number):
            hist[index] = dict(type='histogram',
                               x=self.hist_data[index],
                               xaxis='x1',
                               yaxis='y1',
                               histnorm=self.histnorm,
                               name=self.group_labels[index],
                               legendgroup=self.group_labels[index],
                               marker=dict(color=self.colors[index % len(self.colors)]),
                               autobinx=False,
                               xbins=dict(start=self.start[index],
                                          end=self.end[index],
                                          size=self.bin_size[index]),
                               opacity=.7)
        return hist

    def make_kde(self):
        """
        Makes the kernel density estimation(s) for create_distplot().

        This is called when curve_type = 'kde' in create_distplot().

        :rtype (list) curve: list of kde representations
        """
        curve = [None] * self.trace_number
        for index in range(self.trace_number):
            self.curve_x[index] = [self.start[index] +
                                   x * (self.end[index] - self.start[index])
                                   / 500 for x in range(500)]
            self.curve_y[index] = (scipy_stats.gaussian_kde
                                   (self.hist_data[index])
                                   (self.curve_x[index]))

            if self.histnorm == ALTERNATIVE_HISTNORM:
                self.curve_y[index] *= self.bin_size[index]

        for index in range(self.trace_number):
            curve[index] = dict(type='scatter',
                                x=self.curve_x[index],
                                y=self.curve_y[index],
                                xaxis='x1',
                                yaxis='y1',
                                mode='lines',
                                name=self.group_labels[index],
                                legendgroup=self.group_labels[index],
                                showlegend=False if self.show_hist else True,
                                marker=dict(color=self.colors[index % len(self.colors)]))
        return curve

    def make_normal(self):
        """
        Makes the normal curve(s) for create_distplot().

        This is called when curve_type = 'normal' in create_distplot().

        :rtype (list) curve: list of normal curve representations
        """
        curve = [None] * self.trace_number
        mean = [None] * self.trace_number
        sd = [None] * self.trace_number

        for index in range(self.trace_number):
            mean[index], sd[index] = (scipy_stats.norm.fit
                                      (self.hist_data[index]))
            self.curve_x[index] = [self.start[index] +
                                   x * (self.end[index] - self.start[index])
                                   / 500 for x in range(500)]
            self.curve_y[index] = scipy_stats.norm.pdf(
                self.curve_x[index], loc=mean[index], scale=sd[index])

            if self.histnorm == ALTERNATIVE_HISTNORM:
                self.curve_y[index] *= self.bin_size[index]

        for index in range(self.trace_number):
            curve[index] = dict(type='scatter',
                                x=self.curve_x[index],
                                y=self.curve_y[index],
                                xaxis='x1',
                                yaxis='y1',
                                mode='lines',
                                name=self.group_labels[index],
                                legendgroup=self.group_labels[index],
                                showlegend=False if self.show_hist else True,
                                marker=dict(color=self.colors[index % len(self.colors)]))
        return curve

    def make_rug(self):
        """
        Makes the rug plot(s) for create_distplot().

        :rtype (list) rug: list of rug plot representations
        """
        rug = [None] * self.trace_number
        for index in range(self.trace_number):

            rug[index] = dict(type='scatter',
                              x=self.hist_data[index],
                              y=([self.group_labels[index]] *
                                 len(self.hist_data[index])),
                              xaxis='x1',
                              yaxis='y2',
                              mode='markers',
                              name=self.group_labels[index],
                              legendgroup=self.group_labels[index],
                              showlegend=(False if self.show_hist or
                                          self.show_curve else True),
                              text=self.rug_text[index],
                              marker=dict(color=self.colors[index % len(self.colors)],
                                          symbol='line-ns-open'))
        return rug
