import { mat4, vec3 } from 'gl-matrix'
import { NiivueObject3D } from '../niivue-object3D.js'
import { getExtents } from './utils.js'
import type { NVImage } from './index.js'

/**
 * Creates a NiivueObject3D representation for WebGL rendering from an NVImage.
 * @param nvImage - The NVImage instance
 * @param id - Unique ID for the 3D object
 * @param gl - WebGL2 rendering context
 * @returns NiivueObject3D instance
 */
export function toNiivueObject3D(nvImage: NVImage, id: number, gl: WebGL2RenderingContext): NiivueObject3D {
  // Ensure necessary RAS properties are available on the nvImage object
  if (!nvImage.dimsRAS || !nvImage.matRAS || !nvImage.pixDimsRAS || !nvImage.vox2mm) {
    throw new Error('Cannot create NiivueObject3D: Missing required RAS properties or vox2mm access on NVImage.')
  }

  const dimsRAS = nvImage.dimsRAS as number[]
  const matRAS = nvImage.matRAS as mat4
  const pixDimsRAS = nvImage.pixDimsRAS as number[]

  const L = -0.5
  const P = -0.5
  const I = -0.5
  const R = dimsRAS[1] - 1 + 0.5
  const A = dimsRAS[2] - 1 + 0.5
  const S = dimsRAS[3] - 1 + 0.5

  const vox2mmFn = nvImage.vox2mm

  // Calculate corner coordinates in mm space
  const LPI = vox2mmFn.call(nvImage, [L, P, I], matRAS)
  const LAI = vox2mmFn.call(nvImage, [L, A, I], matRAS)
  const LPS = vox2mmFn.call(nvImage, [L, P, S], matRAS)
  const LAS = vox2mmFn.call(nvImage, [L, A, S], matRAS)
  const RPI = vox2mmFn.call(nvImage, [R, P, I], matRAS)
  const RAI = vox2mmFn.call(nvImage, [R, A, I], matRAS)
  const RPS = vox2mmFn.call(nvImage, [R, P, S], matRAS)
  const RAS = vox2mmFn.call(nvImage, [R, A, S], matRAS)

  // Define vertex positions (XYZ) and texture coordinates (UVW)
  const posTex = [
    // Superior face vertices (Indices 0-3)
    ...LPS,
    ...[0.0, 0.0, 1.0], // 0
    ...RPS,
    ...[1.0, 0.0, 1.0], // 1
    ...RAS,
    ...[1.0, 1.0, 1.0], // 2
    ...LAS,
    ...[0.0, 1.0, 1.0], // 3

    // Inferior face vertices (Indices 4-7)
    ...LPI,
    ...[0.0, 0.0, 0.0], // 4
    ...LAI,
    ...[0.0, 1.0, 0.0], // 5
    ...RAI,
    ...[1.0, 1.0, 0.0], // 6
    ...RPI,
    ...[1.0, 0.0, 0.0] // 7
  ]

  const indexBuffer = gl.createBuffer()
  if (!indexBuffer) {
    throw new Error('Failed to create GL index buffer')
  }
  gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, indexBuffer)

  const indices = [
    0,
    3,
    2,
    2,
    1,
    0, // Top
    4,
    7,
    6,
    6,
    5,
    4, // Bottom
    5,
    6,
    2,
    2,
    3,
    5, // Front -> Corresponds to LAI(5), RAI(6), RAS(2) / RAS(2), LAS(3), LAI(5)
    4,
    0,
    1,
    1,
    7,
    4, // Back -> Corresponds to LPI(4), LPS(0), RPS(1) / RPS(1), RPI(7), LPI(4)
    7,
    1,
    2,
    2,
    6,
    7, // Right -> Corresponds to RPI(7), RPS(1), RAS(2) / RAS(2), RAI(6), RPI(7)
    4,
    5,
    3,
    3,
    0,
    4 // Left -> Corresponds to LPI(4), LAI(5), LAS(3) / LAS(3), LPS(0), LPI(4)
  ]

  gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, new Uint16Array(indices), gl.STATIC_DRAW)

  // Create buffer for position and texture coordinates
  const posTexBuffer = gl.createBuffer()
  if (!posTexBuffer) {
    throw new Error('Failed to create GL vertex buffer')
  }
  gl.bindBuffer(gl.ARRAY_BUFFER, posTexBuffer)
  gl.bufferData(gl.ARRAY_BUFFER, new Float32Array(posTex), gl.STATIC_DRAW)

  // Create Vertex Array Object (VAO)
  const vao = gl.createVertexArray()
  if (!vao) {
    throw new Error('Failed to create GL VAO')
  }
  gl.bindVertexArray(vao)

  gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, indexBuffer) // Associate index buffer with VAO

  gl.bindBuffer(gl.ARRAY_BUFFER, posTexBuffer) // Associate vertex data buffer with VAO
  // Configure vertex attributes pointers
  const stride = 24 // 6 floats * 4 bytes/float = 24 bytes
  // Vertex spatial position (XYZ) - location 0
  gl.enableVertexAttribArray(0)
  gl.vertexAttribPointer(0, 3, gl.FLOAT, false, stride, 0)
  // Texture coordinates (UVW) - location 1
  gl.enableVertexAttribArray(1)
  gl.vertexAttribPointer(1, 3, gl.FLOAT, false, stride, 12) // Offset 12 bytes (3 floats * 4 bytes/float)

  gl.bindVertexArray(null) // Unbind VAO

  // Create the NiivueObject3D instance
  const obj3D = new NiivueObject3D(id, posTexBuffer, gl.TRIANGLES, indices.length, indexBuffer, vao)

  const allCorners = [...LPS, ...RPS, ...RAS, ...LAS, ...LPI, ...LAI, ...RAI, ...RPI]
  const extents = getExtents(allCorners) // Use the utility function

  obj3D.extentsMin = extents.min.slice() // Use slice() for safety
  obj3D.extentsMax = extents.max.slice() // Use slice() for safety
  obj3D.furthestVertexFromOrigin = extents.furthestVertexFromOrigin

  obj3D.originNegate = vec3.clone(extents.origin)
  vec3.negate(obj3D.originNegate, obj3D.originNegate)

  // Calculate field of view based on RAS dimensions and pixel sizes
  obj3D.fieldOfViewDeObliqueMM = [dimsRAS[1] * pixDimsRAS[1], dimsRAS[2] * pixDimsRAS[2], dimsRAS[3] * pixDimsRAS[3]]

  return obj3D
}
