from __future__ import absolute_import

from plotly import colors, exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.tools import make_subplots

import math
from numbers import Number

pd = optional_imports.get_module('pandas')

TICK_COLOR = '#969696'
AXIS_TITLE_COLOR = '#0f0f0f'
AXIS_TITLE_SIZE = 12
GRID_COLOR = '#ffffff'
LEGEND_COLOR = '#efefef'
PLOT_BGCOLOR = '#ededed'
ANNOT_RECT_COLOR = '#d0d0d0'
LEGEND_BORDER_WIDTH = 1
LEGEND_ANNOT_X = 1.05
LEGEND_ANNOT_Y = 0.5
MAX_TICKS_PER_AXIS = 5
THRES_FOR_FLIPPED_FACET_TITLES = 10
GRID_WIDTH = 1

VALID_TRACE_TYPES = ['scatter', 'scattergl', 'histogram', 'bar', 'box']

CUSTOM_LABEL_ERROR = (
    "If you are using a dictionary for custom labels for the facet row/col, "
    "make sure each key in that column of the dataframe is in your facet "
    "labels. The keys you need are {}"
)


def _is_flipped(num):
    if num >= THRES_FOR_FLIPPED_FACET_TITLES:
        flipped = True
    else:
        flipped = False
    return flipped


def _return_label(original_label, facet_labels, facet_var):
    if isinstance(facet_labels, dict):
        label = facet_labels[original_label]
    elif isinstance(facet_labels, str):
        label = '{}: {}'.format(facet_var, original_label)
    else:
        label = original_label
    return label


def _legend_annotation(color_name):
    legend_title = dict(
        textangle=0,
        xanchor='left',
        yanchor='middle',
        x=LEGEND_ANNOT_X,
        y=1.03,
        showarrow=False,
        xref='paper',
        yref='paper',
        text='factor({})'.format(color_name),
        font=dict(
            size=13,
            color='#000000'
        )
    )
    return legend_title


def _annotation_dict(text, lane, num_of_lanes, SUBPLOT_SPACING, row_col='col',
                     flipped=True):
    l = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes)
    if not flipped:
        xanchor = 'center'
        yanchor = 'middle'
        if row_col == 'col':
            x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
            y = 1.03
            textangle = 0
        elif row_col == 'row':
            y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
            x = 1.03
            textangle = 90
    else:
        if row_col == 'col':
            xanchor = 'center'
            yanchor = 'bottom'
            x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
            y = 1.0
            textangle = 270
        elif row_col == 'row':
            xanchor = 'left'
            yanchor = 'middle'
            y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
            x = 1.0
            textangle = 0

    annotation_dict = dict(
        textangle=textangle,
        xanchor=xanchor,
        yanchor=yanchor,
        x=x,
        y=y,
        showarrow=False,
        xref='paper',
        yref='paper',
        text=str(text),
        font=dict(
            size=13,
            color=AXIS_TITLE_COLOR
        )
    )
    return annotation_dict


def _axis_title_annotation(text, x_or_y_axis):
    if x_or_y_axis == 'x':
        x_pos = 0.5
        y_pos = -0.1
        textangle = 0
    elif x_or_y_axis == 'y':
        x_pos = -0.1
        y_pos = 0.5
        textangle = 270

    if not text:
        text = ''

    annot = {'font': {'color': '#000000', 'size': AXIS_TITLE_SIZE},
             'showarrow': False,
             'text': text,
             'textangle': textangle,
             'x': x_pos,
             'xanchor': 'center',
             'xref': 'paper',
             'y': y_pos,
             'yanchor': 'middle',
             'yref': 'paper'}
    return annot


def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False,
                       flipped_cols=False):
    shapes_list = []
    for key in fig['layout'].to_plotly_json().keys():
        if 'axis' in key and fig['layout'][key]['domain'] != [0.0, 1.0]:
            shape = {
               'fillcolor': annot_rect_color,
               'layer': 'below',
               'line': {'color': annot_rect_color, 'width': 1},
               'type': 'rect',
               'xref': 'paper',
               'yref': 'paper'
            }

            if 'xaxis' in key:
                shape['x0'] = fig['layout'][key]['domain'][0]
                shape['x1'] = fig['layout'][key]['domain'][1]
                shape['y0'] = 1.005
                shape['y1'] = 1.05

                if flipped_cols:
                    shape['y1'] += 0.5
                shapes_list.append(shape)

            elif 'yaxis' in key:
                shape['x0'] = 1.005
                shape['x1'] = 1.05
                shape['y0'] = fig['layout'][key]['domain'][0]
                shape['y1'] = fig['layout'][key]['domain'][1]

                if flipped_rows:
                    shape['x1'] += 1
                shapes_list.append(shape)

    fig['layout']['shapes'] = shapes_list


