import csv
import dataclasses
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
_RUNTIME_PREFIXES = (
"CAPTURE_",
"EVENT_",
"MEM_",
"MEMCPY_",
"NOP",
"NOTIFY_",
)
@dataclasses.dataclass(frozen=True)
class ObservedKernel:
name: str
normalized_name: str
category: str
wall_duration_ms: float
self_time_ms: float
average_wall_duration_ms: float
max_wall_duration_ms: float
min_wall_duration_ms: float
occurrences: int
@classmethod
def from_row(cls, row: Dict[str, str]) -> "ObservedKernel":
name = _get_required(row, "Name")
normalized = normalize_kernel_name(name)
return cls(
name=name,
normalized_name=normalized,
category=classify_kernel(normalized),
wall_duration_ms=_to_float(row.get("Wall Duration(ms)")),
self_time_ms=_to_float(row.get("Self Time(ms)")),
average_wall_duration_ms=_to_float(row.get("Average Wall Duration(ms)")),
max_wall_duration_ms=_to_float(row.get("Max Wall Duration(ms)")),
min_wall_duration_ms=_to_float(row.get("Min Wall Duration(ms)")),
occurrences=_to_int(row.get("Occurrences")),
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
class InsightTotals:
wall_duration_ms: float
self_time_ms: float
average_wall_duration_ms: float
max_wall_duration_ms: float
min_wall_duration_ms: float
occurrences: int
@classmethod
def from_row(cls, row: Dict[str, str]) -> "InsightTotals":
return cls(
wall_duration_ms=_to_float(row.get("Wall Duration(ms)")),
self_time_ms=_to_float(row.get("Self Time(ms)")),
average_wall_duration_ms=_to_float(row.get("Average Wall Duration(ms)")),
max_wall_duration_ms=_to_float(row.get("Max Wall Duration(ms)")),
min_wall_duration_ms=_to_float(row.get("Min Wall Duration(ms)")),
occurrences=_to_int(row.get("Occurrences")),
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
@dataclasses.dataclass(frozen=True)
class RawInsightSummary:
kernels: List[ObservedKernel]
totals: InsightTotals
source_path: Optional[str] = None
@property
def total_wall_duration_ms(self) -> float:
return self.totals.wall_duration_ms
def top_kernels(self, limit: int = 20) -> List[ObservedKernel]:
return sorted(self.kernels, key=lambda item: item.wall_duration_ms, reverse=True)[:limit]
def to_dict(self, top_n: Optional[int] = None) -> Dict[str, Any]:
kernels = self.kernels if top_n is None else self.top_kernels(top_n)
return {
"source_path": self.source_path,
"totals": self.totals.to_dict(),
"total_wall_duration_ms": self.total_wall_duration_ms,
"kernels": [kernel.to_dict() for kernel in kernels],
}
def _get_required(row: Dict[str, str], key: str) -> str:
value = row.get(key)
if value is None or not str(value).strip():
raise ValueError(f"Raw Insight row is missing required column {key!r}.")
return str(value).strip()
def _to_float(value: Optional[str]) -> float:
if value is None or str(value).strip() == "":
return 0.0
return float(str(value).strip())
def _to_int(value: Optional[str]) -> int:
if value is None or str(value).strip() == "":
return 0
return int(float(str(value).strip()))
def normalize_kernel_name(name: str) -> str:
value = name.strip()
if any(value.startswith(prefix) for prefix in _RUNTIME_PREFIXES):
return value
if value.startswith("_"):
return re.sub(r"_[0-9]+$", "", value)
first = value.split("_", maxsplit=1)[0]
return first or value
def classify_kernel(normalized_name: str) -> str:
lowered = normalized_name.lower()
if any(normalized_name.startswith(prefix) for prefix in _RUNTIME_PREFIXES):
return "runtime_overhead"
if "allreduce" in lowered or "allgather" in lowered or "alltoall" in lowered or lowered.startswith("hcom"):
return "communication"
if "attention" in lowered or "inferattentionscore" in lowered:
return "attention"
if "moe" in lowered or "dispatchffncombine" in lowered or "gatingtopk" in lowered:
return "moe"
if "matmul" in lowered or "batchmatmul" in lowered:
return "matmul"
if "quant" in lowered:
return "quant"
if "norm" in lowered:
return "norm"
if lowered.startswith("cast"):
return "cast"
return "other"
def _sniff_dialect(sample: str) -> csv.Dialect:
try:
return csv.Sniffer().sniff(sample, delimiters="\t,")
except csv.Error:
return csv.excel_tab
def load_raw_insight(path: Union[str, Path]) -> RawInsightSummary:
insight_path = Path(path)
content = insight_path.read_text(encoding="utf-8-sig")
if not content.strip():
raise ValueError(f"Raw Insight file {insight_path} is empty.")
dialect = _sniff_dialect(content[:4096])
reader = csv.DictReader(content.splitlines(), dialect=dialect)
kernels: List[ObservedKernel] = []
totals: Optional[InsightTotals] = None
saw_data_row = False
for row in reader:
if not any(row.values()):
continue
name = _get_required(row, "Name")
if not saw_data_row:
saw_data_row = True
if name != "Totals":
raise ValueError(
f"Raw Insight file {insight_path} line {reader.line_num}: "
"'Totals' row must immediately follow the header."
)
totals = InsightTotals.from_row(row)
continue
if name == "Totals":
continue
kernels.append(ObservedKernel.from_row(row))
if totals is None:
raise ValueError(f"Raw Insight file {insight_path} must include a 'Totals' row after the header.")
return RawInsightSummary(kernels=kernels, totals=totals, source_path=str(insight_path))