"""MLIR checker utilities for UTs."""
from __future__ import annotations
import re
from typing import Any, Dict, Iterable, List, Optional
from mfusion import ir
from mfusion.dialects import torch as torch_d
class MlirChecker:
"""Helper to parse MLIR and validate IR content in UTs."""
def __init__(self, module: ir.Module):
"""Create a checker for an already-parsed module."""
self.module = module
self._error = ""
@staticmethod
def parse_torch_module(text: str) -> "MlirChecker":
"""Parse Torch MLIR text and wrap it with a checker."""
ctx = ir.Context()
torch_d.register_dialect(ctx)
try:
module = ir.Module.parse(text, ctx)
return MlirChecker(module)
except Exception as e:
raise ValueError(f"Failed to parse MLIR text: {e}") from e
@property
def error(self) -> str:
"""Return the latest error message."""
return self._error
def check_has_op(self, op_name: str, count: int | None = None) -> bool:
"""Check that an op appears, optionally with an exact count."""
ops = self._filter_ops(op_name)
if count is None:
if not ops:
return self._set_error(f"Expected op '{op_name}', but not found.")
else:
actual = len(ops)
if actual != count:
return self._set_error(f"Expected {count} ops of '{op_name}', but got {actual}.")
return self._clear_error()
def check_no_op(self, op_name: str) -> bool:
"""Check that an op does not appear."""
ops = self._filter_ops(op_name)
if ops:
return self._set_error(f"Unexpected op '{op_name}' found (count: {len(ops)}).")
return self._clear_error()
def check_top_level_ops(self, expected: List[str]) -> bool:
"""Check the exact sequence of top-level op names."""
actual = [op.operation.name for op in self.module.body.operations]
if actual != expected:
return self._set_error(f"Top-level ops mismatch.\nExpected: {expected}\nActual: {actual}")
return self._clear_error()
def check_has_function(self, func_name: str) -> bool:
"""Check that a function with the given name exists."""
if self._find_func_op(func_name) is None:
return self._set_error(f"Function '@{func_name}' not found.")
return self._clear_error()
def check_func_op_sequence(self, func_name: str, expected: List[str]) -> bool:
"""Check op name sequence inside the specified function block."""
func_op = self._find_func_op(func_name)
if func_op is None:
return self._set_error(f"Function '@{func_name}' not found.")
if not func_op.regions or not func_op.regions[0].blocks:
return self._set_error(f"Function '@{func_name}' has no body block.")
block = func_op.regions[0].blocks[0]
actual = [op.operation.name for op in block.operations]
if actual != expected:
return self._set_error(
f"Ops mismatch in '@{func_name}'.\nExpected: {expected}\nActual: {actual}"
)
return self._clear_error()
def check_text_contains(self, text: str) -> bool:
"""Check that the module string contains the expected text."""
module_text = str(self.module)
if text not in module_text:
return self._set_error(f"Substring '{text}' not found in IR.")
return self._clear_error()
def check_text_not_contains(self, text: str) -> bool:
"""Check that the module string does not contain the given text."""
module_text = str(self.module)
if text in module_text:
return self._set_error(f"Unexpected substring '{text}' found in IR.")
return self._clear_error()
def check_has_op_with_attrs(
self,
op_name: str,
*,
attrs: Dict[str, Any] | None = None,
attr_keys: List[str] | None = None,
count: int | None = None,
) -> bool:
"""Check that ops exist with the given name and attributes.
Args:
op_name: Operation name to match.
attrs: Attribute key/value matches. If None or empty, no value matching is applied.
attr_keys: Attribute keys that must exist. If None or empty, no key presence check.
count: Exact number of matching ops. If None, require at least one match.
"""
if attrs is None:
attrs = {}
if attr_keys is None:
attr_keys = []
ops = [
op
for op in self._filter_ops(op_name)
if self._op_matches_attrs(op, attrs) and self._op_has_attr_keys(op, attr_keys)
]
if count is None:
if not ops:
return self._set_error(
f"Expected op '{op_name}' with attrs {attrs} and keys {attr_keys}, but not found."
)
else:
actual = len(ops)
if actual != count:
return self._set_error(
f"Expected {count} ops of '{op_name}' with attrs {attrs} and keys {attr_keys}, but got {actual}."
)
return self._clear_error()
def check_has_torch_operator(
self,
operator_name: str,
*,
attrs: Dict[str, Any] | None = None,
attr_keys: List[str] | None = None,
count: int | None = None,
) -> bool:
"""Check that torch.operator exists with the given operator name and attributes.
Args:
operator_name: The operator name inside torch.operator "..."
attrs: Attribute key/value matches. If None or empty, no value matching is applied.
attr_keys: Attribute keys that must exist. If None or empty, no key presence check.
count: Exact number of matching ops. If None, require at least one match.
"""
if attrs is None:
attrs = {}
if attr_keys is None:
attr_keys = []
ops = [
op
for op in self._filter_ops("torch.operator")
if self._torch_operator_name(op) == operator_name
and self._op_matches_attrs(op, attrs)
and self._op_has_attr_keys(op, attr_keys)
]
if count is None:
if not ops:
return self._set_error(
f"Expected torch.operator '{operator_name}' with attrs {attrs} and keys {attr_keys}, but not found."
)
else:
actual = len(ops)
if actual != count:
return self._set_error(
f"Expected {count} torch.operator '{operator_name}' with attrs {attrs} and keys {attr_keys}, but got {actual}."
)
return self._clear_error()
def check_total_op_count(self, expected: int) -> bool:
"""Check the total number of operations in the module."""
actual = sum(1 for _ in self._walk_ops())
if actual != expected:
return self._set_error(f"Total op count mismatch. Expected {expected}, got {actual}.")
return self._clear_error()
def _filter_ops(self, op_name: str) -> List[ir.Operation]:
"""Return ops matching the given name."""
return [op for op in self._walk_ops() if op.operation.name == op_name]
def _op_matches_attrs(self, op: ir.Operation, attrs: Dict[str, Any]) -> bool:
"""Check that an op's attributes match expected values."""
for key, expected in attrs.items():
if key not in op.attributes:
return False
if not self._attr_equals(op.attributes[key], expected):
return False
return True
def _op_has_attr_keys(self, op: ir.Operation, attr_keys: List[str]) -> bool:
"""Check that an op contains all requested attribute keys."""
return all(key in op.attributes for key in attr_keys)
def _attr_equals(self, attr: ir.Attribute, expected: Any) -> bool:
"""Compare an attribute with a Python value."""
if isinstance(expected, bool):
if isinstance(attr, ir.BoolAttr):
return attr.value == expected
if isinstance(attr, ir.IntegerAttr):
return bool(attr.value) == expected
text = self._attr_to_string(attr).lower()
return text in ("true", "false") and (text == "true") == expected
if isinstance(expected, (int, float)):
if isinstance(attr, ir.IntegerAttr):
return int(attr.value) == int(expected)
if isinstance(attr, ir.FloatAttr):
return float(attr.value) == float(expected)
return False
if isinstance(expected, str):
return self._attr_to_string(attr) == expected
return False
def _attr_to_string(self, attr: ir.Attribute) -> str:
"""Get a string representation of an attribute."""
if isinstance(attr, ir.StringAttr):
return attr.value
if hasattr(attr, "value") and isinstance(attr.value, str):
return attr.value
return str(attr).strip('"')
def _torch_operator_name(self, op: ir.Operation) -> str:
"""Extract the operator name from torch.operator."""
if "name" in op.attributes:
return self._attr_to_string(op.attributes["name"])
match = re.search(r'torch\\.operator\\s+"([^"]+)"', str(op))
if match:
return match.group(1)
return ""
def _walk_ops(self) -> Iterable[ir.Operation]:
"""Yield all operations in the module using a manual walk if necessary."""
def _recursive_walk(operation):
yield operation
for region in operation.regions:
for block in region.blocks:
for op in block.operations:
yield from _recursive_walk(op)
yield from _recursive_walk(self.module.operation)
def _find_func_op(self, func_name: str) -> Optional[ir.Operation]:
"""Find a func.func or similar op by symbol name."""
for op in self.module.body.operations:
if "func" not in op.operation.name:
continue
if "sym_name" in op.attributes:
if ir.StringAttr(op.attributes["sym_name"]).value == func_name:
return op
return None
def _set_error(self, msg: str) -> bool:
"""Set the error message and return False."""
self._error = self._format_error(msg)
return False
def _clear_error(self) -> bool:
"""Clear the error message and return True."""
self._error = ""
return True
def _format_error(self, msg: str) -> str:
"""Format error message with a module snippet."""
snippet = self._module_snippet()
return f"[MlirChecker Error]: {msg}\n--- IR Snippet ---\n{snippet}\n------------------"
def _module_snippet(self, max_lines: int = 40) -> str:
"""Build a truncated module text snippet."""
lines = str(self.module).splitlines()
if len(lines) <= max_lines:
return "\n".join(lines)
return "\n".join(lines[:max_lines]) + f"\n... (truncated, total {len(lines)} lines)"