"""Tests for theory-guided shape grid model configs."""

import importlib
import sys
from pathlib import Path


GRID_GENERATOR_DIR = Path(__file__).resolve().parents[3] / "tools" / "perf_data_collection" / "grid_generator"
if str(GRID_GENERATOR_DIR) not in sys.path:
    sys.path.insert(0, str(GRID_GENERATOR_DIR))

model_configs = importlib.import_module("model_configs")
get_matmul_nk_pairs = model_configs.get_matmul_nk_pairs
resolve_configs = model_configs.resolve_configs


def fail_fetch(model_name, model_id):
    raise RuntimeError(f"offline: {model_name} {model_id}")


def test_glm51_aliases_resolve_to_static_fallback(monkeypatch):
    """GLM-5.1 aliases should work even when remote config loading is unavailable."""
    monkeypatch.setattr(model_configs, "_fetch_from_huggingface", fail_fetch)
    model_configs._RESOLVED_CONFIGS.clear()

    by_short, by_hf = resolve_configs(["GLM-5.1", "zai-org/GLM-5.1"])

    assert by_short.name == "GLM-5.1"
    assert by_hf.hidden_size == 6144
    assert by_hf.q_lora_rank == 2048
    assert by_hf.kv_lora_rank == 512
    assert by_hf.head_dim == 256
    assert by_hf.expert_intermediate_size == 2048


def test_all_model_configs_dedupe_aliases():
    configs = resolve_configs(None)

    assert len(configs) == len(set(configs))
    assert [cfg.name for cfg in configs].count("GLM-5.1") == 1


def test_glm51_mla_pairs_include_q_and_kv_projection_dims(monkeypatch):
    """GLM-5.1 MLA matmul candidates should use 256-wide qk/v heads, not HF head_dim=64."""
    monkeypatch.setattr(model_configs, "_fetch_from_huggingface", fail_fetch)
    model_configs._RESOLVED_CONFIGS.clear()

    pairs = get_matmul_nk_pairs(["GLM5.1"])

    assert (2048, 6144) in pairs
    assert (576, 6144) in pairs
    assert (16384, 2048) in pairs
    assert (16384, 512) in pairs