import { applyPatch } from 'fast-json-patch'
import { DataSource } from 'typeorm'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { BaseRetriever } from '@langchain/core/retrievers'
import { PromptTemplate, ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { Runnable, RunnableSequence, RunnableMap, RunnableBranch, RunnableLambda } from '@langchain/core/runnables'
import { BaseMessage, HumanMessage, AIMessage } from '@langchain/core/messages'
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'
import { StringOutputParser } from '@langchain/core/output_parsers'
import type { Document } from '@langchain/core/documents'
import { BufferMemoryInput } from 'langchain/memory'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { getBaseClasses, mapChatMessageToBaseMessage } from '../../../src/utils'
import { ConsoleCallbackHandler, additionalCallbacks } from '../../../src/handler'
import {
    DtamindMemory,
    ICommonObject,
    IMessage,
    INode,
    INodeData,
    INodeParams,
    IDatabaseEntity,
    MemoryMethods,
    IServerSideEventStreamer
} from '../../../src/Interface'
import { QA_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE } from './prompts'

type RetrievalChainInput = {
    chat_history: string
    question: string
}

const sourceRunnableName = 'FindDocs'

class ConversationalRetrievalQAChain_Chains implements INode {
    label: string
    name: string
    version: number
    type: string
    icon: string
    category: string
    baseClasses: string[]
    description: string
    inputs: INodeParams[]
    sessionId?: string

    constructor(fields?: { sessionId?: string }) {
        this.label = 'Conversational Retrieval QA Chain'
        this.name = 'conversationalRetrievalQAChain'
        this.version = 3.0
        this.type = 'ConversationalRetrievalQAChain'
        this.icon = 'qa.svg'
        this.category = 'Chains'
        this.description = 'Document QA - built on RetrievalQAChain to provide a chat history component'
        this.baseClasses = [this.type, ...getBaseClasses(ConversationalRetrievalQAChain)]
        this.inputs = [
            {
                label: 'Chat Model',
                name: 'model',
                type: 'BaseChatModel'
            },
            {
                label: 'Vector Store Retriever',
                name: 'vectorStoreRetriever',
                type: 'BaseRetriever'
            },
            {
                label: 'Memory',
                name: 'memory',
                type: 'BaseMemory',
                optional: true,
                description: 'If left empty, a default BufferMemory will be used'
            },
            {
                label: 'Return Source Documents',
                name: 'returnSourceDocuments',
                type: 'boolean',
                optional: true
            },
            {
                label: 'Rephrase Prompt',
                name: 'rephrasePrompt',
                type: 'string',
                description: 'Using previous chat history, rephrase question into a standalone question',
                warning: 'Prompt must include input variables: {chat_history} and {question}',
                rows: 4,
                additionalParams: true,
                optional: true,
                default: REPHRASE_TEMPLATE
            },
            {
                label: 'Response Prompt',
                name: 'responsePrompt',
                type: 'string',
                description: 'Taking the rephrased question, search for answer from the provided context',
                warning: 'Prompt must include input variable: {context}',
                rows: 4,
                additionalParams: true,
                optional: true,
                default: RESPONSE_TEMPLATE
            },
            {
                label: 'Input Moderation',
                description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
                name: 'inputModeration',
                type: 'Moderation',
                optional: true,
                list: true
            }
            /** Deprecated
            {
                label: 'System Message',
                name: 'systemMessagePrompt',
                type: 'string',
                rows: 4,
                additionalParams: true,
                optional: true,
                placeholder:
                    'I want you to act as a document that I am having a conversation with. Your name is "AI Assistant". You will provide me with answers from the given info. If the answer is not included, say exactly "Hmm, I am not sure." and stop after that. Refuse to answer any question not about the info. Never break character.'
            },
            // TODO: create standalone chains for these 3 modes as they are not compatible with memory
            {
                label: 'Chain Option',
                name: 'chainOption',
                type: 'options',
                options: [
                    {
                        label: 'MapReduceDocumentsChain',
                        name: 'map_reduce',
                        description:
                            'Suitable for QA tasks over larger documents and can run the preprocessing step in parallel, reducing the running time'
                    },
                    {
                        label: 'RefineDocumentsChain',
                        name: 'refine',
                        description: 'Suitable for QA tasks over a large number of documents.'
                    },
                    {
                        label: 'StuffDocumentsChain',
                        name: 'stuff',
                        description: 'Suitable for QA tasks over a small number of documents.'
                    }
                ],
                additionalParams: true,
                optional: true
            }
            */
        ]
        this.sessionId = fields?.sessionId
    }

    async init(nodeData: INodeData): Promise<any> {
        const model = nodeData.inputs?.model as BaseLanguageModel
        const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
        const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
        const rephrasePrompt = nodeData.inputs?.rephrasePrompt as string
        const responsePrompt = nodeData.inputs?.responsePrompt as string

        let customResponsePrompt = responsePrompt
        // If the deprecated systemMessagePrompt is still exists
        if (systemMessagePrompt) {
            customResponsePrompt = `${systemMessagePrompt}\n${QA_TEMPLATE}`
        }

        const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)
        return answerChain
    }

    async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
        const model = nodeData.inputs?.model as BaseLanguageModel
        const externalMemory = nodeData.inputs?.memory
        const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever as BaseRetriever
        const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
        const rephrasePrompt = nodeData.inputs?.rephrasePrompt as string
        const responsePrompt = nodeData.inputs?.responsePrompt as string
        const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
        const prependMessages = options?.prependMessages

        const appDataSource = options.appDataSource as DataSource
        const databaseEntities = options.databaseEntities as IDatabaseEntity
        const chatflowid = options.chatflowid as string

        const shouldStreamResponse = options.shouldStreamResponse
        const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
        const chatId = options.chatId
        const orgId = options.orgId

        let customResponsePrompt = responsePrompt
        // If the deprecated systemMessagePrompt is still exists
        if (systemMessagePrompt) {
            customResponsePrompt = `${systemMessagePrompt}\n${QA_TEMPLATE}`
        }

        let memory: DtamindMemory | undefined = externalMemory
        const moderations = nodeData.inputs?.inputModeration as Moderation[]
        if (!memory) {
            memory = new BufferMemory({
                returnMessages: true,
                memoryKey: 'chat_history',
                appDataSource,
                databaseEntities,
                chatflowid,
                orgId
            })
        }

        if (moderations && moderations.length > 0) {
            try {
                // Use the output of the moderation chain as input for the Conversational Retrieval QA Chain
                input = await checkInputs(moderations, input)
            } catch (e) {
                await new Promise((resolve) => setTimeout(resolve, 500))
                if (options.shouldStreamResponse) {
                    streamResponse(options.sseStreamer, options.chatId, e.message)
                }
                return formatResponse(e.message)
            }
        }
        const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)

        const history = ((await memory.getChatMessages(this.sessionId, false, prependMessages)) as IMessage[]) ?? []

        const loggerHandler = new ConsoleCallbackHandler(options.logger, options?.orgId)
        const additionalCallback = await additionalCallbacks(nodeData, options)

        let callbacks = [loggerHandler, ...additionalCallback]

        if (process.env.DEBUG === 'true') {
            callbacks.push(new LCConsoleCallbackHandler())
        }

        const stream = answerChain.streamLog(
            { question: input, chat_history: history },
            { callbacks },
            {
                includeNames: [sourceRunnableName]
            }
        )

        let streamedResponse: Record<string, any> = {}
        let sourceDocuments: ICommonObject[] = []
        let text = ''
        let isStreamingStarted = false

        for await (const chunk of stream) {
            streamedResponse = applyPatch(streamedResponse, chunk.ops).newDocument

            if (streamedResponse.final_output) {
                text = streamedResponse.final_output?.output
                if (Array.isArray(streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output)) {
                    sourceDocuments = streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output
                    if (shouldStreamResponse && returnSourceDocuments) {
                        if (sseStreamer) {
                            sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments)
                        }
                    }
                }
                if (shouldStreamResponse && sseStreamer) {
                    sseStreamer.streamEndEvent(chatId)
                }
            }

            if (
                Array.isArray(streamedResponse?.streamed_output) &&
                streamedResponse?.streamed_output.length &&
                !streamedResponse.final_output
            ) {
                const token = streamedResponse.streamed_output[streamedResponse.streamed_output.length - 1]

                if (!isStreamingStarted) {
                    isStreamingStarted = true
                    if (shouldStreamResponse) {
                        if (sseStreamer) {
                            sseStreamer.streamStartEvent(chatId, token)
                        }
                    }
                }
                if (shouldStreamResponse) {
                    if (sseStreamer) {
                        sseStreamer.streamTokenEvent(chatId, token)
                    }
                }
            }
        }

        await memory.addChatMessages(
            [
                {
                    text: input,
                    type: 'userMessage'
                },
                {
                    text: text,
                    type: 'apiMessage'
                }
            ],
            this.sessionId
        )

        if (returnSourceDocuments) return { text, sourceDocuments }
        else return { text }
    }
}

