import { Box3, BufferAttribute, BufferGeometry, FrontSide, Group, Material, Matrix4, Mesh, MeshBasicMaterial, MeshNormalMaterial, Object3D, Vector3 } from "three";
import { VertexNormalsHelper } from 'three/examples/jsm/helpers/VertexNormalsHelper.js';

import { AssetReference } from "../../engine/engine_addressables.js";
import { disposeObjectResources } from "../../engine/engine_assetdatabase.js";
import { destroy } from "../../engine/engine_gameobject.js";
import { Gizmos } from "../../engine/engine_gizmos.js";
import { serializable } from "../../engine/engine_serialization.js";
import { setVisibleInCustomShadowRendering } from "../../engine/engine_three_utils.js";
import type { Vec3 } from "../../engine/engine_types.js";
import { getParam } from "../../engine/engine_utils.js";
import type { NeedleXREventArgs, NeedleXRSession } from "../../engine/engine_xr.js";
import { MeshCollider } from "../Collider.js";
import { Behaviour, GameObject } from "../Component.js";

const debug = getParam("debugplanetracking");

declare type XRMesh = {
    meshSpace: XRSpace;
    lastChangedTime: number;
    vertices: Float32Array;
    indices: Uint32Array;
    semanticLabel?: string;
}

declare type XRFramePlanes = XRFrame & {
    detectedPlanes?: Set<XRPlane>;
    detectedMeshes?: Set<XRMesh>;
}

/**
 * Used by {@link WebXRPlaneTracking} to track planes in the real world.
 */
export declare type XRPlaneContext = {
    id: number;
    xrData: (XRPlane & { semanticLabel?: string }) | XRMesh;
    timestamp: number;
    mesh?: Mesh | Group;
    collider?: MeshCollider;
}

/**
 * Used by {@link WebXRPlaneTracking} to track planes in the real world.
 */
export declare type WebXRPlaneTrackingEvent = {
    type: "plane-added" | "plane-updated" | "plane-removed";
    context: XRPlaneContext;
}

/**
 * [WebXRPlaneTracking](https://engine.needle.tools/docs/api/WebXRPlaneTracking) tracks planes and meshes in the real world when in immersive-ar (e.g. on Oculus Quest).
 * @category XR
 * @group Components
 */
export class WebXRPlaneTracking extends Behaviour {

    /** 
     * Optional: if assigned it will be instantiated per tracked plane/tracked mesh.
     * If not assigned a simple mesh will be used. Use `occluder` to create occlusion meshes that don't render color but only depth.
     */
    @serializable(AssetReference)
    dataTemplate?: AssetReference;

    /**
     * If true an occluder material will be applied to the tracked planes/meshes.   
     * Note: this will only be applied if dataTemplate is not assigned
     */
    @serializable()
    occluder: boolean = true;

    /**
     * If true the system will try to initiate room capture if no planes are detected.
     */
    @serializable()
    initiateRoomCaptureIfNoData: boolean = true;

    /**
     * If true plane tracking will be enabled
     */
    @serializable()
    usePlaneData: boolean = true;

    /**
     * If true mesh tracking will be enabled
     */
    @serializable()
    useMeshData: boolean = true;

    /** when enabled mesh or plane tracking will also be used in VR */
    @serializable()
    runInVR: boolean = true;

    /**
     * Returns all tracked planes
     */
    get trackedPlanes() { return this._allPlanes.values(); }
    get trackedMeshes() { return this._allMeshes.values(); }

    /** @internal */
    onBeforeXR(_mode: XRSessionMode, args: XRSessionInit): void {
        if (_mode === "immersive-vr" && !this.runInVR) return;
        args.optionalFeatures = args.optionalFeatures || [];
        if (this.usePlaneData && !args.optionalFeatures.includes("plane-detection"))
            args.optionalFeatures.push("plane-detection");
        if (this.useMeshData && !args.optionalFeatures.includes("mesh-detection"))
            args.optionalFeatures.push("mesh-detection");
    }

    /** @internal */
    onEnterXR(_evt) {
        // remove all previously added data from the scene again
        for (const data of this._allPlanes.keys()) {
            this.removeData(data, this._allPlanes);
        }
        for (const data of this._allMeshes.keys()) {
            this.removeData(data, this._allMeshes);
        }
    }

