import os
import sys
import unittest
from unittest.mock import patch, MagicMock

from library_test.base_test import BaseTest
from ascend_deployer.module_utils.common_info import NPUCardName


class TestHccnCheck(BaseTest):

    @classmethod
    def setUpClass(cls):
        os.environ['DEPLOYER_CHECK_MODE'] = 'fast'

        workspace_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
        if workspace_path not in sys.path:
            sys.path.append(workspace_path)
        super().setUpClass()

        cls._ansible_patcher = patch(
            'library_test.mock_manage.mock_model.mock_ansible_module.AnsibleModule'
        )
        cls._mock_ansible_cls = cls._ansible_patcher.start()

        cls._user_check_patcher = patch(
            'ansible.module_utils.check_library_utils.check_user.UserCheck'
        )
        cls._user_check_patcher.start()

        from ascend_deployer.library.check_hccn import HccnCheck
        cls.HccnCheck = HccnCheck

        cls._check_output_patcher = patch(
            'ansible.module_utils.check_output_manager.CHECK_OUTPUT_MANAGER'
        )
        cls._mock_check_output = cls._check_output_patcher.start()

        cls._wait_patcher = patch(
            'ansible.module_utils.check_output_manager.wait_for_finish'
        )
        cls._wait_patcher.start()

    @classmethod
    def tearDownClass(cls):
        patch.stopall()

    def setUp(self):
        super().setUp()
        self._mock_ansible_cls.reset_mock()
        self._mock_check_output.reset_mock()
        self._mock_check_output.fail_happen = False
        self._mock_check_output.get_check_output.return_value.error_msg = []

    def _create_checker(self, **params):
        default_params = {
            'device_ips': [],
            'gateways': [],
            'netmask': '255.255.255.0',
            'detect_ips': [],
            'common_network': '',
            'bitmap': '',
            'dscp_tc': '',
            'ip_address': 'localhost',
            'HARDWARE_TYPE': '',
        }
        default_params.update(params)

        checker = self.HccnCheck.__new__(self.HccnCheck)
        checker.module = MagicMock()
        checker.module.params = default_params
        checker.device_ips = default_params['device_ips']
        checker.gateways = default_params['gateways']
        checker.netmask = default_params['netmask']
        checker.detect_ips = default_params['detect_ips']
        checker.common_network = default_params['common_network']
        checker.bitmap = default_params['bitmap']
        checker.dscp_tc = default_params['dscp_tc']
        checker.ip_address = default_params['ip_address']
        checker.error_messages = []
        checker.user_check = MagicMock()
        checker.npu_name = NPUCardName.A910A2
        checker.npu_count = 8
        return checker


class TestGetNpuName(TestHccnCheck):

    def test_npu_name_910b(self):
        with patch('ansible.module_utils.check_utils.CheckUtil.get_card', return_value='910b'):
            self.assertEqual(NPUCardName.A910A2, self.HccnCheck.get_npu_name())

    def test_npu_name_910_93(self):
        with patch('ansible.module_utils.check_utils.CheckUtil.get_card', return_value='910_93'):
            self.assertEqual(NPUCardName.A910A3, self.HccnCheck.get_npu_name())

    def test_npu_name_910(self):
        with patch('ansible.module_utils.check_utils.CheckUtil.get_card', return_value='910'):
            self.assertEqual(NPUCardName.A910A1, self.HccnCheck.get_npu_name())

    def test_npu_name_other(self):
        with patch('ansible.module_utils.check_utils.CheckUtil.get_card', return_value='310p'):
            self.assertEqual('310p', self.HccnCheck.get_npu_name())


class TestIsIpv6(TestHccnCheck):

    def test_ipv4(self):
        self.assertFalse(self.HccnCheck.is_ipv6('192.168.0.1'))

    def test_ipv6(self):
        self.assertTrue(self.HccnCheck.is_ipv6('2001:db8::1'))

    def test_ipv6_full(self):
        self.assertTrue(self.HccnCheck.is_ipv6('fd15:4ba5:5a2b:1008:b5e6:77db:eea2:73bc'))


