import os
import shutil
from unittest import mock
from torch_npu._C._profiler import ProfilerActivity
from torch_npu.npu import Event
from torch_npu.profiler import supported_activities
from torch_npu.profiler._profiler_path_creator import ProfPathCreator
from torch_npu.profiler.analysis.prof_common_func._cann_package_manager import (
CannPackageManager,
)
from torch_npu.profiler.analysis.prof_common_func._constant import Constant
from torch_npu.profiler.profiler_interface import (
_disable_event_record,
_enable_event_record,
_ProfInterface,
)
from torch_npu.testing.testcase import run_tests, TestCase
import torch
class TestActionController(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.prof_dir = "./result_dir"
cls.namespace = "torch_npu.profiler.profiler_interface"
@classmethod
def tearDownClass(cls):
if os.path.exists(cls.prof_dir):
shutil.rmtree(cls.prof_dir)
def setUp(self):
self.prof_if = _ProfInterface()
ProfPathCreator().init(dir_name=self.prof_dir)
def test_init_trace(self):
self.prof_if.custom_trace_id_callback = lambda: "trace_0"
with mock.patch(self.namespace + "._init_profiler") as mock_func:
self.prof_if.init_trace()
self.assertEqual(1, mock_func.call_count)
self.assertTrue(os.path.exists(self.prof_dir))
self.assertEqual("trace_0", self.prof_if.trace_id)
def test_start_trace(self):
with (
mock.patch(
self.namespace + ".NpuProfilerConfig", return_value="config_obj"
),
mock.patch(self.namespace + "._get_syscnt_enable", return_value=True),
mock.patch(self.namespace + "._get_freq", return_value=100),
mock.patch(self.namespace + "._get_syscnt", return_value=10000),
mock.patch(self.namespace + "._get_monotonic", return_value=20000),
mock.patch(self.namespace + "._start_profiler") as mock_start,
):
self.prof_if.start_trace()
self.assertEqual(True, self.prof_if.syscnt_enable)
self.assertEqual(100, self.prof_if.freq)
self.assertEqual(10000, self.prof_if.start_cnt)
self.assertEqual(20000, self.prof_if.start_monotonic)
mock_start.assert_called_once_with("config_obj", supported_activities())
def test_stop_trace(self):
with mock.patch(self.namespace + "._stop_profiler") as mock_stop:
self.prof_if.stop_trace()
mock_stop.assert_called_once()
def test_finalize_trace(self):
with (
mock.patch(self.namespace + "._init_profiler"),
mock.patch(self.namespace + "._finalize_profiler") as mock_finalize,
):
self.prof_if.metadata = {"key": "val"}
self.prof_if.init_trace()
self.prof_if.finalize_trace()
mock_finalize.assert_called_once()
self.assertTrue(self._check_profiler_info_json(self.prof_if.prof_path))
self.assertTrue(self._check_metadata_json(self.prof_if.prof_path))
def test_analyse(self):
with mock.patch(
"torch_npu.profiler.analysis._npu_profiler.NpuProfiler.analyse"
) as mock_analyse:
self.prof_if.analyse()
mock_analyse.assert_called_once()
def test_supported_activities(self):
activities = set(supported_activities())
self.assertEqual(2, len(activities))
self.assertTrue(ProfilerActivity.CPU in activities)
self.assertTrue(ProfilerActivity.NPU in activities)
def test_create_trace_id_should_return_default_when_callback_is_none(self):
self.prof_if.custom_trace_id_callback = None
trace_id = self.prof_if.create_trace_id()
self.assertIsInstance(trace_id, str)
self.assertEqual(32, len(trace_id))
def test_create_trace_id_should_return_callback_result_when_callback_is_valid(self):
self.prof_if.custom_trace_id_callback = lambda: "custom_trace_id"
trace_id = self.prof_if.create_trace_id()
self.assertEqual("custom_trace_id", trace_id)
def test_create_trace_id_should_return_default_when_callback_is_not_callable(self):
self.prof_if.custom_trace_id_callback = "not_callable"
trace_id = self.prof_if.create_trace_id()
self.assertIsInstance(trace_id, str)
self.assertEqual(32, len(trace_id))
def test_create_trace_id_should_return_default_when_callback_returns_non_string(
self,
):
self.prof_if.custom_trace_id_callback = lambda: 12345
trace_id = self.prof_if.create_trace_id()
self.assertIsInstance(trace_id, str)
self.assertEqual(32, len(trace_id))
def test_create_trace_id_should_return_default_when_callback_raises_exception(self):
def bad_callback():
raise RuntimeError("callback error")
self.prof_if.custom_trace_id_callback = bad_callback
trace_id = self.prof_if.create_trace_id()
self.assertIsInstance(trace_id, str)
self.assertEqual(32, len(trace_id))
def test_create_trace_id_should_return_default_when_callback_returns_too_long_string(
self,
):
long_str = "a" * (self.prof_if.MAX_TRACE_ID_LEN + 1)
self.prof_if.custom_trace_id_callback = lambda: long_str
trace_id = self.prof_if.create_trace_id()
self.assertIsInstance(trace_id, str)
self.assertEqual(32, len(trace_id))
def test_event_record_should_have_return_true_attr_when_enable_record(self):
_enable_event_record()
self.assertTrue(hasattr(Event.record, "origin_func"))
self.assertTrue(hasattr(Event.wait, "origin_func"))
self.assertTrue(hasattr(Event.query, "origin_func"))
self.assertTrue(hasattr(Event.elapsed_time, "origin_func"))
self.assertTrue(hasattr(Event.synchronize, "origin_func"))
_disable_event_record()
self.assertFalse(hasattr(Event.record, "origin_func"))
self.assertFalse(hasattr(Event.wait, "origin_func"))
self.assertFalse(hasattr(Event.query, "origin_func"))
self.assertFalse(hasattr(Event.elapsed_time, "origin_func"))
self.assertFalse(hasattr(Event.synchronize, "origin_func"))
def _check_profiler_info_json(self, prof_path: str) -> bool:
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank_id = torch.distributed.get_rank()
path = os.path.join(
os.path.realpath(prof_path), f"profiler_info_{rank_id}.json"
)
else:
path = os.path.join(os.path.realpath(prof_path), "profiler_info.json")
return os.path.exists(path)
def _check_metadata_json(self, prof_path: str) -> bool:
path = os.path.join(os.path.realpath(prof_path), "profiler_metadata.json")
return os.path.exists(path)
def _check_params(self):
CannPackageManager.SUPPORT_EXPORT_DB = False
self.prof_if.activities = set(supported_activities())
self.prof_if.experimental_config.export_type = Constant.Db
with self.assertExpectedRaises(RuntimeError):
self.prof_if._check_params()
if __name__ == "__main__":
run_tests()