vec3 inverseTransformDirection(in vec3 normal, in mat4 matrix) {
    // matrix is assumed to be orthogonal
    return normalize((vec4(normal, 0.0) * matrix).xyz);
}


precision highp usampler2D;
precision highp usampler3D;

uniform sampler2D lpv_t_probe_positions;
uniform sampler2D lpv_t_probe_depth;

uniform uint lpv_u_probe_depth_resolution;

#define SEARCH_STEP_LIMIT 64u
#define INVALID_TET 1073741823u

ivec2 lpv_index_to_256_coordinate(uint index) {

    uint pixel_x = index % 256u;
    uint pixel_y = index / 256u;

    return ivec2(int(pixel_x), int(pixel_y));
}

float sign_not_zero(float x) {
    return x >= 0.0 ? 1.0 : -1.0;
}

vec2 sign_not_zero(vec2 x) {
    return vec2(
        sign_not_zero(x.x),
        sign_not_zero(x.y)
    );
}

vec3 sign_not_zero(vec3 x) {
    return vec3(
        sign_not_zero(x.x),
        sign_not_zero(x.y),
        sign_not_zero(x.z)
    );
}

vec4 quadBlendWieghts(vec2 coords)
{
    vec4 res;
/* 0 0 0
    0 0 0
    1 0 0 */
    res.x = min(1.0f - coords.x, 1.0f - coords.y);
/* 1 0 0
    0 0 0
    0 0 1 */
    res.y = abs(coords.x - coords.y);
/* 0 0 1
    0 0 0
    0 0 0 */
    res.z = min(coords.x, coords.y);
/* 0 0 0
    0 0 1
    0 1 1 */
    res.w = ceil(coords.x - coords.y);
    //res.xyz /= (res.x + res.y + res.z);
    return res;
}

//vec2 VecToSphereOct(vec3 v)
//{
//    float l1norm = abs(v.x) + abs(v.y) + abs(v.z);
//
//    vec2 result = v.xz / l1norm;
//
//    if (v.y < 0.0) {
//        result = (1.0 - abs(result.yx)) * sign_not_zero(result.xy);
//    }
//
//    return result;
//}

vec3 OctaSphereEnc(vec2 coord)
{
    coord = (coord - 0.5) * 2.0;
    vec3 position = vec3(coord.x, 0.0, coord.y);
    vec2 absolute = abs(position.xz);
    position.y = 1.0 - absolute.x - absolute.y;

    if (position.y < 0.0)
    {
        position.xz = sign(position.xz) * vec2(1.0 - absolute.y, 1.0 - absolute.x);
    }

    return position;
}

vec2 VecToSphereOct(vec3 v)
{
    float l1norm = abs(v.x) + abs(v.y) + abs(v.z);

    vec2 result = v.xz / l1norm;

    if (v.y < 0.0) {
        result = (1.0 - abs(result.yx)) * sign_not_zero(result.xy);
    }

    return result;
}

float SampleBlended(sampler2D tex, vec2 uv0, vec2 uv1, vec2 uv2, vec4 weights) {

    float samp0 = textureLod(tex, uv0, 0.0).r;
    float samp1 = textureLod(tex, uv1, 0.0).r;
    float samp2 = textureLod(tex, uv2, 0.0).r;

    return samp0 * weights.x + samp1 * weights.y + samp2 * weights.z;
}

