/**
 * @license
 * Copyright 2018 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {whereImpl} from '../backends/where_impl';
import {ENGINE} from '../engine';
import {Tensor, Tensor2D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert, assertShapesMatch} from '../util';
import {assertAndGetBroadcastShape} from './broadcast_util';
import {op} from './operation';
import {zerosLike} from './tensor_ops';

/**
 * Returns the truth value of `NOT x` element-wise.
 *
 * ```js
 * const a = tf.tensor1d([false, true], 'bool');
 *
 * a.logicalNot().print();
 * ```
 *
 * @param x The input tensor. Must be of dtype 'bool'.
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalNot_<T extends Tensor>(x: T|TensorLike): T {
  const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
  return ENGINE.runKernel(backend => backend.logicalNot($x), {$x});
}

/**
 * Returns the truth value of `a AND b` element-wise. Supports broadcasting.
 *
 * ```js
 * const a = tf.tensor1d([false, false, true, true], 'bool');
 * const b = tf.tensor1d([false, true, false, true], 'bool');
 *
 * a.logicalAnd(b).print();
 * ```
 *
 * @param a The first input tensor. Must be of dtype bool.
 * @param b The second input tensor. Must be of dtype bool.
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalAnd_<T extends Tensor>(
    a: Tensor|TensorLike, b: Tensor|TensorLike): T {
  const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
  const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
  assertAndGetBroadcastShape($a.shape, $b.shape);

  return ENGINE.runKernel(backend => backend.logicalAnd($a, $b), {$a, $b}) as T;
}

/**
 * Returns the truth value of `a OR b` element-wise. Supports broadcasting.
 *
 * ```js
 * const a = tf.tensor1d([false, false, true, true], 'bool');
 * const b = tf.tensor1d([false, true, false, true], 'bool');
 *
 * a.logicalOr(b).print();
 * ```
 * @param a The first input tensor. Must be of dtype bool.
 * @param b The second input tensor. Must be of dtype bool.
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalOr_<T extends Tensor>(
    a: Tensor|TensorLike, b: Tensor|TensorLike): T {
  const $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
  const $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
  assertAndGetBroadcastShape($a.shape, $b.shape);

  return ENGINE.runKernel(backend => backend.logicalOr($a, $b), {$a, $b}) as T;
}

/**
 * Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
 *
 * ```js
 * const a = tf.tensor1d([false, false, true, true], 'bool');
 * const b = tf.tensor1d([false, true, false, true], 'bool');
 *
 * a.logicalXor(b).print();
 * ```
 *
 * @param a The first input tensor. Must be of dtype bool.
 * @param b The second input tensor. Must be of dtype bool.
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function logicalXor_<T extends Tensor>(
    a: Tensor|TensorLike, b: Tensor|TensorLike): T {
  const $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
  const $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
  assertAndGetBroadcastShape($a.shape, $b.shape);

  // x ^ y = (x | y) & ~(x & y)
  return logicalOr(a, b).logicalAnd(logicalAnd(a, b).logicalNot()) as T;
}

/**
 * Returns the elements, either `a` or `b` depending on the `condition`.
 *
 * If the condition is true, select from `a`, otherwise select from `b`.
 *
 * ```js
 * const cond = tf.tensor1d([false, false, true], 'bool');
 * const a = tf.tensor1d([1 , 2, 3]);
 * const b = tf.tensor1d([-1, -2, -3]);
 *
 * a.where(cond, b).print();
 * ```
 *
 * @param condition The input condition. Must be of dtype bool.
 * @param a If `condition` is rank 1, `a` may have a higher rank but
 *     its first dimension must match the size of `condition`.
 * @param b A tensor with the same shape and type as `a`.
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
function where_<T extends Tensor>(
    condition: Tensor|TensorLike, a: T|TensorLike, b: T|TensorLike): T {
  const $a = convertToTensor(a, 'a', 'where');
  const $b = convertToTensor(b, 'b', 'where');
  const $condition = convertToTensor(condition, 'condition', 'where', 'bool');

  assertShapesMatch($a.shape, $b.shape, 'Error in where: ');

  if ($condition.rank === 1) {
    // If condition rank is 1, then the first dimension must match the size of
    // condition.
    assert(
        $condition.shape[0] === $a.shape[0],
        () => 'The first dimension of `a` must match the size of `condition`.');
  } else {
    // A must have the same shape as condition.
    assertShapesMatch($condition.shape, $b.shape, 'Error in where: ');
  }

  // TODO(julianoks): Return null for condition gradient
  // when backprop supports it.
  const grad = (dy: T, saved: Tensor[]) => {
    const [$condition] = saved;
    return {
      $condition: () => zerosLike($condition).toFloat(),
      $a: () => dy.mul($condition.cast(dy.dtype)) as T,
      $b: () => dy.mul($condition.logicalNot().cast(dy.dtype)) as T
    };
  };

  return ENGINE.runKernel((backend, save) => {
    const res = backend.select($condition, $a, $b);
    save([$condition]);
    return res;
  }, {$condition, $a, $b}, grad) as T;
}

/**
 * Returns the coordinates of true elements of condition.
 *
 * The coordinates are returned in a 2-D tensor where the first dimension (rows)
 * represents the number of true elements, and the second dimension (columns)
 * represents the coordinates of the true elements. Keep in mind, the shape of
 * the output tensor can vary depending on how many true values there are in
 * input. Indices are output in row-major order. The resulting tensor has the
 * shape `[numTrueElems, condition.rank]`.
 *
 * This is analogous to calling the python `tf.where(cond)` without an x or y.
 *
 * ```js
 * const cond = tf.tensor1d([false, false, true], 'bool');
 * const result = await tf.whereAsync(cond);
 * result.print();
 * ```
 */
/** @doc {heading: 'Operations', subheading: 'Logical'} */
async function whereAsync_(condition: Tensor|TensorLike): Promise<Tensor2D> {
  const $condition =
      convertToTensor(condition, 'condition', 'whereAsync', 'bool');
  const vals = await $condition.data();
  const res = whereImpl($condition.shape, vals);
  if (condition !== $condition) {
    $condition.dispose();
  }
  return res;
}

export const logicalAnd = op({logicalAnd_});
export const logicalNot = op({logicalNot_});
export const logicalOr = op({logicalOr_});
export const logicalXor = op({logicalXor_});
export const where = op({where_});
export const whereAsync = whereAsync_;
