import { isTensor } from '../utils/index';
import { ParamMapping } from './types';

export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
  return (originalPath: string, paramRank: number, mappedPath?: string) => {
    const tensor = weightMap[originalPath];

    if (!isTensor(tensor, paramRank)) {
      throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`);
    }

    paramMappings.push(
      { originalPath, paramPath: mappedPath || originalPath },
    );

    return tensor;
  };
}
