// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

//  TODO support quantization bits not equal to 4
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
  k: number;
  n: number;
  accuracyLevel: number;
  bits: number;
  blockSize: number;
}

const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
  if (inputs.length < 3 || inputs.length > 4) {
    throw new Error('MatMulNBits requires 3 or 4 inputs');
  }
  const a = inputs[0];
  const aRank = a.dims.length;
  if (a.dims[aRank - 1] !== attributes.k) {
    throw new Error('The last dim of input shape does not match the k value');
  }
  const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
  const blobSize = attributes.blockSize / 8 * attributes.bits;
  const b = inputs[1];
  if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
    throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
  }
  const scales = inputs[2];
  const scalesShape = scales.dims;
  if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
    throw new Error('scales input size error.');
  }
  if (inputs.length === 4) {
    const zeroPoints = inputs[3];
    const zeroPointsShape = zeroPoints.dims;
    const expectedZeroPointsSize =
        attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
    if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
      throw new Error('zeroPoints input size error.');
    }
  }
};

export const createMatMulNBitsProgramInfo =
    (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
      const inputShape = inputs[0].dims;
      const aRank = inputShape.length;
      const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
      const m = inputShape[aRank - 2];
      const blobSize = attributes.blockSize / 8 * attributes.bits;
      const blobSizeInWords = blobSize / 4;
      const outputNumber = getMaxComponents(m);
      const components = getMaxComponents(attributes.n);
      const aComponents = getMaxComponents(attributes.k);
      const bComponents = getMaxComponents(blobSizeInWords);
      const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
      const programUniforms: ProgramUniform[] = [
        {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
        {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
        {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
      ];
      const aShape = inputShape.slice();
      aShape.splice(-1, 1, attributes.k / aComponents);
      const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
      bShape.splice(-1, 1, blobSizeInWords / bComponents);
      programUniforms.push(...createTensorShapeVariables(aShape));
      programUniforms.push(...createTensorShapeVariables(bShape));
      programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
      if (inputs.length === 4) {
        programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
      }
      const oShape = outputShape.slice();
      oShape.splice(-1, 1, attributes.n / components);
      programUniforms.push(...createTensorShapeVariables(oShape));
      const getShaderSource = (shaderHelper: ShaderHelper) => {
        const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
        const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
        const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
        const inputVariables = [a, b, scales];
        const zeroPoints =
            inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
        if (zeroPoints) {
          inputVariables.push(zeroPoints);
        }
        const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
        const uniforms: UniformsArrayType = [
          {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
          {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
        ];
        const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
        const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

        const qDqDataType = (() => {
          switch (aComponents) {
            case 1:
              return `array<${dataType}, 8>`;
            case 2:
              return `mat4x2<${dataType}>`;
            case 4:
              return `mat2x4<${dataType}>`;
            default:
              throw new Error(`${aComponents}-component is not supported.`);
          }
        })();

        const dequantizeImpl = `
        fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} {
          ${(() => {
          if (aComponents === 1) {
            return `var dequantized = ${qDqDataType}(${
                Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')});
              return dequantized;`;
          } else {
            return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')});
              return (quantized - zero_points) * scale;`;
          }
        })()}
        }`;
        const ortUnpack8x4snormImpl = `
        fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} {
          var quantized: ${qDqDataType};
          var offset: u32 = 0;
          let count: u32 = 4;
          for (var i: u32 = 0; i < 8u; i++) {
            var result = ${dataType}(extractBits(value, offset, count));
            ${(() => {
          switch (aComponents) {
            case 1:
              return 'quantized[i] = result;';
            case 2:
              return 'quantized[i / 2][i % 2] = result;';
            case 4:
              return 'quantized[i / 4][i % 4] = result;';
            default:
              throw new Error(`${aComponents}-component is not supported.`);
          }
        })()}
            offset += count;
          }
          return quantized;
        }`;

        const updateZeroPointIndex = zeroPoints ? `
          zero_point_offset += 4;
          if (zero_point_offset == 32) {
            zero_point_offset = 0;
            zero_point_index++;
            zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
          }` :
                                                  '';

        return `
        ${dequantizeImpl};
        ${ortUnpack8x4snormImpl};
        ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
        ${shaderHelper.mainStart()}
          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
          var output_values: array<${output.type.value}, ${outputNumber}>;
          var output_indices = ${output.offsetToIndices('global_idx')};
          var n = ${output.indicesGet('output_indices', aRank - 1)};
          var m = ${output.indicesGet('output_indices', aRank - 2)};
          var a_indices: ${a.type.indices} = output_indices;
          // Two zero points are packed into one byte because uniforms.bits <= 4.
          // zero_point_offset is either 0 or 4. It is bit offset within one byte.
          // TODO support zero_point_offset for bits > 4
          ${
            zeroPoints ? `
          var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4;
          var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
          var zero_point_offset: u32 = 0;` :
                         ''}
          var scale_index = n * ${nBlocksPerCol * components};
          var b_indices: ${b.type.indices};
          for (var c: u32 = 0; c < ${components}; c++) {
            ${b.indicesSet('b_indices', '0', `n * ${components} + c`)};
            var block_offset: u32 = 0;
            for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
              // The scale and zero points are computed per block.
              let scale = ${scales.getByOffset('scale_index')};
              // The default zero point is 8 for unsigned 4-bit quantization.
              let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
              ${b.indicesSet('b_indices', '1', 'block')};
              var word_offset: u32 = block_offset;
              for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
                ${b.indicesSet('b_indices', '2', 'word')};
                let b_data = ${b.getByIndices('b_indices')};
                for (var i: u32 = 0; i < ${bComponents}; i++) {
                  let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
                  let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value);
                  let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
                  // Number of B elements per 32-bit word is 32/bits = 32/4 = 8
                  var offset: u32 = word_offset;
                  for (var j: u32 = 0; j < 8/${aComponents}; j++) {
                    ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
                    for (var k: u32 = 0; k < ${outputNumber}u; k++) {
                      ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
                      let a_data = ${a.getByIndices('a_indices')};
                      output_values[k]${components > 1 ? '[c]' : ''} += ${
            aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'};
                    }
                    offset += ${aComponents};
                  }
                  word_offset += 8;
                }
              }
              scale_index++;
              ${updateZeroPointIndex}
              block_offset += uniforms.block_size;
            }
            // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
            ${
            zeroPoints ? `if (zero_point_offset % 8 > 0) {
                ${updateZeroPointIndex}
              }` :
                         ''}
            }
            for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
              ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
              ${output.setByIndices('output_indices', 'output_values[k]')}
            }
        }`;
      };
      return {
        name: 'MatMulNBits',
        shaderCache:
            {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
        getRunData: () => ({
          outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
          dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
          programUniforms
        }),
        getShaderSource
      };
    };

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
  validateInputs(context.inputs, attributes);
  context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
};

export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
    createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);
