import importlib.util
import os
import sys
import types
import unittest
from unittest.mock import patch


_project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
_patcher_dir = os.path.join(_project_root, "mx_driving", "patcher")
_PATCHER_MODULE_NAMES = [
    "mx_driving.patcher",
    "mx_driving.patcher.patcher_logger",
    "mx_driving.patcher.reporting",
    "mx_driving.patcher.version",
    "mx_driving.patcher.patch",
    "mx_driving.patcher.patcher",
    "mx_driving.patcher.legacy",
]


def _load_module_from_file(module_name: str, file_path: str):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def _backup_patcher_modules():
    return {name: sys.modules.get(name) for name in _PATCHER_MODULE_NAMES}


def _restore_patcher_modules(backup):
    for name in _PATCHER_MODULE_NAMES:
        sys.modules.pop(name, None)
    for name, module in backup.items():
        if module is not None:
            sys.modules[name] = module


def _load_patcher_modules():
    patcher_logger_module = _load_module_from_file(
        "mx_driving.patcher.patcher_logger",
        os.path.join(_patcher_dir, "patcher_logger.py"),
    )
    _load_module_from_file(
        "mx_driving.patcher.reporting",
        os.path.join(_patcher_dir, "reporting.py"),
    )
    _load_module_from_file(
        "mx_driving.patcher.version",
        os.path.join(_patcher_dir, "version.py"),
    )
    patch_module = _load_module_from_file(
        "mx_driving.patcher.patch",
        os.path.join(_patcher_dir, "patch.py"),
    )
    patcher_module = _load_module_from_file(
        "mx_driving.patcher.patcher",
        os.path.join(_patcher_dir, "patcher.py"),
    )
    legacy_module = _load_module_from_file(
        "mx_driving.patcher.legacy",
        os.path.join(_patcher_dir, "legacy.py"),
    )
    patcher_init_module = _load_module_from_file(
        "mx_driving.patcher",
        os.path.join(_patcher_dir, "__init__.py"),
    )
    return {
        "AtomicPatch": patch_module.AtomicPatch,
        "Patch": patch_module.Patch,
        "Patcher": patcher_module.Patcher,
        "LegacyPatcherBuilder": legacy_module.LegacyPatcherBuilder,
        "LegacyPatchWrapper": legacy_module.LegacyPatchWrapper,
        "patcher_logger": patcher_logger_module.patcher_logger,
        "patcher_init_module": patcher_init_module,
    }


def _clear_logger_buffers(patcher_logger):
    patcher_logger._applied_patches.clear()
    patcher_logger._skipped_patches.clear()
    patcher_logger._failed_patches.clear()
    patcher_logger._skipped_modules.clear()
    patcher_logger._injected_imports.clear()


class _BaseIsolatedPatcherTest(unittest.TestCase):
    def setUp(self):
        self._module_backup = _backup_patcher_modules()
        modules = _load_patcher_modules()
        self.AtomicPatch = modules["AtomicPatch"]
        self.Patch = modules["Patch"]
        self.Patcher = modules["Patcher"]
        self.LegacyPatcherBuilder = modules["LegacyPatcherBuilder"]
        self.LegacyPatchWrapper = modules["LegacyPatchWrapper"]
        self.patcher_logger = modules["patcher_logger"]
        self._patcher_init_module = modules["patcher_init_module"]
        _clear_logger_buffers(self.patcher_logger)

    def tearDown(self):
        _clear_logger_buffers(self.patcher_logger)
        _restore_patcher_modules(self._module_backup)


class TestPatchNameCompatibility(_BaseIsolatedPatcherTest):
    def setUp(self):
        super().setUp()
        self.mock_module = types.ModuleType("patch_name_compat_test")
        self.mock_module.func1 = lambda x: x
        self.mock_module.func2 = lambda x: x
        sys.modules["patch_name_compat_test"] = self.mock_module

    def tearDown(self):
        sys.modules.pop("patch_name_compat_test", None)
        super().tearDown()

    def test_patch_name_defaults_to_class_name(self):
        class AutoNamedPatch(self.Patch):
            @classmethod
            def patches(cls, options=None):
                return [self.AtomicPatch("patch_name_compat_test.func1", lambda x: x * 2)]

        self.assertEqual(AutoNamedPatch.name, "AutoNamedPatch")

        patcher = self.Patcher()
        patcher.add(AutoNamedPatch).apply()

        self.assertEqual(self.mock_module.func1(5), 10)

    def test_explicit_patch_name_is_preserved(self):
        class ExplicitlyNamedPatch(self.Patch):
            name = "stable_patch_id"

            @classmethod
            def patches(cls, options=None):
                return [self.AtomicPatch("patch_name_compat_test.func1", lambda x: x * 3)]

        self.assertEqual(ExplicitlyNamedPatch.name, "stable_patch_id")

    def test_disable_accepts_patch_classes_and_patch_instances(self):
        class AutoDisabledPatch(self.Patch):
            @classmethod
            def patches(cls, options=None):
                return [self.AtomicPatch("patch_name_compat_test.func1", lambda x: x * 5)]

        direct_patch = self.AtomicPatch("patch_name_compat_test.func2", lambda x: x * 7)

        patcher = self.Patcher()
        patcher.add(AutoDisabledPatch, direct_patch)
        patcher.disable(AutoDisabledPatch, direct_patch)
        patcher.apply()

        self.assertIn("AutoDisabledPatch", patcher._blacklist)
        self.assertIn("patch_name_compat_test.func2", patcher._blacklist)
        self.assertEqual(self.mock_module.func1(5), 5)
        self.assertEqual(self.mock_module.func2(5), 5)


