/**
 * RFC 9497: Oblivious Pseudorandom Functions (OPRFs) Using Prime-Order Groups.
 * https://www.rfc-editor.org/rfc/rfc9497
 *

OPRF allows to interactively create an `Output = PRF(Input, serverSecretKey)`:

- Server cannot calculate Output by itself: it doesn't know Input
- Client cannot calculate Output by itself: it doesn't know server secretKey
- An attacker interception the communication can't restore Input/Output/serverSecretKey and can't
  link Input to some value.

## Issues

- Low-entropy inputs (e.g. password '123') enable brute-forced dictionary attacks by the server
  (solveable by domain separation in POPRF)
- High-level protocol needs to be constructed on top, because OPRF is low-level

## Use cases

1. **Password-Authenticated Key Exchange (PAKE):** Enables secure password login (e.g., OPAQUE)
   without revealing the password to the server.
2. **Private Set Intersection (PSI):** Allows two parties to compute the intersection of their
   private sets without revealing non-intersecting elements.
3. **Anonymous Credential Systems:** Supports issuance of anonymous, unlinkable credentials
   (e.g., Privacy Pass) using blind OPRF evaluation.
4. **Private Information Retrieval (PIR):** Helps users query databases without revealing which
   item they accessed.
5. **Encrypted Search / Secure Indexing:** Enables keyword search over encrypted data while keeping
   queries private.
6. **Spam Prevention and Rate-Limiting:** Issues anonymous tokens to prevent abuse
   (e.g., CAPTCHA bypass) without compromising user privacy.

## Modes

- OPRF: simple mode, client doesn't need to know server public key
- VOPRF: verifiable mode. It lets the client verify that the server used the
  secret key corresponding to a known public key
- POPRF: partially oblivious mode, VOPRF + domain separation

There is also non-interactive mode (Evaluate), which creates Output
non-interactively with knowledge of the secret key.

Flow:
- (once) Server generates secret and public keys, distributes public keys to clients
  - deterministically: `deriveKeyPair` or just random: `generateKeyPair`
- Client blinds input: `blind(secretInput)`
- Server evaluates blinded input: `blindEvaluate` generated by client, sends result to client
- Client creates output using result of evaluation via 'finalize'

 * @module
 */
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
import {
  abytes,
  asciiToBytes,
  bytesToNumberBE,
  bytesToNumberLE,
  concatBytes,
  numberToBytesBE,
  randomBytes,
  validateObject,
  type TArg,
  type TRet,
} from '../utils.ts';
import { pippenger, validatePointCons, type CurvePoint, type CurvePointCons } from './curve.ts';
import { _DST_scalar, type H2CDSTOpts } from './hash-to-curve.ts';
import { getMinHashLength, mapHashToField } from './modular.ts';

// OPRF is designed to be used across network, so we default to serialized values.
/** Serialized group element passed between OPRF participants. */
export type PointBytes = Uint8Array;
/** Serialized scalar used for blinds and server secret keys. */
export type ScalarBytes = Uint8Array;
/** Arbitrary byte input or output used by the OPRF protocol. */
export type Bytes = Uint8Array;
const _DST_scalarBytes = /* @__PURE__ */ asciiToBytes(_DST_scalar);
/** Cryptographically secure byte generator used for blinds and proofs. */
export type RNG = typeof randomBytes;

/** Curve and hash hooks required to instantiate one OPRF ciphersuite. */
export type OPRFOpts<P extends CurvePoint<any, P>> = {
  /** Human-readable suite identifier used for domain separation. */
  name: string;
  /**
   * Prime-order group used by the OPRF construction.
   * Kept generic because the suite returns serialized points.
   */
  Point: CurvePointCons<P>;
  // Fn: IField<bigint>;
  /**
   * Hash function used for transcripts, proofs, and outputs.
   * @param msg - Message bytes to hash.
   * @returns Digest bytes.
   */
  hash(msg: TArg<Bytes>): TRet<Bytes>;
  /**
   * Hash arbitrary bytes into one scalar in the suite order.
   * @param msg - Message bytes to map.
   * @param options - Hash-to-field domain-separation options. See {@link H2CDSTOpts}.
   * Implementations MUST treat `msg` and `options` as read-only.
   * @returns Scalar in the suite order.
   */
  hashToScalar(msg: TArg<Uint8Array>, options: TArg<H2CDSTOpts>): bigint;
  /**
   * Hash arbitrary bytes directly onto one curve point.
   * @param msg - Message bytes to map.
   * @param options - Hash-to-curve domain-separation options. See {@link H2CDSTOpts}.
   * Implementations MUST treat `msg` and `options` as read-only.
   * @returns Point on the suite curve.
   */
  hashToGroup(msg: TArg<Uint8Array>, options: TArg<H2CDSTOpts>): P;
};

