from __future__ import absolute_import

from plotly import exceptions, optional_imports
from plotly.graph_objs import graph_objs

pd = optional_imports.get_module('pandas')


def validate_table(table_text, font_colors):
    """
    Table-specific validations

    Check that font_colors is supplied correctly (1, 3, or len(text)
        colors).

    :raises: (PlotlyError) If font_colors is supplied incorretly.

    See FigureFactory.create_table() for params
    """
    font_colors_len_options = [1, 3, len(table_text)]
    if len(font_colors) not in font_colors_len_options:
        raise exceptions.PlotlyError("Oops, font_colors should be a list "
                                     "of length 1, 3 or len(text)")


def create_table(table_text, colorscale=None, font_colors=None,
                 index=False, index_title='', annotation_offset=.45,
                 height_constant=30, hoverinfo='none', **kwargs):
    """
    BETA function that creates data tables

    :param (pandas.Dataframe | list[list]) text: data for table.
    :param (str|list[list]) colorscale: Colorscale for table where the
        color at value 0 is the header color, .5 is the first table color
        and 1 is the second table color. (Set .5 and 1 to avoid the striped
        table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
        [1, '#ffffff']]
    :param (list) font_colors: Color for fonts in table. Can be a single
        color, three colors, or a color for each row in the table.
        Default=['#000000'] (black text for the entire table)
    :param (int) height_constant: Constant multiplied by # of rows to
        create table height. Default=30.
    :param (bool) index: Create (header-colored) index column index from
        Pandas dataframe or list[0] for each list in text. Default=False.
    :param (string) index_title: Title for index column. Default=''.
    :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
        These kwargs describe other attributes about the annotated Heatmap
        trace such as the colorscale. For more information on valid kwargs
        call help(plotly.graph_objs.Heatmap)

    Example 1: Simple Plotly Table
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_table

    text = [['Country', 'Year', 'Population'],
            ['US', 2000, 282200000],
            ['Canada', 2000, 27790000],
            ['US', 2010, 309000000],
            ['Canada', 2010, 34000000]]

    table = create_table(text)
    py.iplot(table)
    ```

    Example 2: Table with Custom Coloring
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_table

    text = [['Country', 'Year', 'Population'],
            ['US', 2000, 282200000],
            ['Canada', 2000, 27790000],
            ['US', 2010, 309000000],
            ['Canada', 2010, 34000000]]

    table = create_table(text,
                         colorscale=[[0, '#000000'],
                                     [.5, '#80beff'],
                                     [1, '#cce5ff']],
                         font_colors=['#ffffff', '#000000',
                                    '#000000'])
    py.iplot(table)
    ```
    Example 3: Simple Plotly Table with Pandas
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_table

    import pandas as pd

    df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
    df_p = df[0:25]

    table_simple = create_table(df_p)
    py.iplot(table_simple)
    ```
    """

    # Avoiding mutables in the call signature
    colorscale = \
        colorscale if colorscale is not None else [[0, '#00083e'],
                                                   [.5, '#ededee'],
                                                   [1, '#ffffff']]
    font_colors = font_colors if font_colors is not None else ['#ffffff',
                                                               '#000000',
                                                               '#000000']

    validate_table(table_text, font_colors)
    table_matrix = _Table(table_text, colorscale, font_colors, index,
                          index_title, annotation_offset,
                          **kwargs).get_table_matrix()
    annotations = _Table(table_text, colorscale, font_colors, index,
                         index_title, annotation_offset,
                         **kwargs).make_table_annotations()

    trace = dict(type='heatmap', z=table_matrix, opacity=.75,
                 colorscale=colorscale, showscale=False,
                 hoverinfo=hoverinfo, **kwargs)

    data = [trace]
    layout = dict(annotations=annotations,
                  height=len(table_matrix) * height_constant + 50,
                  margin=dict(t=0, b=0, r=0, l=0),
                  yaxis=dict(autorange='reversed', zeroline=False,
                             gridwidth=2, ticks='', dtick=1, tick0=.5,
                             showticklabels=False),
                  xaxis=dict(zeroline=False, gridwidth=2, ticks='',
                             dtick=1, tick0=-0.5, showticklabels=False))
    return graph_objs.Figure(data=data, layout=layout)


class _Table(object):
    """
    Refer to TraceFactory.create_table() for docstring
    """
    def __init__(self, table_text, colorscale, font_colors, index,
                 index_title, annotation_offset, **kwargs):
        if pd and isinstance(table_text, pd.DataFrame):
            headers = table_text.columns.tolist()
            table_text_index = table_text.index.tolist()
            table_text = table_text.values.tolist()
            table_text.insert(0, headers)
            if index:
                table_text_index.insert(0, index_title)
                for i in range(len(table_text)):
                    table_text[i].insert(0, table_text_index[i])
        self.table_text = table_text
        self.colorscale = colorscale
        self.font_colors = font_colors
        self.index = index
        self.annotation_offset = annotation_offset
        self.x = range(len(table_text[0]))
        self.y = range(len(table_text))

    def get_table_matrix(self):
        """
        Create z matrix to make heatmap with striped table coloring

        :rtype (list[list]) table_matrix: z matrix to make heatmap with striped
            table coloring.
        """
        header = [0] * len(self.table_text[0])
        odd_row = [.5] * len(self.table_text[0])
        even_row = [1] * len(self.table_text[0])
        table_matrix = [None] * len(self.table_text)
        table_matrix[0] = header
        for i in range(1, len(self.table_text), 2):
            table_matrix[i] = odd_row
        for i in range(2, len(self.table_text), 2):
            table_matrix[i] = even_row
        if self.index:
            for array in table_matrix:
                array[0] = 0
        return table_matrix

    def get_table_font_color(self):
        """
        Fill font-color array.

        Table text color can vary by row so this extends a single color or
        creates an array to set a header color and two alternating colors to
        create the striped table pattern.

        :rtype (list[list]) all_font_colors: list of font colors for each row
            in table.
        """
        if len(self.font_colors) == 1:
            all_font_colors = self.font_colors*len(self.table_text)
        elif len(self.font_colors) == 3:
            all_font_colors = list(range(len(self.table_text)))
            all_font_colors[0] = self.font_colors[0]
            for i in range(1, len(self.table_text), 2):
                all_font_colors[i] = self.font_colors[1]
            for i in range(2, len(self.table_text), 2):
                all_font_colors[i] = self.font_colors[2]
        elif len(self.font_colors) == len(self.table_text):
            all_font_colors = self.font_colors
        else:
            all_font_colors = ['#000000']*len(self.table_text)
        return all_font_colors

    def make_table_annotations(self):
        """
        Generate annotations to fill in table text

        :rtype (list) annotations: list of annotations for each cell of the
            table.
        """
        table_matrix = _Table.get_table_matrix(self)
        all_font_colors = _Table.get_table_font_color(self)
        annotations = []
        for n, row in enumerate(self.table_text):
            for m, val in enumerate(row):
                # Bold text in header and index
                format_text = ('<b>' + str(val) + '</b>' if n == 0 or
                               self.index and m < 1 else str(val))
                # Match font color of index to font color of header
                font_color = (self.font_colors[0] if self.index and m == 0
                              else all_font_colors[n])
                annotations.append(
                    graph_objs.Annotation(
                        text=format_text,
                        x=self.x[m] - self.annotation_offset,
                        y=self.y[n],
                        xref='x1',
                        yref='y1',
                        align="left",
                        xanchor="left",
                        font=dict(color=font_color),
                        showarrow=False))
        return annotations
