import { type PreTrainedModel, type PreTrainedTokenizer } from '@echogarden/transformers-nodejs-lite'
import { Logger } from '../utilities/Logger.js'
import { loadPackage } from '../utilities/PackageManager.js'
import { alignDTWWindowed } from './DTWSequenceAlignmentWindowed.js'
import { cosineDistance } from '../math/VectorMath.js'
import { includesPunctuation, isWord, parseText } from '../nlp/Segmentation.js'
import { Timeline, extractEntries } from '../utilities/Timeline.js'

export async function alignTimelineToTextSemantically(timeline: Timeline, text: string, textLangCode: string) {
	const logger = new Logger()

	logger.start(`Prepare text for semantic alignment`)

	const timelineSentenceEntries = extractEntries(timeline, entry => entry.type === 'sentence')

	const timelineWordEntryGroups: Timeline[] = []
	const timelineWordGroups: string[][] = []

	for (const sentenceEntry of timelineSentenceEntries) {
		const wordEntryGroup = sentenceEntry.timeline!
			.filter(wordEntry => isWord(wordEntry.text))

		timelineWordEntryGroups.push(wordEntryGroup)
		timelineWordGroups.push(wordEntryGroup.map(wordEntry => wordEntry.text))
	}

	const timelineWordEntriesFiltered = timelineWordEntryGroups.flat()

	const segmentedText = await parseText(text, textLangCode)

	const textWordGroups: string[][] = []

	for (const sentenceEntry of segmentedText.sentences) {
		const wordGroup = sentenceEntry.words.nonPunctuationWords

		textWordGroups.push(wordGroup)
	}

	const textWords = textWordGroups.flat()

	logger.end()

	const wordMappingEntries = await alignWordsToWordsSemantically(timelineWordGroups, textWordGroups)

	logger.start(`Build timeline for translation`)

	const mappingGroups = new Map<number, number[]>()

	for (const wordMappingEntry of wordMappingEntries) {
		const wordIndex1 = wordMappingEntry.wordIndex1
		const wordIndex2 = wordMappingEntry.wordIndex2

		let group = mappingGroups.get(wordIndex1)

		if (!group) {
			group = []
			mappingGroups.set(wordIndex1, group)
		}

		if (!group.includes(wordIndex2)) {
			group.push(wordIndex2)
		}
	}

	type TimeSlice = { startTime: number, endTime: number }

	const timeSlicesLookup = new Map<number, TimeSlice[]>()

	for (const [wordIndex1, mappedWordIndexes] of mappingGroups) {
		if (mappedWordIndexes.length === 0) {
			continue
		}

		const startTime = timelineWordEntriesFiltered[wordIndex1].startTime
		const endTime = timelineWordEntriesFiltered[wordIndex1].endTime

		const splitCount = mappedWordIndexes.length

		const sliceDuration = (endTime - startTime) / splitCount

		let timeOffset = 0

		for (let i = 0; i < splitCount; i++) {
			const timeSlice: TimeSlice = {
				startTime: startTime + timeOffset,
				endTime: startTime + timeOffset + sliceDuration
			}

			const wordIndex2 = mappedWordIndexes[i]

			let timeSlicesForTargetWord = timeSlicesLookup.get(wordIndex2)

			if (!timeSlicesForTargetWord) {
				timeSlicesForTargetWord = []
				timeSlicesLookup.set(wordIndex2, timeSlicesForTargetWord)
			}

			timeSlicesForTargetWord.push(timeSlice)

			timeOffset += sliceDuration
		}
	}

	const resultTimeline: Timeline = []

	for (const [key, value] of timeSlicesLookup) {
		resultTimeline.push({
			type: 'word',
			text: textWords[key],
			startTime: value[0].startTime,
			endTime: value[value.length - 1].endTime
		})
	}

	logger.end()

	return resultTimeline
}

