// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />

import { Env, InferenceSession, Tensor, TRACE_EVENT_BEGIN, TRACE_EVENT_END } from 'onnxruntime-common';

import {
  SerializableInternalBuffer,
  SerializableSessionMetadata,
  SerializableTensorMetadata,
  TensorMetadata,
} from './proxy-messages';
import { setRunOptions } from './run-options';
import { setSessionOptions } from './session-options';
import {
  calculateTensorSizeInBytes,
  dataLocationStringToEnum,
  isGpuBufferSupportedType,
  isMLTensorSupportedType,
  logLevelStringToEnum,
  tensorDataTypeEnumToString,
  tensorDataTypeStringToEnum,
  tensorTypeToTypedArrayConstructor,
} from './wasm-common';
import { getInstance } from './wasm-factory';
import { allocWasmString, checkLastError } from './wasm-utils';
import { loadFile } from './wasm-utils-load-file';

// #region Initializations

/**
 * There are 4 different "initialization" steps for ORT. They happen in different places and different time.
 *
 * 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
 *    This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
 * function multiple times to register all the available backends. The backend registration is very fast. It only
 * registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
 *    Refer to web/lib/index.ts for the backend registration.
 *
 * 2. WebAssembly artifact initialization.
 *    This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is
 * called). In this step, onnxruntime-web does the followings:
 *     - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
 *     - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
 * JavaScript code to initialize the WebAssembly runtime.
 *         - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
 *         - downloading the 'ort-wasm{...}.wasm' file is done in this step.
 *         - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
 *
 * 3. ORT environment initialization.
 *    This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
 * Function `_OrtInit()` is called in this step.
 *     - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
 *     - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
 *
 * 4. Session initialization.
 *    This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once),
 * this step will be done for each session. In this step, onnxruntime-web does the followings:
 *    If the parameter is a URL:
 *    - download the model data from the URL.
 *    - copy the model data to the WASM heap. (proxy: 'copy-from')
 *    - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
 *    - call `_OrtCreateSession()` to create the session. (proxy: 'create')
 *
 *    If the parameter is a Uint8Array object:
 *    - copy the model data to the WASM heap. (proxy: 'copy-from')
 *    - call `_OrtCreateSession()` to create the session. (proxy: 'create')
 *
 *
 */

/**
 * initialize ORT environment.
 *
 * @param numThreads SetGlobalIntraOpNumThreads(numThreads)
 * @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
 */
const initOrt = (numThreads: number, loggingLevel: number): void => {
  const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);
  if (errorCode !== 0) {
    checkLastError("Can't initialize onnxruntime.");
  }
};

/**
 * initialize runtime environment.
 * @param env passed in the environment config object.
 */
export const initRuntime = async (env: Env): Promise<void> => {
  // init ORT
  initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));
};

/**
 * perform EP specific initialization.
 *
 * @param env
 * @param epName
 */
export const initEp = async (env: Env, epName: string): Promise<void> => {
  // initialize ASYNCIFY support
  getInstance().asyncInit?.();

  // perform WebGPU availability check ( either JSEP or WebGPU EP )
  let webgpuAdapter = env.webgpu.adapter as GPUAdapter | null;
  if (epName === 'webgpu') {
    if (typeof navigator === 'undefined' || !navigator.gpu) {
      throw new Error('WebGPU is not supported in current environment');
    }
    if (!webgpuAdapter) {
      // if adapter is not set, request a new adapter.
      const powerPreference = env.webgpu.powerPreference;
      if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') {
        throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
      }
      const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
      if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
        throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
      }
      webgpuAdapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter });
      if (!webgpuAdapter) {
        throw new Error(
          'Failed to get GPU adapter. ' +
            'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.',
        );
      }
    } else {
      // if adapter is set, validate it.
      if (
        typeof webgpuAdapter.limits !== 'object' ||
        typeof webgpuAdapter.features !== 'object' ||
        typeof webgpuAdapter.requestDevice !== 'function'
      ) {
        throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
      }
    }
  }

  // perform WebNN availability check ( either JSEP or WebNN EP )
  if (epName === 'webnn') {
    if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) {
      throw new Error('WebNN is not supported in current environment');
    }
  }

  if (!BUILD_DEFS.DISABLE_JSEP) {
    // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
    const initJsep = require('./jsep/init').init;

    if (epName === 'webgpu') {
      await initJsep('webgpu', getInstance(), env, webgpuAdapter);
    }
    if (epName === 'webnn') {
      await initJsep('webnn', getInstance(), env);
    }
  } else {
    if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
      getInstance().webgpuInit!((device) => {
        env.webgpu.device = device;
      });
    }
    if (!BUILD_DEFS.DISABLE_WEBNN && epName === 'webnn') {
      // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
      const backend = new (require('./jsep/backend-webnn').WebNNBackend)(env);
      getInstance().webnnInit!([
        backend,
        // webnnReserveTensorId
        () => backend.reserveTensorId(),
        // webnnReleaseTensorId,
        (tensorId: number) => backend.releaseTensorId(tensorId),
        // webnnEnsureTensor
        async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
          backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld),
        // webnnUploadTensor
        (tensorId: number, data: Uint8Array) => {
          backend.uploadTensor(tensorId, data);
        },
        // webnnDownloadTensor
        async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) =>
          backend.downloadTensor(tensorId, dstBuffer),
        // webnnRegisterMLContext
        (sessionId: number, mlContext: MLContext) => backend.registerMLContext(sessionId, mlContext),
        // webnnEnableTraceEvent
        !!env.trace,
      ]);
    }
  }
};

