import abc
import os
from unittest.mock import patch

from ascend_deployer.module_utils.common_info import NPUCardName

from library_test.base_test import BaseLibraryTest
from library_test.mock_manage.mock_model.mock_ansible_module import AnsibleModule


class TestBaseHCCN(BaseLibraryTest, metaclass=abc.ABCMeta):
    TESTCASE_DIR = os.path.join(os.path.dirname(__file__), "testcase")

    @classmethod
    def get_module_path(cls):
        return "ascend_deployer.library.process_hccn"

    @classmethod
    def setUpClass(cls) -> None:
        super().setUpClass()

        cls.product_mocker = cls._mock_get_card()
        # default mocker value
        cls.product_mocker.return_value = "910b"
        from ascend_deployer.library.process_hccn import BaseModule, IPUtils, Template, HCCN, CommonInfo
        cls.base_module = BaseModule()
        cls.ip_utils = IPUtils()
        cls.template = Template()
        cls.hccn = HCCN()
        cls.common_info = CommonInfo

    @classmethod
    def _mock_get_card(cls):
        patcher = patch(cls.get_module_path() + ".CheckUtil.get_card")
        mocker = patcher.start()
        return mocker

    @classmethod
    def get_testcase_path(cls):
        return os.path.join(cls.TESTCASE_DIR, "hccn.yml")


class TestBaseModule(TestBaseHCCN):

    def test_is_ipv6(self):
        # str
        self.assertFalse(self.base_module._is_ipv6("192.168.0.1"))
        self.assertTrue(self.base_module._is_ipv6("fd15:4ba5:5a2b:1008:b5e6:77db:eea2:73bc"))
        # list[str]
        self.assertFalse(self.base_module._is_ipv6(["192.168.0.1"]))
        self.assertFalse(self.base_module._is_ipv6([]))
        self.assertTrue(self.base_module._is_ipv6(["fd15:4ba5:5a2b:1008:b5e6:77db:eea2:73bc"]))

    def test_get_npu_name(self):
        # default
        self.assertEqual(NPUCardName.A910A2, self.base_module._get_npu_name())
        # 910A1
        self.product_mocker.return_value = "910"
        self.assertEqual(NPUCardName.A910A1, self.base_module._get_npu_name())
        # 910A3
        self.product_mocker.return_value = "910_93"
        self.assertEqual(NPUCardName.A910A3, self.base_module._get_npu_name())
        # other
        self.product_mocker.return_value = "310p"
        self.assertEqual("310p", self.base_module._get_npu_name())


