import torch
from torch import nn
from amct_pytorch.common.utils.model_util import ModuleHelper
def _build_model():
return nn.Sequential(
nn.Linear(4, 4),
nn.Sequential(
nn.Linear(4, 8),
nn.ReLU(),
),
)
def test_named_module_dict_collects_all_submodules():
model = _build_model()
helper = ModuleHelper(model)
expected = {name for name, _ in model.named_modules()}
assert set(helper.named_module_dict) == expected
def test_named_module_dict_holds_module_references():
model = _build_model()
helper = ModuleHelper(model)
assert helper.named_module_dict["0"] is model[0]
assert helper.named_module_dict["1.0"] is model[1][0]
def test_replace_module_by_name_top_level():
model = nn.Sequential(nn.Linear(4, 4))
new = nn.Linear(4, 8)
ModuleHelper.replace_module_by_name(model, "0", new)
assert model[0] is new
def test_replace_module_by_name_nested():
model = _build_model()
new = nn.Linear(4, 16)
ModuleHelper.replace_module_by_name(model, "1.0", new)
assert model[1][0] is new
out = model(torch.randn(1, 4))
assert out.shape == (1, 16)