// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';

export interface EinsumAttributes extends AttributeWithCacheKey {
  readonly equation: string;
}
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
//     "ij->ji" expresses matrix transpose
//      "ii->i" diagonal elements of a square matrix
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
// 'Z' or '...' to represent arbitrary dimensions.

const symbolPattern =
    '[a-zA-Z]|\\.\\.\\.';  // The pattern each symbol in each term in the symbolic equation should match
const termPattern = '(' + symbolPattern + ')+';   // The pattern each term in the symbolic equation should match
const termPatternOnly = '^' + termPattern + '$';  // The patterns only matchs a term begin to end.
const lhsPattern = '(' + termPattern + ',)*' + termPattern;  // The pattern the LHS should match
const lhsPatternOnly = '^' + lhsPattern + '$';               // The patterns only matchs a LHS begin to end.

interface SymbolInfo {
  count: number;           // Symbol corresponding to a dimmension of an input
  inputIndices: number[];  // Number of input variables the symbol corresponds to
  dimValue: number;        // Number of dimensions the symbol corresponds to
}

class EinsumTerm {
  constructor(inputIndex = -1) {
    this.symbolToIndices = new Map<string, number[]>();
    this.inputIndex = inputIndex;
  }

  // Add a symbol to the term
  addSymbol(symbol: string, index: number) {
    let value = this.symbolToIndices.get(symbol);
    if (value === undefined) {
      value = [index];
    } else {
      value.push(index);
    }
    this.symbolToIndices.set(symbol, value);
  }

  symbolToIndices: Map<string, number[]>;  // Map from symbol to dimensions of the input corresponding to the term
  inputIndex: number;                      // -1 for output and 0, 1, 2, ... for inputs
}

class EinsumEquation {
  constructor(inputs: readonly TensorView[], public readonly equation: string) {
    this.hasEllipsis = false;
    this.symbolToInfo = new Map<string, SymbolInfo>();
    this.lhs = new Array<EinsumTerm>();
    this.outputDims = [];
    // As rhs needs to be updated allow using let instead of const for both lhs and rhs.
    // eslint-disable-next-line prefer-const
    let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
    if (!lhs.match(RegExp(lhsPatternOnly))) {
      throw new Error('Invalid LHS term');
    }
    const inputTerms = lhs.split(',');
    inputTerms.forEach((inputTerm, index) => {
      const dims = inputs[index].dims.slice();
      if (!inputTerm.match(RegExp(termPatternOnly))) {
        throw new Error('Invalid LHS term');
      }
      const einsumTerm = this.processTerm(inputTerm, true, dims, index);
      this.lhs.push(einsumTerm);
    });

    // Initialize the RHS if not specified
    if (rhs === '') {
      // Construct RHS from LHS terms/symbols
      rhs += [...this.symbolToInfo.entries()]
                 .filter(([sym, info]) => (info.count === 1 || sym === '...'))
                 .map(([sym]) => sym)
                 .join('');
    } else {
      if (!rhs.match(RegExp(termPattern))) {
        throw new Error('Invalid RHS');
      }
    }

    // Compute output dims
    const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
    rhsSymbols?.forEach((symbol) => {
      if (symbol === '...') {
        this.outputDims = this.outputDims.concat(this.ellipsisDims);
      } else {
        const info = this.symbolToInfo.get(symbol);
        if (info === undefined) {
          throw new Error('Invalid RHS symbol');
        }
        this.outputDims.push(info.dimValue);
      }
    });
    this.rhs = this.processTerm(rhs, false, this.outputDims);
  }  // End of EinsumEqation constructor

  // Add a symbol to the equation
  addSymbol(symbol: string, dimValue: number, inputIndex: number) {
    let info = this.symbolToInfo.get(symbol);
    if (info !== undefined) {
      if (info.dimValue !== dimValue && info.count !== 1) {
        throw new Error('Dimension mismatch');
      } else {
        info.count++;
        info.inputIndices.push(inputIndex);
      }
    } else {
      info = {count: 1, dimValue, inputIndices: [inputIndex]};
    }
    this.symbolToInfo.set(symbol, info);
  }