class TestIsValidNetmask(TestHccnCheck):

    def test_valid_ipv4_netmask(self):
        checker = self._create_checker()
        self.assertTrue(checker.is_valid_netmask('255.255.255.0'))
        self.assertTrue(checker.is_valid_netmask('255.255.0.0'))
        self.assertTrue(checker.is_valid_netmask('255.0.0.0'))

    def test_invalid_ipv4_netmask_non_contiguous(self):
        checker = self._create_checker()
        self.assertFalse(checker.is_valid_netmask('255.255.0.255'))

    def test_invalid_ipv4_netmask_wrong_format(self):
        checker = self._create_checker()
        self.assertFalse(checker.is_valid_netmask('255.255.255'))
        self.assertFalse(checker.is_valid_netmask('256.255.255.0'))
        self.assertFalse(checker.is_valid_netmask('abc.def.ghi.jkl'))

    def test_valid_ipv6_cidr_netmask(self):
        checker = self._create_checker()
        self.assertTrue(checker.is_valid_netmask('2001:db8::/64'))

    def test_invalid_ipv6_cidr_netmask(self):
        checker = self._create_checker()
        self.assertFalse(checker.is_valid_netmask('2001:db8::/129'))

    def test_valid_ipv6_integer_netmask(self):
        checker = self._create_checker()
        self.assertTrue(checker.is_valid_netmask('64'))
        self.assertTrue(checker.is_valid_netmask('0'))
        self.assertTrue(checker.is_valid_netmask('128'))

    def test_invalid_ipv6_integer_netmask(self):
        checker = self._create_checker()
        self.assertFalse(checker.is_valid_netmask('129'))
        self.assertFalse(checker.is_valid_netmask('-1'))


class TestIpToBinary(TestHccnCheck):

    def test_ipv4_to_binary(self):
        checker = self._create_checker()
        self.assertEqual('11000000101010000000000000000001', checker.ip_to_binary('192.168.0.1'))
        self.assertEqual('00000000000000000000000000000000', checker.ip_to_binary('0.0.0.0'))
        self.assertEqual('11111111111111111111111111111111', checker.ip_to_binary('255.255.255.255'))

    def test_ipv6_to_binary(self):
        checker = self._create_checker()
        result = checker.ip_to_binary('::1')
        self.assertEqual(128, len(result))
        self.assertTrue(result.endswith('1'))

    def test_invalid_ip(self):
        checker = self._create_checker()
        result = checker.ip_to_binary('invalid_ip')
        self.assertIsNone(result)
        self.assertGreater(len(checker.error_messages), 0)


class TestInSameSubnet(TestHccnCheck):

    def test_same_ipv4_subnet(self):
        checker = self._create_checker(netmask='255.255.255.0')
        self.assertTrue(checker.in_same_subnet('192.168.1.10', '192.168.1.1'))

    def test_different_ipv4_subnet(self):
        checker = self._create_checker(netmask='255.255.255.0')
        self.assertFalse(checker.in_same_subnet('192.168.2.10', '192.168.1.1'))

    def test_same_ipv6_subnet(self):
        checker = self._create_checker(netmask='64')
        self.assertTrue(checker.in_same_subnet('2001:db8::1', '2001:db8::2'))

    def test_different_ipv6_subnet(self):
        checker = self._create_checker(netmask='64')
        self.assertFalse(checker.in_same_subnet('2001:db8::1', '2001:db9::2'))

    def test_mixed_ip_versions(self):
        checker = self._create_checker(netmask='255.255.255.0')
        self.assertFalse(checker.in_same_subnet('192.168.1.10', '2001:db8::1'))

    def test_same_ipv6_subnet_cidr_netmask(self):
        checker = self._create_checker(netmask='2001:db8::/64')
        self.assertTrue(checker.in_same_subnet('2001:db8::1', '2001:db8::2'))

    def test_zero_prefix_length(self):
        checker = self._create_checker(netmask='2001:db8::/0')
        self.assertFalse(checker.in_same_subnet('2001:db8::1', '2001:db9::1'))


