// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

export interface InstanceNormAttributes {
  epsilon: number;
  format: 'NHWC'|'NCHW';
}

const createInstanceNormProgramInfo =
    (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => {
      const xShape = inputs[0].dims;
      const outputShape = xShape;
      const axis = 2;
      const normCount = ShapeUtil.sizeToDimension(xShape, axis);
      const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
      const components = getMaxComponents(normSize);
      const normPackedSize = normSize / components;
      const inputShape = [xShape[0], xShape[1], normPackedSize];
      const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
      const programUniforms: ProgramUniform[] =
          [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}];
      programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));

      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);
        const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
        const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
        const output = outputVariable('output', inputs[0].dataType, inputShape.length, components);
        const variables = [x, scale, bias, output];
        const dataType = x.type.value;
        const f32Type = components === 1 ? 'f32' : `vec${components}<f32>`;
        const workgroupSize = 64;

        const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}];
        return `
  var<workgroup> meanShared : f32;
  var<workgroup> squaredNormShared : f32;
  var<workgroup> workgroupShared : array<${f32Type}, ${workgroupSize}>;
  const workgroupSize = ${workgroupSize}u;
  ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
  ${shaderHelper.mainStart(workgroupSize)}
    let norm = global_idx / workgroupSize;
    let batch = norm / uniforms.x_shape[1];
    let channel = norm % uniforms.x_shape[1];
    let localIndex = local_id.x;

    // initialize workgroup memory
    var initial = ${f32Type}(0);
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
      initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
    }
    workgroupShared[localIndex] = initial;
    workgroupBarrier();

    // Calculate the mean of current channel data.
    for (var currSize = workgroupSize >> 1;  currSize > 0; currSize = currSize >> 1) {
      if (localIndex < currSize) {
        workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
      }
      workgroupBarrier();
    }
    if (localIndex == 0) {
      meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize);
    }
    workgroupBarrier();

    // reinitialize workgroup memory.
    initial = ${f32Type}(0);
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
      let deviation =  ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
      initial = initial + deviation * deviation;
    }
    workgroupShared[localIndex] = initial;
    workgroupBarrier();

    // Calculate the sum of square of deviation of current channel data.
    for (var currSize = workgroupSize >> 1;  currSize > 0; currSize = currSize >> 1) {
      if (localIndex < currSize) {
        workgroupShared[localIndex] = workgroupShared[localIndex] + workgroupShared[localIndex + currSize];
      }
      workgroupBarrier();
    }
    if (localIndex == 0) {
      squaredNormShared = ${sumVector('workgroupShared[0]', components)};
    }
    workgroupBarrier();

    let invStdDev = inverseSqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon}));
    let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
    let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
    for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) {
      let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
            f32Type}(channelShift));
      ${output.set('batch', 'channel', 'h', 'value')};
    }
  }`;
      };
      return {
        ...{name: 'InstanceNormalization'},
        // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
        shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies},
        getRunData: () => ({
          outputs: [
            {dims: outputShape, dataType: inputs[0].dataType},
          ],
          dispatchGroup: {x: normCount},
          programUniforms
        }),
        getShaderSource,
      };
    };

