import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state'
import { TLBinding, TLShapeId } from '@tldraw/tlschema'
import { objectMapValues } from '@tldraw/utils'
import { Editor } from '../Editor'

type TLBindingsIndex = Map<TLShapeId, TLBinding[]>

export const bindingsIndex = (editor: Editor): Computed<TLBindingsIndex> => {
	const { store } = editor
	const bindingsHistory = store.query.filterHistory('binding')
	const bindingsQuery = store.query.records('binding')
	function fromScratch() {
		const allBindings = bindingsQuery.get() as TLBinding[]

		const shape2Binding: TLBindingsIndex = new Map()

		for (const binding of allBindings) {
			const { fromId, toId } = binding
			const bindingsForFromShape = shape2Binding.get(fromId)
			if (!bindingsForFromShape) {
				shape2Binding.set(fromId, [binding])
			} else {
				bindingsForFromShape.push(binding)
			}
			const bindingsForToShape = shape2Binding.get(toId)
			if (!bindingsForToShape) {
				shape2Binding.set(toId, [binding])
			} else {
				bindingsForToShape.push(binding)
			}
		}

		return shape2Binding
	}

	return computed<TLBindingsIndex>('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => {
		if (isUninitialized(_lastValue)) {
			return fromScratch()
		}

		const lastValue = _lastValue

		const diff = bindingsHistory.getDiffSince(lastComputedEpoch)

		if (diff === RESET_VALUE) {
			return fromScratch()
		}

		let nextValue: TLBindingsIndex | undefined = undefined

		function removingBinding(binding: TLBinding) {
			nextValue ??= new Map(lastValue)
			const prevFrom = nextValue.get(binding.fromId)
			const nextFrom = prevFrom?.filter((b) => b.id !== binding.id)
			if (!nextFrom?.length) {
				nextValue.delete(binding.fromId)
			} else {
				nextValue.set(binding.fromId, nextFrom)
			}
			const prevTo = nextValue.get(binding.toId)
			const nextTo = prevTo?.filter((b) => b.id !== binding.id)
			if (!nextTo?.length) {
				nextValue.delete(binding.toId)
			} else {
				nextValue.set(binding.toId, nextTo)
			}
		}

		function ensureNewArray(shapeId: TLShapeId) {
			nextValue ??= new Map(lastValue)

			let result = nextValue.get(shapeId)
			if (!result) {
				result = []
				nextValue.set(shapeId, result)
			} else if (result === lastValue.get(shapeId)) {
				result = result.slice(0)
				nextValue.set(shapeId, result)
			}
			return result
		}

		function addBinding(binding: TLBinding) {
			ensureNewArray(binding.fromId).push(binding)
			ensureNewArray(binding.toId).push(binding)
		}

		for (const changes of diff) {
			for (const newBinding of objectMapValues(changes.added)) {
				addBinding(newBinding)
			}

			for (const [prev, next] of objectMapValues(changes.updated)) {
				removingBinding(prev)
				addBinding(next)
			}

			for (const prev of objectMapValues(changes.removed)) {
				removingBinding(prev)
			}
		}

		// TODO: add diff entries if we need them
		return nextValue ?? lastValue
	})
}
