import os
import sys
import tempfile
import shutil
from contextlib import contextmanager
from unittest.mock import patch, MagicMock

from library_test.base_test import BaseTest


class TestLabelNode(BaseTest):
    _VERSION_NEW = '26.1.0'
    _VERSION_HIGHER = '26.2.0'
    _VERSION_LOWER = '26.0.0'
    _VERSION_MUCH_LOWER = '25.0.0'
    _VERSION_RC = '26.1.RC1'
    _VERSION_OLD = '24.1.0'
    _VERSION_FUTURE = '27.0.0'
    _VERSION_VERY_OLD = '23.0.0'

    @classmethod
    def setUpClass(cls):
        workspace_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
        if workspace_path not in sys.path:
            sys.path.append(workspace_path)
        super().setUpClass()
        cls.replace_linux_module()
        from ascend_deployer.library.label_node import LabelNode
        cls.LabelNode = LabelNode

    @contextmanager
    def _create_mock_package_dir(self, version_str):
        temp_dir = tempfile.mkdtemp()
        try:
            package_dir = os.path.join(temp_dir, 'mindxdl', 'dlPackage', 'aarch64')
            os.makedirs(package_dir)
            package_path = os.path.join(package_dir,
                                        'ascend-k8s-device-plugin-{}.zip'.format(version_str))
            with open(package_path, 'w') as f:
                f.write("mock package")
            yield temp_dir
        finally:
            shutil.rmtree(temp_dir)

    def _init_label_node(self, **kwargs):
        default_params = {
            'step': 'get_label',
            'ansible_run_tags': [],
            'node_name': 'test-node',
            'master_node': False,
            'worker_node': True,
            'nodes_label': {},
            'group_count': 1,
            'noded_label': 'off',
        }
        default_params.update(kwargs)

        node = self.LabelNode.__new__(self.LabelNode)
        node.module = MagicMock()
        node.module.params = default_params
        node.step = default_params['step']
        node.tags = default_params['ansible_run_tags']
        node.node_name = default_params['node_name']
        node.master_node = default_params['master_node']
        node.worker_node = default_params['worker_node']
        node.nodes_label = default_params['nodes_label']
        node.sub_groups = default_params['group_count']
        node.noded_label = default_params['noded_label']
        node.resources_dir = default_params.get('resources_dir', '/tmp/test')
        node.arch = 'aarch64'
        node.package_dir = os.path.join(node.resources_dir, 'mindxdl', 'dlPackage', node.arch)
        node.facts = {}
        node.label_yaml_dir = '/tmp/test/label'
        return node

    def _capture_exit_json_result(self, node):
        captured_facts = {}

        def exit_json_side_effect(**kwargs):
            captured_facts.update(kwargs)
            raise Exception("exit_json called")

        with patch.object(node.module, 'exit_json', side_effect=exit_json_side_effect), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info', return_value={}):
            try:
                node.get_labels()
            except Exception as e:
                if "exit_json called" not in str(e):
                    raise

        return captured_facts.get('ansible_facts', {})


class TestGetDevicePluginVersion(TestLabelNode):

    def test_extract_version_from_package_name(self):
        version_str = self._VERSION_NEW
        with self._create_mock_package_dir(version_str) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            version = node.get_device_plugin_version()
            self.assertEqual(version, version_str,
                             "Should extract version from valid package name")

    def test_extract_version_rc_format(self):
        version_str = "26.0.RC3"
        with self._create_mock_package_dir(version_str) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            version = node.get_device_plugin_version()
            self.assertEqual(version, version_str,
                             "Should extract RC version from package name")

    def test_no_package_found(self):
        temp_dir = tempfile.mkdtemp()
        package_dir = os.path.join(temp_dir, 'mindxdl', 'dlPackage', 'aarch64')
        os.makedirs(package_dir)
        try:
            node = self._init_label_node(resources_dir=temp_dir)
            version = node.get_device_plugin_version()
            self.assertIsNone(version,
                              "Should return None when no matching package found")
        finally:
            shutil.rmtree(temp_dir)

    def test_package_dir_not_exists(self):
        node = self._init_label_node(resources_dir='/nonexistent/path')
        version = node.get_device_plugin_version()
        self.assertIsNone(version,
                          "Should return None when package directory does not exist")

    def test_ignore_non_device_plugin_packages(self):
        temp_dir = tempfile.mkdtemp()
        package_dir = os.path.join(temp_dir, 'mindxdl', 'dlPackage', 'aarch64')
        os.makedirs(package_dir)
        other_package = os.path.join(package_dir, 'ascend-operator-26.0.0.zip')
        with open(other_package, 'w') as f:
            f.write("mock package")
        try:
            node = self._init_label_node(resources_dir=temp_dir)
            version = node.get_device_plugin_version()
            self.assertIsNone(version,
                              "Should ignore non-device-plugin packages")
        finally:
            shutil.rmtree(temp_dir)