export async function alignWordsToWordsSemantically(wordsGroups1: string[][], wordsGroups2: string[][], windowTokenCount = 20000) {
	const logger = new Logger()

	// Load embedding model
	const modelPath = await loadPackage(`xenova-multilingual-e5-small-fp16`)

	const embeddingModel = new E5TextEmbedding(modelPath)

	logger.start(`Initialize E5 embedding model`)
	await embeddingModel.initializeIfNeeded()

	async function extractEmbeddingsFromWordGroups(wordGroups: string[][]) {
		const logger = new Logger()

		const maxTokensPerFragment = 512
		const { Tensor } = await import('@echogarden/transformers-nodejs-lite')

		const words: string[] = []

		const embeddings: TokenEmbeddingData[] = []
		const tokenToWordIndexMapping: number[] = []

		for (const wordGroup of wordGroups) {
			const { joinedText: joinedTextForGroup, offsets: offsetsForGroup } = joinAndGetOffsets(wordGroup)

			logger.start(`Tokenize text`)
			const inputsForGroup = await embeddingModel.tokenizeToModelInputs(joinedTextForGroup)

			logger.start(`Infer embeddings for text`)

			const allTokenIds = inputsForGroup['input_ids'].data
			const allAttentionMask = inputsForGroup['attention_mask'].data

			let embeddingsForGroup: TokenEmbeddingData[] = []

			for (let tokenStart = 0; tokenStart < allTokenIds.length; tokenStart += maxTokensPerFragment) {
				const tokenEnd = Math.min(tokenStart + maxTokensPerFragment, allTokenIds.length)
				const fragmentTokenCount = tokenEnd - tokenStart

				const fragmentInputIdsTensor = new Tensor('int64', allTokenIds.slice(tokenStart, tokenEnd), [1, fragmentTokenCount])
				const fragmentAttentionMaskTensor = new Tensor('int64', allAttentionMask.slice(tokenStart, tokenEnd), [1, fragmentTokenCount])

				const inputsForFragment = { input_ids: fragmentInputIdsTensor, attention_mask: fragmentAttentionMaskTensor }

				const embeddingsForFragment = await embeddingModel.inferTokenEmbeddings(inputsForFragment)

				embeddingsForGroup.push(...embeddingsForFragment)
			}

			logger.start(`Compute token to word mapping for text`)
			const filteredEmbeddingsForGroup = embeddingsForGroup.filter((embedding) => embedding.text !== '▁' && embedding.text !== '<s>' && embedding.text !== '</s>')
			const tokenToWordIndexMappingForGroup = mapTokenEmbeddingsToWordIndexes(filteredEmbeddingsForGroup, joinedTextForGroup, offsetsForGroup)
			const tokenToWordIndexMappingForGroupWithOffset = tokenToWordIndexMappingForGroup.map(value => words.length + value)

			embeddings.push(...filteredEmbeddingsForGroup)
			tokenToWordIndexMapping.push(...tokenToWordIndexMappingForGroupWithOffset)

			words.push(...wordGroup)
		}

		return { words, embeddings, tokenToWordIndexMapping }
	}

	logger.start(`Extract embeddings from source 1`)
	const {
		words: words1,
		embeddings: embeddings1,
		tokenToWordIndexMapping: tokenToWordIndexMapping1
	} = await extractEmbeddingsFromWordGroups(wordsGroups1)

	logger.start(`Extract embeddings from source 2`)
	const {
		words: words2,
		embeddings: embeddings2,
		tokenToWordIndexMapping: tokenToWordIndexMapping2
	} = await extractEmbeddingsFromWordGroups(wordsGroups2)

	// Align
	function costFunction(a: TokenEmbeddingData, b: TokenEmbeddingData) {
		const aIsPunctuation = includesPunctuation(a.text)
		const bIsPunctuation = includesPunctuation(b.text)

		if (aIsPunctuation === bIsPunctuation) {
			return cosineDistance(a.embeddingVector, b.embeddingVector)
		} else {
			return 1.0
		}
	}

	logger.start(`Align token embedding vectors using DTW`)

	const { path } = alignDTWWindowed(embeddings1, embeddings2, costFunction, windowTokenCount)

	// Use alignment path to words to words
	logger.start(`Map tokens to words`)

	const wordMapping: WordMapping[] = []

	for (let i = 0; i < path.length; i++) {
		const pathEntry = path[i]

		const sourceTokenIndex = pathEntry.source
		const destTokenIndex = pathEntry.dest

		const mappedWordIndex1 = tokenToWordIndexMapping1[sourceTokenIndex]
		const mappedWordIndex2 = tokenToWordIndexMapping2[destTokenIndex]

		wordMapping.push({
			wordIndex1: mappedWordIndex1,
			word1: words1[mappedWordIndex1],
			wordIndex2: mappedWordIndex2,
			word2: words2[mappedWordIndex2],
		})
	}

	logger.end()

	return wordMapping
}

function mapTokenEmbeddingsToWordIndexes(embeddings: TokenEmbeddingData[], text: string, textWordOffsets: number[]) {
	const tokenToWordIndex: number[] = []

	let currentTextOffset = 0

	for (let i = 0; i < embeddings.length; i++) {
		const embedding = embeddings[i]
		let tokenText = embedding.text

		if (tokenText === '<s>' || tokenText === '</s>') {
			tokenToWordIndex.push(-1)

			continue
		}

		if (tokenText.startsWith('▁')) {
			tokenText = tokenText.substring(1)
		}

		const matchPosition = text.indexOf(tokenText, currentTextOffset)

		if (matchPosition === -1) {
			throw new Error(`Token '${tokenText}' not found in text`)
		}

		currentTextOffset = matchPosition + tokenText.length

		let tokenMatchingWordIndex = textWordOffsets.findIndex((index) => index > matchPosition)

		if (tokenMatchingWordIndex === -1) {
			throw new Error(`Token '${tokenText}' not found in text`)
		} else {
			tokenMatchingWordIndex = Math.max(tokenMatchingWordIndex - 1, 0)
		}

		tokenToWordIndex.push(tokenMatchingWordIndex)
	}

	return tokenToWordIndex
}