  // Process one input/output term
  processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
    const rank = dims.length;
    let ellipsis = false;
    let ellipsisDims = [];
    let nextDim = 0;
    // For output empty string is allowed because the output may be reduced to a scalar value
    if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) {
      throw new Error('Invalid LHS term');
    }
    const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
    const einsumTerm = new EinsumTerm(index);
    // symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
    indexSymbols?.forEach((symbol: string, i: number) => {
      if (symbol === '...') {
        if (ellipsis) {
          throw new Error('Only one ellipsis is allowed per input term');
        }
        ellipsis = true;
        const ellipsisDimLength = rank - indexSymbols.length + 1;
        if (ellipsisDimLength < 0) {
          throw new Error('Ellipsis out of bounds');
        }
        ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
        if (this.hasEllipsis) {
          if (this.ellipsisDims.length !== ellipsisDims.length ||
              this.ellipsisDims.toString() !== ellipsisDims.toString()) {
            throw new Error('Ellipsis dimensions mismatch');
          }
        } else if (isInput) {
          this.hasEllipsis = true;
          this.ellipsisDims = ellipsisDims;
        } else {
          throw new Error('Ellipsis must be specified in the LHS');
        }
        // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
        for (let j = 0; j < ellipsisDims.length; j++) {
          const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
          einsumTerm.addSymbol(symbol, i + j);
          this.addSymbol(symbol, dims[nextDim++], index);
        }
      } else {
        einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
        this.addSymbol(symbol, dims[nextDim++], index);
      }
    });
    return einsumTerm;
  }

  symbolToInfo: Map<string, SymbolInfo>;  // All symbols in the equation
  hasEllipsis: boolean;                   // The equation has ellipsis or not
  ellipsisDims: number[];                 // The dimensions of the equation ellipsis corresponds to.
  lhs: EinsumTerm[];                      // Terms on the left-hand side of the equation
  rhs: EinsumTerm;                        // Term on the right-hand side of the equation
  outputDims: number[];                   // Output dimensions of the equation
}  // End of class EinsumEquation

const appendMax = (name: string): string => name + '_max';

const createEinsumProgramInfo =
    (inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
     outputShape: readonly number[]): ProgramInfo => {
      const ranks = inputShapes.map((dims) => dims.length);
      const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
      const outputSize = ShapeUtil.size(outputShape);
      const output = outputVariable('output', dataType, outputShape.length);
      const uniformsSymbols =
          [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const idxCopy: string[] = [];
        const initProd = 'var prod = 1.0;';
        const initSum = 'var sum = 0.0;';
        const updateSum = 'sum += prod;';
        const reduceOpsSetIndices: string[] = [];
        const reduceOpsLoopHeaders: string[] = [];
        const reduceOpsLoopFooters: string[] = [];
        const reduceOpCompute: string[] = [];
        const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
        einsumEquation.symbolToInfo.forEach((info, symbol) => {
          if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
            const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
            if (outputIndex !== undefined) {
              einsumEquation.lhs.forEach((term, i) => {
                if (info.inputIndices.includes(i)) {
                  const indices = term.symbolToIndices.get(symbol);
                  if (indices === undefined) {
                    throw new Error('Invalid symbol error');
                  }
                  indices.forEach((index) => {
                    idxCopy.push(`${
                        inputVars[i].indicesSet(
                            `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
                  });
                }
              });
            }
          } else {
            einsumEquation.lhs.forEach((term, i) => {
              if (info.inputIndices.includes(i)) {
                const indices = term.symbolToIndices.get(symbol);
                if (indices === undefined) {
                  throw new Error('Invalid symbol error');
                }
                indices.forEach((index) => {
                  reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
                });
                reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
              }
            });
            reduceOpsLoopHeaders.push(
                `for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`);
            reduceOpsLoopFooters.push('}');
          }
        });
        const reduceOps = isReduceOpsWithoutLoop ?
            [
              ...idxCopy,
              `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
            ] :
            [
              ...idxCopy,
              initSum,
              ...reduceOpsLoopHeaders,
              ...reduceOpsSetIndices,
              initProd,
              ...reduceOpCompute,
              updateSum,
              ...reduceOpsLoopFooters,
            ];
        return `
            ${
            shaderHelper
                .registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'})))
                .registerUniform('outputSize', 'u32')
                .declareVariables(...inputVars, output)}

            ${shaderHelper.mainStart()}
            ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
            var outputIndices = ${output.offsetToIndices('global_idx')};
            ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
            ${reduceOps.join('\n')};
            ${output.setByOffset('global_idx', 'sum')};
          }`;
      };
      return {
        name: 'Einsum',
        shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
        getRunData: () => {
          // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
          // filter is added to make sure that dimValue is never 0.
          const programUniformsInit: ProgramUniform[] =
              uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
                  .map(
                      (symbol) =>
                          ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
          programUniformsInit.push({type: DataType.uint32, data: outputSize});
          const programUniforms: ProgramUniform[] =
              inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
                  .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
          programUniforms.push(...createTensorShapeVariables(outputShape));
          return ({
            outputs: [{dims: outputShape, dataType}],
            dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
            programUniforms
          });
        },
        getShaderSource,
      };
    };

export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
  const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
  const outputShape = einsumEquation.outputDims;
  const inputShapes = context.inputs.map((input, _) => input.dims);
  context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};

export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
  const equation = (attributes.equation as string).replace(/\s+/g, '');
  return createAttributeWithCacheKey({equation});
};
