import {ChainForkConfig} from "@lodestar/config";
import {IForkChoice, ProtoBlock} from "@lodestar/fork-choice";
import {SLOTS_PER_EPOCH} from "@lodestar/params";
import {
  DataAvailabilityStatus,
  ExecutionPayloadStatus,
  IBeaconStateView,
  StateHashTreeRootSource,
  computeEpochAtSlot,
  computeStartSlotAtEpoch,
} from "@lodestar/state-transition";
import {BeaconBlock, RootHex, SignedBeaconBlock, Slot} from "@lodestar/types";
import {Logger, fromHex, toRootHex} from "@lodestar/utils";
import {IBeaconDb} from "../../db/index.js";
import {Metrics} from "../../metrics/index.js";
import {nextEventLoop} from "../../util/eventLoop.js";
import {getCheckpointFromState} from "../blocks/utils/checkpoint.js";
import {ChainEvent, ChainEventEmitter} from "../emitter.js";
import {SeenBlockInput} from "../seenCache/seenGossipBlockInput.js";
import {BlockStateCache, CheckpointStateCache} from "../stateCache/types.js";
import {ValidatorMonitor} from "../validatorMonitor.js";
import {RegenError, RegenErrorCode} from "./errors.js";
import {IStateRegeneratorInternal, RegenCaller, StateRegenerationOpts} from "./interface.js";

export type RegenModules = {
  db: IBeaconDb;
  forkChoice: IForkChoice;
  blockStateCache: BlockStateCache;
  checkpointStateCache: CheckpointStateCache;
  seenBlockInputCache: SeenBlockInput;
  config: ChainForkConfig;
  emitter: ChainEventEmitter;
  logger: Logger;
  metrics: Metrics | null;
  validatorMonitor: ValidatorMonitor | null;
};

/**
 * Regenerates states that have already been processed by the fork choice
 * Since Feb 2024, we support reloading checkpoint state from disk via allowDiskReload flag. Due to its performance impact
 * this flag is only set to true in this case:
 *    - getPreState: this is for block processing, it's important to reload state in unfinality time
 *    - updateHeadState: rarely happen, but it's important to make sure we always can regen head state
 */
export class StateRegenerator implements IStateRegeneratorInternal {
  constructor(private readonly modules: RegenModules) {}

  /**
   * Get the state to run with `block`. May be:
   * - If parent is in same epoch -> Exact state at `block.parentRoot`
   * - If parent is in prev epoch -> State after `block.parentRoot` dialed forward through epoch transition
   * - reload state if needed in this flow
   */
  async getPreState(
    block: BeaconBlock,
    opts: StateRegenerationOpts,
    regenCaller: RegenCaller
  ): Promise<IBeaconStateView> {
    const parentRoot = toRootHex(block.parentRoot);
    const parentBlock = this.modules.forkChoice.getBlockHexDefaultStatus(parentRoot);
    if (!parentBlock) {
      throw new RegenError({
        code: RegenErrorCode.BLOCK_NOT_IN_FORKCHOICE,
        blockRoot: block.parentRoot,
      });
    }

    const parentEpoch = computeEpochAtSlot(parentBlock.slot);
    const blockEpoch = computeEpochAtSlot(block.slot);
    const allowDiskReload = true;

    // This may save us at least one epoch transition.
    // If the requested state crosses an epoch boundary
    // then we may use the checkpoint state before the block
    // We may have the checkpoint state with parent root inside the checkpoint state cache
    // through gossip validation.
    if (parentEpoch < blockEpoch) {
      return this.getBlockSlotState(parentBlock, block.slot, opts, regenCaller, allowDiskReload);
    }

    // Otherwise, get the state normally.
    return this.getState(parentBlock.stateRoot, regenCaller, allowDiskReload);
  }

