"""Canonical signature utilities for perf database CSV rows.
Shape-grid generation and microbench backfill both need to decide whether two
profiling rows describe the same operator case. This module keeps that matching
logic in one place, including operator-specific normalization for MatMul-family
aliases, parameter slots, and DispatchFFNCombine EP size.
"""
from __future__ import annotations
import re
DISPATCH_FFN_OP = "DispatchFFNCombine"
MATMUL_FAMILY_OPS = {"MatMulV2", "MatMulV3", "MatMulCommon"}
def normalize_op_name(name: str) -> str:
normalized = name.strip()
if normalized.endswith("_run.py"):
normalized = normalized.removesuffix("_run.py")
elif normalized.endswith("_run"):
normalized = normalized.removesuffix("_run")
elif normalized.endswith(".csv"):
normalized = normalized.removesuffix(".csv")
return normalized
def _split_slot_cell(value: str) -> list[str]:
cleaned = (value or "").strip().strip('"')
if not cleaned:
return []
return [part.strip().strip('"') for part in cleaned.split(";")]
def _trim_trailing_empty(values: list[str]) -> list[str]:
result = list(values)
while result and result[-1] == "":
result.pop()
return result
def _normalize_shape_slot(slot: str) -> str:
cleaned = (slot or "").strip().strip('"').strip()
if cleaned in {"()", "N/A", "NA", "NULL", "None", "none"}:
return ""
if cleaned.startswith("(") and cleaned.endswith(")"):
cleaned = cleaned[1:-1]
return ",".join(part.strip() for part in re.split(r"[,\s]+", cleaned) if part.strip())
def _normalize_shape_attr_sig(
shapes_text: str,
attr_text: str,
) -> tuple[str, str]:
shape_slots = [_normalize_shape_slot(slot) for slot in _split_slot_cell(shapes_text)]
attr_slots = _split_slot_cell(attr_text)
slot_count = max(len(shape_slots), len(attr_slots))
shape_slots += [""] * (slot_count - len(shape_slots))
attr_slots += [""] * (slot_count - len(attr_slots))
normalized_attrs: list[str] = []
for index in range(slot_count):
normalized_attrs.append(attr_slots[index] if shape_slots[index] else "")
return (
";".join(_trim_trailing_empty(shape_slots)),
";".join(_trim_trailing_empty(normalized_attrs)),
)
def _parse_shape_slot(slot: str) -> tuple[int, ...] | None:
cleaned = _normalize_shape_slot(slot)
if not cleaned:
return None
try:
return tuple(int(part) for part in cleaned.split(",") if part)
except ValueError:
return None
def _format_shape_slot(shape: tuple[int, ...]) -> str:
return ",".join(str(dim) for dim in shape)
def is_matmul_family(op_name: str) -> bool:
return normalize_op_name(op_name) in MATMUL_FAMILY_OPS
def canonicalize_matmul_family_signature(
row: dict[str, str],
) -> tuple[str, str, str, str] | None:
input_shapes = [_parse_shape_slot(slot) for slot in _split_slot_cell(row.get("Input Shapes", ""))]
output_shapes = [_parse_shape_slot(slot) for slot in _split_slot_cell(row.get("Output Shapes", ""))]
if len(input_shapes) < 2 or not input_shapes[0] or not input_shapes[1]:
return None
a_shape, b_shape = input_shapes[0], input_shapes[1]
if len(a_shape) != 2 or len(b_shape) != 2:
return None
out_shape = output_shapes[0] if output_shapes else None
if not out_shape or len(out_shape) != 2:
return None
m_dim, n_dim = out_shape
k_dim: int | None = None
if a_shape[0] == m_dim:
if b_shape[0] == n_dim and a_shape[1] == b_shape[1]:
k_dim = a_shape[1]
elif b_shape[1] == n_dim and a_shape[1] == b_shape[0]:
k_dim = a_shape[1]
if k_dim is None:
common_dims = set(a_shape) & set(b_shape)
non_output_common = [dim for dim in common_dims if dim not in {m_dim, n_dim}]
if len(non_output_common) == 1:
k_dim = non_output_common[0]
elif len(common_dims) == 1:
k_dim = next(iter(common_dims))
if k_dim is None:
return None
input_dtypes = _split_slot_cell(row.get("Input Data Types", ""))
input_formats = _split_slot_cell(row.get("Input Formats", ""))
canonical_input_shapes = f"{m_dim},{k_dim};{n_dim},{k_dim}"
canonical_output_shapes = f"{m_dim},{n_dim}"
canonical_input_dtypes = ";".join(input_dtypes[:2])
canonical_input_formats = ";".join(input_formats[:2])
return (
canonical_input_shapes,
canonical_input_dtypes,
canonical_input_formats,
canonical_output_shapes,
)
def canonicalize_profile_signature(
row: dict[str, str],
op_name: str | None = None,
) -> tuple[str, str, str]:
resolved_op_name = normalize_op_name(
(op_name or row.get("OP Type", "") or row.get("OP State", "") or "").strip().strip('"')
)
input_shapes = _split_slot_cell(row.get("Input Shapes", ""))
input_dtypes = _split_slot_cell(row.get("Input Data Types", ""))
input_formats = _split_slot_cell(row.get("Input Formats", ""))
def keep_input_slots(indices: list[int]) -> None:
nonlocal input_shapes, input_dtypes, input_formats
input_shapes = [input_shapes[index] for index in indices if index < len(input_shapes)]
input_dtypes = [input_dtypes[index] for index in indices if index < len(input_dtypes)]
input_formats = [input_formats[index] for index in indices if index < len(input_formats)]
if resolved_op_name == "Index":
output_slots = _split_slot_cell(row.get("Output Shapes", ""))
output = _parse_shape_slot(output_slots[0]) if output_slots else None
if output and input_shapes:
input_shapes = [input_shapes[0], _format_shape_slot((output[0],))]
input_dtypes = [input_dtypes[0], input_dtypes[-1]] if input_dtypes else []
input_formats = [input_formats[0], input_formats[-1]] if input_formats else []
elif resolved_op_name in {"Slice", "SliceAiCore", "Transpose", "TransposeAiCore"}:
keep_input_slots([0])
if input_formats:
input_formats[0] = "ND"
return (";".join(input_shapes), ";".join(input_dtypes), ";".join(input_formats))
def get_sig(
row: dict[str, str],
as_str: bool = False,
op_name: str | None = None,
) -> tuple[str, ...] | str:
resolved_op_name = normalize_op_name(
(op_name or row.get("OP Type", "") or row.get("OP State", "") or "").strip().strip('"')
)
if is_matmul_family(resolved_op_name):
matmul_sig = canonicalize_matmul_family_signature(row)
if matmul_sig is not None:
input_shapes, input_dtypes, input_formats, output_shapes = matmul_sig
_, output_dtypes = _normalize_shape_attr_sig(
row.get("Output Shapes", ""),
row.get("Output Data Types", ""),
)
vals = (input_shapes, input_dtypes, input_formats, output_shapes, output_dtypes)
if as_str:
inp = row.get("Input Shapes", "") or "N/A"
out = row.get("Output Shapes", "") or "N/A"
return f"{inp} -> {out}"
return vals
raw_input_shapes, raw_input_dtypes, raw_input_formats = canonicalize_profile_signature(
row, op_name=op_name
)
input_shapes, input_dtypes = _normalize_shape_attr_sig(raw_input_shapes, raw_input_dtypes)
_, input_formats = _normalize_shape_attr_sig(raw_input_shapes, raw_input_formats)
output_shapes, output_dtypes = _normalize_shape_attr_sig(
row.get("Output Shapes", ""),
row.get("Output Data Types", ""),
)
vals = (input_shapes, input_dtypes, input_formats, output_shapes, output_dtypes)
if as_str:
inp = row.get("Input Shapes", "") or "N/A"
out = row.get("Output Shapes", "") or "N/A"
return f"{inp} -> {out}"
if resolved_op_name == normalize_op_name(DISPATCH_FFN_OP):
vals = vals + ((row.get("EP Size", "") or "").strip(),)
return vals