#!/usr/bin/env python3
from __future__ import annotations
import argparse
import glob
import hashlib
import os
import random
import re
import shlex
import socket
import subprocess
import sys
import textwrap
import threading
import time
import typing
import uuid

CACHE_DB_PREFIX = "ci_pg_restore_cache_"
SYNC_DB = "ci_pg_restore"
LOCK_TABLE = "ci_pg_restore_lock"
DB_TABLE = "ci_pg_restore_db"
LOCK_RETRY_SEC = 0.5
LOCK_JITTER = 0.1
LOCK_EXPIRATION_SEC = 30
LOCK_REFRESH_SEC = 5
CONNECT_TIMEOUT_SEC = 15
MAX_CLONE_TIME_SEC = 30
DO_GC_LOCK_NAME = "#do_gc"


#
# Tool entry point.
#
def main():
    parser = argparse.ArgumentParser(
        description="""
            A cached perf optimized restoration of a PostgreSQL database dump.

            To access the databases, the tool assumes that psql CLI is installed
            in the system, and PGUSER/PGPASSWORD are correctly configured.

            The tool uses the following observation: if we have a large rarely
            changed *.sql file with a database dump, and we need to restore it
            multiple times (in e.g. CI environment, when running integration
            test matrix), it's way faster to restore it only once in some cache
            "template" database and then run CREATE DATABASE WITH TEMPLATE to
            clone that cache database multiple times.

            While restoring, the tool behaves nicely in terms of concurrency: if
            multiple processes of the tool are running at the same time, they
            wait for each other.

            The tool also takes care of cleaning up the cache databases and the
            restored databases. This is typically useful in CI environment, when
            the databases are not needed anymore after the tests are finished.
            It also drops the databases having the name equals to the restored
            database name following a "~*" suffix (* means "any string"); this
            allows the tests in CI to create extra databases if needed.
        """,
        formatter_class=ParagraphFormatter,
    )
    parser.add_argument(
        "-d",
        "--dbname",
        type=str,
        required=True,
        help="the name of the database to restore; it will be pre-created empty prior to calling the restoration command",
    )
    parser.add_argument(
        "--cache-name",
        type=str,
        required=True,
        help='cache name; the tool will create an internal cached "template" database suffixed with the digest of that name plus any --deps-files or --deps-files-cmd involved',
    )
    parser.add_argument(
        "--deps-files",
        type=str,
        default=[],
        action="append",
        required=False,
        metavar='"some/**/file1.sql\\nfile2.sql\\n..."',
        help='a newline separated list of file name wildcards involved in the restoration (you may use this arg multiple times; all files must exist); the tool will create an internal cached "template" database suffixed with the digest of those files content',
    )
    parser.add_argument(
        "--deps-files-cmd",
        type=str,
        default=[],
        action="append",
        required=False,
        metavar="\"find . -name '*.sql'\"",
        help="a shell command which should return a newline separated list of all file names involved in the restoration; see --deps-files for details",
    )
    parser.add_argument(
        "--cache-db-max-age",
        type=str,
        required=False,
        metavar="42, 42s, 30m, 8h, 3d, 1w",
        help='if set, the cached "template" databases will be deleted if they are not used for that long',
    )
    parser.add_argument(
        "--restore-db-max-age",
        type=str,
        required=False,
        metavar="42, 42s, 30m, 8h, 3d, 1w",
        help="if set, the restored databases will be deleted if they are not used for that long",
    )
    parser.add_argument(
        "rest",
        type=str,
        nargs=argparse.REMAINDER,
        metavar="shell command",
        help="a shell command which will be run to restore the database if its cache is invalid; the database name will also be passed via PGDATABASE environment variable; you may call the regular psql here or use any other database migration tool",
    )
    args = parser.parse_args()

    dbname: str = args.dbname
    cache_name: str = args.cache_name
    deps_files: str = "\n".join(args.deps_files)
    deps_files_cmd: str = "\n".join(args.deps_files_cmd)
    cache_db_max_age: int | None = to_sec(args.cache_db_max_age)
    restore_db_max_age: int | None = to_sec(args.restore_db_max_age)
    rest: list[str] = args.rest

    if not rest:
        parser.error("The restoration command must be provided.")

    sync_db = SyncDB(name=SYNC_DB, cache_db_max_age=cache_db_max_age)
    sync_db.ensure_exists()

    digest, because = build_digest(
        cache_name=cache_name,
        deps_files=deps_files,
        deps_files_cmd=deps_files_cmd,
    )
    do_restore(
        sync_db=sync_db,
        dbname=dbname,
        digest=digest,
        because=because,
        rest=rest,
    )

    do_gc(
        sync_db=sync_db,
        cache_db_max_age=cache_db_max_age,
        restore_db_max_age=restore_db_max_age,
    )


