"""Unit tests for FSDP checkpoint load utility helpers."""

import os

import pytest


class TestChunkList:
    @pytest.mark.parametrize(
        "items,chunk_size,expected",
        [
            ([], 1, [[]]),
            ([], 3, [[], [], []]),
            ([1], 1, [[1]]),
            ([1], 2, [[1], []]),
            ([1, 2], 1, [[1, 2]]),
            ([1, 2], 2, [[1], [2]]),
            ([1, 2], 3, [[1], [2], []]),
            ([1, 2, 3], 2, [[1, 2], [3]]),
            ([1, 2, 3, 4], 2, [[1, 2], [3, 4]]),
            ([1, 2, 3, 4, 5], 2, [[1, 2, 3], [4, 5]]),
            ([1, 2, 3, 4, 5], 3, [[1, 2], [3, 4], [5]]),
            ([1, 2, 3, 4, 5, 6, 7], 3, [[1, 2, 3], [4, 5], [6, 7]]),
            ([1, 2, 3, 4, 5, 6, 7], 4, [[1, 2], [3, 4], [5, 6], [7]]),
            (list(range(10)), 3, [list(range(4)), list(range(4, 7)), list(range(7, 10))]),
            (list(range(10)), 4, [list(range(3)), list(range(3, 6)), list(range(6, 8)), list(range(8, 10))]),
            (list(range(12)), 5, [list(range(3)), list(range(3, 6)), list(range(6, 8)), list(range(8, 10)), list(range(10, 12))]),
        ],
    )
    def test_chunk_list_balances_remainder_to_earlier_chunks(self, items, chunk_size, expected):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import chunk_list

        assert chunk_list(items, chunk_size) == expected

    @pytest.mark.parametrize(
        "chunk_size",
        [1, 2, 3, 4, 5, 6, 7, 8],
    )
    def test_chunk_list_keeps_original_order(self, chunk_size):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import chunk_list

        items = [f"param_{idx}" for idx in range(17)]
        chunks = chunk_list(items, chunk_size)

        assert [item for chunk in chunks for item in chunk] == items

    @pytest.mark.parametrize(
        "length,chunk_size",
        [
            (0, 1),
            (0, 5),
            (1, 4),
            (2, 8),
            (5, 2),
            (5, 7),
            (16, 4),
            (17, 4),
            (31, 6),
            (64, 9),
        ],
    )
    def test_chunk_list_returns_requested_number_of_chunks(self, length, chunk_size):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import chunk_list

        chunks = chunk_list(list(range(length)), chunk_size)

        assert len(chunks) == chunk_size

    @pytest.mark.parametrize(
        "length,chunk_size",
        [
            (1, 2),
            (2, 3),
            (3, 2),
            (5, 2),
            (5, 3),
            (5, 4),
            (9, 4),
            (10, 4),
            (11, 4),
            (15, 6),
            (29, 7),
        ],
    )
    def test_chunk_list_chunk_lengths_differ_by_at_most_one(self, length, chunk_size):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import chunk_list

        chunks = chunk_list(list(range(length)), chunk_size)
        lengths = [len(chunk) for chunk in chunks]

        assert max(lengths) - min(lengths) <= 1

    def test_chunk_list_raises_when_chunk_size_is_zero(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import chunk_list

        with pytest.raises(ZeroDivisionError):
            chunk_list([1, 2, 3], 0)


class TestParamInfo:
    def test_param_info_defaults_to_empty_metadata(self):
        pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import ParamInfo

        info = ParamInfo()

        assert info.name is None
        assert info.shape is None
        assert info.dtype is None
        assert info.prefix is None

    def test_param_info_stores_tensor_metadata(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import ParamInfo

        info = ParamInfo(
            name="model.layers.0.weight",
            shape=torch.Size([2, 3]),
            dtype=torch.float32,
            prefix="model",
        )

        assert info.name == "model.layers.0.weight"
        assert info.shape == torch.Size([2, 3])
        assert info.dtype is torch.float32
        assert info.prefix == "model"

    def test_param_info_equality_uses_dataclass_value_semantics(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.load_utils import ParamInfo

        left = ParamInfo("weight", torch.Size([4]), torch.bfloat16, "optimizer")
        right = ParamInfo("weight", torch.Size([4]), torch.bfloat16, "optimizer")
        different = ParamInfo("bias", torch.Size([4]), torch.bfloat16, "optimizer")

        assert left == right
        assert left != different