"""Regression tests for scripts.prefetch_model_configs."""

from __future__ import annotations

import json
import logging
import sys
from pathlib import Path
from types import SimpleNamespace

import pytest

import scripts.prefetch_model_configs as prefetch
from tests.helpers.cli_runner import run_module_main


@pytest.fixture(autouse=True)
def _clean_prefetch_env(monkeypatch: pytest.MonkeyPatch) -> None:
    for key in (
        "HF_HOME",
        "TORCH_HOME",
        "MODELSCOPE_CACHE",
        "MSMODELING_OFFLINE",
        "HF_HUB_OFFLINE",
        "TRANSFORMERS_OFFLINE",
        "HF_DATASETS_OFFLINE",
    ):
        monkeypatch.delenv(key, raising=False)


@pytest.mark.parametrize(
    ("model_id", "expected"),
    [
        ("Qwen/Qwen3-32B", True),
        ("deepseek-ai/DeepSeek-R1", True),
        (" tests/foo", False),
        ("tests/foo", False),
        ("tensor_cast/bar", False),
        ("http://example.com/model", False),
        ("org/model.json", False),
        ("a/model", False),
        ("org/b", False),
        ("org/model with space", False),
        (r"org\model", False),
    ],
)
def test_looks_like_model_id_filters_expected_values(model_id: str, expected: bool) -> None:
    assert prefetch._looks_like_model_id(model_id) is expected


def test_iter_string_values_walks_nested_dicts_and_lists() -> None:
    data = {
        "a": "Qwen/Qwen3-32B",
        "b": ["deepseek-ai/DeepSeek-R1", {"c": "not/a/file.json"}],
        "d": 123,
    }

    assert list(prefetch._iter_string_values(data)) == [
        "Qwen/Qwen3-32B",
        "deepseek-ai/DeepSeek-R1",
        "not/a/file.json",
    ]


def test_collect_from_python_returns_empty_for_syntax_error(tmp_path: Path) -> None:
    path = tmp_path / "broken.py"
    path.write_text('x = "unterminated', encoding="utf-8")

    assert prefetch._collect_from_python(path, frozenset()) == set()


def test_collect_from_json_returns_empty_for_invalid_json(tmp_path: Path) -> None:
    path = tmp_path / "broken.json"
    path.write_text("{bad json", encoding="utf-8")

    assert prefetch._collect_from_json(path, frozenset()) == set()


def test_collect_model_ids_discovers_from_python_and_json_and_skips_ignored_paths(
    tmp_path: Path,
) -> None:
    scan_dir = tmp_path / "scan"
    (scan_dir / "suite").mkdir(parents=True)
    (scan_dir / "tests" / ".ci").mkdir(parents=True)
    (scan_dir / "tests" / "assets" / "cache").mkdir(parents=True)
    (scan_dir / "scripts" / "helpers").mkdir(parents=True)

    (scan_dir / "suite" / "case.py").write_text(
        "\n".join(
            [
                'MODEL_A = "Qwen/Qwen3-32B"',
                'MODEL_B = "deepseek-ai/DeepSeek-R1"',
                'IGNORE_A = "tests/fixture"',
                'IGNORE_B = "org/model.json"',
            ]
        ),
        encoding="utf-8",
    )
    (scan_dir / "suite" / "case.json").write_text(
        json.dumps(
            {
                "models": ["Qwen/Qwen3-32B", {"id": "THUDM/GLM-4-9B"}],
                "nested": {"remote": "deepseek-ai/DeepSeek-R1"},
            }
        ),
        encoding="utf-8",
    )
    (scan_dir / "tests" / ".ci" / "ignored.py").write_text('MODEL = "ignored/Model"', encoding="utf-8")
    (scan_dir / "tests" / "assets" / "cache" / "ignored.json").write_text(
        json.dumps({"model": "ignored/CacheModel"}),
        encoding="utf-8",
    )
    (scan_dir / "scripts" / "helpers" / "ignored.py").write_text(
        'MODEL = "ignored/HelperModel"',
        encoding="utf-8",
    )

    assert prefetch.collect_model_ids(scan_dir) == [
        "Qwen/Qwen3-32B",
        "THUDM/GLM-4-9B",
        "deepseek-ai/DeepSeek-R1",
    ]


