"""
UTs for python/msprobe/msprobe.py
"""
from unittest import TestCase
from unittest.mock import patch, MagicMock
from msprobe.msprobe import main
class TestMsprobeMain(TestCase):
@patch("msprobe.msprobe.argparse.ArgumentParser")
def test_main_when_no_args_then_pass(self, mock_arg_parser):
parser_instance = MagicMock()
subparsers_instance = MagicMock()
mock_arg_parser.return_value = parser_instance
parser_instance.add_subparsers.return_value = subparsers_instance
with patch("msprobe.msprobe.sys") as mock_sys:
mock_sys.argv = ["msprobe"]
mock_sys.exit.side_effect = SystemExit(0)
with self.assertRaises(SystemExit) as cm:
main()
self.assertEqual(cm.exception.code, 0)
parser_instance.print_help.assert_called_once()
@patch("msprobe.msprobe.acc_check_cli")
def test_main_when_acc_check_subcommand_then_pass(self, mock_acc_check_cli):
with patch("msprobe.msprobe.sys.argv", ["msprobe", "acc_check", "--arg1", "value1"]):
main()
mock_acc_check_cli.assert_called_once_with(["--arg1", "value1"])
@patch("msprobe.msprobe.multi_acc_check_cli")
def test_main_when_multi_acc_check_subcommand_then_pass(self, mock_multi_acc_check_cli):
with patch("msprobe.msprobe.sys.argv", ["msprobe", "multi_acc_check", "--arg1", "value1"]):
main()
mock_multi_acc_check_cli.assert_called_once_with(["--arg1", "value1"])
@patch("msprobe.msprobe.compare_cli")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_compare_subcommand_then_pass(self, mock_parse_args, mock_compare_cli):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "compare"]):
main()
mock_parse_args.assert_called_once_with(["compare"])
mock_compare_cli.assert_called_once_with(args, ["compare"])
@patch("msprobe.msprobe.merge_result_cli")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_merge_result_subcommand_then_pass(self, mock_parse_args, mock_merge_result_cli):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "merge_result"]):
main()
mock_parse_args.assert_called_once_with(["merge_result"])
mock_merge_result_cli.assert_called_once_with(args)
@patch("msprobe.msprobe._run_overflow_check")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_overflow_check_subcommand_then_pass(self, mock_parse_args, mock_run_overflow_check):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "overflow_check"]):
main()
mock_parse_args.assert_called_once_with(["overflow_check"])
mock_run_overflow_check.assert_called_once_with(args)
@patch("msprobe.msprobe._graph_service_command")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_graph_visualize_subcommand_then_pass(self, mock_parse_args, mock_graph_service_cmd):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "graph_visualize"]):
main()
mock_parse_args.assert_called_once_with(["graph_visualize"])
mock_graph_service_cmd.assert_called_once_with(args)
@patch("msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare._api_precision_compare_command")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_api_precision_compare_subcommand_then_pass(self, mock_parse_args, mock_api_precision_cmd):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "api_precision_compare"]):
main()
mock_parse_args.assert_called_once_with(["api_precision_compare"])
mock_api_precision_cmd.assert_called_once_with(args)
@patch("msprobe.msprobe._run_config_checking_command")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_config_check_subcommand_then_pass(self, mock_parse_args, mock_run_config):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "config_check"]):
main()
mock_parse_args.assert_called_once_with(["config_check"])
mock_run_config.assert_called_once_with(args)
@patch("msprobe.msprobe._data2db_command")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_data2db_subcommand_then_pass(self, mock_parse_args, mock_data2db_cmd):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "data2db"]):
main()
mock_parse_args.assert_called_once_with(["data2db"])
mock_data2db_cmd.assert_called_once_with(args)
@patch("msprobe.msprobe.offline_dump_cli")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_offline_dump_subcommand_then_pass(self, mock_parse_args, mock_offline_dump_cli):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "offline_dump"]):
main()
mock_parse_args.assert_called_once_with(["offline_dump"])
mock_offline_dump_cli.assert_called_once_with(args)
@patch("msprobe.msprobe.install_deps_cli")
@patch("msprobe.msprobe.argparse.ArgumentParser.parse_args")
def test_main_when_install_deps_subcommand_then_pass(self, mock_parse_args, mock_install_deps_cli):
args = MagicMock()
mock_parse_args.return_value = args
with patch("msprobe.msprobe.sys.argv", ["msprobe", "install_deps"]):
main()
mock_parse_args.assert_called_once_with(["install_deps"])
mock_install_deps_cli.assert_called_once_with(args)