from types import SimpleNamespace

import torch

from tensor_cast.layers.glm5 import Glm5SparseAttention
from tensor_cast.layers.mla import DeepseekSparseAttention
from tensor_cast.layers.mtp import MultiTokenPredictorLayer
from tensor_cast.model_config import MtpConfig
from tensor_cast.transformers.transformations import maybe_enable_mtp


def test_glm5_sparse_attention_returns_topk_slot(monkeypatch):
    sparse_attention = object.__new__(Glm5SparseAttention)

    monkeypatch.setattr(
        DeepseekSparseAttention,
        "forward",
        lambda self, *args, **kwargs: ("hidden", None),
    )

    assert Glm5SparseAttention.forward(sparse_attention, None, None, None) == ("hidden", None, None)


def test_glm5_mtp_layer_uses_hidden_states_from_tuple_block_output():
    class TupleBlock(torch.nn.Module):
        def forward(self, hidden_states, **_kwargs):
            return hidden_states + 1, None

    layer = MultiTokenPredictorLayer(
        SimpleNamespace(hidden_size=2, rms_norm_eps=1e-5),
        TupleBlock(),
    )
    layer.emb_norm = torch.nn.Identity()
    layer.hidden_norm = torch.nn.Identity()
    layer.linear_proj = torch.nn.Linear(4, 2, bias=False)
    with torch.no_grad():
        layer.linear_proj.weight.zero_()
        layer.linear_proj.weight[:, 2:] = torch.eye(2)

    output = layer(
        inputs_embeds=torch.zeros(1, 1, 2),
        position_ids=torch.zeros(1, 1, dtype=torch.long),
        previous_hidden_states=torch.tensor([[[2.0, 3.0]]]),
    )

    assert torch.equal(output, torch.tensor([[[3.0, 4.0]]]))


def test_glm5_mtp_extends_indexer_types(monkeypatch):
    captured = {}

    class FakeMtpWrapper:
        def __init__(self, mtp_config, hf_config, inner):
            captured["mtp_config"] = mtp_config
            captured["hf_config"] = hf_config
            captured["inner"] = inner

    class FakeModel:
        is_vl_model = False
        text_config = None
        _inner = object()
        hf_config = SimpleNamespace(
            indexer_types=["full", "shared"],
            layer_types=["full_attention", "full_attention"],
            mlp_layer_types=["sparse", "sparse"],
        )
        model_config = SimpleNamespace(
            mtp_config=MtpConfig(num_mtp_layers=3, mtp_block_module_name="GlmMoeDsaDecoderLayer"),
            dtype=torch.float32,
        )

        def unwrap(self):
            return SimpleNamespace()

    monkeypatch.setattr("tensor_cast.layers.mtp.MtpWrapper", FakeMtpWrapper)

    maybe_enable_mtp(FakeModel())

    assert captured["hf_config"].indexer_types == ["full", "shared", "shared", "shared", "shared"]