import { hash } from 'three/src/nodes/core/NodeUtils.js'
import {
  add,
  Fn,
  If,
  mix,
  positionGeometry,
  remapClamp,
  uv,
  vec2,
  vec3,
  vec4
} from 'three/tsl'
import { TempNode, type NodeBuilder } from 'three/webgpu'

import {
  depthToViewZ,
  inverseProjectionMatrix,
  inverseViewMatrix,
  projectionMatrix,
  rayEllipsoidIntersection,
  screenToPositionView,
  type Node
} from '@takram/three-geospatial/webgpu'

import { getAtmosphereContext } from './AtmosphereContext'
import { getIndirectLuminanceToPoint, getSplitIlluminance } from './runtime'
import { sky } from './SkyNode'

export class AerialPerspectiveNode extends TempNode {
  static override get type(): string {
    return 'AerialPerspectiveNode'
  }

  colorNode: Node<'vec4'>
  depthNode: Node<'float'>
  normalNode?: Node<'vec3'> | null
  skyNode?: Node<'vec3'> | null
  shadowLengthNode?: Node<'vec2'> | null

  correctGeometricError = true
  lighting = false
  transmittance = true
  inscattering = true
  moonScattering = false

  constructor(
    colorNode: Node<'vec4'>,
    depthNode: Node<'float'>,
    normalNode?: Node<'vec3'> | null,
    shadowLengthNode?: Node<'vec2'> | null
  ) {
    super('vec4')
    this.colorNode = colorNode
    this.depthNode = depthNode
    this.normalNode = normalNode
    this.shadowLengthNode = shadowLengthNode
    this.skyNode = sky(shadowLengthNode)

    this.lighting = normalNode != null
  }

  override customCacheKey(): number {
    return hash(
      +this.correctGeometricError,
      +this.lighting,
      +this.transmittance,
      +this.inscattering,
      +this.moonScattering
    )
  }