class TestIPUtils(TestBaseHCCN):

    def test_get_ipv6_subnet(self):
        self.ip_utils.netmask = 1000
        ipv6 = "2001:db8::1"
        expect_subnet = "100000000000010000110110111000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001"
        self.assertEqual(expect_subnet, self.ip_utils._get_ipv6_subnet(ipv6))
        ipv6 = "invalid:db8::1"
        with self.assertRaises(AnsibleModule.FailJson) as e:
            self.ip_utils._get_ipv6_subnet(ipv6)
        res = self.parse_context_exception(e)
        self.assertTrue("Invalid IPV6 address" in res.get("msg"))
        self.ip_utils.netmask = "255.255.255.0"
        ipv6 = "2001:db8::1"
        with self.assertRaises(AnsibleModule.FailJson) as e:
            self.ip_utils._get_ipv6_subnet(ipv6)
        res = self.parse_context_exception(e)
        self.assertTrue("Invalid IPV6 address" in res.get("msg"))

    def test_is_same_ipv6_subnet(self):
        self.ip_utils.netmask = 64
        ip = "2001:db8::1"
        gateway = "2001:db8::2"
        self.assertTrue(self.ip_utils._is_same_ipv6_subnet(ip, gateway))
        gateway = "2001:db9::2"
        self.assertFalse(self.ip_utils._is_same_ipv6_subnet(ip, gateway))

    def test_get_ipv6_gateway(self):

        self.ip_utils.netmask = 64
        self.ip_utils.gateways = ["2001:db8::1","2001:db8::2","2001:db8::3","2001:db8::4","2001:db8::5","2001:db8::6","2001:db8::7","2001:db8::8"]

        ip = "2001:db8::9"
        self.assertEqual("2001:db8::1", self.ip_utils.get_ipv6_gateway(ip))
        ip = "2001:db9::1"
        self.assertEqual("", self.ip_utils.get_ipv6_gateway(ip))

    def test__convert_ipv4_to_it(self):
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("192.168.0.1"), 3232235521)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("0.0.0.0"), 0)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("255.255.255.255"), 4294967295)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("1.0.0.0"), 16777216)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("0.1.0.0"), 65536)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("0.0.0.1"), 1)
        self.assertEqual(self.ip_utils._convert_ipv4_to_int("0.0.1.0"), 256)

    def test_is_ipv4_in_subnet(self):
        # test ip and gateway are in the same subnet
        self.assertTrue(self.ip_utils._is_ipv4_in_subnet("192.168.1.10", "192.168.1.1"))
        # test ip and gateway are in the different subnet
        self.assertFalse(self.ip_utils._is_ipv4_in_subnet("192.168.2.10", "192.168.1.1"))
        # test ip is broadcast
        self.assertTrue(self.ip_utils._is_ipv4_in_subnet("192.168.1.255", "192.168.1.1"))
        # test subnet
        self.assertTrue(self.ip_utils._is_ipv4_in_subnet("192.168.1.0", "192.168.1.1"))
        # test invalid ip
        with self.assertRaises(ValueError):
            self.ip_utils._is_ipv4_in_subnet("invalid", "192.168.1.1")

    def test_get_ipv4_subnet(self):
        self.ip_utils.netmask = "255.255.255.0"
        gateway = "192.168.0.1"
        self.assertEqual("192.168.0.0/24", self.ip_utils._get_ipv4_subnet(gateway))
        self.ip_utils.netmask = "255.255.224.0"
        self.assertEqual("192.168.0.0/19", self.ip_utils._get_ipv4_subnet(gateway))

    def test_get_ipv4_gateway(self):
        self.ip_utils.gateways = ["192.168.1.1"]
        ip = "192.168.2.1"
        self.assertEqual(("", ""), self.ip_utils.get_ipv4_gateway(ip))
        ip = "192.168.1.10"
        self.assertEqual(("192.168.1.1", "192.168.1.0/24"), self.ip_utils.get_ipv4_gateway(ip))