    onLeaveXR(_args: NeedleXREventArgs): void {
        for (const data of this._allPlanes.keys()) {
            this.removeData(data, this._allPlanes);
        }
        for (const data of this._allMeshes.keys()) {
            this.removeData(data, this._allMeshes);
        }
    }

    /** @internal */
    onUpdateXR(args: NeedleXREventArgs): void {

        if (!this.runInVR && args.xr.isVR) return;

        // parenting tracked planes to the XR rig ensures that they synced with the real-world user data;
        // otherwise they would "swim away" when the user rotates / moves / teleports and so on. 
        // There may be cases where we want that! E.g. a user walks around on their own table in castle builder
        const rig = args.xr.rig;
        if (!rig) {
            console.warn("No XR rig found, cannot parent tracked planes to it");
            return;
        }

        const frame = args.xr.frame as XRFramePlanes;
        const renderer = this.context.renderer;
        const referenceSpace = renderer.xr.getReferenceSpace();
        if (!referenceSpace) return;

        const planes = frame.detectedPlanes;
        const meshes = frame.detectedMeshes;
        const hasAnyPlanes = planes !== undefined && planes.size > 0;
        const hasAnyMeshes = meshes !== undefined && meshes.size > 0;

        // When no planes are found and we haven't already run the coroutine,
        // we start it and then wait for 2s before opening the settings.
        // This only works on Quest through a magic method on the frame,
        // see https://developer.oculus.com/documentation/web/webxr-mixed-reality/#:~:text=Because%20few%20people,once%20per%20session.
        if (this.initiateRoomCaptureIfNoData) {
            if (!hasAnyPlanes && !hasAnyMeshes && this.firstTimeNoPlanesDetected < -10)
                this.firstTimeNoPlanesDetected = Date.now();
            if (hasAnyPlanes || hasAnyMeshes)
                this.firstTimeNoPlanesDetected = -1; // we're done
            if (this.firstTimeNoPlanesDetected > 0 && Date.now() - this.firstTimeNoPlanesDetected > 2500) {
                if ("initiateRoomCapture" in frame.session) {
                    //@ts-ignore
                    frame.session.initiateRoomCapture();
                    this.firstTimeNoPlanesDetected = -1; // we're done
                }
            }
        }

        if (planes !== undefined)
            this.processFrameData(args.xr, rig.gameObject, frame, planes, this._allPlanes);

        if (meshes !== undefined)
            this.processFrameData(args.xr, rig.gameObject, frame, meshes, this._allMeshes);

        if (debug) {
            const camPos = this.context.mainCameraComponent!.gameObject.worldPosition;
            // for each plane we have, draw a label at the bbox center
            for (const plane of this._allPlanes.values()) {
                if (!plane.mesh || !plane.mesh.visible) continue;
                this.bounds.makeEmpty();
                plane.mesh.traverse(x => {
                    if (!(x instanceof Mesh)) return;
                    this.bounds.expandByObject(x);
                });
                this.bounds.getCenter(this.center);

                this.labelOffset.copy(camPos).sub(this.center).normalize().multiplyScalar(0.1);
                Gizmos.DrawLabel(
                    this.center.add(this.labelOffset),
                    (plane.xrData.semanticLabel || "plane").toUpperCase() + "\n" +
                    plane.xrData.lastChangedTime.toFixed(2),
                    0.02,
                );
            }
        }
    }

    private bounds = new Box3();
    private center = new Vector3();
    private labelOffset = new Vector3();

    private removeData(data: XRPlane | XRMesh, _all: Map<XRPlane | XRMesh, XRPlaneContext>) {
        const dataContext = _all.get(data);
        if (!dataContext) return;
        _all.delete(data);
        if (debug) console.log("Plane no longer tracked, id=" + dataContext.id);
        if (dataContext.mesh) {
            dataContext.mesh.removeFromParent();
            dataContext.mesh.traverse(x => {
                const nc = x.userData["normalsHelper"];
                if (nc) {
                    nc.dispose();
                    nc.removeFromParent();
                }
                else if(debug) {
                    console.warn("No normals helper found for mesh", dataContext.mesh);
                }
            });
            destroy(dataContext.mesh, true, true);
        }

        const evt = new CustomEvent<WebXRPlaneTrackingEvent>("plane-tracking", {
            detail: {
                type: "plane-removed",
                context: dataContext
            }
        })
        this.dispatchEvent(evt);
    }

