// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="webnn/webnn.d.ts" />

import { Env, Tensor } from 'onnxruntime-common';

import { DataType, tensorDataTypeStringToEnum } from '../wasm-common';
import { getInstance } from '../wasm-factory';

import { createView } from './tensor-view';
import { TensorId, createTensorManager, convertDataToInt32 } from './webnn/tensor-manager';
import { configureLogger, LOG_DEBUG } from './log';

/*
 * TensorProto::data_type to WebNN OperandType mapping.
 */
const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
  [DataType.float, 'float32'],
  [DataType.float16, 'float16'],
  [DataType.int32, 'int32'],
  [DataType.uint32, 'uint32'],
  [DataType.int64, 'int64'],
  [DataType.uint64, 'uint64'],
  [DataType.int4, 'int4'],
  [DataType.uint4, 'uint4'],
  [DataType.int8, 'int8'],
  [DataType.uint8, 'uint8'],
  [DataType.bool, 'uint8'],
]);

type MLContextEntry = {
  gpuDevice?: GPUDevice;
  options?: MLContextOptions;
  mlContext: MLContext;
};

const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => {
  if (a === b) {
    return true;
  }
  if (a === undefined || b === undefined) {
    return false;
  }
  const aKeys = Object.keys(a).sort() as Array<keyof typeof a>;
  const bKeys = Object.keys(b).sort() as Array<keyof typeof b>;
  return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]);
};

/**
 * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
 * of the current MLContext being used by the sessions.
 */
export class WebNNBackend {
  /**
   * Tensor managers for each session.
   */
  private tensorManager = createTensorManager(this);
  /**
   * Maps from session id to MLContexts.
   */
  private mlContextBySessionId = new Map<number, MLContext>();
  /**
   * Maps from MLContext to session ids.
   */
  private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
  /**
   * Cache of MLContexts.
   */
  private mlContextCache: MLContextEntry[] = [];
  /**
   * Current session id.
   */
  private activeSessionId?: number;
  /**
   * Maps from session id to list of graph inputs.
   */
  private sessionGraphInputs: Map<number, string[]> = new Map();
  /**
   * Maps from session id to list of graph outputs.
   */
  private sessionGraphOutputs: Map<number, string[]> = new Map();
  /**
   * Temporary graph inputs for the current session.
   * These inputs will be registered when the session is created.
   */
  private temporaryGraphInputs: string[] = [];
  /**
   * Temporary graph outputs for the current session.
   * These outputs will be registered when the session is created.
   */
  private temporaryGraphOutputs: string[] = [];
  /**
   * Temporary tensors for the current session.
   */
  private temporarySessionTensorIds: Map<number, TensorId[]> = new Map();
  /**
   * Maps from session id to MLOpSupportLimits.
   */
  private mlOpSupportLimitsBySessionId = new Map<number, MLOpSupportLimits>();

  constructor(env: Env) {
    configureLogger(env.logLevel!, !!env.debug);
  }

  public get currentSessionId(): number {
    if (this.activeSessionId === undefined) {
      throw new Error('No active session');
    }
    return this.activeSessionId;
  }

