from types import SimpleNamespace
import pytest
import torch
import torch.nn as nn
from amct_pytorch.common.optimization import factory
def _named_param_model():
m = nn.Module()
m.encoder_weight = nn.Parameter(torch.zeros(3))
m.encoder_bias = nn.Parameter(torch.zeros(3))
m.decoder_weight = nn.Parameter(torch.zeros(3))
return m
def test_get_n_set_parameters_byname_matches_substring_and_marks_grad():
m = _named_param_model()
factory.set_require_grad_all(m, requires_grad=False)
selected = list(factory.get_n_set_parameters_byname(m, ["encoder"]))
assert len(selected) == 2
for p in selected:
assert p.requires_grad is True
assert m.decoder_weight.requires_grad is False
def test_get_n_set_parameters_byname_returns_iterator():
m = _named_param_model()
out = factory.get_n_set_parameters_byname(m, ["decoder"])
assert iter(out) is out
def test_set_require_grad_all_toggles_every_parameter():
m = _named_param_model()
factory.set_require_grad_all(m, requires_grad=False)
assert all(not p.requires_grad for p in m.parameters())
factory.set_require_grad_all(m, requires_grad=True)
assert all(p.requires_grad for p in m.parameters())
def test_check_params_grad_returns_none():
m = _named_param_model()
assert factory.check_params_grad(m) is None
def _params():
return [nn.Parameter(torch.zeros(2))]
def test_build_optimizer_default_is_adamw():
args = SimpleNamespace()
opt = factory.build_optimizer(args, _params())
assert isinstance(opt, torch.optim.AdamW)
def test_build_optimizer_adam_uses_args_lr_and_weight_decay():
args = SimpleNamespace(optimizer="adam", base_lr=2e-4, weight_decay=0.1)
opt = factory.build_optimizer(args, _params())
assert isinstance(opt, torch.optim.Adam)
assert opt.defaults["lr"] == pytest.approx(2e-4)
assert opt.defaults["weight_decay"] == pytest.approx(0.1)
def test_build_optimizer_sgd_uses_momentum_default_when_missing():
args = SimpleNamespace(optimizer="sgd", base_lr=1e-3)
opt = factory.build_optimizer(args, _params())
assert isinstance(opt, torch.optim.SGD)
assert opt.defaults["momentum"] == pytest.approx(0.9)
def test_build_optimizer_unknown_raises():
args = SimpleNamespace(optimizer="lamb")
with pytest.raises(ValueError, match="Unsupported optimizer 'lamb'"):
factory.build_optimizer(args, _params())
def _adam(params=None):
return torch.optim.AdamW(params or _params())
@pytest.mark.parametrize("name", ["none", "", "NONE"])
def test_build_lr_scheduler_returns_none_for_disabled(name):
args = SimpleNamespace(lr_scheduler=name)
assert factory.build_lr_scheduler(args, _adam()) is None
def test_build_lr_scheduler_cosine_uses_args_for_t_max():
args = SimpleNamespace(
lr_scheduler="cosine",
base_lr=1e-3,
nsamples=64,
cali_bsz=8,
epochs=4,
)
sched = factory.build_lr_scheduler(args, _adam())
assert isinstance(sched, torch.optim.lr_scheduler.CosineAnnealingLR)
assert sched.T_max == 4 * (64 // 8)
assert sched.eta_min == pytest.approx(1e-3 * 1e-3)
def test_build_lr_scheduler_step_uses_args_step_and_gamma():
args = SimpleNamespace(lr_scheduler="step", lr_step_size=3, lr_gamma=0.5)
sched = factory.build_lr_scheduler(args, _adam())
assert isinstance(sched, torch.optim.lr_scheduler.StepLR)
assert sched.step_size == 3
assert sched.gamma == pytest.approx(0.5)
def test_build_lr_scheduler_step_clamps_zero_step_size_to_one():
args = SimpleNamespace(lr_scheduler="step", lr_step_size=0)
sched = factory.build_lr_scheduler(args, _adam())
assert sched.step_size == 1
def test_build_lr_scheduler_unknown_raises():
args = SimpleNamespace(lr_scheduler="warmup")
with pytest.raises(ValueError, match="Unsupported lr scheduler 'warmup'"):
factory.build_lr_scheduler(args, _adam())