"""
Flash Attention Forward Competition — Judge Script
====================================================

Server-side grading script. DO NOT distribute to students.

Usage:
    python judge.py <submission_file.py>
    python judge.py <submission_file.py> --skip-perf
    python judge.py <submission_file.py> --score-file score.txt --log-file judge.log

Outputs:
    stdout      — single JSON line: {"verdict": "SCORE: N", "score": N, "comment": "<log>"}
    score.txt   — single integer grade (readable by web UI)
    judge.log   — full judging log with all messages
    results.json — (optional) structured results via --output-json

Grading logic:
    1. Anti-cheat static analysis  -> reject on violation
    2. Correctness gate: ALL tests must pass or grade = 0
    3. Performance grade = avg( reference_time_i / student_time_i ) x 100
"""

import argparse
import ast
import importlib.util
import os
import sys
import shutil
import csv
import json
import time
import traceback
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Tuple

import torch
import torch_npu
import triton

DEVICE = "npu"


# ============================================================================
# LOGGING: Capture stdout/stderr to a log file + in-memory buffer.
# Real stdout is preserved but NOT written to during the run, so the only
# thing the grader sees on stdout is the final JSON emitted by main().
# ============================================================================

class CapturingLogger:
    """Captures stdout/stderr to a log file and an in-memory buffer."""

    def __init__(self, log_path: str):
        self._stdout = sys.stdout
        self._stderr = sys.stderr
        self._log_file = open(log_path, "w")
        self._buffer = []

    def write(self, msg):
        self._log_file.write(msg)
        self._log_file.flush()
        self._buffer.append(msg)

    def flush(self):
        self._log_file.flush()

    def get_log(self) -> str:
        return "".join(self._buffer)

    def close(self):
        self._log_file.close()
        sys.stdout = self._stdout
        sys.stderr = self._stderr


# ============================================================================
# 1. ANTI-CHEAT: Static Analysis
# ============================================================================

def _dotted_name(node) -> Optional[str]:
    """Reconstruct dotted name from an AST Attribute or Name node."""
    if isinstance(node, ast.Name):
        return node.id
    if isinstance(node, ast.Attribute):
        parent = _dotted_name(node.value)
        if parent is not None:
            return f"{parent}.{node.attr}"
    return None


# torch_npu: block any attribute whose name contains 'attention' (case-insensitive)
# This catches npu_fusion_attention, npu_multi_head_attention, and any future variants.
TORCH_NPU_BLOCKED_PATTERN = "attention"

# torch / torch.nn.functional operations that trivialise the kernel
BLOCKED_TORCH_OPS = {
    "torch.matmul",
    "torch.bmm",
    "torch.softmax",
    "torch.einsum",
    "torch.nn.functional.softmax",
    "torch.nn.functional.scaled_dot_product_attention",
    "torch.nn.functional.multi_head_attention_forward",
}

# Third-party attention libraries
BLOCKED_DIRECT_CALLS = {
    "flash_attn_func",
    "flash_attn_varlen_func",
    "flash_attn_qkvpacked_func",
}


