"""Tests for generate_per_shape_comparison.py — per-(kernel_type, shape) delta."""
import csv
import json
import sys
from pathlib import Path
import pytest
sys.path.insert(
0,
str(Path(__file__).resolve().parents[4] / "tools" / "perf_data_analysis"),
)
from generate_per_shape_comparison import generate_per_shape_comparison
from tests.helpers.cli_runner import run_module_main
def _make_trace(tmp_path, events):
path = tmp_path / "tc_trace.json"
path.write_text(json.dumps({"traceEvents": events}))
return str(path)
def _make_prof(tmp_path, rows):
"""Create prof forward-pass CSV. rows: list of (Type, Input Shapes, Duration)."""
path = tmp_path / "prof.csv"
lines = ["Type,Duration(us),Start Time(us),Input Shapes"]
t = 0
for ktype, shapes, dur in rows:
lines.append(f'{ktype},{dur},{t},"{shapes}"')
t += dur + 10
path.write_text("\n".join(lines))
return str(path)
def _x(name, dur, pid=0, **kwargs):
return {
"name": name,
"ph": "X",
"ts": 0,
"dur": dur,
"pid": pid,
"tid": 0,
"args": {k: str(v) for k, v in kwargs.items()},
}
class TestBasicComparison:
def test_single_kernel_exact_match(self, tmp_path):
"""TC and prof both have MatMulV2 with same shape."""
trace = _make_trace(
tmp_path,
[
_x(
"aten.mm.default",
20,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[4112, 5120], [5120, 768]]",
),
],
)
prof = _make_prof(
tmp_path,
[
("MatMulV2", "4112,5120;5120,768", 25.0),
],
)
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
assert len(rows) == 1
r = rows[0]
assert r["kernel_type"] == "MatMulV2"
assert float(r["tc_dur_us"]) == pytest.approx(20.0)
assert float(r["prof_dur_us"]) == pytest.approx(25.0)
assert float(r["delta_pct"]) == pytest.approx(-20.0, abs=0.1)
def test_multiple_shapes_same_kernel(self, tmp_path):
"""Two different shapes for MatMulV2."""
trace = _make_trace(
tmp_path,
[
_x(
"mm",
20,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[100, 200], [200, 300]]",
),
_x(
"mm",
40,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[500, 200], [200, 300]]",
),
],
)
prof = _make_prof(
tmp_path,
[
("MatMulV2", "100,200;200,300", 22.0),
("MatMulV2", "500,200;200,300", 38.0),
],
)
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
assert len(rows) == 2
class TestCompositeExpansion:
def test_composite_sub_kernels_expanded(self, tmp_path):
"""Composite op with sub_kernel_durations creates separate rows."""
trace = _make_trace(
tmp_path,
[
_x(
"mla",
77,
kernel_type="BMNd,FIA,TBMM",
source="MEASURED",
composite="True",
sub_kernel_durations="[('BMNd', 9.0), ('FIA', 55.0), ('TBMM', 13.0)]",
simulation_shapes="[[4, 512]]",
),
],
)
prof = _make_prof(
tmp_path,
[
("BMNd", "4,512", 10.0),
("FIA", "4,512", 50.0),
("TBMM", "4,512", 14.0),
],
)
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
kts = {r["kernel_type"] for r in rows}
assert "FIA" in kts
fia_row = next(r for r in rows if r["kernel_type"] == "FIA")
assert float(fia_row["tc_dur_us"]) == pytest.approx(55.0)
assert float(fia_row["prof_dur_us"]) == pytest.approx(50.0)
class TestUnmatchedEntries:
def test_tc_only_kernel(self, tmp_path):
"""Kernel in TC but not in prof → prof_dur_us = 0."""
trace = _make_trace(
tmp_path,
[
_x(
"op",
10,
kernel_type="OnlyInTC",
source="MEASURED",
simulation_shapes="[[100]]",
),
],
)
prof = _make_prof(tmp_path, [])
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
assert len(rows) == 1
assert float(rows[0]["prof_dur_us"]) == 0
def test_prof_only_kernel(self, tmp_path):
"""Kernel in prof but not in TC → tc_dur_us = 0."""
trace = _make_trace(tmp_path, [])
prof = _make_prof(
tmp_path,
[
("OnlyInProf", "100,200", 30.0),
],
)
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
assert len(rows) == 1
assert float(rows[0]["tc_dur_us"]) == 0
class TestAggregation:
def test_same_shape_aggregated(self, tmp_path):
"""Multiple invocations of same (kernel_type, shape) are summed."""
trace = _make_trace(
tmp_path,
[
_x(
"mm",
20,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[100, 200]]",
),
_x(
"mm",
30,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[100, 200]]",
),
],
)
prof = _make_prof(
tmp_path,
[
("MatMulV2", "100,200", 22.0),
("MatMulV2", "100,200", 28.0),
],
)
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, prof, str(out))
rows = list(csv.DictReader(out.open()))
assert len(rows) == 1
assert float(rows[0]["tc_dur_us"]) == pytest.approx(50.0)
assert float(rows[0]["prof_dur_us"]) == pytest.approx(50.0)
assert int(rows[0]["tc_count"]) == 2
class TestHcomDedupInProf:
def test_hcom_dedup_groups_by_start_and_type_and_shape(self, tmp_path):
trace = _make_trace(tmp_path, [])
prof = tmp_path / "prof.csv"
lines = [
"Type,Duration(us),Start Time(us),Input Shapes",
'hcom_allReduce_,100.0,1000.0,"128,5120"',
'hcom_allReduce_,80.0,1000.0,"128,5120"',
'MatMulV2,50.0,2000.0,"128,5120"',
]
prof.write_text("\n".join(lines))
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, str(prof), str(out))
rows = list(csv.DictReader(out.open()))
hcom_row = next(r for r in rows if r["kernel_type"] == "hcom_allReduce_")
assert float(hcom_row["prof_dur_us"]) == pytest.approx(100.0)
def test_aicpu_excluded_from_prof(self, tmp_path):
trace = _make_trace(tmp_path, [])
prof = tmp_path / "prof.csv"
lines = [
"Type,Duration(us),Start Time(us),Input Shapes",
'allgatherAicpuKernel,200.0,1000.0,""',
'MatMulV2,50.0,2000.0,"128,5120"',
]
prof.write_text("\n".join(lines))
out = tmp_path / "out.csv"
generate_per_shape_comparison(trace, str(prof), str(out))
rows = list(csv.DictReader(out.open()))
kts = {r["kernel_type"] for r in rows}
assert "allgatherAicpuKernel" not in kts
assert "MatMulV2" in kts
class TestNormalizeShapeKey:
def test_invalid_list_literal(self):
from generate_per_shape_comparison import _normalize_shape_key
result = _normalize_shape_key("[")
assert result == "["
def test_empty_quoted_string(self):
from generate_per_shape_comparison import _normalize_shape_key
assert _normalize_shape_key('""') == ""
def test_bracket_shape_parsing(self):
from generate_per_shape_comparison import _normalize_shape_key
assert _normalize_shape_key("[[100,200],[200,300]]") == "100,200;200,300"
class TestMainCli:
def test_main_with_output(self, tmp_path):
trace = _make_trace(
tmp_path,
[
_x(
"mm",
20,
kernel_type="MatMulV2",
source="MEASURED",
simulation_shapes="[[100,200]]",
),
],
)
prof = _make_prof(tmp_path, [("MatMulV2", "100,200", 25.0)])
out = tmp_path / "out.csv"
result = run_module_main(
"tools.perf_data_analysis.generate_per_shape_comparison",
[
"--tc-trace",
trace,
"--prof-trace",
prof,
"--output",
str(out),
],
)
assert result.returncode == 0
assert out.exists()
assert "Generated" in result.stdout