/*
 * librdkafka - The Apache Kafka C/C++ library
 *
 * Copyright (c) 2020-2022, Magnus Edenhill
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include "rd.h"
#include "rdsysqueue.h"
#include "rdstring.h"
#include "rdmap.h"


static RD_INLINE int rd_map_elem_cmp(const rd_map_elem_t *a,
                                     const rd_map_elem_t *b,
                                     const rd_map_t *rmap) {
        int r = a->hash - b->hash;
        if (r != 0)
                return r;
        return rmap->rmap_cmp(a->key, b->key);
}

static void rd_map_elem_destroy(rd_map_t *rmap, rd_map_elem_t *elem) {
        rd_assert(rmap->rmap_cnt > 0);
        rmap->rmap_cnt--;
        if (rmap->rmap_destroy_key)
                rmap->rmap_destroy_key((void *)elem->key);
        if (rmap->rmap_destroy_value)
                rmap->rmap_destroy_value((void *)elem->value);
        LIST_REMOVE(elem, hlink);
        LIST_REMOVE(elem, link);
        rd_free(elem);
}

static rd_map_elem_t *
rd_map_find(const rd_map_t *rmap, int *bktp, const rd_map_elem_t *skel) {
        int bkt = skel->hash % rmap->rmap_buckets.cnt;
        rd_map_elem_t *elem;

        if (bktp)
                *bktp = bkt;

        LIST_FOREACH(elem, &rmap->rmap_buckets.p[bkt], hlink) {
                if (!rd_map_elem_cmp(skel, elem, rmap))
                        return elem;
        }

        return NULL;
}


/**
 * @brief Create and return new element based on \p skel without value set.
 */
static rd_map_elem_t *
rd_map_insert(rd_map_t *rmap, int bkt, const rd_map_elem_t *skel) {
        rd_map_elem_t *elem;

        elem       = rd_calloc(1, sizeof(*elem));
        elem->hash = skel->hash;
        elem->key  = skel->key; /* takes ownership of key */
        LIST_INSERT_HEAD(&rmap->rmap_buckets.p[bkt], elem, hlink);
        LIST_INSERT_HEAD(&rmap->rmap_iter, elem, link);
        rmap->rmap_cnt++;

        return elem;
}


rd_map_elem_t *rd_map_set(rd_map_t *rmap, void *key, void *value) {
        rd_map_elem_t skel = {.key = key, .hash = rmap->rmap_hash(key)};
        rd_map_elem_t *elem;
        int bkt;

        if (!(elem = rd_map_find(rmap, &bkt, &skel))) {
                elem = rd_map_insert(rmap, bkt, &skel);
        } else {
                if (elem->value && rmap->rmap_destroy_value)
                        rmap->rmap_destroy_value((void *)elem->value);
                if (rmap->rmap_destroy_key)
                        rmap->rmap_destroy_key(key);
        }

        elem->value = value; /* takes ownership of value */

        return elem;
}


void *rd_map_get(const rd_map_t *rmap, const void *key) {
        const rd_map_elem_t skel = {.key  = (void *)key,
                                    .hash = rmap->rmap_hash(key)};
        rd_map_elem_t *elem;

        if (!(elem = rd_map_find(rmap, NULL, &skel)))
                return NULL;

        return (void *)elem->value;
}


void rd_map_delete(rd_map_t *rmap, const void *key) {
        const rd_map_elem_t skel = {.key  = (void *)key,
                                    .hash = rmap->rmap_hash(key)};
        rd_map_elem_t *elem;
        int bkt;

        if (!(elem = rd_map_find(rmap, &bkt, &skel)))
                return;

        rd_map_elem_destroy(rmap, elem);
}


void rd_map_copy(rd_map_t *dst,
                 const rd_map_t *src,
                 rd_map_copy_t *key_copy,
                 rd_map_copy_t *value_copy) {
        const rd_map_elem_t *elem;

        RD_MAP_FOREACH_ELEM(elem, src) {
                rd_map_set(
                    dst, key_copy ? key_copy(elem->key) : (void *)elem->key,
                    value_copy ? value_copy(elem->value) : (void *)elem->value);
        }
}


void rd_map_iter_begin(const rd_map_t *rmap, const rd_map_elem_t **elem) {
        *elem = LIST_FIRST(&rmap->rmap_iter);
}

size_t rd_map_cnt(const rd_map_t *rmap) {
        return (size_t)rmap->rmap_cnt;
}

rd_bool_t rd_map_is_empty(const rd_map_t *rmap) {
        return rmap->rmap_cnt == 0;
}


/**
 * @brief Calculates the number of desired buckets and returns
 *        a struct with pre-allocated buckets.
 */
struct rd_map_buckets rd_map_alloc_buckets(size_t expected_cnt) {
        static const int max_depth      = 15;
        static const int bucket_sizes[] = {
            5,     11,    23,     47,     97,   199, /* default */
            409,   823,   1741,   3469,   6949, 14033,
            28411, 57557, 116731, 236897, -1};
        struct rd_map_buckets buckets = RD_ZERO_INIT;
        int i;