/** Server keypair for one OPRF suite. */
export type OPRFKeys = {
  /** Secret scalar kept by the server. */
  secretKey: TRet<ScalarBytes>;
  /** Public point distributed to clients in verifiable modes. */
  publicKey: TRet<PointBytes>;
};
/** Result of the client-side blind step. */
export type OPRFBlind = {
  /** Secret blind scalar that the client keeps locally. */
  blind: TRet<ScalarBytes>;
  /** Blinded group element sent to the server. */
  blinded: TRet<PointBytes>;
};
/** Server response for one verifiable OPRF evaluation. */
export type OPRFBlindEval = {
  /** Evaluated group element returned by the server. */
  evaluated: TRet<PointBytes>;
  /** DLEQ proof binding the evaluation to the server public key. */
  proof: TRet<Bytes>;
};
/** Server response for a batch of verifiable OPRF evaluations. */
export type OPRFBlindEvalBatch = {
  /** Evaluated group elements returned for each blinded input. */
  evaluated: TRet<PointBytes[]>;
  /** Batch proof covering all evaluated elements. */
  proof: TRet<Bytes>;
};
/** One finalized transcript item used by batch verification helpers. */
export type OPRFFinalizeItem = {
  /** Original client input. */
  input: Bytes;
  /** Secret blind scalar used for the input. */
  blind: ScalarBytes;
  /** Evaluated point returned by the server. */
  evaluated: PointBytes;
  /** Blinded point originally sent to the server. */
  blinded: PointBytes;
};
/** Result of the POPRF client-side blind step with the tweaked server public key. */
export type OPRFBlindTweaked = OPRFBlind & { tweakedKey: TRet<PointBytes> };

/**
 * Represents a full OPRF ciphersuite implementation according to RFC 9497.
 * This object bundles the three protocol variants (OPRF, VOPRF, POPRF) for a specific
 * prime-order group and hash function combination.
 *
 * @see https://www.rfc-editor.org/rfc/rfc9497.html
 */
