// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';

import {calculateOutputShape, ConvAttributes} from './conv';
import {createPackedIm2ColProgramInfoLoader} from './im2col-pack';
import {createPackedMatmulProgramInfoLoader} from './matmul-pack';

export const conv2DPackedPointwise =
    (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => {
      const xshape = inputs[0].dims;
      const kshape = inputs[1].dims;
      const outputShape =
          calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
      const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
      const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]);

      const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
      const matmulOutput = inferenceHandler.run(
          createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs);
      return inferenceHandler.reshapePacked(matmulOutput, outputShape);
    };

export const conv2DPacked =
    (inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvAttributes): Tensor => {
      const xshape = inputs[0].dims;
      const kshape = inputs[1].dims;
      const outputShape =
          calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);

      // run im2col
      const im2colOutput = inferenceHandler.run(
          createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
          [inputs[0]]);

      // reshape kernel
      const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]);

      // run matmul
      const matmulInputs =
          (inputs.length === 3) ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput];
      const matmulOutput = inferenceHandler.run(
          createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes), matmulInputs);

      // reshape output
      const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape);
      return outputReshaped;
    };