class CheatDetector(ast.NodeVisitor):
    """
    AST visitor that enforces:
      1. No torch_npu.*attention* calls outside protected functions
      2. No torch.matmul / torch.softmax / torch.bmm / torch.einsum outside protected functions
      3. No third-party attention library calls outside protected functions
      4. No eval/exec/dynamic import
    Protected functions (allowed to use reference ops): test_op, profiling
    """

    def __init__(self):
        self.violations: List[str] = []
        self._protected = False

    def visit_FunctionDef(self, node):
        # test_op and profiling may call torch_npu reference ops
        if node.name in {"test_op", "profiling"}:
            old = self._protected
            self._protected = True
            self.generic_visit(node)
            self._protected = old
        else:
            self.generic_visit(node)

    def visit_Attribute(self, node):
        if not self._protected:
            full = _dotted_name(node)
            if full is not None:
                # Rule 1: block torch_npu.*attention*
                if (full.startswith("torch_npu.") and
                        TORCH_NPU_BLOCKED_PATTERN in node.attr.lower()):
                    self.violations.append(
                        f"Line {node.lineno}: Disallowed torch_npu attention call '{full}'"
                    )
                # Rule 2: block torch matmul/softmax/bmm/einsum and F variants
                elif full in BLOCKED_TORCH_OPS:
                    self.violations.append(
                        f"Line {node.lineno}: Disallowed torch operation '{full}' — use Triton instead"
                    )
        self.generic_visit(node)

    def visit_Call(self, node):
        if not self._protected:
            # Rule 3: block third-party attention direct calls
            if isinstance(node.func, ast.Name) and node.func.id in BLOCKED_DIRECT_CALLS:
                self.violations.append(
                    f"Line {node.lineno}: Disallowed call to '{node.func.id}'"
                )
        self.generic_visit(node)


def check_submission(source_code: str) -> List[str]:
    """Run static analysis. Returns list of violations (empty = clean)."""
    violations = []

    try:
        tree = ast.parse(source_code)
        detector = CheatDetector()
        detector.visit(tree)
        violations.extend(detector.violations)
    except SyntaxError as e:
        violations.append(f"Syntax error in submission: {e}")
        return violations

    # String-based fallback on non-comment lines only
    code_lines = [
        line.strip() for line in source_code.splitlines()
        if not line.strip().startswith("#")
    ]
    code_text = "\n".join(code_lines)

    for pattern, reason in [
        ("eval(", "Use of eval() is not allowed"),
        ("exec(", "Use of exec() is not allowed"),
        ("__import__", "Dynamic import is not allowed"),
    ]:
        if pattern in code_text:
            violations.append(f"Suspicious pattern: {reason}")

    return violations


# ============================================================================
# 2. CORRECTNESS GATE
# ============================================================================

CORRECTNESS_TESTS = [
    # (Z, H, N_CTX, HEAD_DIM, causal, dtype)
    # BM/BN are obtained from the student's get_tiling() — not fixed here.
    (1, 2, 1024, 64,  False, torch.float16),
    (1, 1, 64,   64,  False, torch.float16),
    (4, 32, 64,  64,  False, torch.float16),
    (1, 1, 128,  128, False, torch.float16),
    (4, 32, 128, 128, False, torch.float16),
    (1, 2, 1024, 64,  False, torch.float16),

    (1, 1, 64,   64,  True,  torch.float16),
    (4, 32, 64,  64,  True,  torch.float16),
    (1, 1, 128,  128, True,  torch.float16),
    (1, 2, 1024, 64,  True,  torch.float16),
    (1, 1, 128,  128, True,  torch.float16),
    (1, 2, 1024, 64,  True,  torch.float16),

    (4, 32, 2048, 128, False, torch.float16),
    (4, 32, 4096, 64,  False, torch.float16),
    (4, 32, 1024, 128, False, torch.float16),
    (4, 32, 1024, 64,  False, torch.float16),
    (4, 32, 2048, 64,  True,  torch.float16),
    (4, 32, 1024, 64,  True,  torch.float16),
    (4, 32, 512,  64,  True,  torch.float16),
    (4, 32, 256,  128, True,  torch.float16),

    (128, 8, 1024, 64,  False, torch.bfloat16),
]


