#!/usr/bin/env python3
# coding: utf-8
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
ascend_download.py 命令行接口测试

这个测试文件专注于测试 ascend_download.py 的命令行接口,
只测试参数解析和验证,不测试实际的下载功能。
"""

import os
import sys
import unittest
import io
from unittest.mock import patch, Mock

# 保存原始 sys.path,以便测试完成后恢复
_original_sys_path = sys.path.copy()

# 添加项目路径,使用绝对路径避免相对路径问题
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
ascend_deployer_dir = os.path.join(project_root, 'ascend_deployer')

# 使用动态方式添加路径,确保项目路径优先级高于系统路径
path_method = getattr(sys, 'path')
insert_method = getattr(path_method, 'insert')

# 确保路径唯一性,避免重复添加
if ascend_deployer_dir not in path_method:
    insert_method(0, ascend_deployer_dir)
if project_root not in path_method:
    insert_method(0, project_root)

# 导入测试模块
from ascend_download import CLI, main

OS_FIELD = "CentOS_7.6_aarch64"
OS_UBUNTU18_FIELD = "Ubuntu_18.04_aarch64"
CANN_FIELD = "CANN"
CANN_VERSION_FIELD = f"{CANN_FIELD}==8.5.0"
MINDSPORE_FIELD = "MindSpore"


class TestAscendDownloadCLI(unittest.TestCase):
    """测试 ascend_download.py 的命令行接口"""
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_cli_initialization(self, mock_get_pkg_list, mock_get_os_list):
        """测试CLI类初始化"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI(
            "test-program",
            "Test description",
            "Test epilog"
        )
        
        self.assertEqual(cli.parser.prog, "test-program")
        self.assertEqual(cli.parser.description, "Test description")
        self.assertEqual(cli.parser.epilog, "Test epilog")
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_cli_help_options(self, mock_get_pkg_list, mock_get_os_list):
        """测试帮助选项"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试-h选项存在
        with patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
            with patch('argparse.ArgumentParser.exit') as mock_exit:
                cli.parser.parse_args(['-h'])
                # 验证exit方法被调用
                mock_exit.assert_called_once()
                help_output = mock_stdout.getvalue()
                self.assertIn('usage:', help_output)
        
        # 测试--help选项存在
        with patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
            with patch('argparse.ArgumentParser.exit') as mock_exit:
                cli.parser.parse_args(['--help'])
                # 验证exit方法被调用
                mock_exit.assert_called_once()
                help_output = mock_stdout.getvalue()
                self.assertIn('usage:', help_output)
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_os_list_argument_not_required_in_parser(self, mock_get_pkg_list, mock_get_os_list):
        """测试--os-list参数在parser中不是必需的(在run方法中检查)"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试不提供--os-list参数时不会出错(在parser层面)
        args = cli.parser.parse_args([])
        self.assertEqual(args.os_list, None)
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_os_list_argument_choices_set(self, mock_get_pkg_list, mock_get_os_list):
        """测试--os-list参数的选择项被正确设置"""
        # 设置mock返回特定的操作系统列表
        expected_os_list = [OS_FIELD, OS_UBUNTU18_FIELD, 'EulerOS_2.8_aarch64']
        mock_get_os_list.return_value = expected_os_list
        mock_get_pkg_list.return_value = [CANN_FIELD]

        cli = CLI("test", "desc", "epilog")
        
        # 测试有效的OS参数
        args = cli.parser.parse_args(['--os-list', OS_FIELD])
        self.assertEqual(args.os_list, [OS_FIELD])
        
        # 测试无效的OS参数会报错
        with patch('argparse.ArgumentParser.error') as mock_error:
            # 解析参数
            cli.parser.parse_args(['--os-list', 'InvalidOS'])
            
            # 验证 error 方法被调用
            mock_error.assert_called()
            
            # 验证错误信息包含无效的选择
            error_messages = [call[0][0] for call in mock_error.call_args_list]
            has_invalid_choice_error = any('InvalidOS' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
            self.assertTrue(has_invalid_choice_error)
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_download_argument_optional(self, mock_get_pkg_list, mock_get_os_list):
        """测试--download参数是可选的"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD, MINDSPORE_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试不提供--download参数时不会出错
        args = cli.parser.parse_args(['--os-list', OS_FIELD])
        self.assertEqual(args.pkg_list, [])
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_download_argument_choices_set(self, mock_get_pkg_list, mock_get_os_list):
        """测试--download参数的选择项被正确设置"""
        # 设置mock返回特定的包列表
        expected_pkg_list = [CANN_FIELD, MINDSPORE_FIELD, 'TensorFlow', 'PyTorch', CANN_VERSION_FIELD]
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = expected_pkg_list
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试有效的包参数
        args = cli.parser.parse_args(['--os-list', OS_FIELD, '--download', CANN_FIELD])
        self.assertEqual(args.pkg_list, [CANN_FIELD])
        
        # 测试无效的包参数会报错
        with patch('argparse.ArgumentParser.error') as mock_error:
            # 解析参数
            cli.parser.parse_args(['--os-list', OS_FIELD, '--download', 'InvalidPackage'])
            
            # 验证 error 方法被调用
            mock_error.assert_called()
            
            # 验证错误信息包含无效的选择
            error_messages = [call[0][0] for call in mock_error.call_args_list]
            has_invalid_choice_error = any('InvalidPackage' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
            self.assertTrue(has_invalid_choice_error)
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_parse_single_os_argument(self, mock_get_pkg_list, mock_get_os_list):
        """测试解析单个操作系统参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试解析单个OS
        test_args = ['--os-list', OS_FIELD]
        args = cli.parser.parse_args(test_args)
        
        self.assertEqual(args.os_list, [OS_FIELD])
        self.assertEqual(args.pkg_list, [])
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_parse_multiple_os_arguments(self, mock_get_pkg_list, mock_get_os_list):
        """测试解析多个操作系统参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD, OS_UBUNTU18_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试解析多个OS
        test_args = ['--os-list', OS_FIELD, OS_UBUNTU18_FIELD]
        args = cli.parser.parse_args(test_args)
        
        # 注意:argparse可能会改变参数顺序,所以我们只检查是否包含这两个值
        self.assertIn(OS_FIELD, args.os_list)
        self.assertIn(OS_UBUNTU18_FIELD, args.os_list)
        self.assertEqual(len(args.os_list), 2)
        self.assertEqual(args.pkg_list, [])
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_parse_os_and_package_arguments(self, mock_get_pkg_list, mock_get_os_list):
        """测试解析操作系统和包参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD, MINDSPORE_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试解析OS和包
        test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD, MINDSPORE_FIELD]
        args = cli.parser.parse_args(test_args)
        
        self.assertEqual(args.os_list, [OS_FIELD])
        # 检查是否包含这两个包,不检查顺序
        self.assertIn(CANN_FIELD, args.pkg_list)
        self.assertIn(MINDSPORE_FIELD, args.pkg_list)
        self.assertEqual(len(args.pkg_list), 2)
    
    @patch('ascend_download.get_download_path')
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    @patch('ascend_download.download_dependency')
    def test_cli_run_method(self, mock_download_dependency, mock_get_pkg_list, 
                           mock_get_os_list, mock_get_download_path):
        """测试CLI类的run方法"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        mock_get_download_path.return_value = '/tmp/test'
        mock_download_dependency.return_value = 0
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试run方法
        test_args = ['--os-list', OS_FIELD]
        cli.run(test_args, check=False)
        
        # run方法没有返回值,只验证download_dependency被调用
        mock_download_dependency.assert_called_once_with(
            [OS_FIELD],
            [],
            '/tmp/test',
            False
        )
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_parse_package_with_version(self, mock_get_pkg_list, mock_get_os_list):
        """测试解析带版本的包参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD, 'CANN_FIELD==8.5.0', 'MindSpore==2.7.2']
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试解析带版本的包
        test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD+'==8.5.0']
        args = cli.parser.parse_args(test_args)
        
        self.assertEqual(args.os_list, [OS_FIELD])
        self.assertEqual(args.pkg_list, [CANN_FIELD+'==8.5.0'])


