///////////////////////////////////////////////////////////////////////
// File:        lstmrecognizer.cpp
// Description: Top-level line recognizer class for LSTM-based networks.
// Author:      Ray Smith
// Created:     Thu May 02 10:59:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////

// Include automatically generated configuration file if running autoconf.
#ifdef HAVE_CONFIG_H
#  include "config_auto.h"
#endif

#include "lstmrecognizer.h"

#include "allheaders.h"
#include "callcpp.h"
#include "dict.h"
#include "genericheap.h"
#include "helpers.h"
#include "imagedata.h"
#include "input.h"
#include "lstm.h"
#include "normalis.h"
#include "pageres.h"
#include "ratngs.h"
#include "recodebeam.h"
#include "scrollview.h"
#include "statistc.h"
#include "tprintf.h"

namespace tesseract {

// Default ratio between dict and non-dict words.
const double kDictRatio = 2.25;
// Default certainty offset to give the dictionary a chance.
const double kCertOffset = -0.085;

LSTMRecognizer::LSTMRecognizer(const STRING language_data_path_prefix)
    : LSTMRecognizer::LSTMRecognizer() {
  ccutil_.language_data_path_prefix = language_data_path_prefix;
}

LSTMRecognizer::LSTMRecognizer()
    : network_(nullptr),
      training_flags_(0),
      training_iteration_(0),
      sample_iteration_(0),
      null_char_(UNICHAR_BROKEN),
      learning_rate_(0.0f),
      momentum_(0.0f),
      adam_beta_(0.0f),
      dict_(nullptr),
      search_(nullptr),
      debug_win_(nullptr) {}

LSTMRecognizer::~LSTMRecognizer() {
  delete network_;
  delete dict_;
  delete search_;
}

// Loads a model from mgr, including the dictionary only if lang is not null.
bool LSTMRecognizer::Load(const ParamsVectors* params, const char* lang,
                          TessdataManager* mgr) {
  TFile fp;
  if (!mgr->GetComponent(TESSDATA_LSTM, &fp)) return false;
  if (!DeSerialize(mgr, &fp)) return false;
  if (lang == nullptr) return true;
  // Allow it to run without a dictionary.
  LoadDictionary(params, lang, mgr);
  return true;
}

// Writes to the given file. Returns false in case of error.
bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
  bool include_charsets = mgr == nullptr ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
  if (!network_->Serialize(fp)) return false;
  if (include_charsets && !GetUnicharset().save_to_file(fp)) return false;
  if (!network_str_.Serialize(fp)) return false;
  if (!fp->Serialize(&training_flags_)) return false;
  if (!fp->Serialize(&training_iteration_)) return false;
  if (!fp->Serialize(&sample_iteration_)) return false;
  if (!fp->Serialize(&null_char_)) return false;
  if (!fp->Serialize(&adam_beta_)) return false;
  if (!fp->Serialize(&learning_rate_)) return false;
  if (!fp->Serialize(&momentum_)) return false;
  if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
  return true;
}

// Reads from the given file. Returns false in case of error.
bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
  delete network_;
  network_ = Network::CreateFromFile(fp);
  if (network_ == nullptr) return false;
  bool include_charsets = mgr == nullptr ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_RECODER) ||
                          !mgr->IsComponentAvailable(TESSDATA_LSTM_UNICHARSET);
  if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false))
    return false;
  if (!network_str_.DeSerialize(fp)) return false;
  if (!fp->DeSerialize(&training_flags_)) return false;
  if (!fp->DeSerialize(&training_iteration_)) return false;
  if (!fp->DeSerialize(&sample_iteration_)) return false;
  if (!fp->DeSerialize(&null_char_)) return false;
  if (!fp->DeSerialize(&adam_beta_)) return false;
  if (!fp->DeSerialize(&learning_rate_)) return false;
  if (!fp->DeSerialize(&momentum_)) return false;
  if (include_charsets && !LoadRecoder(fp)) return false;
  if (!include_charsets && !LoadCharsets(mgr)) return false;
  network_->SetRandomizer(&randomizer_);
  network_->CacheXScaleFactor(network_->XScaleFactor());
  return true;
}

