"""Tests for grid_generator/config.py — config loading utilities."""
import tempfile
import unittest
from pathlib import Path
import yaml
from tools.perf_data_collection.grid_generator.config import (
load_op_mapping_metadata,
load_shape_grid_config,
)
class TestLoadShapeGridConfig(unittest.TestCase):
def test_loads_yaml(self):
with tempfile.TemporaryDirectory() as td:
p = Path(td) / "config.yaml"
p.write_text("assignments:\n MatMulV2:\n pattern: MatMulFamily\n models: [dsv3, qwen332b]\n")
result = load_shape_grid_config(p)
self.assertIn("assignments", result)
self.assertEqual(result["assignments"]["MatMulV2"]["pattern"], "MatMulFamily")
def test_empty_yaml(self):
with tempfile.TemporaryDirectory() as td:
p = Path(td) / "empty.yaml"
p.write_text("{}")
result = load_shape_grid_config(p)
self.assertEqual(result, {})
class TestLoadOpMappingMetadata(unittest.TestCase):
def test_basic_mapping(self):
with tempfile.TemporaryDirectory() as td:
datadir = Path(td)
mapping = {
"operator_mappings": {
"aten.mm.default": {
"kernel_type": "MatMulV2",
"zero_cost": False,
"composite": False,
},
"aten.view.default": {
"kernel_type": "TransData",
"zero_cost": True,
"composite": False,
},
"tensor_cast.mla.default": {
"kernel_type": "FusedInferAttentionScore",
"composite": True,
"alternate_kernel_types": [
"BatchMatMulV2",
"TransposeBatchMatMul",
],
},
"tensor_cast.all_reduce.default": {
"kernel_type": "hcom_allReduce_",
"category": "communication",
"query_mode": "hcom",
},
}
}
with (datadir / "op_mapping.yaml").open("w") as f:
yaml.dump(mapping, f)
meta = load_op_mapping_metadata(datadir)
self.assertIn("MatMulV2", meta)
self.assertFalse(meta["MatMulV2"]["zero_cost"])
self.assertFalse(meta["MatMulV2"]["composite"])
self.assertIn("TransData", meta)
self.assertTrue(meta["TransData"]["zero_cost"])
self.assertIn("FusedInferAttentionScore", meta)
self.assertTrue(meta["FusedInferAttentionScore"]["composite"])
self.assertIn("BatchMatMulV2", meta)
self.assertEqual(meta["BatchMatMulV2"]["alternates_of"], "FusedInferAttentionScore")
self.assertIn("TransposeBatchMatMul", meta)
self.assertIn("hcom_allReduce_", meta)
self.assertTrue(meta["hcom_allReduce_"]["communication"])
self.assertEqual(meta["hcom_allReduce_"]["query_mode"], "hcom")
def test_missing_file_returns_empty(self):
with tempfile.TemporaryDirectory() as td:
datadir = Path(td)
meta = load_op_mapping_metadata(datadir)
self.assertEqual(meta, {})
def test_empty_mapping_returns_empty(self):
with tempfile.TemporaryDirectory() as td:
datadir = Path(td)
with (datadir / "op_mapping.yaml").open("w") as f:
yaml.dump({}, f)
meta = load_op_mapping_metadata(datadir)
self.assertEqual(meta, {})
def test_skips_non_dict_entries(self):
with tempfile.TemporaryDirectory() as td:
datadir = Path(td)
mapping = {
"operator_mappings": {
"aten.skip": "just_a_string",
"aten.mm": {
"kernel_type": "MatMulV2",
},
}
}
with (datadir / "op_mapping.yaml").open("w") as f:
yaml.dump(mapping, f)
meta = load_op_mapping_metadata(datadir)
self.assertIn("MatMulV2", meta)
self.assertNotIn("aten.skip", meta)
def test_skips_no_kernel_type(self):
with tempfile.TemporaryDirectory() as td:
datadir = Path(td)
mapping = {
"operator_mappings": {
"aten.no_kt": {
"zero_cost": True,
},
}
}
with (datadir / "op_mapping.yaml").open("w") as f:
yaml.dump(mapping, f)
meta = load_op_mapping_metadata(datadir)
self.assertEqual(meta, {})
if __name__ == "__main__":
unittest.main()