function joinAndGetOffsets(words: string[]) {
	let joinedText = ''
	const offsets: number[] = []

	let offset = 0

	for (const word of words) {
		const extendedWord = `${word} `
		joinedText += extendedWord

		offsets.push(offset)

		offset += extendedWord.length
	}

	offsets.push(joinedText.length)

	return { joinedText, offsets }
}

export class E5TextEmbedding {
	tokenizer?: PreTrainedTokenizer
	model?: PreTrainedModel

	constructor(public readonly modelPath: string) {
	}

	async tokenizeToModelInputs(text: string) {
		await this.initializeIfNeeded()

		const inputs = await this.tokenizer!(text)

		return inputs
	}

	async inferTokenEmbeddings(inputs: any) {
		await this.initializeIfNeeded()

		const tokensText = this.tokenizer!.model.convert_ids_to_tokens(Array.from(inputs.input_ids.data))

		const result = await this.model!(inputs)

		const lastHiddenState = result.last_hidden_state

		const tokenCount = lastHiddenState.dims[1]
		const embeddingSize = lastHiddenState.dims[2]

		const tokenEmbeddings: TokenEmbeddingData[] = []

		for (let i = 0; i < tokenCount; i++) {
			const tokenEmbeddingVector = lastHiddenState.data.slice(i * embeddingSize, (i + 1) * embeddingSize)

			const tokenId = Number(inputs.input_ids.data[i])
			const tokenText = tokensText[i]

			tokenEmbeddings.push({
				id: tokenId,
				text: tokenText,
				embeddingVector: tokenEmbeddingVector
			})
		}

		return tokenEmbeddings
	}

	async initializeIfNeeded() {
		if (this.tokenizer && this.model) {
			return
		}

		const { AutoTokenizer, AutoModel } = await import('@echogarden/transformers-nodejs-lite')

		this.tokenizer = await AutoTokenizer.from_pretrained(this.modelPath)
		this.model = await AutoModel.from_pretrained(this.modelPath)
	}
}

export interface TokenEmbeddingData {
	id: number
	text: string
	embeddingVector: Float32Array
}

export interface WordMapping {
	wordIndex1: number
	word1: string

	wordIndex2: number
	word2: string
}

export const e5SupportedLanguages: string[] = [
	'af', // Afrikaans
	'am', // Amharic
	'ar', // Arabic
	'as', // Assamese
	'az', // Azerbaijani
	'be', // Belarusian
	'bg', // Bulgarian
	'bn', // Bengali
	'br', // Breton
	'bs', // Bosnian
	'ca', // Catalan
	'cs', // Czech
	'cy', // Welsh
	'da', // Danish
	'de', // German
	'el', // Greek
	'en', // English
	'eo', // Esperanto
	'es', // Spanish
	'et', // Estonian
	'eu', // Basque
	'fa', // Persian
	'fi', // Finnish
	'fr', // French
	'fy', // Western Frisian
	'ga', // Irish
	'gd', // Scottish Gaelic
	'gl', // Galician
	'gu', // Gujarati
	'ha', // Hausa
	'he', // Hebrew
	'hi', // Hindi
	'hr', // Croatian
	'hu', // Hungarian
	'hy', // Armenian
	'id', // Indonesian
	'is', // Icelandic
	'it', // Italian
	'ja', // Japanese
	'jv', // Javanese
	'ka', // Georgian
	'kk', // Kazakh
	'km', // Khmer
	'kn', // Kannada
	'ko', // Korean
	'ku', // Kurdish
	'ky', // Kyrgyz
	'la', // Latin
	'lo', // Lao
	'lt', // Lithuanian
	'lv', // Latvian
	'mg', // Malagasy
	'mk', // Macedonian
	'ml', // Malayalam
	'mn', // Mongolian
	'mr', // Marathi
	'ms', // Malay
	'my', // Burmese
	'ne', // Nepali
	'nl', // Dutch
	'no', // Norwegian
	'om', // Oromo
	'or', // Oriya
	'pa', // Panjabi
	'pl', // Polish
	'ps', // Pashto
	'pt', // Portuguese
	'ro', // Romanian
	'ru', // Russian
	'sa', // Sanskrit
	'sd', // Sindhi
	'si', // Sinhala
	'sk', // Slovak
	'sl', // Slovenian
	'so', // Somali
	'sq', // Albanian
	'sr', // Serbian
	'su', // Sundanese
	'sv', // Swedish
	'sw', // Swahili
	'ta', // Tamil
	'te', // Telugu
	'th', // Thai
	'tl', // Tagalog
	'tr', // Turkish
	'ug', // Uyghur
	'uk', // Ukrainian
	'ur', // Urdu
	'uz', // Uzbek
	'vi', // Vietnamese
	'xh', // Xhosa
	'yi', // Yiddish
	'zh', // Chinese
]
