// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';

import {DataType, tensorDataTypeEnumToString} from '../wasm-common';

import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';

interface CommandInfo {
  readonly kernelId: number;
  readonly computePipeline: GPUComputePipeline;
  readonly bindGroup: GPUBindGroup;
  readonly dispatchGroup: [number, number, number];
}

interface KernelInfo {
  readonly kernelType: string;
  readonly kernelName: string;
  readonly kernelEntry: RunFunction;
  readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown];
}

interface PendingKernelInfo {
  readonly kernelId: number;
  readonly programName: string;
  readonly inputTensorViews: readonly TensorView[];
  readonly outputTensorViews: readonly TensorView[];
}

const getProgramInputTensorInfoDependencyKey =
    (inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => {
      if (inputDependencies.length !== inputTensors.length) {
        throw new Error(`inputDependencies length ${inputDependencies.length} is not equal to inputTensors length ${
            inputTensors.length}.`);
      }

      const inputInfos: string[] = [];
      for (let i = 0; i < inputTensors.length; ++i) {
        const type = inputTensors[i].dataType;
        switch (inputDependencies[i]) {
          case 'none': {
            inputInfos.push('');
            break;
          }
          case 'type': {
            inputInfos.push(`${type}`);
            break;
          }
          case 'rank': {
            const rank = inputTensors[i].dims.length;
            inputInfos.push(`${type};${rank}`);
            break;
          }
          case 'dims': {
            const dims = inputTensors[i].dims.join(',');
            inputInfos.push(`${type};${dims}`);
            break;
          }
          default:
            throw new Error(`unsupported input dependency: ${inputDependencies[i]}`);
        }
      }

      return inputInfos.join('|');
    };

/**
 * get a unique key representing the program from the program info, input shapes and types.
 *
 * @returns a unique key is a shorter string than the shader source, which contains all the information to identify a
 * program. if the key is the same, the program shader source should be the same, so we can reuse the program.
 *
 */
const getProgramInfoUniqueKey =
    (programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => {
      // final key format:
      // <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:is1DimensionDispatch:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
      let key = programInfo.name;
      if (programInfo.shaderCache?.hint) {
        key += '[' + programInfo.shaderCache.hint + ']';
      }
      key += ':' + is1DimensionDispatch +
          `:${
                 getProgramInputTensorInfoDependencyKey(
                     inputTensors,
                     programInfo.shaderCache?.inputDependencies ??
                         new Array<ProgramInputTensorInfoDependency>(inputTensors.length).fill('dims'))}`;
      return key;
    };

class AdapterInfoImpl implements AdapterInfo {
  readonly architecture?: string;
  readonly vendor?: string;

  constructor(adapterInfo: GPUAdapterInfo) {
    if (adapterInfo) {
      this.architecture = adapterInfo.architecture;
      this.vendor = adapterInfo.vendor;
    }
  }

  isArchitecture(architecture: GpuArchitecture): boolean {
    return this.architecture === architecture;
  }

  isVendor(vendor: GpuVendor): boolean {
    return this.vendor === vendor;
  }
}

/**
 * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
 * the first parameter so that it is stored for future use.
 */
export class WebGpuBackend {
  adapterInfo: AdapterInfoImpl;
  device: GPUDevice;
  /**
   * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
   */
  gpuDataManager: GpuDataManager;
  /**
   * an instance of ProgramManager to build and run WebGPU compute shader program, and manage a ProgramKey -> Program
   * artifacts mapping
   */
  programManager: ProgramManager;

  /**
   * representing the session ID of which is currently being run.
   * `null` means no session is being run.
   * only valid when session.run is executed.
   */
  currentSessionId: number|null = null;

  /**
   * representing the kernel ID of which is currently being computed (CPU code perspective).
   * `null` means no kernel is being computed.
   * only one kernel can be computed at a moment.
   */
  currentKernelId: number|null = null;
  /**
   * a list of temporary GPU data for the current kernel. should release when the kernel done computation.
   */
  private temporaryData: GpuData[];
  /**
   * a KernelID -> a GPU data list, which stores persistent GPU data owned by the specific kernel.
   */
  private kernelPersistentData: Map<number, GpuData[]>;
  /**
   * a KernelID -> a custom data, which stores custom data owned by the specific kernel.
   */
  private kernelCustomData: Map<number, {[key: string]: unknown}>;
  /**
   * get the custom data of the current kernel
   */
  get currentKernelCustomData(): {[key: string]: unknown} {
    if (this.currentKernelId === null) {
      throw new Error('currentKernelCustomData(): currentKernelId is null. (should not happen)');
    }

    let data = this.kernelCustomData.get(this.currentKernelId);
    if (!data) {
      data = {};
      this.kernelCustomData.set(this.currentKernelId, data);
    }

    return data;
  }

  // KernelID -> kernelInfo mapping
  kernels: Map<number, KernelInfo>;
  private commandEncoder: GPUCommandEncoder|null = null;
  private computePassEncoder: GPUComputePassEncoder|null = null;
  maxDispatchNumber = 16;
  pendingDispatchNumber = 0;

  // info of kernels pending submission for a single batch
  private pendingKernels: PendingKernelInfo[] = [];
  // queryReadBuffer -> pendingKernels mapping for all the batches
  private pendingQueries: Map<GPUBuffer, PendingKernelInfo[]> = new Map();
  private queryResolveBuffer?: GPUBuffer;
  private querySet?: GPUQuerySet;
  private queryTimeBase?: bigint;
  queryType: TimestampQuery;

  env: Env;
  sessionStatus: SessionState = 'default';
  /**
   * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
   */
  capturedCommandList: Map<number, CommandInfo[]> = new Map();

  /**
   * a SessionID -> PendingKernelInfo[] mapping for profiling.
   */
  private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();

  /**
   * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
   */
  sessionExternalDataMapping: Map<number, Map<number, [number, GPUBuffer]>> = new Map();

  async initialize(env: Env, adapter: GPUAdapter): Promise<void> {
    this.env = env;
    const requiredFeatures: GPUFeatureName[] = [];
    const deviceDescriptor: GPUDeviceDescriptor = {
      requiredLimits: {
        maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
        maxComputeWorkgroupsPerDimension: adapter.limits.maxComputeWorkgroupsPerDimension,
        maxStorageBufferBindingSize: adapter.limits.maxStorageBufferBindingSize,
        maxBufferSize: adapter.limits.maxBufferSize,
        maxComputeInvocationsPerWorkgroup: adapter.limits.maxComputeInvocationsPerWorkgroup,
        maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX,
        maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,
        maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ,
      },
      requiredFeatures,
    };

    if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
      requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
    } else if (adapter.features.has('timestamp-query')) {
      requiredFeatures.push('timestamp-query');
    }
    if (adapter.features.has('shader-f16')) {
      requiredFeatures.push('shader-f16');
    }

    this.device = await adapter.requestDevice(deviceDescriptor);
    this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo());
    this.gpuDataManager = createGpuDataManager(this);
    this.programManager = new ProgramManager(this);
    this.kernels = new Map();
    this.kernelPersistentData = new Map();
    this.kernelCustomData = new Map();

    // set up flags for logger
    configureLogger(env.logLevel!, !!env.debug);

    // TODO: set up flags

    this.device.onuncapturederror = ev => {
      if (ev.error instanceof GPUValidationError) {
        // eslint-disable-next-line no-console
        console.error(`An uncaught WebGPU validation error was raised: ${ev.error.message}`);
      }
    };

    Object.defineProperty(
        this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
    Object.defineProperty(
        this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});

    // init queryType, which is necessary for InferenceSession.create
    this.setQueryType();
  }

  dispose(): void {
    if (typeof this.querySet !== 'undefined') {
      this.querySet.destroy();
    }
    this.gpuDataManager.dispose();
  }

  getCommandEncoder(): GPUCommandEncoder {
    if (!this.commandEncoder) {
      this.commandEncoder = this.device.createCommandEncoder();
    }
    return this.commandEncoder;
  }

  getComputePassEncoder(): GPUComputePassEncoder {
    if (!this.computePassEncoder) {
      const commandEncoder = this.getCommandEncoder();
      const computePassDescriptor: GPUComputePassDescriptor = {};

      if (this.queryType === 'at-passes') {
        computePassDescriptor.timestampWrites = {
          querySet: this.querySet!,
          beginningOfPassWriteIndex: this.pendingDispatchNumber * 2,
          endOfPassWriteIndex: this.pendingDispatchNumber * 2 + 1,
        };
      }

      this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
    }
    return this.computePassEncoder;
  }

  endComputePass(): void {
    if (this.computePassEncoder) {
      this.computePassEncoder.end();
      this.computePassEncoder = null;
    }
  }

  flush(): void {
    if (!this.commandEncoder) {
      return;
    }

    TRACE_FUNC_BEGIN();

    this.endComputePass();
    let queryReadBuffer: GPUBuffer;
    if (this.queryType !== 'none') {
      this.commandEncoder.resolveQuerySet(
          this.querySet!, 0, this.pendingDispatchNumber * 2, this.queryResolveBuffer!, 0);

      queryReadBuffer = this.device.createBuffer(
          // eslint-disable-next-line no-bitwise
          {size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST});

      this.pendingQueries.set(queryReadBuffer, this.pendingKernels);
      this.pendingKernels = [];
      this.commandEncoder.copyBufferToBuffer(
          this.queryResolveBuffer!, 0, queryReadBuffer, 0, this.pendingDispatchNumber * 2 * 8);
    }

    this.device.queue.submit([this.commandEncoder.finish()]);
    this.gpuDataManager.refreshPendingBuffers();
    this.commandEncoder = null;
    this.pendingDispatchNumber = 0;

    if (this.queryType !== 'none') {
      void queryReadBuffer!.mapAsync(GPUMapMode.READ).then(() => {
        const mappedData = new BigUint64Array(queryReadBuffer.getMappedRange());
        const pendingKernels = this.pendingQueries.get(queryReadBuffer)!;
        for (let i = 0; i < mappedData.length / 2; i++) {
          const pendingKernelInfo = pendingKernels[i];
          const kernelId = pendingKernelInfo.kernelId;
          const kernelInfo = this.kernels.get(kernelId)!;
          const kernelType = kernelInfo.kernelType;
          const kernelName = kernelInfo.kernelName;
          const programName = pendingKernelInfo.programName;
          const inputTensorViews = pendingKernelInfo.inputTensorViews;
          const outputTensorViews = pendingKernelInfo.outputTensorViews;
          const startTimeU64 = mappedData[i * 2];
          const endTimeU64 = mappedData[i * 2 + 1];

          if (typeof this.queryTimeBase === 'undefined') {
            this.queryTimeBase = startTimeU64;
          }

          const startTime = Number(startTimeU64 - this.queryTimeBase);
          const endTime = Number(endTimeU64 - this.queryTimeBase);

          if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) {
            throw new RangeError('incorrect timestamp range');
          }

          if (this.env.webgpu.profiling?.ondata) {
            this.env.webgpu.profiling.ondata({
              version: 1,
              inputsMetadata: inputTensorViews.map(
                  value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
              outputsMetadata: outputTensorViews.map(
                  value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
              kernelId,
              kernelType,
              kernelName,
              programName,
              startTime,
              endTime,
            });
          } else {
            // if no callback is provided, print the profiling message to console
            let inputShapes = '';
            inputTensorViews.forEach((value, i) => {
              inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
            });
            let outputShapes = '';
            outputTensorViews.forEach((value, i) => {
              outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
            });
            // eslint-disable-next-line no-console
            console.log(`[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${
                outputShapes}execution time: ${endTime - startTime} ns`);
          }
          TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`);
        }
        queryReadBuffer.unmap();
        this.pendingQueries.delete(queryReadBuffer);
      });
    }
    TRACE_FUNC_END();
  }

  /**
   * run a WebGPU program.
   * @param program a ProgramInfo instance
   * @param inputTensorViews a TensorView array. each element represents a value already exists in GPU.
   * @param outputIndices an indices array. each element can be either -1 (temporary data), -2 (persistent data) or an
   * index to the kernel's output.
   * @param createKernelOutput a callback function that create a value to kernel's output with the given index
   * @param createIntermediateOutput a callback function that create a value as a intermediate value, either temporary
   * or persistent (owned by the current kernel)
   * @returns a TensorView array representing the result.
   */
  run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[],
      createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
      createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] {
    TRACE_FUNC_BEGIN(program.name);
    // create info for inputs
    const inputDatas: GpuData[] = [];
    for (let i = 0; i < inputTensorViews.length; ++i) {
      const data = inputTensorViews[i].data;
      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
      if (data === 0) {
        continue;
      }
      const gpuData = this.gpuDataManager.get(data);
      if (!gpuData) {
        throw new Error(`no GPU data for input: ${data}`);
      }
      inputDatas.push(gpuData);
    }

    const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);

    // check output indices
    const validatedOutputIndices = outputIndices.length === 0 ? outputs.map((_, i) => i) : outputIndices;
    if (validatedOutputIndices.length !== outputs.length) {
      throw new Error(`Output size ${validatedOutputIndices.length} must be equal to ${outputs.length}.`);
    }

    // create info for outputs
    const outputTensorViews: TensorView[] = [];
    const outputDatas: GpuData[] = [];
    for (let i = 0; i < outputs.length; ++i) {
      // value -1 and -2 are used for creating temporary and persistent outputs.
      // value -3 is used for placeholder output. So -3, -2, -1 and 0, 1, 2, ... are valid
      // output indices. see type definition of ComputeContextInputsOutputsMapping for more details.
      if (!Number.isInteger(validatedOutputIndices[i]) || validatedOutputIndices[i] < -3 ||
          validatedOutputIndices[i] >= outputs.length) {
        throw new Error(`Invalid output index: ${validatedOutputIndices[i]}`);
      }
      if (validatedOutputIndices[i] === -3) {
        continue;
      }
      const isTemporary = validatedOutputIndices[i] === -1;
      const isPersistent = validatedOutputIndices[i] === -2;
      const tensorView = (isTemporary || isPersistent) ?
          createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
          createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
      outputTensorViews.push(tensorView);
      // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
      if (tensorView.data === 0) {
        continue;
      }
      const gpuData = this.gpuDataManager.get(tensorView.data);
      if (!gpuData) {
        throw new Error(`no GPU data for output: ${tensorView.data}`);
      }
      if (isTemporary) {
        this.temporaryData.push(gpuData);
      }
      if (isPersistent) {
        let persistentData = this.kernelPersistentData.get(this.currentKernelId!);
        if (!persistentData) {
          persistentData = [];
          this.kernelPersistentData.set(this.currentKernelId!, persistentData);
        }
        persistentData.push(gpuData);
      }
      outputDatas.push(gpuData);
    }

    // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
    // zero-sized tensors.
    if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
      // if all outputs are zero-sized tensors, there is no need to run the program.
      if (outputDatas.length === 0) {
        TRACE_FUNC_END(program.name);
        return outputTensorViews;
      }
      // if some outputs are zero-sized tensors, report an error.
      //
      // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
      // If we see such use case, we need to make a change here to support it.
      throw new Error(
          `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
    }

    // load uniforms
    // TODO: add cache for uniform (is it necessary?)
    //
    let uniformBufferBinding: GPUBindingResource|undefined;
    if (programUniforms) {
      let currentOffset = 0;
      const offsets: number[] = [];

      programUniforms.forEach(v => {
        const data = typeof v.data === 'number' ? [v.data] : v.data;
        if (data.length === 0) {
          return;
        }
        // https://www.w3.org/TR/WGSL/#alignof
        const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
        let sizeOfVecOrMat;
        let baseAlignment;
        if (v.type === DataType.float16) {
          baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
          sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
        } else {
          baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
          sizeOfVecOrMat = 16;
        }
        currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
        offsets.push(currentOffset);
        // For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
        // N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
        // SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
        // array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
        // length is N * SizeOf(mat2x4<f16>).
        const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
        currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
                                           data.length * sizeOfElement;
      });

      // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
      // maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16.
      const maxAlignmentOfField = 16;
      currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField;
      const arrayBuffer = new ArrayBuffer(currentOffset);
      programUniforms.forEach((v, i) => {
        const offset = offsets[i];
        const data = typeof v.data === 'number' ? [v.data] : v.data;
        if (v.type === DataType.int32) {
          new Int32Array(arrayBuffer, offset, data.length).set(data);
        } else if (v.type === DataType.uint32) {
          new Uint32Array(arrayBuffer, offset, data.length).set(data);
        } else if (v.type === DataType.float16) {
          // TODO: use Float16Array.
          new Uint16Array(arrayBuffer, offset, data.length).set(data);
        } else if (v.type === DataType.float) {
          new Float32Array(arrayBuffer, offset, data.length).set(data);
        } else {
          throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
        }
      });

      const uniformBufferData =
          // eslint-disable-next-line no-bitwise
          this.gpuDataManager.create(currentOffset, GPUBufferUsage.COPY_DST | GPUBufferUsage.UNIFORM);
      this.device.queue.writeBuffer(uniformBufferData.buffer, 0, arrayBuffer, 0, currentOffset);
      this.gpuDataManager.release(uniformBufferData.id);
      uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer};
    }

    const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup);
    const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1;
    // get program info
    const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch);
    let artifact = this.programManager.getArtifact(key);
    if (!artifact) {
      artifact = this.programManager.build(program, normalizedDispatchGroup);
      this.programManager.setArtifact(key, artifact);
      LOG_DEBUG('info', () => `[artifact] key: ${key}, programName: ${program.name}`);
    }

    LOG_DEBUG(
        'info',
        () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
            normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);

    if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
      const pendingKernelInfo: PendingKernelInfo = {
        kernelId: this.currentKernelId!,
        programName: artifact.programInfo.name,
        inputTensorViews,
        outputTensorViews,
      };
      this.pendingKernels.push(pendingKernelInfo);

      if (this.sessionStatus === 'capturing') {
        const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
        sessionPendingKernels!.push(pendingKernelInfo);
      }
    }

    this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);

    TRACE_FUNC_END(program.name);
    return outputTensorViews;
  }

  upload(gpuDataId: number, data: Uint8Array): void {
    this.gpuDataManager.upload(gpuDataId, data);
  }

  memcpy(src: number, dst: number): void {
    this.gpuDataManager.memcpy(src, dst);
  }

  async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise<void> {
    // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure
    // the buffer is up-to-date.
    await this.gpuDataManager.download(gpuDataId, getTargetBuffer);
  }

  alloc(size: number): number {
    return this.gpuDataManager.create(size).id;
  }

  free(ptr: number): number {
    return this.gpuDataManager.release(ptr);
  }

  createKernel(kernelType: string, kernelId: number, attribute: unknown, kernelName: string): void {
    const op = WEBGPU_OP_RESOLVE_RULES.get(kernelType);
    if (!op) {
      throw new Error(`kernel not implemented: ${kernelType}`);
    }

    const kernelInfo: KernelInfo = {
      kernelType,
      kernelName,
      kernelEntry: op[0],
      attributes: [op[1], attribute],
    };
    this.kernels.set(kernelId, kernelInfo);
  }

  releaseKernel(kernelId: number): void {
    const persistentData = this.kernelPersistentData.get(kernelId);
    if (persistentData) {
      for (const data of persistentData) {
        this.gpuDataManager.release(data.id);
      }
      this.kernelPersistentData.delete(kernelId);
    }

    this.kernelCustomData.delete(kernelId);
    this.kernels.delete(kernelId);
  }

  computeKernel(kernelId: number, context: ComputeContext, errors: Array<Promise<string|null>>): number {
    const kernel = this.kernels.get(kernelId);
    if (!kernel) {
      throw new Error(`kernel not created: ${kernelId}`);
    }
    const kernelType = kernel.kernelType;
    const kernelName = kernel.kernelName;
    const kernelEntry = kernel.kernelEntry;
    const attributes = kernel.attributes;
    if (this.currentKernelId !== null) {
      throw new Error(`kernel "[${kernelType}] ${kernelName}" is not allowed to be called recursively`);
    }
    this.currentKernelId = kernelId;

    // parse attributes if necessary
    if (attributes[0]) {
      attributes[1] = attributes[0](attributes[1]);
      attributes[0] = undefined;
    }

    LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${kernelType}] ${kernelName}"...`);

    const useErrorScope = this.env.debug;

    this.temporaryData = [];
    try {
      if (useErrorScope) {
        this.device.pushErrorScope('validation');
      }

      kernelEntry(context, attributes[1]);
      return 0;  // ORT_OK
    } catch (e) {
      errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`));
      return 1;  // ORT_FAIL
    } finally {
      if (useErrorScope) {
        errors.push(this.device.popErrorScope().then(
            err => err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null));
      }

      for (const data of this.temporaryData) {
        this.gpuDataManager.release(data.id);
      }
      this.temporaryData = [];
      this.currentKernelId = null;
    }
  }

  // #region external buffer
  registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number {
    let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
    if (!sessionInputOutputMapping) {
      sessionInputOutputMapping = new Map();
      this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
    }

    const previousBuffer = sessionInputOutputMapping.get(index);
    const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
    sessionInputOutputMapping.set(index, [id, buffer]);
    return id;
  }
  unregisterBuffers(sessionId: number): void {
    const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
    if (sessionInputOutputMapping) {
      sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
      this.sessionExternalDataMapping.delete(sessionId);
    }
  }
  getBuffer(gpuDataId: number): GPUBuffer {
    const gpuData = this.gpuDataManager.get(gpuDataId);
    if (!gpuData) {
      throw new Error(`no GPU data for buffer: ${gpuDataId}`);
    }
    return gpuData.buffer;
  }
  createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes):
      () => Promise<Tensor.DataType> {
    return async () => {
      const data = await downloadGpuData(this, gpuBuffer, size);
      return createView(data.buffer, type);
    };
  }
  // #endregion
  writeTimestamp(index: number): void {
    if (this.queryType !== 'inside-passes') {
      return;
    }

    // eslint-disable-next-line @typescript-eslint/no-explicit-any
    (this.computePassEncoder as any).writeTimestamp(this.querySet, index);
  }
  setQueryType(): void {
    this.queryType = 'none';
    if (this.env.webgpu.profiling?.mode === 'default' ||
        (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) {
      if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
        this.queryType = 'inside-passes';
      } else if (this.device.features.has('timestamp-query')) {
        this.queryType = 'at-passes';
      }

      if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
        this.querySet = this.device.createQuerySet({
          type: 'timestamp',
          count: this.maxDispatchNumber * 2,
        });
        this.queryResolveBuffer = this.device.createBuffer(
            // eslint-disable-next-line no-bitwise
            {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE});
      }
    }
  }

  captureBegin(): void {
    LOG_DEBUG('info', 'captureBegin');
    if (!this.capturedCommandList.get(this.currentSessionId!)) {
      this.capturedCommandList.set(this.currentSessionId!, []);
    }
    if (!this.capturedPendingKernels.get(this.currentSessionId!)) {
      this.capturedPendingKernels.set(this.currentSessionId!, []);
    }
    // flush the left commands before we change the status.
    this.flush();
    this.sessionStatus = 'capturing';
  }
  captureEnd(): void {
    LOG_DEBUG('info', 'captureEnd');
    // flush the left commands before we change the status.
    this.flush();
    this.sessionStatus = 'default';
  }
  replay(): void {
    LOG_DEBUG('info', 'replay');
    this.sessionStatus = 'replaying';
    const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
    const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
    const length = sessionCommandList!.length;
    this.pendingKernels = [];
    for (let i = 0; i < length; i++) {
      const computePassEncoder = this.getComputePassEncoder();
      const command = sessionCommandList![i];
      this.writeTimestamp(this.pendingDispatchNumber * 2);
      computePassEncoder.setPipeline(command.computePipeline);
      computePassEncoder.setBindGroup(0, command.bindGroup);
      computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
      this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
      this.pendingDispatchNumber++;
      if (this.queryType !== 'none') {
        this.pendingKernels.push(sessionPendingKernels![i]);
      }
      if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
        this.endComputePass();
      }
      if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
        this.flush();
      }
    }
    // flush the left commands before we change the status.
    this.flush();
    this.sessionStatus = 'default';
  }

  onReleaseSession(sessionId: number): void {
    this.unregisterBuffers(sessionId);
    if (this.capturedCommandList.has(sessionId)) {
      this.capturedCommandList.delete(sessionId);
    }
    if (this.capturedPendingKernels.has(sessionId)) {
      this.capturedPendingKernels.delete(sessionId);
    }
    this.gpuDataManager.onReleaseSession(sessionId);
  }

  onRunStart(sessionId: number): void {
    this.currentSessionId = sessionId;
    this.setQueryType();
  }
}