def run_single_test(attention_fn, get_tiling_fn,
                    Z, H, N_CTX, HEAD_DIM, causal, dtype,
                    seed) -> Tuple[bool, str]:
    """Returns (passed, message). Tiling is obtained from the student's get_tiling()."""
    # Get student's tiling for this shape
    try:
        tiling = get_tiling_fn(Z, H, N_CTX, HEAD_DIM, causal)
        if not isinstance(tiling, (tuple, list)) or len(tiling) != 2:
            return False, f"get_tiling must return (BM, BN), got: {tiling}"
        BM, BN = int(tiling[0]), int(tiling[1])
    except Exception as e:
        return False, f"get_tiling error: {e}"

    torch.manual_seed(seed)
    q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)
    k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)
    v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)

    sm_scale = 0.5
    compressed_len = 2048
    atten_golden_mask = None
    sparse_mode = 0

    if causal:
        atten_golden_mask = torch.triu(
            torch.ones(compressed_len, compressed_len, device=DEVICE), diagonal=1
        ).bool()
        sparse_mode = 2

    try:
        ref_out = torch_npu.npu_fusion_attention(
            q, k, v, H, padding_mask=None, atten_mask=atten_golden_mask,
            scale=sm_scale, keep_prob=1.0, input_layout='BNSD',
            pre_tockens=65535, next_tockens=65535, sparse_mode=sparse_mode,
        )[0]
    except Exception as e:
        return False, f"Reference kernel error: {e}"

    try:
        tri_out = attention_fn(q, k, v, causal, sm_scale, BM, BN).to(dtype)
    except Exception as e:
        return False, f"Student kernel error (BM={BM}, BN={BN}): {e}"

    if torch.allclose(ref_out, tri_out, atol=1e-2, rtol=1e-2):
        return True, f"PASS (BM={BM}, BN={BN})"
    else:
        max_diff = (ref_out - tri_out).abs().max().item()
        return False, f"FAIL (BM={BM}, BN={BN}, max_diff={max_diff:.6f})"


def run_correctness_gate(attention_fn, get_tiling_fn) -> Tuple[bool, int, int, List[dict]]:
    total = len(CORRECTNESS_TESTS)
    passed = 0
    details = []
    all_passed = True

    for i, (Z, H, N_CTX, HEAD_DIM, causal, dtype) in enumerate(CORRECTNESS_TESTS):
        seed = 42 + i * 7  # varied seeds to prevent output hardcoding
        config_str = (f"Z={Z} H={H} N_CTX={N_CTX} HEAD_DIM={HEAD_DIM} "
                      f"causal={causal} dtype={dtype}")
        try:
            ok, msg = run_single_test(attention_fn, get_tiling_fn,
                                      Z, H, N_CTX, HEAD_DIM, causal, dtype, seed)
        except Exception as e:
            ok, msg = False, f"CRASH: {e}"

        details.append({"test": i + 1, "config": config_str, "result": msg})
        if ok:
            passed += 1
        else:
            all_passed = False
        print(f"  [{'PASS' if ok else 'FAIL'}] Test {i+1:2d}/{total}: {msg}")

    return all_passed, passed, total, details


# ============================================================================
# 3. PERFORMANCE GRADING
# ============================================================================

PERF_CONFIGS = [
    # (Z, H, N_CTX, HEAD_DIM, causal, dtype)
    (128, 8, 8192, 64,  False, torch.float16),
    (128, 8, 4096, 128, False, torch.float16),
    (128, 8, 2048, 256, False, torch.float16),
    (128, 8, 2048, 128, True,  torch.float16),
    (128, 8, 1024, 128, True,  torch.float16),
    (128, 8, 1024, 256, True,  torch.float16),
]


