/**********************************************************************
 *
 * Name:     cpl_hash_set.cpp
 * Project:  CPL - Common Portability Library
 * Purpose:  Hash set functions.
 * Author:   Even Rouault, <even dot rouault at spatialys.com>
 *
 **********************************************************************
 * Copyright (c) 2008-2009, Even Rouault <even dot rouault at spatialys.com>
 *
 * SPDX-License-Identifier: MIT
 ****************************************************************************/

#include "cpl_hash_set.h"

#include <cstring>

#include "cpl_conv.h"
#include "cpl_error.h"
#include "cpl_list.h"

struct _CPLHashSet
{
    CPLHashSetHashFunc fnHashFunc;
    CPLHashSetEqualFunc fnEqualFunc;
    CPLHashSetFreeEltFunc fnFreeEltFunc;
    CPLList **tabList;
    int nSize;
    int nIndiceAllocatedSize;
    int nAllocatedSize;
    CPLList *psRecyclingList;
    int nRecyclingListSize;
    bool bRehash;
#ifdef HASH_DEBUG
    int nCollisions;
#endif
};

constexpr int anPrimes[] = {
    53,        97,        193,       389,       769,       1543,     3079,
    6151,      12289,     24593,     49157,     98317,     196613,   393241,
    786433,    1572869,   3145739,   6291469,   12582917,  25165843, 50331653,
    100663319, 201326611, 402653189, 805306457, 1610612741};

/************************************************************************/
/*                          CPLHashSetNew()                             */
/************************************************************************/

/**
 * Creates a new hash set
 *
 * The hash function must return a hash value for the elements to insert.
 * If fnHashFunc is NULL, CPLHashSetHashPointer will be used.
 *
 * The equal function must return if two elements are equal.
 * If fnEqualFunc is NULL, CPLHashSetEqualPointer will be used.
 *
 * The free function is used to free elements inserted in the hash set,
 * when the hash set is destroyed, when elements are removed or replaced.
 * If fnFreeEltFunc is NULL, elements inserted into the hash set will not be
 * freed.
 *
 * @param fnHashFunc hash function. May be NULL.
 * @param fnEqualFunc equal function. May be NULL.
 * @param fnFreeEltFunc element free function. May be NULL.
 *
 * @return a new hash set
 */

CPLHashSet *CPLHashSetNew(CPLHashSetHashFunc fnHashFunc,
                          CPLHashSetEqualFunc fnEqualFunc,
                          CPLHashSetFreeEltFunc fnFreeEltFunc)
{
    CPLHashSet *set = static_cast<CPLHashSet *>(CPLMalloc(sizeof(CPLHashSet)));
    set->fnHashFunc = fnHashFunc ? fnHashFunc : CPLHashSetHashPointer;
    set->fnEqualFunc = fnEqualFunc ? fnEqualFunc : CPLHashSetEqualPointer;
    set->fnFreeEltFunc = fnFreeEltFunc;
    set->nSize = 0;
    set->tabList = static_cast<CPLList **>(CPLCalloc(sizeof(CPLList *), 53));
    set->nIndiceAllocatedSize = 0;
    set->nAllocatedSize = 53;
    set->psRecyclingList = nullptr;
    set->nRecyclingListSize = 0;
    set->bRehash = false;
#ifdef HASH_DEBUG
    set->nCollisions = 0;
#endif
    return set;
}

/************************************************************************/
/*                          CPLHashSetSize()                            */
/************************************************************************/

/**
 * Returns the number of elements inserted in the hash set
 *
 * Note: this is not the internal size of the hash set
 *
 * @param set the hash set
 *
 * @return the number of elements in the hash set
 */

int CPLHashSetSize(const CPLHashSet *set)
{
    CPLAssert(set != nullptr);
    return set->nSize;
}

/************************************************************************/
/*                       CPLHashSetGetNewListElt()                      */
/************************************************************************/

static CPLList *CPLHashSetGetNewListElt(CPLHashSet *set)
{
    if (set->psRecyclingList)
    {
        CPLList *psRet = set->psRecyclingList;
        psRet->pData = nullptr;
        set->nRecyclingListSize--;
        set->psRecyclingList = psRet->psNext;
        return psRet;
    }

    return static_cast<CPLList *>(CPLMalloc(sizeof(CPLList)));
}