export type OPRF = {
  /**
   * The unique identifier for the ciphersuite, e.g., "ristretto255-SHA512".
   * This name is used for domain separation to prevent cross-protocol attacks.
   */
  readonly name: string;

  /**
   * The base Oblivious Pseudorandom Function (OPRF) mode (mode 0x00).
   * This is a two-party protocol between a client and a server to compute F(k, x)
   * where 'k' is the server's key and 'x' is the client's input.
   *
   * The client learns the output F(k, x) but nothing about 'k'.
   * The server learns nothing about 'x' or F(k, x).
   * This mode is NOT verifiable; the client cannot prove the server used a specific key.
   */
  readonly oprf: {
    /**
     * (Server-side) Generates a new random private/public key pair for the server.
     * @returns A new key pair.
     */
    generateKeyPair(): TRet<OPRFKeys>;

    /**
     * (Server-side) Deterministically derives a private/public key pair from a seed.
     * @param seed - A 32-byte cryptographically secure random seed.
     * @param keyInfo - An optional byte string for domain separation.
     * @returns The derived key pair.
     */
    deriveKeyPair(seed: TArg<Bytes>, keyInfo: TArg<Bytes>): TRet<OPRFKeys>;

    /**
     * (Client-side) The first step of the protocol. The client blinds its private input.
     * @param input - The client's private input bytes.
     * @param rng - An optional cryptographically secure random number generator.
     * @returns An object containing the `blind` scalar (which the client MUST keep secret)
     * and the `blinded` element (which the client sends to the server).
     */
    blind(input: TArg<Bytes>, rng?: RNG): TRet<OPRFBlind>;

    /**
     * (Server-side) The second step. The server evaluates the client's blinded element
     * using its secret key.
     * @param secretKey - The server's private key.
     * @param blinded - The blinded group element received from the client.
     * @returns The evaluated group element, to be sent back to the client.
     */
    blindEvaluate(secretKey: TArg<ScalarBytes>, blinded: TArg<PointBytes>): TRet<PointBytes>;

    /**
     * (Client-side) The final step. The client unblinds the server's response to
     * compute the final OPRF output.
     * @param input - The original private input from the `blind` step.
     * @param blind - The secret scalar from the `blind` step.
     * @param evaluated - The evaluated group element received from the server.
     * @returns The final OPRF output, `Hash(len(input)||input||len(unblinded)||unblinded||"Finalize")`.
     */
    finalize(
      input: TArg<Bytes>,
      blind: TArg<ScalarBytes>,
      evaluated: TArg<PointBytes>
    ): TRet<Bytes>;
  };

  /**
   * The Verifiable Oblivious Pseudorandom Function (VOPRF) mode (mode 0x01).
   * This mode extends the base OPRF by providing a proof that the server used the
   * secret key corresponding to its known public key.
   */
  readonly voprf: {
    /** (Server-side) Generates a key pair for the VOPRF mode. */
    generateKeyPair(): TRet<OPRFKeys>;
    /** (Server-side) Deterministically derives a key pair for the VOPRF mode. */
    deriveKeyPair(seed: TArg<Bytes>, keyInfo: TArg<Bytes>): TRet<OPRFKeys>;
    /** (Client-side) Blinds the client's private input for the VOPRF protocol. */
    blind(input: TArg<Bytes>, rng?: RNG): TRet<OPRFBlind>;

    /**
     * (Server-side) Evaluates the client's blinded element and generates a DLEQ proof
     * of correctness.
     * @param secretKey - The server's private key.
     * @param publicKey - The server's public key, used in proof generation.
     * @param blinded - The blinded group element received from the client.
     * @param rng - An optional cryptographically secure random number generator for the proof.
     * @returns The evaluated element and a proof of correct computation.
     */
    blindEvaluate(
      secretKey: TArg<ScalarBytes>,
      publicKey: TArg<PointBytes>,
      blinded: TArg<PointBytes>,
      rng?: RNG
    ): TRet<OPRFBlindEval>;

    /**
     * (Server-side) An optimized batch version of `blindEvaluate`. It evaluates multiple
     * blinded elements and produces a single, constant-size proof for the entire batch,
     * amortizing the cost of proof generation.
     * @param secretKey - The server's private key.
     * @param publicKey - The server's public key.
     * @param blinded - An array of blinded group elements from one or more clients.
     * @param rng - An optional cryptographically secure random number generator for the proof.
     * @returns An array of evaluated elements and a single proof for the batch.
     */
    blindEvaluateBatch(
      secretKey: TArg<ScalarBytes>,
      publicKey: TArg<PointBytes>,
      blinded: TArg<PointBytes[]>,
      rng?: RNG
    ): TRet<OPRFBlindEvalBatch>;

    /**
     * (Client-side) The final step. The client verifies the server's proof, and if valid,
     * unblinds the result to compute the final VOPRF output.
     * @param input - The original private input.
     * @param blind - The secret scalar from the `blind` step.
     * @param evaluated - The evaluated element from the server.
     * @param blinded - The blinded element sent to the server (needed for proof verification).
     * @param publicKey - The server's public key against which the proof is verified.
     * @param proof - The DLEQ proof from the server.
     * @returns The final VOPRF output.
     * @throws If the proof verification fails. {@link Error}
     */
    finalize(
      input: TArg<Bytes>,
      blind: TArg<ScalarBytes>,
      evaluated: TArg<PointBytes>,
      blinded: TArg<PointBytes>,
      publicKey: TArg<PointBytes>,
      proof: TArg<Bytes>
    ): TRet<Bytes>;

    /**
     * (Client-side) The batch-aware version of `finalize`. It verifies a single batch proof
     * against a list of corresponding inputs and outputs.
     * @param items - An array of objects, each containing the parameters for a single finalization.
     * @param publicKey - The server's public key.
     * @param proof - The single DLEQ proof for the entire batch.
     * @returns An array of final VOPRF outputs, one for each item in the input.
     * @throws If the proof verification fails. {@link Error}
     */
    finalizeBatch(
      items: TArg<OPRFFinalizeItem[]>,
      publicKey: TArg<PointBytes>,
      proof: TArg<Bytes>
    ): TRet<Bytes[]>;
  };

  /**
   * A factory for the Partially Oblivious Pseudorandom Function (POPRF) mode (mode 0x02).
   * This mode extends VOPRF to include a public `info` parameter, known to both client and
   * server, which is cryptographically bound to the final output.
   * This is useful for domain separation at the application level.
   * @param info - A public byte string to be mixed into the computation.
   * @returns An object with the POPRF protocol functions.
   */
  readonly poprf: (info: TArg<Bytes>) => {
    /** (Server-side) Generates a key pair for the POPRF mode. */
    generateKeyPair(): TRet<OPRFKeys>;
    /** (Server-side) Deterministically derives a key pair for the POPRF mode. */
    deriveKeyPair(seed: TArg<Bytes>, keyInfo: TArg<Bytes>): TRet<OPRFKeys>;

    /**
     * (Client-side) Blinds the client's private input and computes the "tweaked key".
     * The tweaked key is a public value derived from the server's public key and the public `info`.
     * @param input - The client's private input.
     * @param publicKey - The server's public key.
     * @param rng - An optional cryptographically secure random number generator.
     * @returns The `blind`, `blinded` element, and the `tweakedKey`
     *   the client uses for verification.
     */
    blind(input: TArg<Bytes>, publicKey: TArg<PointBytes>, rng?: RNG): TRet<OPRFBlindTweaked>;

    /**
     * (Server-side) Evaluates the blinded element using a key derived from
     * its secret key and the public `info`.
     * It generates a DLEQ proof against the tweaked key.
     * @param secretKey - The server's private key.
     * @param blinded - The blinded element from the client.
     * @param rng - An optional RNG for the proof.
     * @returns The evaluated element and a proof of correct computation.
     */
    blindEvaluate(
      secretKey: TArg<ScalarBytes>,
      blinded: TArg<PointBytes>,
      rng?: RNG
    ): TRet<OPRFBlindEval>;

    /**
     * (Server-side) A batch-aware version of `blindEvaluate` for the POPRF mode.
     * @param secretKey - The server's private key.
     * @param blinded - An array of blinded elements.
     * @param rng - An optional RNG for the proof.
     * @returns An array of evaluated elements and a single proof for the batch.
     */
    blindEvaluateBatch(
      secretKey: TArg<ScalarBytes>,
      blinded: TArg<PointBytes[]>,
      rng: RNG
    ): TRet<OPRFBlindEvalBatch>;

    /**
     * (Client-side) A batch-aware version of `finalize` for the POPRF mode.
     * It verifies the proof against the tweaked key.
     * @param items - An array containing the parameters for each finalization.
     * @param proof - The single DLEQ proof for the batch.
     * @param tweakedKey - The tweaked key corresponding to the proof.
     *   All items must share the same `info` and `publicKey`.
     * @returns An array of final POPRF outputs.
     * @throws If proof verification fails. {@link Error}
     */
    finalizeBatch(
      items: TArg<OPRFFinalizeItem[]>,
      proof: TArg<Bytes>,
      tweakedKey: TArg<PointBytes>
    ): TRet<Bytes[]>;

    /**
     * (Client-side) Finalizes the POPRF protocol. It verifies the server's proof against the
     * `tweakedKey` computed in the `blind` step. The final output is bound to the public `info`.
     * @param input - The original private input.
     * @param blind - The secret scalar.
     * @param evaluated - The evaluated element from the server.
     * @param blinded - The blinded element sent to the server.
     * @param proof - The DLEQ proof from the server.
     * @param tweakedKey - The public tweaked key computed by the client during the `blind` step.
     * @returns The final POPRF output.
     * @throws If proof verification fails. {@link Error}
     */
    finalize(
      input: TArg<Bytes>,
      blind: TArg<ScalarBytes>,
      evaluated: TArg<PointBytes>,
      blinded: TArg<PointBytes>,
      proof: TArg<Bytes>,
      tweakedKey: TArg<PointBytes>
    ): TRet<Bytes>;

    /**
     * A non-interactive evaluation function for an entity that knows all inputs.
     * Computes the final POPRF output directly. Useful for testing or specific applications
     * where the server needs to compute the output for a known input.
     * @param secretKey - The server's private key.
     * @param input - The client's private input.
     * @returns The final POPRF output.
     */
    evaluate(secretKey: TArg<ScalarBytes>, input: TArg<Bytes>): TRet<Bytes>;
  };
};

