/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {getChannels} from '../packing_util';

import {GPGPUProgram} from './gpgpu_math';
import {getCoordsDataType} from './shader_compiler';

export class PackProgram implements GPGPUProgram {
  variableNames = ['A'];
  outputShape: number[];
  userCode: string;
  packedInputs = false;
  packedOutput = true;

  constructor(
      outputShape:
          number[]) {  // TODO(https://github.com/tensorflow/tfjs/issues/893):
                       // Only input / output 3D tensors.
    this.outputShape = outputShape;
    const rank = outputShape.length;

    if (rank === 0) {
      this.userCode = `
        void main() {
          setOutput(vec4(getA(), 0., 0., 0.));
        }
      `;
    } else {
      const channels = getChannels('rc', rank);
      const dtype = getCoordsDataType(rank);
      const outOfBoundsCondition =
          getOutOfBoundsCondition(rank, outputShape, channels);
      const setup = getSetup(
          rank, outputShape[outputShape.length - 1],
          outputShape[outputShape.length - 2], channels);
      const output = getOutput(outputShape, channels);

      this.userCode = `
        void main() {
          ${dtype} rc = getOutputCoords();

          if(${outOfBoundsCondition}) {
            setOutput(vec4(0));
          } else {
            ${setup}

            setOutput(vec4(${output}));
          }
        }
      `;
    }
  }
}

function getSourceCoordsArr(rank: number, dims: string[]): string[] {
  const coords = [];

  for (let row = 0; row <= 1; row++) {
    for (let col = 0; col <= 1; col++) {
      let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;

      for (let d = 2; d < rank; d++) {
        coord = `${dims[dims.length - 1 - d]},` + coord;
      }

      coords.push(coord);
    }
  }
  return coords;
}

function getOutOfBoundsCondition(
    rank: number, shape: number[], dims: string[]): string {
  if (rank === 1) {
    return `rc > ${shape[0]}`;
  }

  let cond = '';
  for (let i = rank - 2; i < rank; i++) {
    cond += `${dims[i]} >= ${shape[i]}`;
    if (i < rank - 1) {
      cond += '||';
    }
  }

  return cond;
}

function getSetup(
    rank: number, cols: number, rows: number, dims: string[]): string {
  if (rank === 1) {
    return '';
  }

  const innerDims = dims.slice(-2);

  return `
    int r = ${innerDims[0]};
    int c = ${innerDims[1]};
    int rp1 = r + 1;
    int cp1 = c + 1;

    bool cEdge = cp1 >= ${cols};
    bool rEdge = rp1 >= ${rows};
  `;
}

function getOutput(shape: number[], dims: string[]): string {
  const rank = shape.length;
  const sourceCoords = getSourceCoordsArr(rank, dims);
  if (rank === 1) {
    return `getA(rc),
            rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
            0, 0`;
  }

  return `getA(${sourceCoords[0]}),
          cEdge ? 0. : getA(${sourceCoords[1]}),
          rEdge ? 0. : getA(${sourceCoords[2]}),
          rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
}
