import Chunk from './chunk.ts'
import IndexFile, { memoizeByRefId } from './indexFile.ts'
import { findFirstData, parsePseudoBin } from './util.ts'
import { fromBytes } from './virtualOffset.ts'

import type { ParsedIndexBase, RefIndex } from './indexFile.ts'
import type { BaseOpts } from './util.ts'
import type { VirtualOffset } from './virtualOffset.ts'

interface BaiRefIndex extends RefIndex {
  linearIndex: VirtualOffset[]
}

interface BaiParsed extends ParsedIndexBase<BaiRefIndex> {
  bai: true
}

const BAI_MAGIC = 21578050 // BAI\1

// BAI uses a fixed 5-level binning scheme with a 14-bit (16KB) linear index
// resolution. See SAMv1.pdf §5.1.3 (hts-specs).
// https://github.com/samtools/hts-specs/blob/master/SAMv1.pdf
const BAI_LINEAR_SHIFT = 14
const BAI_LINEAR_INTERVAL = 1 << BAI_LINEAR_SHIFT // 16384

function roundDown(n: number, multiple: number) {
  return n - (n % multiple)
}
function roundUp(n: number, multiple: number) {
  return n - (n % multiple) + multiple
}

export interface IndexCovEntry {
  start: number
  end: number
  score: number
}

// Compute bin ranges that overlap [beg, end). Each level's first-bin offset
// is (8^L - 1) / 7. See SAMv1.pdf §5.1.1 for the binning derivation.
function reg2bins(beg: number, end: number) {
  end -= 1
  return [
    [0, 0],
    [1 + (beg >> 26), 1 + (end >> 26)],
    [9 + (beg >> 23), 9 + (end >> 23)],
    [73 + (beg >> 20), 73 + (end >> 20)],
    [585 + (beg >> 17), 585 + (end >> 17)],
    [4681 + (beg >> BAI_LINEAR_SHIFT), 4681 + (end >> BAI_LINEAR_SHIFT)],
  ] as const
}

export default class BAI extends IndexFile<BaiParsed> {
  async _parse(opts: BaseOpts): Promise<BaiParsed> {
    const bytes = await this.filehandle.readFile(opts)
    const dataView = new DataView(bytes.buffer)

    // check BAI magic numbers
    if (dataView.getUint32(0, true) !== BAI_MAGIC) {
      throw new Error('Not a BAI file')
    }

    const refCount = dataView.getInt32(4, true)
    const depth = 5
    const binLimit = ((1 << ((depth + 1) * 3)) - 1) / 7

    // read the indexes for each reference sequence
    let curr = 8
    let firstDataLine: VirtualOffset | undefined

    const offsets = [] as number[]
    for (let i = 0; i < refCount; i++) {
      offsets.push(curr)
      const binCount = dataView.getInt32(curr, true)

      curr += 4

      for (let j = 0; j < binCount; j += 1) {
        const bin = dataView.getUint32(curr, true)
        curr += 4
        if (bin === binLimit + 1) {
          curr += 4
          curr += 32
        } else if (bin > binLimit + 1) {
          throw new Error('bai index contains too many bins, please use CSI')
        } else {
          const chunkCount = dataView.getInt32(curr, true)
          curr += 4
          for (let k = 0; k < chunkCount; k++) {
            curr += 8
            curr += 8
          }
        }
      }

      // walk the linear index to find the smallest virtual offset, which
      // marks where the BAM header ends and data begins
      const linearCount = dataView.getInt32(curr, true)
      curr += 4
      for (let j = 0; j < linearCount; j++) {
        firstDataLine = findFirstData(firstDataLine, fromBytes(bytes, curr))
        curr += 8
      }
    }

    function getIndices(refId: number) {
      let curr = offsets[refId]
      if (curr === undefined) {
        return undefined
      }
      const binCount = dataView.getInt32(curr, true)
      let stats

      curr += 4
      const binIndex: Record<number, Chunk[]> = {}

      for (let j = 0; j < binCount; j += 1) {
        const bin = dataView.getUint32(curr, true)
        curr += 4
        if (bin === binLimit + 1) {
          curr += 4
          stats = parsePseudoBin(bytes, curr + 16)
          curr += 32
        } else if (bin > binLimit + 1) {
          throw new Error('bai index contains too many bins, please use CSI')
        } else {
          const chunkCount = dataView.getInt32(curr, true)
          curr += 4
          const chunks = new Array<Chunk>(chunkCount)
          for (let k = 0; k < chunkCount; k++) {
            const u = fromBytes(bytes, curr)
            curr += 8
            const v = fromBytes(bytes, curr)
            curr += 8
            chunks[k] = new Chunk(u, v, bin)
          }
          binIndex[bin] = chunks
        }
      }

      const linearCount = dataView.getInt32(curr, true)
      curr += 4
      const linearIndex = new Array<VirtualOffset>(linearCount)
      for (let j = 0; j < linearCount; j++) {
        linearIndex[j] = fromBytes(bytes, curr)
        curr += 8
      }

      return {
        binIndex,
        linearIndex,
        stats,
      }
    }

    return {
      bai: true,
      firstDataLine,
      maxBlockSize: 1 << 16,
      indices: memoizeByRefId(getIndices),
      refCount,
    }
  }

  async indexCov(
    seqId: number,
    start?: number,
    end?: number,
    opts?: BaseOpts,
  ): Promise<IndexCovEntry[]> {
    const v = BAI_LINEAR_INTERVAL
    const range = start !== undefined
    const indexData = await this.parse(opts)
    const seqIdx = indexData.indices(seqId)

    if (!seqIdx) {
      return []
    }
    const { linearIndex, stats } = seqIdx
    if (linearIndex.length === 0) {
      return []
    }
    const e = end === undefined ? (linearIndex.length - 1) * v : roundUp(end, v)
    const s = start === undefined ? 0 : roundDown(start, v)
    const depths = range
      ? new Array((e - s) / v)
      : new Array(linearIndex.length - 1)
    const totalSize = linearIndex[linearIndex.length - 1]!.blockPosition
    if (e > (linearIndex.length - 1) * v) {
      throw new Error('query outside of range of linear index')
    }
    let currentPos = linearIndex[s / v]!.blockPosition
    for (let i = s / v, j = 0; i < e / v; i++, j++) {
      depths[j] = {
        score: linearIndex[i + 1]!.blockPosition - currentPos,
        start: i * v,
        end: i * v + v,
      }
      currentPos = linearIndex[i + 1]!.blockPosition
    }
    return depths.map(d => ({
      ...d,
      score: (d.score * (stats?.lineCount ?? 0)) / totalSize,
    }))
  }

  protected reg2bins(min: number, max: number) {
    return reg2bins(min, max)
  }

  // Use the linear index to find minimum file position of chunks that could
  // contain alignments in the region. Linear index entries are monotonically
  // non-decreasing, so the first entry at minLin is the minimum.
  protected getLowestChunk(refIndex: BaiRefIndex, min: number) {
    const { linearIndex } = refIndex
    const nintv = linearIndex.length
    return linearIndex[Math.min(min >> BAI_LINEAR_SHIFT, nintv - 1)]
  }
}
