import unittest
import torch
import torch.nn as nn
from msmodelslim.utils.memory import align_input_to_module_device_hook, register_device_alignment_hook, \
unregister_device_alignment_hook
class MockModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
class TestDeviceAlignmentHook(unittest.TestCase):
def setUp(self):
self.module = MockModule()
self.device = torch.device('cpu')
self.module.to(self.device)
def test_align_input_to_module_device_hook(self):
input_tensor = torch.randn(5, 10).to(torch.device('cpu'))
args = (input_tensor,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
self.assertEqual(aligned_args[0].device, self.module.linear.weight.device)
def test_register_device_alignment_hook(self):
hook_handle = register_device_alignment_hook(self.module)
self.assertTrue(hasattr(self.module, '_device_alignment_hooks_registered'))
self.assertTrue(self.module._device_alignment_hooks_registered)
self.assertTrue(hasattr(self.module, '_device_alignment_pre_hook_handle'))
self.assertIsNotNone(hook_handle)
self.assertIsInstance(hook_handle, dict)
self.assertIn('pre_hook', hook_handle)
def test_register_device_alignment_hook_with_name(self):
custom_name = "MockModule"
hook_handle = register_device_alignment_hook(self.module, name=custom_name)
self.assertTrue(hasattr(self.module, '_device_alignment_hooks_registered'))
self.assertTrue(self.module._device_alignment_hooks_registered)
self.assertTrue(hasattr(self.module, '_device_alignment_pre_hook_handle'))
self.assertIsNotNone(hook_handle)
self.assertIsInstance(hook_handle, dict)
self.assertIn('pre_hook', hook_handle)
unregister_device_alignment_hook(self.module, name=custom_name)
def test_unregister_device_alignment_hook(self):
register_device_alignment_hook(self.module)
self.assertTrue(hasattr(self.module, '_device_alignment_hooks_registered'))
unregister_device_alignment_hook(self.module)
self.assertFalse(hasattr(self.module, '_device_alignment_hooks_registered'))
self.assertFalse(hasattr(self.module, '_device_alignment_pre_hook_handle'))
def test_unregister_device_alignment_hook_with_name(self):
custom_name = "MockModule"
register_device_alignment_hook(self.module, name=custom_name)
self.assertTrue(hasattr(self.module, '_device_alignment_hooks_registered'))
unregister_device_alignment_hook(self.module, name=custom_name)
self.assertFalse(hasattr(self.module, '_device_alignment_hooks_registered'))
self.assertFalse(hasattr(self.module, '_device_alignment_pre_hook_handle'))
def test_hook_prevents_duplicate_registration(self):
handle1 = register_device_alignment_hook(self.module)
handle2 = register_device_alignment_hook(self.module)
self.assertEqual(handle1, handle2)
unregister_device_alignment_hook(self.module)
def test_hook_prevents_duplicate_registration_with_name(self):
custom_name = "MockModule"
handle1 = register_device_alignment_hook(self.module, name=custom_name)
handle2 = register_device_alignment_hook(self.module, name=custom_name)
self.assertEqual(handle1, handle2)
unregister_device_alignment_hook(self.module, name=custom_name)
def test_hook_with_complex_input(self):
input_tensor1 = torch.randn(5, 10).to(torch.device('cpu'))
input_tensor2 = torch.randn(5, 10).to(torch.device('cpu'))
complex_input = {
'tensor1': input_tensor1,
'tensor2': input_tensor2,
'list_data': [input_tensor1, input_tensor2],
'tuple_data': (input_tensor1, input_tensor2)
}
args = (complex_input,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
aligned_complex_input = aligned_args[0]
self.assertEqual(aligned_complex_input['tensor1'].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['tensor2'].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['list_data'][0].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['list_data'][1].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['tuple_data'][0].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['tuple_data'][1].device, self.module.linear.weight.device)
def test_hook_with_no_parameters_module(self):
class NoParamModule(nn.Module):
def forward(self, x):
return x
no_param_module = NoParamModule()
input_tensor = torch.randn(5, 10)
args = (input_tensor,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(no_param_module, args, kwargs)
self.assertEqual(aligned_args, args)
self.assertEqual(aligned_kwargs, kwargs)
def test_hook_with_none_input(self):
args = (torch.randn(5, 10),)
kwargs = {}
result_args, result_kwargs = align_input_to_module_device_hook(None, args, kwargs)
self.assertEqual(result_args, args)
self.assertEqual(result_kwargs, kwargs)
def test_register_hook_with_none_module(self):
result = register_device_alignment_hook(None)
self.assertIsNone(result)
def test_hook_data_statistics(self):
cpu_tensor1 = torch.randn(100, 100)
cpu_tensor2 = torch.randn(50, 50)
cpu_tensor3 = torch.randn(10, 10)
cpu_tensor1 = cpu_tensor1.cpu()
cpu_tensor2 = cpu_tensor2.cpu()
cpu_tensor3 = cpu_tensor3.cpu()
complex_input = {
'tensor1': cpu_tensor1,
'tensor2': cpu_tensor2,
'list_data': [cpu_tensor3],
'already_on_device': torch.randn(5, 5).to(self.module.linear.weight.device)
}
args = (complex_input,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
aligned_complex_input = aligned_args[0]
self.assertEqual(aligned_complex_input['tensor1'].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['tensor2'].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['list_data'][0].device, self.module.linear.weight.device)
self.assertEqual(aligned_complex_input['already_on_device'].device, self.module.linear.weight.device)
def test_hook_no_movement_statistics(self):
correct_device_tensor = torch.randn(10, 10).to(self.module.linear.weight.device)
args = (correct_device_tensor,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
self.assertEqual(aligned_args[0].device, self.module.linear.weight.device)
def test_hook_large_tensor_statistics(self):
large_tensor = torch.randn(1000, 1000)
large_tensor = large_tensor.cpu()
args = (large_tensor,)
kwargs = {}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
self.assertEqual(aligned_args[0].device, self.module.linear.weight.device)
def test_hook_with_kwargs(self):
hook_handle = register_device_alignment_hook(self.module, with_kwargs=True)
self.assertTrue(hasattr(self.module, '_device_alignment_hooks_registered'))
self.assertIsNotNone(hook_handle)
self.assertIsInstance(hook_handle, dict)
self.assertIn('pre_hook', hook_handle)
unregister_device_alignment_hook(self.module)
def test_hook_with_mixed_input_types(self):
tensor_input = torch.randn(5, 10).cpu()
string_input = "test_string"
int_input = 42
list_input = [tensor_input, string_input, int_input]
args = (tensor_input, string_input, int_input, list_input)
kwargs = {'tensor': tensor_input, 'string': string_input}
aligned_args, aligned_kwargs = align_input_to_module_device_hook(self.module, args, kwargs)
self.assertEqual(aligned_args[0].device, self.module.linear.weight.device)
self.assertEqual(aligned_args[1], string_input)
self.assertEqual(aligned_args[2], int_input)
self.assertEqual(aligned_args[3][0].device, self.module.linear.weight.device)
self.assertEqual(aligned_args[3][1], string_input)
self.assertEqual(aligned_args[3][2], int_input)
self.assertEqual(aligned_kwargs['tensor'].device, self.module.linear.weight.device)
self.assertEqual(aligned_kwargs['string'], string_input)
if __name__ == '__main__':
unittest.main()