/*
 *     Copyright 2015 Couchbase, Inc.
 *
 *   Licensed under the Apache License, Version 2.0 (the "License");
 *   you may not use this file except in compliance with the License.
 *   You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *   Unless required by applicable law or agreed to in writing, software
 *   distributed under the License is distributed on an "AS IS" BASIS,
 *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *   See the License for the specific language governing permissions and
 *   limitations under the License.
 */

#define LCB_NO_DEPR_CXX_CTORS

#include "config.h"
#include <sys/types.h>
#include <libcouchbase/couchbase.h>
#include <libcouchbase/api3.h>
#include <libcouchbase/n1ql.h>
#include <libcouchbase/vbucket.h>
#include <iostream>
#include <list>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cerrno>
#include <fstream>
#include <algorithm> // random_shuffle
#include <stdexcept>
#include <sstream>
#ifndef WIN32
#include <pthread.h>
#include <unistd.h> // isatty()
#endif
#include "common/options.h"

using namespace cbc;
using namespace cliopts;
using std::vector;
using std::string;
using std::cerr;
using std::endl;

static void do_or_die(lcb_error_t rc)
{
    if (rc != LCB_SUCCESS) {
        std::stringstream ss;
        ss << "[" << std::hex << rc << "] " << lcb_strerror(NULL, rc);
        throw std::runtime_error(ss.str());
    }
}

class Metrics {
public:
    Metrics()
    : n_rows(0), n_queries(0), n_errors(0) {
        last_update = time(NULL);
    }

    void update_row(size_t n = 1) { n_rows += n; update_display(); }
    void update_done(size_t n = 1) { n_queries += n; update_display(); }
    void update_error(size_t n = 1) { n_errors += n; update_display(); }

#ifndef _WIN32
    bool is_tty() const { return isatty(STDERR_FILENO); }
    void lock() { pthread_mutex_lock(&m_lock); }
    void unlock() { pthread_mutex_unlock(&m_lock); }
#else
    void lock(){}
    void unlock(){}
    bool is_tty() const { return false; }
#endif
    void prepare_screen()
    {
        if (is_tty()) {
            printf("\n\n\n");
        }
    }

private:
    void update_display()
    {
        if (!is_tty()) {
            return;
        }

        time_t now = time(NULL);
        time_t duration = now - last_update;

        if (!duration) {
            return;
        }

        last_update = now;

        // Move up 3 cursors
        printf("\x1B[2A");
        printf("\x1B[KQUERIES/SEC: %lu\n", n_queries / duration);
        printf("\x1B[KROWS/SEC:    %lu\n", n_rows / duration);
        printf("\x1B[KERRORS:      %lu\r", n_errors);
        fflush(stdout);

        n_queries = 0;
        n_rows = 0;
    }

    size_t n_rows;
    size_t n_queries;
    size_t n_errors;
    time_t last_update;
#ifndef _WIN32
    pthread_mutex_t m_lock;
#endif
};

Metrics GlobalMetrics;

class Configuration
{
public:
    Configuration() : o_file("queryfile"), o_threads("num-threads") {
        o_file.mandatory(true);
        o_file.description(
            "Path to a file containing all the queries to execute. "
            "Each line should contain the full query body");
        o_file.abbrev('f');

        o_threads.description("Number of threads to run");
        o_threads.abbrev('t');
        o_threads.setDefault(1);
    }

    void addToParser(Parser& parser)
    {
        parser.addOption(o_file);
        parser.addOption(o_threads);
        m_params.addToParser(parser);
    }

    void processOptions()
    {
        std::ifstream ifs(o_file.const_result().c_str());
        if (!ifs.is_open()) {
            int ec_save = errno;
            string errstr(o_file.const_result());
            errstr += ": ";
            errstr += strerror(ec_save);
            throw std::runtime_error(errstr);
        }

        string curline;
        while (std::getline(ifs, curline).good() && !ifs.eof()) {
            m_queries.push_back(curline);
        }
    }

    void set_cropts(lcb_create_st &opts) { m_params.fillCropts(opts); }
    const vector<string>& queries() const { return m_queries; }
    size_t nthreads() { return o_threads.result(); }

private:
    vector<string> m_queries;
    StringOption o_file;
    UIntOption o_threads;
    ConnParams m_params;
};

extern "C" { static void n1qlcb(lcb_t, int, const lcb_RESPN1QL *resp); }
extern "C" { static void* pthrfunc(void*); }

class ThreadContext {
public:
    void run()
    {
        while (!m_cancelled) {
            vector<string>::const_iterator ii = m_queries.begin();
            for (; ii != m_queries.end(); ++ii) {
                run_one_query(*ii);
            }
        }
    }

#ifndef _WIN32
    void start()
    {
        assert(m_thr == NULL);
        m_thr = new pthread_t;
        int rc = pthread_create(m_thr, NULL, pthrfunc, this);
        if (rc != 0) {
            throw std::runtime_error(strerror(rc));
        }
    }