#
# The main restoration logic of the tool.
#
def do_restore(
    *,
    sync_db: SyncDB,
    dbname: str,
    digest: str,
    because: str,
    rest: list[str],
) -> None:
    db = WorkDB(name=dbname, sync_db=sync_db)
    cache_db = WorkDB(name=f"{CACHE_DB_PREFIX}{digest}", sync_db=sync_db)

    unique_printer = UniquePrinter()
    while True:
        with db.try_lock() as (locked, lock_owner):
            if not locked:
                unique_printer.print_stderr(
                    f"Someone else - {lock_owner} - is updating the database {db.name} right now. Waiting for them to finish.",
                )
                time.sleep(LOCK_RETRY_SEC * (1 + LOCK_JITTER * random.random()))
                continue

            # Exclusively locked db now.
            unique_printer = UniquePrinter()
            while True:
                # We don't have the notion of shared (read) locks, we only have
                # exclusive (write) locking for simplicity and performance. To
                # simulate shared locking behavior ("many readers, one writer"),
                # we use the assumption that cache database, once it's
                # confirmed, is immutable. So we can safely restore from it
                # without acquiring a lock if there is no gc happening anytime
                # soon (i.e. it was touched not too long ago).
                if cache_db.is_confirmed_and_not_for_gc_soon():
                    with Measure(
                        f"Restoring {db.name} from cache template {cache_db.name} (fast)..."
                    ):
                        db.clone_from(template=cache_db)
                    print_stderr(f"Hint: restored from the cache, {because}")
                    return

                with cache_db.try_lock() as (locked, lock_owner):
                    if not locked:
                        unique_printer.print_stderr(
                            f"Someone else - {lock_owner} - is updating the cache {cache_db.name} right now. Waiting for them to finish.",
                        )
                        time.sleep(LOCK_RETRY_SEC * (1 + LOCK_JITTER * random.random()))
                        continue

                    if cache_db.is_confirmed_and_not_for_gc_soon():
                        continue

                    # Exclusively locked cache_db now.
                    with Measure(
                        f"Restoring {db.name} from passed shell command (slow)..."
                    ):
                        db.ensure_absent()
                        db.ensure_exists()
                        subprocess.check_call(
                            rest,
                            env={
                                **os.environ,
                                "PGDATABASE": dbname,
                                "PGCONNECT_TIMEOUT": str(CONNECT_TIMEOUT_SEC),
                            },
                            shell=len(rest) == 1,
                        )

                    locked()
                    cache_db.clone_from(template=db)
                    cache_db.confirm()
                    return