def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker):
    if trace_type in ['scatter', 'scattergl']:
        trace['mode'] = 'markers'
        trace['marker'] = dict(color=color, **kwargs_marker)
    return trace


def _facet_grid_color_categorical(df, x, y, facet_row, facet_col, color_name,
                                  colormap, num_of_rows, num_of_cols,
                                  facet_row_labels, facet_col_labels,
                                  trace_type, flipped_rows, flipped_cols,
                                  show_boxes, SUBPLOT_SPACING, marker_color,
                                  kwargs_trace, kwargs_marker):

    fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
                        shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=SUBPLOT_SPACING,
                        vertical_spacing=SUBPLOT_SPACING, print_grid=False)

    annotations = []
    if not facet_row and not facet_col:
        color_groups = list(df.groupby(color_name))
        for group in color_groups:
            trace = dict(
                type=trace_type,
                name=group[0],
                marker=dict(
                    color=colormap[group[0]],
                ),
                **kwargs_trace
            )
            if x:
                trace['x'] = group[1][x]
            if y:
                trace['y'] = group[1][y]
            trace = _make_trace_for_scatter(
                trace, trace_type, colormap[group[0]], **kwargs_marker
            )

            fig.append_trace(trace, 1, 1)

    elif (facet_row and not facet_col) or (not facet_row and facet_col):
        groups_by_facet = list(
            df.groupby(facet_row if facet_row else facet_col)
        )
        for j, group in enumerate(groups_by_facet):
            for color_val in df[color_name].unique():
                data_by_color = group[1][group[1][color_name] == color_val]
                trace = dict(
                    type=trace_type,
                    name=color_val,
                    marker=dict(
                        color=colormap[color_val],
                    ),
                    **kwargs_trace
                )
                if x:
                    trace['x'] = data_by_color[x]
                if y:
                    trace['y'] = data_by_color[y]
                trace = _make_trace_for_scatter(
                    trace, trace_type, colormap[color_val], **kwargs_marker
                )

                fig.append_trace(trace,
                                 j + 1 if facet_row else 1,
                                 1 if facet_row else j + 1)

            label = _return_label(
                group[0],
                facet_row_labels if facet_row else facet_col_labels,
                facet_row if facet_row else facet_col
            )

            annotations.append(
                _annotation_dict(
                    label,
                    num_of_rows - j if facet_row else j + 1,
                    num_of_rows if facet_row else num_of_cols,
                    SUBPLOT_SPACING,
                    'row' if facet_row else 'col',
                    flipped_rows)
            )

    elif facet_row and facet_col:
        groups_by_facets = list(df.groupby([facet_row, facet_col]))
        tuple_to_facet_group = {item[0]: item[1] for
                                item in groups_by_facets}

        row_values = df[facet_row].unique()
        col_values = df[facet_col].unique()
        color_vals = df[color_name].unique()
        for row_count, x_val in enumerate(row_values):
            for col_count, y_val in enumerate(col_values):
                try:
                    group = tuple_to_facet_group[(x_val, y_val)]
                except KeyError:
                    group = pd.DataFrame([[None, None, None]],
                                         columns=[x, y, color_name])

                for color_val in color_vals:
                    if group.values.tolist() != [[None, None, None]]:
                        group_filtered = group[group[color_name] == color_val]

                        trace = dict(
                            type=trace_type,
                            name=color_val,
                            marker=dict(
                                color=colormap[color_val],
                            ),
                            **kwargs_trace
                        )
                        new_x = group_filtered[x]
                        new_y = group_filtered[y]
                    else:
                        trace = dict(
                            type=trace_type,
                            name=color_val,
                            marker=dict(
                                color=colormap[color_val],
                            ),
                            showlegend=False,
                            **kwargs_trace
                        )
                        new_x = group[x]
                        new_y = group[y]

                    if x:
                        trace['x'] = new_x
                    if y:
                        trace['y'] = new_y
                    trace = _make_trace_for_scatter(
                        trace, trace_type, colormap[color_val],
                        **kwargs_marker
                    )

                    fig.append_trace(trace, row_count + 1, col_count + 1)
                if row_count == 0:
                    label = _return_label(col_values[col_count],
                                          facet_col_labels, facet_col)
                    annotations.append(
                        _annotation_dict(label, col_count + 1, num_of_cols,
                                         SUBPLOT_SPACING,
                                         row_col='col', flipped=flipped_cols)
                        )
            label = _return_label(row_values[row_count],
                                  facet_row_labels, facet_row)
            annotations.append(
                _annotation_dict(label, num_of_rows - row_count, num_of_rows,
                                 SUBPLOT_SPACING,
                                 row_col='row', flipped=flipped_rows)
            )

    return fig, annotations


