/**
 * @license
 * Copyright 2019 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {ENGINE} from '../engine';
import {getKernel} from '../kernel_registry';
import {Tensor, Tensor2D, Tensor3D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {PixelData, TensorLike} from '../types';

import {op} from './operation';
import {tensor3d} from './tensor_ops';

let fromPixels2DContext: CanvasRenderingContext2D;

/**
 * Creates a `tf.Tensor` from an image.
 *
 * ```js
 * const image = new ImageData(1, 1);
 * image.data[0] = 100;
 * image.data[1] = 150;
 * image.data[2] = 200;
 * image.data[3] = 255;
 *
 * tf.browser.fromPixels(image).print();
 * ```
 *
 * @param pixels The input image to construct the tensor from. The
 * supported image types are all 4-channel. You can also pass in an image
 * object with following attributes:
 * `{data: Uint8Array; width: number; height: number}`
 * @param numChannels The number of channels of the output tensor. A
 * numChannels value less than 4 allows you to ignore channels. Defaults to
 * 3 (ignores alpha channel of input image).
 */
/** @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} */
function fromPixels_(
    pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
    HTMLVideoElement,
    numChannels = 3): Tensor3D {
  // Sanity checks.
  if (numChannels > 4) {
    throw new Error(
        'Cannot construct Tensor with more than 4 channels from pixels.');
  }
  if (pixels == null) {
    throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
  }
  let isPixelData = false;
  let isImageData = false;
  let isVideo = false;
  let isImage = false;
  let isCanvasLike = false;
  if ((pixels as PixelData).data instanceof Uint8Array) {
    isPixelData = true;
  } else if (
      typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) {
    isImageData = true;
  } else if (
      typeof (HTMLVideoElement) !== 'undefined' &&
      pixels instanceof HTMLVideoElement) {
    isVideo = true;
  } else if (
      typeof (HTMLImageElement) !== 'undefined' &&
      pixels instanceof HTMLImageElement) {
    isImage = true;
    // tslint:disable-next-line: no-any
  } else if ((pixels as any).getContext != null) {
    isCanvasLike = true;
  } else {
    throw new Error(
        'pixels passed to tf.browser.fromPixels() must be either an ' +
        `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
        `in browser, or OffscreenCanvas, ImageData in webworker` +
        ` or {data: Uint32Array, width: number, height: number}, ` +
        `but was ${(pixels as {}).constructor.name}`);
  }
  if (isVideo) {
    const HAVE_CURRENT_DATA_READY_STATE = 2;
    if (isVideo &&
        (pixels as HTMLVideoElement).readyState <
            HAVE_CURRENT_DATA_READY_STATE) {
      throw new Error(
          'The video element has not loaded data yet. Please wait for ' +
          '`loadeddata` event on the <video> element.');
    }
  }
  // If the current backend has 'FromPixels' registered, it has a more
  // efficient way of handling pixel uploads, so we call that.
  const kernel = getKernel('FromPixels', ENGINE.backendName);
  if (kernel != null) {
    return ENGINE.runKernel('FromPixels', {pixels} as {}, {numChannels}) as
        Tensor3D;
  }

  const [width, height] = isVideo ?
      [
        (pixels as HTMLVideoElement).videoWidth,
        (pixels as HTMLVideoElement).videoHeight
      ] :
      [pixels.width, pixels.height];
  let vals: Uint8ClampedArray|Uint8Array;

  if (isCanvasLike) {
    vals =
        // tslint:disable-next-line:no-any
        (pixels as any).getContext('2d').getImageData(0, 0, width, height).data;
  } else if (isImageData || isPixelData) {
    vals = (pixels as PixelData | ImageData).data;
  } else if (isImage || isVideo) {
    if (fromPixels2DContext == null) {
      fromPixels2DContext = document.createElement('canvas').getContext('2d');
    }
    fromPixels2DContext.canvas.width = width;
    fromPixels2DContext.canvas.height = height;
    fromPixels2DContext.drawImage(
        pixels as HTMLVideoElement, 0, 0, width, height);
    vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
  }
  let values: Int32Array;
  if (numChannels === 4) {
    values = new Int32Array(vals);
  } else {
    const numPixels = width * height;
    values = new Int32Array(numPixels * numChannels);
    for (let i = 0; i < numPixels; i++) {
      for (let channel = 0; channel < numChannels; ++channel) {
        values[i * numChannels + channel] = vals[i * 4 + channel];
      }
    }
  }
  const outShape: [number, number, number] = [height, width, numChannels];
  return tensor3d(values, outShape, 'int32');
}

/**
 * Draws a `tf.Tensor` of pixel values to a byte array or optionally a
 * canvas.
 *
 * When the dtype of the input is 'float32', we assume values in the range
 * [0-1]. Otherwise, when input is 'int32', we assume values in the range
 * [0-255].
 *
 * Returns a promise that resolves when the canvas has been drawn to.
 *
 * @param img A rank-2 or rank-3 tensor. If rank-2, draws grayscale. If
 *     rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
 * grayscale. When depth of 3, we draw with the first three components of
 * the depth dimension corresponding to r, g, b and alpha = 1. When depth of
 * 4, all four components of the depth dimension correspond to r, g, b, a.
 * @param canvas The canvas to draw to.
 */
/** @doc {heading: 'Browser', namespace: 'browser'} */
export async function toPixels(
    img: Tensor2D|Tensor3D|TensorLike,
    canvas?: HTMLCanvasElement): Promise<Uint8ClampedArray> {
  let $img = convertToTensor(img, 'img', 'toPixels');
  if (!(img instanceof Tensor)) {
    // Assume int32 if user passed a native array.
    $img = $img.toInt();
  }
  if ($img.rank !== 2 && $img.rank !== 3) {
    throw new Error(
        `toPixels only supports rank 2 or 3 tensors, got rank ${$img.rank}.`);
  }
  const [height, width] = $img.shape.slice(0, 2);
  const depth = $img.rank === 2 ? 1 : $img.shape[2];

  if (depth > 4 || depth === 2) {
    throw new Error(
        `toPixels only supports depth of size ` +
        `1, 3 or 4 but got ${depth}`);
  }

  const data = await $img.data();
  const minTensor = $img.min();
  const maxTensor = $img.max();
  const vals = await Promise.all([minTensor.data(), maxTensor.data()]);
  const minVals = vals[0];
  const maxVals = vals[1];
  const min = minVals[0];
  const max = maxVals[0];
  minTensor.dispose();
  maxTensor.dispose();
  if ($img.dtype === 'float32') {
    if (min < 0 || max > 1) {
      throw new Error(
          `Tensor values for a float32 Tensor must be in the ` +
          `range [0 - 1] but got range [${min} - ${max}].`);
    }
  } else if ($img.dtype === 'int32') {
    if (min < 0 || max > 255) {
      throw new Error(
          `Tensor values for a int32 Tensor must be in the ` +
          `range [0 - 255] but got range [${min} - ${max}].`);
    }
  } else {
    throw new Error(
        `Unsupported type for toPixels: ${$img.dtype}.` +
        ` Please use float32 or int32 tensors.`);
  }
  const multiplier = $img.dtype === 'float32' ? 255 : 1;
  const bytes = new Uint8ClampedArray(width * height * 4);

  for (let i = 0; i < height * width; ++i) {
    let r, g, b, a;
    if (depth === 1) {
      r = data[i] * multiplier;
      g = data[i] * multiplier;
      b = data[i] * multiplier;
      a = 255;
    } else if (depth === 3) {
      r = data[i * 3] * multiplier;
      g = data[i * 3 + 1] * multiplier;
      b = data[i * 3 + 2] * multiplier;
      a = 255;
    } else if (depth === 4) {
      r = data[i * 4] * multiplier;
      g = data[i * 4 + 1] * multiplier;
      b = data[i * 4 + 2] * multiplier;
      a = data[i * 4 + 3] * multiplier;
    }

    const j = i * 4;
    bytes[j + 0] = Math.round(r);
    bytes[j + 1] = Math.round(g);
    bytes[j + 2] = Math.round(b);
    bytes[j + 3] = Math.round(a);
  }

  if (canvas != null) {
    canvas.width = width;
    canvas.height = height;
    const ctx = canvas.getContext('2d');
    const imageData = new ImageData(bytes, width, height);
    ctx.putImageData(imageData, 0, 0);
  }
  if ($img !== img) {
    $img.dispose();
  }
  return bytes;
}

export const fromPixels = op({fromPixels_});