const computeMean =
    (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number,
     epsilon: number) => {
      const components = getMaxComponents(c);
      const WG = 64;
      // we will store channel scale and channel shift in [2, components] matrix
      // or in vec2 when components == 1
      const outputType = components === 1 ? 'vec2f' : `mat2x${components}f`;
      const sumCastType = components === 1 ? 'f32' : `vec${components}f`;
      const setOutputValue = (var1: string, var2: string) => `${outputType}(${var1}, ${var2})`;
      const unitsOfWork = n * c / components;
      const wgSize = Math.ceil(h / WG);

      const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
      const meanProgramUniforms: ProgramUniform[] = [
        {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h},
        {type: DataType.uint32, data: Math.floor(c / components)},
        {type: DataType.uint32, data: Math.floor(h * c / components)}
      ];

      const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
        const inputHelper = inputVariable('input', input.dataType, input.dims, components);
        return `
  ${shaderHelper.declareVariables(inputHelper)}
  @group(0) @binding(1) var<storage, read_write> output : array<${outputType}>;
  struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32};
  @group(0) @binding(2) var<uniform> uniforms: Uniforms;

  ${shaderHelper.mainStart(WG)}
    let currentImageNumber = global_idx / ${WG} / uniforms.C;
    let currentChannelNumber = (global_idx / ${WG}) % uniforms.C;
    let wgId = global_idx % ${WG};
    let wgOffset = wgId * uniforms.wg_size;
    if (wgOffset >= uniforms.H) {
        return;
    }
    let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H);

    let offset = currentImageNumber * uniforms.image_size + currentChannelNumber;
    var sum = ${fillVector('f32', components)};
    var squaredSum = ${fillVector('f32', components)};
    for (var i: u32 = wgOffset; i < wgMax; i++) {
        let value = ${sumCastType}(input[offset + i * uniforms.C]);
        sum += value;
        squaredSum += value * value;
    }
    output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
  }`;
      };

      const meanValues = context.compute(
          {
            name: 'InstanceNormComputeMean',
            shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies},
            getRunData: () => ({
              outputs: [
                {dims: [n, c, WG, 2], dataType: DataType.float},
              ],
              dispatchGroup: {x: n * c / components},
              programUniforms: meanProgramUniforms
            }),
            getShaderSource: getMeanShaderSource,
          },
          {inputs: [input], outputs: [-1]})[0];

      const programUniforms: ProgramUniform[] = [
        {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h},
        {type: DataType.uint32, data: Math.floor(c / components)},
        {type: DataType.uint32, data: Math.floor(WG * c / components)}
      ];
      const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components);
        const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components);
        return `
  @group(0) @binding(0) var<storage, read> input : array<${outputType}>;
  @group(0) @binding(1) var<storage, read> scale : array<${scaleHelper.type.storage}>;
  @group(0) @binding(2) var<storage, read> bias : array<${biasHelper.type.storage}>;
  @group(0) @binding(3) var<storage, read_write> output : array<${outputType}>;
  struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32};
  @group(0) @binding(4) var<uniform> uniforms: Uniforms;

  ${shaderHelper.mainStart()}
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')}
    let currentImageNumber = global_idx / uniforms.C;
    let currentChannelNumber = global_idx % uniforms.C;

    let offset = currentImageNumber * uniforms.image_size;
    var sum = ${fillVector('f32', components)};
    var squaredSum = ${fillVector('f32', components)};
    for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
        let value = input[offset + i + currentChannelNumber * ${WG}];
        sum += value[0];
        squaredSum += value[1];
    }
    sum = sum / f32(uniforms.H);
    squaredSum = squaredSum / f32(uniforms.H);
    let invStdDev = inverseSqrt(squaredSum - sum * sum + f32(${epsilon}));
    let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]);
    let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale;

    output[global_idx] = ${setOutputValue('channelScale', 'channelShift')};
  }`;
      };
      return context.compute(
          {
            name: 'InstanceNormComputeChannelScaleShift',
            // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon.
            shaderCache: {hint: `${components};${epsilon}`, inputDependencies},
            getRunData: () => ({
              outputs: [
                {dims: [n, c, 2], dataType: DataType.float},
              ],
              dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)},
              programUniforms
            }),
            getShaderSource,
          },
          {inputs: [meanValues, scale, bias], outputs: [-1]})[0];
    };

const createInstanceNormNHWCProgramInfo =
    (context: ComputeContext, inputs: readonly TensorView[], attributes: InstanceNormAttributes) => {
      const xShape = inputs[0].dims;
      const outputShape = xShape;
      const N = xShape[0];
      const C = xShape[xShape.length - 1];
      const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
      const components = getMaxComponents(C);
      const outputSize = ShapeUtil.size(outputShape) / components;
      const programUniforms: ProgramUniform[] =
          [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}];
      const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
      // first compute mean
      const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
        const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`;
        const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`;

        const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components);
        const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components);

        return `
  @group(0) @binding(0) var<storage, read> input : array<${inputHelper.type.storage}>;
  @group(0) @binding(1) var<storage, read> scaleInput : array<${scaleType}>;
  @group(0) @binding(2) var<storage, read_write> output : array<${outputHelper.type.storage}>;
  struct Uniforms {H: u32, C : u32};
  @group(0) @binding(3) var<uniform> uniforms: Uniforms;

  ${shaderHelper.mainStart()}
    let currentImageNumber = global_idx / (uniforms.C * uniforms.H);
    let currentChannelNumber = global_idx % uniforms.C;

    let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber;
    let scale = scaleInput[scaleOffset];
    output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1]));
  }`;
      };
      context.compute(
          {
            name: 'InstanceNormalizationNHWC',
            shaderCache: {hint: `${components}`, inputDependencies},
            getRunData: () => ({
              outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
              dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
              programUniforms
            }),
            getShaderSource,
          },
          {inputs: [inputs[0], channelScaleShift]});
    };

export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => {
  if (attributes.format === 'NHWC') {
    createInstanceNormNHWCProgramInfo(context, context.inputs, attributes);
  } else {
    context.compute(createInstanceNormProgramInfo(context.inputs, attributes));
  }
};
