import * as tf from '@tensorflow/tfjs/dist/tf.es2017.js';

import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';

export function extractConvParamsFactory(
  extractWeights: ExtractWeightsFunction,
  paramMappings: ParamMapping[]
) {

  return function(
    channelsIn: number,
    channelsOut: number,
    filterSize: number,
    mappedPrefix: string
  ): ConvParams {

    const filters = tf.tensor4d(
      extractWeights(channelsIn * channelsOut * filterSize * filterSize),
      [filterSize, filterSize, channelsIn, channelsOut]
    )
    const bias = tf.tensor1d(extractWeights(channelsOut))

    paramMappings.push(
      { paramPath: `${mappedPrefix}/filters` },
      { paramPath: `${mappedPrefix}/bias` }
    )

    return { filters, bias }
  }

}