def fetch_op_time(path: str) -> Optional[float]:
    """Extract average op time for forward attention kernels from profiler CSV."""
    if not os.path.exists(path):
        print(f"    fetch_op_time: path '{path}' does not exist")
        return None

    target_file = None
    for root, dirs, files in os.walk(path):
        if "op_statistic.csv" in files:
            target_file = os.path.join(root, "op_statistic.csv")
            break

    if target_file is None:
        print(f"    fetch_op_time: 'op_statistic.csv' not found under '{path}'")
        return None

    results = []
    with open(target_file, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            op_type = row.get("OP Type", "")
            core_type = row.get("Core Type", "")
            if op_type in ("_attn_fwd", "FlashAttentionScore") and core_type == "MIX_AIC":
                try:
                    results.append(float(row["Avg Time(us)"]))
                except (KeyError, ValueError):
                    continue

    if not results:
        print(f"    fetch_op_time: no matching rows (OP Type in _attn_fwd/FlashAttentionScore, Core Type=MIX_AIC)")
        return None

    return sum(results) / len(results)


def profile_single_config(attention_fn, get_tiling_fn,
                          Z, H, N_CTX, HEAD_DIM, causal, dtype,
                          config_index: int) -> Tuple[Optional[float], Optional[float], int, int]:
    # Get and validate student's tiling
    try:
        tiling = get_tiling_fn(Z, H, N_CTX, HEAD_DIM, causal)
    except Exception as e:
        print(f"    get_tiling raised an exception: {e}")
        return None, None, 0, 0

    if not isinstance(tiling, (tuple, list)) or len(tiling) != 2:
        print(f"    get_tiling must return (BM, BN), got: {tiling}")
        return None, None, 0, 0

    BM, BN = int(tiling[0]), int(tiling[1])

    if BM <= 0 or BN <= 0:
        print(f"    Invalid tiling: BM={BM}, BN={BN} (must be positive)")
        return None, None, BM, BN

    if N_CTX % BM != 0:
        print(f"    Warning: N_CTX={N_CTX} not divisible by BM={BM}")
    if N_CTX % BN != 0:
        print(f"    Warning: N_CTX={N_CTX} not divisible by BN={BN}")

    print(f"    Student tiling: BM={BM}, BN={BN}")

    result_dir = os.path.join(os.getcwd(), f"judge_result_dir_{config_index}")
    if os.path.exists(result_dir):
        shutil.rmtree(result_dir)

    torch.manual_seed(20)
    q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)
    k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)
    v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5)

    sm_scale = 0.5
    compressed_len = 2048
    atten_golden_mask = None
    sparse_mode = 0

    if causal:
        atten_golden_mask = torch.triu(
            torch.ones(compressed_len, compressed_len, device=DEVICE), diagonal=1
        ).bool()
        sparse_mode = 2

    # Quick correctness check with student's tiling before profiling
    try:
        ref_check = torch_npu.npu_fusion_attention(
            q, k, v, H, padding_mask=None, atten_mask=atten_golden_mask,
            scale=sm_scale, keep_prob=1.0, input_layout='BNSD',
            pre_tockens=65535, next_tockens=65535, sparse_mode=sparse_mode,
        )[0]
        tri_check = attention_fn(q, k, v, causal, sm_scale, BM, BN).to(dtype)
        if not torch.allclose(ref_check, tri_check, atol=1e-2, rtol=1e-2):
            max_diff = (ref_check - tri_check).abs().max().item()
            print(f"    Correctness FAILED with BM={BM}, BN={BN} (max_diff={max_diff:.6f})")
            return None, None, BM, BN
        print(f"    Correctness OK with student tiling")
    except Exception as e:
        print(f"    Kernel crashed with BM={BM}, BN={BN}: {e}")
        return None, None, BM, BN

    experimental_config = torch_npu.profiler._ExperimentalConfig(
        aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
        profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
        l2_cache=False,
        data_simplification=False,
    )

    ACTIVE = 30
    TOTAL_STEPS = 1 + 1 + 1 + ACTIVE  # skip_first + wait + warmup + active

    with torch_npu.profiler.profile(
        activities=[torch_npu.profiler.ProfilerActivity.NPU],
        schedule=torch_npu.profiler.schedule(wait=1, warmup=1, active=ACTIVE, repeat=1, skip_first=1),
        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
            os.path.join(result_dir, "torch")
        ),
        record_shapes=True, profile_memory=False, with_stack=False,
        with_flops=False, with_modules=False,
        experimental_config=experimental_config,
    ) as prof:
        for _ in range(TOTAL_STEPS):
            torch_npu.npu_fusion_attention(
                q, k, v, H, padding_mask=None, atten_mask=atten_golden_mask,
                scale=sm_scale, keep_prob=1.0, input_layout='BNSD',
                pre_tockens=65535, next_tockens=65535, sparse_mode=sparse_mode,
            )
            torch.npu.synchronize()
            prof.step()

    with torch_npu.profiler.profile(
        activities=[torch_npu.profiler.ProfilerActivity.NPU],
        schedule=torch_npu.profiler.schedule(wait=1, warmup=1, active=ACTIVE, repeat=1, skip_first=1),
        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
            os.path.join(result_dir, "triton")
        ),
        record_shapes=True, profile_memory=False, with_stack=False,
        with_flops=False, with_modules=False,
        experimental_config=experimental_config,
    ) as prof:
        for _ in range(TOTAL_STEPS):
            attention_fn(q, k, v, causal, sm_scale, BM, BN).to(dtype)
            torch.npu.synchronize()
            prof.step()

    ref_time = fetch_op_time(os.path.join(result_dir, "torch"))
    student_time = fetch_op_time(os.path.join(result_dir, "triton"))
    return ref_time, student_time, BM, BN


