/**
 * @license
 * Copyright 2021 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts

import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';

const createConvTranspose2DOpProgramShaderSource =
    (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean,
     is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType,
     isChannelsLast = false): string => {
      const rowDim = isChannelsLast ? 1 : 2;
      const colDim = isChannelsLast ? 2 : 3;
      const channelDim = isChannelsLast ? 3 : 1;
      const workPerThread = isVec4 ? 2 : 1;

      let declareFunctions = `
  fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
    result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value);
  }`;
      if (hasBias) {
        declareFunctions += `
    fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? `vec4<${dataType}>` : dataType} {
      return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
    }`;
      }
      const components = isVec4 ? 4 : 1;
      const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
      const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
      const inputVariables = [dy, w];
      if (hasBias) {
        inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
      }
      const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

      const codeSnippet4 = `{
        let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
        let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
        let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
        let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;

        let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);

        // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
        // ? = to be determined. : = across all values in that axis.
        var dotProd: array<vec4<${dataType}>, ${workPerThread}>;
        for (var i = 0; i < ${workPerThread}; i++) {
          dotProd[i] = vec4<${dataType}>(0.0);
        }
        for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
          var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
          let wRPerm = uniforms.filter_dims[0] - 1 - wR;
          if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
              fract(dyR) > 0.0 || wRPerm < 0) {
            continue;
          }
          let idyR: u32 = u32(dyR);

          for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
            let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
            let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
            let wCPerm = uniforms.filter_dims[1] - 1 - wC;
            if (wCPerm < 0) {
              continue;
            }
            var bDyCVal = true;
            var bDyCVal2 = true;
            if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
                fract(dyC) > 0.0) {
              bDyCVal = false;
            }
            if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
                fract(dyC2) > 0.0) {
              bDyCVal2 = false;
            }

            let idyC: u32 = u32(dyC);
            let idyC2: u32 = u32(dyC2);
            if (bDyCVal && bDyCVal2) {
              let d2Length = uniforms.Dy_shape[3];
              for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[0] = dotProd[0] + tmpval;

                xValue =  ${dy.get('batch', 'idyR', 'idyC2', 'd2')};

                dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0),
                                                    dot(xValue, wValue1),
                                                    dot(xValue, wValue2),
                                                    dot(xValue, wValue3));
              }
            } else if (bDyCVal) {
              let d2Length = uniforms.Dy_shape[${channelDim}];
              for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[0] = dotProd[0] + tmpval;
              }
            } else if (bDyCVal2) {
              let d2Length = uniforms.Dy_shape[3];
              for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
                let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
                let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
                let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
                let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

                var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
                let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
                                      dot(xValue, wValue1),
                                      dot(xValue, wValue2),
                                      dot(xValue, wValue3));
                dotProd[1] = dotProd[1] + tmpval;
              }
            }
          }
        }

        for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
          let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`};
          ${output.set('batch', 'r', 'c + i', 'd1', 'value')};
        }
      }`;
      const codeSnippet = `
          let outputIndices = ${output.offsetToIndices('global_idx')};
          let batch = ${output.indicesGet('outputIndices', 0)};
          let d1 = ${output.indicesGet('outputIndices', channelDim)};
          let r = ${output.indicesGet('outputIndices', rowDim)};
          let c = ${output.indicesGet('outputIndices', colDim)};
          let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
          let dyRCorner = dyCorner.x;
          let dyCCorner = dyCorner.y;
          let groupId = d1 / uniforms.output_channels_per_group;
          let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
          // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
          // ? = to be determined. : = across all values in that axis.
          var dotProd = ${dataType}(0.0);
          for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
            if (wR % uniforms.dilations.x != 0) {
              continue;
            }
            let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
            let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
            if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
                wRPerm < 0) {
              continue;
            }
            let idyR: u32 = u32(dyR);

            for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
              if (wC % uniforms.dilations.y != 0) {
                continue;
              }
              let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
              let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
              if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
                  fract(dyC) > 0.0 || wCPerm < 0) {
                continue;
              }
              let idyC: u32 = u32(dyC);
              var inputChannel = groupId * uniforms.input_channels_per_group;
              for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
                let xValue = ${
          isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
                           dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
                let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
                dotProd = dotProd + xValue * wValue;
                inputChannel = inputChannel + 1;
              }
            }
          }
          let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`};
          ${output.setByOffset('global_idx', 'value')};
        `;

      return `
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
  ${declareFunctions}

    ${shaderHelper.mainStart()}
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
  ${isVec4 ? codeSnippet4 : codeSnippet}}`;
    };

export const createConvTranspose2DProgramInfo =
    (inputs: readonly TensorView[], attributes: ConvTransposeAttributes,
     squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => {
      const hasBias = inputs.length > 2;
      // const isChannelsLast = attributes.format === 'NHWC';
      const outputShape = attributes.outputShape;
      const outputSize = ShapeUtil.size(outputShape);

      // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
      // TODO Enable isVec4 for performance
      // Disabled due to weight matrix layout issue
      // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0;
      const dispatch = [
        Math.ceil(outputSize / 64),
        1,
        1,
      ];
      LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);

      const isChannelsLast = attributes.format === 'NHWC';
      const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
      const strides = [attributes.strides[0], attributes.strides[1]];
      const filterDims =
          [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
      const dilations = [attributes.dilations[0], attributes.dilations[1]];
      const effectiveFilterDims = [
        filterDims[0] +
            (attributes.dilations[0] <= 1 ?
                 0 :
                 (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)),
        filterDims[1] +
            (attributes.dilations[1] <= 1 ?
                 0 :
                 (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1))
      ];
      const pads = [
        effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
        effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2
      ];

      const isVec4 = false;
      const group = attributes.group;
      const wShape = inputs[1].dims;
      const inputChannelsPerGroup = wShape[0] / group;
      const outputChannelsPerGroup = wShape[1];

      const programUniforms: ProgramUniform[] = [
        {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides},
        {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
        {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
        {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
        ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)
      ];
      if (hasBias) {
        programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
        inputDependencies.push('rank');
      }
      programUniforms.push(...createTensorShapeVariables(outputShape));

      const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const uniforms: UniformsArrayType = [
          {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length},
          {name: 'filter_dims', type: 'u32', length: filterDims.length},
          {name: 'dilations', type: 'u32', length: filterDims.length},
          {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length},
          {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'},
          {name: 'output_channels_per_group', type: 'u32'}
        ];
        const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
        return `${
            createConvTranspose2DOpProgramShaderSource(
                shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms,
                isChannelsLast)}`;
      };
      return {
        name: 'ConvTranspose2D',
        shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies},
        getRunData: () => ({
          dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
          outputs: [{
            dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape,
            dataType: inputs[0].dataType
          }],
          programUniforms
        }),
        getShaderSource
      };
    };
