// MIT License - Copyright (c) 2025 wallstop
// Full license text: https://github.com/wallstop/unity-helpers/blob/main/LICENSE
namespace WallstopStudios.UnityHelpers.Core.DataStructure
{
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using UnityEngine;
using Utils;
///
/// Immutable 3D R-tree for efficient spatial indexing of 3D bounds.
///
///
/// .Entry[] entries = volumes.Select(v => new RTree3D.Entry(v, v.Bounds)).ToArray();
/// RTree3D tree = RTree3D.Build(entries);
/// List overlaps = new List();
/// tree.GetElementsInRange(origin, 8f, overlaps);
/// ]]>
///
/// Element type.
///
/// Pros: Great for sized 3D objects (meshes, volumes) with fast box and radius intersection queries.
/// Cons: Immutable; rebuild when element bounds change.
/// Semantics: RTree3D indexes 3D bounds (AABBs), not points, and aggregates at node level using bounding volumes.
/// As such, results differ by design from point-based structures like KdTree3D/OctTree3D for the same scene.
///
[Serializable]
public sealed class RTree3D : ISpatialTree3D
{
internal const float MinimumNodeSize = 0.001f;
[Serializable]
internal struct ElementData
{
internal T _value;
internal BoundingBox3D _bounds;
internal Vector3 _center;
internal ulong _sortKey;
}
[Serializable]
public sealed class RTreeNode
{
public readonly BoundingBox3D boundary;
internal readonly RTreeNode[] _children;
internal readonly int _startIndex;
internal readonly int _count;
public readonly bool isTerminal;
private RTreeNode(
int startIndex,
int count,
BoundingBox3D boundary,
RTreeNode[] children
)
{
_startIndex = startIndex;
_count = count;
this.boundary = boundary;
_children = children ?? Array.Empty();
isTerminal = _children.Length == 0;
}
internal static RTreeNode CreateEmpty()
{
return new RTreeNode(0, 0, BoundingBox3D.Empty, Array.Empty());
}
internal static RTreeNode CreateLeaf(ElementData[] elements, int startIndex, int count)
{
BoundingBox3D nodeBounds = CalculateBounds(elements, startIndex, count);
return new RTreeNode(startIndex, count, nodeBounds, Array.Empty());
}
internal static RTreeNode CreateInternal(RTreeNode[] children)
{
if (children.Length == 0)
{
return CreateEmpty();
}
int startIndex = children[0]._startIndex;
int lastChildIndex = children.Length - 1;
RTreeNode lastChild = children[lastChildIndex];
int endIndex = lastChild._startIndex + lastChild._count;
BoundingBox3D nodeBounds = children[0].boundary;
for (int i = 1; i < children.Length; ++i)
{
nodeBounds = nodeBounds.ExpandToInclude(children[i].boundary);
}
nodeBounds = EnsureMinimumBounds(nodeBounds);
return new RTreeNode(startIndex, endIndex - startIndex, nodeBounds, children);
}
}
private readonly struct NodeDistance
{
internal readonly RTreeNode _node;
internal readonly float _distanceSquared;
internal NodeDistance(RTreeNode node, float distanceSquared)
{
_node = node;
_distanceSquared = distanceSquared;
}
}
private sealed class CandidateComparer : IComparer<(int index, float distanceSquared)>
{
internal static readonly CandidateComparer Instance = new();
public int Compare(
(int index, float distanceSquared) x,
(int index, float distanceSquared) y
)
{
return x.distanceSquared.CompareTo(y.distanceSquared);
}
}
/// Default number of elements per leaf node.
public const int DefaultBucketSize = 10;
public const int DefaultBranchFactor = 4;
public readonly ImmutableArray elements;
///
/// Gets the overall bounding box of the tree (as Unity Bounds).
///
public Bounds Boundary => _bounds.ToBounds();
private readonly BoundingBox3D _bounds;
private readonly ElementData[] _elementData;
private readonly RTreeNode _head;
///
/// Builds an R-Tree from elements using a transformer that returns each element's 3D bounds.
///
/// Source elements.
/// Maps element to an axis-aligned bounding box in world space.
/// Max elements per leaf.
/// Approximate number of children per internal node (≥2).
/// Thrown when points or elementTransformer are null.
public RTree3D(
IEnumerable points,
Func elementTransformer,
int bucketSize = DefaultBucketSize,
int branchFactor = DefaultBranchFactor
)
{
elements =
points?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(points));
Func transformer =
elementTransformer ?? throw new ArgumentNullException(nameof(elementTransformer));
int elementCount = elements.Length;
_elementData = new ElementData[elementCount];
ElementData[] elementData = _elementData;
bucketSize = Math.Max(1, bucketSize);
branchFactor = Math.Max(2, branchFactor);
float minX = float.MaxValue;
float minY = float.MaxValue;
float minZ = float.MaxValue;
float maxX = float.MinValue;
float maxY = float.MinValue;
float maxZ = float.MinValue;
bool hasElements = false;
for (int i = 0; i < elementCount; ++i)
{
T element = elements[i];
Bounds elementBounds = transformer(element);
BoundingBox3D elementBox = BoundingBox3D.FromClosedBounds(elementBounds);
ElementData data = default;
data._value = element;
data._bounds = elementBox;
data._center = elementBox.Center;
elementData[i] = data;
Vector3 min = elementBox.min;
Vector3 max = elementBox.max;
if (!hasElements)
{
hasElements = true;
}
if (min.x < minX)
{
minX = min.x;
}
if (min.y < minY)
{
minY = min.y;
}
if (min.z < minZ)
{
minZ = min.z;
}
if (max.x > maxX)
{
maxX = max.x;
}
if (max.y > maxY)
{
maxY = max.y;
}
if (max.z > maxZ)
{
maxZ = max.z;
}
}
BoundingBox3D bounds = hasElements
? new BoundingBox3D(new Vector3(minX, minY, minZ), new Vector3(maxX, maxY, maxZ))
: BoundingBox3D.Empty;
if (hasElements)
{
bounds = bounds.EnsureMinimumSize(MinimumNodeSize);
}
_bounds = bounds;
if (!hasElements)
{
_head = RTreeNode.CreateEmpty();
return;
}
float rangeX = maxX - minX;
float rangeY = maxY - minY;
float rangeZ = maxZ - minZ;
float inverseRangeX = rangeX > float.Epsilon ? 1f / rangeX : 0f;
float inverseRangeY = rangeY > float.Epsilon ? 1f / rangeY : 0f;
float inverseRangeZ = rangeZ > float.Epsilon ? 1f / rangeZ : 0f;
for (int i = 0; i < elementCount; ++i)
{
ref ElementData data = ref elementData[i];
Vector3 center = data._center;
float normalizedX = (center.x - minX) * inverseRangeX;
float normalizedY = (center.y - minY) * inverseRangeY;
float normalizedZ = (center.z - minZ) * inverseRangeZ;
ushort quantizedX = QuantizeNormalized(normalizedX);
ushort quantizedY = QuantizeNormalized(normalizedY);
ushort quantizedZ = QuantizeNormalized(normalizedZ);
uint mortonKey = EncodeMorton(quantizedX, quantizedY, quantizedZ);
data._sortKey = ComposeSortKey(mortonKey, quantizedX, quantizedY, quantizedZ);
}
if (elementCount > 1)
{
RadixSort(elementData, elementCount);
}
using PooledResource> currentLevelResource =
Buffers.List.Get(out List currentLevel);
for (int startIndex = 0; startIndex < elementCount; startIndex += bucketSize)
{
int count = Math.Min(bucketSize, elementCount - startIndex);
currentLevel.Add(RTreeNode.CreateLeaf(elementData, startIndex, count));
}
while (currentLevel.Count > 1)
{
using PooledResource> nextLevelResource =
Buffers.List.Get(out List nextLevel);
for (int i = 0; i < currentLevel.Count; i += branchFactor)
{
int childCount = Math.Min(branchFactor, currentLevel.Count - i);
RTreeNode[] children = new RTreeNode[childCount];
currentLevel.CopyTo(i, children, 0, childCount);
nextLevel.Add(RTreeNode.CreateInternal(children));
}
currentLevel.Clear();
currentLevel.AddRange(nextLevel);
}
RTreeNode head = currentLevel.Count > 0 ? currentLevel[0] : RTreeNode.CreateEmpty();
_head = head;
_bounds = _head.boundary;
}
private void CollectElementIndicesInBounds(BoundingBox3D bounds, List indices)
{
indices.Clear();
if (_head._count == 0)
{
return;
}
if (!bounds.Intersects(_bounds))
{
return;
}
using PooledResource> nodeBufferResource =
Buffers.Stack.Get(out Stack nodesToVisit);
nodesToVisit.Push(_head);
while (nodesToVisit.TryPop(out RTreeNode currentNode))
{
if (!bounds.Intersects(currentNode.boundary))
{
continue;
}
if (currentNode.isTerminal)
{
int start = currentNode._startIndex;
int end = start + currentNode._count;
for (int i = start; i < end; ++i)
{
ElementData elementData = _elementData[i];
if (bounds.Intersects(elementData._bounds))
{
indices.Add(i);
}
}
continue;
}
RTreeNode[] childNodes = currentNode._children;
foreach (RTreeNode child in childNodes)
{
if (child._count <= 0)
{
continue;
}
if (!bounds.Intersects(child.boundary))
{
continue;
}
nodesToVisit.Push(child);
}
}
}
public List GetElementsInRange(
Vector3 position,
float range,
List elementsInRange,
float minimumRange = 0f
)
{
elementsInRange.Clear();
if (range < 0f)
{
return elementsInRange;
}
BoundingBox3D queryBounds = BoundingBox3D.FromCenterAndSize(
position,
new Vector3(range * 2f, range * 2f, range * 2f)
);
if (!queryBounds.Intersects(_bounds))
{
return elementsInRange;
}
using PooledResource> candidateIndicesResource = Buffers.List.Get(
out List candidateIndices
);
CollectElementIndicesInBounds(queryBounds, candidateIndices);
if (candidateIndices.Count == 0)
{
return elementsInRange;
}
Sphere area = new(position, range);
bool hasMinimumRange = 0f < minimumRange;
Sphere minimumArea = default;
if (hasMinimumRange)
{
minimumArea = new Sphere(position, minimumRange);
}
foreach (int index in candidateIndices)
{
ElementData elementData = _elementData[index];
BoundingBox3D elementBoundary = elementData._bounds;
if (!area.Intersects(elementBoundary))
{
continue;
}
if (hasMinimumRange && minimumArea.Intersects(elementBoundary))
{
continue;
}
elementsInRange.Add(elementData._value);
}
return elementsInRange;
}
public List GetElementsInBounds(Bounds bounds, List elementsInBounds)
{
elementsInBounds.Clear();
BoundingBox3D queryBounds = BoundingBox3D.FromClosedBounds(bounds);
if (!queryBounds.Intersects(_bounds))
{
return elementsInBounds;
}
using PooledResource> indicesResource = Buffers.List.Get(
out List indices
);
CollectElementIndicesInBounds(queryBounds, indices);
foreach (int index in indices)
{
ElementData elementData = _elementData[index];
if (!queryBounds.Contains(elementData._center))
{
continue;
}
elementsInBounds.Add(elementData._value);
}
return elementsInBounds;
}
public List GetApproximateNearestNeighbors(
Vector3 position,
int count,
List nearestNeighbors
)
{
nearestNeighbors.Clear();
if (count <= 0 || _head._count == 0)
{
return nearestNeighbors;
}
using PooledResource> nodeHeapResource =
Buffers.List.Get(out List nodeHeap);
PushNode(nodeHeap, _head, position);
using PooledResource> nearestNeighborBufferResource = Buffers.HashSet.Get(
out HashSet nearestNeighborsSet
);
using PooledResource> candidateBufferResource =
Buffers<(int index, float distanceSquared)>.List.Get(
out List<(int index, float distanceSquared)> candidates
);
float currentWorstDistanceSquared = float.PositiveInfinity;
while (nodeHeap.Count > 0)
{
NodeDistance best = PopNode(nodeHeap);
if (
candidates.Count >= count
&& best._distanceSquared >= currentWorstDistanceSquared
)
{
break;
}
RTreeNode currentNode = best._node;
if (!currentNode.isTerminal)
{
RTreeNode[] childNodes = currentNode._children;
for (int i = 0; i < childNodes.Length; ++i)
{
RTreeNode child = childNodes[i];
if (child._count > 0)
{
PushNode(nodeHeap, child, position);
}
}
continue;
}
int startIndex = currentNode._startIndex;
int endIndex = startIndex + currentNode._count;
for (int i = startIndex; i < endIndex; ++i)
{
ElementData elementData = _elementData[i];
T value = elementData._value;
if (!nearestNeighborsSet.Add(value))
{
continue;
}
float distanceSquared = (elementData._center - position).sqrMagnitude;
if (candidates.Count < count)
{
candidates.Add((i, distanceSquared));
if (candidates.Count == count)
{
currentWorstDistanceSquared = FindWorstDistance(candidates);
}
continue;
}
if (distanceSquared >= currentWorstDistanceSquared)
{
nearestNeighborsSet.Remove(value);
continue;
}
int worstCandidateIndex = FindIndexOfWorstCandidate(candidates);
T removedValue = _elementData[candidates[worstCandidateIndex].index]._value;
nearestNeighborsSet.Remove(removedValue);
candidates[worstCandidateIndex] = (i, distanceSquared);
currentWorstDistanceSquared = FindWorstDistance(candidates);
}
}
if (candidates.Count == 0)
{
return nearestNeighbors;
}
candidates.Sort(CandidateComparer.Instance);
int resultCount = Math.Min(count, candidates.Count);
for (int i = 0; i < resultCount; ++i)
{
nearestNeighbors.Add(_elementData[candidates[i].index]._value);
}
return nearestNeighbors;
}
private static void PushNode(List heap, RTreeNode node, Vector3 point)
{
NodeDistance entry = new(node, node.boundary.DistanceSquaredTo(point));
heap.Add(entry);
int index = heap.Count - 1;
while (index > 0)
{
int parent = (index - 1) >> 1;
NodeDistance parentEntry = heap[parent];
if (parentEntry._distanceSquared <= entry._distanceSquared)
{
break;
}
heap[index] = parentEntry;
index = parent;
}
heap[index] = entry;
}
private static NodeDistance PopNode(List heap)
{
int lastIndex = heap.Count - 1;
NodeDistance result = heap[0];
NodeDistance last = heap[lastIndex];
heap.RemoveAt(lastIndex);
int index = 0;
int count = heap.Count;
while (true)
{
int left = (index << 1) + 1;
if (left >= count)
{
break;
}
int right = left + 1;
int smallest =
right < count && heap[right]._distanceSquared < heap[left]._distanceSquared
? right
: left;
if (last._distanceSquared <= heap[smallest]._distanceSquared)
{
break;
}
heap[index] = heap[smallest];
index = smallest;
}
if (count > 0)
{
heap[index] = last;
}
return result;
}
private static float FindWorstDistance(List<(int index, float distanceSquared)> list)
{
float worst = 0f;
for (int i = 0; i < list.Count; ++i)
{
float distance = list[i].distanceSquared;
if (distance > worst)
{
worst = distance;
}
}
return worst;
}
private static int FindIndexOfWorstCandidate(List<(int index, float distanceSquared)> list)
{
int worstIndex = 0;
float worstDistance = list[0].distanceSquared;
for (int i = 1; i < list.Count; ++i)
{
float distance = list[i].distanceSquared;
if (distance > worstDistance)
{
worstDistance = distance;
worstIndex = i;
}
}
return worstIndex;
}
private static void RadixSort(ElementData[] elements, int length)
{
if (length <= 1)
{
return;
}
const int BitsPerPass = 8;
const int BucketCount = 1 << BitsPerPass;
Span counts = stackalloc int[BucketCount];
using PooledArray scratchResource = SystemArrayPool.Get(
length,
out ElementData[] scratch
);
ElementData[] source = elements;
ElementData[] destination = scratch;
bool dataInScratch = false;
for (int shift = 0; shift < 64; shift += BitsPerPass)
{
counts.Clear();
ref ElementData sourceRef = ref source[0];
for (int i = 0; i < length; ++i)
{
ulong key = Unsafe.Add(ref sourceRef, i)._sortKey;
counts[(int)((key >> shift) & (BucketCount - 1))]++;
}
int total = 0;
for (int bucket = 0; bucket < BucketCount; ++bucket)
{
int count = counts[bucket];
counts[bucket] = total;
total += count;
}
ref ElementData destinationRef = ref destination[0];
for (int i = 0; i < length; ++i)
{
ElementData value = Unsafe.Add(ref sourceRef, i);
int bucket = (int)((value._sortKey >> shift) & (BucketCount - 1));
Unsafe.Add(ref destinationRef, counts[bucket]++) = value;
}
(source, destination) = (destination, source);
dataInScratch = !dataInScratch;
}
if (dataInScratch)
{
Array.Copy(source, elements, length);
}
}
private static BoundingBox3D CalculateBounds(
ElementData[] elements,
int startIndex,
int count
)
{
float minX = float.MaxValue;
float minY = float.MaxValue;
float minZ = float.MaxValue;
float maxX = float.MinValue;
float maxY = float.MinValue;
float maxZ = float.MinValue;
int endIndex = startIndex + count;
for (int i = startIndex; i < endIndex; ++i)
{
BoundingBox3D bounds = elements[i]._bounds;
Vector3 min = bounds.min;
Vector3 max = bounds.max;
minX = Math.Min(minX, min.x);
maxX = Math.Max(maxX, max.x);
minY = Math.Min(minY, min.y);
maxY = Math.Max(maxY, max.y);
minZ = Math.Min(minZ, min.z);
maxZ = Math.Max(maxZ, max.z);
}
BoundingBox3D nodeBounds = new(
new Vector3(minX, minY, minZ),
new Vector3(maxX, maxY, maxZ)
);
return EnsureMinimumBounds(nodeBounds);
}
private static BoundingBox3D EnsureMinimumBounds(BoundingBox3D bounds)
{
return bounds.EnsureMinimumSize(MinimumNodeSize);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint EncodeMorton(ushort quantizedX, ushort quantizedY, ushort quantizedZ)
{
uint mortonX = Part1By2(quantizedX);
uint mortonY = Part1By2(quantizedY);
uint mortonZ = Part1By2(quantizedZ);
return mortonX | (mortonY << 1) | (mortonZ << 2);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ushort QuantizeNormalized(float normalized)
{
if (normalized <= 0f)
{
return 0;
}
if (normalized >= 1f)
{
return 1023;
}
return (ushort)(normalized * 1023f + 0.5f);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ulong ComposeSortKey(
uint mortonKey,
ushort quantizedX,
ushort quantizedY,
ushort quantizedZ
)
{
return ((ulong)mortonKey << 32)
| ((ulong)quantizedX << 20)
| ((ulong)quantizedY << 10)
| quantizedZ;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static uint Part1By2(uint value)
{
value &= 0x000003ff;
value = (value | (value << 16)) & 0xFF0000FF;
value = (value | (value << 8)) & 0x0F00F00F;
value = (value | (value << 4)) & 0xC30C30C3;
value = (value | (value << 2)) & 0x49249249;
return value;
}
}
}