class TestCheckHccnBitmap(TestHccnCheck):

    def test_valid_bitmap(self):
        checker = self._create_checker(bitmap='1,0,1,0,1,0,1,0')
        checker.check_hccn_bitmap()
        self.assertEqual(0, len(checker.error_messages))

    def test_invalid_bitmap_wrong_length(self):
        checker = self._create_checker(bitmap='1,0,1,0')
        checker.check_hccn_bitmap()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_bitmap_wrong_chars(self):
        checker = self._create_checker(bitmap='1,0,2,0,1,0,1,0')
        checker.check_hccn_bitmap()
        self.assertGreater(len(checker.error_messages), 0)

    def test_empty_bitmap(self):
        checker = self._create_checker(bitmap='')
        checker.check_hccn_bitmap()
        self.assertEqual(0, len(checker.error_messages))

    def test_none_bitmap(self):
        checker = self._create_checker()
        checker.bitmap = None
        checker.check_hccn_bitmap()
        self.assertEqual(0, len(checker.error_messages))


class TestCheckHccnDscpTc(TestHccnCheck):

    def test_valid_dscp_tc(self):
        checker = self._create_checker(dscp_tc='22:0,')
        checker.check_hccn_dscp_tc()
        self.assertEqual(0, len(checker.error_messages))

    def test_valid_dscp_tc_boundary(self):
        checker = self._create_checker(dscp_tc='63:3,')
        checker.check_hccn_dscp_tc()
        self.assertEqual(0, len(checker.error_messages))

    def test_valid_dscp_tc_zero(self):
        checker = self._create_checker(dscp_tc='0:0,')
        checker.check_hccn_dscp_tc()
        self.assertEqual(0, len(checker.error_messages))

    def test_invalid_dscp_tc_format(self):
        checker = self._create_checker(dscp_tc='22:0')
        checker.check_hccn_dscp_tc()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_dscp_out_of_range(self):
        checker = self._create_checker(dscp_tc='64:0,')
        checker.check_hccn_dscp_tc()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_tc_out_of_range(self):
        checker = self._create_checker(dscp_tc='22:4,')
        checker.check_hccn_dscp_tc()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_dscp_tc_bad_format(self):
        checker = self._create_checker(dscp_tc='abc:0,')
        checker.check_hccn_dscp_tc()
        self.assertGreater(len(checker.error_messages), 0)

    def test_empty_dscp_tc(self):
        checker = self._create_checker(dscp_tc='')
        checker.check_hccn_dscp_tc()
        self.assertEqual(0, len(checker.error_messages))

    def test_none_dscp_tc(self):
        checker = self._create_checker()
        checker.dscp_tc = None
        checker.check_hccn_dscp_tc()
        self.assertEqual(0, len(checker.error_messages))


class TestCheckHccnSupport(TestHccnCheck):

    def test_supported_npu(self):
        checker = self._create_checker()
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_support()
        self.assertEqual(0, len(checker.error_messages))

    def test_supported_npu_a1(self):
        checker = self._create_checker()
        checker.npu_name = NPUCardName.A910A1
        checker.check_hccn_support()
        self.assertEqual(0, len(checker.error_messages))

    def test_supported_npu_a3(self):
        checker = self._create_checker()
        checker.npu_name = NPUCardName.A910A3
        checker.check_hccn_support()
        self.assertEqual(0, len(checker.error_messages))

    def test_unsupported_npu(self):
        checker = self._create_checker()
        checker.npu_name = '310p'
        checker.check_hccn_support()
        self.assertGreater(len(checker.error_messages), 0)