/************************************************************************/
/*                       CPLHashSetReturnListElt()                      */
/************************************************************************/

static void CPLHashSetReturnListElt(CPLHashSet *set, CPLList *psList)
{
    if (set->nRecyclingListSize < 128)
    {
        psList->psNext = set->psRecyclingList;
        set->psRecyclingList = psList;
        set->nRecyclingListSize++;
    }
    else
    {
        CPLFree(psList);
    }
}

/************************************************************************/
/*                   CPLHashSetClearInternal()                          */
/************************************************************************/

static void CPLHashSetClearInternal(CPLHashSet *set, bool bFinalize)
{
    CPLAssert(set != nullptr);
    for (int i = 0; i < set->nAllocatedSize; i++)
    {
        CPLList *cur = set->tabList[i];
        while (cur)
        {
            if (set->fnFreeEltFunc)
                set->fnFreeEltFunc(cur->pData);
            CPLList *psNext = cur->psNext;
            if (bFinalize)
                CPLFree(cur);
            else
                CPLHashSetReturnListElt(set, cur);
            cur = psNext;
        }
        set->tabList[i] = nullptr;
    }
    set->bRehash = false;
}

/************************************************************************/
/*                        CPLHashSetDestroy()                           */
/************************************************************************/

/**
 * Destroys an allocated hash set.
 *
 * This function also frees the elements if a free function was
 * provided at the creation of the hash set.
 *
 * @param set the hash set
 */

void CPLHashSetDestroy(CPLHashSet *set)
{
    CPLHashSetClearInternal(set, true);
    CPLFree(set->tabList);
    CPLListDestroy(set->psRecyclingList);
    CPLFree(set);
}

/************************************************************************/
/*                        CPLHashSetClear()                             */
/************************************************************************/

/**
 * Clear all elements from a hash set.
 *
 * This function also frees the elements if a free function was
 * provided at the creation of the hash set.
 *
 * @param set the hash set
 */

void CPLHashSetClear(CPLHashSet *set)
{
    CPLHashSetClearInternal(set, false);
    set->tabList = static_cast<CPLList **>(
        CPLRealloc(set->tabList, sizeof(CPLList *) * 53));
    set->nIndiceAllocatedSize = 0;
    set->nAllocatedSize = 53;
#ifdef HASH_DEBUG
    set->nCollisions = 0;
#endif
    set->nSize = 0;
}

/************************************************************************/
/*                       CPLHashSetForeach()                            */
/************************************************************************/

/**
 * Walk through the hash set and runs the provided function on all the
 * elements
 *
 * This function is provided the user_data argument of CPLHashSetForeach.
 * It must return TRUE to go on the walk through the hash set, or FALSE to
 * make it stop.
 *
 * Note : the structure of the hash set must *NOT* be modified during the
 * walk.
 *
 * @param set the hash set.
 * @param fnIterFunc the function called on each element.
 * @param user_data the user data provided to the function.
 */

void CPLHashSetForeach(CPLHashSet *set, CPLHashSetIterEltFunc fnIterFunc,
                       void *user_data)
{
    CPLAssert(set != nullptr);
    if (!fnIterFunc)
        return;

    for (int i = 0; i < set->nAllocatedSize; i++)
    {
        CPLList *cur = set->tabList[i];
        while (cur)
        {
            if (!fnIterFunc(cur->pData, user_data))
                return;

            cur = cur->psNext;
        }
    }
}

/************************************************************************/
/*                        CPLHashSetRehash()                            */
/************************************************************************/