float lpv_probe_getDepth(uint probe_index, vec3 direction) {
    // get offset
    uint depth_tile_resolution = lpv_u_probe_depth_resolution;
    uvec2 atlas_size = uvec2(4096u);

    uint tiles_per_row = atlas_size.x / depth_tile_resolution;

    uint tile_x = probe_index % tiles_per_row;
    uint tile_y = probe_index / tiles_per_row;

    vec2 tile_offset = vec2(
        tile_x * depth_tile_resolution,
        tile_y * depth_tile_resolution
    );

    // convert direction to UV
    vec2 octahedral_uv = clamp(VecToSphereOct(direction) * 0.5 + 0.5, 0.0, 1.0);
    vec2 grid = octahedral_uv * vec2(depth_tile_resolution - 1u);

    vec2 gridFrac = fract(grid);
    vec2 gridFloor = floor(grid);

    vec4 weights = quadBlendWieghts(gridFrac);

    //3 nearest frames
    vec2 frame0 = gridFloor;
    vec2 frame1 = gridFloor + mix(vec2(0, 1), vec2(1, 0), weights.w);
    vec2 frame2 = gridFloor + vec2(1.0, 1.0);

    // move frames to atlas space
    frame0 += tile_offset;
    frame1 += tile_offset;
    frame2 += tile_offset;

    vec2 uv0 = (frame0 + 0.5) / vec2(atlas_size);
    vec2 uv1 = (frame1 + 0.5) / vec2(atlas_size);
    vec2 uv2 = (frame2 + 0.5) / vec2(atlas_size);

    return SampleBlended(
        lpv_t_probe_depth,
        uv0, uv1, uv2,
        weights
    );
}

float lpv_bilinear_lerp(float v00, float v01, float v10, float v11, vec2 fraction) {

    float x0 = mix(v00, v01, fraction.x);
    float x1 = mix(v10, v11, fraction.x);

    return mix(x0, x1, fraction.y);
}

vec2 lpv_bilinear_lerp(vec2 v00, vec2 v01, vec2 v10, vec2 v11, vec2 fraction) {

    vec2 x0 = mix(v00, v01, fraction.x);
    vec2 x1 = mix(v10, v11, fraction.x);

    return mix(x0, x1, fraction.y);
}

vec2 lpv_sample_bilinear(sampler2D tex, ivec2 texel_position, vec2 fraction) {

    float texel_00 = texelFetch(tex, texel_position, 0).r;
    float texel_01 = texelFetch(tex, texel_position + ivec2(1, 0), 0).r;
    float texel_10 = texelFetch(tex, texel_position + ivec2(0, 1), 0).r;
    float texel_11 = texelFetch(tex, texel_position + ivec2(1, 1), 0).r;

    return vec2(
        lpv_bilinear_lerp(
            texel_00, texel_01,
            texel_10, texel_11,
            fraction
        ),
        lpv_bilinear_lerp(
            texel_00 * texel_00, texel_01 * texel_01,
            texel_10 * texel_10, texel_11 * texel_11,
            fraction
        )
    );
}

ivec2 wrapOctahedralTexelCoordinatesNone(in ivec2 texelCoord, in int texture_size) {
    return texelCoord;
}

ivec2 wrapOctahedralTexelCoordinatesV0(in ivec2 texelCoord, in int texture_size) {
    //    return texelCoord;
    // Decrement the texture size by 1 to get the maximum valid index
    int maxIndex = texture_size - 1;

    // Wrap the texel coordinates
    ivec2 wrappedCoord = texelCoord;

    // Wrap the X coordinate
    if (wrappedCoord.x > maxIndex) {
        wrappedCoord.x = maxIndex;
        wrappedCoord.y += 1;
    }

    // Wrap the Y coordinate
    if (wrappedCoord.y > maxIndex) {
        wrappedCoord.y = maxIndex - wrappedCoord.y; // Wrap from top to bottom
    }

    return wrappedCoord;
}

ivec2 wrapOctahedralTexelCoordinates(const in ivec2 texel, const in int texture_size) {
    ivec2 wrapped = ((texel % texture_size) + texture_size) % texture_size;

    int fx = (abs(texel.x / texture_size) + int(texel.x < 0));
    int fy = (abs(texel.y / texture_size) + int(texel.y < 0));

    if (((fx ^ fy) & 1) != 0) {
        return (texture_size - (wrapped + ivec2(1)));
        return wrapped;
    } else {
        return wrapped;
    }
}