class TestCheckHccnIp(TestHccnCheck):

    def test_empty_device_ips(self):
        checker = self._create_checker(device_ips=[], detect_ips=['192.168.0.1'])
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_empty_detect_ips(self):
        checker = self._create_checker(device_ips=['192.168.0.1'], detect_ips=[])
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_mismatched_npu_count(self):
        checker = self._create_checker(
            device_ips=['192.168.0.1'],
            detect_ips=['192.168.0.1']
        )
        checker.npu_count = 8
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_ip(self):
        checker = self._create_checker(
            device_ips=['invalid_ip'],
            detect_ips=['192.168.0.1']
        )
        checker.npu_count = 1
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_mixed_ipv4_ipv6(self):
        checker = self._create_checker(
            device_ips=['192.168.0.1', '2001:db8::1'],
            detect_ips=['192.168.0.1', '2001:db8::1']
        )
        checker.npu_count = 2
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_ipv6_not_supported(self):
        checker = self._create_checker(
            device_ips=['2001:db8::1'],
            detect_ips=['2001:db8::1']
        )
        checker.npu_count = 1
        checker.npu_name = NPUCardName.A910A1
        checker.check_hccn_ip()
        self.assertGreater(len(checker.error_messages), 0)

    def test_valid_ipv4_config(self):
        ips = ['192.168.0.1', '192.168.0.2', '192.168.0.3', '192.168.0.4',
               '192.168.0.5', '192.168.0.6', '192.168.0.7', '192.168.0.8']
        checker = self._create_checker(device_ips=ips, detect_ips=ips)
        checker.npu_count = 8
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_ip()
        self.assertEqual(0, len(checker.error_messages))

    def test_valid_ipv6_config(self):
        ips = ['2001:db8::1', '2001:db8::2', '2001:db8::3', '2001:db8::4',
               '2001:db8::5', '2001:db8::6', '2001:db8::7', '2001:db8::8']
        checker = self._create_checker(device_ips=ips, detect_ips=ips)
        checker.npu_count = 8
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_ip()
        self.assertEqual(0, len(checker.error_messages))


class TestCheckHccnGateways(TestHccnCheck):

    def test_empty_gateways(self):
        checker = self._create_checker(gateways=[])
        checker.check_hccn_gateways()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_gateway(self):
        checker = self._create_checker(gateways=['invalid_ip'])
        checker.check_hccn_gateways()
        self.assertGreater(len(checker.error_messages), 0)

    def test_valid_gateway_ipv4(self):
        checker = self._create_checker(gateways=['192.168.0.1'])
        checker.check_hccn_gateways()
        self.assertEqual(0, len(checker.error_messages))

    def test_valid_gateway_ipv6(self):
        checker = self._create_checker(gateways=['2001:db8::1'])
        checker.check_hccn_gateways()
        self.assertEqual(0, len(checker.error_messages))

    def test_multiple_gateways_mixed_valid_invalid(self):
        checker = self._create_checker(gateways=['192.168.0.1', 'invalid_ip'])
        checker.check_hccn_gateways()
        self.assertGreater(len(checker.error_messages), 0)


class TestCheckHccnNetmask(TestHccnCheck):

    def test_empty_netmask(self):
        checker = self._create_checker(netmask='')
        checker.check_hccn_netmask()
        self.assertGreater(len(checker.error_messages), 0)

    def test_invalid_netmask(self):
        checker = self._create_checker(netmask='255.255.0.255')
        checker.check_hccn_netmask()
        self.assertGreater(len(checker.error_messages), 0)

    def test_valid_netmask_ipv4(self):
        checker = self._create_checker(netmask='255.255.255.0')
        checker.check_hccn_netmask()
        self.assertEqual(0, len(checker.error_messages))

    def test_valid_netmask_ipv6(self):
        checker = self._create_checker(netmask='64')
        checker.check_hccn_netmask()
        self.assertEqual(0, len(checker.error_messages))


class TestCheckHccnConfiguration(TestHccnCheck):

    def test_empty_device_ips(self):
        checker = self._create_checker(device_ips=[], gateways=['192.168.0.1'])
        checker.check_hccn_configuration()
        self.assertEqual(0, len(checker.error_messages))

    def test_empty_gateways(self):
        checker = self._create_checker(device_ips=['192.168.0.1'], gateways=[])
        checker.check_hccn_configuration()
        self.assertEqual(0, len(checker.error_messages))

    def test_ip_in_gateway_subnet(self):
        checker = self._create_checker(
            device_ips=['192.168.1.10'],
            gateways=['192.168.1.1'],
            netmask='255.255.255.0'
        )
        checker.check_hccn_configuration()
        self.assertEqual(0, len(checker.error_messages))

    def test_ip_not_in_any_gateway_subnet(self):
        checker = self._create_checker(
            device_ips=['192.168.2.10'],
            gateways=['192.168.1.1'],
            netmask='255.255.255.0'
        )
        checker.check_hccn_configuration()
        self.assertGreater(len(checker.error_messages), 0)

    def test_ip_in_second_gateway_subnet(self):
        checker = self._create_checker(
            device_ips=['192.168.2.10'],
            gateways=['192.168.1.1', '192.168.2.1'],
            netmask='255.255.255.0'
        )
        checker.check_hccn_configuration()
        self.assertEqual(0, len(checker.error_messages))