class TestIsNewVersion(TestLabelNode):

    def test_version_26_1_0_exact_match(self):
        with self._create_mock_package_dir(self._VERSION_NEW) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            is_new = node.is_new_version()
            self.assertTrue(is_new,
                            "26.1.0 should be considered new version")

    def test_version_26_2_0_higher(self):
        with self._create_mock_package_dir(self._VERSION_HIGHER) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            is_new = node.is_new_version()
            self.assertTrue(is_new,
                            "26.2.0 should be considered new version")

    def test_version_26_0_0_lower(self):
        with self._create_mock_package_dir(self._VERSION_LOWER) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            is_new = node.is_new_version()
            self.assertFalse(is_new,
                             "26.0.0 should be considered old version")

    def test_version_25_0_0_much_lower(self):
        with self._create_mock_package_dir(self._VERSION_MUCH_LOWER) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            is_new = node.is_new_version()
            self.assertFalse(is_new,
                             "25.0.0 should be considered old version")

    def test_rc_version_less_than_release(self):
        with self._create_mock_package_dir(self._VERSION_RC) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir)
            is_new = node.is_new_version()
            self.assertFalse(is_new,
                             "26.1.RC1 should be considered old version")

    def test_no_version_fallback(self):
        node = self._init_label_node(resources_dir='/nonexistent/path')
        is_new = node.is_new_version()
        self.assertFalse(is_new,
                         "Should fallback to False when no version can be detected")


class TestGetNodedLabel(TestLabelNode):

    def test_noded_on(self):
        node = self._init_label_node(noded_label='on')
        result = node.get_noded_label()
        self.assertEqual(result, {'nodeDEnable': 'on'})

    def test_noded_tag(self):
        node = self._init_label_node(ansible_run_tags=['noded'])
        result = node.get_noded_label()
        self.assertEqual(result, {'nodeDEnable': 'on'})

    def test_dl_tag(self):
        node = self._init_label_node(ansible_run_tags=['dl'])
        result = node.get_noded_label()
        self.assertEqual(result, {'nodeDEnable': 'on'})

    def test_noded_off(self):
        node = self._init_label_node(noded_label='off')
        result = node.get_noded_label()
        self.assertEqual(result, {})

    def test_noded_off_no_tags(self):
        node = self._init_label_node(noded_label='off', ansible_run_tags=[])
        result = node.get_noded_label()
        self.assertEqual(result, {})


class TestGetDevicePluginLabel(TestLabelNode):

    def test_arch_arm(self):
        node = self._init_label_node()
        node.arch = 'aarch64'
        with patch.object(node, 'iter_cmd_output', return_value=iter([])), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info', return_value={}):
            labels = node.get_device_plugin_label()
        self.assertIn('host-arch', labels)
        self.assertEqual(labels['host-arch'], 'huawei-arm')

    def test_accelerator_ascend310(self):
        node = self._init_label_node()
        lspci_output = ['Processing accelerators: Device d100']
        with patch.object(node, 'iter_cmd_output', side_effect=[iter(lspci_output), iter([]), iter([])]), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info', return_value={}):
            labels = node.get_device_plugin_label()
        self.assertIn('accelerator', labels)
        self.assertEqual(labels['accelerator'], 'huawei-Ascend310')

    def test_accelerator_ascend910(self):
        node = self._init_label_node()
        lspci_output = ['Processing accelerators: Device d801']
        with patch.object(node, 'iter_cmd_output', side_effect=[iter(lspci_output), iter([]), iter([])]), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info', return_value={}):
            labels = node.get_device_plugin_label()
        self.assertIn('accelerator', labels)
        self.assertEqual(labels['accelerator'], 'huawei-Ascend910')

    def test_accelerator_ascend310P(self):
        node = self._init_label_node()
        lspci_output = ['Processing accelerators: Device d500']
        with patch.object(node, 'iter_cmd_output', side_effect=[iter(lspci_output), iter([]), iter([])]), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info', return_value={}):
            labels = node.get_device_plugin_label()
        self.assertIn('accelerator', labels)
        self.assertEqual(labels['accelerator'], 'huawei-Ascend310P')

    def test_server_usage_infer(self):
        node = self._init_label_node()
        from ascend_deployer.module_utils import compatibility_config
        with patch.object(node, 'iter_cmd_output', return_value=iter([])), \
                patch('ascend_deployer.library.label_node.common_info.get_npu_info',
                      return_value={'card': compatibility_config.Hardware.ATLAS_800I_A2}):
            labels = node.get_device_plugin_label()
        self.assertIn('server-usage', labels)
        self.assertEqual(labels['server-usage'], 'infer')


