import dataclasses
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
@dataclasses.dataclass(frozen=True)
class ExpectedValue:
time_s: float
rel_tolerance: float = 0.2
abs_tolerance_s: Optional[float] = None
@classmethod
def from_dict(cls, data: Dict[str, Any], default_rel_tolerance: float) -> "ExpectedValue":
if "time_s" not in data:
raise ValueError("Expected time value must include 'time_s'.")
return cls(
time_s=float(data["time_s"]),
rel_tolerance=float(data.get("rel_tolerance", default_rel_tolerance)),
abs_tolerance_s=(None if data.get("abs_tolerance_s") is None else float(data["abs_tolerance_s"])),
)
def matches(self, actual_s: float) -> bool:
tolerance = abs(self.time_s) * self.rel_tolerance
if self.abs_tolerance_s is not None:
tolerance = max(tolerance, self.abs_tolerance_s)
return abs(actual_s - self.time_s) <= tolerance
@dataclasses.dataclass(frozen=True)
class ExpectedOp:
name: str
count: Optional[int] = None
count_min: Optional[int] = None
count_max: Optional[int] = None
total_time: Optional[ExpectedValue] = None
rel_tolerance: float = 0.3
confidence: str = "high"
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExpectedOp":
if "name" not in data:
raise ValueError("Each major op expectation must include 'name'.")
count = data.get("count")
count_min = data.get("count_min")
count_max = data.get("count_max")
if isinstance(count, dict):
count_min = count.get("min")
count_max = count.get("max")
count = None
rel_tolerance = float(data.get("rel_tolerance", 0.3))
total_time = None
if "total_time_s" in data:
total_time = ExpectedValue(
time_s=float(data["total_time_s"]),
rel_tolerance=rel_tolerance,
abs_tolerance_s=(None if data.get("abs_tolerance_s") is None else float(data["abs_tolerance_s"])),
)
return cls(
name=str(data["name"]),
count=None if count is None else int(count),
count_min=None if count_min is None else int(count_min),
count_max=None if count_max is None else int(count_max),
total_time=total_time,
rel_tolerance=rel_tolerance,
confidence=str(data.get("confidence", "high")),
)
def count_matches(self, actual_count: int) -> bool:
if self.count is not None:
return actual_count == self.count
if self.count_min is not None and actual_count < self.count_min:
return False
if self.count_max is not None and actual_count > self.count_max:
return False
return True
@dataclasses.dataclass(frozen=True)
class EvidenceCase:
name: str
input: Dict[str, Any]
total_forward: Optional[ExpectedValue]
major_ops: List[ExpectedOp]
notes: List[str] = dataclasses.field(default_factory=list)
accepted_gaps: List[str] = dataclasses.field(default_factory=list)
observed_kernels: List[Dict[str, Any]] = dataclasses.field(default_factory=list)
shape_hints: Dict[str, Any] = dataclasses.field(default_factory=dict)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EvidenceCase":
expected = data.get("expected", {})
total_forward = None
if expected.get("total_forward") is not None:
total_forward = ExpectedValue.from_dict(expected["total_forward"], 0.2)
return cls(
name=str(data.get("name", "default")),
input=dict(data.get("input", {})),
total_forward=total_forward,
major_ops=[ExpectedOp.from_dict(op) for op in expected.get("major_ops", [])],
notes=list(data.get("notes", [])),
accepted_gaps=list(data.get("accepted_gaps", [])),
observed_kernels=list(data.get("observed_kernels", [])),
shape_hints=dict(data.get("shape_hints", {})),
)
@dataclasses.dataclass(frozen=True)
class EvidenceDocument:
version: int
model: Dict[str, Any]
cases: List[EvidenceCase]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EvidenceDocument":
version = int(data.get("version", 1))
if version != 1:
raise ValueError(f"Unsupported profiling evidence version: {version}")
cases = [EvidenceCase.from_dict(case) for case in data.get("cases", [])]
if not cases:
raise ValueError("Profiling evidence must contain at least one case.")
return cls(version=version, model=dict(data.get("model", {})), cases=cases)
def load_evidence(path: Union[str, Path]) -> EvidenceDocument:
evidence_path = Path(path)
with evidence_path.open("r", encoding="utf-8") as handle:
data = yaml.safe_load(handle) or {}
if not isinstance(data, dict):
raise ValueError("Profiling evidence root must be a mapping.")
return EvidenceDocument.from_dict(data)