class TestCheckHccnCommonNetwork(TestHccnCheck):

    def test_a910a2_empty_common_network(self):
        checker = self._create_checker(common_network='')
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_common_network()
        self.assertEqual(0, len(checker.error_messages))

    def test_a910a2_default_common_network(self):
        checker = self._create_checker(common_network='0.0.0.0/0')
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_common_network()
        self.assertEqual(0, len(checker.error_messages))

    def test_a910a2_invalid_common_network(self):
        checker = self._create_checker(common_network='192.168.0.0/24')
        checker.npu_name = NPUCardName.A910A2
        checker.check_hccn_common_network()
        self.assertGreater(len(checker.error_messages), 0)

    def test_a910a1_valid_common_network(self):
        checker = self._create_checker(common_network='0.0.0.0/0')
        checker.npu_name = NPUCardName.A910A1
        checker.check_hccn_common_network()
        self.assertEqual(0, len(checker.error_messages))

    def test_a910a1_invalid_common_network(self):
        checker = self._create_checker(common_network='192.168.0.0/24')
        checker.npu_name = NPUCardName.A910A1
        checker.check_hccn_common_network()
        self.assertGreater(len(checker.error_messages), 0)

    def test_a910a3_valid_common_network(self):
        checker = self._create_checker(common_network='0.0.0.0/0')
        checker.npu_name = NPUCardName.A910A3
        checker.check_hccn_common_network()
        self.assertEqual(0, len(checker.error_messages))

    def test_a910a3_invalid_common_network(self):
        checker = self._create_checker(common_network='')
        checker.npu_name = NPUCardName.A910A3
        checker.check_hccn_common_network()
        self.assertGreater(len(checker.error_messages), 0)


class TestCheckNpuInstalled(TestHccnCheck):

    def test_npu_installed(self):
        checker = self._create_checker()
        checker.module.run_command.return_value = (0, 'npu info output', '')
        checker.check_npu_installed()
        self.assertEqual(0, len(checker.error_messages))

    def test_npu_not_installed(self):
        checker = self._create_checker()
        checker.module.run_command.return_value = (1, '', 'npu-smi not found')
        checker.check_npu_installed()
        self.assertGreater(len(checker.error_messages), 0)


class TestGetNpuCount(TestHccnCheck):

    def testget_npu_count_success(self):
        checker = self._create_checker()
        with patch.object(self.HccnCheck, 'get_npu_count', return_value=8):
            self.assertEqual(8, checker.get_npu_count())

    def testget_npu_count_zero(self):
        checker = self._create_checker()
        with patch.object(self.HccnCheck, 'get_npu_count', return_value=0):
            self.assertEqual(0, checker.get_npu_count())


class TestRun(TestHccnCheck):

    def test_run_success(self):
        test_ips = ['192.168.0.1', '192.168.0.2', '192.168.0.3', '192.168.0.4',
                    '192.168.0.5', '192.168.0.6', '192.168.0.7', '192.168.0.8']
        checker = self._create_checker(
            device_ips=test_ips,
            detect_ips=test_ips,
            gateways=['192.168.0.1'],
            netmask='255.255.255.0',
            common_network='0.0.0.0/0',
        )
        checker.npu_name = NPUCardName.A910A2
        checker.npu_count = 8
        checker.module.run_command.return_value = (0, 'npu info', '')

        with patch.object(self.HccnCheck, 'get_npu_count', return_value=8):
            checker.module.exit_json.side_effect = Exception("exit")
            with self.assertRaises(Exception):
                checker.run()

    def test_run_with_errors(self):
        checker = self._create_checker(
            device_ips=[],
            detect_ips=[],
            gateways=[],
            netmask='',
            common_network='invalid',
        )
        checker.npu_name = 'unsupported'
        checker.npu_count = 0
        checker.module.run_command.return_value = (1, '', 'npu-smi not found')

        with patch.object(self.HccnCheck, 'get_npu_count', return_value=0):
            checker.module.exit_json.side_effect = Exception("exit")
            with self.assertRaises(Exception):
                checker.run()