"""
Pytest hooks for the test suite.
Hub access: default online. Set ``MSMODELING_OFFLINE=1`` to enable
``HF_HUB_OFFLINE`` / ``TRANSFORMERS_OFFLINE`` / ``HF_DATASETS_OFFLINE``.
See tests/README.md and docs/design/ut_refactor.md.
After the session ends, optionally remove hub weight shards under the repo-local
``.msmodeling_cache`` while keeping config and Python sources.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import pytest
logger = logging.getLogger(__name__)
pytest_plugins = (
"tests.regression.tensor_cast.conftest",
"tests.regression.serving_cast.conftest",
)
_REPO_CACHE = Path.cwd() / ".msmodeling_cache"
def _resolve_cache_dir() -> Path:
raw = os.environ.get("MSMODELING_CACHE", "").strip()
if not raw:
return _REPO_CACHE
p = Path(raw)
if not p.is_absolute():
p = Path.cwd() / p
return p
_WEIGHT_SUFFIXES = (
".safetensors",
".bin",
".pt",
".pth",
".ckpt",
".h5",
".onnx",
".gguf",
".npz",
".zip",
".tar",
".tar.gz",
)
def _is_hub_weight_file(name: str) -> bool:
lower = name.lower()
if lower.endswith(".safetensors.index.json"):
return True
return any(lower.endswith(suf) for suf in _WEIGHT_SUFFIXES)
def _prune_hub_weight_files(root: Path) -> int:
removed = 0
if not root.is_dir():
return removed
for path in root.rglob("*"):
if not path.is_file():
continue
if not _is_hub_weight_file(path.name):
continue
try:
path.unlink()
removed += 1
except OSError:
logger.exception("Could not remove hub weight file %s", path)
if removed:
logger.info("Pruned %s hub weight file(s) under %s", removed, root)
return removed
def _msmodeling_offline_enabled() -> bool:
flag = os.environ.get("MSMODELING_OFFLINE", "").strip().lower()
return flag in ("1", "true", "yes", "on")
def _apply_hub_offline_env() -> None:
"""Single switch for Hugging Face / Transformers / Datasets offline mode."""
if _msmodeling_offline_enabled():
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
else:
os.environ.setdefault("HF_HUB_OFFLINE", "0")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "0")
os.environ.setdefault("HF_DATASETS_OFFLINE", "0")
def _cache_dir_configured() -> bool:
return bool(os.environ.get("MSMODELING_CACHE", "").strip())
def pytest_sessionstart(session) -> None:
_apply_hub_offline_env()
if not _cache_dir_configured():
return
cache_dir = _resolve_cache_dir()
cache = str(cache_dir)
os.environ.setdefault("TORCH_HOME", cache)
os.environ.setdefault("HF_HOME", cache)
os.environ.setdefault("MODELSCOPE_CACHE", cache)
def _weights_prune_enabled() -> bool:
raw = os.environ.get("MSMODELING_TEST_WEIGHTS_PRUNE")
if raw is None or not raw.strip():
raw = os.environ.get("TENSOR_CAST_PRUNE_HUB_WEIGHTS_AFTER_UT", "0")
return raw.strip().lower() not in ("0", "false", "no", "off")
def pytest_sessionfinish(session, exitstatus) -> None:
if not _weights_prune_enabled():
return
if not _cache_dir_configured():
return
_prune_hub_weight_files(_resolve_cache_dir())
@pytest.fixture(autouse=True)
def _seed_rng():
"""Seed ``random`` and ``torch`` before every test for determinism."""
import random
import torch
random.seed(0)
torch.manual_seed(0)
@pytest.fixture(autouse=True)
def _restore_environ():
"""Snapshot os.environ per test and restore it afterwards.
The snapshot is taken after ``pytest_sessionstart`` set the session-level hub
env, so session env is part of the snapshot and preserved across tests.
"""
snapshot = dict(os.environ)
try:
yield
finally:
os.environ.clear()
os.environ.update(snapshot)
@pytest.fixture(scope="session")
def cfg_registry(model_zoo) -> dict:
"""Alias-resolving session view over the shared ``model_cache`` config cache."""
from tests.helpers.model_cache import get_hf_config
class _CfgRegistry(dict):
def __init__(self, alias_to_model_id: dict[str, str]):
super().__init__()
self._alias_to_model_id = alias_to_model_id
def _normalize_model_id(self, key: str) -> str:
model_id = self._alias_to_model_id.get(key, key)
if "/" not in model_id:
aliases = ", ".join(sorted(self._alias_to_model_id))
raise KeyError(f"Unknown model alias '{key}'. Available aliases: {aliases}")
return model_id
def __getitem__(self, key):
model_id = self._normalize_model_id(key)
config = get_hf_config(model_id)
dict.__setitem__(self, model_id, config)
return config
def __setitem__(self, key, value):
model_id = self._normalize_model_id(key)
dict.__setitem__(self, model_id, value)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
return _CfgRegistry(model_zoo)
@pytest.fixture(scope="session")
def device() -> str:
"""Default test device profile name."""
return "TEST_DEVICE"
@pytest.fixture(scope="session")
def model_zoo() -> dict[str, str]:
"""Canonical model aliases used by regression fixtures."""
return {
"deepseek_v32": "deepseek-ai/DeepSeek-V3.2",
"qwen3_32b": "Qwen/Qwen3-32B",
"qwen3_vl_8b": "Qwen/Qwen3-VL-8B-Instruct",
}