def test_env_overrides_activate_sets_and_restores_values(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    monkeypatch.setenv("HF_HOME", "old-hf")
    monkeypatch.setenv("MSMODELING_OFFLINE", "1")

    overrides = prefetch.EnvOverrides(
        hf_home="new-hf",
        torch_home="new-torch",
        modelscope_cache="new-ms",
    )

    with overrides.activate():
        assert prefetch.os.environ["HF_HOME"] == "new-hf"
        assert prefetch.os.environ["TORCH_HOME"] == "new-torch"
        assert prefetch.os.environ["MODELSCOPE_CACHE"] == "new-ms"
        assert prefetch.os.environ["MSMODELING_OFFLINE"] == "0"
        assert prefetch.os.environ["HF_HUB_OFFLINE"] == "0"
        assert prefetch.os.environ["TRANSFORMERS_OFFLINE"] == "0"
        assert prefetch.os.environ["HF_DATASETS_OFFLINE"] == "0"

    assert prefetch.os.environ["HF_HOME"] == "old-hf"
    assert "TORCH_HOME" not in prefetch.os.environ
    assert "MODELSCOPE_CACHE" not in prefetch.os.environ
    assert prefetch.os.environ["MSMODELING_OFFLINE"] == "1"
    assert "HF_HUB_OFFLINE" not in prefetch.os.environ
    assert "TRANSFORMERS_OFFLINE" not in prefetch.os.environ
    assert "HF_DATASETS_OFFLINE" not in prefetch.os.environ


def test_prefetch_result_to_dict_serializes_all_fields() -> None:
    result = prefetch.PrefetchResult(model_id="org/model", source="huggingface", success=False, error="boom")

    assert result.to_dict() == {
        "model_id": "org/model",
        "source": "huggingface",
        "success": False,
        "error": "boom",
    }


def test_huggingface_prefetcher_fetch_success_first_try(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    snapshot_calls: list[str] = []
    config_calls: list[tuple[str, dict[str, object]]] = []

    def _snapshot(model_id: str) -> str:
        snapshot_calls.append(model_id)
        return f"/cache/hf/{model_id}"

    def _from_pretrained(model_id: str, **kwargs: object) -> object:
        config_calls.append((model_id, kwargs))
        return object()

    monkeypatch.setattr(prefetch, "snapshot_huggingface_config_only", _snapshot)
    monkeypatch.setitem(
        sys.modules,
        "transformers",
        SimpleNamespace(AutoConfig=SimpleNamespace(from_pretrained=_from_pretrained)),
    )

    fetcher = prefetch.HuggingFacePrefetcher()
    result = fetcher.fetch("Qwen/Qwen3-32B")

    assert result == prefetch.PrefetchResult(
        model_id="Qwen/Qwen3-32B",
        source="huggingface",
        success=True,
    )
    assert snapshot_calls == ["Qwen/Qwen3-32B"]
    assert config_calls == [("/cache/hf/Qwen/Qwen3-32B", {})]


def test_huggingface_prefetcher_fetch_retries_with_trust_remote_code(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    config_calls: list[tuple[str, dict[str, object]]] = []

    monkeypatch.setattr(
        prefetch,
        "snapshot_huggingface_config_only",
        lambda model_id: f"/cache/hf/{model_id}",
    )

    def _from_pretrained(model_id: str, **kwargs: object) -> object:
        config_calls.append((model_id, kwargs))
        if len(config_calls) == 1:
            raise RuntimeError("set trust_remote_code=True")
        return object()

    monkeypatch.setitem(
        sys.modules,
        "transformers",
        SimpleNamespace(AutoConfig=SimpleNamespace(from_pretrained=_from_pretrained)),
    )

    fetcher = prefetch.HuggingFacePrefetcher()
    result = fetcher.fetch("deepseek-ai/DeepSeek-R1")

    assert result.success is True
    assert config_calls == [
        ("/cache/hf/deepseek-ai/DeepSeek-R1", {}),
        ("/cache/hf/deepseek-ai/DeepSeek-R1", {"trust_remote_code": True}),
    ]


def test_modelscope_prefetcher_fetch_retries_with_trust_remote_code(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    snapshot_calls: list[str] = []
    config_calls: list[tuple[str, dict[str, object]]] = []

    def _snapshot(model_id: str) -> str:
        snapshot_calls.append(model_id)
        return f"/cache/{model_id}"

    def _from_pretrained(local_dir: str, **kwargs: object) -> object:
        config_calls.append((local_dir, kwargs))
        if len(config_calls) == 1:
            raise RuntimeError("please enable trust_remote_code")
        return object()

    monkeypatch.setattr(prefetch, "snapshot_modelscope_config_only", _snapshot)
    monkeypatch.setitem(
        sys.modules,
        "modelscope",
        SimpleNamespace(AutoConfig=SimpleNamespace(from_pretrained=_from_pretrained)),
    )

    fetcher = prefetch.ModelScopePrefetcher()
    result = fetcher.fetch("THUDM/GLM-4-9B")

    assert result == prefetch.PrefetchResult(
        model_id="THUDM/GLM-4-9B",
        source="modelscope",
        success=True,
    )
    assert snapshot_calls == ["THUDM/GLM-4-9B"]
    assert config_calls == [
        ("/cache/THUDM/GLM-4-9B", {}),
        ("/cache/THUDM/GLM-4-9B", {"trust_remote_code": True}),
    ]


class _FailingPrefetcher:
    def __init__(self, message: str) -> None:
        self.message = message

    def fetch(self, _model_id: str) -> prefetch.PrefetchResult:
        raise RuntimeError(self.message)


class _SuccessfulPrefetcher:
    def __init__(self, source: str) -> None:
        self.source = source
        self.calls: list[str] = []

    def fetch(self, model_id: str) -> prefetch.PrefetchResult:
        self.calls.append(model_id)
        return prefetch.PrefetchResult(model_id=model_id, source=self.source, success=True)


def test_try_prefetch_returns_first_successful_result() -> None:
    ok = _SuccessfulPrefetcher("huggingface")

    result = prefetch._try_prefetch(
        "Qwen/Qwen3-32B",
        [_FailingPrefetcher("first failed"), ok],
    )

    assert result == prefetch.PrefetchResult(
        model_id="Qwen/Qwen3-32B",
        source="huggingface",
        success=True,
    )
    assert ok.calls == ["Qwen/Qwen3-32B"]


def test_try_prefetch_returns_unresolved_when_all_prefetchers_fail() -> None:
    result = prefetch._try_prefetch(
        "Qwen/Qwen3-32B",
        [_FailingPrefetcher("first failed"), _FailingPrefetcher("last failed")],
    )

    assert result == prefetch.PrefetchResult(
        model_id="Qwen/Qwen3-32B",
        source="unresolved",
        success=False,
        error="last failed",
    )


def test_prefetch_all_preserves_input_order() -> None:
    ok = _SuccessfulPrefetcher("huggingface")

    results = prefetch._prefetch_all(
        ["Qwen/Qwen3-32B", "deepseek-ai/DeepSeek-R1"],
        [ok],
    )

    assert [item.model_id for item in results] == [
        "Qwen/Qwen3-32B",
        "deepseek-ai/DeepSeek-R1",
    ]
    assert ok.calls == ["Qwen/Qwen3-32B", "deepseek-ai/DeepSeek-R1"]


def test_build_prefetchers_skips_import_errors(monkeypatch: pytest.MonkeyPatch) -> None:
    monkeypatch.setattr(prefetch, "HuggingFacePrefetcher", lambda: "hf")

    def _raise_import_error() -> str:
        raise ImportError("modelscope missing")

    monkeypatch.setattr(prefetch, "ModelScopePrefetcher", _raise_import_error)

    assert prefetch._build_prefetchers() == ["hf"]


def test_write_manifest_persists_results_as_json(tmp_path: Path) -> None:
    results = [
        prefetch.PrefetchResult(model_id="Qwen/Qwen3-32B", source="dry-run", success=True),
        prefetch.PrefetchResult(model_id="bad/model", source="unresolved", success=False, error="boom"),
    ]

    manifest_path = prefetch._write_manifest(tmp_path, tmp_path / "scan", results)
    manifest = json.loads(manifest_path.read_text(encoding="utf-8"))

    assert manifest["schema_version"] == 1
    assert manifest["dest_dir"] == str(tmp_path)
    assert manifest["scan_dir"] == str(tmp_path / "scan")
    assert manifest["models"] == [item.to_dict() for item in results]


def test_main_returns_2_when_scan_dir_is_missing(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
    with caplog.at_level(logging.ERROR, logger="scripts.prefetch_model_configs"):
        result = run_module_main(
            "scripts.prefetch_model_configs",
            [
                "--scan-dir",
                str(tmp_path / "missing"),
                "--dest-dir",
                str(tmp_path / "cache"),
            ],
        )

    assert result.returncode == 2
    assert f"scan dir not found: {(tmp_path / 'missing').resolve()}" in caplog.text


def test_main_dry_run_writes_manifest_and_returns_zero(tmp_path: Path) -> None:
    scan_dir = tmp_path / "tests"
    scan_dir.mkdir()
    (scan_dir / "case.py").write_text('MODEL = "Qwen/Qwen3-32B"', encoding="utf-8")

    dest_dir = tmp_path / "cache"
    result = run_module_main(
        "scripts.prefetch_model_configs",
        [
            "--scan-dir",
            str(scan_dir),
            "--dest-dir",
            str(dest_dir),
            "--dry-run",
        ],
    )

    manifest = json.loads((dest_dir / "model_config_manifest.json").read_text(encoding="utf-8"))

    assert result.returncode == 0
    assert manifest["models"] == [
        {
            "model_id": "Qwen/Qwen3-32B",
            "source": "dry-run",
            "success": True,
            "error": "",
        }
    ]


def test_main_returns_1_when_no_model_ids_are_discovered(
    tmp_path: Path,
    caplog: pytest.LogCaptureFixture,
) -> None:
    scan_dir = tmp_path / "tests"
    scan_dir.mkdir()
    (scan_dir / "case.py").write_text('MODEL = "tests/not-a-model"', encoding="utf-8")

    with caplog.at_level(logging.ERROR, logger="scripts.prefetch_model_configs"):
        result = run_module_main(
            "scripts.prefetch_model_configs",
            [
                "--scan-dir",
                str(scan_dir),
                "--dest-dir",
                str(tmp_path / "cache"),
            ],
        )

    assert result.returncode == 1
    assert "No model id discovered from tests scan." in caplog.text


def test_main_returns_1_when_no_prefetchers_are_available(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
    caplog: pytest.LogCaptureFixture,
) -> None:
    scan_dir = tmp_path / "tests"
    scan_dir.mkdir()
    (scan_dir / "case.py").write_text('MODEL = "Qwen/Qwen3-32B"', encoding="utf-8")
    monkeypatch.setattr(prefetch, "_build_prefetchers", list)

    with caplog.at_level(logging.ERROR, logger="scripts.prefetch_model_configs"):
        result = run_module_main(
            "scripts.prefetch_model_configs",
            [
                "--scan-dir",
                str(scan_dir),
                "--dest-dir",
                str(tmp_path / "cache"),
            ],
        )

    assert result.returncode == 1
    assert "Neither transformers nor modelscope is installed." in caplog.text


def test_main_writes_manifest_and_returns_1_when_prefetch_fails(
    tmp_path: Path,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    scan_dir = tmp_path / "tests"
    scan_dir.mkdir()
    (scan_dir / "case.py").write_text('MODEL = "Qwen/Qwen3-32B"', encoding="utf-8")
    dest_dir = tmp_path / "cache"

    monkeypatch.setattr(prefetch, "_build_prefetchers", lambda: [object()])
    monkeypatch.setattr(
        prefetch,
        "_prefetch_all",
        lambda model_ids, prefetchers: [
            prefetch.PrefetchResult(model_id=model_ids[0], source="unresolved", success=False, error="boom")
        ],
    )

    result = run_module_main(
        "scripts.prefetch_model_configs",
        [
            "--scan-dir",
            str(scan_dir),
            "--dest-dir",
            str(dest_dir),
        ],
    )

    manifest = json.loads((dest_dir / "model_config_manifest.json").read_text(encoding="utf-8"))

    assert result.returncode == 1
    assert manifest["models"] == [
        {
            "model_id": "Qwen/Qwen3-32B",
            "source": "unresolved",
            "success": False,
            "error": "boom",
        }
    ]