import * as tf from '../../dist/tfjs.esm';

import { disposeUnusedWeightTensors, extractWeightEntryFactory, ParamMapping } from '../common/index';
import { isTensor2D } from '../utils/index';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';

function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
  const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);

  function extractScaleLayerParams(prefix: string): ScaleLayerParams {
    const weights = extractWeightEntry(`${prefix}/scale/weights`, 1);
    const biases = extractWeightEntry(`${prefix}/scale/biases`, 1);

    return { weights, biases };
  }

  function extractConvLayerParams(prefix: string): ConvLayerParams {
    const filters = extractWeightEntry(`${prefix}/conv/filters`, 4);
    const bias = extractWeightEntry(`${prefix}/conv/bias`, 1);
    const scale = extractScaleLayerParams(prefix);

    return { conv: { filters, bias }, scale };
  }

  function extractResidualLayerParams(prefix: string): ResidualLayerParams {
    return {
      conv1: extractConvLayerParams(`${prefix}/conv1`),
      conv2: extractConvLayerParams(`${prefix}/conv2`),
    };
  }

  return {
    extractConvLayerParams,
    extractResidualLayerParams,
  };
}

export function extractParamsFromWeightMap(
  weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } {
  const paramMappings: ParamMapping[] = [];

  const {
    extractConvLayerParams,
    extractResidualLayerParams,
  } = extractorsFactory(weightMap, paramMappings);

  const conv32_down = extractConvLayerParams('conv32_down');
  const conv32_1 = extractResidualLayerParams('conv32_1');
  const conv32_2 = extractResidualLayerParams('conv32_2');
  const conv32_3 = extractResidualLayerParams('conv32_3');

  const conv64_down = extractResidualLayerParams('conv64_down');
  const conv64_1 = extractResidualLayerParams('conv64_1');
  const conv64_2 = extractResidualLayerParams('conv64_2');
  const conv64_3 = extractResidualLayerParams('conv64_3');

  const conv128_down = extractResidualLayerParams('conv128_down');
  const conv128_1 = extractResidualLayerParams('conv128_1');
  const conv128_2 = extractResidualLayerParams('conv128_2');

  const conv256_down = extractResidualLayerParams('conv256_down');
  const conv256_1 = extractResidualLayerParams('conv256_1');
  const conv256_2 = extractResidualLayerParams('conv256_2');
  const conv256_down_out = extractResidualLayerParams('conv256_down_out');

  const { fc } = weightMap;
  paramMappings.push({ originalPath: 'fc', paramPath: 'fc' });

  if (!isTensor2D(fc)) {
    throw new Error(`expected weightMap[fc] to be a Tensor2D, instead have ${fc}`);
  }

  const params = {
    conv32_down,
    conv32_1,
    conv32_2,
    conv32_3,
    conv64_down,
    conv64_1,
    conv64_2,
    conv64_3,
    conv128_down,
    conv128_1,
    conv128_2,
    conv256_down,
    conv256_1,
    conv256_2,
    conv256_down_out,
    fc,
  };

  disposeUnusedWeightTensors(weightMap, paramMappings);

  return { params, paramMappings };
}
