// MIT License - Copyright (c) 2025 wallstop
// Full license text: https://github.com/wallstop/unity-helpers/blob/main/LICENSE
namespace WallstopStudios.UnityHelpers.Core.DataStructure
{
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Runtime.CompilerServices;
using UnityEngine;
using Utils;
///
/// Immutable 3D k-d tree for efficient nearest neighbor, range, and bounds queries in 3D space.
///
///
/// points = SamplePoints();
/// KdTree3D tree = new KdTree3D(points, p => p);
/// List neighbors = new List();
/// tree.GetElementsInRange(queryPosition, 4f, neighbors);
/// ]]>
///
/// Element type contained in the tree.
///
/// Pros: Very fast nearest neighbor performance; good for static or batched updates.
/// Cons: Immutable structure by design; rebuild when positions change frequently.
/// Semantics: Due to algorithmic choices (axis-aligned splitting, half-open containment checks,
/// minimum node-size enforcement, and tie-handling on split planes), KdTree3D (balanced and unbalanced)
/// may return different edge-case results compared to OctTree3D for identical inputs/queries—especially for
/// points lying exactly on query boundaries or split planes. See docs/features/spatial/spatial-tree-semantics.md for details.
///
[Serializable]
public sealed class KdTree3D : ISpatialTree3D
{
private readonly struct Neighbor
{
public readonly int index;
public readonly float sqrDistance;
public Neighbor(int index, float sqrDistance)
{
this.index = index;
this.sqrDistance = sqrDistance;
}
}
[Serializable]
public sealed class KdTreeNode
{
public readonly Bounds boundary;
public readonly KdTreeNode left;
public readonly KdTreeNode right;
internal readonly int _startIndex;
internal readonly int _count;
public readonly bool isTerminal;
private KdTreeNode(
Bounds boundary,
KdTreeNode left,
KdTreeNode right,
int startIndex,
int count,
bool isTerminal
)
{
this.boundary = boundary;
this.left = left;
this.right = right;
_startIndex = startIndex;
_count = count;
this.isTerminal = isTerminal;
}
internal static KdTreeNode CreateLeaf(Bounds boundary, int startIndex, int count)
{
return new KdTreeNode(boundary, null, null, startIndex, count, true);
}
internal static KdTreeNode CreateInternal(
Bounds boundary,
KdTreeNode left,
KdTreeNode right,
int startIndex,
int count
)
{
return new KdTreeNode(boundary, left, right, startIndex, count, false);
}
}
private const float MinimumNodeSize = 0.001f;
private const int SmallPartitionThreshold = 32;
///
/// Default bucket size for leaves before stopping recursion.
///
public const int DefaultBucketSize = 12;
public readonly ImmutableArray elements;
///
/// Gets the overall bounding box of the tree.
///
public Bounds Boundary => _bounds;
private readonly Bounds _bounds;
private readonly float[] _positionsX;
private readonly float[] _positionsY;
private readonly float[] _positionsZ;
private readonly int[] _indices;
private readonly KdTreeNode _head;
private readonly bool _balanced;
private readonly int _bucketSize;
///
/// Builds a 3D k-d tree from elements using a transformer to extract 3D positions.
///
/// Source elements.
/// Maps element to its 3D position.
/// Max elements per leaf. Minimum 1.
/// If true, builds a balanced tree by median selection; otherwise uses a quick split strategy.
/// Thrown when points or elementTransformer are null.
public KdTree3D(
IEnumerable points,
Func elementTransformer,
int bucketSize = DefaultBucketSize,
bool balanced = true
)
{
if (elementTransformer is null)
{
throw new ArgumentNullException(nameof(elementTransformer));
}
elements =
points?.ToImmutableArray() ?? throw new ArgumentNullException(nameof(points));
int elementCount = elements.Length;
_positionsX = elementCount == 0 ? Array.Empty() : new float[elementCount];
_positionsY = elementCount == 0 ? Array.Empty() : new float[elementCount];
_positionsZ = elementCount == 0 ? Array.Empty() : new float[elementCount];
_indices = elementCount == 0 ? Array.Empty() : new int[elementCount];
_balanced = balanced;
_bucketSize = Math.Max(1, bucketSize);
float minX = float.PositiveInfinity;
float minY = float.PositiveInfinity;
float minZ = float.PositiveInfinity;
float maxX = float.NegativeInfinity;
float maxY = float.NegativeInfinity;
float maxZ = float.NegativeInfinity;
for (int i = 0; i < elementCount; ++i)
{
T element = elements[i];
Vector3 position = elementTransformer(element);
_positionsX[i] = position.x;
_positionsY[i] = position.y;
_positionsZ[i] = position.z;
if (position.x < minX)
{
minX = position.x;
}
if (position.y < minY)
{
minY = position.y;
}
if (position.z < minZ)
{
minZ = position.z;
}
if (position.x > maxX)
{
maxX = position.x;
}
if (position.y > maxY)
{
maxY = position.y;
}
if (position.z > maxZ)
{
maxZ = position.z;
}
_indices[i] = i;
}
Bounds bounds = CreateBounds(minX, maxX, minY, maxY, minZ, maxZ);
if (elementCount == 0)
{
_bounds = bounds;
_head = KdTreeNode.CreateLeaf(_bounds, 0, 0);
return;
}
KdTreeNode root;
if (_balanced)
{
root = BuildBalanced(0, elementCount, depth: 0);
}
else
{
int[] scratch = ArrayPool.Shared.Rent(elementCount);
try
{
root = BuildUnbalanced(0, elementCount, axis: 0, scratch);
}
finally
{
ArrayPool.Shared.Return(scratch, clearArray: true);
}
}
_head = root;
_bounds = root.boundary;
}
private KdTreeNode BuildBalanced(int startIndex, int count, int depth)
{
if (count <= _bucketSize)
{
Bounds leafBounds = CalculateLeafBounds(startIndex, count);
return KdTreeNode.CreateLeaf(leafBounds, startIndex, count);
}
int axis = depth % 3;
Span span = _indices.AsSpan(startIndex, count);
int leftCount = count / 2;
if (leftCount == 0)
{
Bounds leafBounds = CalculateLeafBounds(startIndex, count);
return KdTreeNode.CreateLeaf(leafBounds, startIndex, count);
}
SelectKth(span, leftCount, axis);
int rightCount = count - leftCount;
if (rightCount == 0)
{
Bounds leafBounds = CalculateLeafBounds(startIndex, count);
return KdTreeNode.CreateLeaf(leafBounds, startIndex, count);
}
KdTreeNode left = BuildBalanced(startIndex, leftCount, depth + 1);
KdTreeNode right = BuildBalanced(startIndex + leftCount, rightCount, depth + 1);
Bounds boundary = CombineChildBounds(left.boundary, right.boundary);
return KdTreeNode.CreateInternal(boundary, left, right, startIndex, count);
}
private KdTreeNode BuildUnbalanced(int startIndex, int count, int axis, int[] scratch)
{
Span source = _indices.AsSpan(startIndex, count);
Span temp = scratch.AsSpan(0, count);
float[] positionsX = _positionsX;
float[] positionsY = _positionsY;
float[] positionsZ = _positionsZ;
float minX = float.PositiveInfinity;
float minY = float.PositiveInfinity;
float minZ = float.PositiveInfinity;
float maxX = float.NegativeInfinity;
float maxY = float.NegativeInfinity;
float maxZ = float.NegativeInfinity;
for (int i = 0; i < count; ++i)
{
int entryIndex = source[i];
float px = positionsX[entryIndex];
float py = positionsY[entryIndex];
float pz = positionsZ[entryIndex];
if (px < minX)
{
minX = px;
}
if (py < minY)
{
minY = py;
}
if (pz < minZ)
{
minZ = pz;
}
if (px > maxX)
{
maxX = px;
}
if (py > maxY)
{
maxY = py;
}
if (pz > maxZ)
{
maxZ = pz;
}
}
Bounds nodeBounds = CreateBounds(minX, maxX, minY, maxY, minZ, maxZ);
if (count <= _bucketSize)
{
return KdTreeNode.CreateLeaf(nodeBounds, startIndex, count);
}
float cutoff = GetAxisValue(nodeBounds.center, axis);
int leftWrite = 0;
int rightWrite = count - 1;
float[] axisArray = GetAxisArray(axis);
for (int i = 0; i < count; ++i)
{
int entryIndex = source[i];
float value = axisArray[entryIndex];
if (value <= cutoff)
{
temp[leftWrite++] = entryIndex;
}
else
{
temp[rightWrite--] = entryIndex;
}
}
int leftCount = leftWrite;
int rightCount = count - leftCount;
if (leftCount == 0 || rightCount == 0)
{
return KdTreeNode.CreateLeaf(nodeBounds, startIndex, count);
}
temp.CopyTo(source);
int nextAxis = (axis + 1) % 3;
KdTreeNode left = BuildUnbalanced(startIndex, leftCount, nextAxis, scratch);
KdTreeNode right = BuildUnbalanced(
startIndex + leftCount,
rightCount,
nextAxis,
scratch
);
Bounds boundary = CombineChildBounds(left.boundary, right.boundary);
return KdTreeNode.CreateInternal(boundary, left, right, startIndex, count);
}
private Bounds CalculateLeafBounds(int startIndex, int count)
{
if (count <= 0)
{
return new Bounds();
}
int[] indices = _indices;
float[] positionsX = _positionsX;
float[] positionsY = _positionsY;
float[] positionsZ = _positionsZ;
float minX = float.PositiveInfinity;
float minY = float.PositiveInfinity;
float minZ = float.PositiveInfinity;
float maxX = float.NegativeInfinity;
float maxY = float.NegativeInfinity;
float maxZ = float.NegativeInfinity;
int end = startIndex + count;
for (int i = startIndex; i < end; ++i)
{
int entryIndex = indices[i];
float px = positionsX[entryIndex];
float py = positionsY[entryIndex];
float pz = positionsZ[entryIndex];
if (px < minX)
{
minX = px;
}
if (py < minY)
{
minY = py;
}
if (pz < minZ)
{
minZ = pz;
}
if (px > maxX)
{
maxX = px;
}
if (py > maxY)
{
maxY = py;
}
if (pz > maxZ)
{
maxZ = pz;
}
}
return CreateBounds(minX, maxX, minY, maxY, minZ, maxZ);
}
private void SelectKth(Span span, int k, int axis)
{
float[] axisValues = GetAxisArray(axis);
int left = 0;
int right = span.Length - 1;
while (left < right)
{
if (right - left <= SmallPartitionThreshold)
{
InsertionSort(span.Slice(left, right - left + 1), axisValues);
return;
}
int pivotIndex = SelectPivot(span, left, right, axisValues);
float pivot = axisValues[span[pivotIndex]];
int i = left;
int j = right;
while (i <= j)
{
while (i <= j && axisValues[span[i]] < pivot)
{
i++;
}
while (i <= j && axisValues[span[j]] > pivot)
{
j--;
}
if (i <= j)
{
(span[i], span[j]) = (span[j], span[i]);
i++;
j--;
}
}
if (k <= j)
{
right = j;
continue;
}
if (k >= i)
{
left = i;
continue;
}
return;
}
}
private static int SelectPivot(Span span, int left, int right, float[] axisValues)
{
int mid = left + ((right - left) >> 1);
float leftValue = axisValues[span[left]];
float midValue = axisValues[span[mid]];
float rightValue = axisValues[span[right]];
if (leftValue > midValue)
{
(span[left], span[mid]) = (span[mid], span[left]);
(leftValue, midValue) = (midValue, leftValue);
}
if (midValue > rightValue)
{
(span[mid], span[right]) = (span[right], span[mid]);
(midValue, rightValue) = (rightValue, midValue);
if (leftValue > midValue)
{
(span[left], span[mid]) = (span[mid], span[left]);
(leftValue, midValue) = (midValue, leftValue);
}
}
return mid;
}
private static void InsertionSort(Span span, float[] axisValues)
{
if (span.Length <= 1)
{
return;
}
for (int i = 1; i < span.Length; ++i)
{
int currentIndex = span[i];
float currentValue = axisValues[currentIndex];
int j = i - 1;
while (j >= 0 && axisValues[span[j]] > currentValue)
{
span[j + 1] = span[j];
j--;
}
span[j + 1] = currentIndex;
}
}
private static float GetAxisValue(Vector3 position, int axis)
{
return axis switch
{
0 => position.x,
1 => position.y,
_ => position.z,
};
}
private static Bounds CombineChildBounds(Bounds left, Bounds right)
{
Bounds combined = left;
combined.Encapsulate(right);
EnsureMinimumBounds(ref combined);
return combined;
}
private static Bounds CreateBounds(
float minX,
float maxX,
float minY,
float maxY,
float minZ,
float maxZ
)
{
if (float.IsInfinity(minX) || float.IsInfinity(minY) || float.IsInfinity(minZ))
{
return new Bounds();
}
Vector3 min = new(minX, minY, minZ);
Vector3 max = new(maxX, maxY, maxZ);
Vector3 center = (min + max) * 0.5f;
Vector3 size = max - min;
Bounds bounds = new(center, size);
EnsureMinimumBounds(ref bounds);
return bounds;
}
private static void EnsureMinimumBounds(ref Bounds bounds)
{
Vector3 size = bounds.size;
if (size.x < MinimumNodeSize)
{
size.x = MinimumNodeSize;
}
if (size.y < MinimumNodeSize)
{
size.y = MinimumNodeSize;
}
if (size.z < MinimumNodeSize)
{
size.z = MinimumNodeSize;
}
bounds.size = size;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private float[] GetAxisArray(int axis)
{
return axis switch
{
0 => _positionsX,
1 => _positionsY,
_ => _positionsZ,
};
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Vector3 GetPosition(int index)
{
return new Vector3(_positionsX[index], _positionsY[index], _positionsZ[index]);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private float GetDistanceSquared(int index, Vector3 point)
{
float dx = _positionsX[index] - point.x;
float dy = _positionsY[index] - point.y;
float dz = _positionsZ[index] - point.z;
return dx * dx + dy * dy + dz * dz;
}
public List GetElementsInRange(
Vector3 position,
float range,
List elementsInRange,
float minimumRange = 0
)
{
elementsInRange.Clear();
// Allow zero range to return only exact matches (distance == 0)
if (range < 0f || _head._count <= 0)
{
return elementsInRange;
}
Sphere querySphere = new(position, range);
Bounds bounds = new(position, new Vector3(range * 2f, range * 2f, range * 2f));
if (!bounds.Intersects(_bounds))
{
return elementsInRange;
}
using PooledResource> stackResource = Buffers.Stack.Get(
out Stack nodesToVisit
);
nodesToVisit.Push(_head);
ImmutableArray values = elements;
int[] indices = _indices;
float rangeSquared = range * range;
bool hasMinimumRange = 0f < minimumRange;
float minimumRangeSquared = minimumRange * minimumRange;
Sphere minimumSphere = hasMinimumRange ? new Sphere(position, minimumRange) : default;
while (nodesToVisit.TryPop(out KdTreeNode currentNode))
{
if (currentNode is null || currentNode._count <= 0)
{
continue;
}
if (!bounds.Intersects(currentNode.boundary))
{
continue;
}
// Use Sphere.Overlaps to check if the sphere fully contains the node's boundary
BoundingBox3D nodeBoundary = BoundingBox3D.FromClosedBounds(currentNode.boundary);
bool nodeFullyContained = querySphere.Overlaps(nodeBoundary);
if (currentNode.isTerminal || nodeFullyContained)
{
int start = currentNode._startIndex;
int end = start + currentNode._count;
// If the node is fully contained, we can skip distance checks for points
if (nodeFullyContained)
{
if (!hasMinimumRange)
{
// Fast path: all points in this node are within range
for (int i = start; i < end; ++i)
{
elementsInRange.Add(values[indices[i]]);
}
}
else
{
// Node is fully in outer sphere, but need to check minimum range
// Check if node is fully outside minimum sphere
bool nodeFullyOutsideMinimum = !minimumSphere.Intersects(nodeBoundary);
if (nodeFullyOutsideMinimum)
{
// Fast path: all points are in the annulus
for (int i = start; i < end; ++i)
{
elementsInRange.Add(values[indices[i]]);
}
}
else
{
// Need to check each point against minimum range
for (int i = start; i < end; ++i)
{
int elementIndex = indices[i];
float squareDistance = GetDistanceSquared(
elementIndex,
position
);
if (squareDistance > minimumRangeSquared)
{
elementsInRange.Add(values[elementIndex]);
}
}
}
}
}
else
{
// Terminal node but not fully contained: check each point
for (int i = start; i < end; ++i)
{
int elementIndex = indices[i];
float squareDistance = GetDistanceSquared(elementIndex, position);
if (squareDistance > rangeSquared)
{
continue;
}
if (hasMinimumRange && squareDistance <= minimumRangeSquared)
{
continue;
}
elementsInRange.Add(values[elementIndex]);
}
}
continue;
}
KdTreeNode left = currentNode.left;
if (left is not null && left._count > 0 && bounds.Intersects(left.boundary))
{
nodesToVisit.Push(left);
}
KdTreeNode right = currentNode.right;
if (right is not null && right._count > 0 && bounds.Intersects(right.boundary))
{
nodesToVisit.Push(right);
}
}
return elementsInRange;
}
public List GetElementsInBounds(Bounds bounds, List elementsInBounds)
{
elementsInBounds.Clear();
if (_head._count <= 0)
{
return elementsInBounds;
}
// Use closed Unity Bounds intersection for traversal to avoid pruning
// legitimate edge cases; final per-point checks use closed semantics.
if (!bounds.Intersects(_bounds))
{
return elementsInBounds;
}
// Build inclusive half-open query box for robust per-point checks
BoundingBox3D queryBox = BoundingBox3D.FromClosedBoundsInclusiveMax(bounds);
using PooledResource> stackResource = Buffers.Stack.Get(
out Stack nodesToVisit
);
nodesToVisit.Push(_head);
ImmutableArray values = elements;
int[] indices = _indices;
while (nodesToVisit.TryPop(out KdTreeNode currentNode))
{
if (currentNode is null || currentNode._count <= 0)
{
continue;
}
if (currentNode.isTerminal)
{
int start = currentNode._startIndex;
int end = start + currentNode._count;
for (int i = start; i < end; ++i)
{
int elementIndex = indices[i];
Vector3 entryPosition = GetPosition(elementIndex);
// Use inclusive half-open check for robust closed semantics
if (queryBox.Contains(entryPosition))
{
elementsInBounds.Add(values[elementIndex]);
}
}
continue;
}
// Once we've reached an internal node that intersects the query,
// visit both non-empty children and rely on per-point checks.
KdTreeNode left = currentNode.left;
if (left is not null && left._count > 0)
{
nodesToVisit.Push(left);
}
KdTreeNode right = currentNode.right;
if (right is not null && right._count > 0)
{
nodesToVisit.Push(right);
}
}
return elementsInBounds;
}
public List GetApproximateNearestNeighbors(
Vector3 position,
int count,
List nearestNeighbors
)
{
nearestNeighbors.Clear();
if (count <= 0 || _head._count == 0)
{
return nearestNeighbors;
}
using PooledResource> nodeBufferResource =
Buffers.Stack.Get(out Stack nodeBuffer);
nodeBuffer.Push(_head);
using PooledResource> nearestNeighborBufferResource = Buffers.HashSet.Get(
out HashSet nearestNeighborBuffer
);
using PooledResource> neighborCandidatesResource =
Buffers.List.Get(out List neighborCandidates);
ImmutableArray values = elements;
int[] indices = _indices;
KdTreeNode current = _head;
while (!current.isTerminal)
{
KdTreeNode left = current.left;
KdTreeNode right = current.right;
if (left is null || left._count == 0)
{
if (right is null || right._count == 0)
{
break;
}
nodeBuffer.Push(right);
current = right;
if (right._count <= count)
{
break;
}
continue;
}
if (right is null || right._count == 0)
{
nodeBuffer.Push(left);
current = left;
if (left._count <= count)
{
break;
}
continue;
}
float leftDistance = (left.boundary.center - position).sqrMagnitude;
float rightDistance = (right.boundary.center - position).sqrMagnitude;
if (leftDistance < rightDistance)
{
if (right._count > 0)
{
nodeBuffer.Push(right);
}
nodeBuffer.Push(left);
current = left;
if (left._count <= count)
{
break;
}
}
else
{
if (left._count > 0)
{
nodeBuffer.Push(left);
}
nodeBuffer.Push(right);
current = right;
if (right._count <= count)
{
break;
}
}
}
while (
nearestNeighborBuffer.Count < count && nodeBuffer.TryPop(out KdTreeNode selected)
)
{
if (selected is null || selected._count <= 0)
{
continue;
}
int startIndex = selected._startIndex;
int endIndex = startIndex + selected._count;
for (int i = startIndex; i < endIndex; ++i)
{
int elementIndex = indices[i];
T value = values[elementIndex];
if (!nearestNeighborBuffer.Add(value))
{
continue;
}
float sqrDistance = GetDistanceSquared(elementIndex, position);
neighborCandidates.Add(new Neighbor(elementIndex, sqrDistance));
}
}
if (neighborCandidates.Count > 1)
{
neighborCandidates.Sort(NeighborComparer.Instance);
}
nearestNeighbors.Clear();
for (int i = 0; i < neighborCandidates.Count && i < count; ++i)
{
nearestNeighbors.Add(values[neighborCandidates[i].index]);
}
return nearestNeighbors;
}
private sealed class NeighborComparer : IComparer
{
internal static readonly NeighborComparer Instance = new();
public int Compare(Neighbor x, Neighbor y)
{
return x.sqrDistance.CompareTo(y.sqrDistance);
}
}
}
}