import { ACESFilmicToneMapping, AgXToneMapping, NeutralToneMapping, NoToneMapping, ShaderChunk, ToneMapping } from "three";

import { isDevEnvironment } from "./debug/index.js";
import type { Context } from "./engine_setup.js";

let patchedTonemapping = false;
export function patchTonemapping(_ctx?: Context) {
    if (patchedTonemapping) return;
    patchedTonemapping = true;
    patchNeutral();
    patchAGX();
}

function patchNeutral() {
    // From https://github.com/google/model-viewer/pull/4495
    // Emmett's new 3D Commerce tone mapping function
    const commerceToneMapping = `
float startCompression = 0.8;
float desaturation = 0.5;
// Patched tonemapping function
vec3 NeutralToneMapping( vec3 color ) {
    color *= toneMappingExposure;
    
    float d = 1. - startCompression;
    // float peak = dot(color, vec3(0.299, 0.587, 0.114));
    float peak = max(color.r, max(color.g, color.b));
    if (peak < startCompression) return color;
    float newPeak = 1. - d * d / (peak + d - startCompression);
    float invPeak = 1. / peak;
    
    float extraBrightness = dot(color * (1. - startCompression * invPeak), vec3(1, 1, 1));
    
    color *= newPeak * invPeak;
    float g = 1. - 3. / (desaturation * extraBrightness + 3.);
    return mix(color, vec3(1, 1, 1), g);
}
`;
    const startStr = `vec3 NeutralToneMapping( vec3 color ) {`
    const endStr = `return mix( color, vec3( newPeak ), g );
}`;

    // Patch Neutral
    const startIndex = ShaderChunk.tonemapping_pars_fragment.indexOf(startStr);
    const endIndex = ShaderChunk.tonemapping_pars_fragment.indexOf(endStr, startIndex);
    if (startIndex >= 0 && endIndex >= 0) {
        // get the old tonemapping
        const existing = ShaderChunk.tonemapping_pars_fragment.substring(startIndex, endIndex + endStr.length);
        // replace it with the new one
        ShaderChunk.tonemapping_pars_fragment = ShaderChunk.tonemapping_pars_fragment.replace(existing, commerceToneMapping);
    }
    else if (isDevEnvironment()) {
        console.error("Couldn't find NeutralToneMapping in ShaderChunk.tonemapping_pars_fragment");
    }
}


function patchAGX() {
    // From https://iolite-engine.com/blog_posts/minimal_agx_implementation
    // Found via https://github.com/google/filament/pull/7236/files#diff-5fe2be2d109db1d5040bc1d96d167c58db4db1b4793816d3a1e90a5d5304af2cR260
    const agxToneMapping = `
// 0: Default, 1: Golden, 2: Punchy
#define AGX_LOOK 0        

vec3 userSlope = vec3(1.0);
vec3 userOffset = vec3(0.0);
vec3 userPower = vec3(1.0);
float userSaturation = 1.0;

// Mean error^2: 3.6705141e-06
vec3 _agxDefaultContrastApprox(vec3 x) {
    vec3 x2 = x * x;
    vec3 x4 = x2 * x2;
    
    return  + 15.5     * x4 * x2
            - 40.14    * x4 * x
            + 31.96    * x4
            - 6.868    * x2 * x
            + 0.4298   * x2
            + 0.1191   * x
            - 0.00232;
}

vec3 _agx(vec3 val) {
    const mat3 agx_mat = mat3(
        0.842479062253094, 0.0423282422610123, 0.0423756549057051,
        0.0784335999999992,  0.878468636469772,  0.0784336,
        0.0792237451477643, 0.0791661274605434, 0.879142973793104);
    
    const float min_ev = -12.47393;
    const float max_ev = 4.026069;

    // val = pow(val, vec3(2.2)); 

    // Input transform (inset)
    val = agx_mat * val;
    
    // Log2 space encoding
    val = clamp(log2(val), min_ev, max_ev);
    val = (val - min_ev) / (max_ev - min_ev);
    
    // Apply sigmoid function approximation
    val = _agxDefaultContrastApprox(val);

    return val;
}

vec3 _agxEotf(vec3 val) {
    const mat3 agx_mat_inv = mat3(
        1.19687900512017, -0.0528968517574562, -0.0529716355144438,
        -0.0980208811401368, 1.15190312990417, -0.0980434501171241,
        -0.0990297440797205, -0.0989611768448433, 1.15107367264116);
        
    // Inverse input transform (outset)
    val = agx_mat_inv * val;
    
    // sRGB IEC 61966-2-1 2.2 Exponent Reference EOTF Display
    // NOTE: We're linearizing the output here. Comment/adjust when
    // *not* using a sRGB render target
    val = pow(val, vec3(2.2)); 

    return val;
}

vec3 _agxLook(vec3 val) {
    const vec3 lw = vec3(0.2126, 0.7152, 0.0722);
    float luma = dot(val, lw);
    
    // Default
    vec3 offset = vec3(0.0);
    vec3 slope = vec3(1.0);
    vec3 power = vec3(1.0);
    float sat = 1.0;
    
    #if AGX_LOOK == 1
    // Golden
    slope = vec3(1.0, 0.9, 0.5);
    power = vec3(0.8);
    sat = 0.8;
    #elif AGX_LOOK == 2
    // Punchy
    slope = vec3(1.0);
    power = vec3(1.35, 1.35, 1.35);
    sat = 1.4;
    #endif        
    
    // Needle
    slope = vec3(1.05);
    power = vec3(1.10, 1.10, 1.10);
    sat = 1.15;

    // User
    // slope = userSlope;
    // offset = userOffset;
    // power = userPower;
    // sat = userSaturation;
    
    // ASC CDL
    val = pow(val * slope + offset, power);
    return luma + sat * (val - luma);
}


vec3 AgXToneMapping( vec3 color ) {
    // apply AGX
    color *= toneMappingExposure;
    color = max(color, vec3(0.001)); // Prevent NaN
    color = _agx(color);
    color = _agxLook(color); // Optional
    color = _agxEotf(color);
    return color;
`;

    const startString = `vec3 AgXToneMapping( vec3 color ) {`;
    const endString = `return color;`;
    const startIndex = ShaderChunk.tonemapping_pars_fragment.indexOf(startString);
    const endIndex = ShaderChunk.tonemapping_pars_fragment.indexOf(endString, startIndex);
    if (startIndex >= 0 && endIndex >= 0) {
        const existing = ShaderChunk.tonemapping_pars_fragment.substring(startIndex, endIndex + endString.length);
        ShaderChunk.tonemapping_pars_fragment = ShaderChunk.tonemapping_pars_fragment.replace(existing, agxToneMapping);
    }
    else if (isDevEnvironment()) {
        console.error("Couldn't find AgXToneMapping in ShaderChunk.tonemapping_pars_fragment");
    }

}


type TonemappingName = "none" | "neutral" | "aces" | "agx" | "khronos_neutral";

export function nameToThreeTonemapping(str: null | undefined | TonemappingName | ({} & string)): undefined | ToneMapping {
    if (typeof str !== "string")
        return undefined;
    str = str.toLowerCase();
    switch (str) {
        case "none":
            return NoToneMapping;
        case "neutral":
            return NeutralToneMapping;
        case "aces":
            return ACESFilmicToneMapping;
        case "agx":
            return AgXToneMapping;
        case "khronos_neutral":
            return NeutralToneMapping;
        default:
            console.warn("[PostProcessing] Unknown tone mapping mode", str);
            return undefined;
    }
}