import { BlendFunction, Effect } from "postprocessing";
import { Uniform } from "three";

import { serializable } from "../../../engine/engine_serialization.js";
import { PostProcessingEffect } from "../PostProcessingEffect.js";

/**
 * @category Effects
 * @group Components
 */
export class SharpeningEffect extends PostProcessingEffect {

    get typeName() {
        return "Sharpening";
    }

    private _effect?: _SharpeningEffect;

    onCreateEffect() {
        return this.effect;
    }

    private get effect() {
        this._effect ??= new _SharpeningEffect();
        return this._effect;
    }

    @serializable()
    set amount(value: number) {
        this.effect.uniforms.get("amount")!.value = value;
    }
    get amount() {
        return this.effect.uniforms.get("amount")!.value;
    }

    @serializable()
    set radius(value: number) {
        this.effect.uniforms.get("radius")!.value = value;
    }
    get radius() {
        return this.effect.uniforms.get("radius")!.value;
    }
    
    // @serializable()
    // set threshold(value: number) {
    //     this.effect.uniforms.get("threshold")!.value = value;
    // }
    // get threshold() {
    //     return this.effect.uniforms.get("threshold")!.value;
    // }

}


const vert = `
  void mainSupport() {
    vUv = uv;
    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
  }
`

const frag = `
uniform sampler2D tDiffuse;
uniform float amount;
uniform float radius;

void mainImage(const in vec4 inputColor, const in vec2 uv, out vec4 outputColor) {
    float tx = 1.0 / resolution.x;
    float ty = 1.0 / resolution.y;
    vec2 texelSize = vec2(tx, ty);

    vec4 blurred = vec4(0.0);
    float total = 0.0;

    for (float x = -radius; x <= radius; x++) {
        for (float y = -radius; y <= radius; y++) {
            vec2 offset = vec2(x, y) * texelSize;
            vec4 diffuse = texture2D(tDiffuse, uv + offset);
            float weight = exp(-length(offset) * amount);
            blurred += diffuse * weight;
            total += weight;
        }
    }

    if (total > 0.0) {
        blurred /= total;
    }

    // Calculate the sharpened color using inputColor
    vec4 sharp = inputColor + clamp(inputColor - blurred, 0.0, 1.0) * amount;
    // Keep original alpha
    sharp.a = inputColor.a;

    // Ensure the sharp color does not go below 0 or above 1
    // This means: sharpening must happen AFTER tonemapping.
    sharp = clamp(sharp, 0.0, 1.0);

    outputColor = sharp;
}

`

class _SharpeningEffect extends Effect {
    constructor() {
        super("Sharpening", frag, {
            vertexShader: vert,
            blendFunction: BlendFunction.NORMAL,
            uniforms: new Map<string, Uniform<any>>([
                ["amount", new Uniform(1)],
                ["radius", new Uniform(1)],
                // ["threshold", new Uniform(0)],
            ]),
        });
    }
}

export { _SharpeningEffect }
