struct ImposterData
{
    vec2 uv;
    vec2 grid;
    vec4 frame0;
    vec4 frame1;
    vec4 frame2;
    vec4 vertex;
};

struct Ray
{
    vec3 Origin;
    vec3 Direction;
};

//for hemisphere
vec3 OctaHemiEnc(vec2 coord)
{
    coord = vec2(coord.x + coord.y, coord.x - coord.y) * 0.5;
    vec3 vec = vec3(coord.x, 1.0 - dot(vec2(1.0, 1.0), abs(coord.xy)), coord.y);
    return vec;
}
//for sphere
vec3 OctaSphereEnc(vec2 coord)
{
    vec3 vec = vec3(coord.x, 1.0 - dot(vec2(1.0), abs(coord)), coord.y);
    if (vec.y < 0.0)
    {
        vec2 flip = vec.xz >= 0.0 ? vec2(1.0, 1.0) : vec2(-1.0, -1.0);
        vec.xz = (1-abs(vec.zx)) * flip;
    }
    return vec;
}

vec3 GridToVector(vec2 coord)
{
    vec3 vec;
    if (_ImposterFullSphere)
    {
        vec = OctaSphereEnc(coord);
    }
    else
    {
        vec = OctaHemiEnc(coord);
    }
    return vec;
}

//for hemisphere
vec2 VecToHemiOct(vec3 vec)
{
    vec.xz /= dot(1.0, abs(vec));
    return vec2(vec.x + vec.z, vec.x - vec.z);
}

vec2 VecToSphereOct(vec3 vec)
{
    vec.xz /= dot(1, abs(vec));
    if (vec.y <= 0)
    {
        vec2 flip = vec.xz >= 0 ? vec2(1, 1) : vec2(-1, -1);
        vec.xz = (1-abs(vec.zx)) * flip;
    }
    return vec.xz;
}

vec2 VectorToGrid(vec3 vec)
{
    vec2 coord;

    if (_ImposterFullSphere)
    {
        coord = VecToSphereOct(vec);
    }
    else
    {
        vec.y = max(0.001, vec.y);
        vec = normalize(vec);
        coord = VecToHemiOct(vec);
    }
    return coord;
}

vec3 SpriteProjection(vec3 pivotToCameraRayLocal, float frames, vec2 size, vec2 coord){
    vec3 gridVec = pivotToCameraRayLocal;

    //octahedron vector, pivot to camera
    vec3 y = normalize(gridVec);

    vec3 x = normalize(cross(y, vec3(0.0, 1.0, 0.0)));
    vec3 z = normalize(cross(x, y));

    vec2 uv = ((coord*frames)-0.5) * 2.0;//-1 to 1

    vec3 newX = x * uv.x;
    vec3 newZ = z * uv.y;

    vec2 halfSize = size*0.5;

    newX *= halfSize.x;
    newZ *= halfSize.y;

    vec3 res = newX + newZ;

    return res;
}

vec4 TriangleInterpolate(vec2 uv){
    uv = frac(uv);

    vec2 omuv = vec2(1.0, 1.0) - uv.xy;

    vec4 res = vec4(0, 0, 0, 0);
    //frame 0
    res.x = min(omuv.x, omuv.y);
    //frame 1
    res.y = abs(dot(uv, vec2(1.0, -1.0)));
    //frame 2
    res.z = min(uv.x, uv.y);
    //mask
    res.w = saturate(ceil(uv.x-uv.y));

    return res;
}

//frame and framecout, returns
vec3 FrameXYToRay(vec2 frame, vec2 frameCountMinusOne)
{
    //divide frame x y by framecount minus one to get 0-1
    vec2 f = frame.xy / frameCountMinusOne;

    //bias and scale to -1 to 1
    f = (f-0.5)*2.0;

    //convert to vector, either full sphere or hemi sphere
    vec3 vec = GridToVector(f);

    vec = normalize(vec);

    return vec;
}

vec3 ITBasis(vec3 vec, vec3 basedX, vec3 basedY, vec3 basedZ)
{
    return vec3(dot(basedX, vec), dot(basedY, vec), dot(basedZ, vec));
}

vec3 FrameTransform(vec3 projRay, vec3 frameRay, out vec3 worldX, out vec3 worldZ)
{
    //TODO something might be wrong here
    worldX = normalize(vec3(-frameRay.z, 0, frameRay.x));
    worldZ = normalize(cross(worldX, frameRay));

    projRay *= -1.0;

    vec3 local = normalize(ITBasis(projRay, worldX, frameRay, worldZ));
    return local;
}


vec4 ImposterBlendWeights(sampler2D tex, vec2 uv, vec2 frame0, vec2 frame1, vec2 frame2, vec4 weights, vec2 ddxy)
{
    vec4 samp0 = tex2Dgrad(tex, frame0, ddxy.x, ddxy.y);
    vec4 samp1 = tex2Dgrad(tex, frame1, ddxy.x, ddxy.y);
    vec4 samp2 = tex2Dgrad(tex, frame2, ddxy.x, ddxy.y);

    //vec4 samp0 = tex2Dlod( tex, float4(frame0,0,0) );
    //vec4 samp1 = tex2Dlod( tex, float4(frame1,0,0) );
    //vec4 samp2 = tex2Dlod( tex, float4(frame2,0,0) );

    vec4 result = samp0*weights.x + samp1*weights.y + samp2*weights.z;

    return result;
}

vec2 VirtualPlaneUV(vec3 planeNormal, vec3 planeX, vec3 planeZ, vec3 center, vec2 uvScale, Ray rayLocal){
    float normalDotOrigin = dot(planeNormal, rayLocal.Origin);
    float normalDotCenter = dot(planeNormal, center);
    float normalDotRay = dot(planeNormal, rayLocal.Direction);

    float planeDistance = normalDotOrigin-normalDotCenter;
    planeDistance *= -1.0;

    float intersect = planeDistance / normalDotRay;

    vec3 intersection = ((rayLocal.Direction * intersect) + rayLocal.Origin) - center;

    float dx = dot(planeX, intersection);
    float dz = dot(planeZ, intersection);

    vec2 uv = vec2(0, 0);

    if (intersect > 0)
    {
        uv = vec2(dx, dz);
    }
    else
    {
        uv = vec2(0, 0);
    }

    uv /= uvScale;
    uv += vec2(0.5, 0.5);
    return uv;
}
