import os
import sys
import unittest
from unittest.mock import patch
from ascend_deployer.module_utils.inventory_file import (
    Var, Host, HostParams, InventoryFile, ConfigrationError, IPRange, StrTool
    )


class TestVar(unittest.TestCase):

    @patch("sys.exit")
    def test_to_string(self, mock_exit):

        test_cases = [
            ("netmask", "255.255.255.0", "netmask=255.255.255.0"),
            ("gateways", "", "gateways="),
            ("roce_port", "4791", "roce_port=4791"),
            ("common_network", "\"0.0.0.0/0\"", "common_network=\"0.0.0.0/0\""),
        ]
        for key, value, expect in test_cases:
            self.assertEqual(expect, Var(key, value).to_string())


class TestHost(unittest.TestCase):

    @patch("sys.exit")
    def test_to_string(self, mock_exit):

        test_cases = [
            ("10.10.1.1", "ansible_connection='local' ansible_ssh_user='root'", "10.10.1.1 ansible_connection='local' ansible_ssh_user='root'"),
        ]
        for ip, params, expect in test_cases:
            self.assertEqual(expect, Host(ip, params).to_string())


class TestIPRange(unittest.TestCase):

    @patch("sys.exit")
    def test_expand_ip_range(self, mock_exit):
        test_cases = [
            ("10.10.0.1-10.10.0.4", 1, ["10.10.0.1", "10.10.0.2", "10.10.0.3", "10.10.0.4"]),
            ("10.10.0.1-10.10.0.4", 4, ["10.10.0.1", "10.10.0.4"]),
            ("10.10.0.1-10.10.0.6", 2, ["10.10.0.1", "10.10.0.3", "10.10.0.5", "10.10.0.6"]),
        ]

        for ip_range, step_len, expect in test_cases:
            self.assertEqual(expect, IPRange(ip_range, step_len).expand_ip_range())

        error_cases = [
            ("1-1", "Parse ip range"),
            ("10.10.0.10-10.10.0.4", "less than to end IP"),
        ]
        for ip_range, error_msg in error_cases:
            with self.assertRaises(ConfigrationError) as context:
                IPRange(ip_range, 1).expand_ip_range()
            self.assertTrue(error_msg in str(context.exception))


class TestHostParam(unittest.TestCase):

    @patch("sys.exit")
    def test_convert_to_dict(self, mock_exit):
        test_cases = [
            ('ansible_ssh_user="root" ansible_ssh_pass="test1234" step_len=3', {"ansible_ssh_user": "\"root\"", "ansible_ssh_pass": "\"test1234\"", "step_len": "3"})
        ]

        for params, expect in test_cases:
            self.assertEqual(expect, HostParams(params)._convert_to_dict(params))

    @patch("sys.exit")
    def test_get_step_len(self, mock_exit):

        test_cases = [
            ('step_len=3', 3),
            ('step_len=1', 1),
            ('step_len=5', 5),
        ]

        for params, expect in test_cases:
            self.assertEqual(expect, HostParams(params).get_step_len())

        error_cases = [
            ('step_len=0', "bigger than 0"),
            ('step_len=-1', "bigger than 0"),
            ('step_len=ttt', "must be int"),
        ]

        for params, error_msg in error_cases:
            with self.assertRaises(ConfigrationError) as context:
                HostParams(params).get_step_len()
            self.assertTrue(error_msg in str(context.exception))

    @patch("sys.exit")
    def test_remove_step_len(self, mock_exit):
        test_cases = [
            ('ansible_ssh_user="root" ansible_ssh_pass="test1234" step_len=3', {"ansible_ssh_user": "\"root\"", "ansible_ssh_pass": "\"test1234\""}, 'ansible_ssh_user="root" ansible_ssh_pass="test1234"'),
            ('ansible_ssh_user="root" ansible_ssh_pass="test1234"', {"ansible_ssh_user": "\"root\"", "ansible_ssh_pass": "\"test1234\""}, 'ansible_ssh_user="root" ansible_ssh_pass="test1234"'),
        ]

        for params, expect1, expect2 in test_cases:
            host_param = HostParams(params)
            host_param.remove_step_len()
            self.assertEqual(expect1, host_param.params_dict)
            self.assertEqual(expect2, host_param.params)

    @patch("sys.exit")
    def test_generate_new_params_str_list(self, mock_exit):
        ips = ["10.10.10.1", "10.10.10.2"]
        test_cases = [
            (
                'ansible_ssh_user="root" ansible_ssh_pass="test1234" step_len=3',
                ['10.10.10.1 ansible_ssh_user="root" ansible_ssh_pass="test1234"',
                 '10.10.10.2 ansible_ssh_user="root" ansible_ssh_pass="test1234"']
            ),
            (
                'ansible_ssh_user="root" ansible_ssh_pass="test1234"',
                ['10.10.10.1 ansible_ssh_user="root" ansible_ssh_pass="test1234"',
                 '10.10.10.2 ansible_ssh_user="root" ansible_ssh_pass="test1234"']
            ),
            (
                'set_hostname="master-{index}"',
                ['10.10.10.1 set_hostname="master-1"', '10.10.10.2 set_hostname="master-2"']
            ),
            (
                'set_hostname="master-{str(index+1)+\'x\'}"',
                ['10.10.10.1 set_hostname="master-2x"', '10.10.10.2 set_hostname="master-3x"']
            ),
            (
                'local_ip_port="{ip}-{index}:8080"',
                ['10.10.10.1 local_ip_port="10.10.10.1-1:8080"', '10.10.10.2 local_ip_port="10.10.10.2-2:8080"']
            ),
            (
                'local_ip_port="{ip}-{index+1}:8080"',
                ['10.10.10.1 local_ip_port="10.10.10.1-2:8080"', '10.10.10.2 local_ip_port="10.10.10.2-3:8080"']
            ),
            (
                'local_ip_port="{ip}-{str(index+int("20"))+"x"}:8080"',
                ['10.10.10.1 local_ip_port="10.10.10.1-21x:8080"', '10.10.10.2 local_ip_port="10.10.10.2-22x:8080"']
            )
        ]

        for params, expect in test_cases:
            self.assertEqual(expect, HostParams(params).generate_new_params_str_list(ips))


class TestInventoryFile(unittest.TestCase):

    def setUp(self):
        self.inventory_file = InventoryFile()

    def test_get_parsed_inventory_file_path(self):
        self.inventory_file.get_parsed_inventory_file_path()


class TestStrTool(unittest.TestCase):

    @patch("sys.exit")
    def test_safe_eval(self, mock_exit):
        test_cases = [
            ("'master-'+str(1+20)+'x'", "master-21x"),
            ("'master-'+'d'+'d'+str(1+1+1+1)", "master-dd4"),
            ("'master.'+str(1+1)", "master.2")
        ]
        for param, expect in test_cases:
            self.assertEqual(expect, StrTool.safe_eval(param))
        
        error_cases = [
            "__import__('os')",
            "'None'[0]",
            "open('/etc')",
            "'ddd'.upper()"
            "'master-'+{open('/etc/passwd')}",
        ]
        for case in error_cases:
            with self.assertRaises(Exception):
                StrTool.safe_eval(case)


if __name__ == '__main__':
    unittest.main()