import { TensorScriptModelInterface, TensorScriptOptions, TensorScriptProperties, Matrix, Vector, PredictionOptions, InputTextArray, } from './model_interface';
import * as UniversalSentenceEncoder from '@tensorflow-models/universal-sentence-encoder';


let model:UniversalSentenceEncoder.UniversalSentenceEncoder;
let tokenizer:UniversalSentenceEncoder.Tokenizer;
/**
 * Text Embedding with Tensorflow Universal Sentence Encoder (USE)
 * @class TextEmbedding
 * @implements {TensorScriptModelInterface}
 */
export class TextEmbedding extends TensorScriptModelInterface {
  /**
   * @param {Object} options - Options for USE
   * @param {{model:Object,tf:Object,}} properties - extra instance properties
   */
  constructor(options:TensorScriptOptions = {}, properties?:TensorScriptProperties) {
    const config = Object.assign({
    }, options);
    super(config, properties);
    this.type = 'TextEmbedding';

    return this;
  }
  /**
   * Asynchronously loads Universal Sentence Encoder and tokenizer
   * @override
   * @return {Object} returns loaded UniversalSentenceEncoder model
   */
  async train() {
    const promises:Promise<any>[] = [];
    if (!model) promises.push(UniversalSentenceEncoder.load());
    else promises.push(Promise.resolve(model));
    if (!tokenizer) promises.push(UniversalSentenceEncoder.loadTokenizer());
    else promises.push(Promise.resolve(tokenizer));
    const USE = await Promise.all(promises);
    if (!model) model = USE[ 0 ];
    if (!tokenizer) tokenizer = USE[ 1 ];
    this.model = model;
    this.tokenizer = tokenizer;
    this.trained = true;
    this.compiled = true;

    return this.model;
  }
  /**
   * Calculates sentence embeddings
   * @override
   * @param {Array<Array<number>>|Array<number>} input_array - new test independent variables
   * @param {Object} options - model prediction options
   * @return {{data: Promise}} returns tensorflow prediction 
   */
  calculate(input_array:InputTextArray, options = {}) {
    if (!input_array || Array.isArray(input_array) === false) throw new Error('invalid input array of sentences');
    const embeddings = this.model.embed(input_array);
    return embeddings;
  }
  /**
   * Returns prediction values from tensorflow model
   * @param {Array<string>} input_matrix - array of sentences to embed 
   * @param {Boolean} [options.json=true] - return object instead of typed array
   * @param {Boolean} [options.probability=true] - return real values instead of integers
   * @return {Array<Array<number>>} predicted model values
   */
  async predict(input_array:InputTextArray, options:PredictionOptions = {}): Promise<Matrix|Vector> {
    const config = Object.assign({
      json: true,
      probability: true,
    }, options);
    const embeddings = await this.calculate(input_array, options);
    const predictions:number[] = await embeddings.data(); 
    if (config.json === false) {
      return predictions;
    } else {
      const shape = [input_array.length, 512, ];
      const predictionValues = (options.probability === false)
        ? Array.from(predictions).map(Math.round)
        : Array.from(predictions);
      return this.reshape(predictionValues, shape);
    }
  }
}