"""Unit tests for FSDP utility registries, dtype helpers, and dataclass helpers."""

import os
from dataclasses import dataclass, field, is_dataclass

import pytest


class TestDtype:
    @pytest.mark.parametrize(
        "name,torch_attr",
        [
            ("fp16", "float16"),
            ("bf16", "bfloat16"),
            ("fp32", "float32"),
            ("fp64", "float64"),
            ("int8", "int8"),
            ("int16", "int16"),
            ("int32", "int32"),
            ("int64", "int64"),
        ],
    )
    def test_get_dtype_maps_supported_names(self, name, torch_attr):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.dtype import get_dtype

        assert get_dtype(name) is getattr(torch, torch_attr)

    @pytest.mark.parametrize(
        "name",
        [
            "",
            "float16",
            "float32",
            "FP16",
            "bf 16",
            "uint8",
            "bool",
            None,
        ],
    )
    def test_get_dtype_rejects_unsupported_names(self, name):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.dtype import get_dtype

        with pytest.raises(ValueError, match="Unsupported dtype"):
            get_dtype(name)


class TestRegister:
    def test_register_and_get_returns_original_object(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.register import Register

        registry = Register()

        @registry.register("model")
        class Model:
            pass

        assert registry.get("model") is Model

    def test_register_decorator_returns_registered_object(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.register import Register

        registry = Register()

        def fn():
            return "ok"

        decorated = registry.register("fn")(fn)

        assert decorated is fn
        assert registry.get("fn") is fn

    def test_register_rejects_duplicate_id(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.register import Register

        registry = Register()

        @registry.register("duplicate")
        class First:
            pass

        with pytest.raises(KeyError, match="already registered"):

            @registry.register("duplicate")
            class Second:
                pass

    @pytest.mark.parametrize(
        "missing_id",
        [
            "missing",
            "",
            None,
            123,
        ],
    )
    def test_register_get_rejects_missing_id(self, missing_id):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.register import Register

        registry = Register()

        with pytest.raises(KeyError, match="not registered"):
            registry.get(missing_id)


class TestParamsUtils:
    def test_create_nested_dataclass_builds_nested_defaults(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.params.utils import create_nested_dataclass

        Config = create_nested_dataclass(
            "Config",
            {
                "model": {
                    "hidden_size": 4096,
                    "dropout": 0.1,
                },
                "enabled": True,
                "names": ["q_proj", "v_proj"],
                "tags": {"vision", "language"},
            },
        )

        cfg = Config()

        assert is_dataclass(Config)
        assert cfg.model.hidden_size == 4096
        assert cfg.model.dropout == pytest.approx(0.1)
        assert cfg.enabled is True
        assert cfg.names == []
        assert cfg.tags == set()

    def test_create_nested_dataclass_uses_independent_mutable_defaults(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.params.utils import create_nested_dataclass

        Config = create_nested_dataclass(
            "Config",
            {
                "items": [],
                "labels": set(),
                "nested": {"values": []},
            },
        )

        left = Config()
        right = Config()

        left.items.append("left")
        left.labels.add("tag")
        left.nested.values.append(1)

        assert right.items == []
        assert right.labels == set()
        assert right.nested.values == []

    def test_allow_extra_fields_adds_unknown_scalar_attributes(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.params.utils as params_utils

        monkeypatch.setattr(params_utils, "print_rank", lambda *args, **kwargs: None)

        @params_utils.allow_extra_fields
        @dataclass
        class Config:
            name: str = "default"

        cfg = Config(name="custom", lr=0.01, enabled=True)

        assert cfg.name == "custom"
        assert cfg.lr == pytest.approx(0.01)
        assert cfg.enabled is True
        assert cfg._extra_fields == {"lr": 0.01, "enabled": True}

    def test_allow_extra_fields_adds_unknown_nested_dict_as_dataclass(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.params.utils as params_utils

        monkeypatch.setattr(params_utils, "print_rank", lambda *args, **kwargs: None)

        @params_utils.allow_extra_fields
        @dataclass
        class Config:
            name: str = "default"

        cfg = Config(
            name="custom",
            optimizer={
                "lr": 0.001,
                "betas": {
                    "beta1": 0.9,
                    "beta2": 0.95,
                },
            },
        )

        assert cfg.optimizer.lr == pytest.approx(0.001)
        assert cfg.optimizer.betas.beta1 == pytest.approx(0.9)
        assert cfg.optimizer.betas.beta2 == pytest.approx(0.95)
        assert cfg._extra_fields == {
            "optimizer": {
                "lr": 0.001,
                "betas": {
                    "beta1": 0.9,
                    "beta2": 0.95,
                },
            }
        }

    def test_instantiate_dataclass_recursively_instantiates_nested_dataclasses(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.params.utils as params_utils

        monkeypatch.setattr(params_utils, "print_rank", lambda *args, **kwargs: None)

        @params_utils.allow_extra_fields
        @dataclass
        class Inner:
            width: int = 1
            depth: int = 2

        @params_utils.allow_extra_fields
        @dataclass
        class Outer:
            inner: Inner = field(default_factory=Inner)
            name: str = "outer"

        cfg = params_utils.instantiate_dataclass(
            Outer,
            {
                "name": "configured",
                "inner": {
                    "width": 16,
                    "depth": 32,
                },
                "new_field": "accepted",
            },
        )

        assert isinstance(cfg, Outer)
        assert isinstance(cfg.inner, Inner)
        assert cfg.name == "configured"
        assert cfg.inner.width == 16
        assert cfg.inner.depth == 32
        assert cfg.new_field == "accepted"

    def test_instantiate_dataclass_returns_input_for_non_dataclass_type(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.params.utils import instantiate_dataclass

        data = {"a": 1}

        assert instantiate_dataclass(dict, data) is data

    def test_instantiate_dataclass_wraps_type_hint_failures(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.params.utils as params_utils

        @dataclass
        class Config:
            value: "MissingType"

        with pytest.raises(RuntimeError, match="Failed to get type hints"):
            params_utils.instantiate_dataclass(Config, {"value": 1})


class TestRuntimeUtils:
    def test_configure_hsdp_gradient_sync_sets_both_flags(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.utils import configure_hsdp_gradient_sync

        class Model:
            def __init__(self):
                self.last_backward_values = []
                self.requires_all_reduce_values = []

            def set_is_last_backward(self, value):
                self.last_backward_values.append(value)

            def set_requires_all_reduce(self, value):
                self.requires_all_reduce_values.append(value)

        model = Model()

        configure_hsdp_gradient_sync(model, True)
        configure_hsdp_gradient_sync(model, False)

        assert model.last_backward_values == [True, False]
        assert model.requires_all_reduce_values == [True, False]

    def test_singleton_metaclass_reuses_instance(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.utils.decorators import Singleton

        class Example(metaclass=Singleton):
            def __init__(self, value):
                self.value = value

        try:
            first = Example("first")
            second = Example("second")

            assert first is second
            assert second.value == "first"
        finally:
            Singleton._instances.pop(Example, None)