  /**
   * Get state after block `blockRoot` dialed forward to `slot`
   *   - allowDiskReload should be used with care, as it will cause the state to be reloaded from disk
   */
  async getBlockSlotState(
    block: ProtoBlock,
    slot: Slot,
    opts: StateRegenerationOpts,
    regenCaller: RegenCaller,
    allowDiskReload = false
  ): Promise<IBeaconStateView> {
    if (slot < block.slot) {
      throw new RegenError({
        code: RegenErrorCode.SLOT_BEFORE_BLOCK_SLOT,
        slot,
        blockSlot: block.slot,
      });
    }

    const {blockRoot} = block;
    const {checkpointStateCache} = this.modules;
    const epoch = computeEpochAtSlot(slot);

    const latestCheckpointStateCtx = allowDiskReload
      ? await checkpointStateCache.getOrReloadLatest(blockRoot, epoch)
      : checkpointStateCache.getLatest(blockRoot, epoch);

    // If a checkpoint state exists with the given checkpoint root, it either is in requested epoch
    // or needs to have empty slots processed until the requested epoch
    if (latestCheckpointStateCtx) {
      return processSlotsByCheckpoint(this.modules, latestCheckpointStateCtx, slot, regenCaller, opts);
    }

    // Otherwise, use the fork choice to get the stateRoot from block at the checkpoint root
    // regenerate that state,
    // then process empty slots until the requested epoch
    const blockStateCtx = await this.getState(block.stateRoot, regenCaller, allowDiskReload);
    return processSlotsByCheckpoint(this.modules, blockStateCtx, slot, regenCaller, opts);
  }