// Loads the charsets from mgr.
bool LSTMRecognizer::LoadCharsets(const TessdataManager* mgr) {
  TFile fp;
  if (!mgr->GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
  if (!ccutil_.unicharset.load_from_file(&fp, false)) return false;
  if (!mgr->GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
  if (!LoadRecoder(&fp)) return false;
  return true;
}

// Loads the Recoder.
bool LSTMRecognizer::LoadRecoder(TFile* fp) {
  if (IsRecoding()) {
    if (!recoder_.DeSerialize(fp)) return false;
    RecodedCharID code;
    recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
    if (code(0) != UNICHAR_SPACE) {
      tprintf("Space was garbled in recoding!!\n");
      return false;
    }
  } else {
    recoder_.SetupPassThrough(GetUnicharset());
    training_flags_ |= TF_COMPRESS_UNICHARSET;
  }
  return true;
}

// Loads the dictionary if possible from the traineddata file.
// Prints a warning message, and returns false but otherwise fails silently
// and continues to work without it if loading fails.
// Note that dictionary load is independent from DeSerialize, but dependent
// on the unicharset matching. This enables training to deserialize a model
// from checkpoint or restore without having to go back and reload the
// dictionary.
// Some parameters have to be passed in (from langdata/config/api via Tesseract)
bool LSTMRecognizer::LoadDictionary(const ParamsVectors* params,
                                    const char* lang, TessdataManager* mgr) {
  delete dict_;
  dict_ = new Dict(&ccutil_);
  dict_->user_words_file.ResetFrom(params);
  dict_->user_words_suffix.ResetFrom(params);
  dict_->user_patterns_file.ResetFrom(params);
  dict_->user_patterns_suffix.ResetFrom(params);
  dict_->SetupForLoad(Dict::GlobalDawgCache());
  dict_->LoadLSTM(lang, mgr);
  if (dict_->FinishLoad()) return true;  // Success.
  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
          lang);
  delete dict_;
  dict_ = nullptr;
  return false;
}

// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
                                   bool debug, double worst_dict_cert,
                                   const TBOX& line_box,
                                   PointerVector<WERD_RES>* words,
                                   int lstm_choice_mode) {
  NetworkIO outputs;
  float scale_factor;
  NetworkIO inputs;
  if (!RecognizeLine(image_data, invert, debug, false, false, &scale_factor,
                     &inputs, &outputs))
    return;
  if (search_ == nullptr) {
    search_ =
        new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
  }
  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert,
                  &GetUnicharset(), lstm_choice_mode);
  search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
                                  &GetUnicharset(), words, lstm_choice_mode);
}

// Helper computes min and mean best results in the output.
void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
                                 float* mean_output, float* sd) {
  const int kOutputScale = INT8_MAX;
  STATS stats(0, kOutputScale + 1);
  for (int t = 0; t < outputs.Width(); ++t) {
    int best_label = outputs.BestLabel(t, nullptr);
    if (best_label != null_char_) {
      float best_output = outputs.f(t)[best_label];
      stats.add(static_cast<int>(kOutputScale * best_output), 1);
    }
  }
  // If the output is all nulls it could be that the photometric interpretation
  // is wrong, so make it look bad, so the other way can win, even if not great.
  if (stats.get_total() == 0) {
    *min_output = 0.0f;
    *mean_output = 0.0f;
    *sd = 1.0f;
  } else {
    *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
    *mean_output = stats.mean() / kOutputScale;
    *sd = stats.sd() / kOutputScale;
  }
}

// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
                                   bool debug, bool re_invert, bool upside_down,
                                   float* scale_factor, NetworkIO* inputs,
                                   NetworkIO* outputs) {
  // Maximum width of image to train on.
  const int kMaxImageWidth = 2560;
  // This ensures consistent recognition results.
  SetRandomSeed();
  int min_width = network_->XScaleFactor();
  Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
                                      &randomizer_, scale_factor);
  if (pix == nullptr) {
    tprintf("Line cannot be recognized!!\n");
    return false;
  }
  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
    tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
            pixGetHeight(pix));
    pixDestroy(&pix);
    return false;
  }
  if (upside_down) pixRotate180(pix, pix);
  // Reduction factor from image to coords.
  *scale_factor = min_width / *scale_factor;
  inputs->set_int_mode(IsIntMode());
  SetRandomSeed();
  Input::PreparePixInput(network_->InputShape(), pix, &randomizer_, inputs);
  network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
  // Check for auto inversion.
  float pos_min, pos_mean, pos_sd;
  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
  if (invert && pos_min < 0.5) {
    // Run again inverted and see if it is any better.
    NetworkIO inv_inputs, inv_outputs;
    inv_inputs.set_int_mode(IsIntMode());
    SetRandomSeed();
    pixInvert(pix, pix);
    Input::PreparePixInput(network_->InputShape(), pix, &randomizer_,
                           &inv_inputs);
    network_->Forward(debug, inv_inputs, nullptr, &scratch_space_,
                      &inv_outputs);
    float inv_min, inv_mean, inv_sd;
    OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
    if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
      // Inverted did better. Use inverted data.
      if (debug) {
        tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
                pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
      }
      *outputs = inv_outputs;
      *inputs = inv_inputs;
    } else if (re_invert) {
      // Inverting was not an improvement, so undo and run again, so the
      // outputs match the best forward result.
      SetRandomSeed();
      network_->Forward(debug, *inputs, nullptr, &scratch_space_, outputs);
    }
  }
  pixDestroy(&pix);
  if (debug) {
    GenericVector<int> labels, coords;
    LabelsFromOutputs(*outputs, &labels, &coords);
    DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
    DebugActivationPath(*outputs, labels, coords);
  }
  return true;
}

// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING LSTMRecognizer::DecodeLabels(const GenericVector<int>& labels) {
  STRING result;
  int end = 1;
  for (int start = 0; start < labels.size(); start = end) {
    if (labels[start] == null_char_) {
      end = start + 1;
    } else {
      result += DecodeLabel(labels, start, &end, nullptr);
    }
  }
  return result;
}

// Displays the forward results in a window with the characters and
// boundaries as determined by the labels and label_coords.
void LSTMRecognizer::DisplayForward(const NetworkIO& inputs,
                                    const GenericVector<int>& labels,
                                    const GenericVector<int>& label_coords,
                                    const char* window_name,
                                    ScrollView** window) {
#ifndef GRAPHICS_DISABLED  // do nothing if there's no graphics
  Pix* input_pix = inputs.ToPix();
  Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
                       pixGetHeight(input_pix), window);
  int line_height = Network::DisplayImage(input_pix, *window);
  DisplayLSTMOutput(labels, label_coords, line_height, *window);
#endif  // GRAPHICS_DISABLED
}

// Displays the labels and cuts at the corresponding xcoords.
// Size of labels should match xcoords.
void LSTMRecognizer::DisplayLSTMOutput(const GenericVector<int>& labels,
                                       const GenericVector<int>& xcoords,
                                       int height, ScrollView* window) {
#ifndef GRAPHICS_DISABLED  // do nothing if there's no graphics
  int x_scale = network_->XScaleFactor();
  window->TextAttributes("Arial", height / 4, false, false, false);
  int end = 1;
  for (int start = 0; start < labels.size(); start = end) {
    int xpos = xcoords[start] * x_scale;
    if (labels[start] == null_char_) {
      end = start + 1;
      window->Pen(ScrollView::RED);
    } else {
      window->Pen(ScrollView::GREEN);
      const char* str = DecodeLabel(labels, start, &end, nullptr);
      if (*str == '\\') str = "\\\\";
      xpos = xcoords[(start + end) / 2] * x_scale;
      window->Text(xpos, height, str);
    }
    window->Line(xpos, 0, xpos, height * 3 / 2);
  }
  window->Update();
#endif  // GRAPHICS_DISABLED
}

// Prints debug output detailing the activation path that is implied by the
// label_coords.
void LSTMRecognizer::DebugActivationPath(const NetworkIO& outputs,
                                         const GenericVector<int>& labels,
                                         const GenericVector<int>& xcoords) {
  if (xcoords[0] > 0)
    DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
  int end = 1;
  for (int start = 0; start < labels.size(); start = end) {
    if (labels[start] == null_char_) {
      end = start + 1;
      DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
                           xcoords[end]);
      continue;
    } else {
      int decoded;
      const char* label = DecodeLabel(labels, start, &end, &decoded);
      DebugActivationRange(outputs, label, labels[start], xcoords[start],
                           xcoords[start + 1]);
      for (int i = start + 1; i < end; ++i) {
        DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
                             xcoords[i], xcoords[i + 1]);
      }
    }
  }
}

// Prints debug output detailing activations and 2nd choice over a range
// of positions.
void LSTMRecognizer::DebugActivationRange(const NetworkIO& outputs,
                                          const char* label, int best_choice,
                                          int x_start, int x_end) {
  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
  double max_score = 0.0;
  double mean_score = 0.0;
  const int width = x_end - x_start;
  for (int x = x_start; x < x_end; ++x) {
    const float* line = outputs.f(x);
    const double score = line[best_choice] * 100.0;
    if (score > max_score) max_score = score;
    mean_score += score / width;
    int best_c = 0;
    double best_score = 0.0;
    for (int c = 0; c < outputs.NumFeatures(); ++c) {
      if (c != best_choice && line[c] > best_score) {
        best_c = c;
        best_score = line[c];
      }
    }
    tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
            best_score * 100.0);
  }
  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
}