    private _dataId = 1;
    private readonly _allPlanes = new Map<XRPlane, XRPlaneContext>();
    private readonly _allMeshes = new Map<XRMesh, XRPlaneContext>();
    private firstTimeNoPlanesDetected = -100;


    private makeOccluder = (mesh: Mesh, m: Material | Array<Material>, force: boolean = false) => {
        if (!m) return;
        if (m instanceof Array) {
            for (const m0 of m)
                this.makeOccluder(mesh, m0, force);
            return;
        }
        if (!force && !m.name.toLowerCase().includes("occlu")) return;
        m.colorWrite = false;
        m.depthTest = true;
        m.depthWrite = true;
        m.transparent = false;
        m.polygonOffset = true;
        // positive values are below
        m.polygonOffsetFactor = 1;
        m.polygonOffsetUnits = .1;
        mesh.renderOrder = -1000;
    }


    private processFrameData(_xr: NeedleXRSession, rig: Object3D, frame: XRFramePlanes, detected: Set<XRPlane | XRMesh>, _all: Map<XRPlane | XRMesh, XRPlaneContext>) {
        const renderer = this.context.renderer;
        const referenceSpace = renderer.xr.getReferenceSpace();

        if (!referenceSpace) return;

        for (const data of _all.keys()) {
            if (!detected.has(data)) {
                this.removeData(data, _all);
            }
        }

        for (const data of detected) {
            const space = "planeSpace" in data ? data.planeSpace
                : ("meshSpace" in data ? data.meshSpace
                    : undefined);
            if (!space) continue;
            const planePose = frame.getPose(space, referenceSpace);

            let planeMesh: Object3D | undefined;
            // If the plane already existed just update it
            if (_all.has(data)) {
                const planeContext = _all.get(data)!;
                planeMesh = planeContext.mesh;
                if (planeContext.timestamp < data.lastChangedTime) {
                    planeContext.timestamp = data.lastChangedTime;

                    // console.log("last change for ID", planeContext.id, planeContext.timestamp, data);

                    // Update the mesh geometry
                    if (planeContext.mesh) {
                        const geometry = this.createGeometry(data);
                        if (planeContext.mesh instanceof Mesh) {
                            planeContext.mesh.geometry.dispose();
                            planeContext.mesh.geometry = geometry;
                            this.makeOccluder(planeContext.mesh, planeContext.mesh.material);
                        }
                        else if (planeContext.mesh instanceof Group) {
                            for (const ch of planeContext.mesh.children) {
                                if (ch instanceof Mesh) {
                                    ch.geometry.dispose();
                                    ch.geometry = geometry;
                                    this.makeOccluder(ch, ch.material);
                                }
                            }
                        }

                        // Update the mesh collider if it exists
                        if (planeContext.collider) {
                            const mesh = planeContext.mesh as unknown as Mesh;
                            planeContext.collider.sharedMesh = mesh;
                            planeContext.collider.convex = this.checkIfContextShouldBeConvex(mesh, planeContext.xrData);
                            planeContext.collider.onDisable();
                            planeContext.collider.onEnable();
                        }

                        if (debug) {
                            console.log("Plane updated, id=" + planeContext.id, planeContext);

                            planeContext.mesh.traverse(x => {
                                if (!(x instanceof Mesh)) return;
                                const nh = x.userData["normalsHelper"];
                                if (!nh) return;

                                // console.log("found normals helper, updating it now...");
                                nh.update();
                            });
                        }
                    }

                    const evt = new CustomEvent<WebXRPlaneTrackingEvent>("plane-tracking", {
                        detail: {
                            type: "plane-updated",
                            context: planeContext
                        }
                    })
                    this.dispatchEvent(evt);
                }
            }
            // Otherwise we create a new plane instance
            else {
                // if we don't have any template assigned we just use a simple mesh object
                if (!this.dataTemplate) {
                    const mesh = new Mesh();
                    if (debug) mesh.material = new MeshNormalMaterial();
                    else if (this.occluder) {
                        mesh.material = new MeshBasicMaterial();
                        this.makeOccluder(mesh, mesh.material, true);
                    }
                    else {
                        mesh.material = new MeshBasicMaterial({ wireframe: true, opacity: .5, transparent: true, color: 0x333333 });
                    }
                    this.dataTemplate = new AssetReference("", "", mesh);
                }

                if (!this.dataTemplate.asset) {
                    this.dataTemplate.loadAssetAsync();
                }
                else {
                    // Create instance
                    const newPlane = GameObject.instantiate(this.dataTemplate.asset) as GameObject;
                    newPlane.name = "xr-tracked-plane";
                    planeMesh = newPlane;
                    setVisibleInCustomShadowRendering(newPlane, false);

                    if (newPlane instanceof Mesh) {
                        disposeObjectResources(newPlane.geometry);
                        newPlane.geometry = this.createGeometry(data);
                        this.makeOccluder(newPlane, newPlane.material, this.occluder && !this.dataTemplate);
                    }
                    else if (newPlane instanceof Group) {
                        // We want to process only one level of children on purpose here
                        for (const ch of newPlane.children) {
                            if (ch instanceof Mesh) {
                                disposeObjectResources(ch.geometry);
                                ch.geometry = this.createGeometry(data);
                                this.makeOccluder(ch, ch.material, this.occluder && !this.dataTemplate);
                            }
                        }
                    }

                    const mc = newPlane.getComponent(MeshCollider) as MeshCollider;
                    if (mc) {
                        const mesh = newPlane as unknown as Mesh;
                        mc.sharedMesh = mesh;
                        mc.convex = this.checkIfContextShouldBeConvex(mesh, data);
                        mc.onDisable();
                        mc.onEnable();
                    }

                    // doesn't seem to work as MeshCollider doesn't have a clear way to refresh itself
                    // after the geometry has changed
                    // newPlane.getComponent(MeshCollider)!.sharedMesh = newPlane as unknown as Mesh;
                    newPlane.matrixAutoUpdate = false;
                    newPlane.matrixWorldNeedsUpdate = true; // force update of rendering settings and so on
                    // TODO: in VR this has issues when the rig is moved
                    // newPlane.matrixWorld.multiply(rig.matrix.invert());
                    // this.context.scene.add(newPlane);
                    rig.add(newPlane);

                    const planeContext: XRPlaneContext = {
                        id: this._dataId++,
                        xrData: data,
                        timestamp: data.lastChangedTime,
                        mesh: newPlane as unknown as Mesh,
                        collider: mc
                    };
                    _all.set(data, planeContext);

                    if (debug) {
                        console.log("New plane detected, id=" + planeContext.id, planeContext, { hasCollider: !!mc, isGroup: newPlane instanceof Group });
                    }

                    try {
                        const evt = new CustomEvent<WebXRPlaneTrackingEvent>("plane-tracking", {
                            detail: {
                                type: "plane-added",
                                context: planeContext
                            }
                        })
                        this.dispatchEvent(evt);
                    }
                    catch (e) {
                        console.error(e);
                    }
                }
            }

            if (planeMesh) {
                if (planePose) {
                    planeMesh.visible = true;
                    planeMesh.matrix.fromArray(planePose.transform.matrix);
                    planeMesh.matrix.premultiply(this._flipForwardMatrix);
                } else {
                    planeMesh.visible = false;
                }

                if (debug) {
                    planeMesh.traverse(x => {
                        if (!(x instanceof Mesh)) return;
                        if(x.userData["normalsHelper"]){
                            const helper = x.userData["normalsHelper"] as VertexNormalsHelper;
                            helper.update();
                        }
                        else {
                            const normalsHelper = new VertexNormalsHelper(x, 0.05, 0x0000ff);
                            normalsHelper.layers.disableAll();
                            normalsHelper.layers.set(2);
                            this.context.scene.add(normalsHelper);
                            x.userData["normalsHelper"] = normalsHelper;
                        }
                    });
                }
            }
        };
    }

