"""
Tests for convert_utils module.
These tests verify the utility functions for module/function/method replacement.
"""
import torch.nn as nn
from torchtitan_npu.converters.convert_utils import (
find_functions,
find_methods,
find_modules,
replace_functions,
replace_methods,
replace_modules,
)
class SimpleModule(nn.Module):
"""Simple test module."""
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.conv = nn.Conv2d(3, 16, 3)
def forward(self, x):
return self.linear(x)
def test_find_linear_modules():
model = SimpleModule()
matches = find_modules(model, r"Linear")
assert len(matches) == 1
assert isinstance(matches[0].module, nn.Linear)
assert matches[0].attr_name == "linear"
def test_find_conv_modules():
model = SimpleModule()
matches = find_modules(model, r"Conv2d")
assert len(matches) == 1
assert isinstance(matches[0].module, nn.Conv2d)
def test_find_no_matches():
model = SimpleModule()
matches = find_modules(model, r"NonExistent")
assert len(matches) == 0
def test_module_match_is_meta():
model = SimpleModule()
matches = find_modules(model, r"Linear")
assert matches[0].is_meta is False
def test_replace_module():
model = SimpleModule()
matches = find_modules(model, r"Linear")
new_linear = nn.Linear(10, 20)
matches[0].replace(new_linear)
assert model.linear is new_linear
def test_find_functions_in_package():
import torchtitan_npu.converters.convert_utils as utils_module
matches = find_functions("find_modules", package="torchtitan_npu.converters")
assert len(matches) >= 1
assert matches[0].func_name == "find_modules"
def test_find_methods_in_class():
import tests.unit_tests.functions.test_convert_utils
matches = find_methods(
"SimpleModule",
"forward",
package="tests.unit_tests.functions.test_convert_utils",
)
assert len(matches) == 1
assert matches[0].class_name == "SimpleModule"
assert matches[0].method_name == "forward"
def test_replace_linear_modules():
model = SimpleModule()
def factory(old_module):
return nn.Linear(10, 10)
count = replace_modules(model, r"Linear", factory)
assert count == 1
def test_replace_returns_count():
model = SimpleModule()
count = replace_modules(model, r"Linear", lambda m: nn.Linear(10, 10))
assert count == 1
def test_replace_functions_replaces_loaded_function():
import torchtitan_npu.converters.convert_utils as utils_module
original = utils_module.find_modules
def replacement(*args, **kwargs):
return []
count = replace_functions(
"find_modules", replacement, package="torchtitan_npu.converters"
)
assert count >= 1
assert utils_module.find_modules is replacement
utils_module.find_modules = original
def test_replace_methods_replaces_loaded_method():
import tests.unit_tests.functions.test_convert_utils as test_module
original = test_module.SimpleModule.forward
def replacement(self, x):
return x
count = replace_methods(
"SimpleModule",
"forward",
replacement,
package="tests.unit_tests.functions.test_convert_utils",
)
assert count == 1
assert test_module.SimpleModule.forward is replacement
test_module.SimpleModule.forward = original