// #endregion Initializations

/**
 * valid data locations for input/output tensors.
 */
type SupportedTensorDataLocationForInputOutput =
  | 'cpu'
  | 'cpu-pinned'
  | 'gpu-buffer'
  | 'ml-tensor'
  // Use 'ml-tensor' during inference, but output a tensor located on the CPU.
  | 'ml-tensor-cpu-output';

type IOBindingState = {
  /**
   * the handle of IO binding.
   */
  readonly handle: number;

  /**
   * the preferred location for each output tensor.
   *
   * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-tensor'.
   */
  readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];

  /**
   * enum value of the preferred location for each output tensor.
   */
  readonly outputPreferredLocationsEncoded: readonly number[];
};

/**
 *  tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
 */
type SessionMetadata = [
  inferenceSessionId: number,
  inputNamesUTF8Encoded: number[],
  outputNamesUTF8Encoded: number[],
  bindingState: IOBindingState | null,
  enableGraphCapture: boolean,
  inputOutputBound: boolean,
];

const activeSessions = new Map<number, SessionMetadata>();

/**
 * get the input/output count of the session.
 * @param sessionHandle the handle representing the session. should be non-zero.
 * @returns a tuple including 2 numbers, representing the input count and output count.
 */
const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
  const wasm = getInstance();
  const stack = wasm.stackSave();
  try {
    const ptrSize = wasm.PTR_SIZE;
    const dataOffset = wasm.stackAlloc(2 * ptrSize);
    const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize);
    if (errorCode !== 0) {
      checkLastError("Can't get session input/output count.");
    }
    const type = ptrSize === 4 ? 'i32' : 'i64';
    return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))];
  } finally {
    wasm.stackRestore(stack);
  }
};

const getSessionInputOutputMetadata = (
  sessionHandle: number,
  index: number,
): [nameOffset: number, elementType: number, dims?: Array<number | string>] => {
  const wasm = getInstance();
  const stack = wasm.stackSave();
  let metadataOffset = 0;
  try {
    const ptrSize = wasm.PTR_SIZE;
    const dataOffset = wasm.stackAlloc(2 * ptrSize);
    const errorCode = wasm._OrtGetInputOutputMetadata(sessionHandle, index, dataOffset, dataOffset + ptrSize);
    if (errorCode !== 0) {
      checkLastError("Can't get session input/output metadata.");
    }
    const nameOffset = Number(wasm.getValue(dataOffset, '*'));
    metadataOffset = Number(wasm.getValue(dataOffset + ptrSize, '*'));
    // get element type
    const elementType = wasm.HEAP32[metadataOffset / 4];
    if (elementType === 0) {
      return [nameOffset, 0]; // non-tensor
    }

    // get dims count
    const dimsCount = wasm.HEAPU32[metadataOffset / 4 + 1];
    // get dims
    const dims: Array<number | string> = [];
    for (let i = 0; i < dimsCount; i++) {
      const symbolicDimNameOffset = Number(wasm.getValue(metadataOffset + 8 + i * ptrSize, '*'));
      dims.push(
        symbolicDimNameOffset !== 0
          ? wasm.UTF8ToString(symbolicDimNameOffset)
          : Number(wasm.getValue(metadataOffset + 8 + (i + dimsCount) * ptrSize, '*')),
      );
    }
    return [nameOffset, elementType, dims];
  } finally {
    wasm.stackRestore(stack);
    if (metadataOffset !== 0) {
      wasm._OrtFree(metadataOffset);
    }
  }
};

