import {
  scaleLinear,
  scaleTime,
  scaleBand,
  scaleOrdinal,
  scaleSequential,
} from 'd3-scale'
import { extent } from 'd3-array'
import type { XYScaleProps, VisualEncodingProps } from '../atoms/scales/types'
import type {
  DataValue,
  XYScaleTypes,
  VisualEncodingTypes,
  GGProps,
} from '../gg'
import {
  defaultScheme,
  defaultInterpolator,
  defaultDasharrays,
  createSequentialScheme,
} from './scaleDefaults'
import { defineGroupAccessor } from './defineGroupAccessor'
import { isDate } from './dates'

export interface IScale<Datum> {
  xScale: XYScaleTypes
  yScale: XYScaleTypes
  fillScale?: VisualEncodingTypes
  strokeScale?: VisualEncodingTypes
  strokeDasharrayScale?: VisualEncodingTypes
  groupAccessor: DataValue<Datum> | undefined
  groups?: string[]
}

export interface AutoScale<Datum> extends GGProps<Datum> {
  scalesState: {
    x: XYScaleProps
    y: XYScaleProps
    hasZeroXBaseLine: boolean
    hasZeroYBaseLine: boolean
    geomGroupAccessors: DataValue<Datum>[]
    y0Aes?: DataValue<Datum>
    y1Aes?: DataValue<Datum>
    geomAesYs: DataValue<Datum>[]
    geomAesStrokes: DataValue<Datum>[]
    geomAesStrokeDasharrays: DataValue<Datum>[]
    geomAesFills: DataValue<Datum>[]
    fill?: VisualEncodingProps
    stroke?: VisualEncodingProps
    strokeDasharray?: VisualEncodingProps
  }
  copiedData: Datum[]
  shouldExcludeMissingXYFromDomains?: boolean
}

