#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["pydantic>=2.0"]
# ///
"""Score JSONL datasets with the reward function."""

from __future__ import annotations

import argparse
import statistics
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import load_examples, output_items_to_text
from reward import score_expansion_detailed


def score_file(path: Path) -> tuple[int, int, list[float], dict]:
    total = 0
    errors = 0
    scores: list[float] = []
    ratings: dict[str, int] = {}

    try:
        examples = load_examples(path)
    except ValueError as e:
        print(f"  Error loading {path}: {e}")
        return 0, 1, [], {}

    for ex in examples:
        total += 1
        output_text = output_items_to_text(ex.output)
        if not output_text:
            errors += 1
            continue

        detail = score_expansion_detailed(ex.query, output_text)
        score = detail["percentage"]
        scores.append(score)
        rating = detail["rating"]
        ratings[rating] = ratings.get(rating, 0) + 1

    return total, errors, scores, ratings


def main() -> int:
    parser = argparse.ArgumentParser(description="Score QMD datasets")
    parser.add_argument(
        "paths",
        nargs="*",
        default=["finetune/data/*.jsonl"],
        help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
    )
    args = parser.parse_args()

    repo_root = Path(__file__).parent.parent.parent
    files: list[Path] = []
    for pattern in args.paths:
        if "*" in pattern:
            files.extend(repo_root.glob(pattern))
        else:
            files.append(repo_root / pattern)

    files = [p for p in files if p.exists()]
    if not files:
        print("No files found to score.")
        return 1

    for path in sorted(files):
        total, errors, scores, ratings = score_file(path)
        if scores:
            avg = statistics.mean(scores)
            median = statistics.median(scores)
            min_score = min(scores)
            max_score = max(scores)
            above_70 = sum(1 for s in scores if s >= 70.0)
            pct_70 = above_70 / len(scores) * 100
            print(
                f"{path}: {len(scores)} scored, {errors} errors, "
                f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
                f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
            )
        else:
            print(f"{path}: 0 scored, {errors} errors")

        if ratings:
            rating_parts = [f"{k}:{v}" for k, v in sorted(ratings.items())]
            print(f"  ratings: {', '.join(rating_parts)}")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())