#
# Garbage collection: removes databases touched too long ago.
#
def do_gc(
    *,
    sync_db: SyncDB,
    cache_db_max_age: int | None,
    restore_db_max_age: int | None,
):
    dbs = sync_db.db_list_with_touch_age()
    if not dbs:
        return
    with sync_db.try_lock(name=DO_GC_LOCK_NAME) as (locked, lock_owner):
        if not locked:
            print_stderr(
                f"Someone else - {lock_owner} - is running garbage collection now, so skipping.",
            )
            return
        for db, age in dbs:
            max_age = (
                cache_db_max_age
                if db.name.startswith(CACHE_DB_PREFIX)
                else restore_db_max_age
            )
            if not max_age or age < max_age:
                continue
            with db.try_lock() as (locked, lock_owner):
                if not locked:
                    print_stderr(
                        f"Someone else - {lock_owner} - is holding {db.name} lock now, so skipping cache removal."
                    )
                    continue
                # We need to reload the db age with the lock held, otherwise
                # there may be a race condition with do_restore() if it fully
                # finished the restoration in between db_list_with_touch_age()
                # and try_lock() calls above.
                db_reloaded_with_lock = sync_db.db_list_with_touch_age(name=db.name)
                if not db_reloaded_with_lock:
                    print_stderr(
                        f"Someone else deleted {db.name} already, so skipping cache removal."
                    )
                    continue
                db, age = db_reloaded_with_lock[0]
                if age < max_age:
                    print_stderr(
                        f"Someone else touched {db.name} recently, so skipping cache removal."
                    )
                    continue
                suffix = (
                    ""
                    if db.name.startswith(CACHE_DB_PREFIX)
                    else f" (and {db.name}~* if any)"
                )
                with Measure(
                    f"Removing old database {db.name}{suffix} with age {age} sec...",
                    defer=True,
                ):
                    db.ensure_absent()


#
# Builds a combined digest of all files content and the passed cache name. If
# only cache_name is passed, and it's a short alphanumeric string, then it's
# returned directly.
#
def build_digest(
    *,
    cache_name: str,
    deps_files: str,
    deps_files_cmd: str,
) -> tuple[str, str]:
    files = resolve_deps(deps_files=deps_files, deps_files_cmd=deps_files_cmd)
    if cache_name and re.match(r"^[a-zA-Z0-9_]{1,32}$", cache_name) and not files:
        because = (
            f"as the cache with static name {cache_name} has been saved "
            + f"in an earlier database migration job."
        )
        return (cache_name, because)
    m = hashlib.sha256(cache_name.encode())
    for file in files:
        with open(file, "r") as f:
            m.update(file.encode())
            m.update(f.read().encode())
    digest = m.hexdigest()[0:16]
    num_files = 4
    because = (
        f"as the contents of files ["
        + ", ".join([os.path.basename(f) for f in files[0:num_files]])
        + (", ..." if len(files) > num_files else "")
        + f"] match those used in an earlier database migration job. "
        + f"Change the files to invalidate the cache."
    )
    return (digest, because)


#
# Resolves --deps-* arguments to the plain list of files.
#
def resolve_deps(
    *,
    deps_files: str,
    deps_files_cmd: str,
) -> list[str]:
    files: list[str] = []
    for line in [v.strip() for v in deps_files.splitlines() if v.strip()]:
        if "*" in line or "?" in line or "[" in line:
            files.extend(glob.glob(line, recursive=True))
        else:
            files.append(line)
    if deps_files_cmd:
        res = subprocess.check_output(deps_files_cmd, text=True, shell=True)
        files.extend([v.strip() for v in res.splitlines() if v.strip()])
    return sorted(set(files))


#
# Converts a time interval to seconds.
#
def to_sec(text: str | None) -> int | None:
    if not text:
        return None
    specs = {
        "s": 1,
        "m": 60,
        "h": 3600,
        "d": 3600 * 24,
        "w": 3600 * 24 * 7,
    }
    matches = re.findall(rf"(\d+)\s*([{''.join(specs.keys())}]?)", text.lower())
    if not matches:
        raise UserException(
            f'Invalid time interval: "{text}"; valid examples: 42, 42s, 10m, 1h10m, 2d, 3w.'
        )
    sec = 0
    for value, unit in matches:
        unit = unit or "s"
        sec += specs[unit] * int(value)
    return sec


