from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest
import torch

from tensor_cast.core.input_generator import (
    RequestInfo,
    _is_v4_model,
    _layer_uses_sparse_attention_indexer,
    _resolve_decoder_layers,
    _resolve_indexer_cache_dtype,
    _resolve_main_kv_cache_dtype,
    _resolve_sparse_attention_indexer_cache_width,
    _resolve_sparse_attention_kv_cache_width,
    _resolve_v4_kv_cache_size,
    generate_image_inputs,
    generate_inputs,
    generate_inputs_varlen,
    get_sparse_attention_indexer_cache_info,
)
from tensor_cast.layers.deepseek_v4 import DeepseekV4SparseAttention
from tensor_cast.device import TEST_DEVICE
from tensor_cast.performance_model.analytic import AnalyticPerformanceModel
from tensor_cast.runtime import Runtime
from tensor_cast.transformers.model import TransformerModel


@pytest.mark.parametrize("is_decode", [True, False])
def test_selected_token_indices_for_lmhead(qwen3_32b_lmhead_attention_transformer: TransformerModel, is_decode):
    model = qwen3_32b_lmhead_attention_transformer
    query_len = 100
    batch_size = 2
    inputs = generate_inputs(
        model,
        [
            RequestInfo(
                query_len=query_len,
                seq_len=query_len,
                concurrency=batch_size,
                is_decode=is_decode,
            )
        ],
    )
    if is_decode:
        output_shape = (1, batch_size * query_len, model.vocab_size)
    else:
        output_shape = (1, batch_size, model.vocab_size)

    machine_config = TEST_DEVICE
    perf_model = AnalyticPerformanceModel(machine_config)
    with Runtime(perf_model, machine_config), torch.no_grad():
        outputs = model.forward(**inputs)
    assert outputs.shape == output_shape


@pytest.mark.parametrize("is_decode", [True, False])
def test_varlen_selected_token_indices_for_lmhead(qwen3_32b_lmhead_attention_transformer: TransformerModel, is_decode):
    model = qwen3_32b_lmhead_attention_transformer
    query_len = [90, 110]
    batch_size = len(query_len)
    request_infos = []
    for i in range(batch_size):
        request_infos.append(
            RequestInfo(
                query_len=query_len[i],
                seq_len=query_len[i],
                is_decode=is_decode,
            )
        )
    inputs = generate_inputs_varlen(model, request_infos, 128)
    if is_decode:
        output_shape = (1, sum(query_len), model.vocab_size)
    else:
        output_shape = (1, batch_size, model.vocab_size)

    machine_config = TEST_DEVICE
    perf_model = AnalyticPerformanceModel(machine_config)
    with Runtime(perf_model, machine_config), torch.no_grad():
        outputs = model.forward(**inputs)
    assert outputs.shape == output_shape


_DSA_INDEXER_CACHE_QUERY_LEN = 32
_DSA_INDEXER_CACHE_NUM_MTP_TOKENS = 2
_DSA_INDEXER_CACHE_BLOCK_SIZE = 128
_DSA_INDEXER_CACHE_NUM_BLOCKS = (
    _DSA_INDEXER_CACHE_QUERY_LEN + _DSA_INDEXER_CACHE_NUM_MTP_TOKENS + 1 + _DSA_INDEXER_CACHE_BLOCK_SIZE - 1
) // _DSA_INDEXER_CACHE_BLOCK_SIZE


def test_dsa_indexer_cache_dtype_follows_attention_quant_config(
    deepseek_v32_build_model_int8,
):
    model = deepseek_v32_build_model_int8
    cache_info = get_sparse_attention_indexer_cache_info(
        model,
        num_blocks=_DSA_INDEXER_CACHE_NUM_BLOCKS,
        block_size=_DSA_INDEXER_CACHE_BLOCK_SIZE,
    )

    assert cache_info["indexer_cache_by_layers"][0].dtype == torch.int8


def test_dsa_indexer_cache_dtype_uses_fp8_when_attention_quant_is_fp8(
    deepseek_v32_build_model_fp8,
):
    model = deepseek_v32_build_model_fp8
    cache_info = get_sparse_attention_indexer_cache_info(
        model,
        num_blocks=_DSA_INDEXER_CACHE_NUM_BLOCKS,
        block_size=_DSA_INDEXER_CACHE_BLOCK_SIZE,
    )

    assert cache_info["indexer_cache_by_layers"][0].dtype == torch.float8_e4m3fn


