import { ResourceSource } from '../types/common';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { Template } from '@huggingface/jinja';
import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults';
import {
  ChatConfig,
  GenerationConfig,
  LLMCapability,
  LLMTool,
  Message,
  SPECIAL_TOKENS,
  ToolsConfig,
} from '../types/llm';
import { parseToolCall } from '../utils/llm';
import { Logger } from '../common/Logger';
import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils';
import { RnExecutorchErrorCode } from '../errors/ErrorCodes';

export class LLMController {
  private nativeModule: any;
  private chatConfig: ChatConfig = DEFAULT_CHAT_CONFIG;
  private toolsConfig: ToolsConfig | undefined;
  private tokenizerConfig: any;
  private onToken?: (token: string) => void;
  private _isReady = false;
  private _isGenerating = false;
  private _messageHistory: Message[] = [];
  // User callbacks
  private tokenCallback: (token: string) => void;
  private messageHistoryCallback: (messageHistory: Message[]) => void;
  private isReadyCallback: (isReady: boolean) => void;
  private isGeneratingCallback: (isGenerating: boolean) => void;

  constructor({
    tokenCallback,
    messageHistoryCallback,
    isReadyCallback,
    isGeneratingCallback,
  }: {
    tokenCallback?: (token: string) => void;
    messageHistoryCallback?: (messageHistory: Message[]) => void;
    isReadyCallback?: (isReady: boolean) => void;
    isGeneratingCallback?: (isGenerating: boolean) => void;
  }) {
    this.tokenCallback = (token) => {
      tokenCallback?.(token);
    };
    this.messageHistoryCallback = (messageHistory) => {
      this._messageHistory = messageHistory;
      messageHistoryCallback?.(messageHistory);
    };
    this.isReadyCallback = (isReady) => {
      this._isReady = isReady;
      isReadyCallback?.(isReady);
    };
    this.isGeneratingCallback = (isGenerating) => {
      this._isGenerating = isGenerating;
      isGeneratingCallback?.(isGenerating);
    };
  }

  public get isReady() {
    return this._isReady;
  }

  public get isGenerating() {
    return this._isGenerating;
  }

  public get messageHistory() {
    return this._messageHistory;
  }

  public async load({
    modelSource,
    tokenizerSource,
    tokenizerConfigSource,
    capabilities,
    onDownloadProgressCallback,
  }: {
    modelSource: ResourceSource;
    tokenizerSource: ResourceSource;
    tokenizerConfigSource: ResourceSource;
    capabilities?: readonly LLMCapability[];
    onDownloadProgressCallback?: (downloadProgress: number) => void;
  }) {
    // reset inner state when loading new model
    this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
    this.isGeneratingCallback(false);
    this.isReadyCallback(false);

    try {
      const tokenizersPromise = ResourceFetcher.fetch(
        undefined,
        tokenizerSource,
        tokenizerConfigSource
      );

      const modelPromise = ResourceFetcher.fetch(
        onDownloadProgressCallback,
        modelSource
      );

      const [tokenizersResults, modelResult] = await Promise.all([
        tokenizersPromise,
        modelPromise,
      ]);

      const tokenizerPath = tokenizersResults?.[0];
      const tokenizerConfigPath = tokenizersResults?.[1];
      const modelPath = modelResult?.[0];

      if (!tokenizerPath || !tokenizerConfigPath || !modelPath) {
        throw new RnExecutorchError(
          RnExecutorchErrorCode.DownloadInterrupted,
          'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
        );
      }

      this.tokenizerConfig = JSON.parse(
        await ResourceFetcher.fs.readAsString(tokenizerConfigPath!)
      );

      if (this.nativeModule) {
        this.nativeModule.unload();
      }

      this.nativeModule = await global.loadLLM(
        modelPath,
        tokenizerPath,
        capabilities ?? []
      );
      this.isReadyCallback(true);
      this.onToken = (data: string) => {
        if (!data) {
          return;
        }

        const filtered = this.filterSpecialTokens(data);

        if (filtered.length === 0) {
          return;
        }
        this.tokenCallback(filtered);
      };
    } catch (e) {
      Logger.error('Load failed:', e);
      this.isReadyCallback(false);
      throw parseUnknownError(e);
    }
  }