#
# Prints a message to stderr.
#
def print_stderr(msg: str):
    print(msg, file=sys.stderr)


#
# Custom user exceptions.
#
class UserException(Exception):
    pass


#
# Some generic database.
#
class DB:
    name: str

    def __init__(self, *, name: str):
        self.name = name

    def __repr__(self) -> str:
        return f"DB({self.name})"

    def global_db(self):
        # In case we are calling ensure_absent() for the SYNC_DB itself (it is
        # not a no-op only during the very 1st run of ci-pg-restore, and is a
        # no-op further), use template1 as a DB to connect to. Otherwise, use
        # SYNC_DB (since template1 must have 0 connections: otherwise, no-one is
        # able to clone from it).
        return DB(name="template1") if self.name == SYNC_DB else DB(name=SYNC_DB)

    def ensure_exists(self):
        self.global_db().psql(
            "CREATE DATABASE %",
            self.name,
            idempotence_check=[
                "SELECT true FROM pg_database WHERE datname=?",
                self.name,
            ],
        )

    def ensure_absent(self):
        extra_names = (
            []
            if self.name.startswith(CACHE_DB_PREFIX)
            else self.global_db().psql(
                "SELECT datname FROM pg_database WHERE starts_with(datname, ?)",
                self.name + "~",
            )
        )
        for name in [self.name, *[row[0] for row in extra_names if row[0]]]:
            # See https://www.postgresql.org/docs/current/sql-dropdatabase.html
            db = DB(name=name)

            prepared_xacts = self.global_db().psql(
                "SELECT gid FROM pg_prepared_xacts WHERE database=?",
                name,
            )
            for xact in [row[0] for row in prepared_xacts]:
                db.psql("ROLLBACK PREPARED ?", xact)

            subscriptions = self.global_db().psql(
                """
                SELECT subname, datname
                FROM pg_subscription
                JOIN pg_database ON pg_database.oid=subdbid
                WHERE datname=?
                """,
                name,
            )
            for sub in [row[0] for row in subscriptions]:
                db.psql(
                    """
                    SET statement_timeout TO '20s';
                    ALTER SUBSCRIPTION % DISABLE;
                    ALTER SUBSCRIPTION % SET (slot_name=NONE);
                    DROP SUBSCRIPTION % CASCADE;
                    """,
                    sub,
                    sub,
                    sub,
                )

            replication_slots = self.global_db().psql(
                "SELECT slot_name FROM pg_replication_slots WHERE database=?",
                name,
            )
            for slot in [row[0] for row in replication_slots]:
                db.psql(
                    """
                    SELECT pg_terminate_backend(active_pid)
                        FROM pg_replication_slots
                        WHERE slot_name=? AND active_pid IS NOT NULL;
                    SELECT pg_drop_replication_slot(?);
                    """,
                    slot,
                    slot,
                )

            self.global_db().psql(
                "DROP DATABASE IF EXISTS % WITH (FORCE)",
                name,
                idempotence_check=[
                    "SELECT count(1)=0 FROM pg_database WHERE datname=?",
                    name,
                ],
            )

    # Runs a psql command against the specified database.
    # - Replaces ? with the quoted literal values from the args list.
    # - Replaces % with the quoted identifier values from the args list.
    # - If idempotence_check is passed, then it is first called. If it returns
    #   true or 1, then the query is skipped.
    # - In case of any error, runs idempotence_check query, and if it returns
    #   one row with one column with 1 or true, then treats the original query
    #   as successful.
    # - Returns raw psql response as a list of rows, each row is a list of
    #   string values (or if idempotence_check succeeded after an error, an
    #   empty list).
    # - NULLs in the returned data are returned as empty strings.
    def psql(
        self,
        sql: str,
        *args: str,
        idempotence_check: list[str] | None = None,
    ) -> list[list[str]]:
        if idempotence_check:
            res = self.psql(*idempotence_check)
            if res == [["1"]] or res == [["t"]]:
                return []
        args_list = list(args)
        sql = textwrap.dedent(re.sub(r"^ *\n", "", sql)).rstrip()
        try:
            res = subprocess.check_output(
                [
                    "psql",
                    "-d",
                    self.name,
                    "-vON_ERROR_STOP=1",
                    "-Pnull=",
                    "-tAq",
                    "-c",
                    re.sub(
                        r"[?%]",
                        lambda m: (
                            self.quote_literal(args_list.pop(0))
                            if m.group() == "?"
                            else self.quote_ident(args_list.pop(0))
                        ),
                        sql,
                    ),
                ],
                text=True,
                stderr=subprocess.PIPE,
                env={**os.environ, "PGCONNECT_TIMEOUT": str(CONNECT_TIMEOUT_SEC)},
            )
            # - no rows         -> ""          -> []
            # - one empty value -> "\n"        -> [[""]]
            # - non-empty value -> "one|two\n" -> [["one", "two"]]
            return [line.split("|") for line in res.splitlines()]
        except subprocess.CalledProcessError as e:
            if not idempotence_check:
                raise
            try:
                res = self.psql(idempotence_check[0], *idempotence_check[1:])
                if res != [["1"]] and res != [["t"]]:
                    raise e
                return []
            except subprocess.CalledProcessError:
                raise e from None

    @staticmethod
    def quote_ident(s: str) -> str:
        return '"' + s.replace('"', '""') + '"'

    @staticmethod
    def quote_literal(s: str) -> str:
        return "'" + s.replace("'", "''") + "'"


