/**
 * @license
 * Copyright 2017 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {Tensor} from './tensor';
import {NamedTensorMap} from './tensor_types';
import * as util from './util';

export interface TapeNode {
  id: number;
  name: string;
  outputs: Tensor[];
  inputs: NamedTensorMap;
  // Optional params, defined only for ops with gradient impl.
  gradient?: (dy: Tensor|Tensor[]) => NamedGradientMap;
  saved?: Tensor[];
}

export type NamedGradientMap = {
  [inputName: string]: () => Tensor;
};

/**
 * Computes a list of TapeNodes that connect x to y, filtering everything else
 * out and preserving the order of the original tape elements.
 *
 * @param tape The tape elements to filter.
 * @param xs The input Tensors.
 * @param y The output Tensor.
 */
export function getFilteredNodesXToY(
    tape: TapeNode[], xs: Tensor[], y: Tensor): TapeNode[] {
  // Forward pass to compute all the nodes and Tensors that are transitively a
  // function of x.
  const tensorsFromX: {[tensorId: number]: boolean} = {};
  const nodesFromX: {[nodeId: number]: boolean} = {};
  for (let i = 0; i < xs.length; i++) {
    tensorsFromX[xs[i].id] = true;
  }

  for (let i = 0; i < tape.length; i++) {
    const node = tape[i];
    const nodeInputs = node.inputs;
    for (const inputName in nodeInputs) {
      const input = nodeInputs[inputName];

      let anyInputFromX = false;
      for (let j = 0; j < xs.length; j++) {
        if (tensorsFromX[input.id]) {
          node.outputs.forEach(output => tensorsFromX[output.id] = true);
          anyInputFromX = true;
          nodesFromX[node.id] = true;
          break;
        }
      }

      if (anyInputFromX) {
        break;
      }
    }
  }

  // Backward pass to find all of the nodes and Tensors that lead to y.
  const tensorsLeadToY: {[tensorId: number]: boolean} = {};
  tensorsLeadToY[y.id] = true;
  const nodesToY: {[nodeId: number]: boolean} = {};

  for (let i = tape.length - 1; i >= 0; i--) {
    const node = tape[i];
    const nodeInputs = node.inputs;

    // If any of the outputs lead to y, mark all of the inputs as leading to y.
    for (let j = 0; j < node.outputs.length; j++) {
      if (tensorsLeadToY[node.outputs[j].id]) {
        for (const inputName in nodeInputs) {
          tensorsLeadToY[nodeInputs[inputName].id] = true;
          nodesToY[node.id] = true;
        }
        break;
      }
    }
  }

  // Return the paths that come from x and lead to y.
  const filteredTape: TapeNode[] = [];
  for (let i = 0; i < tape.length; i++) {
    const node = tape[i];

    if (nodesFromX[node.id] && nodesToY[node.id]) {
      // Prune the inputs from the node that aren't a function of x.
      const prunedInputs: {[inputName: string]: Tensor} = {};
      for (const inputName in node.inputs) {
        const nodeInput = node.inputs[inputName];
        if (tensorsFromX[nodeInput.id]) {
          prunedInputs[inputName] = nodeInput;
        }
      }

      // Copy the node and overwrite inputsAndArgs to the pruned version.
      const prunedNode = Object.assign({}, node) as TapeNode;
      prunedNode.inputs = prunedInputs;
      prunedNode.outputs = node.outputs;

      filteredTape.push(prunedNode);
    }
  }

  return filteredTape;
}

/**
 * Backpropagate gradients through the filtered TapeNodes.
 *
 * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
 * is mutated by this method.
 * @param filteredTape The filtered TapeNodes to backprop through.
 */
export function backpropagateGradients(
    tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},
    filteredTape: TapeNode[], tidy: (f: Function) => Tensor) {
  // Walk the tape backward and keep a map of Tensor to its gradient.
  for (let i = filteredTape.length - 1; i >= 0; i--) {
    const node = filteredTape[i];

    const dys: Tensor[] = [];
    node.outputs.forEach(o => {
      const gradTensor = tensorAccumulatedGradientMap[o.id];
      if (gradTensor != null) {
        dys.push(gradTensor);
      } else {
        // This particular output is not in the back-propagation subgraph, so it
        // does not affect the final output, thus we put zeros for its dy.
        const dy = Tensor.make(
            o.shape, {values: util.makeZerosTypedArray(o.size, o.dtype)},
            o.dtype);
        dys.push(dy);
      }
    });

    if (node.gradient == null) {
      throw new Error(
          `Cannot compute gradient: gradient function not found ` +
          `for ${node.name}.`);
    }

    // Backprop dy through this node and accumulate gradients over the inputs.
    const inputGradients =
        // Grad functions of ops with single outputs expect a dy, while ops
        // with multiple outputs expect dys (array of dy).
        node.gradient(node.outputs.length === 1 ? dys[0] : dys);
    for (const inputName in node.inputs) {
      if (!(inputName in inputGradients)) {
        throw new Error(
            `Cannot backprop through input ${inputName}. ` +
            `Available gradients found: ${Object.keys(inputGradients)}.`);
      }

      // Call the gradient function.
      const dx = tidy(() => inputGradients[inputName]());
      if (dx.dtype !== 'float32') {
        throw new Error(
            `Error in gradient for op ${node.name}. The gradient of input ` +
            `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
      }
      const x = node.inputs[inputName];
      if (!util.arraysEqual(dx.shape, x.shape)) {
        throw new Error(
            `Error in gradient for op ${node.name}. The gradient of input ` +
            `'${inputName}' has shape '${dx.shape}', which does not match ` +
            `the shape of the input '${x.shape}'`);
      }

      if (tensorAccumulatedGradientMap[x.id] == null) {
        tensorAccumulatedGradientMap[x.id] = dx;
      } else {
        const curGradient = tensorAccumulatedGradientMap[x.id];
        tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
        curGradient.dispose();
      }
    }
  }
}
