/**
 * Pedersen Hash over babyjubjub elliptic curve, defined in
 * {@link https://eips.ethereum.org/EIPS/eip-2494 | EIP-2494}.
 * jubjub     - edwards over bls12-381 scalar
 * babyjubjub - edwards over bn254 scalar
 * Using scalar as field allows to be used inside of zk-circuits.
 * @module
 */
import { type EdwardsPoint as ExtPointType } from '@noble/curves/abstract/edwards.js';
import { asciiToBytes } from '@noble/curves/utils.js';
import { babyjubjub } from '@noble/curves/misc.js';
import { blake256 } from '@noble/hashes/blake1.js';
import type { TArg, TRet } from '@noble/hashes/utils.js';

const Fp = babyjubjub.Point.Fp;
const _0n = /* @__PURE__ */ BigInt(0);
const _1n = /* @__PURE__ */ BigInt(1);

type EdwardsPoint = typeof babyjubjub.Point.BASE;
type PointCodec = {
  encode: (p: any) => Uint8Array;
  decode: (bytes: Uint8Array) => ExtPointType;
};

// Seems like twistedEdwards fromBytes/toBytes, but with 'x > Fr.ORDER >> 1n' instead of oddity?
// NOTE: we need to be as close as possible to original, otherwise hashes will change!
/**
 * Pedersen point encoder/decoder for babyjubjub points.
 * @example
 * Encode a babyjubjub point to bytes and decode it back.
 * ```ts
 * const { babyjubjub } = await import('@noble/curves/misc.js');
 * const point = babyjubjub.Point.BASE;
 * const encoded = Point.encode(point);
 * Point.decode(encoded);
 * ```
 */
export const Point: TRet<PointCodec> = Object.freeze({
  encode: (p: any): TRet<Uint8Array> => {
    const { x, y } = p.toAffine();
    const bytes = Fp.toBytes(y);
    // Check highest bit instead of lowest in other twisted edwards
    if (x > Fp.ORDER >> _1n) bytes[31] |= 0x80;
    return bytes as TRet<Uint8Array>;
  },
  // NOTE: decode doesn't check oddity of x before negate, which means this heavily depends on
  // formula and sqrt implementation. Other implementations may return different root first.
  // However it uses exactly same tonneli shanks as @noble/curves, but selects lower root
  // This is very fragile, but probably since used for hashes only
  decode: (bytes: TArg<Uint8Array>): TRet<ExtPointType> => {
    const sign = !!(bytes[31] & 0x80);
    // Clone before clearing the sign bit; callers may pass Buffer, whose slice aliases memory.
    bytes = Uint8Array.from(bytes);
    bytes[31] &= 0x7f; // clean sign bit
    const y = Fp.fromBytes(bytes);
    const y2 = Fp.sqr(y);
    let x = Fp.sqrt(
      Fp.div(
        Fp.sub(Fp.ONE, y2),
        Fp.sub(babyjubjub.Point.CURVE().a, Fp.mul(babyjubjub.Point.CURVE().d, y2))
      )
    );
    // This forces lowest root (instead of isOdd in twisted edwards)
    if (x > Fp.ORDER >> _1n) x = Fp.neg(x);
    if (sign) x = Fp.neg(x);
    return babyjubjub.Point.fromAffine({ x, y }) as TRet<ExtPointType>;
  },
}) as unknown as TRet<PointCodec>;

// We cannot do nice precomputes here since input can be unlimited in size
let POINT_CACHE: EdwardsPoint[] = [];
function basePoint(idx: number) {
  // pedersenHash() requests generators in ascending index order,
  // so sparse cache holes are not observed here.
  if (idx < POINT_CACHE.length) return POINT_CACHE[idx];
  let p = undefined;
  for (let i = 0; !p; i++) {
    const s = `PedersenGenerator_${('' + idx).padStart(32, '0')}_${('' + i).padStart(32, '0')}`;
    const h = blake256(asciiToBytes(s));
    h[31] = h[31] & 0b1011_1111; // clear 255 bit
    try {
      p = Point.decode(h);
    } catch {}
  }
  p = p.clearCofactor();
  p.assertValidity();
  POINT_CACHE[idx] = p;
  return p;
}

function getScalars(msg: TArg<Uint8Array>) {
  // noble-curves now exposes BabyJubJub's subgroup base and subgroup order directly.
  const SUBORDER = babyjubjub.Point.Fn.ORDER;
  const res: bigint[] = [];
  // Very fragile wNAF (4-bit) like structure to avoid zero points
  const window = (n: number) => {
    const sign = !!(n & 0b1000); // highest bit is sign
    n = (n & 0b0111) + 1;
    return BigInt(sign ? -n : n);
  };
  // Process in chunks up to 25 bytes
  // 25 bytes -> 50 signed base-32 digits, which keeps each chunk below the subgroup order.
  const blockLen = 25;
  for (let pos = 0; pos < msg.length; pos += blockLen) {
    const cur = msg.subarray(pos, pos + blockLen);
    let scalar = _0n;
    let shift = _1n;
    for (const b of cur) {
      // NOTE: we need to use multiplication here, because of negative values
      scalar += window(b & 0xf) * shift;
      shift <<= BigInt(5);
      scalar += window((b >>> 4) & 0xf) * shift;
      shift <<= BigInt(5);
    }
    if (scalar < _0n) scalar = SUBORDER + scalar;
    res.push(scalar);
  }
  return res;
}

/**
 * Computes the Pedersen hash for the input bytes.
 * @param msg - Message bytes to hash.
 * @returns Encoded babyjubjub point bytes.
 * @example
 * Hash message bytes into an encoded babyjubjub point.
 * ```ts
 * const digest = pedersenHash(new Uint8Array([1, 2, 3]));
 * ```
 */
export function pedersenHash(msg: TArg<Uint8Array>): TRet<Uint8Array> {
  const p = getScalars(msg).reduce(
    (acc, i, j) => acc.add(basePoint(j).multiply(i)),
    babyjubjub.Point.ZERO
  );
  return Point.encode(p) as TRet<Uint8Array>;
}
