namespace UnityHelpers.Core.DataStructure { using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Extension; using UnityEngine; using Utils; [Serializable] public sealed class QuadTree : ISpatialTree { private const int NumChildren = 4; [Serializable] public sealed class QuadTreeNode { private static readonly List Buffer = new(); public readonly Bounds boundary; internal readonly QuadTreeNode[] children; public readonly V[] elements; public readonly bool isTerminal; public QuadTreeNode( V[] elements, Func elementTransformer, Bounds boundary, int bucketSize ) { this.boundary = boundary; this.elements = elements; isTerminal = elements.Length <= bucketSize; if (isTerminal) { children = Array.Empty>(); return; } children = new QuadTreeNode[NumChildren]; Vector3 quadrantSize = boundary.size / 2f; Vector2 halfQuadrantSize = quadrantSize / 2f; Bounds[] quadrants = { new Bounds( new Vector3( boundary.center.x - halfQuadrantSize.x, boundary.center.y + halfQuadrantSize.y, boundary.center.z ), quadrantSize ), new Bounds( new Vector3( boundary.center.x + halfQuadrantSize.x, boundary.center.y + halfQuadrantSize.y, boundary.center.z ), quadrantSize ), new Bounds( new Vector3( boundary.center.x + halfQuadrantSize.x, boundary.center.y - halfQuadrantSize.y, boundary.center.z ), quadrantSize ), new Bounds( new Vector3( boundary.center.x - halfQuadrantSize.x, boundary.center.y - halfQuadrantSize.y, boundary.center.z ), quadrantSize ), }; for (int i = 0; i < quadrants.Length; ++i) { Bounds quadrant = quadrants[i]; Buffer.Clear(); foreach (V element in elements) { if (quadrant.FastContains2D(elementTransformer(element))) { Buffer.Add(element); } } children[i] = new QuadTreeNode( Buffer.ToArray(), elementTransformer, quadrant, bucketSize ); } } } public const int DefaultBucketSize = 12; public readonly ImmutableArray elements; public Bounds Boundary => _bounds; public Func ElementTransformer => _elementTransformer; private readonly Bounds _bounds; private readonly Func _elementTransformer; private readonly QuadTreeNode _head; public QuadTree( IEnumerable points, Func elementTransformer, Bounds? boundary = null, int bucketSize = DefaultBucketSize ) { _elementTransformer = elementTransformer ?? throw new ArgumentNullException(nameof(elementTransformer)); elements = points?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(points)); _bounds = boundary ?? elements.Select(elementTransformer).GetBounds() ?? new Bounds(); _head = new QuadTreeNode( elements.ToArray(), elementTransformer, _bounds, bucketSize ); } public IEnumerable GetElementsInBounds(Bounds bounds) { Stack> nodeBuffer = Buffers>.Stack; return GetElementsInBounds(bounds, nodeBuffer); } public IEnumerable GetElementsInBounds(Bounds bounds, Stack> nodeBuffer) { if (!bounds.FastIntersects2D(_bounds)) { yield break; } Stack> nodesToVisit = nodeBuffer ?? new Stack>(); nodesToVisit.Clear(); nodesToVisit.Push(_head); while (nodesToVisit.TryPop(out QuadTreeNode currentNode)) { if (currentNode.isTerminal) { foreach (T element in currentNode.elements) { if (bounds.FastContains2D(_elementTransformer(element))) { yield return element; } } continue; } if (bounds.Overlaps2D(currentNode.boundary)) { foreach (T element in currentNode.elements) { yield return element; } continue; } foreach (QuadTreeNode child in currentNode.children) { if (child.elements.Length <= 0) { continue; } if (!bounds.FastIntersects2D(child.boundary)) { continue; } nodesToVisit.Push(child); } } } public void GetApproximateNearestNeighbors( Vector2 position, int count, List nearestNeighbors ) { Stack> nodeBuffer = Buffers>.Stack; List> childrenBuffer = Buffers>.List; HashSet nearestNeighborBuffer = Buffers.HashSet; GetApproximateNearestNeighbors( position, count, nearestNeighbors, nodeBuffer, childrenBuffer, nearestNeighborBuffer ); } // Heavily adapted http://homepage.divms.uiowa.edu/%7Ekvaradar/sp2012/daa/ann.pdf public void GetApproximateNearestNeighbors( Vector2 position, int count, List nearestNeighbors, Stack> nodeBuffer, List> childrenBuffer, HashSet nearestNeighborBuffer ) { nearestNeighbors.Clear(); QuadTreeNode current = _head; Stack> stack = nodeBuffer ?? new Stack>(); stack.Clear(); stack.Push(_head); List> childrenCopy = childrenBuffer ?? new List>(NumChildren); childrenCopy.Clear(); HashSet nearestNeighborsSet = nearestNeighborBuffer ?? new HashSet(count); nearestNeighborsSet.Clear(); Comparison> comparison = Comparison; while (!current.isTerminal) { childrenCopy.Clear(); foreach (QuadTreeNode child in current.children) { childrenCopy.Add(child); } childrenCopy.Sort(comparison); for (int i = childrenCopy.Count - 1; 0 <= i; --i) { stack.Push(childrenCopy[i]); } current = childrenCopy[0]; if (current.elements.Length <= count) { break; } } while (nearestNeighborsSet.Count < count && stack.TryPop(out QuadTreeNode selected)) { foreach (T element in selected.elements) { _ = nearestNeighborsSet.Add(element); } } foreach (T element in nearestNeighborsSet) { nearestNeighbors.Add(element); } if (count < nearestNeighbors.Count) { Vector2 localPosition = position; nearestNeighbors.Sort(NearestComparison); nearestNeighbors.RemoveRange(count, nearestNeighbors.Count - count); int NearestComparison(T lhs, T rhs) => (_elementTransformer(lhs) - localPosition).sqrMagnitude.CompareTo( (_elementTransformer(rhs) - localPosition).sqrMagnitude ); } return; int Comparison(QuadTreeNode lhs, QuadTreeNode rhs) => ((Vector2)lhs.boundary.center - position).sqrMagnitude.CompareTo( ((Vector2)rhs.boundary.center - position).sqrMagnitude ); } } }