vec2 lpv_probe_getDepthBilinear(uint probe_index, vec3 direction) {
    // get offset
    int depth_tile_resolution = int(lpv_u_probe_depth_resolution);
    ivec2 tile_resolution = ivec2(depth_tile_resolution);
    const ivec2 atlas_size = ivec2(4096);

    int tiles_per_row = atlas_size.x / depth_tile_resolution;

    int tile_x = int(probe_index) % tiles_per_row;
    int tile_y = int(probe_index) / tiles_per_row;

    ivec2 tile_offset = ivec2(
        tile_x * depth_tile_resolution,
        tile_y * depth_tile_resolution
    );

    // convert direction to UV
    vec2 octahedral_uv = clamp(VecToSphereOct(direction) * 0.5 + 0.5, 0.0, 1.0);
    vec2 grid = octahedral_uv * vec2(depth_tile_resolution) - 0.5;

    vec2 gridFrac = fract(grid);

    ivec2 texel_position = ivec2(floor(grid));

    ivec2 tile_p_00;
    ivec2 tile_p_01;
    ivec2 tile_p_10;
    ivec2 tile_p_11;

    tile_p_00 = wrapOctahedralTexelCoordinates(texel_position , depth_tile_resolution);
    tile_p_01 = wrapOctahedralTexelCoordinates(texel_position + ivec2(1, 0), depth_tile_resolution);
    tile_p_10 = wrapOctahedralTexelCoordinates(texel_position + ivec2(0, 1), depth_tile_resolution);
    tile_p_11 = wrapOctahedralTexelCoordinates(texel_position + ivec2(1, 1), depth_tile_resolution);


    vec2 texel_00 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_00, 0).rg;
    vec2 texel_01 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_01, 0).rg;
    vec2 texel_10 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_10, 0).rg;
    vec2 texel_11 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_11, 0).rg;

    return lpv_bilinear_lerp(
        texel_00, texel_01,
        texel_10, texel_11,
        gridFrac
    );
}

