import * as logSymbols from 'log-symbols';
import {
  CliFlagInput,
  CliLeaf,
  CliNumberArrayInput,
  CliNumberInput,
  CliOneOfInput,
  CliStringArrayInput,
  CliStringInput,
  CliTerseError,
  CliUsageError
} from '@alwaysai/alwayscli';
import { echo, keyMirror, logger } from '../../util';
import {
  ModelJson,
  ModelJsonParameters,
  modelPurposeValues,
  HailoArchitecture,
  hailoArchitectureEnum,
  hailoArchitectureValues,
  HailoFormat,
  hailoFormatValues,
  HailoPurpose,
  hailoPurposeEnum,
  modelFrameworkValues,
  OnnxArchitectureObjectDetection,
  onnxArchitectureObjectDetectionEnum,
  onnxArchitectureObjectDetectionValues,
  TensorrtArchitectureObjectDetection,
  tensorrtArchitectureObjectDetectionEnum,
  tensorrtArchitectureObjectDetectionValues,
  TensorrtDevice,
  tensorrtDeviceEnum,
  tensorrtDeviceValues,
  hailoPurposeValues,
  validateModel
} from '@alwaysai/model-configuration-schemas';
import { ModelPackageJsonFile } from '../../core/model/model-package-json-file';
import { CliAuthenticationClient } from '../../infrastructure';

const ModelDeviceValues = Object.values(keyMirror(tensorrtDeviceEnum));

const ModelArchitectureValues = Object.values(
  keyMirror({
    ...tensorrtArchitectureObjectDetectionEnum,
    ...hailoArchitectureEnum,
    ...onnxArchitectureObjectDetectionEnum
  })
);