const createRetrieverChain = (llm: BaseLanguageModel, retriever: Runnable, rephrasePrompt: string) => {
    // Small speed/accuracy optimization: no need to rephrase the first question
    // since there shouldn't be any meta-references to prior chat history
    const CONDENSE_QUESTION_PROMPT = PromptTemplate.fromTemplate(rephrasePrompt)
    const condenseQuestionChain = RunnableSequence.from([CONDENSE_QUESTION_PROMPT, llm, new StringOutputParser()]).withConfig({
        runName: 'CondenseQuestion'
    })

    const hasHistoryCheckFn = RunnableLambda.from((input: RetrievalChainInput) => input.chat_history.length > 0).withConfig({
        runName: 'HasChatHistoryCheck'
    })

    const conversationChain = condenseQuestionChain.pipe(retriever).withConfig({
        runName: 'RetrievalChainWithHistory'
    })

    const basicRetrievalChain = RunnableLambda.from((input: RetrievalChainInput) => input.question)
        .withConfig({
            runName: 'Itemgetter:question'
        })
        .pipe(retriever)
        .withConfig({ runName: 'RetrievalChainWithNoHistory' })

    return RunnableBranch.from([[hasHistoryCheckFn, conversationChain], basicRetrievalChain]).withConfig({ runName: sourceRunnableName })
}

