import re
import unittest

from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch_npu
from torch_npu.utils.collect_env import get_cann_version as get_cann_version_from_env
from torch_npu.npu.utils import get_cann_version, _is_gte_cann_version


class TestCANNversion(TestCase):
    @unittest.skip("ci error, beta version")
    def test_get_cann_version(self):
        version_env = get_cann_version_from_env()
        version = get_cann_version(module="CANN")
        if not version_env.startswith("CANN"):
            if version_env >= "8.1.RC1":
                is_match = (re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.T([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)\.alpha([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.([0-9]+)-alpha.([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.([0-9]+)\.alpha([0-9]+)$", version))
                self.assertTrue(is_match, f"The env version is {version_env}. The format of cann version {version} is invalid.")
            else:
                self.assertTrue(version == "", "When verssion_env < '8.1.RC1', the result of get_cann_version is not right.")

        version = get_cann_version(module="CAN")
        self.assertTrue(version == "", "When module is invalid, the result of get_cann_version is not right.")

    def test_get_driver_version(self):
        try:
            version = get_cann_version(module="DRIVER")
        except UnicodeDecodeError:
            print("Failed to get driver version. Your driver version is too old, or the environment information about the driver may be incomplete.")
            return
        if re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)\.B([0-9]+)$", version, re.IGNORECASE):
            version = re.sub(".B([0-9]+)", "", version, flags=re.IGNORECASE)
        if re.match("([0-9]+)\.", version):
            if version >= "25.":
                is_match = (re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)$", version, re.IGNORECASE)
                            or re.match("([0-9]+)\.([0-9]+)\.([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)\.([0-9]+)$", version, re.IGNORECASE)
                            or re.match("([0-9]+)\.([0-9]+)\.([0-9]+)\.([0-9]+)$", version)
                            or re.match("([0-9]+)\.([0-9]+)\.T([0-9]+)$", version, re.IGNORECASE)
                            or re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)\.beta([0-9]+)$", version, re.IGNORECASE)
                            or re.match("([0-9]+)\.([0-9]+)\.RC([0-9]+)\.alpha([0-9]+)$", version, re.IGNORECASE)
                            )
                self.assertTrue(is_match, f"The format of driver version {version} is invalid.")
            else:
                self.assertTrue(version == "", "When verssion_env < '25.', the result of get_cann_version is not right.")

    @unittest.skip("ci error, beta version")
    def test_compare_cann_version(self):
        version_env = get_cann_version_from_env()
        if not version_env.startswith("CANN") and version_env >= "8.1.RC1":
            result = _is_gte_cann_version("8.1.RC1", module="CANN")
            self.assertTrue(result, f"The env version is {version_env}, the result from _is_gte_cann_version is False")

            tags = get_cann_version(module="CANN")
            major = int(tags[0]) + 1
            result1 = _is_gte_cann_version(f"{major}.0.0", module="CANN")
            result2 = _is_gte_cann_version(f"{major}.0.T10", module="CANN")
            result3 = _is_gte_cann_version(f"{major}.0.RC1.alpha001", module="CANN")
            self.assertTrue(not result1 and not result2 and not result3, "the result from _is_gte_cann_version is not right.")

        else:
            with self.assertRaisesRegex(RuntimeError,
                    "When the version 7.0.0 is less than \"8.1.RC1\", this function is not supported"):
                _is_gte_cann_version("7.0.0", "CANN")


if __name__ == "__main__":
    run_tests()