#!/usr/bin/env python3
"""Voice transcription for the GoNext local worker (Audio support, task #21).

Reads a JSON job on stdin: {"audioPath": str, "model": str, "language": str?}.
Transcribes with mlx-whisper and prints {"text": "..."} on stdout.

Design notes:
- The browser sends a 16 kHz mono 16-bit PCM WAV, so we decode it with the
  stdlib `wave` module + numpy and hand mlx-whisper a float32 array. Passing an
  array (not a file path) skips mlx-whisper's ffmpeg-based loader, so NO ffmpeg
  install is required on the worker.
- The model is resolved by HF repo id ("mlx-community/<model>") from the HF hub
  cache (~/.cache/huggingface/hub, honoring HF_HOME/HF_HUB_CACHE). This is the
  same location the wizard installs to and the readiness probe checks.
- On failure we exit non-zero and print a machine-parseable reason token as the
  first stderr line: "lib-missing" | "model-missing" | "audio-error" | "error".
"""

import json
import sys
import wave


TARGET_SR = 16000


def fail(reason: str, detail: str = "") -> None:
    sys.stderr.write(reason + "\n")
    if detail:
        sys.stderr.write(detail + "\n")
    sys.exit(1)


def load_wav_mono_16k(path: str):
    """Decode a PCM WAV to a float32 mono array at 16 kHz without ffmpeg."""
    import numpy as np

    with wave.open(path, "rb") as w:
        n_channels = w.getnchannels()
        sampwidth = w.getsampwidth()
        framerate = w.getframerate()
        n_frames = w.getnframes()
        raw = w.readframes(n_frames)

    if sampwidth == 2:
        audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
    elif sampwidth == 4:
        audio = np.frombuffer(raw, dtype=np.int32).astype(np.float32) / 2147483648.0
    elif sampwidth == 1:
        audio = (np.frombuffer(raw, dtype=np.uint8).astype(np.float32) - 128.0) / 128.0
    else:
        raise ValueError(f"unsupported WAV sample width: {sampwidth} bytes")

    if n_channels > 1:
        audio = audio.reshape(-1, n_channels).mean(axis=1)

    if framerate != TARGET_SR and audio.size > 0:
        # Linear resample; the browser already sends 16 kHz so this is a safety net.
        duration = audio.size / float(framerate)
        target_len = int(round(duration * TARGET_SR))
        if target_len > 0:
            src_x = np.linspace(0.0, 1.0, num=audio.size, endpoint=False)
            dst_x = np.linspace(0.0, 1.0, num=target_len, endpoint=False)
            audio = np.interp(dst_x, src_x, audio).astype(np.float32)

    return np.ascontiguousarray(audio, dtype=np.float32)


def main() -> None:
    try:
        payload = json.loads(sys.stdin.read() or "{}")
    except Exception as e:  # noqa: BLE001
        fail("error", f"invalid job json: {e}")
        return

    audio_path = str(payload.get("audioPath") or "").strip()
    model = str(payload.get("model") or "whisper-large-v3-turbo").strip()
    language = str(payload.get("language") or "").strip() or None
    if not audio_path:
        fail("error", "missing audioPath")
        return

    try:
        import mlx_whisper  # noqa: F401
    except Exception as e:  # noqa: BLE001
        fail("lib-missing", str(e))
        return

    try:
        audio = load_wav_mono_16k(audio_path)
    except Exception as e:  # noqa: BLE001
        fail("audio-error", str(e))
        return

    repo = f"mlx-community/{model}"
    try:
        result = mlx_whisper.transcribe(
            audio,
            path_or_hf_repo=repo,
            task="transcribe",
            language=language,
        )
    except Exception as e:  # noqa: BLE001
        msg = str(e).lower()
        if any(
            token in msg
            for token in (
                "not found",
                "repository",
                "couldn't find",
                "could not find",
                "no such file",
                "connection",
                "offline",
                "resolve",
                "404",
            )
        ):
            fail("model-missing", str(e))
        else:
            fail("error", str(e))
        return

    text = str(result.get("text", "")).strip() if isinstance(result, dict) else ""
    sys.stdout.write(json.dumps({"text": text}))
    sys.stdout.flush()


if __name__ == "__main__":
    main()
