"""
ascend_download.py 命令行接口测试
这个测试文件专注于测试 ascend_download.py 的命令行接口,
只测试参数解析和验证,不测试实际的下载功能。
"""
import os
import sys
import unittest
import io
from unittest.mock import patch, Mock
_original_sys_path = sys.path.copy()
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
ascend_deployer_dir = os.path.join(project_root, 'ascend_deployer')
path_method = getattr(sys, 'path')
insert_method = getattr(path_method, 'insert')
if ascend_deployer_dir not in path_method:
insert_method(0, ascend_deployer_dir)
if project_root not in path_method:
insert_method(0, project_root)
from ascend_download import CLI, main
OS_FIELD = "CentOS_7.6_aarch64"
OS_UBUNTU18_FIELD = "Ubuntu_18.04_aarch64"
CANN_FIELD = "CANN"
CANN_VERSION_FIELD = f"{CANN_FIELD}==8.5.0"
MINDSPORE_FIELD = "MindSpore"
class TestAscendDownloadCLI(unittest.TestCase):
"""测试 ascend_download.py 的命令行接口"""
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_cli_initialization(self, mock_get_pkg_list, mock_get_os_list):
"""测试CLI类初始化"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI(
"test-program",
"Test description",
"Test epilog"
)
self.assertEqual(cli.parser.prog, "test-program")
self.assertEqual(cli.parser.description, "Test description")
self.assertEqual(cli.parser.epilog, "Test epilog")
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_cli_help_options(self, mock_get_pkg_list, mock_get_os_list):
"""测试帮助选项"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
with patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with patch('argparse.ArgumentParser.exit') as mock_exit:
cli.parser.parse_args(['-h'])
mock_exit.assert_called_once()
help_output = mock_stdout.getvalue()
self.assertIn('usage:', help_output)
with patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with patch('argparse.ArgumentParser.exit') as mock_exit:
cli.parser.parse_args(['--help'])
mock_exit.assert_called_once()
help_output = mock_stdout.getvalue()
self.assertIn('usage:', help_output)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_os_list_argument_not_required_in_parser(self, mock_get_pkg_list, mock_get_os_list):
"""测试--os-list参数在parser中不是必需的(在run方法中检查)"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
args = cli.parser.parse_args([])
self.assertEqual(args.os_list, None)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_os_list_argument_choices_set(self, mock_get_pkg_list, mock_get_os_list):
"""测试--os-list参数的选择项被正确设置"""
expected_os_list = [OS_FIELD, OS_UBUNTU18_FIELD, 'EulerOS_2.8_aarch64']
mock_get_os_list.return_value = expected_os_list
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
args = cli.parser.parse_args(['--os-list', OS_FIELD])
self.assertEqual(args.os_list, [OS_FIELD])
with patch('argparse.ArgumentParser.error') as mock_error:
cli.parser.parse_args(['--os-list', 'InvalidOS'])
mock_error.assert_called()
error_messages = [call[0][0] for call in mock_error.call_args_list]
has_invalid_choice_error = any('InvalidOS' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
self.assertTrue(has_invalid_choice_error)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_download_argument_optional(self, mock_get_pkg_list, mock_get_os_list):
"""测试--download参数是可选的"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD, MINDSPORE_FIELD]
cli = CLI("test", "desc", "epilog")
args = cli.parser.parse_args(['--os-list', OS_FIELD])
self.assertEqual(args.pkg_list, [])
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_download_argument_choices_set(self, mock_get_pkg_list, mock_get_os_list):
"""测试--download参数的选择项被正确设置"""
expected_pkg_list = [CANN_FIELD, MINDSPORE_FIELD, 'TensorFlow', 'PyTorch', CANN_VERSION_FIELD]
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = expected_pkg_list
cli = CLI("test", "desc", "epilog")
args = cli.parser.parse_args(['--os-list', OS_FIELD, '--download', CANN_FIELD])
self.assertEqual(args.pkg_list, [CANN_FIELD])
with patch('argparse.ArgumentParser.error') as mock_error:
cli.parser.parse_args(['--os-list', OS_FIELD, '--download', 'InvalidPackage'])
mock_error.assert_called()
error_messages = [call[0][0] for call in mock_error.call_args_list]
has_invalid_choice_error = any('InvalidPackage' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
self.assertTrue(has_invalid_choice_error)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_parse_single_os_argument(self, mock_get_pkg_list, mock_get_os_list):
"""测试解析单个操作系统参数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD]
args = cli.parser.parse_args(test_args)
self.assertEqual(args.os_list, [OS_FIELD])
self.assertEqual(args.pkg_list, [])
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_parse_multiple_os_arguments(self, mock_get_pkg_list, mock_get_os_list):
"""测试解析多个操作系统参数"""
mock_get_os_list.return_value = [OS_FIELD, OS_UBUNTU18_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD, OS_UBUNTU18_FIELD]
args = cli.parser.parse_args(test_args)
self.assertIn(OS_FIELD, args.os_list)
self.assertIn(OS_UBUNTU18_FIELD, args.os_list)
self.assertEqual(len(args.os_list), 2)
self.assertEqual(args.pkg_list, [])
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_parse_os_and_package_arguments(self, mock_get_pkg_list, mock_get_os_list):
"""测试解析操作系统和包参数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD, MINDSPORE_FIELD]
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD, MINDSPORE_FIELD]
args = cli.parser.parse_args(test_args)
self.assertEqual(args.os_list, [OS_FIELD])
self.assertIn(CANN_FIELD, args.pkg_list)
self.assertIn(MINDSPORE_FIELD, args.pkg_list)
self.assertEqual(len(args.pkg_list), 2)
@patch('ascend_download.get_download_path')
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
@patch('ascend_download.download_dependency')
def test_cli_run_method(self, mock_download_dependency, mock_get_pkg_list,
mock_get_os_list, mock_get_download_path):
"""测试CLI类的run方法"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
mock_get_download_path.return_value = '/tmp/test'
mock_download_dependency.return_value = 0
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD]
cli.run(test_args, check=False)
mock_download_dependency.assert_called_once_with(
[OS_FIELD],
[],
'/tmp/test',
False
)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_parse_package_with_version(self, mock_get_pkg_list, mock_get_os_list):
"""测试解析带版本的包参数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD, 'CANN_FIELD==8.5.0', 'MindSpore==2.7.2']
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD+'==8.5.0']
args = cli.parser.parse_args(test_args)
self.assertEqual(args.os_list, [OS_FIELD])
self.assertEqual(args.pkg_list, [CANN_FIELD+'==8.5.0'])
class TestAscendDownloadErrorHandling(unittest.TestCase):
"""测试 ascend_download.py 的错误处理"""
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_invalid_os_argument_error(self, mock_get_pkg_list, mock_get_os_list):
"""测试无效的操作系统参数错误"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', 'InvalidOS']
with patch('argparse.ArgumentParser.error') as mock_error:
cli.parser.parse_args(test_args)
mock_error.assert_called()
error_messages = [call[0][0] for call in mock_error.call_args_list]
has_invalid_choice_error = any('InvalidOS' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
self.assertTrue(has_invalid_choice_error)
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
def test_invalid_package_argument_error(self, mock_get_pkg_list, mock_get_os_list):
"""测试无效的包参数错误"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
cli = CLI("test", "desc", "epilog")
test_args = ['--os-list', OS_FIELD, '--download', 'InvalidPackage']
with patch('argparse.ArgumentParser.error') as mock_error:
cli.parser.parse_args(test_args)
mock_error.assert_called()
error_messages = [call[0][0] for call in mock_error.call_args_list]
has_invalid_choice_error = any('InvalidPackage' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
self.assertTrue(has_invalid_choice_error)
class TestAscendDownloadURLMocking(unittest.TestCase):
"""测试 ascend_download.py 中URL请求的mock"""
@patch('urllib.request.Request')
@patch('urllib.request.urlopen')
def test_url_request_mocking(self, mock_urlopen, mock_request):
"""测试URL请求的mock"""
mock_response = Mock()
mock_response.getheader.return_value = '"123456"'
mock_response.read.return_value = b'test data'
mock_urlopen.return_value.__enter__.return_value = mock_response
from ascend_deployer.downloader.download_util import get_remote_content_length
result = get_remote_content_length('http://example.com')
self.assertEqual(result, 123456)
mock_request.assert_called_once()
mock_urlopen.assert_called_once()
@patch('urllib.request.Request')
@patch('urllib.request.urlopen')
def test_url_request_exception_handling(self, mock_urlopen, mock_request):
"""测试URL请求异常处理"""
mock_urlopen.side_effect = Exception('Connection failed')
from ascend_deployer.downloader.download_util import get_remote_content_length
result = get_remote_content_length('http://example.com')
self.assertIsNone(result)
class TestAscendDownloadMainFunction(unittest.TestCase):
"""测试 ascend_download.py 的main函数"""
@patch('ascend_download.get_download_path')
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
@patch('ascend_download.download_dependency')
def test_main_function(self, mock_download_dependency, mock_get_pkg_list,
mock_get_os_list, mock_get_download_path):
"""测试main函数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
mock_get_download_path.return_value = '/tmp/test'
mock_download_dependency.return_value = 0
test_args = ['--os-list', OS_FIELD]
main(args=test_args, check=False)
mock_download_dependency.assert_called_once_with(
[OS_FIELD],
[],
'/tmp/test',
False
)
@patch('ascend_download.get_download_path')
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
@patch('ascend_download.download_dependency')
def test_main_function_default_check(self, mock_download_dependency, mock_get_pkg_list,
mock_get_os_list, mock_get_download_path):
"""测试main函数默认check参数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [CANN_FIELD]
mock_get_download_path.return_value = '/tmp/test'
mock_download_dependency.return_value = 0
test_args = ['--os-list', OS_FIELD]
main(args=test_args)
mock_download_dependency.assert_called_once_with(
[OS_FIELD],
[],
'/tmp/test',
True
)
@patch('ascend_download.get_download_path')
@patch('ascend_deployer.download_util.get_os_list')
@patch('ascend_deployer.download_util.get_pkg_list')
@patch('ascend_download.download_dependency')
def test_main_function_with_packages(self, mock_download_dependency, mock_get_pkg_list,
mock_get_os_list, mock_get_download_path):
"""测试main函数带包参数"""
mock_get_os_list.return_value = [OS_FIELD]
mock_get_pkg_list.return_value = [MINDSPORE_FIELD, CANN_FIELD]
mock_get_download_path.return_value = '/tmp/test'
mock_download_dependency.return_value = 0
test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD, MINDSPORE_FIELD]
main(args=test_args, check=False)
mock_download_dependency.assert_called_once()
call_args = mock_download_dependency.call_args
self.assertEqual(call_args[0][0], [OS_FIELD])
pkg_list = call_args[0][1]
self.assertIn(CANN_FIELD, pkg_list)
self.assertIn(MINDSPORE_FIELD, pkg_list)
self.assertEqual(len(pkg_list), 2)
self.assertEqual(call_args[0][2], '/tmp/test')
self.assertFalse(call_args[0][3])
if __name__ == '__main__':
unittest.main()