/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

/**
 * IOHandlers related to files, such as browser-triggered file downloads,
 * user-selected files in browser.
 */

import {env} from '../environment';

import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, ModelJSON, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';

const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';

function defer<T>(f: () => T): Promise<T> {
  return new Promise(resolve => setTimeout(resolve)).then(f);
}

export class BrowserDownloads implements IOHandler {
  private readonly modelTopologyFileName: string;
  private readonly weightDataFileName: string;
  private readonly jsonAnchor: HTMLAnchorElement;
  private readonly weightDataAnchor: HTMLAnchorElement;

  static readonly URL_SCHEME = 'downloads://';

  constructor(fileNamePrefix?: string) {
    if (!env().getBool('IS_BROWSER')) {
      // TODO(cais): Provide info on what IOHandlers are available under the
      //   current environment.
      throw new Error(
          'browserDownloads() cannot proceed because the current environment ' +
          'is not a browser.');
    }

    if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
      fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
    }
    if (fileNamePrefix == null || fileNamePrefix.length === 0) {
      fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
    }

    this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
    this.weightDataFileName =
        fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
  }

  async save(modelArtifacts: ModelArtifacts): Promise<SaveResult> {
    if (typeof (document) === 'undefined') {
      throw new Error(
          'Browser downloads are not supported in ' +
          'this environment since `document` is not present');
    }
    const weightsURL = window.URL.createObjectURL(new Blob(
        [modelArtifacts.weightData], {type: 'application/octet-stream'}));

    if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
      throw new Error(
          'BrowserDownloads.save() does not support saving model topology ' +
          'in binary formats yet.');
    } else {
      const weightsManifest: WeightsManifestConfig = [{
        paths: ['./' + this.weightDataFileName],
        weights: modelArtifacts.weightSpecs
      }];
      const modelTopologyAndWeightManifest: ModelJSON = {
        modelTopology: modelArtifacts.modelTopology,
        format: modelArtifacts.format,
        generatedBy: modelArtifacts.generatedBy,
        convertedBy: modelArtifacts.convertedBy,
        weightsManifest
      };
      const modelTopologyAndWeightManifestURL =
          window.URL.createObjectURL(new Blob(
              [JSON.stringify(modelTopologyAndWeightManifest)],
              {type: 'application/json'}));

      // If anchor elements are not provided, create them without attaching them
      // to parents, so that the downloaded file names can be controlled.
      const jsonAnchor = this.jsonAnchor == null ? document.createElement('a') :
                                                   this.jsonAnchor;
      jsonAnchor.download = this.modelTopologyFileName;
      jsonAnchor.href = modelTopologyAndWeightManifestURL;
      // Trigger downloads by evoking a click event on the download anchors.
      // When multiple downloads are started synchronously, Firefox will only
      // save the last one.
      await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));

      if (modelArtifacts.weightData != null) {
        const weightDataAnchor = this.weightDataAnchor == null ?
            document.createElement('a') :
            this.weightDataAnchor;
        weightDataAnchor.download = this.weightDataFileName;
        weightDataAnchor.href = weightsURL;
        await defer(
            () => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
      }

      return {modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)};
    }
  }
}

class BrowserFiles implements IOHandler {
  private readonly files: File[];

  constructor(files: File[]) {
    if (files == null || files.length < 1) {
      throw new Error(
          `When calling browserFiles, at least 1 file is required, ` +
          `but received ${files}`);
    }
    this.files = files;
  }

  async load(): Promise<ModelArtifacts> {
    const jsonFile = this.files[0];
    const weightFiles = this.files.slice(1);

    return new Promise<ModelArtifacts>((resolve, reject) => {
      const jsonReader = new FileReader();
      jsonReader.onload = (event: Event) => {
        // tslint:disable-next-line:no-any
        const modelJSON = JSON.parse((event.target as any).result) as ModelJSON;
        const modelTopology = modelJSON.modelTopology;
        if (modelTopology == null) {
          reject(new Error(
              `modelTopology field is missing from file ${jsonFile.name}`));
          return;
        }

        if (weightFiles.length === 0) {
          resolve({modelTopology});
        }

        const weightsManifest = modelJSON.weightsManifest;
        if (weightsManifest == null) {
          reject(new Error(
              `weightManifest field is missing from file ${jsonFile.name}`));
          return;
        }

        let pathToFile: {[path: string]: File};
        try {
          pathToFile =
              this.checkManifestAndWeightFiles(weightsManifest, weightFiles);
        } catch (err) {
          reject(err);
          return;
        }

        const weightSpecs: WeightsManifestEntry[] = [];
        const paths: string[] = [];
        const perFileBuffers: ArrayBuffer[] = [];
        weightsManifest.forEach(weightsGroup => {
          weightsGroup.paths.forEach(path => {
            paths.push(path);
            perFileBuffers.push(null);
          });
          weightSpecs.push(...weightsGroup.weights);
        });

        weightsManifest.forEach(weightsGroup => {
          weightsGroup.paths.forEach(path => {
            const weightFileReader = new FileReader();
            weightFileReader.onload = (event: Event) => {
              // tslint:disable-next-line:no-any
              const weightData = (event.target as any).result as ArrayBuffer;
              const index = paths.indexOf(path);
              perFileBuffers[index] = weightData;
              if (perFileBuffers.indexOf(null) === -1) {
                resolve({
                  modelTopology,
                  weightSpecs,
                  weightData: concatenateArrayBuffers(perFileBuffers),
                  format: modelJSON.format,
                  generatedBy: modelJSON.generatedBy,
                  convertedBy: modelJSON.convertedBy,
                  userDefinedMetadata: modelJSON.userDefinedMetadata
                });
              }
            };
            weightFileReader.onerror = error =>
                reject(`Failed to weights data from file of path '${path}'.`);
            weightFileReader.readAsArrayBuffer(pathToFile[path]);
          });
        });
      };
      jsonReader.onerror = error => reject(
          `Failed to read model topology and weights manifest JSON ` +
          `from file '${jsonFile.name}'. BrowserFiles supports loading ` +
          `Keras-style tf.Model artifacts only.`);
      jsonReader.readAsText(jsonFile);
    });
  }