// welcome to generic hell
/**
 * @param opts - OPRF ciphersuite options. See {@link OPRFOpts}.
 * @returns OPRF helper namespace.
 * @example
 * Instantiate an OPRF suite from curve-specific hashing hooks.
 *
 * ```ts
 * import { createOPRF } from '@noble/curves/abstract/oprf.js';
 * import { p256, p256_hasher } from '@noble/curves/nist.js';
 * import { sha256 } from '@noble/hashes/sha2.js';
 * const oprf = createOPRF({
 *   name: 'P256-SHA256',
 *   Point: p256.Point,
 *   hash: sha256,
 *   hashToGroup: p256_hasher.hashToCurve,
 *   hashToScalar: p256_hasher.hashToScalar,
 * });
 * const keys = oprf.oprf.generateKeyPair();
 * ```
 */
export function createOPRF<P extends CurvePoint<any, P>>(opts: OPRFOpts<P>): TRet<OPRF> {
  validateObject(opts, {
    name: 'string',
    hash: 'function',
    hashToScalar: 'function',
    hashToGroup: 'function',
  });
  // Cheap constructor-surface sanity check only: this verifies the generic static hooks/fields that
  // OPRF consumes, but it does not certify point semantics like BASE/ZERO correctness.
  validatePointCons(opts.Point);
  const { name, Point, hash } = opts;
  const { Fn } = Point;

  const hashToGroup = (msg: TArg<Uint8Array>, ctx: TArg<Uint8Array>) =>
    opts.hashToGroup(msg, {
      DST: concatBytes(asciiToBytes('HashToGroup-'), ctx),
    }) as P;
  const hashToScalarPrefixed = (msg: TArg<Uint8Array>, ctx: TArg<Uint8Array>) =>
    opts.hashToScalar(msg, { DST: concatBytes(_DST_scalarBytes, ctx) });
  const randomScalar = (rng: RNG = randomBytes) => {
    // RFC 9497 §2.1 defines RandomScalar as nonzero; blind inversion and generated public keys
    // both rely on keeping this helper in the `1..n-1` range.
    const t = mapHashToField(rng(getMinHashLength(Fn.ORDER)), Fn.ORDER, Fn.isLE);
    // We cannot use Fn.fromBytes here, because field
    // can have different number of bytes (like ed448)
    return Fn.isLE ? bytesToNumberLE(t) : bytesToNumberBE(t);
  };

  const msm = (points: P[], scalars: bigint[]) => pippenger(Point, points, scalars);

  const getCtx = (mode: number) =>
    concatBytes(asciiToBytes('OPRFV1-'), new Uint8Array([mode]), asciiToBytes('-' + name));
  const ctxOPRF = getCtx(0x00);
  const ctxVOPRF = getCtx(0x01);
  const ctxPOPRF = getCtx(0x02);

  function encode(...args: TArg<(Uint8Array | number | string)[]>): TRet<Bytes> {
    const res = [];
    for (const a of args) {
      if (typeof a === 'number') res.push(numberToBytesBE(a, 2));
      else if (typeof a === 'string') res.push(asciiToBytes(a));
      else {
        abytes(a);
        res.push(numberToBytesBE(a.length, 2), a);
      }
    }
    // No wipe here, since will modify actual bytes
    return concatBytes(...res) as TRet<Bytes>;
  }
  const inputBytes = (title: string, bytes: TArg<Uint8Array>) => {
    abytes(bytes, undefined, title);
    // RFC 9497 §1.2 limits PrivateInput/PublicInput to 2^16 - 1 bytes because these values are
    // length-prefixed with two bytes before use throughout the protocol.
    if (bytes.length > 0xffff)
      throw new Error(
        `"${title}" expected Uint8Array of length <= 65535, got length=${bytes.length}`
      );
    return bytes;
  };
  const hashInput = (...bytes: TArg<Uint8Array[]>): TRet<Bytes> =>
    hash(encode(...bytes, 'Finalize')) as TRet<Bytes>;

  function getTranscripts(B: P, C: P[], D: P[], ctx: TArg<Bytes>) {
    const Bm = B.toBytes();
    const seed = hash(encode(Bm, concatBytes(asciiToBytes('Seed-'), ctx)));
    const res: bigint[] = [];
    for (let i = 0; i < C.length; i++) {
      const Ci = C[i].toBytes();
      const Di = D[i].toBytes();
      const di = hashToScalarPrefixed(encode(seed, i, Ci, Di, 'Composite'), ctx);
      res.push(di);
    }
    return res;
  }

  function computeComposites(B: P, C: P[], D: P[], ctx: TArg<Bytes>) {
    const T = getTranscripts(B, C, D, ctx);
    const M = msm(C, T);
    const Z = msm(D, T);
    return { M, Z };
  }

  function computeCompositesFast(
    k: bigint,
    B: P,
    C: P[],
    D: P[],
    ctx: TArg<Bytes>
  ): { M: P; Z: P } {
    const T = getTranscripts(B, C, D, ctx);
    const M = msm(C, T);
    // RFC 9497 §2.2.1 ComputeCompositesFast derives weights from both C and D in getTranscripts(),
    // then uses the server shortcut Z = k * M instead of a second MSM over D.
    const Z = M.multiply(k);
    return { M, Z };
  }

  function challengeTranscript(B: P, M: P, Z: P, t2: P, t3: P, ctx: TArg<Bytes>) {
    const [Bm, a0, a1, a2, a3] = [B, M, Z, t2, t3].map((i) => i.toBytes());
    return hashToScalarPrefixed(encode(Bm, a0, a1, a2, a3, 'Challenge'), ctx);
  }

  function generateProof(ctx: TArg<Bytes>, k: bigint, B: P, C: P[], D: P[], rng: RNG): TRet<Bytes> {
    const { M, Z } = computeCompositesFast(k, B, C, D, ctx);
    const r = randomScalar(rng);
    const t2 = Point.BASE.multiply(r);
    const t3 = M.multiply(r);
    const c = challengeTranscript(B, M, Z, t2, t3, ctx);
    const s = Fn.sub(r, Fn.mul(c, k)); // r - c*k
    return concatBytes(...[c, s].map((i) => Fn.toBytes(i))) as TRet<Bytes>;
  }

  function verifyProof(ctx: TArg<Bytes>, B: P, C: P[], D: P[], proof: TArg<Bytes>) {
    abytes(proof, 2 * Fn.BYTES);
    const { M, Z } = computeComposites(B, C, D, ctx);
    const [c, s] = [proof.subarray(0, Fn.BYTES), proof.subarray(Fn.BYTES)].map((f) =>
      Fn.fromBytes(f)
    );
    const t2 = Point.BASE.multiply(s).add(B.multiply(c)); // s*G + c*B
    const t3 = M.multiply(s).add(Z.multiply(c)); // s*M + c*Z
    const expectedC = challengeTranscript(B, M, Z, t2, t3, ctx);
    if (!Fn.eql(c, expectedC)) throw new Error('proof verification failed');
  }

  function generateKeyPair(): TRet<OPRFKeys> {
    const skS = randomScalar();
    const pkS = Point.BASE.multiply(skS);
    return { secretKey: Fn.toBytes(skS), publicKey: pkS.toBytes() } as TRet<OPRFKeys>;
  }

  function deriveKeyPair(ctx: TArg<Bytes>, seed: TArg<Bytes>, info: TArg<Bytes>): TRet<OPRFKeys> {
    // RFC 9497 §3.2.1 defines `seed[32]`; reject other sizes here because this public API already
    // documents a 32-byte seed instead of generic input keying material.
    abytes(seed, 32, 'seed');
    info = inputBytes('keyInfo', info);
    const dst = concatBytes(asciiToBytes('DeriveKeyPair'), ctx);
    const msg = concatBytes(seed, encode(info), Uint8Array.of(0));
    for (let counter = 0; counter <= 255; counter++) {
      msg[msg.length - 1] = counter;
      const skS = opts.hashToScalar(msg, { DST: dst });
      if (Fn.is0(skS)) continue; // should not happen
      return {
        secretKey: Fn.toBytes(skS),
        publicKey: Point.BASE.multiply(skS).toBytes(),
      } as TRet<OPRFKeys>;
    }
    throw new Error('Cannot derive key');
  }
  const wirePoint = (label: string, bytes: TArg<Uint8Array>) => {
    const point = Point.fromBytes(bytes);
    // RFC 9497 §3.3 says applications MUST reject group-identity Elements received over the wire
    // after deserialization, even if the suite decoder itself accepts the identity encoding.
    if (point.equals(Point.ZERO)) throw new Error(label + ' point at infinity');
    return point;
  };
  function blind(
    ctx: TArg<Bytes>,
    input: TArg<Uint8Array>,
    rng: RNG = randomBytes
  ): TRet<OPRFBlind> {
    input = inputBytes('input', input);
    const blind = randomScalar(rng);
    const inputPoint = hashToGroup(input, ctx);
    if (inputPoint.equals(Point.ZERO)) throw new Error('Input point at infinity');
    const blinded = inputPoint.multiply(blind);
    return { blind: Fn.toBytes(blind), blinded: blinded.toBytes() } as TRet<OPRFBlind>;
  }
  function evaluate(
    ctx: TArg<Bytes>,
    secretKey: TArg<ScalarBytes>,
    input: TArg<Bytes>
  ): TRet<Bytes> {
    input = inputBytes('input', input);
    const skS = Fn.fromBytes(secretKey);
    const inputPoint = hashToGroup(input, ctx);
    if (inputPoint.equals(Point.ZERO)) throw new Error('Input point at infinity');
    const unblinded = inputPoint.multiply(skS).toBytes();
    return hashInput(input, unblinded);
  }
  const oprf = Object.freeze({
    generateKeyPair,
    deriveKeyPair: (seed: TArg<Bytes>, keyInfo: TArg<Bytes>) =>
      deriveKeyPair(ctxOPRF, seed, keyInfo),
    blind: (input: TArg<Bytes>, rng: RNG = randomBytes) => blind(ctxOPRF, input, rng),
    blindEvaluate(secretKey: TArg<ScalarBytes>, blindedPoint: TArg<PointBytes>): TRet<PointBytes> {
      const skS = Fn.fromBytes(secretKey);
      const elm = wirePoint('blinded', blindedPoint);
      return elm.multiply(skS).toBytes() as TRet<PointBytes>;
    },
    finalize(
      input: TArg<Bytes>,
      blindBytes: TArg<ScalarBytes>,
      evaluatedBytes: TArg<PointBytes>
    ): TRet<Bytes> {
      input = inputBytes('input', input);
      const blind = Fn.fromBytes(blindBytes);
      const evalPoint = wirePoint('evaluated', evaluatedBytes);
      const unblinded = evalPoint.multiply(Fn.inv(blind)).toBytes();
      return hashInput(input, unblinded);
    },
    evaluate: (secretKey: TArg<ScalarBytes>, input: TArg<Bytes>) =>
      evaluate(ctxOPRF, secretKey, input),
  });

  const voprf = Object.freeze({
    generateKeyPair,
    deriveKeyPair: (seed: TArg<Bytes>, keyInfo: TArg<Bytes>) =>
      deriveKeyPair(ctxVOPRF, seed, keyInfo),
    blind: (input: TArg<Bytes>, rng: RNG = randomBytes) => blind(ctxVOPRF, input, rng),
    blindEvaluateBatch(
      secretKey: TArg<ScalarBytes>,
      publicKey: TArg<PointBytes>,
      blinded: TArg<PointBytes[]>,
      rng: RNG = randomBytes
    ): TRet<OPRFBlindEvalBatch> {
      if (!Array.isArray(blinded)) throw new Error('expected array');
      const skS = Fn.fromBytes(secretKey);
      const pkS = wirePoint('public key', publicKey);
      const blindedPoints = blinded.map((i) => wirePoint('blinded', i));
      const evaluated = blindedPoints.map((i) => i.multiply(skS));
      const proof = generateProof(ctxVOPRF, skS, pkS, blindedPoints, evaluated, rng);
      return { evaluated: evaluated.map((i) => i.toBytes()), proof } as TRet<OPRFBlindEvalBatch>;
    },
    blindEvaluate(
      secretKey: TArg<ScalarBytes>,
      publicKey: TArg<PointBytes>,
      blinded: TArg<PointBytes>,
      rng: RNG = randomBytes
    ): TRet<OPRFBlindEval> {
      const res = this.blindEvaluateBatch(secretKey, publicKey, [blinded], rng);
      return { evaluated: res.evaluated[0], proof: res.proof } as TRet<OPRFBlindEval>;
    },
    finalizeBatch(
      items: TArg<OPRFFinalizeItem[]>,
      publicKey: TArg<PointBytes>,
      proof: TArg<Bytes>
    ): TRet<Bytes[]> {
      if (!Array.isArray(items)) throw new Error('expected array');
      const pkS = wirePoint('public key', publicKey);
      const blindedPoints = items.map((i) => wirePoint('blinded', i.blinded));
      const evalPoints = items.map((i) => wirePoint('evaluated', i.evaluated));
      verifyProof(ctxVOPRF, pkS, blindedPoints, evalPoints, proof);
      return items.map((i) => oprf.finalize(i.input, i.blind, i.evaluated)) as TRet<Bytes[]>;
    },
    finalize(
      input: TArg<Bytes>,
      blind: TArg<ScalarBytes>,
      evaluated: TArg<PointBytes>,
      blinded: TArg<PointBytes>,
      publicKey: TArg<PointBytes>,
      proof: TArg<Bytes>
    ): TRet<Bytes> {
      return this.finalizeBatch([{ input, blind, evaluated, blinded }], publicKey, proof)[0];
    },
    evaluate: (secretKey: TArg<ScalarBytes>, input: TArg<Bytes>) =>
      evaluate(ctxVOPRF, secretKey, input),
  });
  // NOTE: info is domain separation
  const poprf = (info: TArg<Bytes>) => {
    info = inputBytes('info', info);
    const m = hashToScalarPrefixed(encode('Info', info), ctxPOPRF);
    const T = Point.BASE.multiply(m);
    return Object.freeze({
      generateKeyPair,
      deriveKeyPair: (seed: TArg<Bytes>, keyInfo: TArg<Bytes>) =>
        deriveKeyPair(ctxPOPRF, seed, keyInfo),
      blind(
        input: TArg<Bytes>,
        publicKey: TArg<PointBytes>,
        rng: RNG = randomBytes
      ): TRet<OPRFBlindTweaked> {
        input = inputBytes('input', input);
        const pkS = wirePoint('public key', publicKey);
        const tweakedKey = T.add(pkS);
        if (tweakedKey.equals(Point.ZERO)) throw new Error('tweakedKey point at infinity');
        const blind = randomScalar(rng);
        const inputPoint = hashToGroup(input, ctxPOPRF);
        if (inputPoint.equals(Point.ZERO)) throw new Error('Input point at infinity');
        const blindedPoint = inputPoint.multiply(blind);
        return {
          blind: Fn.toBytes(blind),
          blinded: blindedPoint.toBytes(),
          tweakedKey: tweakedKey.toBytes(),
        } as TRet<OPRFBlindTweaked>;
      },
      blindEvaluateBatch(
        secretKey: TArg<ScalarBytes>,
        blinded: TArg<PointBytes[]>,
        rng: RNG = randomBytes
      ): TRet<OPRFBlindEvalBatch> {
        if (!Array.isArray(blinded)) throw new Error('expected array');
        const skS = Fn.fromBytes(secretKey);
        const t = Fn.add(skS, m);
        // "Hence, this error can be a signal for the server to replace its
        // private key". We throw inside; this should be impossible.
        const invT = Fn.inv(t);
        const blindedPoints = blinded.map((i) => wirePoint('blinded', i));
        const evalPoints = blindedPoints.map((i) => i.multiply(invT));
        const tweakedKey = Point.BASE.multiply(t);
        const proof = generateProof(ctxPOPRF, t, tweakedKey, evalPoints, blindedPoints, rng);
        return { evaluated: evalPoints.map((i) => i.toBytes()), proof } as TRet<OPRFBlindEvalBatch>;
      },
      blindEvaluate(
        secretKey: TArg<ScalarBytes>,
        blinded: TArg<PointBytes>,
        rng: RNG = randomBytes
      ): TRet<OPRFBlindEval> {
        const res = this.blindEvaluateBatch(secretKey, [blinded], rng);
        return { evaluated: res.evaluated[0], proof: res.proof } as TRet<OPRFBlindEval>;
      },
      finalizeBatch(
        items: TArg<OPRFFinalizeItem[]>,
        proof: TArg<Bytes>,
        tweakedKey: TArg<PointBytes>
      ): TRet<Bytes[]> {
        if (!Array.isArray(items)) throw new Error('expected array');
        const inputs = items.map((i) => inputBytes('input', i.input));
        const evalPoints = items.map((i) => wirePoint('evaluated', i.evaluated));
        verifyProof(
          ctxPOPRF,
          wirePoint('tweakedKey', tweakedKey),
          evalPoints,
          items.map((i) => wirePoint('blinded', i.blinded)),
          proof
        );
        return items.map((i, j) => {
          const blind = Fn.fromBytes(i.blind);
          const point = evalPoints[j].multiply(Fn.inv(blind)).toBytes();
          return hashInput(inputs[j], info, point);
        }) as TRet<Bytes[]>;
      },
      finalize(
        input: TArg<Bytes>,
        blind: TArg<ScalarBytes>,
        evaluated: TArg<PointBytes>,
        blinded: TArg<PointBytes>,
        proof: TArg<Bytes>,
        tweakedKey: TArg<PointBytes>
      ): TRet<Bytes> {
        return this.finalizeBatch([{ input, blind, evaluated, blinded }], proof, tweakedKey)[0];
      },
      evaluate(secretKey: TArg<ScalarBytes>, input: TArg<Bytes>): TRet<Bytes> {
        input = inputBytes('input', input);
        const skS = Fn.fromBytes(secretKey);
        const inputPoint = hashToGroup(input, ctxPOPRF);
        if (inputPoint.equals(Point.ZERO)) throw new Error('Input point at infinity');
        const t = Fn.add(skS, m);
        const invT = Fn.inv(t);
        const unblinded = inputPoint.multiply(invT).toBytes();
        return hashInput(input, info, unblinded);
      },
    });
  };
  const res = { name, oprf, voprf, poprf, __tests: Object.freeze({ Fn }) };
  return Object.freeze(res) as TRet<OPRF>;
}