const formatDocs = (docs: Document[]) => {
    return docs.map((doc, i) => `<doc id='${i}'>${doc.pageContent}</doc>`).join('\n')
}

const formatChatHistoryAsString = (history: BaseMessage[]) => {
    return history.map((message) => `${message._getType()}: ${message.content}`).join('\n')
}

const serializeHistory = (input: any) => {
    const chatHistory: IMessage[] = input.chat_history || []
    const convertedChatHistory = []
    for (const message of chatHistory) {
        if (message.type === 'userMessage') {
            convertedChatHistory.push(new HumanMessage({ content: message.message }))
        }
        if (message.type === 'apiMessage') {
            convertedChatHistory.push(new AIMessage({ content: message.message }))
        }
    }
    return convertedChatHistory
}

const createChain = (
    llm: BaseLanguageModel,
    retriever: Runnable,
    rephrasePrompt = REPHRASE_TEMPLATE,
    responsePrompt = RESPONSE_TEMPLATE
) => {
    const retrieverChain = createRetrieverChain(llm, retriever, rephrasePrompt)

    const context = RunnableMap.from({
        context: RunnableSequence.from([
            ({ question, chat_history }) => ({
                question,
                chat_history: formatChatHistoryAsString(chat_history)
            }),
            retrieverChain,
            RunnableLambda.from(formatDocs).withConfig({
                runName: 'FormatDocumentChunks'
            })
        ]),
        question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
            runName: 'Itemgetter:question'
        }),
        chat_history: RunnableLambda.from((input: RetrievalChainInput) => input.chat_history).withConfig({
            runName: 'Itemgetter:chat_history'
        })
    }).withConfig({ tags: ['RetrieveDocs'] })

    const prompt = ChatPromptTemplate.fromMessages([
        ['system', responsePrompt],
        new MessagesPlaceholder('chat_history'),
        ['human', `{question}`]
    ])

    const responseSynthesizerChain = RunnableSequence.from([prompt, llm, new StringOutputParser()]).withConfig({
        tags: ['GenerateResponse']
    })

    const conversationalQAChain = RunnableSequence.from([
        {
            question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
                runName: 'Itemgetter:question'
            }),
            chat_history: RunnableLambda.from(serializeHistory).withConfig({
                runName: 'SerializeHistory'
            })
        },
        context,
        responseSynthesizerChain
    ])

    return conversationalQAChain
}

interface BufferMemoryExtendedInput {
    appDataSource: DataSource
    databaseEntities: IDatabaseEntity
    chatflowid: string
    orgId: string
}

class BufferMemory extends DtamindMemory implements MemoryMethods {
    appDataSource: DataSource
    databaseEntities: IDatabaseEntity
    chatflowid: string
    orgId: string

    constructor(fields: BufferMemoryInput & BufferMemoryExtendedInput) {
        super(fields)
        this.appDataSource = fields.appDataSource
        this.databaseEntities = fields.databaseEntities
        this.chatflowid = fields.chatflowid
        this.orgId = fields.orgId
    }

    async getChatMessages(
        overrideSessionId = '',
        returnBaseMessages = false,
        prependMessages?: IMessage[]
    ): Promise<IMessage[] | BaseMessage[]> {
        if (!overrideSessionId) return []

        const chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
            where: {
                sessionId: overrideSessionId,
                chatflowid: this.chatflowid
            },
            order: {
                createdDate: 'ASC'
            }
        })

        if (prependMessages?.length) {
            chatMessage.unshift(...prependMessages)
        }

        if (returnBaseMessages) {
            return await mapChatMessageToBaseMessage(chatMessage, this.orgId)
        }

        let returnIMessages: IMessage[] = []
        for (const m of chatMessage) {
            returnIMessages.push({
                message: m.content as string,
                type: m.role
            })
        }
        return returnIMessages
    }

    async addChatMessages(): Promise<void> {
        // adding chat messages is done on server level
        return
    }

    async clearChatMessages(): Promise<void> {
        // clearing chat messages is done on server level
        return
    }
}

module.exports = { nodeClass: ConversationalRetrievalQAChain_Chains }