def _facet_grid_color_numerical(df, x, y, facet_row, facet_col, color_name,
                                colormap, num_of_rows,
                                num_of_cols, facet_row_labels,
                                facet_col_labels, trace_type,
                                flipped_rows, flipped_cols, show_boxes,
                                SUBPLOT_SPACING, marker_color, kwargs_trace,
                                kwargs_marker):

    fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
                        shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=SUBPLOT_SPACING,
                        vertical_spacing=SUBPLOT_SPACING, print_grid=False)

    annotations = []
    if not facet_row and not facet_col:
        trace = dict(
            type=trace_type,
            marker=dict(
                color=df[color_name],
                colorscale=colormap,
                showscale=True,
            ),
            **kwargs_trace
        )
        if x:
            trace['x'] = df[x]
        if y:
            trace['y'] = df[y]
        trace = _make_trace_for_scatter(
            trace, trace_type, df[color_name], **kwargs_marker
        )

        fig.append_trace(trace, 1, 1)

    if (facet_row and not facet_col) or (not facet_row and facet_col):
        groups_by_facet = list(
            df.groupby(facet_row if facet_row else facet_col)
        )
        for j, group in enumerate(groups_by_facet):
            trace = dict(
                type=trace_type,
                marker=dict(
                    color=df[color_name],
                    colorscale=colormap,
                    showscale=True,
                    colorbar=dict(x=1.15),
                ),
                **kwargs_trace
            )
            if x:
                trace['x'] = group[1][x]
            if y:
                trace['y'] = group[1][y]
            trace = _make_trace_for_scatter(
                trace, trace_type, df[color_name], **kwargs_marker
            )

            fig.append_trace(
                trace,
                j + 1 if facet_row else 1,
                1 if facet_row else j + 1
            )

            labels = facet_row_labels if facet_row else facet_col_labels
            label = _return_label(
                group[0], labels, facet_row if facet_row else facet_col
            )

            annotations.append(
                _annotation_dict(
                    label,
                    num_of_rows - j if facet_row else j + 1,
                    num_of_rows if facet_row else num_of_cols,
                    SUBPLOT_SPACING,
                    'row' if facet_row else 'col',
                    flipped=flipped_rows)
            )

    elif facet_row and facet_col:
        groups_by_facets = list(df.groupby([facet_row, facet_col]))
        tuple_to_facet_group = {item[0]: item[1] for
                                item in groups_by_facets}

        row_values = df[facet_row].unique()
        col_values = df[facet_col].unique()
        for row_count, x_val in enumerate(row_values):
            for col_count, y_val in enumerate(col_values):
                try:
                    group = tuple_to_facet_group[(x_val, y_val)]
                except KeyError:
                    group = pd.DataFrame([[None, None, None]],
                                         columns=[x, y, color_name])

                if group.values.tolist() != [[None, None, None]]:
                    trace = dict(
                        type=trace_type,
                        marker=dict(
                            color=df[color_name],
                            colorscale=colormap,
                            showscale=(row_count == 0),
                            colorbar=dict(x=1.15),
                        ),
                        **kwargs_trace
                    )

                else:
                    trace = dict(
                        type=trace_type,
                        showlegend=False,
                        **kwargs_trace
                    )

                if x:
                    trace['x'] = group[x]
                if y:
                    trace['y'] = group[y]
                trace = _make_trace_for_scatter(
                    trace, trace_type, df[color_name], **kwargs_marker
                )

                fig.append_trace(trace, row_count + 1, col_count + 1)
                if row_count == 0:
                    label = _return_label(col_values[col_count],
                                          facet_col_labels, facet_col)
                    annotations.append(
                        _annotation_dict(label, col_count + 1, num_of_cols,
                                         SUBPLOT_SPACING,
                                         row_col='col', flipped=flipped_cols)
                        )
            label = _return_label(row_values[row_count],
                                  facet_row_labels, facet_row)
            annotations.append(
                _annotation_dict(row_values[row_count],
                                 num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING,
                                 row_col='row', flipped=flipped_rows)
            )

    return fig, annotations


