#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = ["pydantic>=2.0"]
# ///
"""Score JSONL datasets with the reward function."""
from __future__ import annotations
import argparse
import statistics
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dataset.schema import load_examples, output_items_to_text
from reward import score_expansion_detailed
def score_file(path: Path) -> tuple[int, int, list[float], dict]:
total = 0
errors = 0
scores: list[float] = []
ratings: dict[str, int] = {}
try:
examples = load_examples(path)
except ValueError as e:
print(f" Error loading {path}: {e}")
return 0, 1, [], {}
for ex in examples:
total += 1
output_text = output_items_to_text(ex.output)
if not output_text:
errors += 1
continue
detail = score_expansion_detailed(ex.query, output_text)
score = detail["percentage"]
scores.append(score)
rating = detail["rating"]
ratings[rating] = ratings.get(rating, 0) + 1
return total, errors, scores, ratings
def main() -> int:
parser = argparse.ArgumentParser(description="Score QMD datasets")
parser.add_argument(
"paths",
nargs="*",
default=["finetune/data/*.jsonl"],
help="JSONL files or glob patterns (default: finetune/data/*.jsonl)",
)
args = parser.parse_args()
repo_root = Path(__file__).parent.parent.parent
files: list[Path] = []
for pattern in args.paths:
if "*" in pattern:
files.extend(repo_root.glob(pattern))
else:
files.append(repo_root / pattern)
files = [p for p in files if p.exists()]
if not files:
print("No files found to score.")
return 1
for path in sorted(files):
total, errors, scores, ratings = score_file(path)
if scores:
avg = statistics.mean(scores)
median = statistics.median(scores)
min_score = min(scores)
max_score = max(scores)
above_70 = sum(1 for s in scores if s >= 70.0)
pct_70 = above_70 / len(scores) * 100
print(
f"{path}: {len(scores)} scored, {errors} errors, "
f"avg {avg:.1f}, median {median:.1f}, min {min_score:.1f}, "
f"max {max_score:.1f}, >=70 {pct_70:.1f}%"
)
else:
print(f"{path}: 0 scored, {errors} errors")
if ratings:
rating_parts = [f"{k}:{v}" for k, v in sorted(ratings.items())]
print(f" ratings: {', '.join(rating_parts)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())