// MIT License - Copyright (c) 2025 wallstop // Full license text: https://github.com/wallstop/unity-helpers/blob/main/LICENSE namespace WallstopStudios.UnityHelpers.Core.DataStructure { using System; using System.Buffers; using System.Collections.Generic; using System.Collections.Immutable; using Extension; using UnityEngine; using Utils; /// /// Immutable 2D k-d tree for efficient nearest neighbor, range, and bounds queries. /// /// /// .Entry[] entries = points.Select(p => new KdTree2D.Entry(p, p)).ToArray(); /// KdTree2D tree = KdTree2D.Build(entries); /// List neighbors = new List(); /// tree.GetElementsInRange(queryPosition, 3f, neighbors); /// ]]> /// /// Element type contained in the tree. /// /// Pros: Very fast nearest neighbor performance; good for static or batched updates. /// Cons: Immutable structure by design; rebuild when positions change frequently. /// Semantics: For identical input data and queries, KdTree2D (balanced or unbalanced) /// returns the same set of results as QuadTree2D; they differ only in performance characteristics. /// [Serializable] public sealed class KdTree2D : ISpatialTree2D { [Serializable] public readonly struct Entry { public readonly T value; public readonly Vector2 position; public Entry(T value, Vector2 position) { this.value = value; this.position = position; } } private readonly struct Neighbor { public readonly T value; public readonly float sqrDistance; public Neighbor(T value, float sqrDistance) { this.value = value; this.sqrDistance = sqrDistance; } } [Serializable] public sealed class KdTreeNode { public readonly Bounds boundary; public readonly KdTreeNode left; public readonly KdTreeNode right; internal readonly int _startIndex; internal readonly int _count; public readonly bool isTerminal; private KdTreeNode( Bounds boundary, KdTreeNode left, KdTreeNode right, int startIndex, int count, bool isTerminal ) { this.boundary = boundary; this.left = left; this.right = right; _startIndex = startIndex; _count = count; this.isTerminal = isTerminal; } internal static KdTreeNode CreateLeaf(Bounds boundary, int startIndex, int count) { return new KdTreeNode(boundary, null, null, startIndex, count, true); } internal static KdTreeNode CreateInternal( Bounds boundary, KdTreeNode left, KdTreeNode right, int startIndex, int count ) { return new KdTreeNode(boundary, left, right, startIndex, count, false); } } private const float MinimumNodeSize = 0.001f; private const int SmallPartitionThreshold = 32; /// /// Default bucket size for leaves before stopping recursion. /// public const int DefaultBucketSize = 12; public readonly ImmutableArray elements; /// /// Gets the overall bounding box of the tree. /// public Bounds Boundary => _bounds; private readonly Bounds _bounds; private readonly Entry[] _entries; private readonly int[] _indices; private readonly KdTreeNode _head; private readonly bool _balanced; private readonly int _bucketSize; /// /// Builds a 2D k-d tree from elements using a transformer to extract 2D positions. /// /// Source elements. /// Maps element to its 2D position. /// Max elements per leaf. Minimum 1. /// If true, builds a balanced tree by median selection; otherwise uses a quick split strategy. /// Thrown when points or elementTransformer are null. public KdTree2D( IEnumerable points, Func elementTransformer, int bucketSize = DefaultBucketSize, bool balanced = true ) { if (elementTransformer is null) { throw new ArgumentNullException(nameof(elementTransformer)); } elements = points?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(points)); int elementCount = elements.Length; _entries = elementCount == 0 ? Array.Empty() : new Entry[elementCount]; _indices = elementCount == 0 ? Array.Empty() : new int[elementCount]; _balanced = balanced; _bucketSize = Math.Max(1, bucketSize); float minX = float.PositiveInfinity; float minY = float.PositiveInfinity; float maxX = float.NegativeInfinity; float maxY = float.NegativeInfinity; for (int i = 0; i < elementCount; ++i) { T element = elements[i]; Vector2 position = elementTransformer(element); _entries[i] = new Entry(element, position); if (position.x < minX) { minX = position.x; } if (position.y < minY) { minY = position.y; } if (position.x > maxX) { maxX = position.x; } if (position.y > maxY) { maxY = position.y; } _indices[i] = i; } Bounds bounds = CreateBounds(minX, maxX, minY, maxY); if (elementCount == 0) { _bounds = bounds; _head = KdTreeNode.CreateLeaf(_bounds, 0, 0); return; } KdTreeNode root; if (_balanced) { root = BuildBalanced(0, elementCount, depth: 0); } else { int[] scratch = ArrayPool.Shared.Rent(elementCount); try { root = BuildUnbalanced(0, elementCount, splitOnXAxis: true, scratch); } finally { ArrayPool.Shared.Return(scratch, clearArray: true); } } _head = root; _bounds = root.boundary; } private KdTreeNode BuildBalanced(int startIndex, int count, int depth) { if (count <= _bucketSize) { Bounds leafBounds = CalculateLeafBounds(startIndex, count); return KdTreeNode.CreateLeaf(leafBounds, startIndex, count); } bool splitOnXAxis = (depth & 1) == 0; int axis = splitOnXAxis ? 0 : 1; Span span = _indices.AsSpan(startIndex, count); int leftCount = count / 2; if (leftCount == 0) { Bounds leafBounds = CalculateLeafBounds(startIndex, count); return KdTreeNode.CreateLeaf(leafBounds, startIndex, count); } SelectKth(span, leftCount, axis); int rightCount = count - leftCount; if (rightCount == 0) { Bounds leafBounds = CalculateLeafBounds(startIndex, count); return KdTreeNode.CreateLeaf(leafBounds, startIndex, count); } KdTreeNode left = BuildBalanced(startIndex, leftCount, depth + 1); KdTreeNode right = BuildBalanced(startIndex + leftCount, rightCount, depth + 1); Bounds boundary = CombineChildBounds(left.boundary, right.boundary); return KdTreeNode.CreateInternal(boundary, left, right, startIndex, count); } private KdTreeNode BuildUnbalanced( int startIndex, int count, bool splitOnXAxis, int[] scratch ) { Span source = _indices.AsSpan(startIndex, count); Span temp = scratch.AsSpan(0, count); Entry[] entries = _entries; float minX = float.PositiveInfinity; float minY = float.PositiveInfinity; float maxX = float.NegativeInfinity; float maxY = float.NegativeInfinity; for (int i = 0; i < count; ++i) { Vector2 position = entries[source[i]].position; if (position.x < minX) { minX = position.x; } if (position.y < minY) { minY = position.y; } if (position.x > maxX) { maxX = position.x; } if (position.y > maxY) { maxY = position.y; } } Bounds nodeBounds = CreateBounds(minX, maxX, minY, maxY); if (count <= _bucketSize) { return KdTreeNode.CreateLeaf(nodeBounds, startIndex, count); } float cutoff = splitOnXAxis ? nodeBounds.center.x : nodeBounds.center.y; int leftWrite = 0; int rightWrite = count - 1; for (int i = 0; i < count; ++i) { int entryIndex = source[i]; Vector2 position = entries[entryIndex].position; float value = splitOnXAxis ? position.x : position.y; if (value <= cutoff) { temp[leftWrite++] = entryIndex; } else { temp[rightWrite--] = entryIndex; } } int leftCount = leftWrite; int rightCount = count - leftCount; if (leftCount == 0 || rightCount == 0) { return KdTreeNode.CreateLeaf(nodeBounds, startIndex, count); } temp.CopyTo(source); KdTreeNode left = BuildUnbalanced(startIndex, leftCount, !splitOnXAxis, scratch); KdTreeNode right = BuildUnbalanced( startIndex + leftCount, rightCount, !splitOnXAxis, scratch ); Bounds boundary = CombineChildBounds(left.boundary, right.boundary); return KdTreeNode.CreateInternal(boundary, left, right, startIndex, count); } private Bounds CalculateLeafBounds(int startIndex, int count) { if (count <= 0) { return new Bounds(); } Entry[] entries = _entries; int[] indices = _indices; float minX = float.PositiveInfinity; float minY = float.PositiveInfinity; float maxX = float.NegativeInfinity; float maxY = float.NegativeInfinity; int end = startIndex + count; for (int i = startIndex; i < end; ++i) { Vector2 position = entries[indices[i]].position; if (position.x < minX) { minX = position.x; } if (position.y < minY) { minY = position.y; } if (position.x > maxX) { maxX = position.x; } if (position.y > maxY) { maxY = position.y; } } return CreateBounds(minX, maxX, minY, maxY); } private void SelectKth(Span span, int k, int axis) { Entry[] entries = _entries; int left = 0; int right = span.Length - 1; while (left < right) { if (right - left <= SmallPartitionThreshold) { InsertionSort(span.Slice(left, right - left + 1), axis, entries); return; } int pivotIndex = SelectPivot(span, left, right, axis, entries); float pivot = GetAxis(entries[span[pivotIndex]], axis); int i = left; int j = right; if (axis == 0) { while (i <= j) { while (i <= j && entries[span[i]].position.x < pivot) { i++; } while (i <= j && entries[span[j]].position.x > pivot) { j--; } if (i <= j) { (span[i], span[j]) = (span[j], span[i]); i++; j--; } } } else { while (i <= j) { while (i <= j && entries[span[i]].position.y < pivot) { i++; } while (i <= j && entries[span[j]].position.y > pivot) { j--; } if (i <= j) { (span[i], span[j]) = (span[j], span[i]); i++; j--; } } } if (k <= j) { right = j; continue; } if (k >= i) { left = i; continue; } return; } } private static int SelectPivot( Span span, int left, int right, int axis, Entry[] entries ) { int mid = left + ((right - left) >> 1); float leftValue = GetAxis(entries[span[left]], axis); float midValue = GetAxis(entries[span[mid]], axis); float rightValue = GetAxis(entries[span[right]], axis); if (leftValue > midValue) { (span[left], span[mid]) = (span[mid], span[left]); (leftValue, midValue) = (midValue, leftValue); } if (midValue > rightValue) { (span[mid], span[right]) = (span[right], span[mid]); (midValue, rightValue) = (rightValue, midValue); if (leftValue > midValue) { (span[left], span[mid]) = (span[mid], span[left]); (leftValue, midValue) = (midValue, leftValue); } } return mid; } private static void InsertionSort(Span span, int axis, Entry[] entries) { if (span.Length <= 1) { return; } if (axis == 0) { for (int i = 1; i < span.Length; ++i) { int currentIndex = span[i]; float currentValue = entries[currentIndex].position.x; int j = i - 1; while (j >= 0 && entries[span[j]].position.x > currentValue) { span[j + 1] = span[j]; j--; } span[j + 1] = currentIndex; } } else { for (int i = 1; i < span.Length; ++i) { int currentIndex = span[i]; float currentValue = entries[currentIndex].position.y; int j = i - 1; while (j >= 0 && entries[span[j]].position.y > currentValue) { span[j + 1] = span[j]; j--; } span[j + 1] = currentIndex; } } } private static float GetAxis(in Entry entry, int axis) => axis == 0 ? entry.position.x : entry.position.y; private static Bounds CombineChildBounds(Bounds left, Bounds right) { Bounds combined = left; combined.Encapsulate(right); EnsureMinimumBounds(ref combined); return combined; } private static Bounds CreateBounds(float minX, float maxX, float minY, float maxY) { if (float.IsInfinity(minX) || float.IsInfinity(minY)) { return new Bounds(); } Vector3 min = new(minX, minY, 0f); Vector3 max = new(maxX, maxY, 0f); Vector3 center = (min + max) * 0.5f; Vector3 size = max - min; Bounds bounds = new(center, new Vector3(size.x, size.y, 1f)); EnsureMinimumBounds(ref bounds); return bounds; } private static void EnsureMinimumBounds(ref Bounds bounds) { Vector3 size = bounds.size; if (size.x < MinimumNodeSize) { size.x = MinimumNodeSize; } if (size.y < MinimumNodeSize) { size.y = MinimumNodeSize; } size.z = 1f; bounds.size = size; } public List GetElementsInRange( Vector2 position, float range, List elementsInRange, float minimumRange = 0 ) { elementsInRange.Clear(); // Allow zero range to return only exact matches (distance == 0) if (range < 0f || _head._count <= 0) { return elementsInRange; } Bounds bounds = new(position, new Vector3(range * 2f, range * 2f, 1f)); if (!bounds.FastIntersects2D(_bounds)) { return elementsInRange; } using PooledResource> stackResource = Buffers.Stack.Get( out Stack nodesToVisit ); nodesToVisit.Push(_head); Entry[] entries = _entries; int[] indices = _indices; float rangeSquared = range * range; bool hasMinimumRange = 0f < minimumRange; float minimumRangeSquared = minimumRange * minimumRange; while (nodesToVisit.TryPop(out KdTreeNode currentNode)) { if (currentNode is null || currentNode._count <= 0) { continue; } if (!bounds.FastIntersects2D(currentNode.boundary)) { continue; } if (currentNode.isTerminal || bounds.FastContains2D(currentNode.boundary)) { int start = currentNode._startIndex; int end = start + currentNode._count; for (int i = start; i < end; ++i) { Entry entry = entries[indices[i]]; float squareDistance = (entry.position - position).sqrMagnitude; if (squareDistance > rangeSquared) { continue; } if (hasMinimumRange && squareDistance <= minimumRangeSquared) { continue; } elementsInRange.Add(entry.value); } continue; } KdTreeNode left = currentNode.left; if (left is not null && left._count > 0 && bounds.FastIntersects2D(left.boundary)) { nodesToVisit.Push(left); } KdTreeNode right = currentNode.right; if ( right is not null && right._count > 0 && bounds.FastIntersects2D(right.boundary) ) { nodesToVisit.Push(right); } } return elementsInRange; } public List GetElementsInBounds(Bounds bounds, List elementsInBounds) { elementsInBounds.Clear(); if (_head._count <= 0 || !bounds.FastIntersects2D(_bounds)) { return elementsInBounds; } using PooledResource> stackResource = Buffers.Stack.Get( out Stack nodesToVisit ); nodesToVisit.Push(_head); Entry[] entries = _entries; int[] indices = _indices; while (nodesToVisit.TryPop(out KdTreeNode currentNode)) { if (currentNode is null || currentNode._count <= 0) { continue; } if (bounds.FastContains2D(currentNode.boundary)) { int start = currentNode._startIndex; int end = start + currentNode._count; for (int i = start; i < end; ++i) { elementsInBounds.Add(entries[indices[i]].value); } continue; } if (currentNode.isTerminal) { int start = currentNode._startIndex; int end = start + currentNode._count; for (int i = start; i < end; ++i) { Entry entry = entries[indices[i]]; if (bounds.FastContains2D(entry.position)) { elementsInBounds.Add(entry.value); } } continue; } KdTreeNode left = currentNode.left; if (left is not null && left._count > 0 && bounds.FastIntersects2D(left.boundary)) { nodesToVisit.Push(left); } KdTreeNode right = currentNode.right; if ( right is not null && right._count > 0 && bounds.FastIntersects2D(right.boundary) ) { nodesToVisit.Push(right); } } return elementsInBounds; } public List GetApproximateNearestNeighbors( Vector2 position, int count, List nearestNeighbors ) { nearestNeighbors.Clear(); if (count <= 0 || _head._count == 0) { return nearestNeighbors; } using PooledResource> nodeBufferResource = Buffers.Stack.Get(out Stack nodeBuffer); nodeBuffer.Push(_head); using PooledResource> nearestNeighborBufferResource = Buffers.HashSet.Get( out HashSet nearestNeighborBuffer ); using PooledResource> neighborCandidatesResource = Buffers.List.Get(out List neighborCandidates); Entry[] entries = _entries; int[] indices = _indices; KdTreeNode current = _head; while (!current.isTerminal) { KdTreeNode left = current.left; KdTreeNode right = current.right; if (left is null || left._count == 0) { if (right is null || right._count == 0) { break; } nodeBuffer.Push(right); current = right; if (right._count <= count) { break; } continue; } if (right is null || right._count == 0) { nodeBuffer.Push(left); current = left; if (left._count <= count) { break; } continue; } float leftDistance = ((Vector2)left.boundary.center - position).sqrMagnitude; float rightDistance = ((Vector2)right.boundary.center - position).sqrMagnitude; if (leftDistance < rightDistance) { nodeBuffer.Push(left); current = left; if (left._count <= count) { break; } } else { nodeBuffer.Push(right); current = right; if (right._count <= count) { break; } } } while ( nearestNeighborBuffer.Count < count && nodeBuffer.TryPop(out KdTreeNode selected) ) { if (selected is null || selected._count <= 0) { continue; } int startIndex = selected._startIndex; int endIndex = startIndex + selected._count; for (int i = startIndex; i < endIndex; ++i) { Entry entry = entries[indices[i]]; if (!nearestNeighborBuffer.Add(entry.value)) { continue; } float sqrDistance = (entry.position - position).sqrMagnitude; neighborCandidates.Add(new Neighbor(entry.value, sqrDistance)); } } if (count < neighborCandidates.Count) { neighborCandidates.Sort(NeighborComparer.Instance); neighborCandidates.RemoveRange(count, neighborCandidates.Count - count); } nearestNeighbors.Clear(); for (int i = 0; i < neighborCandidates.Count && i < count; ++i) { nearestNeighbors.Add(neighborCandidates[i].value); } return nearestNeighbors; } private sealed class NeighborComparer : IComparer { internal static readonly NeighborComparer Instance = new(); public int Compare(Neighbor x, Neighbor y) { return x.sqrDistance.CompareTo(y.sqrDistance); } } } }