    private _flipForwardMatrix = new Matrix4().makeRotationY(Math.PI);

    // heuristic to determine if a collider should be convex or not - 
    // the "global mesh" should be non-convex, other meshes should be
    private checkIfContextShouldBeConvex(mesh: Mesh | Group | undefined, xrData: XRPlane | XRMesh) {
        if (!mesh) return true;
        if (mesh) {
            // get bounding box of the mesh
            const bbox = new Box3();
            bbox.expandByObject(mesh);
            const size = new Vector3();
            bbox.getSize(size);

            let isConvex = true;

            // if the mesh is too big we make it non-convex
            if (size.x > 2 && size.y > 2 && size.z > 1.5)
                isConvex = false;

            // if the semantic label is "wall" we make it convex
            if (isConvex && "semanticLabel" in xrData && xrData.semanticLabel === "wall")
                isConvex = true;

            // console.log(size, xrData.semanticLabel, isConvex);
            return isConvex;
        }
        return true;
    }

    private createGeometry(data: XRPlane | XRMesh) {
        if ("polygon" in data) {
            return this.createPlaneGeometry(data.polygon);
        }
        else if ("vertices" in data && "indices" in data) {
            return this.createMeshGeometry(data.vertices, data.indices);
        }
        return new BufferGeometry();
    }