#
# A database which is a target for restoration, or a restoration cache. Such
# databases are subject for garbage collection, so right before they are used,
# we touch them in sync_db.
#
class WorkDB(DB):
    sync_db: SyncDB

    def __init__(self, *, name: str, sync_db: SyncDB):
        super().__init__(name=name)
        self.sync_db = sync_db

    def ensure_exists(self):
        self.sync_db.db_touch(name=self.name)
        super().ensure_exists()

    def ensure_absent(self):
        super().ensure_absent()
        self.sync_db.db_deregister(name=self.name)

    def confirm(self):
        self.sync_db.db_confirm(name=self.name)

    def is_confirmed_and_not_for_gc_soon(self) -> bool:
        return self.sync_db.db_is_confirmed_and_not_for_gc_soon(name=self.name)

    def clone_from(self, *, template: WorkDB):
        self.ensure_absent()
        self.sync_db.db_touch(name=self.name)
        self.sync_db.db_touch(name=template.name)
        # It may be pgbouncer who keeps the connection open after the
        # restoration shell script has finished.
        self.global_db().psql(
            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname=?",
            template.name,
        )
        self.global_db().psql(
            "CREATE DATABASE % WITH TEMPLATE = %",
            self.name,
            template.name,
        )

    def try_lock(self):
        return self.sync_db.try_lock(name=self.name)


