"""
Test module for transfer.py from tools/load_ms_weights_to_pt/
Tests the transfer and patch functionality
"""
import pytest
import sys
import os
from unittest.mock import patch
tools_path = os.path.join(os.path.dirname(__file__), '../../../tools/load_ms_weights_to_pt')
sys.path.insert(0, tools_path)
from transfer import (
transfer_load,
copy_weights_transfer_tool_file,
patch_torch_load,
patch_texts
)
class TestCopyWeightsTransferToolFile:
"""Test copy_weights_transfer_tool_file function"""
def test_copy_weights_transfer_tool_file_success(self, tmp_path):
"""Test successful copying of weight transfer tool files"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
target_training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
target_training_dir.mkdir(parents=True)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
copy_weights_transfer_tool_file(str(target_dir))
assert (target_training_dir / "checkpointing.py").exists()
assert (target_training_dir / "serialization.py").exists()
assert (target_training_dir / "checkpointing.py").read_text() == "# checkpointing code"
assert (target_training_dir / "serialization.py").read_text() == "# serialization code"
def test_copy_weights_transfer_tool_file_missing_checkpointing(self, tmp_path):
"""Test error when checkpointing.py is missing"""
source_dir = tmp_path / "source"
source_dir.mkdir()
serialization_file = source_dir / "serialization.py"
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
target_training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
target_training_dir.mkdir(parents=True)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
with pytest.raises(FileNotFoundError, match="checkpointing.py does not exist"):
copy_weights_transfer_tool_file(str(target_dir))
def test_copy_weights_transfer_tool_file_missing_serialization(self, tmp_path):
"""Test error when serialization.py is missing"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
checkpointing_file.write_text("# checkpointing code")
target_dir = tmp_path / "target"
target_training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
target_training_dir.mkdir(parents=True)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
with pytest.raises(FileNotFoundError, match="serialization.py does not exist"):
copy_weights_transfer_tool_file(str(target_dir))
def test_copy_weights_transfer_tool_file_missing_target_directory(self, tmp_path):
"""Test error when target directory doesn't exist"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "nonexistent"
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
with pytest.raises(FileNotFoundError, match="does not exist"):
copy_weights_transfer_tool_file(str(target_dir))
class TestPatchTorchLoad:
"""Test patch_torch_load function"""
def test_patch_torch_load_success(self, tmp_path):
"""Test successful patching of torch.load"""
adaptor_dir = tmp_path / "mindspeed_llm" / "tasks"
adaptor_dir.mkdir(parents=True)
adaptor_file = adaptor_dir / "megatron_adaptor.py"
original_content = """class MegatronAdaptor:
def patch_datasets(self):
pass"""
adaptor_file.write_text(original_content)
patch_torch_load(str(tmp_path))
patched_content = adaptor_file.read_text()
assert "from mindspeed_llm.mindspore.training.checkpointing import load_wrapper" in patched_content
assert "MegatronAdaptation.register('torch.load', load_wrapper)" in patched_content
def test_patch_torch_load_file_not_found(self, tmp_path):
"""Test error when megatron_adaptor.py doesn't exist"""
with pytest.raises(FileNotFoundError, match="megatron_adaptor.py does not exist"):
patch_torch_load(str(tmp_path))
def test_patch_torch_load_pattern_not_found(self, tmp_path):
"""Test error when pattern to replace is not found"""
adaptor_dir = tmp_path / "mindspeed_llm" / "tasks"
adaptor_dir.mkdir(parents=True)
adaptor_file = adaptor_dir / "megatron_adaptor.py"
original_content = """class MegatronAdaptor:
def some_other_method(self):
pass"""
adaptor_file.write_text(original_content)
with pytest.raises(ValueError, match="replace fail"):
patch_torch_load(str(tmp_path))
class TestTransferLoad:
"""Test transfer_load main function"""
def test_transfer_load_integration(self, tmp_path):
"""Test complete transfer_load workflow"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
training_dir.mkdir(parents=True)
tasks_dir = target_dir / "mindspeed_llm" / "tasks"
tasks_dir.mkdir(parents=True)
adaptor_file = tasks_dir / "megatron_adaptor.py"
original_content = """class MegatronAdaptor:
def patch_datasets(self):
pass"""
adaptor_file.write_text(original_content)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
transfer_load(str(target_dir))
assert (training_dir / "checkpointing.py").exists()
assert (training_dir / "serialization.py").exists()
patched_content = adaptor_file.read_text()
assert "load_wrapper" in patched_content
def test_transfer_load_with_copy_failure(self, tmp_path):
"""Test transfer_load handles copy failure"""
target_dir = tmp_path / "target"
with patch('transfer.copy_weights_transfer_tool_file', side_effect=FileNotFoundError("Test error")):
with pytest.raises(FileNotFoundError, match="Test error"):
transfer_load(str(target_dir))
def test_transfer_load_with_patch_failure(self, tmp_path):
"""Test transfer_load handles patch failure"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
training_dir.mkdir(parents=True)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
with pytest.raises(FileNotFoundError):
transfer_load(str(target_dir))
class TestPatchTexts:
"""Test patch_texts constant"""
def test_patch_texts_format(self):
"""Test patch_texts has correct format"""
assert isinstance(patch_texts, str)
assert "def patch_datasets(self):" in patch_texts
assert "from mindspeed_llm.mindspore.training.checkpointing import load_wrapper" in patch_texts
assert "MegatronAdaptation.register('torch.load', load_wrapper)" in patch_texts
class TestCommandLineInterface:
"""Test command line interface"""
def test_main_with_valid_args(self, tmp_path):
"""Test main function with valid arguments"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
training_dir.mkdir(parents=True)
tasks_dir = target_dir / "mindspeed_llm" / "tasks"
tasks_dir.mkdir(parents=True)
adaptor_file = tasks_dir / "megatron_adaptor.py"
adaptor_file.write_text("""class MegatronAdaptor:
def patch_datasets(self):
pass""")
test_args = ['transfer.py', '--mindspeed_llm_path', str(target_dir)]
with patch('sys.argv', test_args):
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mindspeed_llm_path", type=str, required=True)
args = parser.parse_args(test_args[1:])
transfer_load(args.mindspeed_llm_path)
assert (training_dir / "checkpointing.py").exists()
class TestEdgeCases:
"""Test edge cases and error handling"""
def test_copy_with_readonly_source(self, tmp_path):
"""Test copying when source files are read-only"""
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
os.chmod(checkpointing_file, 0o444)
os.chmod(serialization_file, 0o444)
target_dir = tmp_path / "target"
training_dir = target_dir / "mindspeed_llm" / "mindspore" / "training"
training_dir.mkdir(parents=True)
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
copy_weights_transfer_tool_file(str(target_dir))
assert (training_dir / "checkpointing.py").exists()
def test_patch_with_unicode_content(self, tmp_path):
"""Test patching file with unicode content"""
adaptor_dir = tmp_path / "mindspeed_llm" / "tasks"
adaptor_dir.mkdir(parents=True)
adaptor_file = adaptor_dir / "megatron_adaptor.py"
original_content = """class MegatronAdaptor:
def patch_datasets(self):
# 中文注释
pass"""
adaptor_file.write_text(original_content, encoding='utf-8')
patch_torch_load(str(tmp_path))
patched_content = adaptor_file.read_text(encoding='utf-8')
assert "中文注释" in patched_content
def test_transfer_with_symlinks(self, tmp_path):
"""Test transfer handles symbolic links correctly"""
if os.name == 'nt':
pytest.skip("Symlinks not reliably supported on Windows")
source_dir = tmp_path / "source"
source_dir.mkdir()
checkpointing_file = source_dir / "checkpointing.py"
serialization_file = source_dir / "serialization.py"
checkpointing_file.write_text("# checkpointing code")
serialization_file.write_text("# serialization code")
target_dir = tmp_path / "target"
actual_dir = tmp_path / "actual"
training_dir = actual_dir / "mindspeed_llm" / "mindspore" / "training"
training_dir.mkdir(parents=True)
try:
os.symlink(actual_dir, target_dir)
except (OSError, NotImplementedError):
pytest.skip("Symlinks not supported on this system")
with patch('transfer.os.path.dirname', return_value=str(source_dir)):
with patch('transfer.os.path.abspath', return_value=str(source_dir / "transfer.py")):
copy_weights_transfer_tool_file(str(target_dir))
assert (training_dir / "checkpointing.py").exists()
if __name__ == "__main__":
pytest.main([__file__, "-v"])