// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {env} from 'onnxruntime-common';

import {DataType} from '../../../wasm-common';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createTensorShapeVariables, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common';

const validateInputsContent = (start: number, limit: number, delta: number): void => {
  const sameStartLimit = start === limit;
  const increasingRangeNegativeStep = start < limit && delta < 0;
  const decreasingRangePositiveStep = start > limit && delta > 0;

  if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) {
    throw new Error('Range these inputs\' contents are invalid.');
  }
};

const createRangeProgramInfo = (start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => {
  const numElements = Math.abs(Math.ceil((limit - start) / delta));
  const outputShape: number[] = [numElements];
  const outputSize = numElements;
  const programUniforms: ProgramUniform[] = [
    {type: DataType.uint32, data: outputSize}, {type: dataType, data: start}, {type: dataType, data: delta},
    ...createTensorShapeVariables(outputShape)
  ];

  const getShaderSource = (shaderHelper: ShaderHelper) => {
    const output = outputVariable('output', dataType, outputShape.length);
    const wgslType = output.type.value;
    const uniforms: UniformsArrayType = [
      {name: 'outputSize', type: 'u32'}, {name: 'start', type: wgslType as UniformDataElementType},
      {name: 'delta', type: wgslType as UniformDataElementType}
    ];
    return `
        ${shaderHelper.registerUniforms(uniforms).declareVariables(output)}
        ${shaderHelper.mainStart()}
        ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
        output[global_idx] = uniforms.start + ${wgslType}(global_idx) * uniforms.delta;
      }`;
  };

  return {
    name: 'Range',
    shaderCache: {hint: `${dataType}`},
    getShaderSource,
    getRunData: () => ({
      outputs: [{dims: outputShape, dataType}],
      dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
      programUniforms
    })
  };
};

export const range = (context: ComputeContext): void => {
  let start = 0;
  let limit = 0;
  let delta = 0;
  if (context.inputs[0].dataType === DataType.int32) {
    start = context.inputs[0].getInt32Array()[0];
    limit = context.inputs[1].getInt32Array()[0];
    delta = context.inputs[2].getInt32Array()[0];
  } else if (context.inputs[0].dataType === DataType.float) {
    start = context.inputs[0].getFloat32Array()[0];
    limit = context.inputs[1].getFloat32Array()[0];
    delta = context.inputs[2].getFloat32Array()[0];
  }
  if (env.webgpu.validateInputContent) {
    validateInputsContent(start, limit, delta);
  }

  context.compute(createRangeProgramInfo(start, limit, delta, context.inputs[0].dataType), {inputs: []});
};