def run_performance_grading(attention_fn, get_tiling_fn) -> Tuple[int, List[dict]]:
    ratios = []
    details = []
    n_configs = len(PERF_CONFIGS)

    for i, (Z, H, N_CTX, HEAD_DIM, causal, dtype) in enumerate(PERF_CONFIGS):
        config_str = f"Z={Z} H={H} N_CTX={N_CTX} HEAD_DIM={HEAD_DIM} causal={causal}"
        print(f"\n  Config {i+1}/{n_configs}: {config_str}")

        try:
            ref_time, student_time, BM, BN = profile_single_config(
                attention_fn, get_tiling_fn, Z, H, N_CTX, HEAD_DIM, causal, dtype, i
            )
        except Exception as e:
            print(f"    Profiling failed: {e}")
            print(f"    This config scores 0")
            details.append({"config": config_str, "ratio": 0, "error": str(e)})
            ratios.append(0)
            continue

        if ref_time is None or student_time is None:
            print(f"    Error: could not extract profiling data — this config scores 0")
            details.append({"config": config_str, "ratio": 0, "error": "missing profiling data"})
            ratios.append(0)
            continue

        if student_time <= 0:
            print(f"    Error: student kernel time is 0 — this config scores 0")
            details.append({"config": config_str, "ratio": 0, "error": "student_time is 0"})
            ratios.append(0)
            continue

        ratio = ref_time / student_time
        ratios.append(ratio)
        print(f"    Reference: {ref_time:.1f} us | Student: {student_time:.1f} us | Ratio: {ratio:.4f}")
        details.append({
            "config": config_str, "BM": BM, "BN": BN,
            "ref_time_us": ref_time, "student_time_us": student_time, "ratio": ratio,
        })

    avg_ratio = sum(ratios) / n_configs
    n_success = sum(1 for r in ratios if r > 0)
    print(f"\n  Configs succeeded: {n_success}/{n_configs}")
    print(f"  Average ratio: {avg_ratio:.4f}")
    return avg_ratio, details


# ============================================================================
# 4. MAIN JUDGE
# ============================================================================

@dataclass
class JudgeResult:
    submission_file: str
    timestamp: str
    static_analysis_passed: bool = False
    static_violations: List[str] = field(default_factory=list)
    correctness_passed: bool = False
    tests_passed: int = 0
    tests_total: int = 0
    correctness_details: List[dict] = field(default_factory=list)
    correctness_grade: float = 0.0   # out of 40
    perf_details: List[dict] = field(default_factory=list)
    perf_grade: float = 0.0          # out of 60
    grade: int = 0                   # final total (correctness_grade + perf_grade), integer