  /**
   * Check the compatibility between weights manifest and weight files.
   */
  private checkManifestAndWeightFiles(
      manifest: WeightsManifestConfig, files: File[]): {[path: string]: File} {
    const basenames: string[] = [];
    const fileNames = files.map(file => basename(file.name));
    const pathToFile: {[path: string]: File} = {};
    for (const group of manifest) {
      group.paths.forEach(path => {
        const pathBasename = basename(path);
        if (basenames.indexOf(pathBasename) !== -1) {
          throw new Error(
              `Duplicate file basename found in weights manifest: ` +
              `'${pathBasename}'`);
        }
        basenames.push(pathBasename);
        if (fileNames.indexOf(pathBasename) === -1) {
          throw new Error(
              `Weight file with basename '${pathBasename}' is not provided.`);
        } else {
          pathToFile[path] = files[fileNames.indexOf(pathBasename)];
        }
      });
    }

    if (basenames.length !== files.length) {
      throw new Error(
          `Mismatch in the number of files in weights manifest ` +
          `(${basenames.length}) and the number of weight files provided ` +
          `(${files.length}).`);
    }
    return pathToFile;
  }
}

export const browserDownloadsRouter: IORouter = (url: string|string[]) => {
  if (!env().getBool('IS_BROWSER')) {
    return null;
  } else {
    if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
      return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
    } else {
      return null;
    }
  }
};
IORouterRegistry.registerSaveRouter(browserDownloadsRouter);

/**
 * Creates an IOHandler that triggers file downloads from the browser.
 *
 * The returned `IOHandler` instance can be used as model exporting methods such
 * as `tf.Model.save` and supports only saving.
 *
 * ```js
 * const model = tf.sequential();
 * model.add(tf.layers.dense(
 *     {units: 1, inputShape: [10], activation: 'sigmoid'}));
 * const saveResult = await model.save('downloads://mymodel');
 * // This will trigger downloading of two files:
 * //   'mymodel.json' and 'mymodel.weights.bin'.
 * console.log(saveResult);
 * ```
 *
 * @param fileNamePrefix Prefix name of the files to be downloaded. For use with
 *   `tf.Model`, `fileNamePrefix` should follow either of the following two
 *   formats:
 *   1. `null` or `undefined`, in which case the default file
 *      names will be used:
 *      - 'model.json' for the JSON file containing the model topology and
 *        weights manifest.
 *      - 'model.weights.bin' for the binary file containing the binary weight
 *        values.
 *   2. A single string or an Array of a single string, as the file name prefix.
 *      For example, if `'foo'` is provided, the downloaded JSON
 *      file and binary weights file will be named 'foo.json' and
 *      'foo.weights.bin', respectively.
 * @param config Additional configuration for triggering downloads.
 * @returns An instance of `BrowserDownloads` `IOHandler`.
 */
/**
 * @doc {
 *   heading: 'Models',
 *   subheading: 'Loading',
 *   namespace: 'io',
 *   ignoreCI: true
 * }
 */
export function browserDownloads(fileNamePrefix = 'model'): IOHandler {
  return new BrowserDownloads(fileNamePrefix);
}

/**
 * Creates an IOHandler that loads model artifacts from user-selected files.
 *
 * This method can be used for loading from files such as user-selected files
 * in the browser.
 * When used in conjunction with `tf.loadLayersModel`, an instance of
 * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
 *
 * ```js
 * // Note: This code snippet won't run properly without the actual file input
 * //   elements in the HTML DOM.
 *
 * // Suppose there are two HTML file input (`<input type="file" ...>`)
 * // elements.
 * const uploadJSONInput = document.getElementById('upload-json');
 * const uploadWeightsInput = document.getElementById('upload-weights');
 * const model = await tf.loadLayersModel(tf.io.browserFiles(
 *     [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
 * ```
 *
 * @param files `File`s to load from. Currently, this function supports only
 *   loading from files that contain Keras-style models (i.e., `tf.Model`s), for
 *   which an `Array` of `File`s is expected (in that order):
 *   - A JSON file containing the model topology and weight manifest.
 *   - Optionally, One or more binary files containing the binary weights.
 *     These files must have names that match the paths in the `weightsManifest`
 *     contained by the aforementioned JSON file, or errors will be thrown
 *     during loading. These weights files have the same format as the ones
 *     generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
 *     Python PIP package. If no weights files are provided, only the model
 *     topology will be loaded from the JSON file above.
 * @returns An instance of `Files` `IOHandler`.
 */
/**
 * @doc {
 *   heading: 'Models',
 *   subheading: 'Loading',
 *   namespace: 'io',
 *   ignoreCI: true
 * }
 */
export function browserFiles(files: File[]): IOHandler {
  return new BrowserFiles(files);
}
