import csv
import logging
import os
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from .models import CATEGORY_ORDER, CATEGORY_TITLES, Violation
logger = logging.getLogger(__name__)
class ViolationReport:
CSV_HEADER = ["category", "rule", "slot", "tensor", "func", "cell", "message"]
def __init__(
self,
slot_tensor_names: Optional[Dict[int, str]] = None,
slot_func_names: Optional[Dict[int, str]] = None,
):
self._violations: List[Violation] = []
self._slot_tensor_names: Dict[int, str] = slot_tensor_names or {}
self._slot_func_names: Dict[int, str] = slot_func_names or {}
@property
def violations(self) -> List[Violation]:
return self._violations
def extend(self, violations: Iterable[Violation]):
for v in violations:
self._violations.append(v)
def has_failure(self) -> bool:
return bool(self._violations)
def print_console(self) -> None:
if not self._violations:
logger.info("PASS")
return
grouped: Dict[str, List[Violation]] = defaultdict(list)
for v in self._violations:
grouped[v.category or "Other"].append(v)
logger.info("FAIL: %d issue(s) detected.", len(self._violations))
ordered_categories = [c for c in CATEGORY_ORDER if c in grouped]
ordered_categories += [c for c in grouped if c not in CATEGORY_ORDER]
for cat in ordered_categories:
title = CATEGORY_TITLES.get(cat, cat)
logger.info("\n[%s]", title)
for v in grouped[cat]:
logger.info(" - %s: %s", self._format_subject(v), v.message)
def save_csv(self, path: str) -> None:
parent = os.path.dirname(os.path.abspath(path))
if parent:
os.makedirs(parent, exist_ok=True)
with open(path, "w", encoding="utf-8-sig", newline="") as f:
w = csv.writer(f)
w.writerow(self.CSV_HEADER)
for v in self._violations:
w.writerow([
v.category,
v.rule_id,
"" if v.slot_idx is None else v.slot_idx,
self._tensor_of(v.slot_idx),
self._func_of(v.slot_idx),
"" if v.cell_idx is None else v.cell_idx,
v.message,
])
logger.info("report written to: %s", path)
def _format_subject(self, v: Violation) -> str:
if v.slot_idx is None:
return "(no tensor)"
tensor = self._tensor_of(v.slot_idx)
slot_label = f"slot={v.slot_idx}"
if v.cell_idx is not None:
slot_label = f"{slot_label}, cell={v.cell_idx}"
if tensor:
return f"tensor '{tensor}' ({slot_label})"
return f"unnamed tensor ({slot_label})"
def _func_of(self, slot_idx: Optional[int]) -> str:
if slot_idx is None:
return ""
return self._slot_func_names.get(slot_idx, "") or ""
def _tensor_of(self, slot_idx: Optional[int]) -> str:
if slot_idx is None:
return ""
return self._slot_tensor_names.get(slot_idx, "") or ""