// Helper returns true if the null_char is the winner at t, and it beats the
// null_threshold, or the next choice is space, in which case we will use the
// null anyway.
#if 0  // TODO: unused, remove if still unused after 2020.
static bool NullIsBest(const NetworkIO& output, float null_thr,
                       int null_char, int t) {
  if (output.f(t)[null_char] >= null_thr) return true;
  if (output.BestLabel(t, null_char, null_char, nullptr) != UNICHAR_SPACE)
    return false;
  return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
}
#endif

// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs,
                                       GenericVector<int>* labels,
                                       GenericVector<int>* xcoords) {
  if (SimpleTextOutput()) {
    LabelsViaSimpleText(outputs, labels, xcoords);
  } else {
    LabelsViaReEncode(outputs, labels, xcoords);
  }
}

// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for CJK.
void LSTMRecognizer::LabelsViaReEncode(const NetworkIO& output,
                                       GenericVector<int>* labels,
                                       GenericVector<int>* xcoords) {
  if (search_ == nullptr) {
    search_ =
        new RecodeBeamSearch(recoder_, null_char_, SimpleTextOutput(), dict_);
  }
  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, nullptr);
  search_->ExtractBestPathAsLabels(labels, xcoords);
}

// Converts the network output to a sequence of labels, with scores, using
// the simple character model (each position is a char, and the null_char_ is
// mainly intended for tail padding.)
void LSTMRecognizer::LabelsViaSimpleText(const NetworkIO& output,
                                         GenericVector<int>* labels,
                                         GenericVector<int>* xcoords) {
  labels->truncate(0);
  xcoords->truncate(0);
  const int width = output.Width();
  for (int t = 0; t < width; ++t) {
    float score = 0.0f;
    const int label = output.BestLabel(t, &score);
    if (label != null_char_) {
      labels->push_back(label);
      xcoords->push_back(t);
    }
  }
  xcoords->push_back(width);
}

// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* LSTMRecognizer::DecodeLabel(const GenericVector<int>& labels,
                                        int start, int* end, int* decoded) {
  *end = start + 1;
  if (IsRecoding()) {
    // Decode labels via recoder_.
    RecodedCharID code;
    if (labels[start] == null_char_) {
      if (decoded != nullptr) {
        code.Set(0, null_char_);
        *decoded = recoder_.DecodeUnichar(code);
      }
      return "<null>";
    }
    int index = start;
    while (index < labels.size() &&
           code.length() < RecodedCharID::kMaxCodeLen) {
      code.Set(code.length(), labels[index++]);
      while (index < labels.size() && labels[index] == null_char_) ++index;
      int uni_id = recoder_.DecodeUnichar(code);
      // If the next label isn't a valid first code, then we need to continue
      // extending even if we have a valid uni_id from this prefix.
      if (uni_id != INVALID_UNICHAR_ID &&
          (index == labels.size() ||
           code.length() == RecodedCharID::kMaxCodeLen ||
           recoder_.IsValidFirstCode(labels[index]))) {
        *end = index;
        if (decoded != nullptr) *decoded = uni_id;
        if (uni_id == UNICHAR_SPACE) return " ";
        return GetUnicharset().get_normed_unichar(uni_id);
      }
    }
    return "<Undecodable>";
  } else {
    if (decoded != nullptr) *decoded = labels[start];
    if (labels[start] == null_char_) return "<null>";
    if (labels[start] == UNICHAR_SPACE) return " ";
    return GetUnicharset().get_normed_unichar(labels[start]);
  }
}

// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* LSTMRecognizer::DecodeSingleLabel(int label) {
  if (label == null_char_) return "<null>";
  if (IsRecoding()) {
    // Decode label via recoder_.
    RecodedCharID code;
    code.Set(0, label);
    label = recoder_.DecodeUnichar(code);
    if (label == INVALID_UNICHAR_ID) return "..";  // Part of a bigger code.
  }
  if (label == UNICHAR_SPACE) return " ";
  return GetUnicharset().get_normed_unichar(label);
}

}  // namespace tesseract.
