// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { Attribute } from './attribute';
import * as ortFbs from './ort-schema/flatbuffers/ort-generated';
import { onnx } from './ort-schema/protobuf/onnx';
import { Tensor } from './tensor';
import { LongUtil, MAX_CLIP, MIN_CLIP, ProtoUtil } from './util';

export declare namespace Graph {
  export interface Shape {
    readonly dims: readonly number[];
  }
  export interface ValueType {
    readonly tensorType: Tensor.DataType;
    readonly shape: Shape;
  }
  export interface Value {
    // the tensor data. empty for non-initialized inputs
    readonly tensor?: Tensor;

    // index to the Node where the value comes from. -1 for initializer.
    readonly from: number;

    // indices to the Nodes where the values go to.
    readonly to: readonly number[];

    // value type specification. empty for non-input values.
    readonly type?: ValueType;
  }
  export interface Node {
    // name of the node
    readonly name: string;

    // the operator type
    readonly opType: string;

    // indices to the Values where the inputs come from.
    readonly inputs: readonly number[];

    // indices to the Values where the outpus go to.
    readonly outputs: readonly number[];

    // the attributes that used by the operator
    readonly attributes: Attribute;
  }

  /**
   * a Transformer is an instance that allows all possible transformation operations that applied to a graph
   */
  export interface Transformer {
    removeAllIdentityNodes(): void;
    removeAllDropoutNodes(): void;
    fuseConvActivationNodes(): void;
    // TODO: add generic functions to manipulate the graph
  }

  // an initializer can use transformer to transform the graph
  export interface Initializer {
    transformGraph(transformer: Transformer): void;
  }
}

// eslint-disable-next-line @typescript-eslint/no-redeclare
export interface Graph {
  getInputIndices(): readonly number[];
  getInputNames(): readonly string[];
  getOutputIndices(): readonly number[];
  getOutputNames(): readonly string[];
  getValues(): readonly Graph.Value[];
  getNodes(): readonly Graph.Node[];
}

// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare
export const Graph = {
  /**
   * construct a graph from a graph protobuf type
   */
  from: (graphProto: onnx.IGraphProto | ortFbs.Graph, initializer?: Graph.Initializer) =>
    new GraphImpl(graphProto, initializer),
};

class Value implements Graph.Value {
  constructor(valueInfo?: onnx.IValueInfoProto) {
    this._from = undefined;
    this._to = [];
    this.tensor = undefined;
    this.type = undefined;

    if (valueInfo) {
      this.type = ProtoUtil.tensorValueTypeFromProto(valueInfo.type!.tensorType!);
    }
  }

  _from?: number; // -1 represent from initializer
  get from() {
    return this._from!;
  }
  _to: number[];
  get to() {
    return this._to;
  }
  type?: Graph.ValueType;
  tensor?: Tensor;
}

class Node implements Graph.Node {
  constructor(_nodeProto: onnx.INodeProto | ortFbs.Node, name?: string) {
    if (_nodeProto instanceof onnx.NodeProto) {
      this.name = _nodeProto.name;
      this.opType = _nodeProto.opType;
      this.attributes = new Attribute(_nodeProto.attribute);
    } else if (_nodeProto instanceof ortFbs.Node) {
      this.name = name ?? _nodeProto.name()!;
      this.opType = _nodeProto.opType()!;
      this.attributes = new Attribute(ProtoUtil.tensorAttributesFromORTFormat(_nodeProto));
    }

    this.inputs = [];
    this.outputs = [];
    this.executeNode = true;
  }

  name: string;
  opType: string;
  inputs: number[];
  outputs: number[];
  attributes: Attribute;
  executeNode: boolean;
}

class GraphImpl implements Graph, Graph.Transformer {
  private _allData: Value[];

  private _allInputIndices: number[];
  private _allInputNames: string[];

  private _allOutputIndices: number[];
  private _allOutputNames: string[];

  private _nodes: Node[];

  constructor(graph: onnx.IGraphProto | ortFbs.Graph, graphInitializer?: Graph.Initializer) {
    if (!graph) {
      throw new TypeError('graph is empty');
    }

    // build the graph - will throw exceptions if something fatal is detected
    this.buildGraph(graph);

    // execute any transformation logic for the graph (if applicable)
    this.transformGraph(graphInitializer);

    // check for cycles and other inconsistencies - will throw exceptions if something fatal is detected
    this.checkIsAcyclic();
  }