#
# A database used for internal purposes, to sync against race conditions in
# concurrent processes. Exposes the notion of exclusive lock. Only one process
# will be able to hold that lock and modify the cache "template" database. Also,
# it holds a registry for all WorkDBs to be able to garbage collect them.
#
class SyncDB(DB):
    owner: str = f"{socket.gethostname()}/{os.getpid()}/{uuid.uuid4().hex[0:8]}"
    locked_names: set[str] = set()
    cache_db_max_age: int | None
    thread: threading.Thread | None = None

    def __init__(self, *, name: str, cache_db_max_age: int | None):
        super().__init__(name=name)
        self.cache_db_max_age = cache_db_max_age
        self.thread = threading.Thread(
            target=lambda: self._lock_refresh_loop(),
            daemon=True,
        )
        self.thread.start()

    def ensure_exists(self):
        super().ensure_exists()
        self.psql(
            """
            CREATE TABLE % (
                name varchar(1024) NOT NULL UNIQUE,
                owner varchar(1024) NOT NULL,
                acquired_at timestamptz NOT NULL
            );
            CREATE TABLE % (
                name varchar(1024) NOT NULL UNIQUE,
                created_by varchar(1024) NOT NULL,
                created_at timestamptz NOT NULL,
                touched_by varchar(1024) NOT NULL,
                touched_at timestamptz NOT NULL,
                confirmed_by varchar(1024),
                confirmed_at timestamptz
            );
            """,
            LOCK_TABLE,
            DB_TABLE,
            idempotence_check=["SELECT true FROM pg_class WHERE relname=?", LOCK_TABLE],
        )

    def db_touch(self, *, name: str):
        self.psql(
            """
            INSERT INTO % (name, created_by, created_at, touched_by, touched_at)
                VALUES (?, ?, now(), ?, now())
            ON CONFLICT (name) DO
                UPDATE SET touched_by=EXCLUDED.touched_by, touched_at=now()
                WHERE %.name=EXCLUDED.name
            """,
            DB_TABLE,
            name,
            self.owner,
            self.owner,
            DB_TABLE,
        )

    def db_deregister(self, *, name: str):
        self.psql("DELETE FROM % WHERE name=?", DB_TABLE, name)

    def db_confirm(self, *, name: str):
        self.psql(
            """
            INSERT INTO % (name, created_by, created_at, touched_by, touched_at, confirmed_by, confirmed_at)
                VALUES (?, ?, now(), ?, now(), ?, now())
            ON CONFLICT (name) DO
                UPDATE SET touched_by=EXCLUDED.touched_by, touched_at=EXCLUDED.touched_at,
                    confirmed_by=EXCLUDED.confirmed_by, confirmed_at=EXCLUDED.confirmed_at
                WHERE %.name=EXCLUDED.name
            """,
            DB_TABLE,
            name,
            self.owner,
            self.owner,
            self.owner,
            DB_TABLE,
        )

    # If the database is touched too long ago and, thus, is too close to being
    # deleted by garbage collection, we don't treat it as confirmed, to
    # eliminate races with the garbage collector.
    def db_is_confirmed_and_not_for_gc_soon(self, *, name: str) -> bool:
        return (
            self.psql(
                """
                SELECT confirmed_at IS NOT NULL
                FROM %
                WHERE name=? AND (now() - touched_at) < interval ?
                """,
                DB_TABLE,
                name,
                (
                    f"{self.cache_db_max_age - MAX_CLONE_TIME_SEC} seconds"
                    if self.cache_db_max_age
                    else "10 years"
                ),
            )
            == [["t"]]
        )

    def db_list_with_touch_age(
        self, *, name: str | None = None
    ) -> list[tuple[WorkDB, int]]:
        res = self.psql(
            f"""
            SELECT name, EXTRACT(EPOCH FROM (now() - touched_at))::integer
            FROM % {f'WHERE name=?' if name else ''}
            ORDER BY name
            """,
            DB_TABLE,
            name or "",
        )
        return [(WorkDB(name=row[0], sync_db=self), int(row[1])) for row in res]

    # Returns a with-statement Locker context manager which locks (it it can)
    # the specified name and unlocks it on cleanup (it it was locked). The lock
    # is extended time to time in a background thread. The context return object
    # is a tuple: (locked, lock_owner) where locked is either None
    # (locking did not succeed) or a callable to recheck the lock and fail if it
    # was lost for some reason.
    def try_lock(self, *, name: str):
        class Locker:
            @staticmethod
            def __enter__() -> tuple[None | typing.Callable[[], None], str]:
                locked, lock_owner = self._lock_try(name=name)
                return (
                    (lambda: self._lock_assert_relocked(name=name)) if locked else None,
                    lock_owner,
                )

            @staticmethod
            def __exit__(*_: typing.Any):
                self._lock_release(name=name)

        return Locker()

    def _lock_try(self, *, name: str) -> tuple[bool, str]:
        self.psql(
            "DELETE FROM % WHERE acquired_at < now() - interval ?",
            LOCK_TABLE,
            f"{LOCK_EXPIRATION_SEC} seconds",
        )
        lock_owner = self.psql(
            """
            INSERT INTO % (name, owner, acquired_at)
                VALUES (?, ?, now())
            ON CONFLICT (name) DO
                UPDATE SET name=EXCLUDED.name WHERE %.name=EXCLUDED.name
            RETURNING owner
            """,
            LOCK_TABLE,
            name,
            self.owner,
            LOCK_TABLE,
        )[0][0]
        if lock_owner != self.owner:
            return (False, lock_owner)
        self.locked_names.add(name)
        return (True, lock_owner)

    def _lock_release(self, *, name: str):
        if name not in self.locked_names:
            return
        self.psql(
            "DELETE FROM % WHERE name=? AND owner=?",
            LOCK_TABLE,
            name,
            self.owner,
        )
        self.locked_names.remove(name)

    def _lock_assert_relocked(self, *, name: str):
        res = self.psql(
            "UPDATE % SET acquired_at=now() WHERE name=? AND owner=? RETURNING true",
            LOCK_TABLE,
            name,
            self.owner,
        )
        if res != [["t"]] and name in self.locked_names:
            self.locked_names.remove(name)
            raise UserException("The process has suddenly lost the lock.")

    def _lock_refresh_loop(self):
        while True:
            for name in list(self.locked_names):
                try:
                    self._lock_assert_relocked(name=name)
                except Exception as e:
                    print_stderr(f"Error while re-locking {name}: {e}")
            time.sleep(LOCK_REFRESH_SEC)


