import { Computed, computed, isUninitialized, RESET_VALUE } from '@tldraw/state'
import { CollectionDiff, RecordsDiff } from '@tldraw/store'
import { isShape, TLParentId, TLRecord, TLShape, TLShapeId, TLStore } from '@tldraw/tlschema'
import { sortByIndex } from '@tldraw/utils'

type ParentShapeIdsToChildShapeIds = Record<TLParentId, TLShapeId[]>

function fromScratch(
	shapeIdsQuery: Computed<Set<TLShapeId>, CollectionDiff<TLShapeId>>,
	store: TLStore
) {
	const result: ParentShapeIdsToChildShapeIds = {}
	const shapeIds = shapeIdsQuery.get()
	const sortedShapes = Array.from(shapeIds, (id) => store.get(id)!).sort(sortByIndex)

	// Populate the result object with an array for each parent.
	sortedShapes.forEach((shape) => {
		result[shape.parentId] ??= []
		result[shape.parentId].push(shape.id)
	})

	return result
}

export const parentsToChildren = (store: TLStore) => {
	const shapeIdsQuery = store.query.ids<'shape'>('shape')
	const shapeHistory = store.query.filterHistory('shape')

	return computed<ParentShapeIdsToChildShapeIds>(
		'parentsToChildrenWithIndexes',
		(lastValue, lastComputedEpoch) => {
			if (isUninitialized(lastValue)) {
				return fromScratch(shapeIdsQuery, store)
			}

			const diff = shapeHistory.getDiffSince(lastComputedEpoch)

			if (diff === RESET_VALUE) {
				return fromScratch(shapeIdsQuery, store)
			}

			if (diff.length === 0) return lastValue

			let newValue: Record<TLParentId, TLShapeId[]> | null = null

			const ensureNewArray = (parentId: TLParentId) => {
				if (!newValue) {
					newValue = { ...lastValue }
				}
				if (!newValue[parentId]) {
					newValue[parentId] = []
				} else if (newValue[parentId] === lastValue[parentId]) {
					newValue[parentId] = [...newValue[parentId]!]
				}
			}

			const toSort = new Set<TLShapeId[]>()

			let changes: RecordsDiff<TLRecord>

			for (let i = 0, n = diff.length; i < n; i++) {
				changes = diff[i]

				// Iterate through the added shapes, add them to the new value and mark them for sorting
				for (const record of Object.values(changes.added)) {
					if (!isShape(record)) continue
					ensureNewArray(record.parentId)
					newValue![record.parentId].push(record.id)
					toSort.add(newValue![record.parentId])
				}

				// Iterate through the updated shapes, add them to their parents in the new value and mark them for sorting
				for (const [from, to] of Object.values(changes.updated)) {
					if (!isShape(to)) continue
					if (!isShape(from)) continue

					if (from.parentId !== to.parentId) {
						// If the parents have changed, remove the new value from the old parent and add it to the new parent
						ensureNewArray(from.parentId)
						ensureNewArray(to.parentId)
						newValue![from.parentId].splice(newValue![from.parentId].indexOf(to.id), 1)
						newValue![to.parentId].push(to.id)
						toSort.add(newValue![to.parentId])
					} else if (from.index !== to.index) {
						// If the parent is the same but the index has changed (e.g. if they've been reordered), update the parent's array at the new index
						ensureNewArray(to.parentId)
						const idx = newValue![to.parentId].indexOf(to.id)
						newValue![to.parentId][idx] = to.id
						toSort.add(newValue![to.parentId])
					}
				}

				// Iterate through the removed shapes, remove them from their parents in new value
				for (const record of Object.values(changes.removed)) {
					if (!isShape(record)) continue
					ensureNewArray(record.parentId)
					newValue![record.parentId].splice(newValue![record.parentId].indexOf(record.id), 1)
				}
			}

			// Sort the arrays that have been marked for sorting (in-place to avoid intermediate arrays)
			for (const arr of toSort) {
				// Filter out any deleted shapes in-place
				let writeIdx = 0
				for (let readIdx = 0; readIdx < arr.length; readIdx++) {
					if (store.get(arr[readIdx])) {
						arr[writeIdx++] = arr[readIdx]
					}
				}
				arr.length = writeIdx

				// Sort in-place by index
				arr.sort((a, b) => {
					const shapeA = store.get(a) as TLShape
					const shapeB = store.get(b) as TLShape
					return sortByIndex(shapeA, shapeB)
				})
			}

			return newValue ?? lastValue
		}
	)
}