  override setup(builder: NodeBuilder): unknown {
    const atmosphereContext = getAtmosphereContext(builder)

    const camera = atmosphereContext.camera ?? builder.camera
    if (camera == null) {
      return
    }

    const { worldToUnit } = atmosphereContext.parametersNode
    const {
      ellipsoid,
      matrixWorldToECEF,
      sunDirectionECEF,
      moonDirectionECEF,
      cameraPositionUnit,
      altitudeCorrectionUnit
    } = atmosphereContext

    const { colorNode, depthNode, normalNode } = this
    const depth = depthNode.r.toConst()

    const getSurfacePositionECEF = (): Node<'vec3'> => {
      const viewZ = depthToViewZ(depth, camera)
      const positionView = screenToPositionView(
        uv(),
        depth,
        viewZ,
        projectionMatrix(camera),
        inverseProjectionMatrix(camera)
      )
      const positionWorld = inverseViewMatrix(camera).mul(
        vec4(positionView, 1)
      ).xyz
      return matrixWorldToECEF.mul(vec4(positionWorld, 1)).xyz
    }

    const getRayDirectionECEF = (): Node<'vec3'> => {
      const positionView = inverseProjectionMatrix(camera).mul(
        vec4(positionGeometry, 1)
      ).xyz
      const directionWorld = inverseViewMatrix(camera).mul(
        vec4(positionView, 0)
      ).xyz
      const directionECEF = matrixWorldToECEF.mul(vec4(directionWorld, 0)).xyz
      return directionECEF.toVertexStage().normalize()
    }

    const surfaceLuminance = Fn(() => {
      const positionUnit = getSurfacePositionECEF().mul(worldToUnit).toVar()

      // Changed our strategy on the geometric error correction, because we no
      // longer have LightingMask to exclude objects in space.
      const geometryCorrectionAmount = remapClamp(
        positionUnit.distance(cameraPositionUnit),
        // The distance to the horizon from the highest point on the earth,
        worldToUnit.mul(336_000),
        // The distance to the horizon at the top atmosphere
        worldToUnit.mul(876_000)
      )

      // Geometry normal can be trivially corrected:
      const radiiUnit = vec3(ellipsoid.radii).mul(worldToUnit)
      const normalCorrected = positionUnit.div(radiiUnit.pow2()).normalize()

      if (this.correctGeometricError) {
        const rayDirectionECEF = getRayDirectionECEF()
        const intersection = rayEllipsoidIntersection(
          cameraPositionUnit,
          rayDirectionECEF,
          radiiUnit
        ).x // Near side

        const positionCorrected = intersection.greaterThanEqual(0).select(
          rayDirectionECEF.mul(intersection).add(cameraPositionUnit),
          // Fallback to radial projection:
          normalCorrected.mul(radiiUnit)
        )
        positionUnit.assign(
          mix(positionUnit, positionCorrected, geometryCorrectionAmount)
        )
      }

      const illuminance = Fn(() => {
        // Normal vector of the surface:
        let normalECEF
        if (normalNode != null) {
          const normalView = normalNode.xyz
          const normalWorld = inverseViewMatrix(camera).mul(
            vec4(normalView, 0)
          ).xyz
          normalECEF = matrixWorldToECEF.mul(vec4(normalWorld, 0)).xyz

          if (this.correctGeometricError) {
            normalECEF.assign(
              mix(normalECEF, normalCorrected, geometryCorrectionAmount)
            )
          }
        } else {
          normalECEF = positionUnit.normalize()
        }

        // Direct and indirect illuminance on the surface:
        const solarIlluminance = getSplitIlluminance(
          positionUnit.add(altitudeCorrectionUnit),
          normalECEF,
          sunDirectionECEF
        )
        let illuminance = add(
          solarIlluminance.get('direct'),
          solarIlluminance.get('indirect')
        )
        if (this.moonScattering) {
          const lunarIlluminance = getSplitIlluminance(
            positionUnit.add(altitudeCorrectionUnit),
            normalECEF,
            moonDirectionECEF
          )
          illuminance = add(
            illuminance,
            lunarIlluminance.get('direct'),
            lunarIlluminance.get('indirect')
          )
        }
        return illuminance
      })()

      const diffuse = this.lighting
        ? colorNode.rgb.mul(illuminance).mul(1 / Math.PI) // Lambertian
        : colorNode.rgb

      // Scattering between the camera to the surface:
      const solarLuminanceTransfer = getIndirectLuminanceToPoint(
        cameraPositionUnit.add(altitudeCorrectionUnit),
        positionUnit.add(altitudeCorrectionUnit),
        this.shadowLengthNode ?? vec2(0),
        sunDirectionECEF
      ).toConst()
      const transmittance = solarLuminanceTransfer.get('transmittance')
      let inscattering = solarLuminanceTransfer.get('luminance')

      if (this.moonScattering) {
        // TODO: Combine raymarch when raymarchScattering.
        const lunarLuminanceTransfer = getIndirectLuminanceToPoint(
          cameraPositionUnit.add(altitudeCorrectionUnit),
          positionUnit.add(altitudeCorrectionUnit),
          this.shadowLengthNode ?? vec2(0),
          moonDirectionECEF
        ).toConst()

        // TODO: Consider moon phase
        inscattering = inscattering.add(
          lunarLuminanceTransfer.get('luminance').mul(2.5e-6)
        )
      }

      let output = diffuse
      if (this.transmittance) {
        output = output.mul(transmittance)
      }
      if (this.inscattering) {
        output = output.add(inscattering)
      }
      return output
    })()

    return Fn(() => {
      const luminance = colorNode.toVar()
      If(
        builder.renderer.reversedDepthBuffer
          ? depth.lessThanEqual(0)
          : depth.greaterThanEqual(1),
        () => {
          if (this.skyNode != null) {
            // Render the sky (the scattering seen from the camera to an infinite
            // distance) for very far depths.
            luminance.rgb.assign(this.skyNode)
          }
        }
      ).Else(() => {
        luminance.rgb.assign(surfaceLuminance)
      })
      return luminance
    })()
  }

  override dispose(): void {
    this.skyNode?.dispose() // TODO: Conditionally depending on the owner.
    super.dispose()
  }

  /** @deprecated Use inscattering instead. */
  get inscatter(): boolean {
    return this.inscattering
  }

  /** @deprecated Use inscattering instead. */
  set inscatter(value: boolean) {
    this.inscattering = value
  }
}

export const aerialPerspective = (
  ...args: ConstructorParameters<typeof AerialPerspectiveNode>
): AerialPerspectiveNode => new AerialPerspectiveNode(...args)
