"""Tests for theory_router.py — template engine and routing logic."""
import unittest
from tools.perf_data_collection.grid_generator.theory_router import (
_GRID_REGISTRY,
_resolve_grid,
collect_theory_generated_rows,
generate_from_template,
resolve_theory_pattern_name,
resolve_complex_generator,
get_theory_generator,
get_default_theory_generator,
default_complex_generators,
)
class TestCollectTheoryGeneratedRows(unittest.TestCase):
def test_basic_collection(self):
import tempfile
from pathlib import Path
from tools.perf_data_collection.grid_generator.generators import TheoryShapeRow
with tempfile.TemporaryDirectory() as td:
csv_path = Path(td) / "MatMulV2.csv"
headers = [
"Input Shapes",
"Input Data Types",
"Input Formats",
"Output Shapes",
"Output Data Types",
"Average Duration(us)",
]
source_rows = [
{
"Input Shapes": '"1,5120;5120,25600"',
"Input Data Types": "DT_BF16;DT_BF16",
"Input Formats": "ND;ND",
"Output Shapes": '"1,25600"',
"Output Data Types": "DT_BF16",
"Average Duration(us)": "10.0",
}
]
generated = iter(
[
TheoryShapeRow([(128, 5120), (5120, 25600)], [(128, 25600)]),
TheoryShapeRow([(256, 5120), (5120, 25600)], [(256, 25600)]),
]
)
rows = collect_theory_generated_rows(
headers,
source_rows,
generated,
csv_path=csv_path,
file_index=1,
total_files=1,
max_rows=None,
rng=None,
)
self.assertEqual(len(rows), 2)
self.assertIn("128,5120;5120,25600", rows[0]["Input Shapes"])
self.assertEqual(rows[0]["Average Duration(us)"], "0")
def test_max_rows_limit(self):
import random
import tempfile
from pathlib import Path
from tools.perf_data_collection.grid_generator.generators import TheoryShapeRow
with tempfile.TemporaryDirectory() as td:
csv_path = Path(td) / "MatMulV2.csv"
headers = ["Input Shapes", "Output Shapes", "Average Duration(us)"]
source_rows = [
{
"Input Shapes": '"1,5120;5120,25600"',
"Output Shapes": '"1,25600"',
"Average Duration(us)": "10.0",
}
]
generated = iter(
[TheoryShapeRow([(i * 128, 5120), (5120, 25600)], [(i * 128, 25600)]) for i in range(1, 20)]
)
rows = collect_theory_generated_rows(
headers,
source_rows,
generated,
csv_path=csv_path,
file_index=1,
total_files=1,
max_rows=5,
rng=random.Random(42),
)
self.assertEqual(len(rows), 5)
class TestGridRegistry(unittest.TestCase):
def test_all_names_resolve(self):
for name in _GRID_REGISTRY:
self.assertIsInstance(_GRID_REGISTRY[name], list)
self.assertGreater(len(_GRID_REGISTRY[name]), 0)
class TestResolveGrid(unittest.TestCase):
def test_by_name(self):
result = _resolve_grid("M_GRID")
self.assertIsInstance(result, list)
self.assertGreater(len(result), 0)
def test_inline_list(self):
self.assertEqual(_resolve_grid([1, 2, 3]), [1, 2, 3])
def test_unknown_name_raises(self):
with self.assertRaises(KeyError):
_resolve_grid("NONEXISTENT_GRID")
def test_invalid_type_raises(self):
with self.assertRaises(ValueError):
_resolve_grid(42)
class TestGenerateFromTemplate(unittest.TestCase):
def test_basic_generation(self):
pattern = self._make_pattern()
rows = list(generate_from_template(pattern, None))
self.assertEqual(len(rows), 2)
def test_constraints_filter(self):
pattern = self._make_pattern(constraints=["tokens > 200"])
rows = list(generate_from_template(pattern, None))
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].input_shapes[0], (256, 5120))
def test_with_output_templates(self):
pattern = self._make_pattern(outputs=["(tokens, hidden)", "(tokens, 1)"])
rows = list(generate_from_template(pattern, None))
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0].output_shapes, [(128, 5120), (128, 1)])
def test_extra_values_dtypes(self):
pattern = self._make_pattern(
input_dtypes=["DT_BF16"],
input_formats=["ND"],
output_dtypes=["DT_BF16"],
output_formats=["ND"],
)
rows = list(generate_from_template(pattern, None))
self.assertIn("Input Data Types", rows[0].extra_values)
self.assertEqual(rows[0].extra_values["Input Data Types"], "DT_BF16")
self.assertEqual(rows[0].extra_values["Input Formats"], "ND")
self.assertEqual(rows[0].extra_values["Output Data Types"], "DT_BF16")
self.assertEqual(rows[0].extra_values["Output Formats"], "ND")
def test_empty_iters_with_constants_only(self):
pattern = {
"iterators": {},
"constants": {"D": 5120},
"inputs": ["(D,)"],
"outputs": ["(D,)"],
}
rows = list(generate_from_template(pattern, None))
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].input_shapes, [(5120,)])
def test_two_iterators_product(self):
pattern = {
"iterators": {"tokens": [128, 256], "hidden": [4096, 5120]},
"constants": {},
"inputs": ["(tokens, hidden)"],
"outputs": ["(tokens, hidden)"],
}
rows = list(generate_from_template(pattern, None))
self.assertEqual(len(rows), 4)
def test_multi_input_expression(self):
pattern = {
"iterators": {"seq": [1, 2]},
"constants": {"D": 5120, "R": 64},
"inputs": ["(seq, D)", "(seq + D, R)"],
"outputs": ["(seq, D)"],
}
rows = list(generate_from_template(pattern, None))
self.assertEqual(rows[0].input_shapes, [(1, 5120), (5121, 64)])
self.assertEqual(rows[1].input_shapes, [(2, 5120), (5122, 64)])
def test_max_func_in_expr(self):
pattern = {
"iterators": {"tokens": [1, 2048, 4096]},
"constants": {"limit": 2048},
"inputs": ["(max(tokens, limit), 64)"],
"outputs": ["(max(tokens, limit), 128)"],
}
rows = list(generate_from_template(pattern, None))
self.assertEqual(rows[0].input_shapes, [(2048, 64)])
self.assertEqual(rows[1].input_shapes, [(2048, 64)])
self.assertEqual(rows[2].input_shapes, [(4096, 64)])
def _make_pattern(self, **overrides):
base = {
"iterators": {"tokens": [128, 256]},
"constants": {"hidden": 5120},
"inputs": ["(tokens, hidden)"],
"outputs": ["(tokens, hidden)"],
}
base.update(overrides)
return base
class TestResolveTheoryPatternName(unittest.TestCase):
def test_direct_assignment(self):
assignments = {"MatMulV2": "MatMulFamily"}
self.assertEqual(
resolve_theory_pattern_name("MatMulV2", assignments, {}),
"MatMulFamily",
)
def test_alternate_resolution(self):
assignments = {"FusedInferAttentionScore": "FIA"}
meta = {"alternates_of": "FusedInferAttentionScore"}
self.assertEqual(
resolve_theory_pattern_name("BatchMatMulV2", assignments, meta),
"FIA",
)
def test_elementwise_query_mode(self):
assignments = {}
meta = {"query_mode": "elementwise"}
self.assertEqual(
resolve_theory_pattern_name("Add", assignments, meta),
"elementwise_binary",
)
def test_no_match(self):
self.assertIsNone(resolve_theory_pattern_name("UnknownOp", {}, {}))
class TestResolveComplexGenerator(unittest.TestCase):
def test_not_in_dict(self):
result = resolve_complex_generator("nonexistent", None, {}, {})
self.assertIsNone(result)
def test_calls_with_model_names(self):
called_with = []
def fake_gen(model_names):
called_with.append(model_names)
return iter([])
result = resolve_complex_generator("test_func", ["dsv3"], {"test_func": fake_gen}, {})
self.assertIsNotNone(result)
rows = list(result)
self.assertEqual(len(rows), 0)
self.assertEqual(called_with, [["dsv3"]])
class TestDefaultComplexGenerators(unittest.TestCase):
def test_all_registered(self):
generators = default_complex_generators()
self.assertIn("_theory_grouped_matmul", generators)
self.assertIn("_theory_dfc", generators)
self.assertIn("_theory_fused_attention", generators)
self.assertIn("_theory_split_qkv_rmsnorm_rope", generators)
for func in generators.values():
self.assertTrue(callable(func))
class TestGetDefaultTheoryGenerator(unittest.TestCase):
def test_returns_none_for_unknown_kernel(self):
gen = get_default_theory_generator("UnknownKernel", None, {"assignments": {}, "patterns": {}}, {})
self.assertIsNone(gen)
def test_returns_none_for_communication_op(self):
config = {"assignments": {"hcom_allReduce_": "skip"}, "patterns": {}}
op_meta = {"hcom_allReduce_": {"communication": True}}
gen = get_default_theory_generator("hcom_allReduce_", None, config, op_meta)
self.assertIsNone(gen)
def test_returns_none_for_composite(self):
config = {"assignments": {}, "patterns": {}}
op_meta = {"FusedInferAttentionScore": {"composite": True}}
gen = get_default_theory_generator("FusedInferAttentionScore", None, config, op_meta)
self.assertIsNone(gen)
def test_returns_none_for_zero_cost(self):
config = {
"assignments": {"TransData": "elementwise_binary"},
"patterns": {
"elementwise_binary": {
"iterators": {"tokens": [128]},
"constants": {"D": 5120},
"inputs": ["(tokens, D)"],
"outputs": ["(tokens, D)"],
}
},
}
op_meta = {"TransData": {"zero_cost": True}}
gen = get_default_theory_generator("TransData", None, config, op_meta)
self.assertIsNone(gen)
def test_returns_generator_for_template_pattern(self):
config = {
"assignments": {"MatMulV2": "MatMulFamily"},
"patterns": {
"MatMulFamily": {
"iterators": {"tokens": [128]},
"constants": {"hidden": 5120},
"inputs": ["(tokens, hidden)"],
"outputs": ["(tokens, hidden)"],
},
},
}
op_meta = {"MatMulV2": {}}
gen = get_default_theory_generator("MatMulV2", None, config, op_meta)
self.assertIsNotNone(gen)
rows = list(gen)
self.assertEqual(len(rows), 1)
class TestGetTheoryGenerator(unittest.TestCase):
def setUp(self):
self.config = {
"assignments": {"MatMulV2": "TestIdentity"},
"patterns": {
"TestIdentity": {
"iterators": {"tokens": [128]},
"constants": {"hidden": 5120},
"inputs": ["(tokens, hidden)"],
"outputs": ["(tokens, hidden)"],
},
},
}
self.op_meta = {"MatMulV2": {}}
def test_basic_generator(self):
gen = get_theory_generator(
"MatMulV2",
None,
self.config,
self.op_meta,
complex_generators={},
signature_cache={},
)
self.assertIsNotNone(gen)
rows = list(gen)
self.assertEqual(len(rows), 1)
def test_zero_cost_skipped(self):
self.op_meta["MatMulV2"]["zero_cost"] = True
gen = get_theory_generator(
"MatMulV2",
None,
self.config,
self.op_meta,
complex_generators={},
signature_cache={},
)
self.assertIsNone(gen)
def test_composite_skipped(self):
self.op_meta["MatMulV2"]["composite"] = True
gen = get_theory_generator(
"MatMulV2",
None,
self.config,
self.op_meta,
complex_generators={},
signature_cache={},
)
self.assertIsNone(gen)
def test_communication_skipped(self):
self.op_meta["MatMulV2"]["communication"] = True
gen = get_theory_generator(
"MatMulV2",
None,
self.config,
self.op_meta,
complex_generators={},
signature_cache={},
)
self.assertIsNone(gen)
def test_unknown_kernel(self):
gen = get_theory_generator(
"UnknownKernel",
None,
self.config,
{},
complex_generators={},
signature_cache={},
)
self.assertIsNone(gen)
def test_missing_pattern(self):
self.config["assignments"]["MatMulV2"] = "NonExistentPattern"
gen = get_theory_generator(
"MatMulV2",
None,
self.config,
self.op_meta,
complex_generators={},
signature_cache={},
)
self.assertIsNone(gen)
if __name__ == "__main__":
unittest.main()