"""Generate LoCoMo benchmark comparison report across multiple runs.
Usage:
uv run python tests/e2e/generate_report.py --runs run1 run2 run3 run4 run5 run6 run7 run8
uv run python tests/e2e/generate_report.py --runs run3 run5 run8
"""
import argparse
import csv
import json
import os
import sys
from collections import defaultdict
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import numpy as np
_cjk_font = "/System/Library/Fonts/STHeiti Medium.ttc"
if os.path.exists(_cjk_font):
fm.fontManager.addfont(_cjk_font)
plt.rcParams["font.family"] = fm.FontProperties(fname=_cjk_font).get_name()
plt.rcParams["axes.unicode_minus"] = False
RESULT_DIR = Path(__file__).parent / "result"
CATEGORY_NAMES = {
"1": "单会话",
"2": "跨会话",
"3": "推理",
"4": "遗忘",
}
def load_run(run_name: str) -> dict[str, list[dict]]:
"""Load all judged CSVs for a run, keyed by conv-id."""
run_dir = RESULT_DIR / run_name
if not run_dir.exists():
return {}
conv_data = {}
for f in sorted(run_dir.glob("*_judged.csv")):
rows = []
with open(f, encoding="utf-8") as fh:
for row in csv.DictReader(fh):
rows.append(row)
if rows:
conv = rows[0].get("sample_id", f.stem)
conv_data[conv] = rows
return conv_data
def get_accuracy(rows: list[dict]) -> tuple[int, int, float]:
correct = sum(1 for r in rows if r.get("is_correct") == "CORRECT")
total = len(rows)
pct = correct / total * 100 if total else 0
return correct, total, pct
def classify_error(row: dict) -> str:
"""Classify error type for a wrong answer."""
response = row.get("response", "").lower()
evidence = row.get("retrieved_evidence", "")
if "网络" in response or "network" in response or "timed out" in response or "connection error" in response:
return "网络错误"
if "400" in response and ("安全" in response or "敏感" in response or "unsafe" in response.lower()):
return "安全过滤"
if "don't" in response or "doesn't" in response or "no information" in response or "no record" in response or "无法" in response or "没有" in response:
if evidence and "Retrieved Memories" in evidence:
return "有检索但不知道"
return "无检索"
if not evidence or "Retrieved Memories" not in evidence:
return "无检索"
return "答错"
def sub_classify_wrong(row: dict) -> str:
"""Sub-classify wrong answers."""
response = row.get("response", "").lower()
expected = row.get("expected", "").lower()
exp_words = set(expected.split())
resp_words = set(response.split())
overlap = exp_words & resp_words - {"the", "a", "an", "in", "on", "at", "to", "of", "and", "is", "was", "for"}
time_words = {"january", "february", "march", "april", "may", "june", "july",
"august", "september", "october", "november", "december",
"2022", "2023", "2024", "monday", "tuesday", "wednesday",
"thursday", "friday", "saturday", "sunday"}
has_time_q = bool(time_words & set(response.split()))
has_time_e = bool(time_words & set(expected.split()))
if overlap and len(overlap) >= 2:
return "部分正确"
if has_time_q and has_time_e:
return "时间错误"
return "幻觉/其他"
def check_evidence_has_answer(row: dict) -> str:
"""Check if evidence contains the answer."""
evidence = row.get("retrieved_evidence", "")
expected = row.get("expected", "").strip().lower()
if not evidence or "Retrieved Memories" not in evidence:
return "无Evidence"
if not expected:
return "不确定"
ev_lower = evidence.lower()
key_words = [w for w in expected.split() if len(w) > 3 and w not in
{"the", "and", "with", "from", "that", "this", "was", "were", "has", "had", "have", "been", "into", "about", "which", "their", "they", "them", "what", "when", "where", "who", "how"}]
matches = sum(1 for w in key_words if w in ev_lower)
if matches >= len(key_words) * 0.7 and len(key_words) >= 2:
return "Evidence含完整答案"
elif matches >= len(key_words) * 0.3:
return "Evidence含部分答案"
else:
return "Evidence不含答案"
def plot_overall(runs_data: dict, run_names: list, out_dir: Path):
fig, ax = plt.subplots(figsize=(10, 5))
names, accs = [], []
for rn in run_names:
if rn not in runs_data or not runs_data[rn]:
continue
all_rows = [r for rows in runs_data[rn].values() for r in rows]
_, _, pct = get_accuracy(all_rows)
names.append(rn)
accs.append(pct)
bars = ax.bar(names, accs, color=plt.cm.viridis(np.linspace(0.3, 0.9, len(names))))
ax.set_ylabel("Accuracy (%)")
ax.set_title("Overall Accuracy by Run")
ax.set_ylim(60, 90)
for bar, val in zip(bars, accs):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
f"{val:.1f}%", ha="center", va="bottom", fontsize=10)
plt.tight_layout()
fig.savefig(out_dir / "01_overall_accuracy.png", dpi=150)
plt.close(fig)
def plot_per_sample(runs_data: dict, run_names: list, out_dir: Path):
all_convs = sorted(set(
conv for rn in run_names for conv in runs_data.get(rn, {})
))
fig, ax = plt.subplots(figsize=(14, 6))
x = np.arange(len(all_convs))
width = 0.8 / len(run_names)
for i, rn in enumerate(run_names):
accs = []
for conv in all_convs:
rows = runs_data.get(rn, {}).get(conv, [])
_, _, pct = get_accuracy(rows) if rows else (0, 0, 0)
accs.append(pct)
offset = (i - len(run_names)/2 + 0.5) * width
ax.bar(x + offset, accs, width, label=rn)
ax.set_xticks(x)
ax.set_xticklabels(all_convs, rotation=45)
ax.set_ylabel("Accuracy (%)")
ax.set_title("Per-Sample Accuracy")
ax.set_ylim(55, 95)
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fig.savefig(out_dir / "02_per_sample_accuracy.png", dpi=150)
plt.close(fig)
def plot_per_category(runs_data: dict, run_names: list, out_dir: Path):
categories = list(CATEGORY_NAMES.keys())
cat_labels = [CATEGORY_NAMES.get(c, c) for c in categories]
fig, ax = plt.subplots(figsize=(10, 5))
x = np.arange(len(categories))
width = 0.8 / len(run_names)
for i, rn in enumerate(run_names):
accs = []
for cat in categories:
all_rows = [r for rows in runs_data.get(rn, {}).values() for r in rows
if r.get("category", "").lower().replace(" ", "_") == cat]
_, _, pct = get_accuracy(all_rows) if all_rows else (0, 0, 0)
accs.append(pct)
offset = (i - len(run_names)/2 + 0.5) * width
ax.bar(x + offset, accs, width, label=rn)
ax.set_xticks(x)
ax.set_xticklabels(cat_labels)
ax.set_ylabel("Accuracy (%)")
ax.set_title("Per-Category Accuracy")
ax.set_ylim(50, 95)
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fig.savefig(out_dir / "03_per_category_accuracy.png", dpi=150)
plt.close(fig)
def plot_error_attribution(runs_data: dict, run_names: list, out_dir: Path):
error_types = ["安全过滤", "网络错误", "无检索", "有检索但不知道", "答错"]
fig, ax = plt.subplots(figsize=(10, 5))
x = np.arange(len(error_types))
width = 0.8 / len(run_names)
for i, rn in enumerate(run_names):
all_rows = [r for rows in runs_data.get(rn, {}).values() for r in rows
if r.get("is_correct") != "CORRECT"]
counts = defaultdict(int)
for r in all_rows:
counts[classify_error(r)] += 1
vals = [counts.get(et, 0) for et in error_types]
offset = (i - len(run_names)/2 + 0.5) * width
ax.bar(x + offset, vals, width, label=rn)
ax.set_xticks(x)
ax.set_xticklabels(error_types, rotation=15)
ax.set_ylabel("Count")
ax.set_title("Error Attribution")
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fig.savefig(out_dir / "04_error_attribution.png", dpi=150)
plt.close(fig)
def plot_cross_run_delta(runs_data: dict, run_names: list, out_dir: Path):
if len(run_names) < 2:
return
first, last = run_names[0], run_names[-1]
all_convs = sorted(set(runs_data.get(first, {}).keys()) | set(runs_data.get(last, {}).keys()))
fig, ax = plt.subplots(figsize=(14, 5))
deltas = []
labels = []
for conv in all_convs:
r1 = runs_data.get(first, {}).get(conv, [])
r2 = runs_data.get(last, {}).get(conv, [])
if r1 and r2:
_, _, p1 = get_accuracy(r1)
_, _, p2 = get_accuracy(r2)
deltas.append(p2 - p1)
labels.append(conv)
colors = ["green" if d >= 0 else "red" for d in deltas]
ax.bar(labels, deltas, color=colors, alpha=0.7)
ax.axhline(y=0, color="black", linewidth=0.5)
ax.set_ylabel(f"Accuracy Delta ({last} - {first})")
ax.set_title(f"Accuracy Change: {first} → {last}")
ax.tick_params(axis="x", rotation=45)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fig.savefig(out_dir / "07_cross_run_delta.png", dpi=150)
plt.close(fig)
def plot_question_consistency(runs_data: dict, run_names: list, out_dir: Path):
"""Show how many questions are always correct, always wrong, or flippy."""
q_results = defaultdict(list)
for rn in run_names:
for conv, rows in runs_data.get(rn, {}).items():
for r in rows:
q = r.get("question", "").strip()
correct = r.get("is_correct") == "CORRECT"
q_results[(conv, q)].append(correct)
always_correct = sum(1 for v in q_results.values() if all(v))
always_wrong = sum(1 for v in q_results.values() if not any(v))
flippy = sum(1 for v in q_results.values() if any(v) and not all(v))
fig, ax = plt.subplots(figsize=(8, 5))
labels = ["Always Correct", "Always Wrong", "Flippy"]
vals = [always_correct, always_wrong, flippy]
colors = ["#2ecc71", "#e74c3c", "#f39c12"]
ax.pie(vals, labels=labels, colors=colors, autopct="%1.1f%%", startangle=90)
ax.set_title(f"Question Consistency across {len(run_names)} runs")
plt.tight_layout()
fig.savefig(out_dir / "08_question_consistency.png", dpi=150)
plt.close(fig)
def plot_evidence_length(runs_data: dict, run_names: list, out_dir: Path):
fig, ax = plt.subplots(figsize=(10, 5))
data_correct = {}
data_wrong = {}
for rn in run_names:
c_lens, w_lens = [], []
for rows in runs_data.get(rn, {}).values():
for r in rows:
evid = r.get("retrieved_evidence", "")
l = len(evid)
if r.get("is_correct") == "CORRECT":
c_lens.append(l)
else:
w_lens.append(l)
data_correct[rn] = np.median(c_lens) if c_lens else 0
data_wrong[rn] = np.median(w_lens) if w_lens else 0
x = np.arange(len(run_names))
width = 0.35
ax.bar(x - width/2, [data_correct[rn] for rn in run_names], width, label="CORRECT", color="green", alpha=0.6)
ax.bar(x + width/2, [data_wrong[rn] for rn in run_names], width, label="WRONG", color="red", alpha=0.6)
ax.set_xticks(x)
ax.set_xticklabels(run_names)
ax.set_ylabel("Median Evidence Length (chars)")
ax.set_title("Evidence Length: Correct vs Wrong")
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
fig.savefig(out_dir / "09_evidence_length.png", dpi=150)
plt.close(fig)
def generate_report(run_names: list[str], output_dir: str | None = None):
runs_data = {}
for rn in run_names:
data = load_run(rn)
if data:
runs_data[rn] = data
print(f"Loaded {rn}: {len(data)} convs, {sum(len(v) for v in data.values())} questions")
else:
print(f"WARNING: {rn} has no data, skipping")
active_runs = [rn for rn in run_names if rn in runs_data]
if not active_runs:
print("No data found!")
sys.exit(1)
if output_dir:
out_dir = Path(output_dir)
else:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = RESULT_DIR / f"report_{ts}"
out_dir.mkdir(parents=True, exist_ok=True)
print("Generating plots...")
plot_overall(runs_data, active_runs, out_dir)
plot_per_sample(runs_data, active_runs, out_dir)
plot_per_category(runs_data, active_runs, out_dir)
plot_error_attribution(runs_data, active_runs, out_dir)
plot_cross_run_delta(runs_data, active_runs, out_dir)
plot_question_consistency(runs_data, active_runs, out_dir)
plot_evidence_length(runs_data, active_runs, out_dir)
print("Generating report...")
md = []
md.append("# LoCoMo Benchmark 分析报告\n")
md.append(f"Runs: {', '.join(active_runs)}\n")
md.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
md.append("## 1. 总览\n")
md.append("| Run | 总题数 | 正确 | 准确率 |")
md.append("|-----|--------|------|--------|")
for rn in active_runs:
all_rows = [r for rows in runs_data[rn].values() for r in rows]
c, t, pct = get_accuracy(all_rows)
md.append(f"| {rn} | {t} | {c} | **{pct:.1f}%** |")
md.append("")
md.append("\n")
md.append("## 2. 按 Sample 分析\n")
all_convs = sorted(set(conv for rn in active_runs for conv in runs_data[rn]))
header = "| Conv | " + " | ".join(active_runs) + " |"
sep = "|------|" + "|".join(["--------" for _ in active_runs]) + "|"
md.append(header)
md.append(sep)
for conv in all_convs:
vals = []
for rn in active_runs:
rows = runs_data[rn].get(conv, [])
if rows:
_, _, pct = get_accuracy(rows)
vals.append(f"{pct:.1f}%")
else:
vals.append("-")
md.append(f"| {conv} | " + " | ".join(vals) + " |")
md.append("")
md.append("\n")
md.append("## 3. 按类别分析\n")
categories = list(CATEGORY_NAMES.keys())
cat_labels = {c: CATEGORY_NAMES.get(c, c) for c in categories}
header = "| 类别 | " + " | ".join(active_runs) + " |"
sep = "|------|" + "|".join(["--------" for _ in active_runs]) + "|"
md.append(header)
md.append(sep)
for cat in categories:
vals = []
for rn in active_runs:
cat_rows = [r for rows in runs_data[rn].values() for r in rows
if r.get("category", "").lower().replace(" ", "_") == cat]
if cat_rows:
_, _, pct = get_accuracy(cat_rows)
vals.append(f"{pct:.1f}%")
else:
vals.append("-")
md.append(f"| {cat_labels[cat]} | " + " | ".join(vals) + " |")
md.append("")
md.append("\n")
md.append("## 4. 错误归因\n")
for rn in active_runs:
wrong_rows = [r for rows in runs_data[rn].values() for r in rows
if r.get("is_correct") != "CORRECT"]
total_errors = len(wrong_rows)
md.append(f"### {rn}\n")
md.append(f"总错误: {total_errors}\n")
error_counts = defaultdict(int)
for r in wrong_rows:
error_counts[classify_error(r)] += 1
md.append("| 错误类型 | 数量 | 占比 |")
md.append("|----------|------|------|")
for et in ["无检索", "网络错误", "安全过滤", "有检索但不知道", "答错"]:
cnt = error_counts.get(et, 0)
pct = cnt / total_errors * 100 if total_errors else 0
md.append(f"| {et} | {cnt} | {pct:.1f}% |")
md.append("")
wrong_answer_rows = [r for r in wrong_rows if classify_error(r) == "答错"]
if wrong_answer_rows:
sub_counts = defaultdict(int)
for r in wrong_answer_rows:
sub_counts[sub_classify_wrong(r)] += 1
md.append("#### 答错细分\n")
md.append("| 子类型 | 数量 | 占比 |")
md.append("|--------|------|------|")
for st in ["部分正确", "时间错误", "幻觉/其他"]:
cnt = sub_counts.get(st, 0)
pct = cnt / len(wrong_answer_rows) * 100 if wrong_answer_rows else 0
md.append(f"| {st} | {cnt} | {pct:.1f}% |")
md.append("")
dk_rows = [r for r in wrong_rows if classify_error(r) == "有检索但不知道"]
if dk_rows:
ev_counts = defaultdict(int)
for r in dk_rows:
ev_counts[check_evidence_has_answer(r)] += 1
md.append(f'#### "不知道"题目的 Evidence 分析\n')
md.append("| Evidence 含答案情况 | 数量 | 占比 |")
md.append("|-------------------|------|------|")
for et in ["Evidence含完整答案", "Evidence含部分答案", "Evidence不含答案", "无Evidence", "不确定"]:
cnt = ev_counts.get(et, 0)
pct = cnt / len(dk_rows) * 100 if dk_rows else 0
md.append(f"| {et} | {cnt} | {pct:.1f}% |")
md.append("")
md.append("#### 按 Category 错误分布\n")
cat_header = "| 错误类型 | " + " | ".join(cat_labels.values()) + " |"
cat_sep = "|----------|" + "|".join(["---|" for _ in categories])
md.append(cat_header)
md.append(cat_sep)
for et in ["无检索", "网络错误", "安全过滤", "有检索但不知道", "答错"]:
vals = []
for cat in categories:
cat_rows = [r for r in wrong_rows if r.get("category", "").lower().replace(" ", "_") == cat]
cnt = sum(1 for r in cat_rows if classify_error(r) == et)
vals.append(str(cnt))
md.append(f"| {et} | " + " | ".join(vals) + " |")
md.append("")
md.append("\n")
md.append("## 5. Token 消耗\n")
for rn in active_runs:
all_rows = [r for rows in runs_data[rn].values() for r in rows]
total_in = sum(int(r.get("input_tokens", 0)) for r in all_rows)
total_out = sum(int(r.get("output_tokens", 0)) for r in all_rows)
total_cache = sum(int(r.get("cacheRead", 0)) for r in all_rows)
total_tok = sum(int(r.get("total_tokens", 0)) for r in all_rows)
md.append(f"### {rn}\n")
md.append(f"- Input: {total_in:,}")
md.append(f"- Output: {total_out:,}")
md.append(f"- Cache Read: {total_cache:,}")
md.append(f"- Total: {total_tok:,}\n")
md.append("## 6. 跨 Run 波动\n")
if len(active_runs) >= 2:
first, last = active_runs[0], active_runs[-1]
first_map = {}
for conv, rows in runs_data[first].items():
for r in rows:
first_map[(conv, r.get("question","").strip())] = r.get("is_correct") == "CORRECT"
always_correct = always_wrong = flippy = 0
for conv, rows in runs_data[last].items():
for r in rows:
key = (conv, r.get("question","").strip())
if key in first_map:
last_correct = r.get("is_correct") == "CORRECT"
if first_map[key] and last_correct:
always_correct += 1
elif not first_map[key] and not last_correct:
always_wrong += 1
else:
flippy += 1
total = always_correct + always_wrong + flippy
md.append(f"对比: {first} → {last}\n")
md.append(f"- 始终正确: {always_correct} ({always_correct/total*100:.1f}%)")
md.append(f"- 始终错误: {always_wrong} ({always_wrong/total*100:.1f}%)")
md.append(f"- 翻转: {flippy} ({flippy/total*100:.1f}%)\n")
md.append("\n")
md.append("\n")
md.append("## 7. 检索质量\n")
for rn in active_runs:
all_rows = [r for rows in runs_data[rn].values() for r in rows]
no_evid = sum(1 for r in all_rows if "Retrieved Memories" not in r.get("retrieved_evidence", ""))
c_lens = [len(r.get("retrieved_evidence","")) for r in all_rows if r.get("is_correct") == "CORRECT"]
w_lens = [len(r.get("retrieved_evidence","")) for r in all_rows if r.get("is_correct") != "CORRECT"]
md.append(f"### {rn}\n")
md.append(f"- 无 Evidence 率: {no_evid/len(all_rows)*100:.1f}%")
md.append(f" - CORRECT: median evidence = {np.median(c_lens):.0f} chars")
md.append(f" - WRONG: median evidence = {np.median(w_lens):.0f} chars\n")
md.append("\n")
md.append("## 8. 数据源\n")
for rn in active_runs:
md.append(f"- `{RESULT_DIR / rn}/`")
report_path = out_dir / "report.md"
with open(report_path, "w", encoding="utf-8") as f:
f.write("\n".join(md))
print(f"\nReport generated: {report_path}")
return str(out_dir)
def main():
parser = argparse.ArgumentParser(description="Generate LoCoMo benchmark comparison report")
parser.add_argument("--runs", nargs="+", required=True, help="Run names to compare (e.g. run1 run2 run3)")
parser.add_argument("--output", "-o", help="Output directory (default: auto-generated)")
args = parser.parse_args()
generate_report(args.runs, args.output)
if __name__ == "__main__":
main()