from unittest import mock

from unittest.mock import patch, MagicMock

from library_test.base_test import BaseTest





class AnsibleModuleMocker:

    def __init__(self, params: dict):

        self.params = params





class BaseTestCheckUtil(BaseTest):



    @classmethod

    def get_module_path(cls):

        return "ascend_deployer.module_utils.check_library_utils"



    @classmethod

    def setUpClass(cls) -> None:

        super().setUpClass()

        from ascend_deployer.module_utils.check_library_utils.cann_checks import CANNCheck

        param = {"tags": "", "ascend_deployer_work_dir": "", "python_version": "", "packages": ""}

        cls.cann_check = CANNCheck(AnsibleModuleMocker(param), {}, [])





class TestCheckUtils(BaseTestCheckUtil):



    @patch('os.stat')

    @patch('os.path.isdir')

    def test_check_cann_install_path_permission(self, mock_isdir, mock_stat):

        mock_isdir.return_value = False

        self.cann_check.check_cann_install_path_permission()

        self.assertEqual([], self.cann_check.error_messages)



        mock_isdir.return_value = True

        magic_mock = MagicMock()

        magic_mock.st_uid = 1

        mock_stat.return_value = magic_mock

        self.cann_check.check_cann_install_path_permission()

        self.assertEqual(["[ASCEND][ERROR] The owner of the cann installation dir "

                          "'/usr/local/Ascend' must be root, change the owner to root"], self.cann_check.error_messages)



        self.cann_check.error_messages = []

        mock_isdir.return_value = True

        magic_mock = MagicMock()

        magic_mock.st_uid = 0

        magic_mock.st_mode = 16871

        mock_stat.return_value = magic_mock

        self.cann_check.check_cann_install_path_permission()

        self.assertEqual(["[ASCEND][ERROR] When installing cann, the user and group of the installation path "

                          "must be root, and the permission must be 755. "], self.cann_check.error_messages)



        self.cann_check.error_messages = []

        mock_isdir.return_value = True

        magic_mock = MagicMock()

        magic_mock.st_uid = 0

        magic_mock.st_mode = 16877

        mock_stat.return_value = magic_mock

        self.cann_check.check_cann_install_path_permission()

        self.assertEqual([], self.cann_check.error_messages)



    @patch('builtins.open', new_callable=mock.mock_open, read_data="Driver_Install_Path_Param=/usr/local/Ascend")

    @patch('os.path.isfile')

    def test_check_driver_installation(self, mock_isfile, mock_open):

        mock_isfile.return_value = False

        self.cann_check.check_driver_installation()

        self.assertEqual([], self.cann_check.error_messages)



        path_returns1 = {"/etc/ascend_install.info": True, "/usr/local/Ascend/driver/version.info": False}



        def isfile_side_effect1(path):

            return path_returns1.get(path, False)



        mock_isfile.side_effect = isfile_side_effect1

        self.cann_check.check_driver_installation()

        self.assertEqual(["[ASCEND][ERROR] The /etc/ascend_install.info file exists in the environment, "

                          "and the file records the driver installation path. However, "

                          "the driver/version.info does not exist in the installation path. "

                          "Please check the driver is correctly installed."], self.cann_check.error_messages)



        self.cann_check.error_messages = []

        path_returns2 = {"/etc/ascend_install.info": True, "/usr/local/Ascend/driver/version.info": True}



        def isfile_side_effect2(path):

            return path_returns2.get(path, True)



        mock_isfile.side_effect = isfile_side_effect2

        self.cann_check.check_driver_installation()

        self.assertEqual([], self.cann_check.error_messages)



    @patch('glob.glob')

    def test_check_kernels(self, mock_glob):

        self.cann_check.npu_info = {"scene": "infer"}

        self.cann_check.check_kernels()

        self.assertEqual(["[ASCEND][ERROR] kernels not support infer scene"], self.cann_check.error_messages)



        self.cann_check.error_messages = []

        self.cann_check.npu_info = {"scene": "train"}

        self.cann_check.packages = {}

        self.cann_check.check_kernels()

        self.assertEqual(["[ASCEND][ERROR] Do not find kernels package, please download kernels package first."],

                         self.cann_check.error_messages)



        self.cann_check.error_messages = []

        self.cann_check.npu_info = {"scene": "train"}

        self.cann_check.packages = {"kernels": "path_to_kernels", "nnae": "path_to_nnae"}

        self.cann_check.tags = ["nnae"]

        self.cann_check.check_kernels()

        self.assertEqual([], self.cann_check.error_messages)



        self.cann_check.error_messages = []

        self.cann_check.npu_info = {"scene": "train"}

        self.cann_check.packages = {"kernels": "kernels path err"}

        self.cann_check.tags = ["nnae"]

        self.cann_check.check_kernels()

        self.assertEqual(["[ASCEND][ERROR] Do not find kernels package, please download kernels package first."],

                         self.cann_check.error_messages)



        self.cann_check.error_messages = []

        self.cann_check.npu_info = {"scene": "train"}

        self.cann_check.packages = {"kernels": "Atlas-A3-cann-kernels_8.0.0_linux-aarch64.zip"}

        self.cann_check.tags = ["nnae"]

        mock_glob.return_value = []

        self.cann_check.check_kernels()

        self.assertEqual(['[ASCEND][ERROR] Please install toolkit, nnae or nnrt version 8.0.0 before '

                          'installing kernels 8.0.0'], self.cann_check.error_messages)