"""
Eval Loop Runner — generated by aiwg nlp new
Pattern: simple-chain

Run: python eval/eval.py [--cases eval/cases.jsonl] [--threshold 0.85]
"""

from __future__ import annotations

import argparse
import json
import sys
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path

import anthropic

# Adjust imports based on your pipeline location
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

PROMPTS_DIR = Path(__file__).parent.parent / "prompts"
RESULTS_FILE = Path(__file__).parent / "results.jsonl"
SUMMARY_FILE = Path(__file__).parent / "summary.json"

EVAL_MODEL = "claude-haiku-4-5"
MAX_TOKENS_EVAL = 256


def load_evaluator_prompt() -> tuple[str, str]:
    """Load the evaluator prompt. Strict isolation — no generator context."""
    path = PROMPTS_DIR / "evaluator.prompt.md"
    if not path.exists():
        raise FileNotFoundError(
            f"Evaluator prompt not found at {path}. "
            "Create prompts/evaluator.prompt.md (separate from generator)."
        )
    content = path.read_text(encoding="utf-8")
    # Strip frontmatter
    if content.startswith("---"):
        _, _, content = content.split("---", 2)
    system, user = "", ""
    section = None
    for line in content.splitlines():
        if line.strip() == "## System":
            section = "system"
        elif line.strip() == "## User":
            section = "user"
        elif section == "system":
            system += line + "\n"
        elif section == "user":
            user += line + "\n"
    return system.strip(), user.strip()


def evaluate_output(
    client: anthropic.Anthropic,
    eval_system: str,
    eval_user_template: str,
    input_text: str,
    output: str,
) -> dict:
    """Run the isolated evaluator. ONLY passes input and output — no generator context."""
    user = eval_user_template.replace("{{input}}", input_text).replace(
        "{{output}}", json.dumps(output) if not isinstance(output, str) else output
    )
    response = client.messages.create(
        model=EVAL_MODEL,
        max_tokens=MAX_TOKENS_EVAL,
        system=eval_system,
        messages=[{"role": "user", "content": user}],
    )
    raw = response.content[0].text.strip()
    if raw.startswith("```"):
        lines = raw.splitlines()
        raw = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
    return json.loads(raw)


def run_eval(cases_path: Path, threshold: float, max_attempts: int) -> None:
    from pipeline import run as pipeline_run  # type: ignore[import]

    client = anthropic.Anthropic()
    eval_system, eval_user_template = load_evaluator_prompt()

    cases = [json.loads(line) for line in cases_path.read_text().splitlines() if line.strip()]
    run_id = f"eval-{cases_path.parent.parent.name}-{uuid.uuid4().hex[:8]}"
    now = datetime.now(timezone.utc).isoformat()

    results = []
    passed = 0

    for case in cases:
        case_id = case["id"]
        input_text = case["input"]
        print(f"  {case_id}...", end=" ", flush=True)

        attempts = 0
        last_result = None
        last_output = None

        for attempt in range(1, max_attempts + 1):
            attempts = attempt
            output = pipeline_run(input_text)
            eval_result = evaluate_output(
                client, eval_system, eval_user_template, input_text, output
            )
            last_result = eval_result
            last_output = output

            if eval_result.get("pass", False):
                break

        score = last_result.get("score", 0.0)
        did_pass = last_result.get("pass", False)
        if did_pass:
            passed += 1
            print(f"✓ ({score:.2f})")
        else:
            print(f"✗ ({score:.2f}) — {last_result.get('feedback', '')[:60]}")

        record = {
            "version": "1.0.0",
            "run_id": run_id,
            "case_id": case_id,
            "input": input_text,
            "output": last_output,
            "score": score,
            "pass": did_pass,
            "pass_threshold": threshold,
            "feedback": last_result.get("feedback", ""),
            "rubric_scores": last_result.get("rubric_scores", {}),
            "failure_category": last_result.get("failure_category"),
            "suggested_fix": last_result.get("suggested_fix"),
            "attempts": attempts,
            "evaluated_at": datetime.now(timezone.utc).isoformat(),
            "contamination_warning": False,
        }
        results.append(record)

    # Write results
    with RESULTS_FILE.open("a") as f:
        for r in results:
            f.write(json.dumps(r) + "\n")

    pass_rate = passed / len(cases) if cases else 0.0
    print(f"\n  {passed}/{len(cases)} passed ({pass_rate:.1%})")

    # Write summary
    top_failures = [r for r in results if not r["pass"]][:3]
    summary = {
        "run_id": run_id,
        "evaluated_at": now,
        "total_cases": len(cases),
        "passed": passed,
        "failed": len(cases) - passed,
        "pass_rate": pass_rate,
        "top_failures": [
            {"case_id": r["case_id"], "score": r["score"], "feedback": r["feedback"]}
            for r in top_failures
        ],
    }
    SUMMARY_FILE.write_text(json.dumps(summary, indent=2))

    if pass_rate < threshold:
        print(f"\n  ⚠ Pass rate {pass_rate:.1%} < threshold {threshold:.1%}")
        if top_failures:
            print("  Top recommended fix:")
            fix = top_failures[0].get("suggested_fix", "")
            if fix:
                print(f"    {fix}")
        sys.exit(1)
    else:
        print(f"  ✓ Pass rate {pass_rate:.1%} meets threshold {threshold:.1%}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run eval loop")
    parser.add_argument("--cases", default="eval/cases.jsonl", help="Test cases JSONL file")
    parser.add_argument("--threshold", type=float, default=0.85, help="Pass threshold (0.0-1.0)")
    parser.add_argument("--max-attempts", type=int, default=3, help="Max generation attempts per case")
    args = parser.parse_args()

    print(f"Running eval ({args.cases}, threshold={args.threshold})...")
    run_eval(Path(args.cases), args.threshold, args.max_attempts)