static void CPLHashSetRehash(CPLHashSet *set)
{
    int nNewAllocatedSize = anPrimes[set->nIndiceAllocatedSize];
    CPLList **newTabList = static_cast<CPLList **>(
        CPLCalloc(sizeof(CPLList *), nNewAllocatedSize));
#ifdef HASH_DEBUG
    CPLDebug("CPLHASH",
             "hashSet=%p, nSize=%d, nCollisions=%d, "
             "fCollisionRate=%.02f",
             set, set->nSize, set->nCollisions,
             set->nCollisions * 100.0 / set->nSize);
    set->nCollisions = 0;
#endif
    for (int i = 0; i < set->nAllocatedSize; i++)
    {
        CPLList *cur = set->tabList[i];
        while (cur)
        {
            const unsigned long nNewHashVal =
                set->fnHashFunc(cur->pData) % nNewAllocatedSize;
#ifdef HASH_DEBUG
            if (newTabList[nNewHashVal])
                set->nCollisions++;
#endif
            CPLList *psNext = cur->psNext;
            cur->psNext = newTabList[nNewHashVal];
            newTabList[nNewHashVal] = cur;
            cur = psNext;
        }
    }
    CPLFree(set->tabList);
    set->tabList = newTabList;
    set->nAllocatedSize = nNewAllocatedSize;
    set->bRehash = false;
}

/************************************************************************/
/*                        CPLHashSetFindPtr()                           */
/************************************************************************/

static void **CPLHashSetFindPtr(CPLHashSet *set, const void *elt)
{
    const unsigned long nHashVal = set->fnHashFunc(elt) % set->nAllocatedSize;
    CPLList *cur = set->tabList[nHashVal];
    while (cur)
    {
        if (set->fnEqualFunc(cur->pData, elt))
            return &cur->pData;
        cur = cur->psNext;
    }
    return nullptr;
}

/************************************************************************/
/*                         CPLHashSetInsert()                           */
/************************************************************************/

/**
 * Inserts an element into a hash set.
 *
 * If the element was already inserted in the hash set, the previous
 * element is replaced by the new element. If a free function was provided,
 * it is used to free the previously inserted element
 *
 * @param set the hash set
 * @param elt the new element to insert in the hash set
 *
 * @return TRUE if the element was not already in the hash set
 */

int CPLHashSetInsert(CPLHashSet *set, void *elt)
{
    CPLAssert(set != nullptr);
    void **pElt = CPLHashSetFindPtr(set, elt);
    if (pElt)
    {
        if (set->fnFreeEltFunc)
            set->fnFreeEltFunc(*pElt);

        *pElt = elt;
        return FALSE;
    }

    if (set->nSize >= 2 * set->nAllocatedSize / 3 ||
        (set->bRehash && set->nIndiceAllocatedSize > 0 &&
         set->nSize <= set->nAllocatedSize / 2))
    {
        set->nIndiceAllocatedSize++;
        CPLHashSetRehash(set);
    }

    const unsigned long nHashVal = set->fnHashFunc(elt) % set->nAllocatedSize;
#ifdef HASH_DEBUG
    if (set->tabList[nHashVal])
        set->nCollisions++;
#endif

    CPLList *new_elt = CPLHashSetGetNewListElt(set);
    new_elt->pData = elt;
    new_elt->psNext = set->tabList[nHashVal];
    set->tabList[nHashVal] = new_elt;
    set->nSize++;

    return TRUE;
}

/************************************************************************/
/*                        CPLHashSetLookup()                            */
/************************************************************************/

/**
 * Returns the element found in the hash set corresponding to the element to
 * look up The element must not be modified.
 *
 * @param set the hash set
 * @param elt the element to look up in the hash set
 *
 * @return the element found in the hash set or NULL
 */

void *CPLHashSetLookup(CPLHashSet *set, const void *elt)
{
    CPLAssert(set != nullptr);
    void **pElt = CPLHashSetFindPtr(set, elt);
    if (pElt)
        return *pElt;

    return nullptr;
}

/************************************************************************/
/*                     CPLHashSetRemoveInternal()                       */
/************************************************************************/