class TestAscendDownloadErrorHandling(unittest.TestCase):
    """测试 ascend_download.py 的错误处理"""
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_invalid_os_argument_error(self, mock_get_pkg_list, mock_get_os_list):
        """测试无效的操作系统参数错误"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试无效的OS参数
        test_args = ['--os-list', 'InvalidOS']
        
        # 使用 parse_known_args() 方法,它不会在遇到无效参数时退出
        # 但我们需要模拟 argparse 的错误处理行为
        with patch('argparse.ArgumentParser.error') as mock_error:
            # 解析参数
            cli.parser.parse_args(test_args)
            
            # 验证 error 方法被调用
            mock_error.assert_called()
            
            # 验证错误信息包含无效的选择
            error_messages = [call[0][0] for call in mock_error.call_args_list]
            has_invalid_choice_error = any('InvalidOS' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
            self.assertTrue(has_invalid_choice_error)
    
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    def test_invalid_package_argument_error(self, mock_get_pkg_list, mock_get_os_list):
        """测试无效的包参数错误"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        
        # 导入CLI类
        
        cli = CLI("test", "desc", "epilog")
        
        # 测试无效的包参数
        test_args = ['--os-list', OS_FIELD, '--download', 'InvalidPackage']
        
        # 使用 parse_known_args() 方法,它不会在遇到无效参数时退出
        # 但我们需要模拟 argparse 的错误处理行为
        with patch('argparse.ArgumentParser.error') as mock_error:
            # 解析参数
            cli.parser.parse_args(test_args)
            
            # 验证 error 方法被调用
            mock_error.assert_called()
            
            # 验证错误信息包含无效的选择
            error_messages = [call[0][0] for call in mock_error.call_args_list]
            has_invalid_choice_error = any('InvalidPackage' in msg and 'invalid choice' in msg.lower() for msg in error_messages)
            self.assertTrue(has_invalid_choice_error)


