import {OpNode,Model, tensorInfo} from './interface/interface'
import Kernel from './kernel'
import Tensor from './tensor'

interface tensorPool{
    [key: string]: any
}

export default abstract class Graph {
    protected totalOpNode_: OpNode[];
    protected totalKernel_: Kernel[];
    protected inputTensor_: Tensor[];  //the input tensor of this graph
    protected outputTensor_: Tensor[]; //the output tensor of this graph
    protected tmpBufferTensor_: Tensor = null as unknown as Tensor;; //record the temp data
    //after copy to gpu,if we need to release this host data
    protected inputData:ArrayBuffer[] = [];
    protected ifDebug:boolean = false;
    protected inputChanged_:boolean = false;

    constructor(model:Model){
        this.totalOpNode_ = model.ops;
        this.totalKernel_ = [];
        this.inputTensor_ = [];
        this.outputTensor_ = [];
    }

    //creating kernel according to OpNode and assigning tensor information to each kernel 
    abstract buildGraph():number;

    //allocate memory to each tensor of kernel and resource needed before forward
    abstract prepareGraph():number;

    //give input data to the network and forward all kernels
    abstract runGraph():number;

    abstract finish():Promise<undefined[]>;
    
    //get the output of network.Each backend may have different ways to get it.
    abstract getArrayBufferOutput():Promise<ArrayBuffer[]>;

    //the configure if printf the key log
    setIfDebug(ifDebug:boolean ){this.ifDebug = ifDebug;}
    getIfDebug():boolean { return this.ifDebug};

    //runtimeImp->graphImp->TensorImp->BufferImp
    //each backend has it's own tensorImp Management
    abstract createTensor(info:tensorInfo,isInput:boolean,isOutput:boolean):any;

    //for reuse code,we extract this code from graph_imp.you should call this function in buildGraph()
    //after kernel is created(register).oneOpNode respondings to a kernel based on index.
    assignTensorToKernel():number{
        var currentTensor:tensorPool = {} as tensorPool;
        this.totalOpNode_.forEach((oneOpNode,idx)=>{
            oneOpNode.inputs.forEach((opEdge_)=>{
                var tensorName_ = opEdge_.name;
                var find_res = currentTensor[tensorName_] ;
                //tensor is reused causing the shape of tensor is not correspond to every op
                //so we record the shape information by inShape/outShape
                this.totalKernel_[idx].addInShape(opEdge_.tensorInfo.shape);
                if(find_res == undefined){
                    var tensor_:any; 
                    var edgeName_ = opEdge_.name;
                    if(edgeName_.includes('ppljs_input')){
                        tensor_ =this.createTensor(opEdge_.tensorInfo,true,false);
                        this.inputTensor_.push(tensor_);
                    }else
                        tensor_ =this.createTensor(opEdge_.tensorInfo,false,false);                    
                    this.totalKernel_[idx].addInTensor(tensor_);
                    currentTensor[tensorName_] = tensor_;
                }else{
                    this.totalKernel_[idx].addInTensor(find_res);
                }
            });

            oneOpNode.outputs.forEach((opEdge_)=>{
                var edgeName_ = opEdge_.name;
                var find_res = currentTensor[edgeName_];
                this.totalKernel_[idx].addOutShape(opEdge_.tensorInfo.shape);
                if(find_res == undefined){
                    var tensor_:any; ;
                    if(edgeName_.includes('ppljs_output')){
                        tensor_ =this.createTensor(opEdge_.tensorInfo,false,true);
                        this.outputTensor_.push(tensor_);
                    }else
                        tensor_ =this.createTensor(opEdge_.tensorInfo,false,false);
                    this.totalKernel_[idx].addOutTensor(tensor_);
                    currentTensor[edgeName_] = tensor_;
                }else{
                    this.totalKernel_[idx].addOutTensor(find_res);
                }
            })
        });
        return 0;
    }

    //after tensors is created after assignTensorToKernel()
    //we can malloc the tensor buffer memory
    //according to tensor information.and this code can be shared.
    assignMemoryToTensor():number{
        this.inputTensor_.forEach((inTensor)=>{
            if(inTensor.buffer == undefined){
                inTensor.mallocTensorBuffer();
            } 
        });
        this.outputTensor_.forEach((outTensor)=>{
            if(outTensor.buffer == undefined){
                outTensor.mallocTensorBuffer();
            }  
        });
        // we allocate temp memory for each kernel
        var maxTempSize :number = 0; 
        this.totalKernel_.forEach((tTotal)=>{
            for(var n:number = 0 ;n<tTotal.getOutTensorCount();n++){
                if(tTotal.getOutTensor(n).buffer == undefined){
                    tTotal.getOutTensor(n).mallocTensorBuffer();
                }
            }
            maxTempSize = (tTotal.tempBufferSize()>maxTempSize?tTotal.tempBufferSize():maxTempSize); 
        });
        if(maxTempSize != 0){
            this.tmpBufferTensor_ = this.createTensor(
                                     {name:"tempBuffer",shape:[1,1,1,maxTempSize],precision:1},false,false);
           this.tmpBufferTensor_.mallocTensorBuffer();
            for(let t in this.totalKernel_){
               if((this.totalKernel_[t]).tempBufferSize()!=0){
                   (this.totalKernel_[t]).setTmpTensor(this.tmpBufferTensor_);
               }
            }
        }
        return 0;
    }

    //pass the input of network after runtime is created.
    setInputFromArrayBuffer(inData:ArrayBuffer[]):number{
        this.inputData = inData;
        this.inputChanged_ = true;
        return 0;
    }

    getInputTensorCount():number { return this.inputTensor_.length; }
    getInputTensor(i:number):Tensor { return this.inputTensor_[i]; }
    getOutputTensorCount():number { return this.outputTensor_.length; }
    getOutputTensor(i:number):Tensor { return this.outputTensor_[i]; }

    getKernelCount():number { return this.totalKernel_.length; }
    getKernel(i:number):Kernel { return this.totalKernel_[i]; }

    getInputTensorShape(idx:number):any{
        if(idx >= this.getInputTensorCount()){
            console.error("idx exceed inputTensor size");
            return null;
        }
        return this.inputTensor_[idx].shape();
    }

    getOutputTensorShape(idx:number):any{
        if(idx >= this.getOutputTensorCount()){
            console.error("idx exceed outputTensor size");
            return null;
        }
        return this.outputTensor_[idx].shape();
    }

    releaseResources():number{
        //release all the tensor
        for(let iT of this.inputTensor_){
            iT.releaseTensorBuffer();
        }
        for(let oT of this.outputTensor_){
            oT.releaseTensorBuffer();
        }
        for(let k of this.totalKernel_){
            for(var iT = 0;iT<k.getInTensorCount();iT = iT +1){
                k.getInTensor(iT).releaseTensorBuffer();
            }
            for(var oT = 0;oT<k.getOutTensorCount();oT = oT +1){
                k.getOutTensor(oT).releaseTensorBuffer();
            }
        }
        //release all the GPUBUFFER in the kernel
        for(let t in this.totalKernel_){
            (this.totalKernel_[t]).releaseKernelResource();
        }
        //release temp buffer
        if(this.tmpBufferTensor_ !=null)
            this.tmpBufferTensor_.releaseTensorBuffer();
        return 0;
    }
}