#!/usr/bin/env python3
# coding: utf-8
# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""
process_npu.py 系统测试
测试 ascend_deployer/library/process_npu.py 模块。
"""

import abc
import os
import sys
from unittest.mock import patch

# 保存原始 sys.path,以便测试完成后恢复
_original_sys_path = sys.path.copy()

# 添加项目根目录到路径,以便导入 library_test
# 使用绝对路径避免相对路径问题
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
test_dir = os.path.join(project_root, 'test')

# 确保路径唯一性,避免重复添加
if project_root not in sys.path:
    sys.path.append(project_root)
if test_dir not in sys.path:
    sys.path.append(test_dir)

# 使用 try-except 确保即使导入失败也能恢复 sys.path
try:
    from library_test.base_test import BaseLibraryTest
    from library_test.mock_manage.mock_model.mock_ansible_module import AnsibleModule
    from library_test.mock_manage.mock_handlers.mock_cmd_handler import MockCmdHandler
except ImportError as e:
    # 恢复原始 sys.path
    sys.path = _original_sys_path
    raise ImportError(f"Failed to import test dependencies: {e}") from e

TMP_DIR = "/tmp/test"
DRIVER_FIELD = "driver"
FIRM_FIELD = "firmware"


class TestBaseNpu(BaseLibraryTest, metaclass=abc.ABCMeta):
    TESTCASE_DIR = os.path.join(os.path.dirname(__file__), "testcase")

    @classmethod
    def get_module_path(cls):
        return "ascend_deployer.library.process_npu"

    @classmethod
    def get_testcase_path(cls):
        return os.path.join(cls.TESTCASE_DIR, "npu.yml")

    @classmethod
    def setUpClass(cls):
        # 先调用父类的setUp
        super().setUpClass()
        
        # 在每个测试前创建mock
        cls.get_os_and_arch_mocker = cls._mock_get_os_and_arch()
        cls.parse_card_mocker = cls._mock_parse_card()
        cls.get_npu_info_mocker = cls._mock_get_npu_info()
        
        from ascend_deployer.library.process_npu import NpuInstallation
        cls.NpuInstallation = NpuInstallation
    
    @classmethod
    def _mock_get_os_and_arch(cls):
        """mock common_info.get_os_and_arch() 方法"""
        patcher = patch(cls.get_module_path() + ".get_os_and_arch")
        mocker = patcher.start()
        mocker.return_value = "centos7.6-aarch64"
        return mocker
    
    @classmethod
    def _mock_parse_card(cls):
        """mock common_info.parse_card() 方法"""
        patcher = patch("ascend_deployer.module_utils.common_info.parse_card")
        mocker = patcher.start()
        mocker.return_value = "Atlas 900 A2"
        return mocker
    
    @classmethod
    def _mock_get_npu_info(cls):
        """mock common_info.get_npu_info() 方法"""
        patcher = patch("ascend_deployer.module_utils.common_info.get_npu_info")
        mocker = patcher.start()
        mocker.return_value = {"card": "Atlas 900 A2","model": "Atlas 900 A2","scene": "910_a2","product": "Atlas-900-A2"}
        return mocker


class TestNpuInstallation(TestBaseNpu):
    """测试 NpuInstallation 类"""
    
    def test_npu_installation(self):
        """测试 NpuInstallation 类安装"""
        
        params = {
            "force_upgrade_npu": False,
            "resource_dir": TMP_DIR,
            "cus_npu_info": "",
            "ansible_run_tags": ["npu"],
            "action": "install"
        }
        
        module = self._create_ansible_module(params)
        npu_install = self.NpuInstallation(module)
        
        self.assertEqual(npu_install.force_upgrade_npu, False)
        self.assertEqual(npu_install.resource_dir, TMP_DIR)
        self.assertEqual(npu_install.cus_npu_info, "")
        self.assertEqual(npu_install.action, "install")
        
        self.assertIn(DRIVER_FIELD, npu_install.install_target)
        self.assertIn(FIRM_FIELD, npu_install.install_target)
        self.assertIn(DRIVER_FIELD, npu_install.upgrade_target)
        self.assertIn(FIRM_FIELD, npu_install.upgrade_target)
        
        # 模拟 _process_npu 失败
        with patch.object(npu_install, '_process_npu'):
            with self.assertRaises(module.FailJson):
                npu_install.run()

        # diff tags
        params["ansible_run_tags"] = [DRIVER_FIELD]
        module = self._create_ansible_module(params)
        npu_install = self.NpuInstallation(module)
        
        self.assertIn(DRIVER_FIELD, npu_install.install_target)
        self.assertIn(DRIVER_FIELD, npu_install.upgrade_target)
        self.assertNotIn(FIRM_FIELD, npu_install.install_target)
        self.assertNotIn(FIRM_FIELD, npu_install.upgrade_target)

        with patch('ascend_deployer.library.process_npu.glob.glob') as mock_glob:
            mock_glob.return_value = [f"{TMP_DIR}/test_file.run"]
            
            # 使用 getattr 动态调用受保护方法
            find_files_method = getattr(npu_install, "_find_files")
            result = find_files_method(TMP_DIR, "*.run")
            
            self.assertEqual(result, f"{TMP_DIR}/test_file.run")
            self.assertIn("try to find", npu_install.messages[0])
            self.assertIn("find files:", npu_install.messages[1])
    
    def _create_ansible_module(self, params):
        """创建模拟的 AnsibleModule 实例"""
        cmd_handler = MockCmdHandler([], {})
        return AnsibleModule(params, cmd_handler)


if __name__ == '__main__':
    import unittest
    unittest.main()