// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils';

export const createNaiveMatmulProgramInfo =
    (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
     reshapedOutputShape?: readonly number[],
     isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
      const aShape = inputs[0].dims;
      const bShape = inputs[1].dims;

      const M = aShape[aShape.length - 2];
      const N = bShape[bShape.length - 1];
      const K = aShape[aShape.length - 1];
      const components = getMaxComponents(N);
      const aComponents = getMaxComponents(K);
      const outputNumber = getMaxComponents(M);
      const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
      const hasBias = inputs.length > 2;
      const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
      const batchSize = ShapeUtil.size(outerDims);
      const outputShapeInShader = [batchSize, M, N];

      const programUniforms: ProgramUniform[] = [
        {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N},
        {type: DataType.uint32, data: K}
      ];
      appendActivationUniformsData(activationAttributes, programUniforms);
      programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
      if (hasBias) {
        programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
      }
      programUniforms.push(...createTensorShapeVariables(outputShapeInShader));

      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
        const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
        const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
        const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
        const baseType = tensorTypeToWsglStorageType(output.type.tensor);
        const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType);
        const inputVariables = [a, b];
        let processBias = '';
        if (hasBias) {
          const biasComponents = isChannelsLast ? components : 1;
          inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
          processBias = `${
              isChannelsLast ? `value += bias[col / ${biasComponents}];` :
                               `value += ${output.type.value}(bias[row + i]);`}`;
        }

        const outerDimsA = aShape.slice(0, -2);
        const outerDimsB = bShape.slice(0, -2);
        const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
        const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
        const uniforms: UniformsArrayType = [
          {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'},
          {name: 'K', type: 'u32'}
        ];
        appendActivationUniforms(activationAttributes, uniforms);

        const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
          const rank = variable.rank;
          const name = variable.name;
          if (rank === 2) {
            return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`;
          }
          const batchRank = batchDims.rank;
          let resStr = `var ${name}_indices: ${variable.type.indices};`;
          for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
            resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`;
          }
          broadCastDims.forEach(i => {
            resStr += `\n${name}_indices[${i}] = 0;`;
          });
          resStr += `${name}_indices[${rank - 2}] = 0u;
                     ${name}_indices[${rank - 1}] = 0u;`;
          return resStr;
        };

        const calcResult = (): string => {
          let calcStr = `var a_data: ${a.type.value};`;
          for (let i = 0; i < aComponents; i++) {
            calcStr += `
              let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
          }
          for (let i = 0; i < outputNumber; i++) {
            calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;

            for (let j = 0; j < aComponents; j++) {
              calcStr += `
            values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${
                  i}]);\n`;
            }
          }
          return calcStr;
        };

        return `
  ${
            shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables(
                ...inputVariables, output)}
  ${shaderHelper.mainStart()}
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
    let col = (global_idx % (uniforms.N / ${components})) * ${components};
    var index1 = global_idx / (uniforms.N / ${components});
    let stride1 = uniforms.M / ${outputNumber};
    let row = (index1 % stride1) * ${outputNumber};
    let batch = index1 / stride1;

    ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
    ${getIndices(a, broadCastADims)}
    let a_offset = ${a.indicesToOffset('a_indices')};
    ${getIndices(b, broadCastBDims)}
    let b_offset = ${b.indicesToOffset('b_indices')};
    var values: array<${output.type.value}, ${outputNumber}>;
    for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
      ${calcResult()}
    }
    for (var i = 0u; i < ${outputNumber}u; i++) {
      var value = values[i];
      ${processBias}
      ${applyActivation}
      let cur_indices = ${output.type.indices}(batch, row + i, col);
      let offset = ${output.indicesToOffset('cur_indices')};
      ${output.setByOffset(`offset / ${components}`, 'value')};
    }
  }
  `;
      };
      return {
        name: 'MatMulNaive',
        shaderCache: {
          hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`,
          inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank']
        },
        getRunData: () => ({
          outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
          dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
          programUniforms
        }),
        getShaderSource
      };
    };

const validateInputs = (inputs: readonly TensorView[]): void => {
  if (!inputs || inputs.length !== 2) {
    throw new Error('MatMul requires 2 inputs.');
  }

  if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) {
    throw new Error('shared dimension does not match.');
  }
};

export const matMul = (context: ComputeContext): void => {
  validateInputs(context.inputs);
  const outputShape = BroadcastUtil.calcShape(context.inputs[0].dims, context.inputs[1].dims, true);
  if (!outputShape) {
    throw new Error('Can\'t use matmul on the given tensors');
  }
  const N = outputShape[outputShape.length - 1];
  const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
  if (N < 8 && K < 8) {
    context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
  } else {
    context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape));
  }
};