        if (!expected_cnt) {
                buckets.cnt = 199;
        } else {
                /* Strive for an average (at expected element count) depth
                 * of 15 elements per bucket, but limit the maximum
                 * bucket count to the maximum value in bucket_sizes above.
                 * When a real need arise we'll change this to a dynamically
                 * growing hash map instead, but this will do for now. */
                buckets.cnt = bucket_sizes[0];
                for (i = 1; bucket_sizes[i] != -1 &&
                            (int)expected_cnt / max_depth > bucket_sizes[i];
                     i++)
                        buckets.cnt = bucket_sizes[i];
        }

        rd_assert(buckets.cnt > 0);

        buckets.p = rd_calloc(buckets.cnt, sizeof(*buckets.p));

        return buckets;
}


void rd_map_init(rd_map_t *rmap,
                 size_t expected_cnt,
                 int (*cmp)(const void *a, const void *b),
                 unsigned int (*hash)(const void *key),
                 void (*destroy_key)(void *key),
                 void (*destroy_value)(void *value)) {

        memset(rmap, 0, sizeof(*rmap));
        rmap->rmap_buckets       = rd_map_alloc_buckets(expected_cnt);
        rmap->rmap_cmp           = cmp;
        rmap->rmap_hash          = hash;
        rmap->rmap_destroy_key   = destroy_key;
        rmap->rmap_destroy_value = destroy_value;
}

void rd_map_clear(rd_map_t *rmap) {
        rd_map_elem_t *elem;

        while ((elem = LIST_FIRST(&rmap->rmap_iter)))
                rd_map_elem_destroy(rmap, elem);
}

void rd_map_destroy(rd_map_t *rmap) {
        rd_map_clear(rmap);
        rd_free(rmap->rmap_buckets.p);
}


int rd_map_str_cmp(const void *a, const void *b) {
        return strcmp((const char *)a, (const char *)b);
}

/**
 * @brief A djb2 string hasher.
 */
unsigned int rd_map_str_hash(const void *key) {
        const char *str = key;
        return rd_string_hash(str, -1);
}



/**
 * @name Unit tests
 *
 */
#include "rdtime.h"
#include "rdunittest.h"
#include "rdcrc32.h"


/**
 * Typed hash maps
 */

/* Complex key type */
struct mykey {
        int k;
        int something_else; /* Ignored by comparator and hasher below */
};

/* Key comparator */
static int mykey_cmp(const void *_a, const void *_b) {
        const struct mykey *a = _a, *b = _b;
        return a->k - b->k;
}

/* Key hasher */
static unsigned int mykey_hash(const void *_key) {
        const struct mykey *key = _key;
        return (unsigned int)key->k;
}

/* Complex value type */
struct person {
        char *name;
        char *surname;
};

/* Define typed hash map type */
typedef RD_MAP_TYPE(const struct mykey *,
                    const struct person *) ut_my_typed_map_t;


/**
 * @brief Test typed hash map with pre-defined type.
 */
static int unittest_typed_map(void) {
        ut_my_typed_map_t rmap =
            RD_MAP_INITIALIZER(0, mykey_cmp, mykey_hash, NULL, NULL);
        ut_my_typed_map_t dup =
            RD_MAP_INITIALIZER(0, mykey_cmp, mykey_hash, NULL, NULL);
        struct mykey k1  = {1};
        struct mykey k2  = {2};
        struct person v1 = {"Roy", "McPhearsome"};
        struct person v2 = {"Hedvig", "Lindahl"};
        const struct mykey *key;
        const struct person *value;

        RD_MAP_SET(&rmap, &k1, &v1);
        RD_MAP_SET(&rmap, &k2, &v2);

        value = RD_MAP_GET(&rmap, &k2);
        RD_UT_ASSERT(value == &v2, "mismatch");

        RD_MAP_FOREACH(key, value, &rmap) {
                RD_UT_SAY("enumerated key %d person %s %s", key->k, value->name,
                          value->surname);
        }

        RD_MAP_COPY(&dup, &rmap, NULL, NULL);

        RD_MAP_DELETE(&rmap, &k1);
        value = RD_MAP_GET(&rmap, &k1);
        RD_UT_ASSERT(value == NULL, "expected no k1");

        value = RD_MAP_GET(&dup, &k1);
        RD_UT_ASSERT(value == &v1, "copied map: k1 mismatch");
        value = RD_MAP_GET(&dup, &k2);
        RD_UT_ASSERT(value == &v2, "copied map: k2 mismatch");

        RD_MAP_DESTROY(&rmap);
        RD_MAP_DESTROY(&dup);

        RD_UT_PASS();
}


static int person_cmp(const void *_a, const void *_b) {
        const struct person *a = _a, *b = _b;
        int r;
        if ((r = strcmp(a->name, b->name)))
                return r;
        return strcmp(a->surname, b->surname);
}
static unsigned int person_hash(const void *_key) {
        const struct person *key = _key;
        return 31 * rd_map_str_hash(key->name) + rd_map_str_hash(key->surname);
}