class TestAscendDownloadURLMocking(unittest.TestCase):
    """测试 ascend_download.py 中URL请求的mock"""
    
    @patch('urllib.request.Request')
    @patch('urllib.request.urlopen')
    def test_url_request_mocking(self, mock_urlopen, mock_request):
        """测试URL请求的mock"""
        # 设置mock响应
        mock_response = Mock()
        mock_response.getheader.return_value = '"123456"'
        mock_response.read.return_value = b'test data'
        mock_urlopen.return_value.__enter__.return_value = mock_response
        
        # 导入需要测试的函数
        from ascend_deployer.downloader.download_util import get_remote_content_length
        
        # 测试函数
        result = get_remote_content_length('http://example.com')
        
        # 验证结果
        self.assertEqual(result, 123456)
        
        # 验证mock被调用
        mock_request.assert_called_once()
        mock_urlopen.assert_called_once()
    
    @patch('urllib.request.Request')
    @patch('urllib.request.urlopen')
    def test_url_request_exception_handling(self, mock_urlopen, mock_request):
        """测试URL请求异常处理"""
        # 设置mock抛出异常
        mock_urlopen.side_effect = Exception('Connection failed')
        
        # 导入需要测试的函数
        from ascend_deployer.downloader.download_util import get_remote_content_length
        
        # 测试函数
        result = get_remote_content_length('http://example.com')
        
        # 验证异常被正确处理
        self.assertIsNone(result)


class TestAscendDownloadMainFunction(unittest.TestCase):
    """测试 ascend_download.py 的main函数"""
    
    @patch('ascend_download.get_download_path')
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    @patch('ascend_download.download_dependency')
    def test_main_function(self, mock_download_dependency, mock_get_pkg_list,
                          mock_get_os_list, mock_get_download_path):
        """测试main函数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        mock_get_download_path.return_value = '/tmp/test'
        mock_download_dependency.return_value = 0
        
        # 测试main函数
        test_args = ['--os-list', OS_FIELD]
        main(args=test_args, check=False)
        
        # main函数没有返回值,只验证download_dependency被调用
        mock_download_dependency.assert_called_once_with(
            [OS_FIELD],
            [],
            '/tmp/test',
            False
        )
    
    @patch('ascend_download.get_download_path')
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    @patch('ascend_download.download_dependency')
    def test_main_function_default_check(self, mock_download_dependency, mock_get_pkg_list,
                                        mock_get_os_list, mock_get_download_path):
        """测试main函数默认check参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [CANN_FIELD]
        mock_get_download_path.return_value = '/tmp/test'
        mock_download_dependency.return_value = 0
        
        # 测试main函数,使用默认check=True
        test_args = ['--os-list', OS_FIELD]
        main(args=test_args)  # 不指定check,使用默认值True
        
        # main函数没有返回值,只验证download_dependency被调用
        mock_download_dependency.assert_called_once_with(
            [OS_FIELD],
            [],
            '/tmp/test',
            True
        )
    
    @patch('ascend_download.get_download_path')
    @patch('ascend_deployer.download_util.get_os_list')
    @patch('ascend_deployer.download_util.get_pkg_list')
    @patch('ascend_download.download_dependency')
    def test_main_function_with_packages(self, mock_download_dependency, mock_get_pkg_list,
                                        mock_get_os_list, mock_get_download_path):
        """测试main函数带包参数"""
        # 设置mock
        mock_get_os_list.return_value = [OS_FIELD]
        mock_get_pkg_list.return_value = [MINDSPORE_FIELD, CANN_FIELD]
        mock_get_download_path.return_value = '/tmp/test'
        mock_download_dependency.return_value = 0
        
        # 测试main函数
        test_args = ['--os-list', OS_FIELD, '--download', CANN_FIELD, MINDSPORE_FIELD]
        main(args=test_args, check=False)
        
        # main函数没有返回值,只验证download_dependency被调用
        # 注意:包列表的顺序可能改变,所以只检查是否包含这两个包
        mock_download_dependency.assert_called_once()
        call_args = mock_download_dependency.call_args
        
        # 检查第一个参数(os_list)
        self.assertEqual(call_args[0][0], [OS_FIELD])
        
        # 检查第二个参数(pkg_list)是否包含这两个包
        pkg_list = call_args[0][1]
        self.assertIn(CANN_FIELD, pkg_list)
        self.assertIn(MINDSPORE_FIELD, pkg_list)
        self.assertEqual(len(pkg_list), 2)
        
        # 检查其他参数
        self.assertEqual(call_args[0][2], '/tmp/test')
        self.assertFalse(call_args[0][3])


if __name__ == '__main__':
    unittest.main()