def _facet_grid(df, x, y, facet_row, facet_col, num_of_rows,
                num_of_cols, facet_row_labels, facet_col_labels,
                trace_type, flipped_rows, flipped_cols, show_boxes,
                SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker):

    fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
                        shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=SUBPLOT_SPACING,
                        vertical_spacing=SUBPLOT_SPACING, print_grid=False)
    annotations = []
    if not facet_row and not facet_col:
        trace = dict(
            type=trace_type,
            marker=dict(
                color=marker_color,
                line=kwargs_marker['line'],
            ),
            **kwargs_trace
        )

        if x:
            trace['x'] = df[x]
        if y:
            trace['y'] = df[y]
        trace = _make_trace_for_scatter(
            trace, trace_type, marker_color, **kwargs_marker
        )

        fig.append_trace(trace, 1, 1)

    elif (facet_row and not facet_col) or (not facet_row and facet_col):
        groups_by_facet = list(
            df.groupby(facet_row if facet_row else facet_col)
        )
        for j, group in enumerate(groups_by_facet):
            trace = dict(
                type=trace_type,
                marker=dict(
                    color=marker_color,
                    line=kwargs_marker['line'],
                ),
                **kwargs_trace
            )

            if x:
                trace['x'] = group[1][x]
            if y:
                trace['y'] = group[1][y]
            trace = _make_trace_for_scatter(
                trace, trace_type, marker_color, **kwargs_marker
            )

            fig.append_trace(trace,
                             j + 1 if facet_row else 1,
                             1 if facet_row else j + 1)

            label = _return_label(
                group[0],
                facet_row_labels if facet_row else facet_col_labels,
                facet_row if facet_row else facet_col
            )

            annotations.append(
                _annotation_dict(
                    label,
                    num_of_rows - j if facet_row else j + 1,
                    num_of_rows if facet_row else num_of_cols,
                    SUBPLOT_SPACING,
                    'row' if facet_row else 'col',
                    flipped_rows
                )
            )

    elif facet_row and facet_col:
        groups_by_facets = list(df.groupby([facet_row, facet_col]))
        tuple_to_facet_group = {item[0]: item[1] for
                                item in groups_by_facets}

        row_values = df[facet_row].unique()
        col_values = df[facet_col].unique()
        for row_count, x_val in enumerate(row_values):
            for col_count, y_val in enumerate(col_values):
                try:
                    group = tuple_to_facet_group[(x_val, y_val)]
                except KeyError:
                    group = pd.DataFrame([[None, None]], columns=[x, y])
                trace = dict(
                    type=trace_type,
                    marker=dict(
                        color=marker_color,
                        line=kwargs_marker['line'],
                    ),
                    **kwargs_trace
                )
                if x:
                    trace['x'] = group[x]
                if y:
                    trace['y'] = group[y]
                trace = _make_trace_for_scatter(
                    trace, trace_type, marker_color, **kwargs_marker
                )

                fig.append_trace(trace, row_count + 1, col_count + 1)
                if row_count == 0:
                    label = _return_label(col_values[col_count],
                                          facet_col_labels,
                                          facet_col)
                    annotations.append(
                        _annotation_dict(label, col_count + 1, num_of_cols, SUBPLOT_SPACING,
                                         row_col='col', flipped=flipped_cols)
                        )

            label = _return_label(row_values[row_count],
                                  facet_row_labels,
                                  facet_row)
            annotations.append(
                _annotation_dict(label, num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING,
                                 row_col='row', flipped=flipped_rows)
            )

    return fig, annotations


