"""
start_deploy.py 系统测试
测试 ascend_deployer/start_deploy.py 模块。
"""
import os
import sys
import types
import unittest
from unittest.mock import patch, Mock
_original_sys_path = sys.path.copy()
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
base_dir = os.path.join(project_root, 'ascend_deployer')
path_method = getattr(sys, 'path')
insert_method = getattr(path_method, 'insert')
if base_dir not in path_method:
insert_method(0, base_dir)
if project_root not in path_method:
insert_method(0, project_root)
JOBS_FIELD="jobs"
UTILS_FIELD="utils"
MODULE_UTILS_FIELD="module_utils"
MODULE_UTILS_PATH_MANAGER_FIELD = "module_utils.path_manager"
ASCEND_DEPLOYER_FIELD = "ascend_deployer"
TOOLKIT_FIELD = "toolkit"
TOOLBOX_FIELD = "toolbox"
TENSORFLOW_VERSION_FIELD = "tensorflow"
MINDSPORE_FIELD = "mindspore"
PYTORCH_FIELD = "pytorch"
NNAE_FIELD = "nnae"
SYS_PKG_FIELD = "sys_pkg"
NPU_FIELD = "npu"
CONSOLE_FIELD = "console"
HANDLER_FIELD = "handler"
LEVEL_FIELD = "level"
INFO_LEVEL = "INFO"
INSTALL_FIELD= "--install"
class TestStartDeploy(unittest.TestCase):
"""测试 start_deploy 模块"""
@classmethod
def setUpClass(cls):
if JOBS_FIELD in sys.modules:
del sys.modules[JOBS_FIELD]
sys.modules[JOBS_FIELD] = types.ModuleType('jobs')
sys.modules[JOBS_FIELD].accept_eula = Mock(return_value=True)
sys.modules[JOBS_FIELD].PrepareJob = Mock()
sys.modules[JOBS_FIELD].PrepareJob().run = Mock()
sys.modules[JOBS_FIELD].get_localhost_ip = Mock(return_value="127.0.0.1")
sys.modules[JOBS_FIELD].process_hccn_check = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_check = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_install = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_scene = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_patch = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_upgrade = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_patch_rollback = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_test = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_clean = Mock(return_value=0)
sys.modules[JOBS_FIELD].process_hccn = Mock(return_value=0)
sys.modules[JOBS_FIELD].ResourcePkg = Mock()
sys.modules[JOBS_FIELD].ResourcePkg().handle_pkgs = Mock()
sys.modules[JOBS_FIELD].ResourcePkg().start_nexus_daemon = Mock(return_value=None)
sys.modules[JOBS_FIELD].ResourcePkg().clean = Mock()
if UTILS_FIELD in sys.modules:
del sys.modules[UTILS_FIELD]
sys.modules[UTILS_FIELD] = types.ModuleType(UTILS_FIELD)
sys.modules[UTILS_FIELD].HelpFormatter = Mock()
sys.modules[UTILS_FIELD].install_items = ["sys_pkg", "npu", TOOLKIT_FIELD, TOOLBOX_FIELD, TENSORFLOW_VERSION_FIELD, MINDSPORE_FIELD, PYTORCH_FIELD, NNAE_FIELD]
sys.modules[UTILS_FIELD].ValidChoices = "store"
sys.modules[UTILS_FIELD].stdout_callbacks = ["default", "ansible_log"]
sys.modules[UTILS_FIELD].scene_items = ["auto", "dl", MINDSPORE_FIELD, "offline_dev"]
sys.modules[UTILS_FIELD].patch_items = [NNAE_FIELD, TOOLKIT_FIELD]
sys.modules[UTILS_FIELD].test_items = ["all", "firmware", "driver"]
sys.modules[UTILS_FIELD].check_items = ["full", "fast"]
sys.modules[UTILS_FIELD].upgrade_items = ["npu", NNAE_FIELD, TOOLKIT_FIELD]
sys.modules[UTILS_FIELD].args_with_comma = Mock(side_effect=lambda args: args)
sys.modules[UTILS_FIELD].get_hosts_name = Mock(return_value="worker")
sys.modules[UTILS_FIELD].ROOT_PATH = project_root
sys.modules[UTILS_FIELD].LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
HANDLER_FIELD: {
CONSOLE_FIELD: {
'class': 'logging.StreamHandler',
LEVEL_FIELD: INFO_LEVEL,
'formatter': 'simple',
},
},
'formatters': {
'simple': {
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
},
},
'loggers': {
ASCEND_DEPLOYER_FIELD: {
LEVEL_FIELD: INFO_LEVEL,
HANDLER_FIELD: [CONSOLE_FIELD],
},
'install_operation': {
LEVEL_FIELD: INFO_LEVEL,
HANDLER_FIELD: [CONSOLE_FIELD],
},
},
}
if MODULE_UTILS_FIELD in sys.modules:
del sys.modules[MODULE_UTILS_FIELD]
sys.modules[MODULE_UTILS_FIELD] = types.ModuleType(MODULE_UTILS_FIELD)
sys.modules[MODULE_UTILS_PATH_MANAGER_FIELD] = types.ModuleType(MODULE_UTILS_PATH_MANAGER_FIELD)
sys.modules[MODULE_UTILS_PATH_MANAGER_FIELD].TmpPath = Mock()
sys.modules[MODULE_UTILS_PATH_MANAGER_FIELD].TmpPath.DEPLOY_INFO = "/tmp/deploy_info"
sys.modules[MODULE_UTILS_PATH_MANAGER_FIELD].TmpPath.CHECK_RES_OUTPUT_JSON = "/tmp/check_res_output.json"
def test_cli_initialization(self):
"""测试 CLI 初始化"""
original_argv = sys.argv.copy()
try:
sys.argv = [ASCEND_DEPLOYER_FIELD, "--help"]
with patch('argparse.ArgumentParser.print_help') as mock_print_help:
from ascend_deployer import start_deploy
cli = start_deploy.CLI(
ASCEND_DEPLOYER_FIELD,
"Manage Ascend Packages and dependence packages for specified OS"
)
self.assertIsNotNone(cli)
self.assertIsNotNone(cli.parser)
finally:
sys.argv = original_argv
def test_cli_process_args(self):
"""测试 CLI _process_args 方法"""
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD, NPU_FIELD, "--verbose"], process_args=True)
self.assertEqual(cli.install, [SYS_PKG_FIELD, NPU_FIELD])
self.assertEqual(cli.ansible_args, ['-vv'])
def test_cli_check_ai_frameworks(self):
"""测试 CLI AI 框架检查逻辑"""
cli = self._create_cli_with_args([INSTALL_FIELD, "sys_pkg"], process_args=True)
self.assertEqual(cli.install, ["sys_pkg"])
cli = self._create_cli_with_args([INSTALL_FIELD, NPU_FIELD, TENSORFLOW_VERSION_FIELD], process_args=True)
self.assertEqual(cli.install, [NPU_FIELD, TENSORFLOW_VERSION_FIELD])
with self.assertRaises(Exception) as context:
self._create_cli_with_args([INSTALL_FIELD, TENSORFLOW_VERSION_FIELD, MINDSPORE_FIELD], process_args=True, check_args=True)
self.assertIn("cannot be installed at the same time", str(context.exception))
def test_cli_check_args(self):
"""测试 CLI 参数检查逻辑"""
with self.assertRaises(Exception):
self._create_cli_with_args([], process_args=True, check_args=True)
with self.assertRaises(Exception):
self._create_cli_with_args([INSTALL_FIELD, "tensorflow", "mindspore"], process_args=True, check_args=True)
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD], process_args=True, check_args=True)
self.assertEqual(cli.install, [SYS_PKG_FIELD])
def test_cli_process_env(self):
"""测试 CLI _process_env 方法"""
from ascend_deployer import start_deploy
cli = start_deploy.CLI(
ASCEND_DEPLOYER_FIELD,
"Manage Ascend Packages and dependence packages for specified OS"
)
cli.stdout_callback = "ansible_log"
process_env_method = getattr(cli, "_process_env")
process_env_method()
self.assertEqual(os.environ.get('ANSIBLE_STDOUT_CALLBACK'), "ansible_log")
cli.stdout_callback = None
process_env_method = getattr(cli, "_process_env")
process_env_method()
self.assertIsNotNone(os.environ.get('ANSIBLE_CONFIG'))
def test_cli_license_agreement(self):
"""测试 CLI 许可协议逻辑"""
import jobs
original_accept_eula = jobs.accept_eula
jobs.accept_eula = Mock(return_value=True)
try:
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD], process_args=True, license_agreement=True)
self.assertEqual(cli.install, [SYS_PKG_FIELD])
finally:
jobs.accept_eula = original_accept_eula
jobs.accept_eula = Mock(return_value=False)
try:
with self.assertRaises(Exception):
self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD], process_args=True, license_agreement=True)
finally:
jobs.accept_eula = original_accept_eula
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD, "--check"], process_args=True, license_agreement=True)
self.assertEqual(cli.install, [SYS_PKG_FIELD])
self.assertTrue(cli.check)
def test_cli_prepare_job(self):
"""测试 CLI 作业准备逻辑"""
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD], process_args=True, prepare_job=True)
self.assertIsNotNone(cli.envs)
self.assertIsInstance(cli.envs, dict)
with patch('builtins.print') as mock_print:
cli = self._create_cli_with_args([INSTALL_FIELD, NNAE_FIELD], process_args=True, prepare_job=True)
mock_print.assert_called_once_with('Warning: --install=nnae feature will be deprecated soon.')
def test_cli_run_check(self):
"""测试 CLI 检查运行逻辑"""
import jobs
original_process_check = jobs.process_check
jobs.process_check = Mock(return_value=0)
try:
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD, "--check"], process_args=True, prepare_job=True)
self.assertTrue(cli.check)
self.assertIsNotNone(cli.envs)
finally:
jobs.process_check = original_process_check
original_process_hccn_check = jobs.process_hccn_check
jobs.process_hccn_check = Mock(return_value=0)
try:
cli = self._create_cli_with_args(["--hccn", "--check"], process_args=True, prepare_job=True)
self.assertTrue(cli.hccn)
self.assertTrue(cli.check)
self.assertIsNotNone(cli.envs)
finally:
jobs.process_hccn_check = original_process_hccn_check
def test_cli_run_handler(self):
"""测试 CLI 处理程序运行逻辑"""
import jobs
original_process_install = jobs.process_install
jobs.process_install = Mock(return_value=0)
try:
cli = self._create_cli_with_args([INSTALL_FIELD, SYS_PKG_FIELD], process_args=True, prepare_job=True)
self.assertIsNotNone(cli.envs)
self.assertIsInstance(cli.envs, dict)
finally:
jobs.process_install = original_process_install
def test_cli_run_test(self):
"""测试 CLI 测试运行逻辑"""
import jobs
original_process_test = jobs.process_test
jobs.process_test = Mock(return_value=0)
try:
cli = self._create_cli_with_args(["--test", "all"], process_args=True, prepare_job=True)
self.assertEqual(cli.test, ["all"])
self.assertIsNotNone(cli.envs)
finally:
jobs.process_test = original_process_test
def test_cli_run_clean(self):
"""测试 CLI 清理运行逻辑"""
import jobs
original_process_clean = jobs.process_clean
jobs.process_clean = Mock(return_value=0)
try:
cli = self._create_cli_with_args(["--clean"], process_args=True, prepare_job=True)
self.assertTrue(cli.clean)
self.assertIsNotNone(cli.envs)
finally:
jobs.process_clean = original_process_clean
def test_cli_run_hccn(self):
"""测试 CLI HCCN 运行逻辑"""
import jobs
original_process_hccn_check = jobs.process_hccn_check
original_process_hccn = jobs.process_hccn
jobs.process_hccn_check = Mock(return_value=0)
jobs.process_hccn = Mock(return_value=0)
try:
cli = self._create_cli_with_args(["--hccn"], process_args=True, prepare_job=True)
self.assertTrue(cli.hccn)
self.assertIsNotNone(cli.envs)
finally:
jobs.process_hccn_check = original_process_hccn_check
jobs.process_hccn = original_process_hccn
def test_cli_run(self):
"""测试 CLI run 方法"""
original_argv = sys.argv.copy()
try:
sys.argv = [ASCEND_DEPLOYER_FIELD, INSTALL_FIELD, SYS_PKG_FIELD]
from ascend_deployer.start_deploy import CLI
cli = CLI(
ASCEND_DEPLOYER_FIELD,
"Manage Ascend Packages and dependence packages for specified OS"
)
with patch.object(CLI, '_process_args') as mock_process_args, \
patch.object(CLI, '_check_args') as mock_check_args, \
patch.object(CLI, '_process_env') as mock_process_env, \
patch.object(CLI, '_license_agreement') as mock_license_agreement, \
patch.object(CLI, '_prepare_job') as mock_prepare_job, \
patch.object(CLI, '_run_handler') as mock_run_handler:
cli.check = False
cli.install = [SYS_PKG_FIELD]
cli.scene = None
cli.patch = None
cli.upgrade = None
cli.patch_rollback = None
cli.test = None
cli.clean = False
cli.hccn = False
cli.check_mode = None
cli.envs = {}
mock_run_handler.return_value = 0
result = cli.run()
self.assertEqual(result, 0)
finally:
sys.argv = original_argv
def test_main_function(self):
"""测试 main 函数"""
original_argv = sys.argv.copy()
try:
sys.argv = [ASCEND_DEPLOYER_FIELD]
from ascend_deployer import start_deploy
with patch('ascend_deployer.start_deploy.CLI._check_args',
side_effect=Exception("expected one valid argument at least")):
result = start_deploy.main()
self.assertEqual(result, -1)
finally:
sys.argv = original_argv
def _create_cli_with_args(self, args, process_args=False, check_args=False, license_agreement=False, prepare_job=False):
"""创建 CLI 实例并设置参数"""
original_argv = sys.argv.copy()
try:
sys.argv = [ASCEND_DEPLOYER_FIELD] + args
from ascend_deployer import start_deploy
cli = start_deploy.CLI(
ASCEND_DEPLOYER_FIELD,
"Manage Ascend Packages and dependence packages for specified OS"
)
if process_args:
process_args_method = getattr(cli, "process_args", None) or getattr(cli, "_process_args")
process_args_method()
if check_args:
check_args_method = getattr(cli, "check_args", None) or getattr(cli, "_check_args")
check_args_method()
if license_agreement:
license_agreement_method = getattr(cli, "license_agreement", None) or getattr(cli, "_license_agreement")
license_agreement_method()
if prepare_job:
prepare_job_method = getattr(cli, "prepare_job", None) or getattr(cli, "_prepare_job")
prepare_job_method()
return cli
finally:
sys.argv = original_argv
if __name__ == '__main__':
unittest.main()