"""Unit tests for FSDP MoE checkpoint planner helper behavior."""

import os
from dataclasses import dataclass

import pytest


@dataclass(frozen=True)
class _FakeReadItem:
    storage_offsets: object
    lengths: object
    dest_offsets: object = None
    fqn: str = "param"


class _FakeMesh:
    def __init__(self, mesh_dim_names):
        self.mesh_dim_names = mesh_dim_names


class _FakeDTensor:
    def __init__(self, mesh_dim_names):
        self.device_mesh = _FakeMesh(mesh_dim_names)


class _FakeModel:
    def __init__(self, params):
        self._params = params

    def named_parameters(self):
        return iter(self._params)


class TestGetChunkReadItem:
    @pytest.mark.parametrize(
        "offsets,lengths,ep_rank,expected_offsets",
        [
            ([0], [1], 0, [0]),
            ([0], [1], 1, [1]),
            ([0], [1], 2, [2]),
            ([3], [7], 0, [3]),
            ([3], [7], 1, [10]),
            ([3], [7], 4, [31]),
            ([0, 0], [2, 5], 0, [0, 0]),
            ([0, 0], [2, 5], 1, [2, 0]),
            ([1, 3], [4, 8], 2, [9, 3]),
            ([5, 6, 7], [10, 11, 12], 3, [35, 6, 7]),
        ],
    )
    def test_get_chunk_readitem_offsets_first_dimension_by_ep_rank(
        self,
        offsets,
        lengths,
        ep_rank,
        expected_offsets,
    ):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.moe_utils import get_chunk_readitem

        readitem = _FakeReadItem(
            storage_offsets=torch.Size(offsets),
            lengths=torch.Size(lengths),
            dest_offsets=torch.Size([0] * len(offsets)),
        )

        chunked = get_chunk_readitem(readitem, ep_rank)

        assert chunked is not readitem
        assert chunked.storage_offsets == torch.Size(expected_offsets)
        assert chunked.lengths == torch.Size(lengths)
        assert chunked.dest_offsets == torch.Size([0] * len(offsets))

    @pytest.mark.parametrize(
        "operate_dim,expected_offsets",
        [
            (0, [10, 4, 6]),
            (1, [2, 24, 6]),
            (2, [2, 4, 36]),
        ],
    )
    def test_get_chunk_readitem_can_offset_selected_dimension(self, operate_dim, expected_offsets):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.moe_utils import get_chunk_readitem

        readitem = _FakeReadItem(
            storage_offsets=torch.Size([2, 4, 6]),
            lengths=torch.Size([4, 10, 15]),
        )

        chunked = get_chunk_readitem(readitem, ep_rank=2, operate_dim=operate_dim)

        assert chunked.storage_offsets == torch.Size(expected_offsets)

    def test_get_chunk_readitem_raises_when_offsets_and_lengths_have_different_rank(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.moe_utils import get_chunk_readitem

        readitem = _FakeReadItem(
            storage_offsets=torch.Size([0, 1]),
            lengths=torch.Size([2]),
        )

        with pytest.raises(ValueError, match="same size"):
            get_chunk_readitem(readitem, ep_rank=1)

    def test_get_chunk_readitem_preserves_non_offset_fields(self):
        torch = pytest.importorskip("torch")
        from mindspeed_mm.fsdp.checkpoint.moe_utils import get_chunk_readitem

        readitem = _FakeReadItem(
            storage_offsets=torch.Size([1, 2]),
            lengths=torch.Size([3, 4]),
            dest_offsets=torch.Size([5, 6]),
            fqn="moe.experts.weight",
        )

        chunked = get_chunk_readitem(readitem, ep_rank=3)

        assert chunked.fqn == "moe.experts.weight"
        assert chunked.dest_offsets == torch.Size([5, 6])
        assert chunked.lengths == torch.Size([3, 4])


class TestGetCheckMoeFunc:
    def test_check_moe_func_matches_dtensor_params_on_efsdp_mesh(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.checkpoint.moe_utils as moe_utils

        monkeypatch.setattr(moe_utils, "DTensor", _FakeDTensor)
        model = _FakeModel(
            [
                ("layers.0.mlp.experts.0.weight", _FakeDTensor(["dp", "efsdp"])),
                ("layers.0.mlp.shared.weight", object()),
                ("layers.1.attn.q_proj.weight", _FakeDTensor(["dp", "tp"])),
            ]
        )

        check_moe = moe_utils.get_check_moe_func(model)

        assert check_moe("layers.0.mlp.experts.0.weight") is True
        assert check_moe("module.layers.0.mlp.experts.0.weight") is True
        assert check_moe("layers.0.mlp.shared.weight") is False
        assert check_moe("layers.1.attn.q_proj.weight") is False

    def test_check_moe_func_strips_recompute_prefix_from_model_params(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.checkpoint.moe_utils as moe_utils

        monkeypatch.setattr(moe_utils, "DTensor", _FakeDTensor)
        model = _FakeModel(
            [
                ("_checkpoint_wrapped_module.layers.4.experts.down_proj.weight", _FakeDTensor(["efsdp"])),
            ]
        )

        check_moe = moe_utils.get_check_moe_func(model)

        assert check_moe("layers.4.experts.down_proj.weight") is True
        assert check_moe("_checkpoint_wrapped_module.layers.4.experts.down_proj.weight") is True
        assert check_moe("layers.4.router.weight") is False

    def test_check_moe_func_returns_false_when_no_efsdp_params_exist(self, monkeypatch):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.checkpoint.moe_utils as moe_utils

        monkeypatch.setattr(moe_utils, "DTensor", _FakeDTensor)
        model = _FakeModel(
            [
                ("layers.0.experts.weight", _FakeDTensor(["dp"])),
                ("layers.0.router.weight", _FakeDTensor(["tp"])),
                ("layers.0.norm.weight", object()),
            ]
        )

        check_moe = moe_utils.get_check_moe_func(model)

        assert check_moe("layers.0.experts.weight") is False
        assert check_moe("layers.0.router.weight") is False
        assert check_moe("anything") is False

    @pytest.mark.parametrize(
        "candidate,expected",
        [
            ("blocks.0.moe.experts.0.w1.weight", True),
            ("prefix.blocks.0.moe.experts.0.w1.weight", True),
            ("_checkpoint_wrapped_module.blocks.0.moe.experts.0.w1.weight", True),
            ("blocks.0.moe.experts.0.w2.weight", False),
            ("blocks.1.moe.experts.0.w1.weight", False),
            ("blocks.0.moe.router.weight", False),
        ],
    )
    def test_check_moe_func_uses_suffix_matching_for_wrapped_names(self, monkeypatch, candidate, expected):
        pytest.importorskip("torch")
        import mindspeed_mm.fsdp.checkpoint.moe_utils as moe_utils

        monkeypatch.setattr(moe_utils, "DTensor", _FakeDTensor)
        model = _FakeModel(
            [
                ("blocks.0.moe.experts.0.w1.weight", _FakeDTensor(["dp", "efsdp"])),
            ]
        )

        check_moe = moe_utils.get_check_moe_func(model)

        assert check_moe(candidate) is expected