  getInputIndices(): readonly number[] {
    return this._allInputIndices;
  }

  getInputNames(): readonly string[] {
    return this._allInputNames;
  }

  getOutputIndices(): readonly number[] {
    return this._allOutputIndices;
  }

  getOutputNames(): readonly string[] {
    return this._allOutputNames;
  }

  getValues(): readonly Graph.Value[] {
    return this._allData;
  }

  getNodes(): readonly Graph.Node[] {
    return this._nodes;
  }

  private buildGraph(graph: onnx.IGraphProto | ortFbs.Graph) {
    // build the graph - will throw exceptions if something fatal is detected
    if (graph instanceof onnx.GraphProto) {
      this.buildGraphFromOnnxFormat(graph);
    } else if (graph instanceof ortFbs.Graph) {
      this.buildGraphFromOrtFormat(graph);
    } else {
      throw new TypeError('Graph type is not supported.');
    }
  }
  private buildGraphFromOnnxFormat(graph: onnx.IGraphProto) {
    const dataIndices = new Map<string, number>();
    this._allData = [];

    this._allInputIndices = [];
    this._allInputNames = [];

    this._allOutputIndices = [];
    this._allOutputNames = [];

    this._nodes = [];

    const nodesIndices = new Map<string, number>();

    // scan all inputs
    if (!graph.input) {
      throw new Error('missing information in graph: input');
    }
    const inputValueNames = [];
    for (const i of graph.input) {
      if (dataIndices.has(i.name!)) {
        throw new Error(`duplicated input name: ${i.name}`);
      }
      const currentIndex = this._allData.push(new Value(i)) - 1;
      dataIndices.set(i.name!, currentIndex);
      inputValueNames.push(i.name!);
    }

    // scan all initializers
    if (!graph.initializer) {
      throw new Error('missing information in graph: initializer');
    }
    for (const i of graph.initializer) {
      let index = dataIndices.get(i.name!);
      if (index === undefined) {
        const value = new Value();
        value.type = {
          shape: { dims: ProtoUtil.tensorDimsFromProto(i.dims!) },
          tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!),
        };
        index = this._allData.push(value) - 1;
        dataIndices.set(i.name!, index);
      }
      this._allData[index]._from = -1;
      this._allData[index].tensor = Tensor.fromProto(i);
    }

    // filter out input indices
    for (let i = 0; i < this._allData.length; i++) {
      if (!this._allData[i].tensor) {
        this._allInputIndices.push(i);
        this._allInputNames.push(inputValueNames[i]);
      }
    }

    // scan all outputs
    if (!graph.output) {
      throw new Error('missing information in graph: output');
    }
    for (const i of graph.output) {
      if (dataIndices.has(i.name!)) {
        throw new Error(`duplicated output name: ${i.name}`);
      }
      const currentIndex = this._allData.push(new Value(i)) - 1;
      dataIndices.set(i.name!, currentIndex);
      this._allOutputIndices.push(currentIndex);
      this._allOutputNames.push(i.name!);
    }

    // scan all nodes
    if (!graph.node) {
      throw new Error('missing information in graph: node');
    }
    for (const nodeProto of graph.node) {
      if (!nodeProto.name) {
        // assign a name to the node if it doesn't have one
        for (let pick = 0; ; pick++) {
          const name = `unnamed_${nodeProto.opType}_${pick}`;
          if (!nodesIndices.has(name)) {
            nodeProto.name = name;
            break;
          }
        }
      }

      if (nodesIndices.has(nodeProto.name)) {
        throw new Error(`duplicated node name: ${nodeProto.name}`);
      }
      const currentIndex = this._nodes.push(new Node(nodeProto)) - 1;
      nodesIndices.set(nodeProto.name, currentIndex);
    }