    void join()
    {
        assert(m_thr != NULL);
        void *arg = NULL;
        pthread_join(*m_thr, &arg);
    }

    ~ThreadContext()
    {
        if (m_thr != NULL) {
            join();
            delete m_thr;
            m_thr = NULL;
        }
    }
#else
    void start() { run(); }
    void join() {}
#endif

    void handle_response(const lcb_RESPN1QL *resp)
    {
        if (resp->rflags & LCB_RESP_F_FINAL) {
            if (resp->rc != LCB_SUCCESS) {
                log_error(resp->rc);
            }
        } else {
            last_nrow++;
        }
    }

    ThreadContext(lcb_t instance, const vector<string>& initial_queries)
    : m_instance(instance), last_nerr(0), last_nrow(0),
      m_metrics(&GlobalMetrics), m_cancelled(false), m_thr(NULL)
    {
        memset(&m_cmd, 0, sizeof m_cmd);
        m_cmd.content_type = "application/json";
        m_cmd.callback = n1qlcb;

        // Shuffle the list
        m_queries = initial_queries;
        std::random_shuffle(m_queries.begin(), m_queries.end());
    }

private:

    void log_error(lcb_error_t)
    {
        m_metrics->lock();
        m_metrics->update_error();
        m_metrics->unlock();
    }

    void run_one_query(const string& txt)
    {
        // Reset counters
        last_nrow = 0;
        last_nerr = 0;

        m_cmd.query = txt.c_str();
        m_cmd.nquery = txt.size();

        lcb_error_t rc = lcb_n1ql_query(m_instance, this, &m_cmd);
        if (rc != LCB_SUCCESS) {
            log_error(rc);
        } else {
            lcb_wait(m_instance);
            m_metrics->lock();
            m_metrics->update_row(last_nrow);
            m_metrics->update_done(1);
            m_metrics->unlock();
        }
    }

    lcb_t m_instance;
    vector<string> m_queries;
    size_t last_nerr;
    size_t last_nrow;
    lcb_CMDN1QL m_cmd;
    Metrics *m_metrics;
    volatile bool m_cancelled;
    #ifndef _WIN32
    pthread_t *m_thr;
    #else
    void *m_thr;
    #endif
};

static void n1qlcb(lcb_t, int, const lcb_RESPN1QL *resp)
{
    reinterpret_cast<ThreadContext*>(resp->cookie)->handle_response(resp);
}

static void* pthrfunc(void *arg)
{
    reinterpret_cast<ThreadContext*>(arg)->run();
    return NULL;
}

static bool
instance_has_n1ql(lcb_t instance) {
    // Check that the instance supports N1QL
    lcbvb_CONFIG *vbc;
    do_or_die(lcb_cntl(instance, LCB_CNTL_GET, LCB_CNTL_VBCONFIG, &vbc));

    int sslmode = 0;
    lcbvb_SVCMODE svcmode = LCBVB_SVCMODE_PLAIN;

    do_or_die(lcb_cntl(instance, LCB_CNTL_GET, LCB_CNTL_SSL_MODE, &sslmode));

    if (sslmode & LCB_SSL_ENABLED) {
        svcmode = LCBVB_SVCMODE_SSL;
    }

    int hix = lcbvb_get_randhost(vbc, LCBVB_SVCTYPE_N1QL, svcmode);
    return hix > -1;
}

static void real_main(int argc, char **argv) {
    Configuration config;
    Parser parser;
    config.addToParser(parser);
    parser.parse(argc, argv);

    vector<ThreadContext*> threads;
    vector<lcb_t> instances;

    lcb_create_st cropts = { 0 };
    config.set_cropts(cropts);
    config.processOptions();

    for (size_t ii = 0; ii < config.nthreads(); ii++) {
        lcb_t instance;
        do_or_die(lcb_create(&instance, &cropts));
        do_or_die(lcb_connect(instance));
        lcb_wait(instance);
        do_or_die(lcb_get_bootstrap_status(instance));

        if (ii == 0 && !instance_has_n1ql(instance)) {
            throw std::runtime_error("Cluster does not support N1QL!");
        }

        ThreadContext* cx = new ThreadContext(instance, config.queries());
        threads.push_back(cx);
        instances.push_back(instance);
    }

    for (size_t ii = 0; ii < threads.size(); ++ii) {
        threads[ii]->start();
    }
    for (size_t ii = 0; ii < threads.size(); ++ii) {
        threads[ii]->join();
    }
    for (size_t ii = 0; ii < instances.size(); ++ii) {
        lcb_destroy(instances[ii]);
    }
}

int main(int argc, char **argv)
{
    GlobalMetrics.prepare_screen();
    try {
        real_main(argc, argv);
        return 0;
    } catch (std::exception& exc) {
        cerr << exc.what() << endl;
        exit(EXIT_FAILURE);
    }
}