  public setTokenCallback(tokenCallback: (token: string) => void) {
    this.tokenCallback = tokenCallback;
  }

  public configure({
    chatConfig,
    toolsConfig,
    generationConfig,
  }: {
    chatConfig?: Partial<ChatConfig>;
    toolsConfig?: ToolsConfig;
    generationConfig?: GenerationConfig;
  }) {
    this.chatConfig = { ...DEFAULT_CHAT_CONFIG, ...chatConfig };
    this.toolsConfig = toolsConfig;

    if (generationConfig?.outputTokenBatchSize) {
      this.nativeModule.setCountInterval(generationConfig.outputTokenBatchSize);
    }
    if (generationConfig?.batchTimeInterval) {
      this.nativeModule.setTimeInterval(generationConfig.batchTimeInterval);
    }
    if (generationConfig?.temperature) {
      this.nativeModule.setTemperature(generationConfig.temperature);
    }
    if (generationConfig?.topp) {
      if (generationConfig.topp < 0 || generationConfig.topp > 1) {
        throw new RnExecutorchError(
          RnExecutorchErrorCode.InvalidConfig,
          'Top P has to be in range [0, 1]'
        );
      }
      this.nativeModule.setTopp(generationConfig.topp);
    }

    // reset inner state when loading new configuration
    this.messageHistoryCallback(this.chatConfig.initialMessageHistory);
    this.isGeneratingCallback(false);
  }

