import * as tf from '../../dist/tfjs.esm';

import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common/index';
import { isFloat } from '../utils/index';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';

function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
  function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
    const weights = extractWeights(numFilterValues);
    const depth = weights.length / (numFilters * filterSize * filterSize);

    if (isFloat(depth)) {
      throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`);
    }

    return tf.tidy(
      () => tf.transpose(
        tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
        [2, 3, 1, 0],
      ),
    );
  }

  function extractConvParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string,
  ): ConvParams {
    const filters = extractFilterValues(numFilterValues, numFilters, filterSize);
    const bias = tf.tensor1d(extractWeights(numFilters));

    paramMappings.push(
      { paramPath: `${mappedPrefix}/filters` },
      { paramPath: `${mappedPrefix}/bias` },
    );

    return { filters, bias };
  }

  function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {
    const weights = tf.tensor1d(extractWeights(numWeights));
    const biases = tf.tensor1d(extractWeights(numWeights));

    paramMappings.push(
      { paramPath: `${mappedPrefix}/weights` },
      { paramPath: `${mappedPrefix}/biases` },
    );

    return {
      weights,
      biases,
    };
  }

  function extractConvLayerParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string,
  ): ConvLayerParams {
    const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`);
    const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`);

    return { conv, scale };
  }

  function extractResidualLayerParams(
    numFilterValues: number,
    numFilters: number,
    filterSize: number,
    mappedPrefix: string,
    isDown = false,
  ): ResidualLayerParams {
    const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`);
    const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`);

    return { conv1, conv2 };
  }

  return {
    extractConvLayerParams,
    extractResidualLayerParams,
  };
}

export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
  const {
    extractWeights,
    getRemainingWeights,
  } = extractWeightsFactory(weights);

  const paramMappings: ParamMapping[] = [];

  const {
    extractConvLayerParams,
    extractResidualLayerParams,
  } = extractorsFactory(extractWeights, paramMappings);

  const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down');
  const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1');
  const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2');
  const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3');

  const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true);
  const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1');
  const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2');
  const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3');

  const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true);
  const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1');
  const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2');

  const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true);
  const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1');
  const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2');
  const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out');

  const fc = tf.tidy(
    () => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0]),
  );
  paramMappings.push({ paramPath: 'fc' });

  if (getRemainingWeights().length !== 0) {
    throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`);
  }

  const params = {
    conv32_down,
    conv32_1,
    conv32_2,
    conv32_3,
    conv64_down,
    conv64_1,
    conv64_2,
    conv64_3,
    conv128_down,
    conv128_1,
    conv128_2,
    conv256_down,
    conv256_1,
    conv256_2,
    conv256_down_out,
    fc,
  };

  return { params, paramMappings };
}