def test_qwen3_vl_1080p_resize_to_1088x1920(
    qwen3_vl_8b_instruct_transformer: TransformerModel,
):
    model = qwen3_vl_8b_instruct_transformer

    image_kwargs = generate_image_inputs(
        model=model,
        image_batch_size=1,
        image_height=1080,
        image_width=1920,
        concurrency=1,
    )

    # grid_h=68, grid_w=120 -> resized height/width = 1088x1920
    assert torch.equal(image_kwargs["image_grid_thw"], torch.tensor([[1, 68, 120]]))


class TestSparseAttentionCacheHelpers:
    def test_resolve_kv_cache_width_from_attention_head_dim(self):
        model = MagicMock()
        model.text_config.kv_lora_rank = 512
        model.text_config.qk_rope_head_dim = 64
        attention = MagicMock(head_dim=480, _head_dim=None)
        assert _resolve_sparse_attention_kv_cache_width(model, attention) == 480

    def test_resolve_kv_cache_width_fallback_without_layer(self):
        model = MagicMock()
        model.text_config.kv_lora_rank = 512
        model.text_config.qk_rope_head_dim = 64
        assert _resolve_sparse_attention_kv_cache_width(model, None) == 576

    def test_resolve_indexer_cache_width_prefers_index_head_dim(self):
        model = MagicMock()
        model.text_config.index_head_dim = 999
        attention = MagicMock(_index_head_dim=128, indexer=None)
        assert _resolve_sparse_attention_indexer_cache_width(model, attention) == 128

    def test_resolve_indexer_cache_width_from_indexer_module(self):
        model = MagicMock()
        model.text_config.index_head_dim = None
        attention = MagicMock(_index_head_dim=None, indexer=MagicMock(head_dim=64))
        assert _resolve_sparse_attention_indexer_cache_width(model, attention) == 64

    def test_layer_uses_sparse_attention_indexer(self):
        assert _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=True))
        assert not _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=False, indexer=None))
        assert _layer_uses_sparse_attention_indexer(MagicMock(use_indexer=None, indexer=object()))

    def test_resolve_decoder_layers_direct_layout(self):
        layers = [MagicMock(), MagicMock()]
        model = MagicMock()
        model.unwrap.return_value = SimpleNamespace(layers=layers)
        assert _resolve_decoder_layers(model) is layers

    def test_resolve_decoder_layers_nested_causal_lm_layout(self):
        nested_layers = [MagicMock()]
        model = MagicMock()
        model.unwrap.return_value = SimpleNamespace(model=SimpleNamespace(layers=nested_layers))
        assert _resolve_decoder_layers(model) is nested_layers

    @patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
    def test_get_sparse_attention_indexer_cache_info_v4_layers(self, _mock_attn_quant):
        model = MagicMock()
        model.num_hidden_layers = 2
        model.model_config.mla_config = MagicMock(mla_cls=DeepseekV4SparseAttention)
        model.model_config.dtype = torch.bfloat16
        model.unwrap.return_value = MagicMock(
            layers=[
                MagicMock(self_attn=MagicMock(use_indexer=False, indexer=None)),
                MagicMock(
                    self_attn=MagicMock(
                        use_indexer=True,
                        _index_head_dim=128,
                        indexer=None,
                    )
                ),
            ]
        )

        info = get_sparse_attention_indexer_cache_info(model, num_blocks=4, block_size=16)

        assert 1 in info["indexer_cache_by_layers"]
        assert info["indexer_cache_by_layers"][1].shape == (4, 16, 128)
        assert info["indexer_cache_per_token"] > 0