class TestGetLabelsNewVersion(TestLabelNode):

    def test_worker_node_new_version(self):
        with self._create_mock_package_dir(self._VERSION_NEW) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('node-role.kubernetes.io/worker', node_labels)
                self.assertIn('workerselector', node_labels)
                self.assertNotIn('host-arch', node_labels,
                                 "New version should NOT have host-arch")
                self.assertNotIn('accelerator', node_labels,
                                 "New version should NOT have accelerator")
                self.assertNotIn('accelerator-type', node_labels,
                                 "New version should NOT have accelerator-type")
                self.assertNotIn('server-usage', node_labels,
                                 "New version should NOT have server-usage")
                self.assertNotIn('nodeDEnable', node_labels,
                                 "New version should NOT have nodeDEnable when off")

    def test_master_node_new_version(self):
        with self._create_mock_package_dir(self._VERSION_NEW) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, master_node=True, worker_node=False)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('masterselector', node_labels)
                self.assertNotIn('node-role.kubernetes.io/worker', node_labels,
                                 "Master node should NOT have worker labels")
                self.assertNotIn('host-arch', node_labels)

    def test_both_master_and_worker_new_version(self):
        with self._create_mock_package_dir(self._VERSION_NEW) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, master_node=True, worker_node=True)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('masterselector', node_labels)
                self.assertIn('node-role.kubernetes.io/worker', node_labels)
                self.assertIn('workerselector', node_labels)
                self.assertNotIn('host-arch', node_labels)


class TestGetLabelsOldVersion(TestLabelNode):

    def test_worker_node_old_version(self):
        with self._create_mock_package_dir(self._VERSION_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('node-role.kubernetes.io/worker', node_labels)
                self.assertIn('workerselector', node_labels)
                self.assertIn('host-arch', node_labels,
                              "Old version should have host-arch")

    def test_master_node_old_version(self):
        with self._create_mock_package_dir(self._VERSION_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, master_node=True, worker_node=False)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('masterselector', node_labels)
                self.assertNotIn('node-role.kubernetes.io/worker', node_labels)

    def test_worker_node_noded_enabled(self):
        with self._create_mock_package_dir(self._VERSION_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True, noded_label='on')
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('nodeDEnable', node_labels,
                              "Should have nodeDEnable when enabled")
                self.assertEqual(node_labels['nodeDEnable'], 'on')

    def test_default_case_simplified_labels(self):
        with self._create_mock_package_dir(self._VERSION_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir,
                                         master_node=False, worker_node=False)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('node-role.kubernetes.io/worker', node_labels,
                              "Default should have role label")
                self.assertIn('host-arch', node_labels,
                              "Old version default should have host-arch")


class TestVersionBoundaryCases(TestLabelNode):

    def test_version_26_1_rc1_just_below_boundary(self):
        with self._create_mock_package_dir(self._VERSION_RC) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('host-arch', node_labels,
                              "Old version should include host-arch")

    def test_version_26_1_0_exact_boundary(self):
        with self._create_mock_package_dir(self._VERSION_NEW) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            is_new = node.is_new_version()
            self.assertTrue(is_new,
                            "26.1.0 exactly at boundary should use new logic")

    def test_version_27_0_0_future_major(self):
        with self._create_mock_package_dir(self._VERSION_FUTURE) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            is_new = node.is_new_version()
            self.assertTrue(is_new,
                            "27.0.0 future major version should use new logic")


class TestBackwardCompatibility(TestLabelNode):

    def test_conservative_fallback_when_no_package_found(self):
        node = self._init_label_node(resources_dir='/nonexistent/path', worker_node=True)
        with patch.object(node, 'iter_cmd_output', return_value=iter([])):
            ansible_facts = self._capture_exit_json_result(node)
            node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

            self.assertIn('host-arch', node_labels,
                          "Conservative fallback should include host-arch")
            self.assertIn('node-role.kubernetes.io/worker', node_labels,
                          "Conservative fallback should include worker role")

    def test_arch_detection_in_old_version(self):
        with self._create_mock_package_dir(self._VERSION_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            node.arch = 'x86_64'
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('host-arch', node_labels,
                              "Old version should detect architecture")
                self.assertEqual(node_labels['host-arch'], 'huawei-x86',
                                 "x86_64 should map to huawei-x86")

    def test_arm_architecture_detection(self):
        with self._create_mock_package_dir(self._VERSION_VERY_OLD) as temp_dir:
            node = self._init_label_node(resources_dir=temp_dir, worker_node=True)
            node.arch = 'aarch64'
            with patch.object(node, 'iter_cmd_output', return_value=iter([])):
                ansible_facts = self._capture_exit_json_result(node)
                node_labels = ansible_facts.get('node_label', {}).get('test-node', {})

                self.assertIn('host-arch', node_labels,
                              "Old version should detect ARM architecture")
                self.assertEqual(node_labels['host-arch'], 'huawei-arm',
                                 "aarch64 should map to huawei-arm")