    // scan node's outputs
    for (let i = 0; i < this._nodes.length; i++) {
      const node = this._nodes[i];
      const nodeProto = graph.node[i];
      if (!nodeProto.output) {
        throw new Error(`missing output for node: ${nodeProto.name}`);
      }
      for (const output of nodeProto.output) {
        let dataIndex = dataIndices.get(output);
        if (typeof dataIndex === 'undefined') {
          dataIndex = this._allData.push(new Value()) - 1;
          dataIndices.set(output, dataIndex);
        }
        node.outputs.push(dataIndex);

        if (this._allData[dataIndex]._from !== undefined) {
          throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
        }
        this._allData[dataIndex]._from = i;

        // for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
        // operator and ignore the node from the graph
        if (nodeProto.opType === 'Constant') {
          if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) {
            throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
          }
          if (!nodeProto.output || nodeProto.output.length !== 1) {
            throw new Error('missing output or incorrect number of outputs for this Constant operator');
          }
          node.outputs.pop();
          node.executeNode = false;

          this._allData[dataIndex]._from = -1;
          this._allData[dataIndex].tensor = Tensor.fromProto(nodeProto.attribute[0].t);
        }
      }
    }

    // scan node's inputs
    for (let i = 0; i < this._nodes.length; i++) {
      const node = this._nodes[i];
      const nodeProto = graph.node[i];

      if (!nodeProto.input) {
        throw new Error(`missing input for node: ${nodeProto.name}`);
      }
      for (const input of nodeProto.input) {
        const dataIndex = dataIndices.get(input);
        if (typeof dataIndex === 'undefined') {
          // handle exception when opset > 9 and roi / scales not given
          if (
            input === '' &&
            (nodeProto.input.length === 3 || nodeProto.input.length === 4) &&
            nodeProto.opType === 'Resize'
          ) {
            continue;
          }
          throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`);
        }
        node.inputs.push(dataIndex);

        this._allData[dataIndex]._to.push(i);
      }
    }

    return true;
  }

  private buildGraphFromOrtFormat(graph: ortFbs.Graph) {
    const dataIndices = new Map<string, number>();
    this._allData = [];

    this._allInputIndices = [];
    this._allInputNames = [];

    this._allOutputIndices = [];
    this._allOutputNames = [];

    this._nodes = [];

    const nodesIndices = new Map<string, number>();

    // scan all inputs
    const inputValueNames = [];
    for (let i = 0; i < graph.inputsLength(); i++) {
      const inputName = graph.inputs(i);
      if (dataIndices.has(inputName)) {
        throw new Error(`duplicated input name: ${inputName}`);
      }
      // Find the input typeInfo from nodeargs
      for (let j = 0; j < graph.nodeArgsLength(); j++) {
        if (graph.nodeArgs(j)?.name() === inputName) {
          const value = new Value();
          const valueType = graph.nodeArgs(j)?.type()?.valueType();
          if (valueType !== ortFbs.TypeInfoValue.tensor_type) {
            throw new Error('Unexpected value type for the nodeArg.');
          }
          const valueInfo = graph.nodeArgs(j)!.type()!.value(new ortFbs.TensorTypeAndShape())!;
          const type = ProtoUtil.tensorDataTypeFromProto(valueInfo.elemType());
          const shape = valueInfo.shape()!;
          const dims = [];
          for (let k = 0; k < shape.dimLength()!; k++) {
            dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!));
          }
          value.type = { shape: { dims }, tensorType: type };
          const currentIndex = this._allData.push(value) - 1;
          dataIndices.set(inputName, currentIndex);
          inputValueNames.push(inputName);
        }
      }
    }
    // check initializers
    for (let i = 0; i < graph.initializersLength(); i++) {
      const initializer = graph.initializers(i)!;
      let index = dataIndices.get(initializer.name()!);
      if (index === undefined) {
        const value = new Value();
        const dims = ProtoUtil.tensorDimsFromORTFormat(initializer);
        const type = ProtoUtil.tensorDataTypeFromProto(initializer.dataType());
        value.type = { shape: { dims }, tensorType: type };
        index = this._allData.push(value) - 1;
        dataIndices.set(initializer.name()!, index);
      }
      this._allData[index]._from = -1;
      this._allData[index].tensor = Tensor.fromOrtTensor(initializer);
    }

    // filter out input indices
    for (let i = 0; i < this._allData.length; i++) {
      if (!this._allData[i].tensor) {
        this._allInputIndices.push(i);
        this._allInputNames.push(inputValueNames[i]);
      }
    }

    // scan all outputs
    for (let i = 0; i < graph.outputsLength(); i++) {
      const outputName = graph.outputs(i);
      if (dataIndices.has(outputName)) {
        throw new Error(`duplicated output name: ${outputName}`);
      }
      const currentIndex = this._allData.push(new Value()) - 1;
      dataIndices.set(outputName, currentIndex);
      this._allOutputIndices.push(currentIndex);
      this._allOutputNames.push(outputName);
    }

    // scan all nodes
    if (!graph.nodes) {
      throw new Error('missing information in graph: node');
    }
    for (let i = 0; i < graph.nodesLength(); i++) {
      const nodeProto = graph.nodes(i);
      let name = nodeProto!.name();
      if (!name) {
        // assign a name to the node if it doesn't have one
        for (let pick = 0; ; pick++) {
          name = `unnamed_${nodeProto!.opType()}_${pick}`;
          if (!nodesIndices.has(name)) {
            // an unique name is found. break.
            break;
          }
        }
      }

      if (nodesIndices.has(name)) {
        throw new Error(`duplicated node name: ${name}`);
      }
      const currentIndex = this._nodes.push(new Node(nodeProto!, name)) - 1;
      nodesIndices.set(name, currentIndex);
    }

    // scan node's outputs
    for (let i = 0; i < this._nodes.length; i++) {
      const node = this._nodes[i];
      const nodeProto = graph.nodes(i);
      if (nodeProto == null) {
        throw new Error(`No node exists at index ${i}`);
      }
      if (nodeProto?.outputsLength() === 0) {
        throw new Error(`missing output for node: ${nodeProto.name}`);
      }
      for (let j = 0; j < nodeProto?.outputsLength(); j++) {
        const output = nodeProto?.outputs(j);
        let dataIndex = dataIndices.get(output);
        if (typeof dataIndex === 'undefined') {
          dataIndex = this._allData.push(new Value()) - 1;
          dataIndices.set(output, dataIndex);
        }
        node.outputs.push(dataIndex);

        if (this._allData[dataIndex]._from !== undefined) {
          throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
        }
        this._allData[dataIndex]._from = i;

        // for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
        // operator and ignore the node from the graph
        if (nodeProto.opType() === 'Constant') {
          if (nodeProto.attributesLength() !== 1 || !nodeProto.attributes(0)!.t()) {
            throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
          }
          if (nodeProto.outputsLength() !== 1) {
            throw new Error('missing output or incorrect number of outputs for this Constant operator');
          }
          node.outputs.pop();
          node.executeNode = false;

          this._allData[dataIndex]._from = -1;
          this._allData[dataIndex].tensor = Tensor.fromOrtTensor(nodeProto.attributes(0)!.t()!);
        }
      }
    }

    // scan node's inputs
    for (let i = 0; i < this._nodes.length; i++) {
      const node = this._nodes[i];
      const nodeProto = graph.nodes(i)!;

      if (nodeProto.inputsLength() === 0) {
        throw new Error(`missing input for node: ${nodeProto.name}`);
      }
      for (let j = 0; j < nodeProto.inputsLength()!; j++) {
        const input = nodeProto.inputs(j)!;
        const dataIndex = dataIndices.get(input);
        if (typeof dataIndex === 'undefined') {
          throw new Error(`unrecognized input '${input}' for node: ${nodeProto!.name()}`);
        }
        node.inputs.push(dataIndex);

        this._allData[dataIndex]._to.push(i);
      }
    }
  }

  private checkIsAcyclic() {
    // go through the graph and check for cycles or other fatal inconsistencies
    const starters: Set<number> = new Set<number>();
    this._allInputIndices.forEach((i) => {
      const data = this._allData[i];
      data._to.forEach((j) => {
        starters.add(j);
      });
    });

    // Iterative DFS to check for cycles
    const nodesStack = Array.from(starters);
    const nodesState = new Array<string>(this._nodes.length).fill('white');

    while (nodesStack.length > 0) {
      const nodeIndex = nodesStack.pop()!;
      // this node has now been processed completely. Mark this node 'black' to denote this.
      if (nodesState[nodeIndex] === 'gray') {
        nodesState[nodeIndex] = 'black';
      } else {
        // this node is under processing stage. mark this node 'gray' to denote this.
        nodesStack.push(nodeIndex);
        nodesState[nodeIndex] = 'gray';

        this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => {
          const data = this._allData[outgoingEdgeIndex];
          if (typeof data.tensor !== 'undefined') {
            throw new Error('node outputs should not be initialized');
          }
          if (data._from !== nodeIndex) {
            throw new Error("from property of the Value object doesn't match index of Node being processed");
          }
          data._to.forEach((downstreamNodeIndex) => {
            // back edge found - cyclic
            if (nodesState[downstreamNodeIndex] === 'gray') {
              throw new Error('model graph is cyclic');
            }
            // tree edge found - continue processing by adding it to stack
            else if (nodesState[downstreamNodeIndex] === 'white') {
              nodesStack.push(downstreamNodeIndex);
            }
          });
        });
      }
    }
  }

  private transformGraph(graphInitializer?: Graph.Initializer): void {
    // apply common transform
    this.removeAllIdentityNodes();
    this.removeAllDropoutNodes();
    this.fuseConvActivationNodes();
    // apply initializer specific transform
    if (graphInitializer) {
      graphInitializer.transformGraph(this);
    }

    // finalize graph
    this.finalizeGraph();
  }

  /**
   * finalize the graph.
   *
   * this function should be called after all the transformation completed.
   * this function removes all unnecessary nodes and values from the graph
   */
  finalizeGraph() {
    let offset = 0;
    // delete all nodes that are not being executed
    // The graph is represented using these two arrays
    // this._nodes - Array holding the kernels to execute - each entry is a kernel pointing to this._allData
    // this._allData - hold 2 fields - to [] & from - these feileds hold the graph map for inputs and outputs per node
    // newIndices - remapping the graph after reading the flag 'executeNode'
    const newIndices = new Array<number>(this._nodes.length, 0);
    let nodePossition = 0;

    for (let i = 0; i < this._nodes.length; i++) {
      // giving new indexes to the nodes based on execution flag
      newIndices[i] = nodePossition;
      if (this._nodes[i].executeNode) {
        if (nodePossition !== i) {
          this._nodes[nodePossition] = this._nodes[i];
        }
        nodePossition++;
      } else {
        // delete all output values
        this._nodes[i].outputs.forEach((ind) => {
          this._allData[ind]._from = -2;
        });
      }
    }

    // removing the unused nodes
    this._nodes.splice(nodePossition, this._nodes.length - nodePossition);

    // Updating this._allData according to the new this._nodes
    for (let i = 0; i < this._allData.length; i++) {
      const currentData = this._allData[i];
      if (currentData._from !== undefined && currentData._from !== -1 && currentData._from !== -2) {
        currentData._from = newIndices[currentData._from];
      }

      for (let j = 0; j < currentData._to.length; j++) {
        if (currentData._to[j] >= 0) {
          currentData._to[j] = newIndices[currentData._to[j]];
        } else {
          throw new Error('Trying to update a removed node');
        }
      }
    }

    offset = 0;
    // delete all values that are not being referenced
    for (let i = 0; i < this._allData.length; i++) {
      // if current value is neither linked to next node, nor an output value, remove it.
      if (this._allData[i].from === -2 && this._allOutputIndices.indexOf(i + offset) === -1) {
        offset++;
        this._allData.splice(i, 1);
        i--;
        continue;
      }
      if (offset > 0) {
        let ind = -1;
        // if current value is neither an input value nor an initializer, find the node it's
        // coming from and update the corresponding node output
        if (this._allData[i].from !== undefined && this._allData[i].from !== -1) {
          ind = this._nodes[this._allData[i].from].outputs.indexOf(i + offset);
          if (ind !== -1) {
            this._nodes[this._allData[i].from].outputs[ind] = i;
          }
        } else {
          // if current value is an input value, update its reference in inputIndices
          ind = this._allInputIndices.indexOf(i + offset);
          if (ind !== -1) {
            this._allInputIndices[ind] = i;
          }
        }

        // find the node that the current value is linking to and update its input reference
        this._allData[i].to.forEach((node) => {
          ind = this._nodes[node].inputs.indexOf(i + offset);
          if (ind !== -1) {
            this._nodes[node].inputs[ind] = i;
          }
        });
        if (this._allData[i].to.length === 0) {
          // if current value is a graph output, update its reference in outputIndices
          ind = this._allOutputIndices.indexOf(i + offset);
          if (ind !== -1) {
            this._allOutputIndices[ind] = i;
          }
        }
      }
    }
  }

  /**
   * Delete the specified node. Assume the node has one incoming input and the first output connected to other nodes.
   * An input validation must be done before calling this function.
   * @param nodeIndex The index of node to be deleted
   */
  private deleteNode(nodeIndex: number) {
    const node = this._nodes[nodeIndex];
    if (node.outputs.length > 1) {
      for (let i = 1; i < node.outputs.length; i++) {
        if (this._allData[node.outputs[i]].to.length > 0) {
          throw new Error('Node deletion with more than one output connected to other nodes is not supported. ');
        }
      }
    }

    // this node wil not be executed
    node.executeNode = false;
    const inputValueIndex = node.inputs[0];
    const outputValueIndex = node.outputs[0];
    const nodesConsumingOutput = this._allData[outputValueIndex].to;

    // remove this node from the to property of the input Value
    for (let i = 0; i < node.inputs.length; i++) {
      const delIndex = this._allData[node.inputs[i]].to.indexOf(nodeIndex);
      // should not happen
      if (delIndex === -1) {
        throw new Error("The Value object doesn't have the current Node in it's 'to' property ");
      }
      this._allData[node.inputs[i]].to.splice(delIndex, 1);
    }

    // clear node indices consuming this output Value
    this._allData[outputValueIndex]._to = [];

    // if the output of this node is a graph output, adjust the index appropriately
    const index = this._allOutputIndices.indexOf(outputValueIndex);
    if (index !== -1) {
      this._allOutputIndices[index] = inputValueIndex;
    }

    // override the inputs for nodes consuming this node's output with the input to this node
    if (nodesConsumingOutput && nodesConsumingOutput.length > 0) {
      for (const nodeIndex of nodesConsumingOutput) {
        const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex);
        // should not happen
        if (replaceIndex === -1) {
          throw new Error("The Node object doesn't have the output Value in it's 'inputs' property ");
        }
        this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex;
        this._allData[inputValueIndex].to.push(nodeIndex);
      }
    }
  }

  removeAllDropoutNodes() {
    let nodeIndex = 0;
    for (const node of this._nodes) {
      // weed out 'Dropout' nodes so that no time is wasted in execution
      if (node.opType === 'Dropout') {
        // the node should have exactly 1 input and 1 or 2 outputs
        if (node.inputs.length !== 1) {
          throw new Error('Dropout nodes should only contain one input. ');
        }
        if (node.outputs.length !== 1 && node.outputs.length !== 2) {
          throw new Error('Dropout nodes should contain either 1 or 2 output(s)');
        }
        // the second output should not be referenced by any other node
        if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) {
          throw new Error("Dropout nodes's second output should not be referenced by other nodes");
        }
        this.deleteNode(nodeIndex);
      }
      nodeIndex++;
    }
  }

  removeAllIdentityNodes() {
    let nodeIndex = 0;
    for (const node of this._nodes) {
      // weed out 'Identity' nodes so that no time is wasted in execution
      if (node.opType === 'Identity') {
        this.deleteNode(nodeIndex);
      }
      nodeIndex++;
    }
  }

  isActivation(n: Node): boolean {
    switch (n.opType) {
      // TODO: add other activation methods
      case 'Relu':
      case 'Sigmoid':
      case 'Clip':
        return true;
      default:
        return false;
    }
  }

  fuseConvActivationNodes() {
    for (const node of this._nodes) {
      if (node.opType === 'Conv') {
        const next = this._allData[node.outputs[0]]._to;
        if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
          const child = this._nodes[next[0]];
          if (child.opType === 'Clip') {
            if (child.inputs.length === 1) {
              try {
                node.attributes.set('activation_params', 'floats', [
                  child.attributes.getFloat('min'),
                  child.attributes.getFloat('max'),
                ]);
              } catch {
                node.attributes.set('activation_params', 'floats', [MIN_CLIP, MAX_CLIP]);
              }
            } else if (
              child.inputs.length >= 3 &&
              this._allData[child.inputs[1]].tensor !== undefined &&
              this._allData[child.inputs[2]].tensor !== undefined
            ) {
              node.attributes.set('activation_params', 'floats', [
                this._allData[child.inputs[1]].tensor!.floatData[0],
                this._allData[child.inputs[2]].tensor!.floatData[0],
              ]);
            } else {
              // Skip fusion with clip node since clip min and clip max are not coming from initializer
              continue;
            }
          }
          node.attributes.set('activation', 'string', child.opType);
          this.deleteNode(next[0]);
        }
      }
    }
  }
}
