"""
Test module for checkpointing.py from tools/load_ms_weights_to_pt/
Tests the load_wrapper decorator functionality
"""
import pytest
import sys
import os
from unittest.mock import patch
tools_path = os.path.join(os.path.dirname(__file__), '../../../tools/load_ms_weights_to_pt')
sys.path.insert(0, tools_path)
from checkpointing import load_wrapper
from serialization import load_ms_weights
class TestLoadWrapper:
"""Test cases for the load_wrapper decorator"""
def test_load_wrapper_success(self):
"""
Feature: Test load_wrapper
Description: original function succeeds
Expectation: Success.
"""
@load_wrapper
def mock_load_func(file_path):
return {"data": "loaded successfully"}
result = mock_load_func("test.pt")
assert result == {"data": "loaded successfully"}
def test_load_wrapper_with_kwargs(self):
"""
Feature: Test load_wrapper
Description: keyword arguments
Expectation: Success.
"""
@load_wrapper
def mock_load_func(file_path, map_location=None, weights_only=False):
return {"file": file_path, "map_location": map_location, "weights_only": weights_only}
result = mock_load_func("test.pt", map_location="cpu", weights_only=True)
assert result["file"] == "test.pt"
assert result["map_location"] == "cpu"
assert result["weights_only"] is True
def test_load_wrapper_preserves_function_metadata(self):
"""
Feature: Test load_wrapper
Description: original function's metadata
Expectation: Success.
"""
@load_wrapper
def documented_function(arg1, arg2):
return arg1 + arg2
assert documented_function.__name__ == "documented_function"
def test_load_wrapper_with_multiple_args(self):
"""
Feature: Test load_wrapper
Description: multiple positional arguments
Expectation: Success.
"""
@load_wrapper
def mock_multi_arg_func(arg1, arg2, arg3):
return arg1 + arg2 + arg3
result = mock_multi_arg_func(1, 2, 3)
assert result == 6
def test_load_wrapper_with_mixed_args(self):
"""
Feature: Test load_wrapper
Description: mixed positional and keyword arguments
Expectation: Success.
"""
@load_wrapper
def mock_mixed_func(a, b, c=10, d=20):
return a + b + c + d
result = mock_mixed_func(1, 2, c=5)
assert result == 28
def test_load_wrapper_return_none(self):
"""
Feature: Test load_wrapper
Description: when function returns None
Expectation: Success.
"""
@load_wrapper
def return_none_func():
return None
result = return_none_func()
assert result is None
def test_load_wrapper_with_complex_return_type(self):
"""
Feature: Test load_wrapper
Description: complex return types
Expectation: Success.
"""
@load_wrapper
def complex_return_func():
return {
"model": {"layer1": [1, 2, 3], "layer2": [4, 5, 6]},
"optimizer": {"lr": 0.001, "momentum": 0.9},
"epoch": 10
}
result = complex_return_func()
assert isinstance(result, dict)
assert "model" in result
assert result["epoch"] == 10
def test_load_wrapper_exception(self, tmp_path):
"""
Feature: Test load_wrapper
Description: file path with load from backup
Expectation: Success.
"""
test_file = tmp_path / "test.pt"
test_file.write_text("test data")
@load_wrapper
def mock_load_func(file_path):
raise Exception("Failed to load")
with patch('tools.load_ms_weights_to_pt.serialization.load_ms_weights',
return_value={"data": "loaded from backup"}):
result = mock_load_func(test_file)
assert result == {"data": "loaded from backup"}
class TestLoadWrapperIntegration:
"""Integration tests for load_wrapper"""
def test_load_wrapper_with_file_operations(self, tmp_path):
"""
Feature: Test load_wrapper
Description: actual file path operations
Expectation: Success.
"""
test_file = tmp_path / "test.pt"
test_file.write_text("test data")
@load_wrapper
def read_file(path):
with open(path, 'r') as f:
return f.read()
result = read_file(str(test_file))
assert result == "test data"
def test_load_wrapper_decorator_chain(self):
"""
Feature: Test load_wrapper
Description: decorator chain
Expectation: Success.
"""
def another_decorator(func):
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
return f"decorated: {result}"
return wrapper
@another_decorator
@load_wrapper
def chained_function(value):
return value
result = chained_function("test")
assert result == "decorated: test"
if __name__ == "__main__":
pytest.main([__file__, "-v"])