// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';

import {
  createTensorShapeVariables,
  getMaxComponents,
  inputVariable,
  outputVariable,
  ShaderHelper,
  tensorTypeToWsglStorageType,
} from './common';

//  TODO support quantization bits not equal to 4
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
  k: number;
  n: number;
  accuracyLevel: number;
  bits: number;
  blockSize: number;
}

const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
  if (inputs.length < 3 || inputs.length > 4) {
    throw new Error('MatMulNBits requires 3 or 4 inputs');
  }
  const a = inputs[0];
  const aRank = a.dims.length;
  if (a.dims[aRank - 1] !== attributes.k) {
    throw new Error('The last dim of input shape does not match the k value');
  }
  const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
  const blobSize = (attributes.blockSize / 8) * attributes.bits;
  const b = inputs[1];
  if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
    throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
  }
  const scales = inputs[2];
  const scalesShape = scales.dims;
  if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
    throw new Error('scales input size error.');
  }
  if (inputs.length === 4) {
    const zeroPoints = inputs[3];
    const zeroPointsShape = zeroPoints.dims;

    // This assumes zero points are packed.
    // Unpack format (zero point has same data type and shape as scale) is not supported by webgpu.
    const expectedZeroPointsSize =
      attributes.n * (attributes.bits === 8 ? nBlocksPerCol : Math.floor((nBlocksPerCol * attributes.bits + 7) / 8));
    if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
      throw new Error('zeroPoints input size error.');
    }
  }
};

