import sys
import logging
from unittest import mock
import unittest
import torch
import torch.cuda._sanitizer as csan
from torch.utils._python_dispatch import TorchDispatchMode
import torch_npu
import torch_npu.npu._stream_check as stream_check
from torch_npu.testing.testcase import TestCase, run_tests
class TestStreamCheck(TestCase):
def test_parse_methods_with_valid_inputs(self):
mock_event_handler = mock.MagicMock()
mode = stream_check.NPUSanitizerDispatchMode(mock_event_handler)
mock_schema = mock.MagicMock()
args = (torch.tensor([1.0]), torch.tensor([2.0]))
kwargs = {'test_arg': torch.tensor([3.0])}
mode.args_handler = mock.MagicMock()
mode.parse_inputs(mock_schema, args, kwargs, is_factory=False)
mode.args_handler.parse_inputs.assert_called_once_with(mock_schema, args, kwargs, is_factory=False)
mock_outputs = [torch.tensor([4.0])]
mode.parse_outputs(mock_schema, mock_outputs, is_factory=False)
mode.args_handler.parse_outputs.assert_called_once_with(mock_schema, mock_outputs, is_factory=False)
def test_torch_dispatch_success(self):
mock_event_handler = mock.MagicMock()
mode = stream_check.NPUSanitizerDispatchMode(mock_event_handler)
mock_func = mock.MagicMock()
mock_func.__name__ = "aten::add"
mock_func._schema = mock.MagicMock()
mock_args = (torch.tensor([1.0]), torch.tensor([2.0]))
mock_kwargs = {}
with mock.patch('torch_npu.npu.current_stream') as mock_stream:
mock_stream_instance = mock.MagicMock()
mock_stream_instance.npu_stream = 1
mock_stream.return_value = mock_stream_instance
mock_outputs = [torch.tensor([4.0])]
with mock.patch.object(mode, 'parse_inputs') as mock_parse_inputs, \
mock.patch.object(mode, 'parse_outputs') as mock_parse_outputs, \
mock.patch.object(mode, 'check_errors') as mock_check_errors:
pass
def test_enable_autograd_with_matching_api(self):
mock_event_handler = mock.MagicMock()
mode = stream_check.NPUSanitizerDispatchMode(mock_event_handler)
with mock.patch('torch._C._dispatch_tls_set_dispatch_key_excluded') as mock_set_dispatch:
mode.enable_autograd("adaptive_avg_pool2d")
mock_set_dispatch.assert_called_once_with(torch._C.DispatchKey.AutogradFunctionality, False)
def test_init_with_event_handler(self):
mock_event_handler = mock.MagicMock()
mode = stream_check.NPUSanitizerDispatchMode(mock_event_handler)
self.assertEqual(mode.event_handler, mock_event_handler)
self.assertIsNone(mode.args_handler)
self.assertEqual(mode.npu_adjust_autograd, ["adaptive_avg_pool2d", "batch_norm", "log_softmax", "nll_loss", "to"])
if __name__ == "__main__":
run_tests()