    // we cache vertices-to-geometry, because it looks like when we get an update sometimes the geometry stays the same.
    // so we don't want to re-create the geometry every time.
    private _verticesCache = new Map<string, BufferGeometry>();
    private createMeshGeometry(vertices: Float32Array, indices: Uint32Array) {
        const key = vertices.toString() + "_" + indices.toString();
        if (this._verticesCache.has(key)) {
            return this._verticesCache.get(key)!;
        }
        const geometry = new BufferGeometry();
        geometry.setIndex(new BufferAttribute(indices, 1));
        geometry.setAttribute('position', new BufferAttribute(vertices, 3));
        // set UVs in worldspace
        const uvs = Array<number>();
        for (let i = 0; i < vertices.length; i += 3) {
            uvs.push(vertices[i], vertices[i + 2]);
        }
        geometry.setAttribute('uv', new BufferAttribute(vertices, 3));
        geometry.computeVertexNormals();
        // no tangents for now, since we'd need proper UVs for that
        // geometry.computeTangents();

        // simplify - too slow, would need to be on a worker it seems...
        /*
        const modifier = new SimplifyModifier();
        const simplified = modifier.modify(geometry, Math.floor(indices.length / 3) * 0.1);
        geometry.dispose();
        geometry.copy(simplified);
        geometry.computeVertexNormals();
        */

        this._verticesCache.set(key, geometry);
        return geometry;
    }

    private createPlaneGeometry(polygon: Vec3[]) {
        const geometry = new BufferGeometry();

        const vertices: number[] = [];
        const uvs: number[] = [];
        polygon.forEach(point => {
            vertices.push(point.x, point.y, point.z);
            uvs.push(point.x, point.z);
        })

        // get the normal of the plane by using the cross product of B-A and C-A
        const a = new Vector3(vertices[0], vertices[1], vertices[2]);
        const b = new Vector3(vertices[3], vertices[4], vertices[5]);
        const c = new Vector3(vertices[6], vertices[7], vertices[8]);
        const ab = new Vector3();
        const ac = new Vector3();
        ab.subVectors(b, a);
        ac.subVectors(c, a);
        ab.cross(ac);
        ab.normalize();

        const normals: number[] = [];
        for (let i = 0; i < vertices.length / 3; i++) {
            normals.push(ab.x, ab.y, ab.z);
        }

        const indices: number[] = [];
        for (let i = 2; i < polygon.length; ++i) {
            indices.push(0, i - 1, i);
        }

        geometry.setAttribute('position', new BufferAttribute(new Float32Array(vertices), 3));
        geometry.setAttribute('uv', new BufferAttribute(new Float32Array(uvs), 2))
        geometry.setAttribute('normal', new BufferAttribute(new Float32Array(normals), 3));
        geometry.setIndex(indices);

        // update bounds
        geometry.computeBoundingBox();
        geometry.computeBoundingSphere();

        return geometry;
    }
}