//
// Postprocessing compute shader for the face mesh pipeline
//

#pragma kernel FaceTransformation
#pragma kernel EyesToFaceRefinement
#pragma kernel LowPassFilter

//
// Kernel 0 - FaceTransformation
//
// Applies a transformation to the face vertices and determines the bounding
// box of the transformed face mesh.
//

#define FX_VERTEX_COUNT 468
#define FX_THREAD_COUNT 52
#define FX_VERTEX_PER_THREAD (FX_VERTEX_COUNT / FX_THREAD_COUNT)

StructuredBuffer<float4> _fx_input;
float4x4 _fx_xform;

RWStructuredBuffer<float4> _fx_output;
RWStructuredBuffer<float4> _fx_bbox;

groupshared float4 fx_bbox_agg[FX_THREAD_COUNT];

[numthreads(FX_THREAD_COUNT, 1, 1)]
void FaceTransformation(uint id : SV_DispatchThreadID)
{
    uint ref = id * FX_VERTEX_PER_THREAD;
    float4 bbox = float4(1, 1, 0, 0);

    // Transformation
    for (uint i = 0; i < FX_VERTEX_PER_THREAD; i++, ref++)
    {
        float4 p = mul(_fx_xform, float4(_fx_input[ref].xyz, 1));
        bbox = float4(min(bbox.xy, p.xy), max(bbox.zw, p.xy));
        _fx_output[ref] = p;
    }

    // Per-thread bounding box
    fx_bbox_agg[id] = bbox;

    // Only the first thread continues.
    if (id > 0) return;

    // Bounding box aggregation
    bbox = fx_bbox_agg[0];

    for (i = 1; i < FX_THREAD_COUNT; i++)
    {
        bbox.xy = min(bbox.xy, fx_bbox_agg[i].xy);
        bbox.zw = max(bbox.zw, fx_bbox_agg[i].zw);
    }

    _fx_bbox[0] = bbox;
}

//
// Kernel 1 - EyesToFaceRefinement
//
// Overwrites the face vertices with the eye meshes.
//

#define E2F_VERTEX_COUNT 71

StructuredBuffer<uint> _e2f_index_table;
StructuredBuffer<float4> _e2f_eye_l;
StructuredBuffer<float4> _e2f_eye_r;
float4x4 _e2f_xform_l;
float4x4 _e2f_xform_r;

RWStructuredBuffer<float4> _e2f_face;

[numthreads(E2F_VERTEX_COUNT, 1, 1)]
void EyesToFaceRefinement(uint id : SV_DispatchThreadID)
{
    uint i_l = _e2f_index_table[id];
    uint i_r = _e2f_index_table[id + E2F_VERTEX_COUNT];

    _e2f_face[i_l].xy = mul(_e2f_xform_l, float4(_e2f_eye_l[id + 5].xy, 0, 1)).xy;
    _e2f_face[i_r].xy = mul(_e2f_xform_r, float4(_e2f_eye_r[id + 5].xy, 0, 1)).xy;
}

//
// Kernel 2 - LowPassFilter
//
// Low pass filter for vertex positions
//

#define LPF_VERTEX_COUNT 468

StructuredBuffer<float4> _lpf_input;
RWStructuredBuffer<float4> _lpf_output;

float _lpf_beta;
float _lpf_cutoff_min;
float _lpf_t_e;

float lpf_alpha(float cutoff)
{
    float r = 2 * 3.141592 * cutoff * _lpf_t_e;
    return r / (r + 1);
}

[numthreads(52, 1, 1)]
void LowPassFilter(uint id : SV_DispatchThreadID)
{
    float3 x = _lpf_input[id].xyz;
    float3 p_x = _lpf_output[id].xyz;
    float3 p_dx = _lpf_output[id + LPF_VERTEX_COUNT].xyz;

    float3 dx = (x - p_x) / _lpf_t_e;
    float3 dx_res = lerp(p_dx, dx, lpf_alpha(1));

    float cutoff = _lpf_cutoff_min + _lpf_beta * length(dx_res);
    float3 x_res = lerp(p_x, x, lpf_alpha(cutoff));

    _lpf_output[id] = float4(x_res, 1);
    _lpf_output[id + LPF_VERTEX_COUNT] = float4(dx_res, 1);
}