class TestDeepseekV4KvCacheHelpers:
    @pytest.mark.parametrize(
        ("hf_model_type", "text_model_type", "expected"),
        [
            ("deepseek_v4", None, True),
            (None, "deepseek_v4", True),
            ("deepseek_v32", "deepseek_v32", False),
            (None, None, False),
        ],
    )
    def test_is_v4_model(self, hf_model_type, text_model_type, expected):
        model = MagicMock()
        model.model_config.hf_config = MagicMock(model_type=hf_model_type) if hf_model_type is not None else None
        model.text_config = MagicMock(model_type=text_model_type) if text_model_type is not None else None
        assert _is_v4_model(model) is expected

    @patch("tensor_cast.core.input_generator.get_attention_quant_config")
    def test_resolve_main_kv_cache_dtype_v4_ignores_attention_quant(self, mock_get_attn_quant):
        mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.float8_e4m3fn)
        model = MagicMock()
        model.model_config.dtype = torch.bfloat16
        model.model_config.hf_config = MagicMock(model_type="deepseek_v4")
        model.text_config = None

        assert _resolve_main_kv_cache_dtype(model, 0) == torch.bfloat16

    @patch("tensor_cast.core.input_generator.get_attention_quant_config")
    def test_resolve_main_kv_cache_dtype_non_v4_uses_attention_quant(self, mock_get_attn_quant):
        mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.float8_e4m3fn)
        model = MagicMock()
        model.model_config.dtype = torch.bfloat16
        model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
        model.text_config = None

        assert _resolve_main_kv_cache_dtype(model, 0) == torch.float8_e4m3fn

    @patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
    def test_resolve_main_kv_cache_dtype_non_v4_fallback_to_model_dtype(self, _mock_get_attn_quant):
        model = MagicMock()
        model.model_config.dtype = torch.bfloat16
        model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
        model.text_config = None

        assert _resolve_main_kv_cache_dtype(model, 0) == torch.bfloat16

    @patch("tensor_cast.core.input_generator.get_attention_quant_config")
    def test_resolve_indexer_cache_dtype_uses_attention_quant(self, mock_get_attn_quant):
        mock_get_attn_quant.return_value = MagicMock(get_quant_dtype=lambda: torch.int8)
        model = MagicMock()
        model.model_config.dtype = torch.bfloat16

        assert _resolve_indexer_cache_dtype(model, 0) == torch.int8

    @patch("tensor_cast.core.input_generator._resolve_sparse_attention_kv_cache_width", return_value=576)
    def test_resolve_v4_kv_cache_size_compressed_sparse_layer(self, _mock_head_dim):
        model = MagicMock()
        model.text_config.sliding_window = 128
        model.text_config.kv_lora_rank = 512
        model.text_config.qk_rope_head_dim = 64
        attention_layer = MagicMock(compress_ratio=4, head_dim=576)

        shape = _resolve_v4_kv_cache_size(
            model,
            attention_layer=attention_layer,
            num_blocks=100,
            block_size=128,
            batch_size=2,
            total_kv_tokens=8192,
        )

        # window_slots=256, compressed_slots=2048 -> total_slots=2304 -> 18 blocks
        assert shape == [18, 128, 576]

    @patch("tensor_cast.core.input_generator._resolve_sparse_attention_kv_cache_width", return_value=480)
    def test_resolve_v4_kv_cache_size_fallback_without_batch_info(self, _mock_head_dim):
        model = MagicMock()
        model.text_config.sliding_window = 128
        attention_layer = MagicMock(compress_ratio=4)

        shape = _resolve_v4_kv_cache_size(
            model,
            attention_layer=attention_layer,
            num_blocks=42,
            block_size=128,
        )

        assert shape == [42, 128, 480]

    @patch("tensor_cast.core.input_generator.get_attention_quant_config", return_value=None)
    def test_v4_indexer_cache_compression_only_for_v4_model(self, _mock_attn_quant):
        model = MagicMock()
        model.num_hidden_layers = 1
        model.model_config.mla_config = MagicMock(mla_cls=DeepseekV4SparseAttention)
        model.model_config.dtype = torch.bfloat16
        model.model_config.hf_config = MagicMock(model_type="deepseek_v32")
        model.text_config = None
        model.unwrap.return_value = MagicMock(
            layers=[
                MagicMock(
                    self_attn=MagicMock(
                        use_indexer=True,
                        _index_head_dim=128,
                        indexer=None,
                        compress_ratio=4,
                    )
                )
            ]
        )

        info = get_sparse_attention_indexer_cache_info(
            model,
            num_blocks=100,
            block_size=128,
            batch_size=2,
            total_kv_tokens=8192,
        )

        # Non-V4 models must keep the full paged pool even when compress_ratio is set.
        assert info["indexer_cache_by_layers"][0].shape[0] == 100