import sys
import builtins
import unittest
from unittest.mock import Mock, patch
import torch
ACTION_NAME = "test"
FLAG = True
UNFLAG = False
REAL_IMPORT = builtins.__import__
def mock_import_module(name, *args, **kwargs):
if name == "ms_service_profiler":
raise ImportError("Mocked ImportERROR")
else:
return REAL_IMPORT(name, *args, **kwargs)
class TestProfilerFunctions(unittest.TestCase):
def __init__(self, methodName="runTest"):
super().__init__(methodName)
self.torch_npu = None
def clean(self):
if "mindie_llm.utils.prof.profiler" in sys.modules:
del sys.modules["mindie_llm.utils.prof.profiler"]
if "ms_service_profiler" in sys.modules:
del sys.modules["ms_service_profiler"]
if "ms_service_profiler.mstx" in sys.modules:
del sys.modules["ms_service_profiler.mstx"]
if "mindie_llm.utils.log.logging" in sys.modules:
del sys.modules["mindie_llm.utils.log.logging"]
if "torch_npu" in sys.modules:
del sys.modules["torch_npu"]
del torch.npu
def setUp(self):
self.torch_npu = sys.modules.get("torch_npu", None)
self.clean()
torch.npu = Mock()
sys.modules["torch_npu"] = Mock()
def tearDown(self):
self.clean()
if self.torch_npu is not None:
sys.modules["torch_npu"] = self.torch_npu
self.torch_npu = None
def test_span_start_with_error_import(self):
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import span_start
try:
span_start(ACTION_NAME, FLAG)
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_span_end_with_error_import(self):
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import span_start, span_end
try:
prof = span_start(ACTION_NAME)
span_end(prof, FLAG)
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_span_req_with_error_import(self):
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import span_req
try:
span_req(ACTION_NAME, ["rid_1234"])
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_profiler_with_error_import(self):
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import Profiler
try:
prof = Profiler()
prof.span_start()
prof.span_end()
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_count_block_with_success_import(self):
from mindie_llm.utils.prof.profiler import count_block
from mindie_llm.utils.prof.profiler import np
try:
count_block(np.zeros((2, 3, 4), dtype=np.int32))
count_block(np.zeros((10, 2), dtype=np.int32))
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_tensor_attr_with_error_import(self):
"""测试在导入失败的情况下 tensor_attr 是否可用"""
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor = torch.Tensor([1, 2, 3, 4, 5])
try:
_ = tensor_attr(test_tensor, statistics=True)
except Exception as e:
self.fail(f"tensor_attr raised an exception when it shouldn't have: {e}")
def test_tensor_attr_statistics_mode(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试 tensor_attr 在 statistics=True 时的行为"""
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor_1d = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0])
result = tensor_attr(test_tensor_1d, statistics=True)
self.assertIsInstance(result, dict)
self.assertIn("min", result)
self.assertIn("max", result)
self.assertIn("mean", result)
self.assertIn("first_10", result)
self.assertIn("shape", result)
self.assertEqual(result["min"], 1.0)
self.assertEqual(result["max"], 5.0)
self.assertEqual(result["mean"], 3.0)
self.assertEqual(result["first_10"], [1.0, 2.0, 3.0, 4.0, 5.0])
self.assertEqual(result["shape"], [5])
def test_tensor_attr_no_statistics_mode(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试 tensor_attr 在 statistics=False 时的行为"""
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor_1d = torch.Tensor([1, 2, 3])
result = tensor_attr(test_tensor_1d, statistics=False)
self.assertEqual(result, [1, 2, 3])
def test_tensor_attr_2d_tensor(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试 2D 张量的处理"""
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor_2d = torch.Tensor([[1, 2, 3], [4, 5, 6]])
result = tensor_attr(test_tensor_2d, statistics=True)
self.assertEqual(result["min"], 1)
self.assertEqual(result["max"], 6)
self.assertEqual(result["mean"], 3.5)
self.assertEqual(result["first_10"], [1, 2, 3, 4, 5, 6])
self.assertEqual(result["shape"], [2, 3])
def test_tensor_attr_large_tensor(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试大型张量(元素超过10个)"""
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor_large = torch.Tensor(list(range(20)))
result = tensor_attr(test_tensor_large, statistics=True)
self.assertEqual(result["first_10"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
self.assertEqual(result["shape"], [20])
def test_tensor_attr_tuple_input(self):
"""测试输入为元组的情况"""
with patch('builtins.__import__', side_effect=mock_import_module):
from mindie_llm.utils.prof.profiler import tensor_attr
test_tuple = (1, 2, 3)
result = tensor_attr(test_tuple, statistics=True)
self.assertEqual(result, (1, 2, 3))
def test_tensor_attr_empty_tensor(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试空张量"""
from mindie_llm.utils.prof.profiler import tensor_attr
empty_tensor = torch.Tensor([])
result = tensor_attr(empty_tensor, statistics=True)
self.assertEqual(result, {})
def test_tensor_attr_complex_tensor(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
"""测试复杂张量(浮点数、负数等)"""
from mindie_llm.utils.prof.profiler import tensor_attr
complex_tensor = torch.Tensor([-1.5, 0.0, 3.14, -2.7, 10.0])
result = tensor_attr(complex_tensor, statistics=True)
self.assertAlmostEqual(result["min"], -2.7)
self.assertAlmostEqual(result["max"], 10.0)
self.assertAlmostEqual(result["mean"], 1.788, places=2)
self.assertEqual(result["first_10"], torch.Tensor([-1.5, 0.0, 3.14, -2.7, 10.0]).tolist())
def test_tensor_attr_with_import_success(self):
"""测试在成功导入的情况下 tensor_attr 的行为"""
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
sys.modules["mindie_llm.utils.tensor"] = Mock()
from mindie_llm.utils.prof.profiler import tensor_attr
test_tensor = torch.Tensor([[1, 2], [3, 4]])
result = tensor_attr(test_tensor, statistics=True)
self.assertIsInstance(result, dict)
self.assertEqual(result["min"], 1)
self.assertEqual(result["max"], 4)
self.assertEqual(result["mean"], 2.5)
def mock_profiler(self):
ms_service_profiler = Mock()
ms_service_profiler.Profiler = lambda _: self.mock_prof()
mstx = Mock()
mstx.service_profiler = Mock()
mstx.service_profiler.is_enable = lambda _: True
return ms_service_profiler, mstx
def mock_prof(self):
prof = Mock()
prof.span_start = Mock()
prof.span_end = Mock()
prof._enable = True
return prof
def test_span_start_when_import_success_no_error(self):
ms_service_profiler, mstx = self.mock_profiler()
sys.modules["ms_service_profiler"] = ms_service_profiler
sys.modules["ms_service_profiler.mstx"] = mstx
sys.modules["mindie_llm.utils.log.logging"] = Mock()
from mindie_llm.utils.prof.profiler import span_start
try:
span_start(ACTION_NAME, FLAG)
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_span_end_when_import_success_no_error(self):
ms_service_profiler, mstx = self.mock_profiler()
sys.modules["ms_service_profiler"] = ms_service_profiler
sys.modules["ms_service_profiler.mstx"] = mstx
sys.modules["mindie_llm.utils.log.logging"] = Mock()
from mindie_llm.utils.prof.profiler import span_attr, span_end
try:
span_attr(self.mock_prof(), "key", "value")
span_end(self.mock_prof(), FLAG)
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_span_req_when_import_success(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
from mindie_llm.utils.prof.profiler import span_req
try:
span_req(0, [])
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")
def test_span_req_when_import_success_no_error(self):
sys.modules["ms_service_profiler"] = Mock()
sys.modules["ms_service_profiler.mstx"] = Mock()
sys.modules["mindie_llm.utils.log.logging"] = Mock()
from mindie_llm.utils.prof.profiler import span_req
try:
span_req(ACTION_NAME, [1])
except Exception as e:
self.fail(f"Function raised an exception when it shouldn't have: {e}")