class TestTemplate(TestBaseHCCN):

    def test_get_basic_defines(self):
        npu_id = 0
        self.template.netmask = "64"
        self.common_info.device_ip = "2001::1"
        self.common_info.detect_ip = "2001::1"
        # only ipv6
        self.template.working_on_ipv6 = True
        self.template.dscp_tc = ""
        self.template.bitmap = ""

        self.assertEqual(
            "IPv6address_0=2001::1\nIPv6netmask_0=64\nIPv6netdetect_0=2001::1\n",
            self.template._get_basic_defines(self.common_info, 0)
        )

        # only ipv4
        self.template.working_on_ipv6 = False
        self.template.netmask = "255.255.255.0"
        self.common_info.device_ip = "192.168.1.2"
        self.common_info.detect_ip = "192.168.1.2"
        self.assertEqual(
            "address_0=192.168.1.2\nnetmask_0=255.255.255.0\nnetdetect_0=192.168.1.2\n",
            self.template._get_basic_defines(self.common_info, 0)
        )

        # ipv4 + dscp + bitmap
        self.template.dscp_tc = "35:2,"
        self.template.bitmap = "0,0,0,0,1,0,0,0"
        self.assertEqual(
            "address_0=192.168.1.2\nnetmask_0=255.255.255.0\nnetdetect_0=192.168.1.2\ndscp_tc_0=35:2,\nbitmap_0=0,0,0,0,1,0,0,0\n",
            self.template._get_basic_defines(self.common_info, 0)
        )

    def test_get_gateway_defines(self):
        npu_id = 0
        # ipv6
        self.template.working_on_ipv6 = True
        self.common_info.gateway = "2001::1"
        self.assertEqual("IPv6gateway_0=2001::1\n", self.template._get_gateway_defines(self.common_info, npu_id))

        # ipv4
        self.template.working_on_ipv6 = False
        self.common_info.gateway = "192.168.1.0"
        self.assertEqual("gateway_0=192.168.1.0\n", self.template._get_gateway_defines(self.common_info, npu_id))

    def test_generate_hccn_conf(self):
        self.template.dscp_tc = ""
        self.template.bitmap = ""

        basic_info = {
            0: self.common_info("192.168.1.1", "192.168.1.0/24", "192.168.1.1", "192.168.1.0"),
            1: self.common_info("192.168.1.2", "192.168.1.0/24", "192.168.1.2", "192.168.1.0"),
        }
        is_standard_npu_card = False
        gateway_set = set("192.168.1.0")
        mac_addresses = {
            0: "00:1A:2B:3C:4D:5E",
            1: "12:34:56:78:9A:BC"
        }
        self.template.working_on_ipv6 = False

        # 910A1 + ipv4
        self.template.npu_name = NPUCardName.A910A1
        self.assertEqual("""address_0=192.168.1.1
netmask_0=255.255.255.0
netdetect_0=192.168.1.1
gateway_0=192.168.1.0
ip_rule_0=add from 192.168.1.1 table 100
ip_route_0=add None via 192.168.1.0 dev eth0 table 100
address_1=192.168.1.2
netmask_1=255.255.255.0
netdetect_1=192.168.1.2
gateway_1=192.168.1.0
ip_rule_1=add from 192.168.1.2 table 101
ip_route_1=add None via 192.168.1.0 dev eth1 table 101
""",
                         self.template.generate_hccn_conf(basic_info, is_standard_npu_card, gateway_set, mac_addresses))

        # 910A2 + ipv4
        self.template.npu_name = NPUCardName.A910A2
        self.assertEqual("""address_0=192.168.1.1
netmask_0=255.255.255.0
netdetect_0=192.168.1.1
gateway_0=192.168.1.0
address_1=192.168.1.2
netmask_1=255.255.255.0
netdetect_1=192.168.1.2
gateway_1=192.168.1.0
""",
                         self.template.generate_hccn_conf(basic_info, is_standard_npu_card, gateway_set, mac_addresses))

        # 910A3 + ipv4
        self.template.npu_name = NPUCardName.A910A3
        self.assertEqual("""address_0=192.168.1.1
netmask_0=255.255.255.0
netdetect_0=192.168.1.1
gateway_0=192.168.1.0
ip_rule_0=add from 192.168.1.1 table 100
ip_route_0=add None via 192.168.1.0 dev eth0 table 100
address_1=192.168.1.2
netmask_1=255.255.255.0
netdetect_1=192.168.1.2
gateway_1=192.168.1.0
ip_rule_1=add from 192.168.1.2 table 101
ip_route_1=add None via 192.168.1.0 dev eth1 table 101
""",
                         self.template.generate_hccn_conf(basic_info, is_standard_npu_card, gateway_set, mac_addresses))

        # 910A2 + ipv6
        self.template.working_on_ipv6 = True
        self.template.npu_name = NPUCardName.A910A2
        basic_info = {
            0: self.common_info("2001::1", "", "2001::1", "2001::1"),
            1: self.common_info("2001::2", "", "2001::2", "2001::1"),
        }
        gateway_set = set("2001::1")
        self.assertEqual("""IPv6address_0=2001::1
IPv6netmask_0=255.255.255.0
IPv6netdetect_0=2001::1
IPv6gateway_0=2001::1
IPv6address_1=2001::2
IPv6netmask_1=255.255.255.0
IPv6netdetect_1=2001::2
IPv6gateway_1=2001::1
""",
                         self.template.generate_hccn_conf(basic_info, is_standard_npu_card, gateway_set, mac_addresses))


class TestHccn(TestBaseHCCN):

    def test_ipv6_support(self):
        # 910A3 + ipv6 -> denied
        self.hccn.working_on_ipv6 = True
        self.hccn.npu_name = NPUCardName.A910A3
        with self.assertRaises(AnsibleModule.FailJson) as context:
            self.hccn.ipv6_support()
        res = self.parse_context_exception(context)
        self.assertTrue("Current NPU Do not support HCCN" in res.get("msg"))

        # 910A1 + ipv6 -> denied
        self.hccn.npu_name = NPUCardName.A910A1
        with self.assertRaises(AnsibleModule.FailJson) as context:
            self.hccn.ipv6_support()
        res = self.parse_context_exception(context)
        self.assertTrue("Current NPU Do not support HCCN" in res.get("msg"))

        # 910A2 + ipv6 -> passed
        self.hccn.npu_name = NPUCardName.A910A2
        self.hccn.ipv6_support()

        self.hccn.working_on_ipv6 = False
        # 910A1 + ipv4 -> passed
        self.hccn.npu_name = NPUCardName.A910A1
        self.hccn.ipv6_support()

        # 910A3 + ipv4 -> passed
        self.hccn.npu_name = NPUCardName.A910A3
        self.hccn.ipv6_support()