import * as tf from '../../dist/tfjs.esm';

import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common/index';
import { NetParams } from './types';

export function extractParamsFromWeightMap(
  weightMap: tf.NamedTensorMap,
): { params: NetParams, paramMappings: ParamMapping[] } {
  const paramMappings: ParamMapping[] = [];

  const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings);

  function extractFcParams(prefix: string): FCParams {
    const weights = extractWeightEntry(`${prefix}/weights`, 2);
    const bias = extractWeightEntry(`${prefix}/bias`, 1);
    return { weights, bias };
  }

  const params = {
    fc: {
      age: extractFcParams('fc/age'),
      gender: extractFcParams('fc/gender'),
    },
  };

  disposeUnusedWeightTensors(weightMap, paramMappings);

  return { params, paramMappings };
}