export const modelConfigure = CliLeaf({
  name: 'configure',
  description: 'Generate or modify the model configuration of the application.',
  namedInputs: {
    name: CliStringInput({
      description: 'The model name to be used to generate the model ID',
      required: true
    }),
    // Required fields for all frameworks
    framework: CliOneOfInput({
      description: '',
      required: true,
      values: modelFrameworkValues
    }),
    model_file: CliStringInput({
      description: 'Path to model binary file.'
    }),
    mean: CliNumberArrayInput({
      description:
        'The average pixel intensity in the red, green, and blue channels of the training dataset.'
    }),
    scalefactor: CliNumberInput({
      description: 'Factor to scale pixel intensities by.'
    }),
    size: CliNumberArrayInput({
      description: 'The input size of the neural network.'
    }),
    purpose: CliOneOfInput({
      description: 'Computer vision purpose of the model.',
      required: true,
      values: modelPurposeValues
    }),
    crop: CliFlagInput({
      description: 'Crop before resize?'
    }),

    // Optional fields
    config_file: CliStringInput({
      description: 'Path to model structure.'
    }),
    label_file: CliStringInput({
      description: 'File containing labels for each class index.'
    }),
    colors_file: CliStringInput({
      description: 'File containing colors to be used by each class index.'
    }),
    swaprb: CliFlagInput({
      description: 'Swap red and blue channels after image blob generation'
    }),
    softmax: CliFlagInput({
      description:
        'Apply softmax to the output of the neural network? Boolean true/false'
    }),
    batch_size: CliNumberInput({
      description: 'The inference batch size of the model'
    }),
    output_layer_names: CliStringArrayInput({
      description: 'List of output layers provided in advance',
      placeholder: '<>'
    }),
    device: CliOneOfInput({
      description: 'Define the device model is intended to be used on.',
      values: ModelDeviceValues
    }),
    architecture: CliOneOfInput({
      description: 'Define the architecture type intended to be used.',
      values: ModelArchitectureValues
    }),
    quantize_input: CliFlagInput({
      description: 'Quantize input? Boolean true/false.'
    }),
    quantize_output: CliFlagInput({
      description: 'Quantize output? Boolean true/false.'
    }),
    input_format: CliOneOfInput({
      description: 'Define the input format of the data.',
      values: hailoFormatValues
    }),
    output_format: CliOneOfInput({
      description: 'Define the output format of the data.',
      values: hailoFormatValues
    })
  },
  async action(_, opts) {
    const {
      name,
      framework,
      model_file,
      config_file,
      mean,
      scalefactor,
      size,
      purpose,
      crop,
      label_file,
      colors_file,
      swaprb,
      softmax,
      batch_size,
      output_layer_names,
      device,
      architecture,
      quantize_input,
      quantize_output,
      input_format,
      output_format
    } = opts;

    let modelParameters: ModelJsonParameters;

    switch (framework) {
      case 'tensorflow': {
        modelParameters = {
          framework_type: 'tensorflow',
          model_file: model_file || '',
          label_file: label_file || '',
          mean: mean || [0, 0, 0],
          scalefactor: scalefactor || 1,
          size: size || [300, 300],
          purpose,
          crop,
          config_file: config_file || '',
          colors_file: colors_file || '',
          swaprb,
          softmax
        };
        break;
      }

      case 'caffe': {
        modelParameters = {
          framework_type: 'caffe',
          config_file: config_file || '',
          size: size || [300, 300],
          model_file: model_file || '',
          label_file: label_file || '',
          scalefactor: scalefactor || 1,
          mean: mean || [0, 0, 0],
          crop,
          swaprb,
          softmax,
          purpose,
          output_layer_names: output_layer_names || ['', '']
        };
        break;
      }

      case 'enet': {
        modelParameters = {
          framework_type: 'enet',
          size: size || [300, 300],
          model_file: model_file || '',
          label_file: label_file || '',
          colors_file: colors_file || '',
          scalefactor: scalefactor || 1,
          mean: mean || [0, 0, 0],
          crop,
          swaprb,
          purpose
        };
        break;
      }

      case 'darknet': {
        modelParameters = {
          framework_type: 'darknet',
          config_file: config_file || '',
          size: size || [300, 300],
          model_file: model_file || '',
          label_file: label_file || '',
          colors_file: colors_file || '',
          scalefactor: scalefactor || 1,
          mean: mean || [0, 0, 0],
          crop,
          swaprb,
          purpose,
          output_layer_names: output_layer_names || null
        };
        break;
      }

      case 'onnx': {
        if (purpose === 'ObjectDetection') {
          if (
            architecture &&
            !(architecture in onnxArchitectureObjectDetectionEnum)
          ) {
            throw new CliUsageError(
              `Architecture not supported! (${onnxArchitectureObjectDetectionValues})`
            );
          }
          modelParameters = {
            framework_type: 'onnx',
            size: size || [300, 300],
            model_file: model_file || '',
            label_file: label_file || '',
            colors_file: colors_file || '',
            scalefactor: scalefactor || 1,
            crop,
            swaprb,
            purpose,
            mean: mean || [0, 0, 0],
            output_layer_names: output_layer_names || null,
            architecture: architecture as OnnxArchitectureObjectDetection
          };
        } else {
          // Purpose other than ObjectDetection
          if (architecture) {
            throw new CliUsageError(
              `Parameter --architecture not supported for purpose ${purpose}`
            );
          }
          modelParameters = {
            framework_type: 'onnx',
            size: size || [300, 300],
            model_file: model_file || '',
            label_file: label_file || '',
            colors_file: colors_file || '',
            scalefactor: scalefactor || 1,
            crop,
            swaprb,
            purpose,
            mean: mean || [0, 0, 0],
            output_layer_names: output_layer_names || null
          };
        }
        if (batch_size) {
          modelParameters.batch_size = batch_size;
        }
        break;
      }

      case 'tensor-rt': {
        if (!batch_size) {
          throw new CliUsageError(`Parameter --batch_size required!`);
        }
        if (device && !(device in tensorrtDeviceEnum)) {
          throw new CliUsageError(
            `Device not supported! (${tensorrtDeviceValues})`
          );
        }
        if (purpose === 'ObjectDetection') {
          if (
            architecture &&
            !(architecture in tensorrtArchitectureObjectDetectionEnum)
          ) {
            throw new CliUsageError(
              `Architecture not supported! (${tensorrtArchitectureObjectDetectionValues})`
            );
          }
          modelParameters = {
            framework_type: 'tensor-rt',
            size: size || [300, 300],
            model_file: model_file || '',
            label_file: label_file || '',
            scalefactor: scalefactor || 1,
            mean: mean || [0, 0, 0],
            crop,
            swaprb,
            purpose,
            batch_size,
            colors_file: colors_file || '',
            device: device as TensorrtDevice,
            architecture: architecture as TensorrtArchitectureObjectDetection
          };
        } else {
          // Purpose other than ObjectDetection
          if (architecture) {
            throw new CliUsageError(
              `Parameter --architecture not supported for purpose ${purpose}`
            );
          }
          modelParameters = {
            framework_type: 'tensor-rt',
            size: size || [300, 300],
            model_file: model_file || '',
            label_file: label_file || '',
            scalefactor: scalefactor || 1,
            mean: mean || [0, 0, 0],
            crop,
            swaprb,
            purpose,
            batch_size,
            colors_file: colors_file || '',
            device: device as TensorrtDevice
          };
        }
        break;
      }

      case 'hailo': {
        if (!architecture || !(architecture in hailoArchitectureEnum)) {
          throw new CliUsageError(
            `Parameter --architecture required! (${hailoArchitectureValues})`
          );
        }
        if (!purpose || !(purpose in hailoPurposeEnum)) {
          throw new CliUsageError(
            `Parameter --purpose required! (${hailoPurposeValues})`
          );
        }

        modelParameters = {
          framework_type: 'hailo',
          architecture: architecture as HailoArchitecture,
          quantize_input: quantize_input || true,
          quantize_output: quantize_output || true,
          input_format: (input_format as HailoFormat) || 'auto',
          output_format: (output_format as HailoFormat) || 'auto',
          size: size || [300, 300],
          model_file: model_file || '',
          label_file: label_file || '',
          purpose: purpose as HailoPurpose,
          crop,
          swaprb,
          mean: mean || [0, 0, 0],
          scalefactor: scalefactor || 1
        };
        break;
      }
      default: {
        throw new Error('Unsupported framework.');
      }
    }

    const { username } = await CliAuthenticationClient().getInfo();

    const newModel: ModelJson = {
      accuracy: '',
      dataset: '',
      description: '',
      id: `${username}/${name}`,
      inference_time: null,
      license: '',
      mean_average_precision_top_1: null,
      mean_average_precision_top_5: null,
      public: false,
      website_url: '',
      model_parameters: modelParameters
    };

    validateModel(newModel);
    if (validateModel.errors) {
      echo(JSON.stringify(validateModel.errors, _, 2));
      throw new CliTerseError('Model package contents are invalid!');
    }

    const message = `Write alwaysai.model.json file`;
    const modelPkg = ModelPackageJsonFile(process.cwd());
    try {
      modelPkg.write(newModel);
      echo(`${logSymbols.success} ${message}`);
    } catch (exception) {
      echo(`${logSymbols.error} ${message}`);
      logger.error(exception);
      throw new CliTerseError(`Failed to write model package! ${exception}`);
    }
  }
});
