import unittest
from unittest.mock import patch, MagicMock
import ctypes
from ms_service_profiler.mstx import LibServiceProfiler, ProfilerCallbackResult
class TestProfilerCallbackResult(unittest.TestCase):
def test_callback_result_dynamic(self):
result = ProfilerCallbackResult(ProfilerCallbackResult.DYNAMIC)
self.assertEqual(result.mode, ProfilerCallbackResult.DYNAMIC)
self.assertEqual(result.message, "")
self.assertTrue(result.is_dynamic)
self.assertFalse(result.is_legacy)
def test_callback_result_legacy(self):
result = ProfilerCallbackResult(ProfilerCallbackResult.LEGACY, "Legacy mode message")
self.assertEqual(result.mode, ProfilerCallbackResult.LEGACY)
self.assertEqual(result.message, "Legacy mode message")
self.assertFalse(result.is_dynamic)
self.assertTrue(result.is_legacy)
def test_callback_result_dynamic_with_message(self):
result = ProfilerCallbackResult(ProfilerCallbackResult.DYNAMIC, "Success message")
self.assertTrue(result.is_dynamic)
self.assertFalse(result.is_legacy)
self.assertEqual(result.message, "Success message")
class TestLibServiceProfiler(unittest.TestCase):
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
def setUp(self, mock_get_valid_lib_path=None):
if mock_get_valid_lib_path is not None:
mock_get_valid_lib_path.return_value = "/path/to/libms_service_profiler.so"
self.service_profiler = LibServiceProfiler()
self.service_profiler.is_initialized = False
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
@patch("ctypes.cdll.LoadLibrary")
def test_init_with_valid_lib_path(self, mock_load_library, mock_get_valid_lib_path):
mock_get_valid_lib_path.return_value = "/path/to/libms_service_profiler.so"
mock_load_library.return_value = MagicMock()
profiler = LibServiceProfiler()
profiler.init()
self.assertIsNotNone(profiler.lib)
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
def test_init_with_empty_lib_path(self, mock_get_valid_lib_path):
mock_get_valid_lib_path.return_value = ""
profiler = LibServiceProfiler()
self.assertIsNone(profiler.lib)
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
def test_init_with_none_lib_path(self, mock_get_valid_lib_path):
mock_get_valid_lib_path.return_value = None
profiler = LibServiceProfiler()
self.assertIsNone(profiler.lib)
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
def test_init_with_load_library_error(self, mock_get_valid_lib_path):
mock_get_valid_lib_path.return_value = "/path/to/libms_service_profiler.so"
with patch("ctypes.cdll.LoadLibrary", side_effect=Exception("Library load error")):
profiler = LibServiceProfiler()
self.assertIsNone(profiler.lib)
def test_start_span(self):
self.service_profiler.func_start_span_with_name = MagicMock(return_value=12345)
span_handle = self.service_profiler.start_span("test_span")
self.assertEqual(span_handle, 12345)
self.service_profiler.func_start_span_with_name.assert_called_once_with(b'test_span')
def test_end_span(self):
self.service_profiler.func_end_span = MagicMock()
self.service_profiler.end_span(12345)
self.service_profiler.func_end_span.assert_called_once_with(12345)
def test_mark_span_attr(self):
self.service_profiler.func_mark_span_attr = MagicMock()
self.service_profiler.mark_span_attr("test_attr", 12345)
self.service_profiler.func_mark_span_attr.assert_called_once_with(b'test_attr', 12345)
def test_mark_event(self):
self.service_profiler.func_mark_event = MagicMock()
self.service_profiler.mark_event("test_event")
self.service_profiler.func_mark_event.assert_called_once_with(b'test_event')
def test_start_profiler(self):
self.service_profiler.func_start_service_profiler = MagicMock()
self.service_profiler.start_profiler()
self.service_profiler.func_start_service_profiler.assert_called_once()
def test_stop_profiler(self):
self.service_profiler.func_stop_service_profiler = MagicMock()
self.service_profiler.stop_profiler()
self.service_profiler.func_stop_service_profiler.assert_called_once()
def test_is_enable(self):
self.service_profiler.func_is_enable = MagicMock(return_value=True)
result = self.service_profiler.is_enable(1)
self.assertTrue(result)
self.service_profiler.func_is_enable.assert_called_once_with(1)
def test_add_meta_info(self):
self.service_profiler.is_initialized = True
self.service_profiler.func_add_meta_info = MagicMock()
self.service_profiler.add_meta_info("key", "value")
self.service_profiler.func_add_meta_info.assert_called_once_with(b"key", b"value")
def test_mark_event_ex_with_func(self):
self.service_profiler.func_mark_event_ex = MagicMock()
self.service_profiler.mark_event_ex("test_name", "test_domain", "test_msg")
self.service_profiler.func_mark_event_ex.assert_called_once_with(
b"test_name", b"test_domain", b"test_msg"
)
def test_mark_event_ex_fallback(self):
self.service_profiler.func_mark_event_ex = None
self.service_profiler.func_mark_event = MagicMock()
import json
self.service_profiler.mark_event_ex("test_name", "test_domain", "test_msg")
self.service_profiler.func_mark_event.assert_called_once()
call_args = self.service_profiler.func_mark_event.call_args[0][0]
result = json.loads(call_args)
self.assertEqual(result["name"], "test_name")
self.assertEqual(result["domain"], "test_domain")
self.assertEqual(result["msg"], "test_msg")
def test_span_end_ex_with_func(self):
self.service_profiler.func_span_end_ex = MagicMock()
self.service_profiler.span_end_ex("test_name", "test_domain", "test_msg", 12345)
self.service_profiler.func_span_end_ex.assert_called_once_with(
b"test_name", b"test_domain", b"test_msg", 12345
)
def test_span_end_ex_fallback(self):
self.service_profiler.func_span_end_ex = None
self.service_profiler.func_end_span = MagicMock()
self.service_profiler.func_mark_span_attr = MagicMock()
self.service_profiler.span_end_ex("test_name", "test_domain", "test_msg", 12345)
self.service_profiler.func_end_span.assert_called_once_with(12345)
self.service_profiler.func_mark_span_attr.assert_called_once()
call_args = self.service_profiler.func_mark_span_attr.call_args[0][0]
self.assertIsInstance(call_args, bytes)
extra_str = call_args.decode("utf-8")
import json
result = json.loads(extra_str)
self.assertEqual(result["name"], "test_name")
self.assertEqual(result["domain"], "test_domain")
self.assertEqual(result["msg"], "test_msg")
def test_is_domain_enable(self):
self.service_profiler.func_is_valid_dommain = MagicMock(return_value=True)
result = self.service_profiler.is_domain_enable("test_domain")
self.assertTrue(result)
self.service_profiler.func_is_valid_dommain.assert_called_once_with(b"test_domain")
def test_is_domain_enable_no_func(self):
self.service_profiler.func_is_valid_dommain = None
result = self.service_profiler.is_domain_enable("test_domain")
self.assertTrue(result)
def test_get_prof_path(self):
self.service_profiler.func_get_prof_path = MagicMock(return_value=b"/test/path")
result = self.service_profiler.get_prof_path()
self.assertEqual(result, "/test/path")
def test_get_prof_path_empty(self):
self.service_profiler.func_get_prof_path = MagicMock(return_value=None)
result = self.service_profiler.get_prof_path()
self.assertEqual(result, "")
def test_get_acl_task_time_level(self):
self.service_profiler.func_get_acl_task_time_level = MagicMock(return_value=b"L1")
result = self.service_profiler.get_acl_task_time_level()
self.assertEqual(result, "L1")
def test_get_acl_task_time_level_default(self):
self.service_profiler.func_get_acl_task_time_level = None
result = self.service_profiler.get_acl_task_time_level()
self.assertEqual(result, "L0")
def test_get_acl_prof_aicore_metrics(self):
self.service_profiler.func_get_acl_prof_aicore_metrics = MagicMock(return_value=3)
result = self.service_profiler.get_acl_prof_aicore_metrics()
self.assertEqual(result, 3)
def test_get_acl_prof_aicore_metrics_default(self):
self.service_profiler.func_get_acl_prof_aicore_metrics = None
result = self.service_profiler.get_acl_prof_aicore_metrics()
self.assertEqual(result, -1)
def test_get_torch_prof_step_num(self):
self.service_profiler.func_get_torch_prof_step_num = MagicMock(return_value=100)
result = self.service_profiler.get_torch_prof_step_num()
self.assertEqual(result, 100)
def test_get_torch_prof_step_num_default(self):
self.service_profiler.func_get_torch_prof_step_num = None
result = self.service_profiler.get_torch_prof_step_num()
self.assertEqual(result, 0)
def test_is_torch_prof_stack(self):
self.service_profiler.func_get_torch_prof_stack = MagicMock(return_value=True)
result = self.service_profiler.is_torch_prof_stack()
self.assertTrue(result)
def test_is_torch_prof_stack_default(self):
self.service_profiler.func_get_torch_prof_stack = None
result = self.service_profiler.is_torch_prof_stack()
self.assertFalse(result)
def test_is_torch_prof_modules(self):
self.service_profiler.func_get_torch_prof_modules = MagicMock(return_value=True)
result = self.service_profiler.is_torch_prof_modules()
self.assertTrue(result)
def test_is_torch_prof_modules_default(self):
self.service_profiler.func_get_torch_prof_modules = None
result = self.service_profiler.is_torch_prof_modules()
self.assertFalse(result)
def test_is_torch_profiler_enable(self):
self.service_profiler.func_get_torch_profiler_enable = MagicMock(return_value=True)
self.service_profiler.func_is_enable = MagicMock(return_value=True)
result = self.service_profiler.is_torch_profiler_enable(10)
self.assertTrue(result)
def test_is_torch_profiler_enable_no_torch_func(self):
self.service_profiler.func_get_torch_profiler_enable = None
result = self.service_profiler.is_torch_profiler_enable(10)
self.assertFalse(result)
def test_is_torch_profiler_enable_disabled(self):
self.service_profiler.func_get_torch_profiler_enable = MagicMock(return_value=True)
self.service_profiler.func_is_enable = MagicMock(return_value=False)
result = self.service_profiler.is_torch_profiler_enable(10)
self.assertFalse(result)
def test_start_span_with_none_name(self):
self.service_profiler.func_start_span_with_name = MagicMock(return_value=12345)
span_handle = self.service_profiler.start_span(None)
self.assertEqual(span_handle, 12345)
self.service_profiler.func_start_span_with_name.assert_called_once_with(b"")
def test_start_span_no_func(self):
self.service_profiler.func_start_span_with_name = None
span_handle = self.service_profiler.start_span("test")
self.assertEqual(span_handle, 0)
def test_end_span_no_func(self):
self.service_profiler.func_end_span = None
self.service_profiler.end_span(12345)
def test_mark_span_attr_no_func(self):
self.service_profiler.func_mark_span_attr = None
self.service_profiler.mark_span_attr("test_attr", 12345)
def test_mark_event_no_func(self):
self.service_profiler.func_mark_event = None
self.service_profiler.mark_event("test_event")
def test_start_profiler_no_func(self):
self.service_profiler.func_start_service_profiler = None
self.service_profiler.start_profiler()
def test_stop_profiler_no_func(self):
self.service_profiler.func_stop_service_profiler = None
self.service_profiler.stop_profiler()
def test_is_enable_no_func(self):
self.service_profiler.func_is_enable = None
result = self.service_profiler.is_enable(1)
self.assertFalse(result)
def test_add_meta_info_no_func(self):
self.service_profiler.func_add_meta_info = None
self.service_profiler.add_meta_info("key", "value")
def test_supports_dynamic_callbacks_true(self):
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = MagicMock()
self.service_profiler._func_register_stop_callback = MagicMock()
result = self.service_profiler.supports_dynamic_callbacks()
self.assertTrue(result)
def test_supports_dynamic_callbacks_no_lib(self):
self.service_profiler.init()
self.service_profiler.lib = None
result = self.service_profiler.supports_dynamic_callbacks()
self.assertFalse(result)
def test_supports_dynamic_callbacks_no_register_funcs(self):
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = None
self.service_profiler._func_register_stop_callback = None
result = self.service_profiler.supports_dynamic_callbacks()
self.assertFalse(result)
def test_on_cpp_start(self):
callback1 = MagicMock()
callback2 = MagicMock()
self.service_profiler._start_callbacks = [callback1, callback2]
self.service_profiler._on_cpp_start()
callback1.assert_called_once()
callback2.assert_called_once()
def test_on_cpp_start_exception(self):
callback1 = MagicMock(side_effect=Exception("Test error"))
callback2 = MagicMock()
self.service_profiler._start_callbacks = [callback1, callback2]
self.service_profiler._on_cpp_start()
callback2.assert_called_once()
def test_on_cpp_stop(self):
callback1 = MagicMock()
callback2 = MagicMock()
self.service_profiler._stop_callbacks = [callback1, callback2]
self.service_profiler._on_cpp_stop()
callback1.assert_called_once()
callback2.assert_called_once()
def test_on_cpp_stop_exception(self):
callback1 = MagicMock(side_effect=Exception("Test error"))
callback2 = MagicMock()
self.service_profiler._stop_callbacks = [callback1, callback2]
self.service_profiler._on_cpp_stop()
callback2.assert_called_once()
def test_register_profiler_start_callback_dynamic(self):
callback = MagicMock()
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = MagicMock()
self.service_profiler._func_register_stop_callback = MagicMock()
self.service_profiler._cpp_callbacks_registered = False
result = self.service_profiler.register_profiler_start_callback(callback)
self.assertEqual(result.mode, ProfilerCallbackResult.DYNAMIC)
self.assertIn(callback, self.service_profiler._start_callbacks)
def test_register_profiler_start_callback_legacy(self):
callback = MagicMock()
self.service_profiler._func_register_start_callback = None
self.service_profiler._func_register_stop_callback = None
result = self.service_profiler.register_profiler_start_callback(callback)
self.assertEqual(result.mode, ProfilerCallbackResult.LEGACY)
self.assertIn(callback, self.service_profiler._start_callbacks)
def test_register_profiler_stop_callback_dynamic(self):
callback = MagicMock()
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = MagicMock()
self.service_profiler._func_register_stop_callback = MagicMock()
self.service_profiler._cpp_callbacks_registered = False
result = self.service_profiler.register_profiler_stop_callback(callback)
self.assertEqual(result.mode, ProfilerCallbackResult.DYNAMIC)
self.assertIn(callback, self.service_profiler._stop_callbacks)
def test_register_profiler_stop_callback_legacy(self):
callback = MagicMock()
self.service_profiler._func_register_start_callback = None
self.service_profiler._func_register_stop_callback = None
result = self.service_profiler.register_profiler_stop_callback(callback)
self.assertEqual(result.mode, ProfilerCallbackResult.LEGACY)
self.assertIn(callback, self.service_profiler._stop_callbacks)
def test_init_already_initialized(self):
self.service_profiler.is_initialized = True
self.service_profiler.init()
self.assertTrue(self.service_profiler.is_initialized)
def test_ensure_cpp_callbacks_registered_already_registered(self):
self.service_profiler._cpp_callbacks_registered = True
result = self.service_profiler._ensure_cpp_callbacks_registered()
self.assertTrue(result)
self.assertEqual(len(self.service_profiler._c_callback_refs), 0)
@patch("ms_service_profiler.utils.file_open_check.get_valid_lib_path")
def test_ensure_cpp_callbacks_registered_lib_none(self, mock_get_valid_lib_path):
mock_get_valid_lib_path.return_value = None
profiler = LibServiceProfiler()
profiler.init()
result = profiler._ensure_cpp_callbacks_registered()
self.assertFalse(result)
def test_ensure_cpp_callbacks_registered_no_register_funcs(self):
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = None
self.service_profiler._func_register_stop_callback = None
result = self.service_profiler._ensure_cpp_callbacks_registered()
self.assertFalse(result)
def test_register_start_callback_message_dynamic(self):
callback = MagicMock()
self.service_profiler.init()
self.service_profiler.lib = MagicMock()
self.service_profiler._func_register_start_callback = MagicMock()
self.service_profiler._func_register_stop_callback = MagicMock()
self.service_profiler._cpp_callbacks_registered = False
result = self.service_profiler.register_profiler_start_callback(callback)
self.assertEqual(result.message, "Callback registered successfully")
def test_register_stop_callback_message_legacy(self):
callback = MagicMock()
self.service_profiler._func_register_start_callback = None
self.service_profiler._func_register_stop_callback = None
result = self.service_profiler.register_profiler_stop_callback(callback)
self.assertEqual(result.message, "C++ library does not support dynamic callbacks")
if __name__ == '__main__':
unittest.main()