export const createMatMulNBitsProgramInfo = (
  inputs: readonly TensorView[],
  attributes: MatMulNBitsAttributes,
): ProgramInfo => {
  const inputShape = inputs[0].dims;
  const aRank = inputShape.length;
  const dimAOuter = inputShape[aRank - 2];
  const dimInner = attributes.k;
  const dimBOuter = attributes.n;
  const batchDims = inputShape.slice(0, aRank - 2);
  const batchSize = ShapeUtil.size(batchDims);
  const blobSize = inputs[1].dims[2];
  const blobSizeInWords = blobSize / 4;
  const dataType = inputs[0].dataType;
  const aComponents = getMaxComponents(attributes.k);
  const bComponents = getMaxComponents(blobSizeInWords);
  const components = getMaxComponents(dimBOuter);
  const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
  const outputNumber = dimAOuter > 1 && (dimBOuter / components) % 2 === 0 ? 2 : 1;
  const dispatchSize = ShapeUtil.size(outputShape) / components / outputNumber;

  const workgroupSize = 64;

  const programUniforms: ProgramUniform[] = [];
  const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
  const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
  bShape.splice(-1, 1, blobSizeInWords / bComponents);
  programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
  programUniforms.push(...createTensorShapeVariables(bShape));
  programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
  if (inputs.length === 4) {
    programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
  }
  const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
  programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

  const getShaderSource = (shaderHelper: ShaderHelper) => {
    const inputRank = inputShapeTemp.length;
    const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
    const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
    const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
    const inputVariables = [a, b, scales];
    const zeroPoints =
      inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
    if (zeroPoints) {
      inputVariables.push(zeroPoints);
    }
    const outputRank = outputShapeTemp.length;
    const output = outputVariable('output', inputs[0].dataType, outputRank, components);
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

    const qDqDataType = (() => {
      switch (aComponents) {
        case 1:
          return `array<${dataType}, 8>`;
        case 2:
          return `mat4x2<${dataType}>`;
        case 4:
          return `mat2x4<${dataType}>`;
        default:
          throw new Error(`${aComponents}-component is not supported.`);
      }
    })();

    // Number of quantized values per u32 word and passes needed (each pass extracts 8 values).
    const valuesPerWord = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
    const passesPerWord = Math.floor(valuesPerWord / 8); // Q4=1, Q2=2

    const processOneWord = (): string => {
      let calcStr = '';
      for (let pass = 0; pass < passesPerWord; pass++) {
        // Each pass processes 8 values from the current u32 word.
        // For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6).
        const lowerShift = pass * attributes.bits * 4; // bit offset for lower group within each byte
        const upperShift = lowerShift + attributes.bits;
        calcStr += `
          // reuse a data (pass ${pass})
            var input_offset${pass > 0 ? pass : ''} = ${pass === 0 ? a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`) : `input_offset`};
            var a_data${pass > 0 ? pass : ''}: ${qDqDataType};
            for (var j${pass > 0 ? pass : ''}: u32 = 0; j${pass > 0 ? pass : ''} < ${8 / aComponents}; j${pass > 0 ? pass : ''}++) {
              a_data${pass > 0 ? pass : ''}[j${pass > 0 ? pass : ''}] = ${a.getByOffset(`input_offset${pass > 0 ? pass : ''}`)};
              input_offset${pass > 0 ? pass : ''}++;
            }
          `;
        for (let c = 0; c < components * outputNumber; c++) {
          calcStr += `
            b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`};
            ${
              attributes.bits === 2
                ? `{
              let half_word = b_value >> ${pass * 16}u;
              let byte_lo = half_word & 0xFFu;
              let byte_hi = (half_word >> 8u) & 0xFFu;
              let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
              b_value_lower = unpack4xU8(spread_word & b_mask);
              b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask);
            }`
                : `b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & b_mask);
            b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & b_mask);`
            }
            b_quantized_values = ${qDqDataType}(${Array.from(
              { length: 4 },
              (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
            ).join(', ')});
            b_dequantized_values = ${(() => {
              if (aComponents === 1) {
                return `${qDqDataType}(${Array.from(
                  { length: 8 },
                  (_, i) => `(b_quantized_values[${i}] - ${zeroPoints ? `zero_point${c}` : 'zero_point'}) * scale${c}`,
                ).join(', ')});`;
              } else {
                return `(b_quantized_values - ${qDqDataType}(${Array(8)
                  .fill(`${zeroPoints ? `zero_point${c}` : 'zero_point'}`)
                  .join(',')})) * scale${c};`;
              }
            })()};
            workgroup_shared[local_id.x * ${outputNumber} + ${Math.floor(c / components)}]${components > 1 ? `[${c % components}]` : ''} += ${Array.from(
              { length: 8 / aComponents },
              (_, i) =>
                `${
                  aComponents === 1
                    ? `a_data${pass > 0 ? pass : ''}[${i}] * b_dequantized_values[${i}]`
                    : `dot(a_data${pass > 0 ? pass : ''}[${i}], b_dequantized_values[${i}])`
                }`,
            ).join(' + ')};
          `;
        }
      }
      return calcStr;
    };
    const prepareScaleAndZeroPoint = (): string => {
      let calcStr = `
            var col_index = col * ${components};
            ${
              zeroPoints
                ? `
            let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
            let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
            var zero_point_byte_count: u32;
            var zero_point_word_index: u32;
            var zero_point_byte_offset: u32;
            let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
            var zero_point_bits_offset: u32;
            var zero_point_word: u32;`
                : `
            // The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
            let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
            }
            `;
      for (let c = 0; c < components * outputNumber; c++) {
        calcStr += `
            let scale${c} = ${scales.getByOffset(`col_index * nBlocksPerCol + block`)};
            ${
              zeroPoints
                ? `
            zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
            zero_point_word_index = zero_point_byte_count >> 0x2u;
            zero_point_byte_offset = zero_point_byte_count & 0x3u;
            zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
            zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
            let zero_point${c} = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
                : ''
            }
            col_index += 1;`;
      }
      return calcStr;
    };
    const prepareBData = (): string => {
      let calcStr = `col_index = col * ${components};`;
      for (let c = 0; c < components * outputNumber; c++) {
        calcStr += `
            let b${c}_data = ${b.getByIndices(`${b.type.indices}(col_index, block, word)`)};
            col_index += 1;`;
      }
      calcStr += `
            var b_value: u32;
            let b_mask: u32 = ${attributes.bits === 2 ? '0x03030303u' : '0x0F0F0F0Fu'};
            var b_value_lower: vec4<u32>;
            var b_value_upper: vec4<u32>;
            var b_quantized_values: ${qDqDataType};
            var b_dequantized_values: ${qDqDataType};`;
      return calcStr;
    };
    return `
        var<workgroup> workgroup_shared: array<${output.type.value}, ${outputNumber * workgroupSize}>;
        ${shaderHelper.declareVariables(...inputVariables, output)}
        ${shaderHelper.mainStart([workgroupSize, 1, 1])}
          let output_indices = ${output.offsetToIndices(`(global_idx / ${workgroupSize}) * ${outputNumber}`)};
          let col = output_indices[2];
          let row = output_indices[1];
          let batch = output_indices[0];
          let nBlocksPerCol = uniforms.b_shape[1];

          for (var block = local_id.x; block < nBlocksPerCol; block += ${workgroupSize}) {
            //process one block
            var word_offset: u32 = block * ${attributes.blockSize / aComponents};
            ${prepareScaleAndZeroPoint()}
            for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
              ${prepareBData()}
              for (var i: u32 = 0; i < ${bComponents}; i++) {
                ${processOneWord()}
                word_offset += ${valuesPerWord / aComponents};
              }
            }
          }
          workgroupBarrier();

          if (local_id.x < ${outputNumber}) {
            var output_value: ${output.type.value} = ${output.type.value}(0);
            var workgroup_shared_offset: u32 = local_id.x;
            for (var b: u32 = 0u; b < ${workgroupSize}u; b++) {
              output_value += workgroup_shared[workgroup_shared_offset];
              workgroup_shared_offset += ${outputNumber};
            }
            ${output.setByIndices(`${output.type.indices}(batch, row, col + local_id.x)`, 'output_value')};
          }
        }`;
  };
  return {
    name: 'MatMulNBits',
    shaderCache: {
      hint: `${attributes.blockSize};${attributes.bits};${aComponents};${bComponents};${components};${outputNumber};${workgroupSize}`,
      inputDependencies: Array(inputs.length).fill('rank'),
    },
    getRunData: () => ({
      outputs: [{ dims: outputShape, dataType }],
      dispatchGroup: { x: dispatchSize },
      programUniforms,
    }),
    getShaderSource,
  };
};

