import unittest
from unittest.mock import Mock
import torch
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.performance_model.base import PerformanceModel
from tensor_cast.performance_model.empirical import EmpiricalPerformanceModel
from tensor_cast.performance_model.profiling_database.data_source import (
DataSourcePerformanceModel,
QueryResult,
QuerySource,
)
from tensor_cast.runtime import Runtime
class TestEmpiricalPerformanceModel(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.device_profile = TEST_DEVICE
self.data_source = Mock(spec=DataSourcePerformanceModel)
self.fallback_model = Mock(spec=PerformanceModel)
self.empirical_model = EmpiricalPerformanceModel(
device_profile=self.device_profile,
data_source=self.data_source,
fallback_model=self.fallback_model,
)
def test_init_with_fallback_model(self):
"""Test initialization with a provided fallback model."""
self.assertEqual(self.empirical_model.name, "empirical")
self.assertEqual(self.empirical_model.device_profile, self.device_profile)
self.assertEqual(self.empirical_model.data_source, self.data_source)
self.assertEqual(self.empirical_model.fallback_model, self.fallback_model)
def test_init_without_fallback_model(self):
"""Test initialization without a provided fallback model (should use AnalyticPerformanceModel)."""
empirical_model = EmpiricalPerformanceModel(device_profile=self.device_profile, data_source=self.data_source)
self.assertEqual(empirical_model.name, "empirical")
self.assertEqual(empirical_model.device_profile, self.device_profile)
self.assertEqual(empirical_model.data_source, self.data_source)
self.assertIsInstance(empirical_model.fallback_model, AnalyticPerformanceModel)
self.assertEqual(empirical_model.fallback_model.device_profile, self.device_profile)
def test_get_classifiers_with_runtime(self):
"""Test get_classifiers returns classifiers from fallback model via Runtime."""
self.data_source.lookup.return_value = None
perf_model = EmpiricalPerformanceModel(self.device_profile, self.data_source)
def func(x, y):
return torch.matmul(x, y)
x = torch.randn([100, 100], device="meta")
y = torch.randn([100, 100], device="meta")
with (
Runtime(perf_model, self.device_profile) as runtime,
torch.no_grad(),
):
func(x, y)
classifiers = perf_model.get_classifiers()
self.assertIsNotNone(classifiers)
self.assertIsInstance(classifiers, list)
result = runtime.table_averages()
self.assertIn("empirical", result)
breakdowns = runtime.get_breakdowns()
self.assertIsInstance(breakdowns, dict)
def test_empirical_model_with_runtime_matmul(self):
"""Test EmpiricalPerformanceModel with real PyTorch matmul op via Runtime."""
query_result = Mock(spec=QueryResult)
query_result.latency_us = 100.0
query_result.confidence = 0.95
query_result.source = QuerySource.MEASURED
query_result.details = {"kernel_type": "MatMulV2"}
query_result.shape_debug_statistics.return_value = {}
self.data_source.lookup.return_value = query_result
perf_model = EmpiricalPerformanceModel(self.device_profile, self.data_source)
def func(x, y):
return torch.matmul(x, y)
x = torch.randn([100, 100], device="meta")
y = torch.randn([100, 100], device="meta")
with (
Runtime(perf_model, self.device_profile) as runtime,
torch.no_grad(),
):
func(x, y)
total_time_s = runtime.total_execution_time_s()[perf_model.name]
self.assertGreater(total_time_s, 0)
self.assertTrue(self.data_source.lookup.called)
def test_empirical_model_with_runtime_add(self):
"""Test EmpiricalPerformanceModel with real PyTorch add op via Runtime."""
self.data_source.lookup.return_value = None
perf_model = EmpiricalPerformanceModel(self.device_profile, self.data_source)
def func(x, y):
return torch.add(x, y)
x = torch.randn([1000, 1000], device="meta")
y = torch.randn([1000, 1000], device="meta")
with (
Runtime(perf_model, self.device_profile) as runtime,
torch.no_grad(),
):
func(x, y)
total_time_s = runtime.total_execution_time_s()[perf_model.name]
self.assertGreater(total_time_s, 0)
self.assertTrue(self.data_source.lookup.called)
if __name__ == "__main__":
unittest.main()