/**
 * allocate the memory and memcpy the external buffer.
 *
 * @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
 * @returns a 2-elements tuple - the pointer and size of the allocated buffer
 */
export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
  const wasm = getInstance();
  const modelDataOffset = wasm._malloc(model.byteLength);
  if (modelDataOffset === 0) {
    throw new Error(`Can't create a session. failed to allocate a buffer of size ${model.byteLength}.`);
  }
  wasm.HEAPU8.set(model, modelDataOffset);
  return [modelDataOffset, model.byteLength];
};

/**
 * create an inference session from a model data buffer.
 *
 * @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
 *     pointer and size of the model data buffer.
 * @param options an optional session options object.
 * @returns a 3-elements tuple containing [session handle, input names, output names]
 */
export const createSession = async (
  modelData: Uint8Array | SerializableInternalBuffer,
  options?: InferenceSession.SessionOptions,
): Promise<SerializableSessionMetadata> => {
  let modelDataOffset: number, modelDataLength: number;
  const wasm = getInstance();

  if (Array.isArray(modelData)) {
    // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
    [modelDataOffset, modelDataLength] = modelData;
  } else if (modelData.buffer === wasm.HEAPU8.buffer) {
    // if model data uses the same buffer as the WASM heap, we don't need to copy it.
    [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
  } else {
    // otherwise, copy the model data to the WASM heap.
    [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
  }

  let sessionHandle = 0;
  let sessionOptionsHandle = 0;
  let ioBindingHandle = 0;
  let allocs: number[] = [];
  const inputNamesUTF8Encoded = [];
  const outputNamesUTF8Encoded = [];

  try {
    [sessionOptionsHandle, allocs] = await setSessionOptions(options);

    if (options?.externalData && wasm.mountExternalData) {
      const loadingPromises = [];
      for (const file of options.externalData) {
        const path = typeof file === 'string' ? file : file.path;
        loadingPromises.push(
          loadFile(typeof file === 'string' ? file : file.data).then((data) => {
            wasm.mountExternalData(path, data);
          }),
        );
      }

      // wait for all external data files to be loaded
      await Promise.all(loadingPromises);
    }

    for (const provider of options?.executionProviders ?? []) {
      const providerName = typeof provider === 'string' ? provider : provider.name;
      if (providerName === 'webnn') {
        wasm.shouldTransferToMLTensor = false;
        if (typeof provider !== 'string') {
          const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
          const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
          const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
          const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
          const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
          if (context) {
            wasm.currentContext = context as MLContext;
          } else if (gpuDevice) {
            wasm.currentContext = await wasm.webnnCreateMLContext!(gpuDevice);
          } else {
            wasm.currentContext = await wasm.webnnCreateMLContext!({ deviceType, powerPreference });
          }
        } else {
          wasm.currentContext = await wasm.webnnCreateMLContext!();
        }
        break;
      }
    }

    sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
    wasm.webgpuOnCreateSession?.(sessionHandle);
    if (sessionHandle === 0) {
      checkLastError("Can't create a session.");
    }

    wasm.jsepOnCreateSession?.();

    // clear current MLContext after session creation
    if (wasm.currentContext) {
      wasm.webnnRegisterMLContext!(sessionHandle, wasm.currentContext);
      wasm.currentContext = undefined;
      wasm.shouldTransferToMLTensor = true;
    }

    const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

    const enableGraphCapture = !!options?.enableGraphCapture;

    const inputNames = [];
    const outputNames = [];
    const inputMetadata: InferenceSession.ValueMetadata[] = [];
    const outputMetadata: InferenceSession.ValueMetadata[] = [];
    const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
    for (let i = 0; i < inputCount; i++) {
      const [nameOffset, elementType, shape] = getSessionInputOutputMetadata(sessionHandle, i);
      if (nameOffset === 0) {
        checkLastError("Can't get an input name.");
      }
      inputNamesUTF8Encoded.push(nameOffset);
      const name = wasm.UTF8ToString(nameOffset);
      inputNames.push(name);
      inputMetadata.push(
        elementType === 0
          ? { name, isTensor: false }
          : { name, isTensor: true, type: tensorDataTypeEnumToString(elementType), shape: shape! },
      );
    }
    for (let i = 0; i < outputCount; i++) {
      const [nameOffset, elementType, shape] = getSessionInputOutputMetadata(sessionHandle, i + inputCount);
      if (nameOffset === 0) {
        checkLastError("Can't get an output name.");
      }
      outputNamesUTF8Encoded.push(nameOffset);
      const nameString = wasm.UTF8ToString(nameOffset);
      outputNames.push(nameString);
      outputMetadata.push(
        elementType === 0
          ? { name: nameString, isTensor: false }
          : { name: nameString, isTensor: true, type: tensorDataTypeEnumToString(elementType), shape: shape! },
      );

      if (!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) {
        if (enableGraphCapture && options?.preferredOutputLocation === undefined) {
          outputPreferredLocations.push('gpu-buffer');
          continue;
        }
        const location =
          typeof options?.preferredOutputLocation === 'string'
            ? options.preferredOutputLocation
            : (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
        const isGraphOutput = wasm.webnnIsGraphOutput;
        if (location === 'cpu' && isGraphOutput && isGraphOutput(sessionHandle, nameString)) {
          outputPreferredLocations.push('ml-tensor-cpu-output');
          continue;
        }
        if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') {
          throw new Error(`Not supported preferred output location: ${location}.`);
        }
        if (enableGraphCapture && location !== 'gpu-buffer') {
          throw new Error(
            `Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
          );
        }
        outputPreferredLocations.push(location);
      }
    }

    // use IO binding only when at least one output is preferred to be on GPU.
    let bindingState: IOBindingState | null = null;
    if (
      (!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) &&
      outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output')
    ) {
      ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
      if (ioBindingHandle === 0) {
        checkLastError("Can't create IO binding.");
      }

      bindingState = {
        handle: ioBindingHandle,
        outputPreferredLocations,
        outputPreferredLocationsEncoded: outputPreferredLocations
          // 'ml-tensor-cpu-output' is treated as 'ml-tensor' for the purpose of IO binding.
          .map((l) => (l === 'ml-tensor-cpu-output' ? 'ml-tensor' : l))
          .map((l) => dataLocationStringToEnum(l)),
      };
    }

    activeSessions.set(sessionHandle, [
      sessionHandle,
      inputNamesUTF8Encoded,
      outputNamesUTF8Encoded,
      bindingState,
      enableGraphCapture,
      false,
    ]);
    return [sessionHandle, inputNames, outputNames, inputMetadata, outputMetadata];
  } catch (e) {
    inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
    outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));

    if (ioBindingHandle !== 0) {
      if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) {
        checkLastError("Can't release IO binding.");
      }
    }

    if (sessionHandle !== 0) {
      if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
        checkLastError("Can't release session.");
      }
    }
    throw e;
  } finally {
    wasm._free(modelDataOffset);
    if (sessionOptionsHandle !== 0) {
      if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) {
        checkLastError("Can't release session options.");
      }
    }
    allocs.forEach((alloc) => wasm._free(alloc));

    // unmount external data if necessary
    wasm.unmountExternalData?.();
  }
};

export const releaseSession = (sessionId: number): void => {
  const wasm = getInstance();
  const session = activeSessions.get(sessionId);
  if (!session) {
    throw new Error(`cannot release session. invalid session id: ${sessionId}`);
  }
  const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;

  if (ioBindingState) {
    if (enableGraphCapture) {
      if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
        checkLastError("Can't clear bound outputs.");
      }
    }
    if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) {
      checkLastError("Can't release IO binding.");
    }
  }

  wasm.jsepOnReleaseSession?.(sessionId);
  wasm.webnnOnReleaseSession?.(sessionId);
  wasm.webgpuOnReleaseSession?.(sessionId);

  inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
  outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
  if (wasm._OrtReleaseSession(sessionHandle) !== 0) {
    checkLastError("Can't release session.");
  }
  activeSessions.delete(sessionId);
};

export const prepareInputOutputTensor = async (
  tensor: TensorMetadata | null,
  tensorHandles: number[],
  allocs: number[],
  sessionId: number,
  tensorNameUTF8Encoded: number,
  index: number,
  enableGraphCapture = false,
): Promise<void> => {
  if (!tensor) {
    tensorHandles.push(0);
    return;
  }

  const wasm = getInstance();
  const ptrSize = wasm.PTR_SIZE;

  const dataType = tensor[0];
  const dims = tensor[1];
  const location = tensor[3];
  let actualLocation = location;

  let rawData: number;
  let dataByteLength: number;

  if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-tensor')) {
    throw new Error('String tensor is not supported on GPU.');
  }

  if (enableGraphCapture && location !== 'gpu-buffer') {
    throw new Error(
      `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`,
    );
  }

  if (location === 'gpu-buffer') {
    const gpuBuffer = tensor[2].gpuBuffer;
    dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;

    if (!BUILD_DEFS.DISABLE_WEBGPU) {
      const registerBuffer = wasm.webgpuRegisterBuffer;
      if (!registerBuffer) {
        throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
      }

      rawData = registerBuffer(gpuBuffer, sessionId);
    } else {
      const registerBuffer = wasm.jsepRegisterBuffer;
      if (!registerBuffer) {
        throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
      }
      rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
    }
  } else if (location === 'ml-tensor') {
    const mlTensor = tensor[2].mlTensor as MLTensor;
    dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;

    const registerMLTensor = wasm.webnnRegisterMLTensor;
    if (!registerMLTensor) {
      throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
    }
    rawData = registerMLTensor(sessionId, mlTensor, tensorDataTypeStringToEnum(dataType), dims);
  } else {
    const data = tensor[2];

    if (Array.isArray(data)) {
      // string tensor
      dataByteLength = ptrSize * data.length;
      rawData = wasm._malloc(dataByteLength);
      allocs.push(rawData);
      for (let i = 0; i < data.length; i++) {
        if (typeof data[i] !== 'string') {
          throw new TypeError(`tensor data at index ${i} is not a string`);
        }
        wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*');
      }
    } else {
      const isGraphInput = wasm.webnnIsGraphInput;
      const isGraphOutput = wasm.webnnIsGraphOutput;
      if (dataType !== 'string' && isGraphInput && isGraphOutput) {
        const tensorName = wasm.UTF8ToString(tensorNameUTF8Encoded);
        // Promote the tensor to 'ml-tensor' if it is a graph input.
        if (isGraphInput(sessionId, tensorName) || isGraphOutput(sessionId, tensorName)) {
          const dataTypeEnum = tensorDataTypeStringToEnum(dataType);
          dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!;
          actualLocation = 'ml-tensor';
          const createTemporaryTensor = wasm.webnnCreateTemporaryTensor;
          const uploadTensor = wasm.webnnUploadTensor;
          if (!createTemporaryTensor || !uploadTensor) {
            throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
          }
          const tensorId = await createTemporaryTensor(sessionId, dataTypeEnum, dims as number[]);
          uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength));
          rawData = tensorId;
        } else {
          dataByteLength = data.byteLength;
          rawData = wasm._malloc(dataByteLength);
          allocs.push(rawData);
          wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
        }
      } else {
        dataByteLength = data.byteLength;
        rawData = wasm._malloc(dataByteLength);
        allocs.push(rawData);
        wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
      }
    }
  }

  const stack = wasm.stackSave();
  const dimsOffset = wasm.stackAlloc(4 * dims.length);
  try {
    dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64'));
    const tensor = wasm._OrtCreateTensor(
      tensorDataTypeStringToEnum(dataType),
      rawData,
      dataByteLength,
      dimsOffset,
      dims.length,
      dataLocationStringToEnum(actualLocation),
    );
    if (tensor === 0) {
      checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
    }
    tensorHandles.push(tensor);
  } finally {
    wasm.stackRestore(stack);
  }
};

/**
 * perform inference run
 */
export const run = async (
  sessionId: number,
  inputIndices: number[],
  inputTensors: TensorMetadata[],
  outputIndices: number[],
  outputTensors: Array<TensorMetadata | null>,
  options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
  const wasm = getInstance();
  const ptrSize = wasm.PTR_SIZE;
  const session = activeSessions.get(sessionId);
  if (!session) {
    throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
  }
  const sessionHandle = session[0];
  const inputNamesUTF8Encoded = session[1];
  const outputNamesUTF8Encoded = session[2];
  const ioBindingState = session[3];
  const enableGraphCapture = session[4];
  const inputOutputBound = session[5];

  const inputCount = inputIndices.length;
  const outputCount = outputIndices.length;

  let runOptionsHandle = 0;
  let runOptionsAllocs: number[] = [];

  const inputTensorHandles: number[] = [];
  const outputTensorHandles: number[] = [];
  const inputOutputAllocs: number[] = [];
  const preAllocatedOutputs: number[] = [];

  const beforeRunStack = wasm.stackSave();
  const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize);
  const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize);
  const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize);
  const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize);

  try {
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);

    TRACE_EVENT_BEGIN('wasm prepareInputOutputTensor');
    // create input tensors
    for (let i = 0; i < inputCount; i++) {
      await prepareInputOutputTensor(
        inputTensors[i],
        inputTensorHandles,
        inputOutputAllocs,
        sessionId,
        inputNamesUTF8Encoded[inputIndices[i]],
        inputIndices[i],
        enableGraphCapture,
      );
    }

    // create output tensors
    for (let i = 0; i < outputCount; i++) {
      await prepareInputOutputTensor(
        outputTensors[i],
        outputTensorHandles,
        inputOutputAllocs,
        sessionId,
        outputNamesUTF8Encoded[outputIndices[i]],
        inputCount + outputIndices[i],
        enableGraphCapture,
      );
    }
    TRACE_EVENT_END('wasm prepareInputOutputTensor');

    for (let i = 0; i < inputCount; i++) {
      wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*');
      wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*');
    }
    for (let i = 0; i < outputCount; i++) {
      wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*');
      wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*');
    }

    if ((!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) && ioBindingState && !inputOutputBound) {
      const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState;

      if (inputNamesUTF8Encoded.length !== inputCount) {
        throw new Error(
          `input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
        );
      }

      TRACE_EVENT_BEGIN('wasm bindInputsOutputs');
      // process inputs
      for (let i = 0; i < inputCount; i++) {
        const index = inputIndices[i];
        const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
        if (errorCode !== 0) {
          checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
        }
      }

      // process pre-allocated outputs
      for (let i = 0; i < outputCount; i++) {
        const index = outputIndices[i];
        const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.

        if (location) {
          // output is pre-allocated, store and bind the tensor.
          preAllocatedOutputs.push(outputTensorHandles[i]);
          const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
          if (errorCode !== 0) {
            checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
          }
        } else {
          // output is not pre-allocated. reset preferred location.
          const errorCode = wasm._OrtBindOutput(
            handle,
            outputNamesUTF8Encoded[index],
            0,
            outputPreferredLocationsEncoded[index],
          );
          if (errorCode !== 0) {
            checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
          }
        }
      }
      TRACE_EVENT_END('wasm bindInputsOutputs');
      activeSessions.set(sessionId, [
        sessionHandle,
        inputNamesUTF8Encoded,
        outputNamesUTF8Encoded,
        ioBindingState,
        enableGraphCapture,
        true,
      ]);
    }

    wasm.jsepOnRunStart?.(sessionHandle);
    wasm.webnnOnRunStart?.(sessionHandle);

    let errorCode: number;
    if ((!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) && ioBindingState) {
      errorCode = await wasm._OrtRunWithBinding(
        sessionHandle,
        ioBindingState.handle,
        outputCount,
        outputValuesOffset,
        runOptionsHandle,
      );
    } else {
      errorCode = await wasm._OrtRun(
        sessionHandle,
        inputNamesOffset,
        inputValuesOffset,
        inputCount,
        outputNamesOffset,
        outputCount,
        outputValuesOffset,
        runOptionsHandle,
      );
    }

    if (errorCode !== 0) {
      checkLastError('failed to call OrtRun().');
    }

    const output: TensorMetadata[] = [];
    const outputPromises: Array<Promise<[number, Tensor.DataType]>> = [];

    TRACE_EVENT_BEGIN('wasm ProcessOutputTensor');
    for (let i = 0; i < outputCount; i++) {
      const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*'));
      // TODO: revisit this part to ensure it works for WebGPU when both pre-allocated outputs and
      // preferred location are specified.
      // Certain pre-allocated tensors may already be bound in the IO binding. e.g. the WebNN backend
      // always binds its tensor to 'ml-tensor'. In such cases, the tensor ID might change after binding,
      // but copying data for these tensors should still be avoided.
      if (tensor === outputTensorHandles[i] || preAllocatedOutputs.includes(outputTensorHandles[i])) {
        // output tensor is pre-allocated. no need to copy data.
        output.push(outputTensors[i]!);
        if (tensor !== outputTensorHandles[i]) {
          // release redundant tensor earlier.
          if (wasm._OrtReleaseTensor(tensor) !== 0) {
            checkLastError("Can't release tensor.");
          }
        }
        continue;
      }

      const beforeGetTensorDataStack = wasm.stackSave();
      // stack allocate 4 pointer value
      const tensorDataOffset = wasm.stackAlloc(4 * ptrSize);

      let keepOutputTensor = false;
      let type: Tensor.Type | undefined,
        dataOffset = 0;
      try {
        const errorCode = wasm._OrtGetTensorData(
          tensor,
          tensorDataOffset,
          tensorDataOffset + ptrSize,
          tensorDataOffset + 2 * ptrSize,

          tensorDataOffset + 3 * ptrSize,
        );
        if (errorCode !== 0) {
          checkLastError(`Can't access output tensor data on index ${i}.`);
        }
        const valueType = ptrSize === 4 ? 'i32' : 'i64';
        const dataType = Number(wasm.getValue(tensorDataOffset, valueType));
        dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*');
        const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*');
        const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType));
        const dims = [];
        for (let i = 0; i < dimsLength; i++) {
          dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType)));
        }
        if (wasm._OrtFree(dimsOffset) !== 0) {
          checkLastError("Can't free memory for tensor dims.");
        }
        const size = dims.reduce((a, b) => a * b, 1);
        type = tensorDataTypeEnumToString(dataType);

        const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];

        if (type === 'string') {
          if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-tensor') {
            throw new Error('String tensor is not supported on GPU.');
          }
          const stringData: string[] = [];
          for (let i = 0; i < size; i++) {
            const offset = wasm.getValue(dataOffset + i * ptrSize, '*');
            const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*');
            const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset;
            stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
          }
          output.push([type, dims, stringData, 'cpu']);
        } else {
          // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
          // tensor for it. There is no mapping GPU buffer for an empty tensor.
          if (preferredLocation === 'gpu-buffer' && size > 0) {
            const getBuffer = !BUILD_DEFS.DISABLE_WEBGPU ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer;
            if (!getBuffer) {
              throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
            }
            const gpuBuffer = getBuffer(dataOffset);
            const bufferSize = calculateTensorSizeInBytes(dataType, size);
            if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {
              throw new Error(`Unsupported data type: ${type}`);
            }

            // do not release the tensor right now. it will be released when user calls tensor.dispose().
            keepOutputTensor = true;

            if (!BUILD_DEFS.DISABLE_WEBGPU) {
              wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset);
              const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId);
              output.push([
                type,
                dims,
                {
                  gpuBuffer,
                  download: async () => {
                    const arrayBuffer = await downloadDataFunction();
                    const data = new (tensorTypeToTypedArrayConstructor(type!))(arrayBuffer);
                    return data as Tensor.DataTypeMap[Tensor.GpuBufferDataTypes];
                  },
                  dispose: () => {
                    if (wasm._OrtReleaseTensor(tensor) !== 0) {
                      checkLastError("Can't release tensor.");
                    }
                  },
                },
                'gpu-buffer',
              ]);
            } else {
              output.push([
                type,
                dims,
                {
                  gpuBuffer,
                  download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
                  dispose: () => {
                    if (wasm._OrtReleaseTensor(tensor) !== 0) {
                      checkLastError("Can't release tensor.");
                    }
                  },
                },
                'gpu-buffer',
              ]);
            }
          } else if (preferredLocation === 'ml-tensor' && size > 0) {
            const ensureTensor = wasm.webnnEnsureTensor;
            const isGraphInputOutputTypeSupported = wasm.webnnIsGraphInputOutputTypeSupported;
            if (!ensureTensor || !isGraphInputOutputTypeSupported) {
              throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.');
            }
            const tensorSize = calculateTensorSizeInBytes(dataType, size);
            if (tensorSize === undefined || !isMLTensorSupportedType(type)) {
              throw new Error(`Unsupported data type: ${type}`);
            }
            if (!isGraphInputOutputTypeSupported(sessionId, type, false)) {
              throw new Error(
                `preferredLocation "ml-tensor" for ${type} output is not supported by current WebNN Context.`,
              );
            }

            // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
            // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor
            // has been created.
            const mlTensor = await ensureTensor(sessionId, dataOffset, dataType, dims, false);

            // do not release the tensor right now. it will be released when user calls tensor.dispose().
            keepOutputTensor = true;

            output.push([
              type,
              dims,
              {
                mlTensor,
                download: wasm.webnnCreateMLTensorDownloader!(dataOffset, type),
                dispose: () => {
                  wasm.webnnReleaseTensorId!(dataOffset);
                  wasm._OrtReleaseTensor(tensor);
                },
              },
              'ml-tensor',
            ]);
          } else if (preferredLocation === 'ml-tensor-cpu-output' && size > 0) {
            const data = wasm.webnnCreateMLTensorDownloader!(dataOffset, type as Tensor.MLTensorDataTypes)();
            const index = output.length;
            // Delay the data download and releasing the tensor until we can wait for all output tensors to be downloaded.
            keepOutputTensor = true;
            outputPromises.push(
              (async () => {
                const result: [number, Tensor.DataType] = [index, await data];
                wasm.webnnReleaseTensorId!(dataOffset);
                wasm._OrtReleaseTensor(tensor);
                return result;
              })(),
            );
            output.push([type, dims, [], 'cpu']);
          } else {
            const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
            const data = new typedArrayConstructor(size);
            new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
              wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
            );
            output.push([type, dims, data, 'cpu']);
          }
        }
      } finally {
        wasm.stackRestore(beforeGetTensorDataStack);
        if (type === 'string' && dataOffset) {
          wasm._free(dataOffset);
        }
        if (!keepOutputTensor) {
          wasm._OrtReleaseTensor(tensor);
        }
      }
    }

    if (ioBindingState && !enableGraphCapture) {
      if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) {
        checkLastError("Can't clear bound outputs.");
      }
      activeSessions.set(sessionId, [
        sessionHandle,
        inputNamesUTF8Encoded,
        outputNamesUTF8Encoded,
        ioBindingState,
        enableGraphCapture,
        false,
      ]);
    }
    // Wait for all output tensor data to be downloaded.
    for (const [index, data] of await Promise.all(outputPromises)) {
      output[index][2] = data;
    }
    TRACE_EVENT_END('wasm ProcessOutputTensor');
    return output;
  } finally {
    wasm.webnnOnRunEnd?.(sessionHandle);

    wasm.stackRestore(beforeRunStack);

    if (!BUILD_DEFS.DISABLE_WEBGPU) {
      inputTensors.forEach((t) => {
        if (t && t[3] === 'gpu-buffer') {
          wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer);
        }
      });
      outputTensors.forEach((t) => {
        if (t && t[3] === 'gpu-buffer') {
          wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer);
        }
      });
    }
    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
    inputOutputAllocs.forEach((p) => wasm._free(p));

    if (runOptionsHandle !== 0) {
      wasm._OrtReleaseRunOptions(runOptionsHandle);
    }
    runOptionsAllocs.forEach((p) => wasm._free(p));
  }
};

/**
 * end profiling
 */
export const endProfiling = (sessionId: number): void => {
  const wasm = getInstance();
  const session = activeSessions.get(sessionId);
  if (!session) {
    throw new Error('invalid session id');
  }
  const sessionHandle = session[0];

  // profile file name is not used yet, but it must be freed.
  const profileFileName = wasm._OrtEndProfiling(sessionHandle);
  if (profileFileName === 0) {
    checkLastError("Can't get an profile file name.");
  }
  wasm._OrtFree(profileFileName);
};

export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {
  const buffers: ArrayBufferLike[] = [];
  for (const tensor of tensors) {
    const data = tensor[2];
    if (!Array.isArray(data) && 'buffer' in data) {
      buffers.push(data.buffer);
    }
  }
  return buffers;
};
