/*
 * Copyright 2012-2014 Luke Dashjr
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the standard MIT license.  See COPYING for more details.
 */

#ifndef WIN32
#include <arpa/inet.h>
#else
#include <winsock2.h>
#endif

#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <sys/types.h>
#include "sha2.h"

#include "libbase58.h"
#include "macros.h"
#include "ripemd160.h"

bool (*b58_sha256_impl)(void *, const void *, size_t) = NULL;

static const int8_t b58digits_map[] = {
    -1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,
    -1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,
    -1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,
    -1, 0, 1, 2, 3, 4, 5, 6,  7, 8,-1,-1,-1,-1,-1,-1,
    -1, 9,10,11,12,13,14,15, 16,-1,17,18,19,20,21,-1,
    22,23,24,25,26,27,28,29, 30,31,32,-1,-1,-1,-1,-1,
    -1,33,34,35,36,37,38,39, 40,41,42,43,-1,44,45,46,
    47,48,49,50,51,52,53,54, 55,56,57,-1,-1,-1,-1,-1,
};

typedef uint64_t b58_maxint_t;
typedef uint32_t b58_almostmaxint_t;
#define b58_almostmaxint_bits (sizeof(b58_almostmaxint_t) * 8)
static const b58_almostmaxint_t b58_almostmaxint_mask = ((((b58_maxint_t)1) << b58_almostmaxint_bits) - 1);

bool se_b58tobin(void *bin, size_t *binszp, const char *b58, size_t b58sz)
{
    size_t binsz = *binszp;
    const unsigned char *b58u = (void*)b58;
    unsigned char *binu = bin;
    size_t outisz = (binsz + sizeof(b58_almostmaxint_t) - 1) / sizeof(b58_almostmaxint_t);
    b58_almostmaxint_t outi[outisz];
    b58_maxint_t t;
    b58_almostmaxint_t c;
    size_t i, j;
    uint8_t bytesleft = binsz % sizeof(b58_almostmaxint_t);
    b58_almostmaxint_t zeromask = bytesleft ? (b58_almostmaxint_mask << (bytesleft * 8)) : 0;
    unsigned zerocount = 0;
    
    if (!b58sz)
        b58sz = strlen(b58);
    
    for (i = 0; i < outisz; ++i) {
        outi[i] = 0;
    }
    
    // Leading zeros, just count
    for (i = 0; i < b58sz && b58u[i] == '1'; ++i)
        ++zerocount;
    
    for ( ; i < b58sz; ++i)
    {
        if (b58u[i] & 0x80)
            // High-bit set on invalid digit
            return false;
        if (b58digits_map[b58u[i]] == -1)
            // Invalid base58 digit
            return false;
        c = (unsigned)b58digits_map[b58u[i]];
        for (j = outisz; j--; )
        {
            t = ((b58_maxint_t)outi[j]) * 58 + c;
            c = t >> b58_almostmaxint_bits;
            outi[j] = t & b58_almostmaxint_mask;
        }
        if (c)
            // Output number too big (carry to the next int32)
            return false;
        if (outi[0] & zeromask)
            // Output number too big (last int32 filled too far)
            return false;
    }
    
    j = 0;
    if (bytesleft) {
        for (i = bytesleft; i > 0; --i) {
            *(binu++) = (outi[0] >> (8 * (i - 1))) & 0xff;
        }
        ++j;
    }
    
    for (; j < outisz; ++j)
    {
        for (i = sizeof(*outi); i > 0; --i) {
            *(binu++) = (outi[j] >> (8 * (i - 1))) & 0xff;
        }
    }
    
    // Count canonical base58 byte count
    binu = bin;
    for (i = 0; i < binsz; ++i)
    {
        if (binu[i])
            break;
        --*binszp;
    }
    *binszp += zerocount;
    
    return true;
}

static
bool my_dblsha256(void *hash, const void *data, size_t datasz)
{
    uint8_t buf[0x20];
    return b58_sha256_impl(buf, data, datasz) && b58_sha256_impl(hash, buf, sizeof(buf));
}

int se_b58check(const void *bin, size_t binsz, const char *base58str, size_t b58sz)
{
    unsigned char buf[32];
    const uint8_t *binc = bin;
    unsigned i;
    if (binsz < 4)
        return -4;
    if (!my_dblsha256(buf, bin, binsz - 4))
        return -2;
    if (memcmp(&binc[binsz - 4], buf, 4))
        return -1;
    
    // Check number of zeros is correct AFTER verifying checksum (to avoid possibility of accessing base58str beyond the end)
    for (i = 0; binc[i] == '\0' && base58str[i] == '1'; ++i)
    {}  // Just finding the end of zeros, nothing to do in loop
    if (binc[i] == '\0' || base58str[i] == '1')
        return -3;
    
    return binc[0];
}

static const char b58digits_ordered[] = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";

bool se_b58enc(char *b58, size_t *b58sz, const void *data, size_t binsz)
{
    const uint8_t *bin = data;
    int carry;
    ssize_t i, j, high, zcount = 0;
    size_t size;
    
    while (zcount < binsz && !bin[zcount])
        ++zcount;
    
    size = (binsz - zcount) * 138 / 100 + 1;
    uint8_t buf[size];
    memset(buf, 0, size);
    
    for (i = zcount, high = size - 1; i < binsz; ++i, high = j)
    {
        for (carry = bin[i], j = size - 1; (j > high) || carry; --j)
        {
            carry += 256 * buf[j];
            buf[j] = carry % 58;
            carry /= 58;
        }
    }
    
    for (j = 0; j < size && !buf[j]; ++j);
    
    if (*b58sz <= zcount + size - j)
    {
        *b58sz = zcount + size - j + 1;
        return false;
    }
    
    if (zcount)
        memset(b58, '1', zcount);
    for (i = zcount; j < size; ++i, ++j)
        b58[i] = b58digits_ordered[buf[j]];
    b58[i] = '\0';
    *b58sz = i + 1;
    
    return true;
}

bool se_b58check_enc(char *b58c, size_t *b58c_sz, uint8_t ver, const void *data, size_t datasz)
{
    uint8_t buf[1 + datasz + 0x20];
    uint8_t *hash = &buf[1 + datasz];
    
    buf[0] = ver;
    memcpy(&buf[1], data, datasz);
    if (!my_dblsha256(hash, buf, datasz + 1))
    {
        *b58c_sz = 0;
        return false;
    }
    
    return se_b58enc(b58c, b58c_sz, buf, 1 + datasz + 4);
}

int base58_encode_check(const uint8_t *data, int datalen, char *str, int strsize)
{
    if (datalen > 128) {
        return 0;
    }
    uint8_t buf[datalen + 32];
    uint8_t *hash = buf + datalen;
    memcpy(buf, data, datalen);
    sha256_Raw(data, datalen, hash);
    sha256_Raw(hash, 32, hash);
    size_t res = strsize;
    bool success = se_b58enc(str, &res, buf, datalen + 4);
    MEMSET_BZERO(buf, sizeof(buf));
    return success ? (int)res : 0;
}

int base58_decode_check(const char *str, uint8_t *data, int datalen)
{
    if (datalen > 128) {
        return 0;
    }
    uint8_t d[datalen + 4];
    size_t res = datalen + 4;
    if (se_b58tobin(d, &res, str, 0) != true) {
        return 0;
    }
    uint8_t *nd = d + datalen + 4 - res;
    if (se_b58check(nd, res, str, 0) < 0) {
        return 0;
    }
    memcpy(data, nd, res - 4);
    return (int)(res - 4);
}
