import { FileJp, FunctionJp, Loop, Varref } from "../../Joinpoints.js";
import ClavaCode from "../ClavaCode.js";
import ClavaJoinPoints from "../ClavaJoinPoints.js";
import MpiAccessPattern from "./MpiAccessPattern.js";
import MpiUtils from "./MpiUtils.js";
import MpiAccessPatterns from "./patterns/MpiAccessPatterns.js";

/**
 * Applies an MPI scatter-gather strategy to loops.
 */
export default class MpiScatterGatherLoop {
  private $loop: Loop;
  private inputJps: Varref[] = [];
  private inputAccesses: MpiAccessPattern[] = [];
  private outputJps: Varref[] = [];
  private outputAccesses: MpiAccessPattern[] = [];

  constructor($loop: Loop) {
    this.$loop = $loop;

    // Check if loop can be parallelize
    if (this.$loop.iterationsExpr === undefined) {
      throw "Could not determine expression with number of iterations of the loop. Check if the loop is in the Canonical Loop Form, according to the OpenMP standard.";
    }
  }

  addInput(varName: string, accessPattern: MpiAccessPattern) {
    this.addVariable(varName, accessPattern, this.inputJps, this.inputAccesses);
  }

  addOutput(varName: string, accessPattern: MpiAccessPattern) {
    this.addVariable(
      varName,
      accessPattern,
      this.outputJps,
      this.outputAccesses
    );
  }

  /**
   * Adapts code to use the MPI strategy.
   */
  execute() {
    const $mainFunction = ClavaCode.getFunctionDefinition(
      "main",
      true
    ) as FunctionJp;
    const $mainFile = $mainFunction.getAncestor("file") as FileJp | undefined;

    if ($mainFile == undefined) {
      throw "Could not find file of main function";
    }

    // Add include
    $mainFile.addInclude("mpi.h");
    $mainFile.addInclude("iostream", true);

    // Add global variables
    const $intType = ClavaJoinPoints.builtinType("int");
    $mainFile.addGlobal(MpiUtils.VAR_NUM_TASKS, $intType, "0");
    $mainFile.addGlobal(MpiUtils.VAR_NUM_WORKERS, $intType, "0");
    const $rankDecl = $mainFile.addGlobal(MpiUtils.VAR_RANK, $intType, "0");

    // Create decl
    const mpiWorkerFunction = ClavaJoinPoints.declLiteral(
      this.buildMpiWorker()
    );

    // Add MPI Worker
    $rankDecl.insertAfter(mpiWorkerFunction);

    // Replace loop with MPI Master routine
    this.replaceLoop();

    // Add MPI initialization
    this.addMpiInit($mainFunction);
  }

  /** PRIVATE SECTION **/

  private static FUNCTION_MPI_WORKER = "mpi_worker";
  private static VAR_WORKER_NUM_ELEMS = "mpi_loop_num_elems";
  private static VAR_MASTER_TOTAL_ITER = "clava_mpi_total_iter";

  private replaceLoop() {
    let masterSend = "";

    for (let i = 0; i < this.inputJps.length; i++) {
      const $inputJp = this.inputJps[i];
      const accessPattern = this.inputAccesses[i];

      masterSend += accessPattern.sendMaster(
        $inputJp,
        MpiScatterGatherLoop.VAR_MASTER_TOTAL_ITER
      );
      masterSend += "\n";
    }

    let masterReceive = "";

    for (let i = 0; i < this.outputJps.length; i++) {
      const $inputJp = this.outputJps[i];
      const accessPattern = this.outputAccesses[i];

      masterReceive += accessPattern.receiveMaster(
        $inputJp,
        MpiScatterGatherLoop.VAR_MASTER_TOTAL_ITER
      );
      masterReceive += "\n";
    }

    this.$loop.replaceWith(
      MpiScatterGatherLoop.MpiMaster(
        MpiUtils.VAR_NUM_WORKERS,
        this.$loop.iterationsExpr.code,
        masterSend,
        masterReceive,
        MpiUtils._VAR_MPI_STATUS
      )
    );
  }

  private buildMpiWorker() {
    const workerLoopCode = this.getWorkerLoopCode();

    let workerReceive = "";

    for (let i = 0; i < this.inputJps.length; i++) {
      const $inputJp = this.inputJps[i];
      const accessPattern = this.inputAccesses[i];

      workerReceive += accessPattern.receiveWorker(
        $inputJp,
        MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS
      );
      workerReceive += "\n";
    }

    let outputDecl = "";
    for (let i = 0; i < this.outputJps.length; i++) {
      const $outputJp = this.outputJps[i];
      const accessPattern = this.outputAccesses[i];

      outputDecl += accessPattern.outputDeclWorker(
        $outputJp,
        MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS
      );
      outputDecl += "\n";
    }

    let workerSend = "";

    for (let i = 0; i < this.outputJps.length; i++) {
      const $outputJp = this.outputJps[i];
      const accessPattern = this.outputAccesses[i];

      workerSend += accessPattern.sendWorker(
        $outputJp,
        MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS
      );
      workerSend += "\n";
    }

    return MpiScatterGatherLoop.MpiWorker(
      MpiScatterGatherLoop.FUNCTION_MPI_WORKER,
      MpiUtils._VAR_MPI_STATUS,
      MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS,
      workerReceive,
      outputDecl,
      workerLoopCode,
      workerSend
    );
  }