// Currently, only support blockSize = 32.
export const createMatMulNBitsBlockSize32ProgramInfo = (
  inputs: readonly TensorView[],
  attributes: MatMulNBitsAttributes,
): ProgramInfo => {
  const inputShape = inputs[0].dims;
  const aRank = inputShape.length;
  const dimAOuter = inputShape[aRank - 2];
  const dimInner = attributes.k;
  const dimBOuter = attributes.n;
  const batchDims = inputShape.slice(0, aRank - 2);
  const batchSize = ShapeUtil.size(batchDims);
  const blobSize = inputs[1].dims[2];
  const blobSizeInWords = blobSize / 4;
  const dataType = inputs[0].dataType;
  const aComponents = getMaxComponents(attributes.k);
  const bComponents = getMaxComponents(blobSizeInWords);
  const outputShape = batchDims.concat([dimAOuter, dimBOuter]);

  const workgroupSize = 128;
  const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1;
  const workgroupX = workgroupSize / workgroupY;
  const valuesPerWordBs32 = Math.floor(32 / attributes.bits); // Q4=8, Q2=16
  const tileSize = workgroupX * bComponents * valuesPerWordBs32; // each uint32 has valuesPerWord data.
  const aLengthPerTile = tileSize / aComponents;
  const blocksPerTile = tileSize / attributes.blockSize;
  const dispatchSize = ShapeUtil.size(outputShape) / workgroupY;

  const programUniforms: ProgramUniform[] = [];
  const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
  const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
  bShape.splice(-1, 1, blobSizeInWords / bComponents);
  programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
  programUniforms.push(...createTensorShapeVariables(bShape));
  programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
  if (inputs.length === 4) {
    programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
  }
  const outputShapeTemp = [batchSize, dimAOuter, dimBOuter];
  programUniforms.push(...createTensorShapeVariables(outputShapeTemp));

  const getShaderSource = (shaderHelper: ShaderHelper) => {
    const inputRank = inputShapeTemp.length;
    const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
    const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
    const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
    const inputVariables = [a, b, scales];
    const zeroPoints =
      inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
    if (zeroPoints) {
      inputVariables.push(zeroPoints);
    }
    const outputRank = outputShapeTemp.length;
    const output = outputVariable('output', inputs[0].dataType, outputRank);
    const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
    const readA = () => {
      switch (aComponents) {
        case 1:
          return `
          let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);
          let a_data1 = vec4<${dataType}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`;
        case 2:
          return `
          let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1]);
          let a_data1 = vec4<${dataType}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`;
        case 4:
          return `
          let a_data0 = sub_a[word_offset];
          let a_data1 = sub_a[word_offset + 1];`;
        default:
          throw new Error(`${aComponents}-component is not supported.`);
      }
    };

    return `
        var<workgroup> sub_a: array<${a.type.value}, ${aLengthPerTile}>;
        var<workgroup> inter_results: array<array<${output.type.value}, ${workgroupX}>, ${workgroupY}>;
        ${shaderHelper.declareVariables(...inputVariables, output)}
        ${shaderHelper.mainStart([workgroupX, workgroupY, 1])}
          let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)};
          let col = output_indices[2];
          let row = output_indices[1];
          let batch = output_indices[0];
          let n_blocks_per_col = uniforms.b_shape[1];
          let num_tiles =  (n_blocks_per_col - 1) / ${blocksPerTile} + 1;

          // Loop over shared dimension.
          for (var tile: u32 = 0; tile < num_tiles; tile += 1) {
            let a_col_start = tile * ${aLengthPerTile};
            // load one tile A data into shared memory.
            for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize})
            {
              let a_col = a_col_start + a_offset;
              if (a_col < uniforms.a_shape[2])
              {
                sub_a[a_offset] = ${a.getByIndices(`${a.type.indices}(batch, row, a_col)`)};
              } else {
                sub_a[a_offset] = ${a.type.value}(0);
              }
            }
            workgroupBarrier();

            // each thread process one block
            let b_row = col + local_id.y;
            let block = tile * ${blocksPerTile} + local_id.x;
            ${
              zeroPoints
                ? `
            let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u;
            let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte;
            let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte);
            let zero_point_word_index = zero_point_byte_count >> 0x2u;
            let zero_point_byte_offset = zero_point_byte_count & 0x3u;
            let zero_point_sub_offset: u32 = block % zero_point_values_per_byte;
            let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u);
            let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;
            let zero_point = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});`
                : `
            // The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization.
            let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});`
            }
            let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)};
            let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)};
            var word_offset = local_id.x * ${attributes.blockSize / aComponents};
            for (var i: u32 = 0; i < ${bComponents}; i++) {
              let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`};
              ${(() => {
                const passesPerWordBs32 = Math.floor(valuesPerWordBs32 / 8);
                let code = '';
                for (let pass = 0; pass < passesPerWordBs32; pass++) {
                  const lowerShift = pass * attributes.bits * 4;
                  const upperShift = lowerShift + attributes.bits;
                  code += `
              ${readA()}
              {${
                attributes.bits === 2
                  ? `
                let half_word = b_value >> ${pass * 16}u;
                let byte_lo = half_word & 0xFFu;
                let byte_hi = (half_word >> 8u) & 0xFFu;
                let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u);
                let b_value_lower = unpack4xU8(spread_word & 0x03030303u);
                let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);`
                  : `
                let b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & 0x0F0F0F0Fu);
                let b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & 0x0F0F0F0Fu);`
              }
                let b_quantized_values = mat2x4<${dataType}>(${Array.from(
                  { length: 4 },
                  (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`,
                ).join(', ')});
                let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale;
                inter_results[local_id.y][local_id.x] += ${Array.from(
                  { length: 2 },
                  (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`,
                ).join(' + ')};
              }
              word_offset += ${8 / aComponents};`;
                }
                return code;
              })()}
            }
            workgroupBarrier();
          }

          if (local_idx < ${workgroupY}) {
            var output_value: ${output.type.value} = ${output.type.value}(0);
            for (var b = 0u; b < ${workgroupX}; b++) {
              output_value += inter_results[local_idx][b];
            }
            if (col + local_idx < uniforms.output_shape[2])
            {
              ${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')}
            }
          }
        }`;
  };
  return {
    name: 'BlockwiseMatMulNBits32',
    shaderCache: {
      hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`,
      inputDependencies: Array(inputs.length).fill('rank'),
    },
    getRunData: () => ({
      outputs: [{ dims: outputShape, dataType }],
      dispatchGroup: { x: dispatchSize },
      programUniforms,
    }),
    getShaderSource,
  };
};

export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
  validateInputs(context.inputs, attributes);
  if (
    attributes.blockSize === 32 &&
    context.adapterInfo.isVendor('intel') &&
    context.adapterInfo.isArchitecture('gen-12lp')
  ) {
    context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes));
  } else {
    context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
  }
};

export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
  createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);