  /**
   * Get state by exact root. If not in cache directly, requires finding the block that references the state from the
   * forkchoice and replaying blocks to get to it.
   *   - allowDiskReload should be used with care, as it will cause the state to be reloaded from disk
   */
  async getState(
    stateRoot: RootHex,
    caller: RegenCaller,
    // internal option, don't want to expose to external caller
    allowDiskReload = false
  ): Promise<IBeaconStateView> {
    // Trivial case, state at stateRoot is already cached
    const cachedStateCtx = this.modules.blockStateCache.get(stateRoot);
    if (cachedStateCtx) {
      return cachedStateCtx;
    }

    // Otherwise we have to use the fork choice to traverse backwards, block by block,
    // searching the state caches
    // then replay blocks forward to the desired stateRoot
    const block = this.findFirstStateBlock(stateRoot);

    // blocks to replay, ordered highest to lowest
    // gets reversed when replayed
    const blocksToReplay = [block];
    let state: IBeaconStateView | null = null;
    const {checkpointStateCache} = this.modules;

    const getSeedStateTimer = this.modules.metrics?.regenGetState.getSeedState.startTimer({caller});
    // iterateAncestorBlocks only returns ancestor blocks, not the block itself
    for (const b of this.modules.forkChoice.iterateAncestorBlocks(block.blockRoot, block.payloadStatus)) {
      state = this.modules.blockStateCache.get(b.stateRoot);
      if (state) {
        break;
      }
      const lastBlockToReplay = blocksToReplay.at(-1);
      if (!lastBlockToReplay) continue;
      const epoch = computeEpochAtSlot(lastBlockToReplay.slot - 1);

      state = allowDiskReload
        ? await checkpointStateCache.getOrReloadLatest(b.blockRoot, epoch)
        : checkpointStateCache.getLatest(b.blockRoot, epoch);
      if (state) {
        break;
      }
      blocksToReplay.push(b);
    }
    getSeedStateTimer?.();

    if (state === null) {
      throw new RegenError({
        code: RegenErrorCode.NO_SEED_STATE,
      });
    }

    const blockCount = blocksToReplay.length;
    const MAX_EPOCH_TO_PROCESS = 5;
    if (blockCount > MAX_EPOCH_TO_PROCESS * SLOTS_PER_EPOCH) {
      throw new RegenError({
        code: RegenErrorCode.TOO_MANY_BLOCK_PROCESSED,
        stateRoot,
      });
    }

    this.modules.metrics?.regenGetState.blockCount.observe({caller}, blockCount);

    const replaySlots = new Array<Slot>(blockCount);
    const blockPromises = new Array<Promise<SignedBeaconBlock | null>>(blockCount);

    const protoBlocksAsc = blocksToReplay.reverse();
    for (const [i, protoBlock] of protoBlocksAsc.entries()) {
      replaySlots[i] = protoBlock.slot;
      const blockInput = this.modules.seenBlockInputCache.get(protoBlock.blockRoot);
      blockPromises[i] = blockInput?.hasBlock()
        ? Promise.resolve(blockInput.getBlock())
        : this.modules.db.block.get(fromHex(protoBlock.blockRoot));
    }

    const logCtx = {stateRoot, caller, replaySlots: replaySlots.join(",")};
    this.modules.logger.debug("Replaying blocks to get state", logCtx);

    const loadBlocksTimer = this.modules.metrics?.regenGetState.loadBlocks.startTimer({caller});
    const blockOrNulls = await Promise.all(blockPromises);
    loadBlocksTimer?.();

    const blocksByRoot = new Map<RootHex, SignedBeaconBlock>();
    for (const [i, blockOrNull] of blockOrNulls.entries()) {
      // checking early here helps prevent unneccessary state transition below
      if (blockOrNull === null) {
        throw new RegenError({
          code: RegenErrorCode.BLOCK_NOT_IN_DB,
          blockRoot: protoBlocksAsc[i].blockRoot,
        });
      }
      blocksByRoot.set(protoBlocksAsc[i].blockRoot, blockOrNull);
    }

    const stateTransitionTimer = this.modules.metrics?.regenGetState.stateTransition.startTimer({caller});
    for (const b of protoBlocksAsc) {
      const block = blocksByRoot.get(b.blockRoot);
      // just to make compiler happy, we checked in the above for loop already
      if (block === undefined) {
        throw new RegenError({
          code: RegenErrorCode.BLOCK_NOT_IN_DB,
          blockRoot: b.blockRoot,
        });
      }

      try {
        // Only advances state trusting block's signture and hashes.
        // We are only running the state transition to get a specific state's data.
        // stateTransition() does the clone() inside, transfer cache to make the regen faster
        state = state.stateTransition(
          block,
          {
            // Replay previously imported blocks, assume valid and available
            executionPayloadStatus: ExecutionPayloadStatus.valid,
            dataAvailabilityStatus: DataAvailabilityStatus.Available,
            verifyStateRoot: false,
            verifyProposer: false,
            verifySignatures: false,
            dontTransferCache: false,
          },
          this.modules
        );

        const hashTreeRootTimer = this.modules.metrics?.stateHashTreeRootTime.startTimer({
          source: StateHashTreeRootSource.regenState,
        });
        const stateRoot = toRootHex(state.hashTreeRoot());
        hashTreeRootTimer?.();

        if (b.stateRoot !== stateRoot) {
          throw new RegenError({
            slot: b.slot,
            code: RegenErrorCode.INVALID_STATE_ROOT,
            actual: stateRoot,
            expected: b.stateRoot,
          });
        }

        if (allowDiskReload) {
          // also with allowDiskReload flag, we "reload" it to the state cache too
          this.modules.blockStateCache.add(state);
        }
      } catch (e) {
        throw new RegenError({
          code: RegenErrorCode.STATE_TRANSITION_ERROR,
          error: e as Error,
        });
      }
    }
    stateTransitionTimer?.();

    this.modules.logger.debug("Replayed blocks to get state", {...logCtx, stateSlot: state.slot});

    return state;
  }

  private findFirstStateBlock(stateRoot: RootHex): ProtoBlock {
    for (const block of this.modules.forkChoice.forwarditerateAncestorBlocks()) {
      if (block.stateRoot === stateRoot) {
        return block;
      }
    }

    throw new RegenError({
      code: RegenErrorCode.STATE_NOT_IN_FORKCHOICE,
      stateRoot,
    });
  }
}

/**
 * Starting at `state.slot`,
 * process slots forward towards `slot`,
 * emitting "checkpoint" events after every epoch processed.
 */