#
# A helper class for ArgumentParser.
#
class ParagraphFormatter(argparse.HelpFormatter):
    def _fill_text(self, text: str, width: int, indent: str) -> str:
        text = re.sub(r"^ *\n", "", text)
        return "\n\n".join(
            [
                textwrap.indent(textwrap.fill(paragraph, width), indent)
                for paragraph in textwrap.dedent(text).split("\n\n")
            ]
        )


#
# A helper context manager to print measured time.
#
class Measure:
    pre: str
    defer: bool
    start: float

    def __init__(self, pre: str, *, defer: bool = False):
        self.pre = pre
        self.defer = defer

    def __enter__(self):
        self.start = time.monotonic()
        if not self.defer:
            print_stderr(self.pre)
        return self

    def __exit__(self, e: typing.Any, *_: typing.Any):
        sec = time.monotonic() - self.start
        suffix = ("failed" if e else "succeeded") + f" in {sec:.2f} sec."
        if self.defer:
            pre = f"Was running: {self.pre[0].lower()}{self.pre[1:]}" if e else self.pre
            print_stderr(pre + f" {suffix}")
        else:
            print_stderr(f"...{suffix}")


#
# Prints a potentially repetitive message only once.
#
class UniquePrinter:
    prev_msg: str | None = None

    def print_stderr(self, msg: str):
        if self.prev_msg != msg:
            print_stderr(msg)
            self.prev_msg = msg


#
# Script entry point.
#
if __name__ == "__main__":
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        sys.exit(1)
    except UserException as e:
        print_stderr(f"Error: {e}")
        sys.exit(1)
    except subprocess.CalledProcessError as e:
        print_stderr(
            f"$ {shlex.join(e.cmd).strip()}\n"
            + textwrap.indent(
                f"Error: command returned status {e.returncode}."
                + (f"\n{e.stdout}" if e.stdout else "")
                + (f"\n{e.stderr}" if e.stderr else ""),
                prefix="  ",
            ),
        )
        sys.exit(2)