def load_submission(filepath: str):
    spec = importlib.util.spec_from_file_location("submission", filepath)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def judge(submission_path: str, skip_perf: bool = False) -> JudgeResult:
    result = JudgeResult(
        submission_file=os.path.basename(submission_path),
        timestamp=time.strftime("%Y-%m-%d %H:%M:%S"),
    )

    print("=" * 60)
    print("STEP 1: Static Analysis (Anti-Cheat)")
    print("=" * 60)
    with open(submission_path, "r") as f:
        source = f.read()

    violations = check_submission(source)
    if violations:
        result.static_violations = violations
        print("  REJECTED — violations found:")
        for v in violations:
            print(f"    • {v}")
        print("\n  Grade: 0")
        return result

    result.static_analysis_passed = True
    print("  PASSED")

    print("\n" + "=" * 60)
    print("STEP 2: Loading Submission")
    print("=" * 60)
    try:
        module = load_submission(submission_path)
        attention_fn = module.attention
        get_tiling_fn = module.get_tiling
        print("  Loaded attention and get_tiling successfully")
    except AttributeError as e:
        print(f"  FAILED: missing required export — {e}")
        print("  Submission must define both 'attention' and 'get_tiling'")
        return result
    except Exception as e:
        print(f"  FAILED: {e}")
        traceback.print_exc()
        return result

    print("\n" + "=" * 60)
    print("STEP 3: Correctness Testing  (40 pts)")
    print("=" * 60)
    all_passed, n_passed, n_total, details = run_correctness_gate(attention_fn, get_tiling_fn)
    result.tests_passed = n_passed
    result.tests_total = n_total
    result.correctness_details = details
    result.correctness_passed = all_passed

    # Correctness grade: always proportional (n_passed / n_total) * 40
    result.correctness_grade = round((n_passed / n_total) * 40, 1)
    print(f"\n  Result: {n_passed}/{n_total} tests passed")
    print(f"  Correctness grade: {result.correctness_grade} / 40")

    if skip_perf:
        print("\n  Performance profiling skipped (--skip-perf)")
        result.perf_grade = 0.0
    elif not all_passed:
        print("\n  Performance grade: 0 / 60  (requires 100% correctness)")
        result.perf_grade = 0.0
    else:
        print("\n" + "=" * 60)
        print("STEP 4: Performance Testing  (60 pts)")
        print(f"  Running {len(PERF_CONFIGS)} configurations...")
        print(f"  Grade = avg( ref_time_i / student_time_i ) x 60")
        print("=" * 60)
        try:
            avg_ratio, perf_details = run_performance_grading(attention_fn, get_tiling_fn)
            result.perf_grade = round(avg_ratio * 60, 1)
            result.perf_details = perf_details
        except Exception as e:
            print(f"  Performance profiling failed: {e}")
            traceback.print_exc()
            result.perf_grade = 0.0

    result.grade = int(result.correctness_grade + result.perf_grade)

    _print_summary(result)
    return result


def _print_summary(result: JudgeResult):
    w = 60
    print("\n" + "═" * w)
    print("║" + " FINAL RESULTS ".center(w - 2) + "║")
    print("═" * w)
    print(f"  Submission:      {result.submission_file}")
    print(f"  Timestamp:       {result.timestamp}")
    print(f"  Static Analysis: {'✓ PASS' if result.static_analysis_passed else '✗ FAIL'}")
    print(f"  Correctness:     {result.tests_passed}/{result.tests_total}"
          f" ({'✓ PASS' if result.correctness_passed else '✗ PARTIAL'})")
    print("─" * w)
    print(f"  Correctness score:   {result.correctness_grade:5.1f} / 40")
    perf_note = "" if result.correctness_passed else "  (0 — requires 100% correctness)"
    print(f"  Performance score:   {result.perf_grade:5.1f} / 60{perf_note}")
    print("─" * w)

    grade = result.grade
    bar_label = ("Outstanding" if grade >= 100 else "Excellent" if grade >= 80
                 else "Good" if grade >= 60 else "Fair" if grade >= 40
                 else "Needs work" if grade > 0 else "No score")
    bar_width = 40
    filled = min(grade, 100) * bar_width // 100
    bar = "█" * filled + "░" * (bar_width - filled)

    print(f"\n  GRADE:  {grade} / 100  ({bar_label})")
    print(f"  [{bar}]")
    print("\n" + "═" * w)