class TestLegacyUnknownLabelCompatibility(_BaseIsolatedPatcherTest):
    def setUp(self):
        super().setUp()
        self.mock_module = types.ModuleType("legacy_label_test")
        self.mock_module.func = lambda x: x
        sys.modules["legacy_label_test"] = self.mock_module
        self._legacy_name_to_class = self._patcher_init_module._LEGACY_NAME_TO_CLASS
        self._legacy_name_to_class_backup = dict(self._legacy_name_to_class)
        self.LegacyPatchWrapper._migration_warning_shown = False

    def tearDown(self):
        self._legacy_name_to_class.clear()
        self._legacy_name_to_class.update(self._legacy_name_to_class_backup)
        sys.modules.pop("legacy_label_test", None)
        super().tearDown()

    def test_custom_legacy_patch_uses_its_own_group_name(self):
        def custom_legacy(module, options):
            module.func = lambda x: x * 3

        builder = self.LegacyPatcherBuilder()
        builder.add_module_patch("legacy_label_test", self.Patch(custom_legacy))

        with patch.object(self.patcher_logger, "flush_summary", return_value=None):
            builder.build().apply()

        self.assertEqual(self.mock_module.func(5), 15)
        self.assertEqual(len(self.patcher_logger._applied_patches), 1)
        info = self.patcher_logger._applied_patches[0]
        self.assertEqual(info.patch_name, "custom_legacy")
        self.assertEqual(info.target, "custom_legacy")
        self.assertEqual(info.package, "legacy_label_test")

    def test_known_legacy_patch_expands_to_detailed_children(self):
        class DetailedPatch(self.Patch):
            name = "detailed_patch"

            @classmethod
            def patches(cls, options=None):
                multiplier = (options or {}).get("multiplier", 4)
                return [
                    self.AtomicPatch("legacy_label_test.func", lambda x, m=multiplier: x * m),
                ]

        def mapped_legacy(module, options):
            raise AssertionError("Known legacy patch should expand via its Patch class")

        self._legacy_name_to_class["mapped_legacy"] = DetailedPatch

        builder = self.LegacyPatcherBuilder()
        builder.add_module_patch(
            "legacy_label_test",
            self.Patch(mapped_legacy, options={"multiplier": 4}),
        )

        with patch.object(self.patcher_logger, "flush_summary", return_value=None):
            builder.build().apply()

        self.assertEqual(self.mock_module.func(5), 20)
        self.assertEqual(len(self.patcher_logger._applied_patches), 1)
        info = self.patcher_logger._applied_patches[0]
        self.assertEqual(info.patch_name, "detailed_patch")
        self.assertEqual(info.target, "legacy_label_test.func")
        self.assertEqual(info.package, "legacy_label_test")

    def test_known_and_custom_legacy_patches_keep_legacy_order(self):
        class MappedSecondPatch(self.Patch):
            name = "mapped_second_patch"

            @classmethod
            def patches(cls, options=None):
                return [self.AtomicPatch("legacy_label_test.func", lambda x: "mapped-second")]

        def custom_first(module, options):
            module.func = lambda x: "custom-first"

        def mapped_second(module, options):
            raise AssertionError("Known legacy patch should expand via its Patch class")

        self._legacy_name_to_class["mapped_second"] = MappedSecondPatch

        builder = self.LegacyPatcherBuilder()
        builder.add_module_patch(
            "legacy_label_test",
            self.Patch(custom_first, priority=1),
            self.Patch(mapped_second, priority=2),
        )

        with patch.object(self.patcher_logger, "flush_summary", return_value=None):
            builder.build().apply()

        self.assertEqual(self.mock_module.func(0), "mapped-second")
        self.assertEqual(
            [info.patch_name for info in self.patcher_logger._applied_patches],
            ["custom_first", "mapped_second_patch"],
        )


if __name__ == "__main__":
    unittest.main()