  public onRunStart(sessionId: number): void {
    LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`);
    this.activeSessionId = sessionId;
  }

  public onRunEnd(sessionId: number): void {
    LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`);
    const tensorIds = this.temporarySessionTensorIds.get(sessionId);
    if (!tensorIds) {
      return;
    }
    for (const tensorId of tensorIds) {
      LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`);
      this.tensorManager.releaseTensorId(tensorId);
    }
    this.temporarySessionTensorIds.delete(sessionId);
    this.activeSessionId = undefined;
  }

  public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
    if (optionsOrDevice instanceof GPUDevice) {
      const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
      if (mlContextIndex !== -1) {
        return this.mlContextCache[mlContextIndex].mlContext;
      } else {
        const mlContext = await navigator.ml.createContext(optionsOrDevice);
        this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext });
        return mlContext;
      }
    } else if (optionsOrDevice === undefined) {
      const mlContextIndex = this.mlContextCache.findIndex(
        (entry) => entry.options === undefined && entry.gpuDevice === undefined,
      );
      if (mlContextIndex !== -1) {
        return this.mlContextCache[mlContextIndex].mlContext;
      } else {
        const mlContext = await navigator.ml.createContext();
        this.mlContextCache.push({ mlContext });
        return mlContext;
      }
    }

    const mlContextIndex = this.mlContextCache.findIndex((entry) =>
      compareMLContextOptions(entry.options, optionsOrDevice),
    );
    if (mlContextIndex !== -1) {
      return this.mlContextCache[mlContextIndex].mlContext;
    } else {
      const mlContext = await navigator.ml.createContext(optionsOrDevice);
      this.mlContextCache.push({ options: optionsOrDevice, mlContext });
      return mlContext;
    }
  }

  public registerMLContext(sessionId: number, mlContext: MLContext): void {
    this.mlContextBySessionId.set(sessionId, mlContext);
    let sessionIds = this.sessionIdsByMLContext.get(mlContext);
    if (!sessionIds) {
      sessionIds = new Set();
      this.sessionIdsByMLContext.set(mlContext, sessionIds);
    }
    sessionIds.add(sessionId);

    if (!this.mlOpSupportLimitsBySessionId.has(sessionId)) {
      this.mlOpSupportLimitsBySessionId.set(sessionId, mlContext.opSupportLimits());
    }

    if (this.temporaryGraphInputs.length > 0) {
      this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
      this.temporaryGraphInputs = [];
    }
    if (this.temporaryGraphOutputs.length > 0) {
      this.sessionGraphOutputs.set(sessionId, this.temporaryGraphOutputs);
      this.temporaryGraphOutputs = [];
    }
  }

  public onReleaseSession(sessionId: number): void {
    this.sessionGraphInputs.delete(sessionId);
    this.sessionGraphOutputs.delete(sessionId);
    const mlContext = this.mlContextBySessionId.get(sessionId)!;
    if (!mlContext) {
      // Current session is not a WebNN session.
      return;
    }
    this.tensorManager.releaseTensorsForSession(sessionId);
    this.mlContextBySessionId.delete(sessionId);
    this.mlOpSupportLimitsBySessionId.delete(sessionId);
    const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
    sessionIds.delete(sessionId);
    if (sessionIds.size === 0) {
      this.sessionIdsByMLContext.delete(mlContext);
      const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext);
      if (mlContextIndex !== -1) {
        this.mlContextCache.splice(mlContextIndex, 1);
      }
    }
  }

  public getMLContext(sessionId: number): MLContext | undefined {
    return this.mlContextBySessionId.get(sessionId);
  }

  public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined {
    return this.mlOpSupportLimitsBySessionId.get(sessionId);
  }

  public reserveTensorId(): TensorId {
    return this.tensorManager.reserveTensorId();
  }

  public releaseTensorId(tensorId: TensorId): void {
    LOG_DEBUG('verbose', () => `[WebNN] releaseTensorId {tensorId: ${tensorId}}`);
    this.tensorManager.releaseTensorId(tensorId);
  }

  public async ensureTensor(
    sessionId: number | undefined,
    tensorId: TensorId,
    onnxDataType: DataType,
    dimensions: number[],
    copyOld: boolean,
  ): Promise<MLTensor> {
    const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
    if (!webnnDataType) {
      throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
    }
    return this.tensorManager.ensureTensor(
      sessionId ?? this.currentSessionId,
      tensorId,
      webnnDataType,
      dimensions,
      copyOld,
    );
  }

  public async createTemporaryTensor(
    sessionId: number,
    onnxDataType: DataType,
    shape: readonly number[],
  ): Promise<TensorId> {
    LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`);
    const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
    if (!dataType) {
      throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
    }
    const tensorId = this.tensorManager.reserveTensorId();
    await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false);
    const tensorIds = this.temporarySessionTensorIds.get(sessionId);
    if (!tensorIds) {
      this.temporarySessionTensorIds.set(sessionId, [tensorId]);
    } else {
      tensorIds.push(tensorId);
    }
    return tensorId;
  }

  public uploadTensor(tensorId: TensorId, data: Uint8Array): void {
    const wasm = getInstance();
    if (!wasm.shouldTransferToMLTensor) {
      throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false');
    }
    LOG_DEBUG('verbose', () => `[WebNN] uploadTensor {tensorId: ${tensorId}, data: ${data.byteLength}}`);
    this.tensorManager.upload(tensorId, data);
  }

  public async downloadTensor(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise<undefined> {
    return this.tensorManager.download(tensorId, dstBuffer);
  }

  public createMLTensorDownloader(tensorId: TensorId, type: Tensor.MLTensorDataTypes): () => Promise<Tensor.DataType> {
    return async () => {
      const data = await this.tensorManager.download(tensorId);
      return createView(data, type);
    };
  }

  public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId {
    const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType);
    if (!webnnDataType) {
      throw new Error(`Unsupported ONNX data type: ${onnxDataType}`);
    }

    const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions);
    LOG_DEBUG(
      'verbose',
      () =>
        `[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${
          dimensions
        }} -> {tensorId: ${id}}`,
    );
    return id;
  }

  // Register a WebNN Constant operand from external data.
  public registerMLConstant(
    externalFilePath: string,
    dataOffset: number,
    dataLength: number,
    builder: MLGraphBuilder,
    desc: MLOperandDescriptor,
    mountedFiles: Map<string, Uint8Array> | undefined,
    shouldConvertInt64ToInt32 = false,
  ): MLOperand {
    // If available, "Module.MountedFiles" is a Map for all preloaded files.
    if (!mountedFiles) {
      throw new Error('External mounted files are not available.');
    }

    let filePath = externalFilePath;
    if (externalFilePath.startsWith('./')) {
      filePath = externalFilePath.substring(2);
    }
    const fileData = mountedFiles.get(filePath);
    if (!fileData) {
      throw new Error(`File with name ${filePath} not found in preloaded files.`);
    }

    if (dataOffset + dataLength > fileData.byteLength) {
      throw new Error('Out of bounds: data offset and length exceed the external file data size.');
    }

    const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
    let bufferView: ArrayBufferView;
    switch (desc.dataType) {
      case 'float32':
        bufferView = new Float32Array(buffer);
        break;
      case 'float16':
        bufferView =
          typeof Float16Array !== 'undefined' && Float16Array.from ? new Float16Array(buffer) : new Uint16Array(buffer);
        break;
      case 'int32':
        bufferView = new Int32Array(buffer);
        break;
      case 'uint32':
        bufferView = new Uint32Array(buffer);
        break;
      case 'int64':
        if (shouldConvertInt64ToInt32) {
          // Int64 is not supported by current context, use int32 instead.
          const int32Buffer = convertDataToInt32(new Uint8Array(buffer), 'int64');
          bufferView = new Int32Array(int32Buffer.buffer);
          desc.dataType = 'int32';
        } else {
          bufferView = new BigInt64Array(buffer);
        }
        break;
      case 'uint64':
        bufferView = new BigUint64Array(buffer);
        break;
      case 'int8':
        bufferView = new Int8Array(buffer);
        break;
      case 'int4':
      case 'uint4':
      case 'uint8':
        bufferView = new Uint8Array(buffer);
        break;
      default:
        throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
    }

    LOG_DEBUG(
      'verbose',
      () =>
        `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}} ${
          shouldConvertInt64ToInt32 ? '(Note: it was int64 data type and registered to int32 as workaround)' : ''
        }`,
    );

    return builder.constant(desc, bufferView);
  }

  public registerGraphInput(inputName: string): void {
    this.temporaryGraphInputs.push(inputName);
  }

  public registerGraphOutput(outputName: string): void {
    this.temporaryGraphOutputs.push(outputName);
  }

  public isGraphInput(sessionId: number, inputName: string): boolean {
    const inputNames = this.sessionGraphInputs.get(sessionId);
    if (!inputNames) {
      return false;
    }
    return inputNames.includes(inputName);
  }

  public isGraphOutput(sessionId: number, outputName: string): boolean {
    const outputNames = this.sessionGraphOutputs.get(sessionId);
    if (!outputNames) {
      return false;
    }
    return outputNames.includes(outputName);
  }

  public isGraphInputOutputTypeSupported(sessionId: number, type: Tensor.Type, isInput = true): boolean {
    const dataType = onnxDataTypeToWebnnDataType.get(tensorDataTypeStringToEnum(type));
    const opLimits = this.mlOpSupportLimitsBySessionId.get(sessionId);

    if (typeof dataType === 'undefined') {
      return false;
    }

    if (isInput) {
      return !!opLimits?.input.dataTypes.includes(dataType);
    } else {
      return !!opLimits?.output.dataTypes.includes(dataType);
    }
  }

  public flush(): void {
    // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
  }
}