/**
 * @brief Test typed hash map with locally defined type.
 */
static int unittest_typed_map2(void) {
        RD_MAP_LOCAL_INITIALIZER(usermap, 3, const char *,
                                 const struct person *, rd_map_str_cmp,
                                 rd_map_str_hash, NULL, NULL);
        RD_MAP_LOCAL_INITIALIZER(personmap, 3, const struct person *,
                                 const char *, person_cmp, person_hash, NULL,
                                 NULL);
        struct person p1 = {"Magnus", "Lundstrom"};
        struct person p2 = {"Peppy", "Popperpappies"};
        const char *user;
        const struct person *person;

        /* Populate user -> person map */
        RD_MAP_SET(&usermap, "user1234", &p1);
        RD_MAP_SET(&usermap, "user9999999999", &p2);

        person = RD_MAP_GET(&usermap, "user1234");


        RD_UT_ASSERT(person == &p1, "mismatch");

        RD_MAP_FOREACH(user, person, &usermap) {
                /* Populate reverse name -> user map */
                RD_MAP_SET(&personmap, person, user);
        }

        RD_MAP_FOREACH(person, user, &personmap) {
                /* Just reference the memory to catch memory errors.*/
                RD_UT_ASSERT(strlen(person->name) > 0 &&
                                 strlen(person->surname) > 0 &&
                                 strlen(user) > 0,
                             "bug");
        }

        RD_MAP_DESTROY(&usermap);
        RD_MAP_DESTROY(&personmap);

        return 0;
}


/**
 * @brief Untyped hash map.
 *
 * This is a more thorough test of the underlying hash map implementation.
 */
static int unittest_untyped_map(void) {
        rd_map_t rmap;
        int pass, i, r;
        int cnt     = 100000;
        int exp_cnt = 0, get_cnt = 0, iter_cnt = 0;
        const rd_map_elem_t *elem;
        rd_ts_t ts     = rd_clock();
        rd_ts_t ts_get = 0;

        rd_map_init(&rmap, cnt, rd_map_str_cmp, rd_map_str_hash, rd_free,
                    rd_free);

        /* pass 0 is set,delete,overwrite,get
         * pass 1-5 is get */
        for (pass = 0; pass < 6; pass++) {
                if (pass == 1)
                        ts_get = rd_clock();

                for (i = 1; i < cnt; i++) {
                        char key[10];
                        char val[64];
                        const char *val2;
                        rd_bool_t do_delete = !(i % 13);
                        rd_bool_t overwrite = !do_delete && !(i % 5);

                        rd_snprintf(key, sizeof(key), "key%d", i);
                        rd_snprintf(val, sizeof(val), "VALUE=%d!", i);

                        if (pass == 0) {
                                rd_map_set(&rmap, rd_strdup(key),
                                           rd_strdup(val));

                                if (do_delete)
                                        rd_map_delete(&rmap, key);
                        }

                        if (overwrite) {
                                rd_snprintf(val, sizeof(val), "OVERWRITE=%d!",
                                            i);
                                if (pass == 0)
                                        rd_map_set(&rmap, rd_strdup(key),
                                                   rd_strdup(val));
                        }

                        val2 = rd_map_get(&rmap, key);

                        if (do_delete)
                                RD_UT_ASSERT(!val2,
                                             "map_get pass %d "
                                             "returned value %s "
                                             "for deleted key %s",
                                             pass, val2, key);
                        else
                                RD_UT_ASSERT(val2 && !strcmp(val, val2),
                                             "map_get pass %d: "
                                             "expected value %s, not %s, "
                                             "for key %s",
                                             pass, val, val2 ? val2 : "NULL",
                                             key);

                        if (pass == 0 && !do_delete)
                                exp_cnt++;
                }

                if (pass >= 1)
                        get_cnt += cnt;
        }

        ts_get = rd_clock() - ts_get;
        RD_UT_SAY("%d map_get iterations took %.3fms = %" PRId64 "us/get",
                  get_cnt, (float)ts_get / 1000.0, ts_get / get_cnt);

        RD_MAP_FOREACH_ELEM(elem, &rmap) {
                iter_cnt++;
        }

        r = (int)rd_map_cnt(&rmap);
        RD_UT_ASSERT(r == exp_cnt, "expected %d map entries, not %d", exp_cnt,
                     r);

        RD_UT_ASSERT(r == iter_cnt,
                     "map_cnt() = %d, iteration gave %d elements", r, iter_cnt);

        rd_map_destroy(&rmap);

        ts = rd_clock() - ts;
        RD_UT_SAY("Total time over %d entries took %.3fms", cnt,
                  (float)ts / 1000.0);

        RD_UT_PASS();
}


int unittest_map(void) {
        int fails = 0;
        fails += unittest_untyped_map();
        fails += unittest_typed_map();
        fails += unittest_typed_map2();
        return 0;
}