def build_comment(result: JudgeResult) -> str:
    """Build the per-testcase summary used as the JSON `comment` field."""
    lines = []

    # Static-analysis rejection
    if not result.static_analysis_passed:
        lines.append("Static analysis FAILED — submission rejected:")
        for v in result.static_violations:
            lines.append(f"  - {v}")
        return "\n".join(lines)

    # Load failure (static passed but no correctness tests ran)
    if result.tests_total == 0:
        lines.append("Failed to load submission "
                     "(could not import 'attention' / 'get_tiling').")
        return "\n".join(lines)

    # --- Correctness section ---
    lines.append(
        f"Correctness: {result.tests_passed}/{result.tests_total} "
        f"tests passed  ({result.correctness_grade:g} / 40 pts)"
    )
    for d in result.correctness_details:
        lines.append(f"  Test {d['test']:2d}: {d['result']}  |  {d['config']}")

    # --- Performance section ---
    lines.append("")
    if not result.correctness_passed:
        lines.append("Performance: 0 / 60 pts  (skipped — requires 100% correctness)")
    elif not result.perf_details:
        lines.append(f"Performance: {result.perf_grade:g} / 60 pts  (no per-config detail)")
    else:
        lines.append(f"Performance: {result.perf_grade:g} / 60 pts")
        for i, d in enumerate(result.perf_details, 1):
            if "error" in d:
                lines.append(
                    f"  Config {i}: ERROR — {d['error']}  |  {d['config']}"
                )
            else:
                lines.append(
                    f"  Config {i}: ratio={d['ratio']:.4f}  "
                    f"(ref={d['ref_time_us']:.1f}us, student={d['student_time_us']:.1f}us)  "
                    f"BM={d['BM']} BN={d['BN']}  |  {d['config']}"
                )

    return "\n".join(lines)


def main():
    parser = argparse.ArgumentParser(description="Flash Attention Forward Competition Judge")
    parser.add_argument("submission", help="Path to submission .py file")
    parser.add_argument("--skip-perf", action="store_true",
                        help="Skip performance profiling (correctness check only)")
    parser.add_argument("--score-file", type=str, default="score.txt",
                        help="File to write the final grade (default: score.txt)")
    parser.add_argument("--log-file", type=str, default="judge.log",
                        help="File to write the full judging log (default: judge.log)")
    parser.add_argument("--output-json", type=str, default=None,
                        help="Save structured results to JSON file")
    args = parser.parse_args()

    if not os.path.exists(args.submission):
        print(json.dumps({
            "verdict": "correctness 0 + performance 0",
            "score": 0,
            "comment": f"Error: '{args.submission}' not found",
        }))
        sys.exit(1)

    logger = CapturingLogger(args.log_file)
    sys.stdout = logger
    sys.stderr = logger

    result = None
    log_content = ""
    try:
        result = judge(args.submission, skip_perf=args.skip_perf)

        with open(args.score_file, "w") as f:
            f.write(str(result.grade))
        print(f"\nScore written to {args.score_file}")

        if args.output_json:
            with open(args.output_json, "w") as f:
                json.dump(asdict(result), f, indent=2, default=str)
            print(f"Results saved to {args.output_json}")
    except Exception:
        traceback.print_exc()
    finally:
        print(f"Log written to {args.log_file}")
        log_content = logger.get_log()
        logger.close()

    # Build the JSON payload printed to the real stdout
    if result is not None:
        score = result.grade
        verdict = (f"correctness {result.correctness_grade:g} "
                   f"+ performance {result.perf_grade:g}")
        comment = build_comment(result)
    else:
        score = 0
        verdict = "judge crashed"
        comment = log_content  # fall back to full log on hard crash

    print(json.dumps({
        "verdict": verdict,
        "score": score,
        "comment": comment,
    }))


if __name__ == "__main__":
    main()