vec2 lpv_probe_getDepthTriangular(uint probe_index, vec3 direction) {
    // get offset
    uint depth_tile_resolution = lpv_u_probe_depth_resolution;
    uvec2 atlas_size = uvec2(4096u);

    uint tiles_per_row = atlas_size.x / depth_tile_resolution;

    uint tile_x = probe_index % tiles_per_row;
    uint tile_y = probe_index / tiles_per_row;

    vec2 tile_offset = vec2(
        tile_x * depth_tile_resolution,
        tile_y * depth_tile_resolution
    );

    // convert direction to UV
    vec2 octahedral_uv = clamp(VecToSphereOct(direction) * 0.5 + 0.5, 0.0, 1.0);
    vec2 grid = octahedral_uv * vec2(depth_tile_resolution - 1u);

    vec2 gridFrac = fract(grid);
    vec2 gridFloor = floor(grid);

    vec4 weights = quadBlendWieghts(gridFrac);

    //3 nearest frames
    vec2 frame0 = gridFloor;
    vec2 frame1 = gridFloor + mix(vec2(0, 1), vec2(1, 0), weights.w);
    vec2 frame2 = gridFloor + vec2(1.0, 1.0);

    // move frames to atlas space
    frame0 += tile_offset;
    frame1 += tile_offset;
    frame2 += tile_offset;

    float samp0 = texelFetch(lpv_t_probe_depth, ivec2(frame0), 0).r;
    float samp1 = texelFetch(lpv_t_probe_depth, ivec2(frame1), 0).r;
    float samp2 = texelFetch(lpv_t_probe_depth, ivec2(frame2), 0).r;

    float d0 = samp0 * weights.x;
    float d1 = samp1 * weights.y;
    float d2 = samp2 * weights.z;

    float mean = d0 + d1 + d2;
    float mean2 = samp0 * d0 + samp1 * d1 + samp2 * d2;

    return vec2(mean, mean2);
}
vec3 lpv_DEBUG(uint probe_index, vec3 direction) {
    // get offset
    int depth_tile_resolution = int(lpv_u_probe_depth_resolution);
    const ivec2 atlas_size = ivec2(4096);

    int tiles_per_row = atlas_size.x / depth_tile_resolution;

    int tile_x = int(probe_index) % tiles_per_row;
    int tile_y = int(probe_index) / tiles_per_row;

    ivec2 tile_offset = ivec2(
        tile_x * depth_tile_resolution,
        tile_y * depth_tile_resolution
    );

    // convert direction to UV
    vec2 octahedral_uv = clamp(VecToSphereOct(direction) * 0.5 + 0.5, 0.0, 1.0);
    vec2 grid = octahedral_uv * vec2(depth_tile_resolution);

    vec2 gridFrac = fract(grid);

    ivec2 texel_position = ivec2(floor(grid));

    ivec2 tile_p_01 = wrapOctahedralTexelCoordinates(texel_position + ivec2(1, 0), depth_tile_resolution);
    ivec2 tile_p_10 = wrapOctahedralTexelCoordinates(texel_position + ivec2(0, 1), depth_tile_resolution);
    ivec2 tile_p_11 = wrapOctahedralTexelCoordinates(texel_position + ivec2(1, 1), depth_tile_resolution);

    return mix(vec3(
                   tile_p_01.x == depth_tile_resolution ? 1.0 : 0.0,
                   tile_p_10.y == depth_tile_resolution ? 1.0 : 0.0,
                   0.0
               ), vec3(octahedral_uv, 0.0), 0.0);

    float texel_00 = texelFetch(lpv_t_probe_depth, tile_offset + texel_position, 0).r;
    float texel_01 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_01, 0).r;
    float texel_10 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_10, 0).r;
    float texel_11 = texelFetch(lpv_t_probe_depth, tile_offset + tile_p_11, 0).r;


    vec3 dir_00 = OctaSphereEnc(vec2(texel_position) / float(depth_tile_resolution));
    vec3 dir_01 = OctaSphereEnc(vec2(tile_p_01) / float(depth_tile_resolution));
    vec3 dir_10 = OctaSphereEnc(vec2(tile_p_10) / float(depth_tile_resolution));
    vec3 dir_11 = OctaSphereEnc(vec2(tile_p_11) / float(depth_tile_resolution));


    vec3 dir = vec3(
        lpv_bilinear_lerp(
            dir_00.x, dir_01.x,
            dir_10.x, dir_11.x,
            gridFrac
        ), lpv_bilinear_lerp(
            dir_00.y, dir_01.y,
            dir_10.y, dir_11.y,
            gridFrac
        ), lpv_bilinear_lerp(
            dir_00.z, dir_01.z,
            dir_10.z, dir_11.z,
            gridFrac
        )
    ) * 0.5 + 0.5;

    return dir;

    return vec3(octahedral_uv, 1.0);
}


vec3 lpv_probe_getPosition(uint probe_index) {
    return texelFetch(lpv_t_probe_positions, lpv_index_to_256_coordinate(probe_index), 0).rgb;
}

varying vec3 vNormal;
flat varying uint instance_index;

void main() {
    vec3 normal = normalize(vNormal);
    vec3 worldNormal = normalize(inverseTransformDirection(normal, viewMatrix));

    float depth = lpv_probe_getDepthBilinear(instance_index, worldNormal).r / 7.0;
//    float depth = lpv_probe_getDepthTriangular(instance_index, worldNormal).r / 7.0;

    //    gl_FragColor = vec4(lpv_DEBUG(instance_index, worldNormal), 1.0);
    gl_FragColor = vec4(vec3(depth), 1.0);
    //    gl_FragColor = vec4( 1.0);
}