namespace UnityHelpers.Core.DataStructure { using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using Extension; using UnityEngine; using Utils; [Serializable] public sealed class KDTree : ISpatialTree { public delegate float Axis(V element); [Serializable] public sealed class KDTreeNode { public readonly Bounds boundary; public readonly KDTreeNode left; public readonly KDTreeNode right; public readonly V[] elements; public readonly bool isTerminal; public KDTreeNode( List elements, Func elementTransformer, int bucketSize, bool isXAxis, bool balanced ) { boundary = elements.Select(elementTransformer).GetBounds() ?? new Bounds(); this.elements = elements.ToArray(); isTerminal = elements.Count <= bucketSize; if (isTerminal) { return; } if (balanced) { Axis axisFunction = isXAxis ? element => elementTransformer(element).x : element => elementTransformer(element).y; int Comparison(V lhs, V rhs) { return axisFunction(lhs).CompareTo(axisFunction(rhs)); } elements.Sort(Comparison); int cutoff = elements.Count / 2; left = new KDTreeNode( elements.Take(cutoff).ToList(), elementTransformer, bucketSize, !isXAxis, true ); right = new KDTreeNode( elements.Skip(cutoff).ToList(), elementTransformer, bucketSize, !isXAxis, true ); } else { Vector2 cutoff = boundary.center; if (isXAxis) { left = new KDTreeNode( elements .Where(element => elementTransformer(element).x <= cutoff.x) .ToList(), elementTransformer, bucketSize, false, false ); right = new KDTreeNode( elements .Where(element => cutoff.x < elementTransformer(element).x) .ToList(), elementTransformer, bucketSize, false, false ); } else { left = new KDTreeNode( elements .Where(element => elementTransformer(element).y <= cutoff.y) .ToList(), elementTransformer, bucketSize, true, false ); right = new KDTreeNode( elements .Where(element => cutoff.y < elementTransformer(element).y) .ToList(), elementTransformer, bucketSize, true, false ); } } } } public const int DefaultBucketSize = 12; public readonly ImmutableArray elements; public Bounds Boundary => _bounds; public Func ElementTransformer => _elementTransformer; private readonly Bounds _bounds; private readonly Func _elementTransformer; private readonly KDTreeNode _head; public KDTree( IEnumerable points, Func elementTransformer, int bucketSize = DefaultBucketSize, bool balanced = true ) { _elementTransformer = elementTransformer ?? throw new ArgumentNullException(nameof(elementTransformer)); elements = points?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(points)); _bounds = elements.Select(elementTransformer).GetBounds() ?? new Bounds(); _head = new KDTreeNode( elements.ToList(), elementTransformer, bucketSize: bucketSize, isXAxis: true, balanced: balanced ); } public IEnumerable GetElementsInBounds(Bounds bounds) { Stack> buffer = Buffers>.Stack; return GetElementsInBounds(bounds, buffer); } public IEnumerable GetElementsInBounds(Bounds bounds, Stack> nodeBuffer) { if (!bounds.FastIntersects2D(_bounds)) { yield break; } Stack> nodesToVisit = nodeBuffer ?? new Stack>(); nodesToVisit.Clear(); nodesToVisit.Push(_head); while (nodesToVisit.TryPop(out KDTreeNode currentNode)) { if (currentNode.isTerminal) { foreach (T element in currentNode.elements) { if (bounds.FastContains2D(_elementTransformer(element))) { yield return element; } } continue; } if (bounds.Overlaps2D(currentNode.boundary)) { foreach (T element in currentNode.elements) { yield return element; } continue; } KDTreeNode leftNode = currentNode.left; if (0 < leftNode.elements.Length && bounds.FastIntersects2D(leftNode.boundary)) { nodesToVisit.Push(leftNode); } KDTreeNode rightNode = currentNode.right; if (0 < rightNode.elements.Length && bounds.FastIntersects2D(rightNode.boundary)) { nodesToVisit.Push(rightNode); } } } public void GetApproximateNearestNeighbors( Vector2 position, int count, List nearestNeighbors ) { Stack> nodeBuffer = Buffers>.Stack; HashSet nearestNeighborBuffer = Buffers.HashSet; GetApproximateNearestNeighbors( position, count, nearestNeighbors, nodeBuffer, nearestNeighborBuffer ); } // Heavily adapted http://homepage.divms.uiowa.edu/%7Ekvaradar/sp2012/daa/ann.pdf public void GetApproximateNearestNeighbors( Vector2 position, int count, List nearestNeighbors, Stack> nodeBuffer, HashSet nearestNeighborBuffer ) { nearestNeighbors.Clear(); KDTreeNode current = _head; Stack> stack = nodeBuffer ?? new Stack>(); stack.Clear(); stack.Push(_head); HashSet nearestNeighborsSet = nearestNeighborBuffer ?? new HashSet(count); nearestNeighborsSet.Clear(); while (!current.isTerminal) { KDTreeNode left = current.left; KDTreeNode right = current.right; if ( ((Vector2)left.boundary.center - position).sqrMagnitude < ((Vector2)right.boundary.center - position).sqrMagnitude ) { stack.Push(left); current = left; if (left.elements.Length <= count) { break; } } else { stack.Push(right); current = right; if (right.elements.Length <= count) { break; } } } while (nearestNeighborsSet.Count < count && stack.TryPop(out KDTreeNode selected)) { foreach (T element in selected.elements) { _ = nearestNeighborsSet.Add(element); } } foreach (T element in nearestNeighborsSet) { nearestNeighbors.Add(element); } if (count < nearestNeighbors.Count) { Vector2 localPosition = position; nearestNeighbors.Sort(NearestComparison); nearestNeighbors.RemoveRange(count, nearestNeighbors.Count - count); int NearestComparison(T lhs, T rhs) => (_elementTransformer(lhs) - localPosition).sqrMagnitude.CompareTo( (_elementTransformer(rhs) - localPosition).sqrMagnitude ); } } } }