static bool CPLHashSetRemoveInternal(CPLHashSet *set, const void *elt,
                                     bool bDeferRehash)
{
    CPLAssert(set != nullptr);
    if (set->nIndiceAllocatedSize > 0 && set->nSize <= set->nAllocatedSize / 2)
    {
        set->nIndiceAllocatedSize--;
        if (bDeferRehash)
            set->bRehash = true;
        else
            CPLHashSetRehash(set);
    }

    int nHashVal = static_cast<int>(set->fnHashFunc(elt) % set->nAllocatedSize);
    CPLList *cur = set->tabList[nHashVal];
    CPLList *prev = nullptr;
    while (cur)
    {
        if (set->fnEqualFunc(cur->pData, elt))
        {
            if (prev)
                prev->psNext = cur->psNext;
            else
                set->tabList[nHashVal] = cur->psNext;

            if (set->fnFreeEltFunc)
                set->fnFreeEltFunc(cur->pData);

            CPLHashSetReturnListElt(set, cur);
#ifdef HASH_DEBUG
            if (set->tabList[nHashVal])
                set->nCollisions--;
#endif
            set->nSize--;
            return true;
        }
        prev = cur;
        cur = cur->psNext;
    }
    return false;
}

/************************************************************************/
/*                         CPLHashSetRemove()                           */
/************************************************************************/

/**
 * Removes an element from a hash set
 *
 * @param set the hash set
 * @param elt the new element to remove from the hash set
 *
 * @return TRUE if the element was in the hash set
 */

int CPLHashSetRemove(CPLHashSet *set, const void *elt)
{
    return CPLHashSetRemoveInternal(set, elt, false);
}

/************************************************************************/
/*                     CPLHashSetRemoveDeferRehash()                    */
/************************************************************************/

/**
 * Removes an element from a hash set.
 *
 * This will defer potential rehashing of the set to later calls to
 * CPLHashSetInsert() or CPLHashSetRemove().
 *
 * @param set the hash set
 * @param elt the new element to remove from the hash set
 *
 * @return TRUE if the element was in the hash set
 */

int CPLHashSetRemoveDeferRehash(CPLHashSet *set, const void *elt)
{
    return CPLHashSetRemoveInternal(set, elt, true);
}

/************************************************************************/
/*                    CPLHashSetHashPointer()                           */
/************************************************************************/

/**
 * Hash function for an arbitrary pointer
 *
 * @param elt the arbitrary pointer to hash
 *
 * @return the hash value of the pointer
 */

unsigned long CPLHashSetHashPointer(const void *elt)
{
    return static_cast<unsigned long>(
        reinterpret_cast<GUIntptr_t>(const_cast<void *>(elt)));
}

/************************************************************************/
/*                   CPLHashSetEqualPointer()                           */
/************************************************************************/

/**
 * Equality function for arbitrary pointers
 *
 * @param elt1 the first arbitrary pointer to compare
 * @param elt2 the second arbitrary pointer to compare
 *
 * @return TRUE if the pointers are equal
 */

int CPLHashSetEqualPointer(const void *elt1, const void *elt2)
{
    return elt1 == elt2;
}

/************************************************************************/
/*                        CPLHashSetHashStr()                           */
/************************************************************************/

/**
 * Hash function for a zero-terminated string
 *
 * @param elt the string to hash. May be NULL.
 *
 * @return the hash value of the string
 */

CPL_NOSANITIZE_UNSIGNED_INT_OVERFLOW
unsigned long CPLHashSetHashStr(const void *elt)
{
    if (elt == nullptr)
        return 0;

    const unsigned char *pszStr = static_cast<const unsigned char *>(elt);
    unsigned long hash = 0;

    int c = 0;
    while ((c = *pszStr++) != '\0')
        hash = c + (hash << 6) + (hash << 16) - hash;

    return hash;
}

/************************************************************************/
/*                     CPLHashSetEqualStr()                             */
/************************************************************************/

/**
 * Equality function for strings
 *
 * @param elt1 the first string to compare. May be NULL.
 * @param elt2 the second string to compare. May be NULL.
 *
 * @return TRUE if the strings are equal
 */

int CPLHashSetEqualStr(const void *elt1, const void *elt2)
{
    const char *pszStr1 = static_cast<const char *>(elt1);
    const char *pszStr2 = static_cast<const char *>(elt2);

    if (pszStr1 == nullptr && pszStr2 != nullptr)
        return FALSE;

    if (pszStr1 != nullptr && pszStr2 == nullptr)
        return FALSE;

    if (pszStr1 == nullptr && pszStr2 == nullptr)
        return TRUE;

    return strcmp(pszStr1, pszStr2) == 0;
}