export const autoScale = <Datum>({
  scalesState,
  data,
  copiedData,
  aes,
  width = 500,
  height = 450,
  margin: suppliedMargin,
  shouldExcludeMissingXYFromDomains,
}: AutoScale<Datum>): IScale<Datum> => {
  const margin = {
    top: 10,
    right: 20,
    bottom: 10,
    left: 30,
    ...suppliedMargin,
  }

  const {
    x: xScaleState,
    y: yScaleState,
    fill: fillScaleState,
    stroke: strokeScaleState,
    strokeDasharray: strokeDasharrayState,
    hasZeroXBaseLine,
    hasZeroYBaseLine,
    geomGroupAccessors,
    y0Aes,
    y1Aes,
    geomAesYs,
    geomAesStrokes,
    geomAesFills,
    geomAesStrokeDasharrays,
  } = scalesState
  const {
    domain: xScaleDomain,
    type: xScaleType,
    reverse: reverseX,
  } = xScaleState || {}
  const {
    domain: yScaleDomain,
    type: yScaleType,
    reverse: reverseY,
  } = yScaleState || {}
  const {
    domain: fillScaleDomain,
    type: fillScaleType,
    values: fillScaleColors,
    reverse: fillScaleReverse,
  } = fillScaleState || {}
  const {
    domain: strokeScaleDomain,
    type: strokeScaleType,
    values: strokeScaleColors,
    reverse: strokeScaleReverse,
  } = strokeScaleState || {}
  const { domain: strokeDasharrayDomain, values: strokeDasharrays } =
    strokeDasharrayState || {}

  // used for maintaining the member order (domain) in categorical axes
  const sortDomain = (a: string, b: string, initialDomain: string[]) =>
    initialDomain.indexOf(a) - initialDomain.indexOf(b)

  const geomGroupAccessor = geomGroupAccessors.length
    ? geomGroupAccessors[0]
    : undefined

  // identify groups
  const group =
    aes?.fill || aes?.stroke || aes?.strokeDasharray || aes?.group
      ? defineGroupAccessor(aes)
      : geomGroupAccessor

  let hasCategoricalVar = aes.group || geomGroupAccessors.length || false
  const calculatedGroups = group
    ? (Array.from(new Set(data.map(group))) as string[])
    : ['__group']

  const thisYAes = aes.y || (geomAesYs.length ? geomAesYs[0] : undefined)
  const resolvedYAes = thisYAes ?? y1Aes ?? y0Aes
  const thisStrokeAes =
    aes.stroke || (geomAesStrokes.length ? geomAesStrokes[0] : undefined)
  const thisFillAes =
    aes.fill || (geomAesFills.length ? geomAesFills[0] : undefined)
  const thisStrokeDasharrayAes =
    aes.strokeDasharray ||
    (geomAesStrokeDasharrays.length ? geomAesStrokeDasharrays[0] : undefined)

  /// SCALING ///

  let xScale
  const firstX = data.map(aes.x).find((d) => d !== null && d !== undefined)
  if (isDate(firstX)) {
    const domain =
      (xScaleDomain as Date[]) || extent(data, aes.x as (d: unknown) => Date)

    const hasDomain =
      typeof domain[0] !== 'undefined' &&
      typeof domain[1] !== 'undefined' &&
      // check for only null Dates
      domain[0].valueOf() !== 0 &&
      domain[1].valueOf() !== 0

    xScale = scaleTime()
      .range([margin.left, width - margin.right])
      .domain(hasDomain ? domain : [0, 0])
  } else if (typeof firstX === 'number') {
    const defaultDomain = extent(data, aes.x as (d: unknown) => number)

    const domain = (xScaleDomain as number[]) || [
      hasZeroXBaseLine ? 0 : defaultDomain[0],
      defaultDomain[1],
    ]

    const hasDomain =
      typeof domain[0] !== 'undefined' && typeof domain[1] !== 'undefined'

    const xType: any = xScaleType || scaleLinear
    xScale = xType()
      .range([margin.left, width - margin.right])
      .domain(hasDomain ? domain : [0, 1])
  } else if (!Number.isFinite(firstX) || typeof firstX === 'string') {
    // hasCategoricalVar = true
    // maintain the existing order
    const initialDomain = Array.from(new Set(copiedData.map(aes.x))) as string[]
    const computedDomain = Array.from(new Set(data.map(aes.x))) as string[]

    const domain =
      (xScaleDomain as string[]) ||
      computedDomain
        .filter((d) =>
          shouldExcludeMissingXYFromDomains
            ? d !== null && typeof d !== 'undefined'
            : true,
        )
        .sort((a, b) => sortDomain(a, b, initialDomain))

    xScale = scaleBand()
      .range([margin.left, width - margin.right])
      .domain(domain)
  }
  if (reverseX) xScale?.domain(xScale.domain().reverse())

  let yScale

  if (resolvedYAes) {
    const firstY = data
      .map(resolvedYAes)
      .find((d) => d !== null && d !== undefined)

    if (isDate(firstY)) {
      const domain =
        (yScaleDomain as Date[]) ||
        extent(data, thisYAes as (d: unknown) => Date)

      const hasDomain =
        typeof domain[0] !== 'undefined' && typeof domain[1] !== 'undefined'

      yScale = scaleTime()
        .range([height - margin.bottom, margin.top])
        .domain(hasDomain ? domain : [0, 1])
    } else if (typeof firstY === 'number') {
      const defaultDomain = extent(data, thisYAes as (d: unknown) => number)

      const domain = yScaleDomain ?? [
        hasZeroYBaseLine ? 0 : defaultDomain[0],
        defaultDomain[1],
      ]

      const hasDomain =
        typeof domain[0] !== 'undefined' && typeof domain[1] !== 'undefined'

      const yType: any = yScaleType || scaleLinear

      yScale = yType()
        .range([height - margin.bottom, margin.top])
        .domain(hasDomain ? domain : [0, 1])
    } else if (!Number.isFinite(firstY) || typeof firstY === 'string') {
      // hasCategoricalVar = true
      // maintain the existing order
      const initialDomain = Array.from(
        new Set(copiedData.map(resolvedYAes)),
      ) as string[]
      const computedDomain = Array.from(
        new Set(data.map(resolvedYAes)),
      ) as string[]

      const domain =
        (yScaleDomain as string[]) ||
        computedDomain
          .filter((d) =>
            shouldExcludeMissingXYFromDomains
              ? d !== null && typeof d !== 'undefined'
              : true,
          )
          .sort((a, b) => sortDomain(a, b, initialDomain))

      yScale = scaleBand()
        .range([margin.top, height - margin.bottom])
        .domain(domain)
    }
  } else {
    yScale = scaleLinear()
      .range([height - margin.bottom, margin.top])
      .domain([0, 1])
  }
  if (reverseY) yScale?.domain(yScale.domain().reverse())

  // fill
  let fillScale
  if (thisFillAes) {
    const firstFill = data
      .map(thisFillAes)
      .find((d) => d !== null && d !== undefined)

    const continuousDomain =
      (fillScaleDomain as number[]) ||
      (extent(data, thisFillAes as (d: unknown) => number) as number[])

    const continuousInterpolator =
      (fillScaleColors as (t: number) => string) || defaultInterpolator

    const categoricalDomain = fillScaleDomain || calculatedGroups

    const discreteDomain =
      fillScaleDomain || data.map((d) => (group ? group(d) : '__group'))
    const discreteColors = fillScaleColors || defaultScheme
    const discreteSequential =
      fillScaleColors || createSequentialScheme(continuousInterpolator)

    if (fillScaleType) {
      const fillType = fillScaleType as any

      if (fillType()?.invertExtent) {
        fillScale = fillType()
          .domain(
            fillType.name === 'quantize' ? continuousDomain : discreteDomain,
          )
          .range(discreteSequential) as VisualEncodingTypes
      } else if (fillType()?.interpolator) {
        fillScale = fillType()
          .domain(continuousDomain)
          .interpolator(continuousInterpolator) as VisualEncodingTypes
      } else {
        hasCategoricalVar = true

        fillScale = fillType()
          .domain(categoricalDomain)
          .range(discreteColors) as VisualEncodingTypes
      }
    } else if (!Number.isFinite(firstFill) || typeof firstFill === 'string') {
      hasCategoricalVar = true

      fillScale = scaleOrdinal()
        .domain(categoricalDomain)
        .range(discreteColors as string[]) as VisualEncodingTypes
    } else if (isDate(firstFill) || typeof firstFill === 'number') {
      hasCategoricalVar = false

      fillScale = scaleSequential()
        .domain(continuousDomain)
        .interpolator(continuousInterpolator) as VisualEncodingTypes
    }
  }
  if (fillScaleReverse && fillScale?.interpolator)
    fillScale?.domain(fillScale.domain().reverse())

  // stroke
  let strokeScale
  if (thisStrokeAes) {
    const firstStroke = data
      .map(thisStrokeAes)
      .find((d) => d !== null && d !== undefined)

    if (strokeScaleType) {
      let domain
      const strokeType = strokeScaleType as any
      switch (strokeScaleType.name) {
        case 'sequential':
          domain =
            (strokeScaleDomain as number[]) ||
            (extent(data, thisStrokeAes as (d: unknown) => number) as number[])

          strokeScale = strokeType()
            .domain(domain)
            .interpolator(
              (strokeScaleColors as (t: number) => string) ||
                defaultInterpolator,
            ) as VisualEncodingTypes
          break
        case 'sequentialLog':
          domain =
            (strokeScaleDomain as number[]) ||
            (extent(data, thisStrokeAes as (d: unknown) => number) as number[])

          strokeScale = strokeType()
            .domain(domain)
            .interpolator(
              (strokeScaleColors as (t: number) => string) ||
                defaultInterpolator,
            ) as VisualEncodingTypes
          break
        case 'sequentialSqrt':
          domain =
            (strokeScaleDomain as number[]) ||
            (extent(data, thisStrokeAes as (d: unknown) => number) as number[])

          strokeScale = strokeType()
            .domain(domain)
            .interpolator(
              (strokeScaleColors as (t: number) => string) ||
                defaultInterpolator,
            ) as VisualEncodingTypes
          break
        case 'ordinal':
          hasCategoricalVar = true
          domain = (strokeScaleDomain as string[]) || calculatedGroups

          strokeScale = strokeType()
            .domain(domain)
            .range(
              (strokeScaleColors as string[]) || defaultScheme,
            ) as VisualEncodingTypes
          break
        default:
          hasCategoricalVar = true
          domain = (strokeScaleDomain as string[]) || calculatedGroups

          strokeScale = strokeType()
            .domain(domain)
            .range(
              (strokeScaleColors as string[]) || defaultScheme,
            ) as VisualEncodingTypes
      }
    } else if (
      !Number.isFinite(firstStroke) ||
      typeof firstStroke === 'string'
    ) {
      hasCategoricalVar = true
      const domain = (strokeScaleDomain as string[]) || calculatedGroups

      strokeScale = scaleOrdinal()
        .domain(domain)
        .range(
          (strokeScaleColors as string[]) || defaultScheme,
        ) as VisualEncodingTypes
    } else if (isDate(firstStroke) || typeof firstStroke === 'number') {
      const domain =
        (strokeScaleDomain as number[]) ||
        (extent(data, thisStrokeAes as (d: unknown) => number) as number[])

      strokeScale = scaleSequential()
        .domain(domain)
        .interpolator(
          (strokeScaleColors as (t: number) => string) || defaultInterpolator,
        ) as VisualEncodingTypes
    }
  }
  if (strokeScaleReverse) strokeScale?.domain(strokeScale.domain().reverse())

  // strokeDasharray
  let strokeDasharrayScale
  if (thisStrokeDasharrayAes) {
    hasCategoricalVar = true
    const domain = (strokeDasharrayDomain as string[]) || calculatedGroups

    strokeDasharrayScale = scaleOrdinal()
      .domain(domain)
      .range(
        (strokeDasharrays as string[]) || defaultDasharrays,
      ) as VisualEncodingTypes
  }

  return {
    xScale,
    yScale,
    fillScale,
    strokeScale,
    strokeDasharrayScale,
    groupAccessor: group,
    groups: hasCategoricalVar
      ? calculatedGroups
      : // ? fillScale?.domain() ?? calculatedGroups
        undefined,
  }
}
