import os
import unittest
from mock import patch
from msprof_analyze.compare_tools.compare_backend.utils.args_manager import ArgsManager
from msprof_analyze.compare_tools.compare_backend.utils.compare_args import Args
from msprof_analyze.prof_common.constant import Constant
from msprof_analyze.prof_common.file_manager import FileManager
from msprof_analyze.prof_common.path_manager import PathManager
class TestArgsManager(unittest.TestCase):
def setUp(self):
ArgsManager._instance = {}
self.args = Args(
base_profiling_path="/path/to/base/profiling",
comparison_profiling_path="/path/to/comparison/profiling",
base_step="1",
comparison_step="2",
)
self.args_manager = ArgsManager(self.args)
def tearDown(self) -> None:
ArgsManager._instance = {}
def test_singleton_pattern(self):
"""测试 ArgsManager 是否遵循单例模式"""
another_args_manager = ArgsManager(self.args)
self.assertIs(self.args_manager, another_args_manager)
@patch.object(PathManager, 'check_input_directory_path')
@patch.object(PathManager, 'check_input_file_path')
@patch('os.path.exists', return_value=True)
def test_check_profiling_path_success(self, mock_exists, mock_file_check, mock_directory_check):
"""测试成功检查性能分析路径"""
self.args_manager.check_profiling_path({"profiling_path": "/valid/path"})
self.assertEqual(mock_directory_check.call_count, 5)
@patch.object(PathManager, 'check_input_directory_path')
@patch.object(PathManager, 'check_input_file_path')
@patch('os.path.isfile')
@patch('os.listdir')
def test_init_with_default_value(self, mock_listdir, mock_isfile, mock_check_file, mock_check_dir):
"""测试初始化 ArgsManager"""
mock_listdir.return_value = [""]
mock_isfile.side_effect = [False, True, False, True]
self.args_manager.init()
self.assertEqual(self.args_manager.base_profiling_path, "/path/to/base/profiling")
self.assertEqual(self.args_manager.base_step, 1)
self.assertEqual(self.args_manager.comparison_step, 2)
self.assertEqual(self.args_manager.comparison_profiling_type, "NPU")
self.assertEqual(len(self.args_manager.base_path_dict), 4)
self.assertEqual(len(self.args_manager.comparison_path_dict), 4)
self.assertTrue(self.args_manager.enable_memory_compare)
self.assertTrue(self.args_manager.enable_communication_compare)
self.assertFalse(self.args_manager.use_kernel_type)
def test_init_with_invalid_max_kernel_num(self):
"""测试输入max_kernel_num非法大于3时的异常分支"""
ArgsManager._instance = {}
arg_manager = ArgsManager(Args(
max_kernel_num=3
))
with self.assertRaises(RuntimeError) as exec_info:
arg_manager.init()
self.assertEqual(exec_info.exception.args, ("Invalid param, --max_kernel_num has to be greater than 3",))
def test_set_compare_type(self):
"""测试设置比较类型"""
self.args_manager.set_compare_type(Constant.OVERALL_COMPARE)
self.assertTrue(self.args_manager.enable_profiling_compare)
self.args_manager.set_compare_type(Constant.OPERATOR_COMPARE)
self.assertTrue(self.args_manager.enable_operator_compare)
self.args_manager.set_compare_type(Constant.API_COMPARE)
self.assertTrue(self.args_manager.enable_api_compare)
self.args_manager.set_compare_type(Constant.KERNEL_COMPARE)
self.assertTrue(self.args_manager.enable_kernel_compare)
@patch.object(PathManager, 'check_input_file_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'split')
@patch.object(os.path, 'splitext')
@patch.object(FileManager, 'check_json_type')
def test_parse_profiling_path_json_file(self, mock_check_json_type, mock_splitext,
mock_split, mock_isfile, mock_path_check):
"""测试解析单个JSON文件路径"""
mock_path_check.return_value = None
mock_isfile.return_value = True
mock_split.return_value = ("/path/to", "file.json")
mock_splitext.return_value = ("file", ".json")
mock_check_json_type.return_value = Constant.GPU
result = ArgsManager(Args()).parse_profiling_path("/path/to/file.json")
expected_result = {
Constant.PROFILING_TYPE: Constant.GPU,
Constant.PROFILING_PATH: "/path/to/file.json",
Constant.TRACE_PATH: "/path/to/file.json"
}
self.assertEqual(result, expected_result)
mock_path_check.assert_called_once_with("/path/to/file.json")
mock_isfile.assert_called_once_with("/path/to/file.json")
mock_split.assert_called_once_with("/path/to/file.json")
mock_splitext.assert_called_once_with("file.json")
mock_check_json_type.assert_called_once_with("/path/to/file.json")
@patch.object(PathManager, 'check_input_file_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'split')
@patch.object(os.path, 'splitext')
def test_parse_profiling_path_db_file(self, mock_splitext, mock_split, mock_isfile, mock_path_check):
"""测试解析单个DB文件路径"""
mock_path_check.return_value = None
mock_isfile.return_value = True
mock_split.return_value = ("/path/to", "ascend_pytorch_profiler.db")
mock_splitext.return_value = ("ascend_pytorch_profiler", ".db")
result = ArgsManager().parse_profiling_path("/path/to/ascend_pytorch_profiler.db")
expected_result = {
Constant.PROFILING_TYPE: Constant.NPU,
Constant.PROFILING_PATH: "/path/to/ascend_pytorch_profiler.db",
Constant.PROFILER_DB_PATH: "/path/to/ascend_pytorch_profiler.db"
}
self.assertEqual(result, expected_result)
@patch.object(PathManager, 'check_input_file_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'split')
@patch.object(os.path, 'splitext')
def test_parse_profiling_path_invalid_file_extension(self, mock_splitext, mock_split, mock_isfile, mock_path_check):
"""测试解析无效扩展名的文件路径"""
mock_path_check.return_value = None
mock_isfile.return_value = True
mock_split.return_value = ("/path/to", "file.txt")
mock_splitext.return_value = ("file", ".txt")
with self.assertRaises(RuntimeError) as context:
ArgsManager(Args()).parse_profiling_path("/path/to/file.txt")
self.assertIn("Invalid profiling path suffix", str(context.exception))
@patch.object(PathManager, 'check_input_directory_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'isdir')
@patch('os.listdir')
@patch.object(os.path, 'join')
def test_parse_profiling_path_directory_with_profiler_info(self, mock_join, mock_listdir,
mock_isdir, mock_isfile, mock_path_check):
"""测试解析包含 profiler_info.json 的目录路径"""
mock_path_check.return_value = None
mock_isfile.side_effect = [False, False]
mock_isdir.side_effect = [True, False]
mock_listdir.side_effect = [
["profiler_info.json", "other_file.txt"],
[]
]
mock_join.return_value = "/path/to/directory/profiler_info.json"
with self.assertRaises(RuntimeError) as context:
ArgsManager(Args()).parse_profiling_path("/path/to/directory")
self.assertIn("Invalid profiling path", str(context.exception))
@patch.object(PathManager, 'check_input_directory_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'isdir')
@patch('os.listdir')
@patch.object(os.path, 'join')
def test_parse_profiling_path_directory_with_db_file(self, mock_join, mock_listdir,
mock_isdir, mock_isfile, mock_path_check):
"""测试解析包含 .db 文件的目录路径"""
mock_path_check.return_value = None
mock_isfile.return_value = False
mock_isdir.side_effect = [True, False]
mock_listdir.side_effect = [
["other_file.txt"],
["ascend_pytorch_profiler.db", "other_file.txt"]
]
mock_join.side_effect = [
"/path/to/directory/ASCEND_PROFILER_OUTPUT",
"/path/to/directory/ascend_pytorch_profiler.db"
]
result = ArgsManager(Args()).parse_profiling_path("/path/to/directory")
expected_result = {
Constant.PROFILING_TYPE: Constant.NPU,
Constant.PROFILING_PATH: "/path/to/directory",
Constant.PROFILER_DB_PATH: "/path/to/directory/ascend_pytorch_profiler.db",
Constant.ASCEND_OUTPUT_PATH: "/path/to/directory/ASCEND_PROFILER_OUTPUT"
}
self.assertEqual(result, expected_result)
@patch.object(PathManager, 'check_input_directory_path')
@patch.object(os.path, 'isfile')
@patch.object(os.path, 'isdir')
@patch('os.listdir')
@patch.object(os.path, 'join')
def test_parse_profiling_path_directory_with_trace_view_json(self, mock_join, mock_listdir,
mock_isdir, mock_isfile, mock_path_check):
"""测试解析包含 trace_view.json 的目录路径"""
mock_path_check.return_value = None
mock_isfile.side_effect = [False, True]
mock_isdir.side_effect = [True, False]
mock_listdir.side_effect = [
["other_file.txt"],
["trace_view.json", "other_file.txt"]
]
mock_join.side_effect = [
"/path/to/directory/ASCEND_PROFILER_OUTPUT",
"/path/to/directory/trace_view.json"
]
result = ArgsManager(Args()).parse_profiling_path("/path/to/directory")
expected_result = {
Constant.PROFILING_TYPE: Constant.NPU,
Constant.PROFILING_PATH: "/path/to/directory",
Constant.TRACE_PATH: "/path/to/directory/trace_view.json",
Constant.ASCEND_OUTPUT_PATH: "/path/to/directory/ASCEND_PROFILER_OUTPUT"
}
self.assertEqual(result, expected_result)
@patch.object(PathManager, 'input_path_common_check')
def test_parse_profiling_path_path_validation_fails(self, mock_path_check):
"""测试路径验证失败的情况"""
mock_path_check.side_effect = RuntimeError("Invalid path")
with self.assertRaises(RuntimeError) as context:
ArgsManager(Args()).parse_profiling_path("/invalid/path")
self.assertEqual("Invalid path", str(context.exception))
if __name__ == '__main__':
unittest.main()