def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None,
                      color_name=None, colormap=None, color_is_cat=False,
                      facet_row_labels=None, facet_col_labels=None,
                      height=None, width=None, trace_type='scatter',
                      scales='fixed', dtick_x=None, dtick_y=None,
                      show_boxes=True, ggplot2=False, binsize=1, **kwargs):
    """
    Returns figure for facet grid.

    :param (pd.DataFrame) df: the dataframe of columns for the facet grid.
    :param (str) x: the name of the dataframe column for the x axis data.
    :param (str) y: the name of the dataframe column for the y axis data.
    :param (str) facet_row: the name of the dataframe column that is used to
        facet the grid into row panels.
    :param (str) facet_col: the name of the dataframe column that is used to
        facet the grid into column panels.
    :param (str) color_name: the name of your dataframe column that will
        function as the colormap variable.
    :param (str|list|dict) colormap: the param that determines how the
        color_name column colors the data. If the dataframe contains numeric
        data, then a dictionary of colors will group the data categorically
        while a Plotly Colorscale name or a custom colorscale will treat it
        numerically. To learn more about colors and types of colormap, run
        `help(plotly.colors)`.
    :param (bool) color_is_cat: determines whether a numerical column for the
        colormap will be treated as categorical (True) or sequential (False).
            Default = False.
    :param (str|dict) facet_row_labels: set to either 'name' or a dictionary
        of all the unique values in the faceting row mapped to some text to
        show up in the label annotations. If None, labeling works like usual.
    :param (str|dict) facet_col_labels: set to either 'name' or a dictionary
        of all the values in the faceting row mapped to some text to show up
        in the label annotations. If None, labeling works like usual.
    :param (int) height: the height of the facet grid figure.
    :param (int) width: the width of the facet grid figure.
    :param (str) trace_type: decides the type of plot to appear in the
        facet grid. The options are 'scatter', 'scattergl', 'histogram',
        'bar', and 'box'.
        Default = 'scatter'.
    :param (str) scales: determines if axes have fixed ranges or not. Valid
        settings are 'fixed' (all axes fixed), 'free_x' (x axis free only),
        'free_y' (y axis free only) or 'free' (both axes free).
    :param (float) dtick_x: determines the distance between each tick on the
        x-axis. Default is None which means dtick_x is set automatically.
    :param (float) dtick_y: determines the distance between each tick on the
        y-axis. Default is None which means dtick_y is set automatically.
    :param (bool) show_boxes: draws grey boxes behind the facet titles.
    :param (bool) ggplot2: draws the facet grid in the style of `ggplot2`. See
        http://ggplot2.tidyverse.org/reference/facet_grid.html for reference.
        Default = False
    :param (int) binsize: groups all data into bins of a given length.
    :param (dict) kwargs: a dictionary of scatterplot arguments.

    Examples 1: One Way Faceting
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')

    fig = ff.create_facet_grid(
        mpg,
        x='displ',
        y='cty',
        facet_col='cyl',
    )
    py.iplot(fig, filename='facet_grid_mpg_one_way_facet')
    ```

    Example 2: Two Way Faceting
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')

    fig = ff.create_facet_grid(
        mpg,
        x='displ',
        y='cty',
        facet_row='drv',
        facet_col='cyl',
    )
    py.iplot(fig, filename='facet_grid_mpg_two_way_facet')
    ```

    Example 3: Categorical Coloring
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')

    fig = ff.create_facet_grid(
        mtcars,
        x='mpg',
        y='wt',
        facet_col='cyl',
        color_name='cyl',
        color_is_cat=True,
    )
    py.iplot(fig, filename='facet_grid_mpg_default_colors')
    ```

    Example 4: Sequential Coloring
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    tips = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/tips.csv')

    fig = ff.create_facet_grid(
        tips,
        x='total_bill',
        y='tip',
        facet_row='sex',
        facet_col='smoker',
        color_name='size',
        colormap='Viridis',
    )
    py.iplot(fig, filename='facet_grid_tips_sequential_colors')
    ```

    Example 5: Custom labels
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv')

    fig = ff.create_facet_grid(
        mtcars,
        x='wt',
        y='mpg',
        facet_col='cyl',
        facet_col_labels={4: "$\\alpha$", 6: '$\\beta$', 8: '$\sqrt[y]{x}$'},
    )

    py.iplot(fig, filename='facet_grid_mtcars_custom_labels')
    ```

    Example 6: Other Trace Type
    ```
    import plotly.plotly as py
    import plotly.figure_factory as ff

    import pandas as pd

    mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv')

    fig = ff.create_facet_grid(
        mtcars,
        x='wt',
        facet_col='cyl',
        trace_type='histogram',
    )

    py.iplot(fig, filename='facet_grid_mtcars_other_trace_type')
    ```
    """
    if not pd:
        raise exceptions.ImportError(
            "'pandas' must be installed for this figure_factory."
        )

    if not isinstance(df, pd.DataFrame):
        raise exceptions.PlotlyError(
            "You must input a pandas DataFrame."
        )

    # make sure all columns are of homogenous datatype
    utils.validate_dataframe(df)

    if trace_type in ['scatter', 'scattergl']:
        if not x or not y:
            raise exceptions.PlotlyError(
                "You need to input 'x' and 'y' if you are you are using a "
                "trace_type of 'scatter' or 'scattergl'."
            )

    for key in [x, y, facet_row, facet_col, color_name]:
        if key is not None:
            try:
                df[key]
            except KeyError:
                raise exceptions.PlotlyError(
                    "x, y, facet_row, facet_col and color_name must be keys "
                    "in your dataframe."
                )
    # autoscale histogram bars
    if trace_type not in ['scatter', 'scattergl']:
        scales = 'free'

    # validate scales
    if scales not in ['fixed', 'free_x', 'free_y', 'free']:
        raise exceptions.PlotlyError(
            "'scales' must be set to 'fixed', 'free_x', 'free_y' and 'free'."
        )

    if trace_type not in VALID_TRACE_TYPES:
        raise exceptions.PlotlyError(
            "'trace_type' must be in {}".format(VALID_TRACE_TYPES)
        )

    if trace_type == 'histogram':
        SUBPLOT_SPACING = 0.06
    else:
        SUBPLOT_SPACING = 0.015

    # seperate kwargs for marker and else
    if 'marker' in kwargs:
        kwargs_marker = kwargs['marker']
    else:
        kwargs_marker = {}
    marker_color = kwargs_marker.pop('color', None)
    kwargs.pop('marker', None)
    kwargs_trace = kwargs

    if 'size' not in kwargs_marker:
        if ggplot2:
            kwargs_marker['size'] = 5
        else:
            kwargs_marker['size'] = 8

    if 'opacity' not in kwargs_marker:
        if not ggplot2:
            kwargs_trace['opacity'] = 0.6

    if 'line' not in kwargs_marker:
        if not ggplot2:
            kwargs_marker['line'] = {'color': 'darkgrey', 'width': 1}
        else:
            kwargs_marker['line'] = {}

    # default marker size
    if not ggplot2:
        if not marker_color:
            marker_color = 'rgb(31, 119, 180)'
    else:
        marker_color = 'rgb(0, 0, 0)'

    num_of_rows = 1
    num_of_cols = 1
    flipped_rows = False
    flipped_cols = False
    if facet_row:
        num_of_rows = len(df[facet_row].unique())
        flipped_rows = _is_flipped(num_of_rows)
        if isinstance(facet_row_labels, dict):
            for key in df[facet_row].unique():
                if key not in facet_row_labels.keys():
                    unique_keys = df[facet_row].unique().tolist()
                    raise exceptions.PlotlyError(
                        CUSTOM_LABEL_ERROR.format(unique_keys)
                    )
    if facet_col:
        num_of_cols = len(df[facet_col].unique())
        flipped_cols = _is_flipped(num_of_cols)
        if isinstance(facet_col_labels, dict):
            for key in df[facet_col].unique():
                if key not in facet_col_labels.keys():
                    unique_keys = df[facet_col].unique().tolist()
                    raise exceptions.PlotlyError(
                        CUSTOM_LABEL_ERROR.format(unique_keys)
                    )
    show_legend = False
    if color_name:
        if isinstance(df[color_name].iloc[0], str) or color_is_cat:
            show_legend = True
            if isinstance(colormap, dict):
                utils.validate_colors_dict(colormap, 'rgb')

                for val in df[color_name].unique():
                    if val not in colormap.keys():
                        raise exceptions.PlotlyError(
                            "If using 'colormap' as a dictionary, make sure "
                            "all the values of the colormap column are in "
                            "the keys of your dictionary."
                        )
            else:
                # use default plotly colors for dictionary
                default_colors = utils.DEFAULT_PLOTLY_COLORS
                colormap = {}
                j = 0
                for val in df[color_name].unique():
                    if j >= len(default_colors):
                        j = 0
                    colormap[val] = default_colors[j]
                    j += 1
            fig, annotations = _facet_grid_color_categorical(
                df, x, y, facet_row, facet_col, color_name, colormap,
                num_of_rows, num_of_cols, facet_row_labels, facet_col_labels,
                trace_type, flipped_rows, flipped_cols, show_boxes,
                SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker
            )

        elif isinstance(df[color_name].iloc[0], Number):
            if isinstance(colormap, dict):
                show_legend = True
                utils.validate_colors_dict(colormap, 'rgb')

                for val in df[color_name].unique():
                    if val not in colormap.keys():
                        raise exceptions.PlotlyError(
                            "If using 'colormap' as a dictionary, make sure "
                            "all the values of the colormap column are in "
                            "the keys of your dictionary."
                        )
                fig, annotations = _facet_grid_color_categorical(
                    df, x, y, facet_row, facet_col, color_name, colormap,
                    num_of_rows, num_of_cols, facet_row_labels,
                    facet_col_labels, trace_type, flipped_rows,
                    flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color,
                    kwargs_trace, kwargs_marker
                )

            elif isinstance(colormap, list):
                colorscale_list = colormap
                utils.validate_colorscale(colorscale_list)

                fig, annotations = _facet_grid_color_numerical(
                    df, x, y, facet_row, facet_col, color_name,
                    colorscale_list, num_of_rows, num_of_cols,
                    facet_row_labels, facet_col_labels, trace_type,
                    flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
                    marker_color, kwargs_trace, kwargs_marker
                )
            elif isinstance(colormap, str):
                if colormap in colors.PLOTLY_SCALES.keys():
                    colorscale_list = colors.PLOTLY_SCALES[colormap]
                else:
                    raise exceptions.PlotlyError(
                        "If 'colormap' is a string, it must be the name "
                        "of a Plotly Colorscale. The available colorscale "
                        "names are {}".format(colors.PLOTLY_SCALES.keys())
                    )
                fig, annotations = _facet_grid_color_numerical(
                    df, x, y, facet_row, facet_col, color_name,
                    colorscale_list, num_of_rows, num_of_cols,
                    facet_row_labels, facet_col_labels, trace_type,
                    flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
                    marker_color, kwargs_trace, kwargs_marker
                )
            else:
                colorscale_list = colors.PLOTLY_SCALES['Reds']
                fig, annotations = _facet_grid_color_numerical(
                    df, x, y, facet_row, facet_col, color_name,
                    colorscale_list, num_of_rows, num_of_cols,
                    facet_row_labels, facet_col_labels, trace_type,
                    flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
                    marker_color, kwargs_trace, kwargs_marker
                )

    else:
        fig, annotations = _facet_grid(
            df, x, y, facet_row, facet_col, num_of_rows, num_of_cols,
            facet_row_labels, facet_col_labels, trace_type, flipped_rows,
            flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color,
            kwargs_trace, kwargs_marker
        )

    if not height:
        height = max(600, 100 * num_of_rows)
    if not width:
        width = max(600, 100 * num_of_cols)

    fig['layout'].update(height=height, width=width, title='',
                         paper_bgcolor='rgb(251, 251, 251)')
    if ggplot2:
        fig['layout'].update(plot_bgcolor=PLOT_BGCOLOR,
                             paper_bgcolor='rgb(255, 255, 255)',
                             hovermode='closest')

    # axis titles
    x_title_annot = _axis_title_annotation(x, 'x')
    y_title_annot = _axis_title_annotation(y, 'y')

    # annotations
    annotations.append(x_title_annot)
    annotations.append(y_title_annot)

    # legend
    fig['layout']['showlegend'] = show_legend
    fig['layout']['legend']['bgcolor'] = LEGEND_COLOR
    fig['layout']['legend']['borderwidth'] = LEGEND_BORDER_WIDTH
    fig['layout']['legend']['x'] = 1.05
    fig['layout']['legend']['y'] = 1
    fig['layout']['legend']['yanchor'] = 'top'

    if show_legend:
        fig['layout']['showlegend'] = show_legend
        if ggplot2:
            if color_name:
                legend_annot = _legend_annotation(color_name)
                annotations.append(legend_annot)
            fig['layout']['margin']['r'] = 150

    # assign annotations to figure
    fig['layout']['annotations'] = annotations

    # add shaded boxes behind axis titles
    if show_boxes and ggplot2:
        _add_shapes_to_fig(fig, ANNOT_RECT_COLOR, flipped_rows, flipped_cols)

    # all xaxis and yaxis labels
    axis_labels = {'x': [], 'y': []}
    for key in fig['layout']:
        if 'xaxis' in key:
            axis_labels['x'].append(key)
        elif 'yaxis' in key:
            axis_labels['y'].append(key)

    string_number_in_data = False
    for var in [v for v in [x, y] if v]:
        if isinstance(df[var].tolist()[0], str):
            for item in df[var]:
                try:
                    int(item)
                    string_number_in_data = True
                except ValueError:
                    pass

    if string_number_in_data:
        for x_y in axis_labels.keys():
            for axis_name in axis_labels[x_y]:
                fig['layout'][axis_name]['type'] = 'category'

    if scales == 'fixed':
        fixed_axes = ['x', 'y']
    elif scales == 'free_x':
        fixed_axes = ['y']
    elif scales == 'free_y':
        fixed_axes = ['x']
    elif scales == 'free':
        fixed_axes = []

    # fixed ranges
    for x_y in fixed_axes:
        min_ranges = []
        max_ranges = []
        for trace in fig['data']:
            if trace[x_y] is not None and len(trace[x_y]) > 0:
                min_ranges.append(min(trace[x_y]))
                max_ranges.append(max(trace[x_y]))
        while None in min_ranges:
            min_ranges.remove(None)
        while None in max_ranges:
            max_ranges.remove(None)

        min_range = min(min_ranges)
        max_range = max(max_ranges)

        range_are_numbers = (isinstance(min_range, Number) and
                             isinstance(max_range, Number))

        if range_are_numbers:
            min_range = math.floor(min_range)
            max_range = math.ceil(max_range)

            # extend widen frame by 5% on each side
            min_range -= 0.05 * (max_range - min_range)
            max_range += 0.05 * (max_range - min_range)

            if x_y == 'x':
                if dtick_x:
                    dtick = dtick_x
                else:
                    dtick = math.floor(
                        (max_range - min_range) / MAX_TICKS_PER_AXIS
                    )
            elif x_y == 'y':
                if dtick_y:
                    dtick = dtick_y
                else:
                    dtick = math.floor(
                        (max_range - min_range) / MAX_TICKS_PER_AXIS
                    )
        else:
            dtick = 1

        for axis_title in axis_labels[x_y]:
            fig['layout'][axis_title]['dtick'] = dtick
            fig['layout'][axis_title]['ticklen'] = 0
            fig['layout'][axis_title]['zeroline'] = False
            if ggplot2:
                fig['layout'][axis_title]['tickwidth'] = 1
                fig['layout'][axis_title]['ticklen'] = 4
                fig['layout'][axis_title]['gridwidth'] = GRID_WIDTH

                fig['layout'][axis_title]['gridcolor'] = GRID_COLOR
                fig['layout'][axis_title]['gridwidth'] = 2
                fig['layout'][axis_title]['tickfont'] = {
                    'color': TICK_COLOR, 'size': 10
                }

        # insert ranges into fig
        if x_y in fixed_axes:
            for key in fig['layout']:
                if '{}axis'.format(x_y) in key and range_are_numbers:
                    fig['layout'][key]['range'] = [min_range, max_range]

    return fig
