from collections import defaultdict
import tempfile
import unittest
from unittest.mock import Mock, patch
import numpy as np
import mindspore as ms
from mindspore import Tensor, ops
from msprobe.core.common.utils import Const
from msprobe.mindspore.dump.mindspore_service import MindsporeService
from msprobe.core.common.runtime import Runtime
from msprobe.core.dump.common_config import CommonConfig, BaseConfig
from msprobe.mindspore.dump.debugger.debugger_config import DebuggerConfig
from msprobe.mindspore.dump.dump_processor.hook_cell.hook_cell import HOOKCell
from msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks import PrimitiveHookService
from msprobe.mindspore.dump.ms_config import StatisticsConfig
class TestPrimitiveHookService(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
dump_path = self.temp_dir.name
json_config = {
"task": "statistics",
"dump_path": dump_path,
"rank": [],
"step": [0, 2],
"level": "L1"
}
common_config = CommonConfig(json_config)
task_config = StatisticsConfig(json_config)
config = DebuggerConfig(common_config, task_config)
with patch('msprobe.core.dump.service.build_data_collector'), \
patch('msprobe.mindspore.dump.mindspore_service.CellProcessor'), \
patch('msprobe.mindspore.dump.mindspore_service.PrimitiveHookService'), \
patch('msprobe.mindspore.dump.mindspore_service.get_api_register'):
self.mock_service_instance = MindsporeService(config)
Runtime.is_running = True
self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance)
def tearDown(self):
self.temp_dir.cleanup()
def test_two_input_backward_hook(self):
captured_grads = []
num_tensors = 2
updated_primitive_name = "test_primitive_output"
hook_type = Const.INPUT
hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents
create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents
backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type)
grad_1 = np.array([1.0, 2.0, 3.0])
grad_2 = np.array([4.0, 5.0, 6.0])
backward_hook(grad_1)
self.assertEqual(len(captured_grads), 3)
backward_hook(grad_2)
self.assertEqual(len(captured_grads), 6)
self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called)
def test_four_input_backward_hook(self):
captured_grads = []
num_tensors = 4
updated_primitive_name = "test_primitive_output"
hook_type = Const.INPUT
hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents
create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents
backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type)
grad_1 = np.array([1.0, 2.0, 3.0])
grad_2 = np.array([4.0, 5.0, 6.0])
grad_3 = np.array([7.0, 8.0, 9.0])
grad_4 = np.array([10.0, 11.0, 12.0])
backward_hook(grad_1)
self.assertEqual(len(captured_grads), 3)
backward_hook(grad_2)
self.assertEqual(len(captured_grads), 6)
backward_hook(grad_3)
self.assertEqual(len(captured_grads), 9)
backward_hook(grad_4)
self.assertEqual(len(captured_grads), 12)
self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called)
def test_two_output_backward_hook(self):
captured_grads = []
num_tensors = 2
updated_primitive_name = "test_primitive_output"
hook_type = Const.OUTPUT
hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents
create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents
backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type)
grad_1 = np.array([1.0, 2.0, 3.0])
grad_2 = np.array([4.0, 5.0, 6.0])
backward_hook(grad_1)
self.assertEqual(len(captured_grads), 3)
backward_hook(grad_2)
self.assertEqual(len(captured_grads), 6)
self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called)
def test_four_output_backward_hook(self):
captured_grads = []
num_tensors = 4
updated_primitive_name = "test_primitive_output"
hook_type = Const.OUTPUT
hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents
create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents
backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type)
grad_1 = np.array([1.0, 2.0, 3.0])
grad_2 = np.array([4.0, 5.0, 6.0])
grad_3 = np.array([7.0, 8.0, 9.0])
grad_4 = np.array([10.0, 11.0, 12.0])
backward_hook(grad_1)
self.assertEqual(len(captured_grads), 3)
backward_hook(grad_2)
self.assertEqual(len(captured_grads), 6)
backward_hook(grad_3)
self.assertEqual(len(captured_grads), 9)
backward_hook(grad_4)
self.assertEqual(len(captured_grads), 12)
self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called)
def test_hook_primitive_inputs(self):
args = (Tensor(np.array([1.0, 2.0]), ms.float32), Tensor(np.array([3.0, 4.0]), ms.float32))
captured_grads_input = []
updated_primitive_name = "test_primitive_input"
hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents
with patch.object(ops, 'HookBackward') as mock_HookBackward:
target_value = Tensor([1.0])
mock_hbw = mock_HookBackward.return_value
mock_hbw.return_value = target_value
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
self.assertEqual(mock_HookBackward.call_count, len(args))
for hooked_input in hooked_inputs:
self.assertTrue((hooked_input == target_value).all())
def test_hook_primitive_outputs(self):
out = (Tensor(np.array([1.0, 2.0]), ms.float32), Tensor(np.array([3.0, 4.0]), ms.float32))
captured_grads_output = []
updated_primitive_name = "test_primitive_output"
hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None,
"example").__closure__[1].cell_contents
with patch.object(ops, 'HookBackward') as mock_HookBackward:
target_value = Tensor([1.0])
mock_hbw = mock_HookBackward.return_value
mock_hbw.return_value = target_value
hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
self.assertEqual(mock_HookBackward.call_count, len(out))
for hooked_output in hooked_outputs:
self.assertTrue((hooked_output == target_value).all())
def test_wrapped_primitive_call_args(self):
args = (Tensor(np.array([1.0, 2.0]), ms.float32), Tensor(np.array([3.0, 4.0]), ms.float32))
captured_grads_input = []
updated_primitive_name = "test_primitive_args"
wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(lambda x, y: x + y, "add")
try:
with patch.object(ops, 'HookBackward') as mock_HookBackward:
target_value = Tensor([1.0])
mock_hbw = mock_HookBackward.return_value
mock_hbw.return_value = target_value
hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input,
updated_primitive_name)
self.assertEqual(mock_HookBackward.call_count, len(args))
for hooked_input in hooked_inputs:
self.assertTrue((hooked_input == target_value).all())
except Exception as e:
self.fail(f"wrapped_primitive_call raised an exception: {e}")
def test_update_primitive_counters_multiple(self):
primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"]
for name in primitive_names:
for i in range(3):
self.primitive_hook_service.update_primitive_counters(name)
self.assertEqual(self.primitive_hook_service.primitive_counters[name], i)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_forward_hook_various_inputs(self, mock_hook_backward):
input_tensors = [
Tensor(np.random.randn(2, 2).astype(np.float32)),
Tensor(np.random.randn(4, 4).astype(np.float32)),
Tensor(np.random.randn(10, 10).astype(np.float32)),
]
for input_tensor in input_tensors:
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
result = wrapped_func(Mock(), input_tensor)
mock_origin_func.assert_called_once()
self.assertIsInstance(result, Tensor)
def test_wrap_primitive_no_hook_with_invalid_input(self):
Runtime.is_running = False
invalid_inputs = [None, "invalid_tensor", 123]
for invalid_input in invalid_inputs:
mock_origin_func = Mock(return_value=invalid_input)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
result = wrapped_func(Mock(), invalid_input)
mock_origin_func.assert_called_once()
self.assertEqual(result, invalid_input)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_with_multiple_hooks(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
primitive_names = ["MatMul", "Add", "Sub"]
for name in primitive_names:
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, name)
result = wrapped_func(Mock(), input_tensor)
mock_origin_func.assert_called_once()
self.assertIsInstance(result, Tensor)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_with_exception_handling_multiple(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
exception_messages = ["Invalid operation", "Null reference", "Type error"]
for exception_message in exception_messages:
mock_origin_func = Mock(side_effect=Exception(exception_message))
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
with self.assertRaises(Exception) as context:
wrapped_func(Mock(), input_tensor)
self.assertIn(exception_message, str(context.exception))
def test_create_backward_hook_multiple(self):
captured_grads_sets = [[Mock()], [Mock(), Mock()], [Mock(), Mock(), Mock()]]
for captured_grads in captured_grads_sets:
updated_primitive_name = "MatMul.Backward"
hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul")
backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT)
self.assertIsNotNone(backward_hook)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "Conv2D")
result = wrapped_func(Mock(), input_tensor)
mock_origin_func.assert_called_once()
self.assertIsInstance(result, Tensor)
def test_update_primitive_counters_different_names(self):
primitive_names = ["MatMul", "Add", "Sub", "Mul", "Conv2D"]
for name in primitive_names:
for i in range(5):
self.primitive_hook_service.update_primitive_counters(name)
self.assertEqual(self.primitive_hook_service.primitive_counters[name], i)
def test_update_primitive_counters(self):
primitive_name = "MatMul"
self.primitive_hook_service.update_primitive_counters(primitive_name)
self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 0)
self.primitive_hook_service.update_primitive_counters(primitive_name)
self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 1)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_forward_hook(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
result = wrapped_func(Mock(), input_tensor)
mock_origin_func.assert_called_once()
self.assertIsInstance(result, Tensor)
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_backward_hook(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
grad_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
mock_hook_backward.return_value = lambda x: grad_tensor
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect'):
result = wrapped_func(Mock(), input_tensor)
self.assertIsInstance(result, Tensor)
def test_wrap_primitive_no_hook_when_switch_off(self):
Runtime.is_running = False
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
mock_origin_func = Mock(return_value=input_tensor)
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
result = wrapped_func(Mock(), input_tensor)
mock_origin_func.assert_called_once()
HOOKCell.cell_count = defaultdict(int)
self.assertTrue((result == input_tensor).all())
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_wrap_primitive_error_handling(self, mock_hook_backward):
input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32))
mock_origin_func = Mock(side_effect=Exception("Mocked exception"))
wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul")
with self.assertRaises(Exception) as context:
wrapped_func(Mock(), input_tensor)
self.assertIn("Mocked exception", str(context.exception))
@patch('msprobe.mindspore.dump.dump_processor.hook_cell.primitive_hooks.ops.HookBackward')
def test_create_backward_hook(self, mock_hook_backward):
captured_grads = []
updated_primitive_name = "MatMul.Backward"
backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul")
hook = backward_hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT)
self.assertIsNotNone(hook)