import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import auto, Enum
from typing import Any, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from ..op_invoke_info import OpInvokeInfo
class QuerySource(Enum):
MEASURED = auto()
INTERPOLATED = auto()
EXTRAPOLATED = auto()
PARTIAL = auto()
@dataclass
class QueryResult:
latency_us: float
confidence: float
source: QuerySource
details: Dict[str, Any] = field(default_factory=dict)
shape_match_info: Optional["ShapeMatchInfo"] = None
sub_kernel_shapes: Optional[List["SubKernelShapeInfo"]] = None
def shape_debug_statistics(self) -> dict:
"""Serialize shape debug info into statistics dict entries.
Uses isinstance checks (not 'is not None') so that Mock objects in tests
do not accidentally trigger iteration and raise TypeError.
"""
out: dict = {}
if isinstance(self.sub_kernel_shapes, list):
out["sub_kernel_shapes"] = json.dumps(
[
{
"kernel_type": sk.kernel_type,
"simulation_shapes": str(sk.simulation_shapes),
"kernel_shapes": str(sk.kernel_shapes),
"shape_match_rule": sk.shape_match_rule,
}
for sk in self.sub_kernel_shapes
]
)
elif isinstance(self.shape_match_info, ShapeMatchInfo):
info = self.shape_match_info
out["kernel_shapes"] = str(info.kernel_shapes) if info.kernel_shapes else ""
out["shape_match_rule"] = info.shape_match_rule
return out
@dataclass
class ShapeMatchInfo:
"""Shape debug info for a single profiling lookup."""
simulation_shapes: List[List[int]]
kernel_shapes: List[List[int]]
shape_match_rule: str
@dataclass
class SubKernelShapeInfo:
"""Shape debug info for one sub-kernel inside a composite op."""
kernel_type: str
simulation_shapes: List[List[int]]
kernel_shapes: List[List[int]]
shape_match_rule: str
class DataSourcePerformanceModel(ABC):
"""Abstract base class for performance data sources.
TensorCast queries via OpInvokeInfo only, unaware of underlying data format.
"""
@abstractmethod
def lookup(self, op_invoke_info: "OpInvokeInfo") -> Optional[QueryResult]:
"""Query operator performance from OpInvokeInfo."""
...
def store(self, op_invoke_info: "OpInvokeInfo", result: QueryResult) -> None:
"""Store performance data (optional). Default: read-only."""
raise NotImplementedError("This data source is read-only")