  private getWorkerLoopCode() {
    // Copy loop
    const $workerLoop = this.$loop.copy() as Loop;

    // Adjust start and end of loop
    $workerLoop.initValue = "0";
    $workerLoop.endValue = MpiScatterGatherLoop.VAR_WORKER_NUM_ELEMS;

    // TODO: Adapt loop body, if needed

    return $workerLoop.code;
  }

  private addMpiInit($mainFunction: FunctionJp) {
    // Add params to main, if no params
    if ($mainFunction.params.length === 0) {
      $mainFunction.setParamsFromStrings(["int argc", "char** argv"]);
    }

    const numMainParams = $mainFunction.params.length;
    if (numMainParams !== 2) {
      throw `Expected main() function to have 2 paramters, has '${numMainParams}'`;
    }

    const argc = $mainFunction.params[0].name;
    const argv = $mainFunction.params[1].name;

    $mainFunction.body.insertBegin(
      MpiScatterGatherLoop.MpiInit(
        argc,
        argv,
        MpiUtils.VAR_RANK,
        MpiUtils.VAR_NUM_TASKS,
        MpiUtils.VAR_NUM_WORKERS,
        MpiScatterGatherLoop.FUNCTION_MPI_WORKER
      )
    );
  }

  private addVariable(
    varName: string,
    accessPattern: MpiAccessPattern | undefined = MpiAccessPatterns.SCALAR_PATTERN,
    namesArray: Varref[],
    accessesArray: MpiAccessPattern[]
  ) {
    // Check if loop contains a reference to the variable
    let firstVarref = undefined;
    for (const $v of this.$loop.getDescendants("varref")) {
      const $varref = $v as Varref;
      if ($varref.name === varName) {
        firstVarref = $varref;
        break;
      }
    }

    if (firstVarref === undefined) {
      throw `Could not find a reference to the variable '${varName}' in the loop located at ${this.$loop.location}`;
    }

    namesArray.push(firstVarref);
    accessesArray.push(accessPattern);
  }

  /** CODEDEFS **/

  // TODO: std::cerr should not be hardcoded, lara.code.Logger should be used instead
  private static MpiInit(
    argc: string,
    argv: string,
    rank: string,
    numTasks: string,
    numWorkers: string,
    mpiWorker: string
  ) {
    return `
    MPI_Init(&${argc}, &${argv});
    MPI_Comm_rank(MPI_COMM_WORLD, &${rank});
    MPI_Comm_size(MPI_COMM_WORLD, &${numTasks});
    ${numWorkers} = ${numTasks} - 1;

    if(${numWorkers} == 0) {
        std::cerr << "This program does not support working with a single process." << std::endl;
        return 1;
    }

	if(${rank} > 0) {
		${mpiWorker}();
		MPI_Finalize();
		return 0;
	}
`;
  }

  private static MpiWorker(
    functionName: string,
    status: string,
    numElems: string,
    receiveData: string,
    outputDecl: string,
    loop: string,
    sendData: string
  ) {
    return `
void ${functionName}() {
    MPI_Status ${status};

	// Number of loop iterations
    int ${numElems};

    MPI_Recv(&${numElems}, 1, MPI_INT, 0, 1, MPI_COMM_WORLD, &${status});
	
	${receiveData}
	
	${outputDecl}
	
	${loop}
	
	${sendData}	
}
`;
  }

  private static MpiMaster(
    numWorkers: string,
    numIterations: string,
    masterSend: string,
    masterReceive: string,
    status: string
  ) {
    return `
	// Master routine
	
	// split iterations of the loop
	int clava_mpi_total_iter = ${numIterations};
	int clava_mpi_loop_limit = clava_mpi_total_iter;
	// A better distribution calculation could be used
	int clava_mpi_num_iter = clava_mpi_total_iter / ${numWorkers};
	int clava_mpi_num_iter_last = clava_mpi_num_iter + clava_mpi_total_iter % ${numWorkers};
	// int clava_mpi_num_iter_last = clava_mpi_num_iter + (clava_mpi_loop_limit - (clava_mpi_num_iter * ${numWorkers}));
	
	// send number of iterations
	for(int i=0; i<${numWorkers}-1; i++) {
		MPI_Send(&clava_mpi_num_iter, 1, MPI_INT, i+1, 1, MPI_COMM_WORLD);
	}
	MPI_Send(&clava_mpi_num_iter_last, 1, MPI_INT, ${numWorkers}, 1, MPI_COMM_WORLD);
	
	
	${masterSend}
	
	MPI_Status ${status};
	
	${masterReceive}
	
	MPI_Finalize();
`;
  }
}