  private getImageToken(): string {
    const token = this.tokenizerConfig.image_token;
    if (!token) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.InvalidConfig,
        "Tokenizer config is missing 'image_token'. Vision models require tokenizerConfigSource with an 'image_token' field."
      );
    }
    return token;
  }

  private filterSpecialTokens(text: string): string {
    let filtered = text;
    if (
      SPECIAL_TOKENS.EOS_TOKEN in this.tokenizerConfig &&
      this.tokenizerConfig.eos_token
    ) {
      filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, '');
    }
    if (
      SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig &&
      this.tokenizerConfig.pad_token
    ) {
      filtered = filtered.replaceAll(this.tokenizerConfig.pad_token, '');
    }
    return filtered;
  }

  public delete() {
    if (this._isGenerating) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModelGenerating,
        'You cannot delete the model now. You need ot interrupt it first.'
      );
    }

    this.onToken = () => {};
    if (this.nativeModule) {
      this.nativeModule.unload();
    }
    this.isReadyCallback(false);
    this.isGeneratingCallback(false);
  }

  public async forward(input: string, imagePaths?: string[]): Promise<string> {
    if (!this._isReady) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModuleNotLoaded,
        'The model is currently not loaded. Please load the model before calling forward().'
      );
    }
    if (this._isGenerating) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModelGenerating,
        'The model is currently generating. Please wait until previous model run is complete.'
      );
    }
    try {
      this.isGeneratingCallback(true);
      this.nativeModule.reset();
      const response =
        imagePaths && imagePaths.length > 0
          ? await this.nativeModule.generateMultimodal(
              input,
              imagePaths,
              this.getImageToken(),
              this.onToken
            )
          : await this.nativeModule.generate(input, this.onToken);
      return this.filterSpecialTokens(response);
    } catch (e) {
      throw parseUnknownError(e);
    } finally {
      this.isGeneratingCallback(false);
    }
  }

  public interrupt() {
    if (!this.nativeModule) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModuleNotLoaded,
        "Cannot interrupt a model that's not loaded."
      );
    }
    this.nativeModule.interrupt();
  }

  public getGeneratedTokenCount(): number {
    if (!this.nativeModule) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModuleNotLoaded,
        "Cannot get token count for a model that's not loaded."
      );
    }
    return this.nativeModule.getGeneratedTokenCount();
  }

  public getPromptTokenCount(): number {
    if (!this.nativeModule) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModuleNotLoaded,
        "Cannot get prompt token count for a model that's not loaded."
      );
    }
    return this.nativeModule.getPromptTokenCount();
  }

  public getTotalTokenCount(): number {
    return this.getGeneratedTokenCount() + this.getPromptTokenCount();
  }

  public async generate(
    messages: Message[],
    tools?: LLMTool[]
  ): Promise<string> {
    if (!this._isReady) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.ModuleNotLoaded,
        'The model is currently not loaded. Please load the model before calling generate().'
      );
    }
    if (messages.length === 0) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.InvalidUserInput,
        'Messages array is empty!'
      );
    }
    if (messages[0] && messages[0].role !== 'system') {
      Logger.warn(
        `You are not providing system prompt. You can pass it in the first message using { role: 'system', content: YOUR_PROMPT }. Otherwise prompt from your model's chat template will be used.`
      );
    }

    const imagePaths = messages
      .filter((m) => m.mediaPath)
      .map((m) => m.mediaPath!);

    const renderedChat: string = this.applyChatTemplate(
      messages,
      this.tokenizerConfig,
      tools,
      // eslint-disable-next-line camelcase
      { tools_in_user_message: false, add_generation_prompt: true }
    );

    return await this.forward(
      renderedChat,
      imagePaths.length > 0 ? imagePaths : undefined
    );
  }

  public async sendMessage(
    message: string,
    media?: { imagePath?: string }
  ): Promise<string> {
    const mediaPath = media?.imagePath;
    const newMessage: Message = {
      content: message,
      role: 'user',
      ...(mediaPath ? { mediaPath } : {}),
    };
    const updatedHistory = [...this._messageHistory, newMessage];
    this.messageHistoryCallback(updatedHistory);

    const historyForTemplate = updatedHistory.map((m) =>
      m.mediaPath
        ? {
            ...m,
            content: [
              { type: 'image' },
              { type: 'text', text: m.content },
            ] as any,
          }
        : m
    );

    const visualTokenCount = this.nativeModule.getVisualTokenCount();
    const countTokensCallback = (messages: Message[]) => {
      const rendered = this.applyChatTemplate(
        messages,
        this.tokenizerConfig,
        this.toolsConfig?.tools,
        // eslint-disable-next-line camelcase
        { tools_in_user_message: false, add_generation_prompt: true }
      );
      const textTokens = this.nativeModule.countTextTokens(rendered);
      const imageCount = messages.filter((m) => m.mediaPath).length;
      return textTokens + imageCount * (visualTokenCount - 1);
    };
    const maxContextLength = this.nativeModule.getMaxContextLength();
    const messageHistoryWithPrompt =
      this.chatConfig.contextStrategy.buildContext(
        this.chatConfig.systemPrompt,
        historyForTemplate,
        maxContextLength,
        countTokensCallback
      );

    const response = await this.generate(
      messageHistoryWithPrompt,
      this.toolsConfig?.tools
    );

    if (!this.toolsConfig || this.toolsConfig.displayToolCalls) {
      this.messageHistoryCallback([
        ...this._messageHistory,
        { content: response, role: 'assistant' },
      ]);
    }

    if (this.toolsConfig) {
      const toolCalls = parseToolCall(response);
      for (const toolCall of toolCalls) {
        this.toolsConfig
          .executeToolCallback(toolCall)
          .then((toolResponse: string | null) => {
            if (toolResponse) {
              this.messageHistoryCallback([
                ...this._messageHistory,
                { content: toolResponse, role: 'assistant' },
              ]);
            }
          });
      }
    }

    return response;
  }

  public deleteMessage(index: number) {
    // we delete referenced message and all messages after it
    // so the model responses that used them are deleted as well
    const newMessageHistory = this._messageHistory.slice(0, index);

    this.messageHistoryCallback(newMessageHistory);
  }

  private applyChatTemplate(
    messages: Message[],
    tokenizerConfig: any,
    tools?: LLMTool[],
    templateFlags?: Object
  ): string {
    if (!tokenizerConfig.chat_template) {
      throw new RnExecutorchError(
        RnExecutorchErrorCode.TokenizerError,
        "Tokenizer config doesn't include chat_template"
      );
    }
    const template = new Template(tokenizerConfig.chat_template);

    const specialTokens = Object.fromEntries(
      Object.values(SPECIAL_TOKENS)
        .filter((key) => key in tokenizerConfig)
        .map((key) => [key, tokenizerConfig[key]])
    );

    const result = template.render({
      messages,
      tools,
      ...templateFlags,
      ...specialTokens,
    });
    return result;
  }
}