async function processSlotsByCheckpoint(
  modules: {
    checkpointStateCache: CheckpointStateCache;
    metrics: Metrics | null;
    validatorMonitor: ValidatorMonitor | null;
    emitter: ChainEventEmitter;
    config: ChainForkConfig;
    logger: Logger;
  },
  preState: IBeaconStateView,
  slot: Slot,
  regenCaller: RegenCaller,
  opts: StateRegenerationOpts
): Promise<IBeaconStateView> {
  let postState = await processSlotsToNearestCheckpoint(modules, preState, slot, regenCaller, opts);
  if (postState.slot < slot) {
    postState = postState.processSlots(slot, opts, modules);
  }
  return postState;
}

/**
 * Starting at `state.slot`,
 * process slots forward towards `slot`,
 * emitting "checkpoint" events after every epoch processed.
 *
 * Stops processing after no more full epochs can be processed.
 */
export async function processSlotsToNearestCheckpoint(
  modules: {
    checkpointStateCache: CheckpointStateCache;
    metrics: Metrics | null;
    validatorMonitor: ValidatorMonitor | null;
    emitter: ChainEventEmitter | null;
    config: ChainForkConfig;
    logger: Logger | null;
  },
  preState: IBeaconStateView,
  slot: Slot,
  regenCaller: RegenCaller,
  opts: StateRegenerationOpts
): Promise<IBeaconStateView> {
  const preSlot = preState.slot;
  const postSlot = slot;
  const preEpoch = computeEpochAtSlot(preSlot);
  let postState = preState;
  const {checkpointStateCache, emitter, metrics, logger} = modules;
  let count = 0;

  for (
    let nextEpochSlot = computeStartSlotAtEpoch(preEpoch + 1);
    nextEpochSlot <= postSlot;
    nextEpochSlot += SLOTS_PER_EPOCH
  ) {
    logger?.verbose("Processing slots over epochs", {
      slot: postState.slot,
      nextEpochSlot,
      postSlot,
      caller: regenCaller,
    });
    // processSlots calls .clone() before mutating
    postState = postState.processSlots(nextEpochSlot, opts, modules);
    metrics?.epochTransitionByCaller.inc({caller: regenCaller});

    // this is usually added when we prepare for next slot or validate gossip block
    // then when we process the 1st block of epoch, we don't have to do state transition again
    // This adds Previous Root Checkpoint State to the checkpoint state cache
    // This may becomes the "official" checkpoint state if the 1st block of epoch is skipped
    const checkpointState = postState;
    const cp = getCheckpointFromState(checkpointState);
    checkpointStateCache.add(cp, checkpointState);
    // consumers should not mutate state ever
    emitter?.emit(ChainEvent.checkpoint, cp, checkpointState);

    if (count >= 1) {
      // in normal condition, we only process 1 epoch so never reach this
      // in that case, we want to prune state at the last 1/3 slot of slot 0 of the next epoch after importing the 1st block of epoch
      // in non-finality time, we may process a lot of epochs so need to prune the cache to keep the node healthy
      // this happened to holesky on Feb 2025, see https://github.com/ChainSafe/lodestar/issues/7495#issuecomment-2680800898
      // cannot use getBlockRootAtSlot() because nextEpochSlot = postState
      const latestBlockHex = toRootHex(cp.root);
      try {
        const persistCount = await checkpointStateCache.processState(latestBlockHex, checkpointState);
        logger?.verbose("pruning checkpointStateCache during processSlotsToNearestCheckpoint", {
          root: latestBlockHex,
          epoch: cp.epoch,
          persistCount,
        });
      } catch (e) {
        logger?.debug(
          "CheckpointStateCache failed to process checkpoint state",
          {root: latestBlockHex, epoch: cp.epoch},
          e as Error
        );
      }
    }
    count++;

    // this avoids keeping our node busy processing blocks
    await nextEventLoop();
  }
  return postState;
}
