import * as tf from '@tensorflow/tfjs/dist/tf.es2017.js';

import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common';
import { isFloat } from '../utils';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';

function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {

  function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
    const weights = extractWeights(numFilterValues)
    const depth = weights.length / (numFilters * filterSize * filterSize)

    if (isFloat(depth)) {
      throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`)
    }

    return tf.tidy(
      () => tf.transpose(
        tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
        [2, 3, 1, 0]
      )
    )
  }

  function extractConvParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string
  ): ConvParams {

    const filters = extractFilterValues(numFilterValues, numFilters, filterSize)
    const bias = tf.tensor1d(extractWeights(numFilters))

    paramMappings.push(
      { paramPath: `${mappedPrefix}/filters` },
      { paramPath: `${mappedPrefix}/bias` }
    )

    return { filters, bias }
  }

  function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {

    const weights = tf.tensor1d(extractWeights(numWeights))
    const biases = tf.tensor1d(extractWeights(numWeights))

    paramMappings.push(
      { paramPath: `${mappedPrefix}/weights` },
      { paramPath: `${mappedPrefix}/biases` }
    )

    return {
      weights,
      biases
    }
  }

  function extractConvLayerParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string
  ): ConvLayerParams {

    const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`)
    const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`)

    return { conv, scale }
  }

  function extractResidualLayerParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string,
    isDown: boolean = false
  ): ResidualLayerParams {

    const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`)
    const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`)

    return { conv1, conv2 }
  }

  return {
    extractConvLayerParams,
    extractResidualLayerParams
  }

}

export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {

  const {
    extractWeights,
    getRemainingWeights
  } = extractWeightsFactory(weights)

  const paramMappings: ParamMapping[] = []

  const {
    extractConvLayerParams,
    extractResidualLayerParams
  } = extractorsFactory(extractWeights, paramMappings)

  const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down')
  const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1')
  const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2')
  const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3')

  const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true)
  const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1')
  const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2')
  const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3')

  const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true)
  const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1')
  const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2')

  const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true)
  const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1')
  const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2')
  const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out')

  const fc = tf.tidy(
    () => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0])
  )
  paramMappings.push({ paramPath: `fc` })

  if (getRemainingWeights().length !== 0) {
    throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
  }

  const params = {
    conv32_down,
    conv32_1,
    conv32_2,
    conv32_3,
    conv64_down,
    conv64_1,
    conv64_2,
    conv64_3,
    conv128_down,
    conv128_1,
    conv128_2,
    conv256_down,
    conv256_1,
    conv256_2,
    conv256_down_out,
    fc
  }

  return { params, paramMappings }
}