import logging
from amct_pytorch.common.utils import safe_load
def test_safe_torch_load_defaults_to_weights_only(monkeypatch):
captured_kwargs = {}
def fake_load(path, *, weights_only=None, **kwargs):
captured_kwargs["weights_only"] = weights_only
captured_kwargs.update(kwargs)
return {"path": path}
monkeypatch.setattr(safe_load.torch, "load", fake_load)
assert safe_load.safe_torch_load("params.pth") == {"path": "params.pth"}
assert captured_kwargs["weights_only"] is True
def test_safe_torch_load_preserves_explicit_weights_only(monkeypatch):
captured_kwargs = {}
def fake_load(path, *, weights_only=None, **kwargs):
captured_kwargs["weights_only"] = weights_only
captured_kwargs.update(kwargs)
return {"path": path}
monkeypatch.setattr(safe_load.torch, "load", fake_load)
safe_load.safe_torch_load("params.pth", weights_only=False)
assert captured_kwargs["weights_only"] is False
def test_safe_torch_load_falls_back_when_weights_only_unsupported(monkeypatch, caplog):
captured = {}
def fake_load_without_weights_only(path, **kwargs):
captured["path"] = path
captured["kwargs"] = kwargs
return {"path": path}
monkeypatch.setattr(safe_load.torch, "load", fake_load_without_weights_only)
with caplog.at_level(logging.WARNING):
result = safe_load.safe_torch_load("params.pth", weights_only=True)
assert result == {"path": "params.pth"}
assert "weights_only" not in captured["kwargs"]
assert any("weights_only" in record.message for record in caplog.records)