import unittest
from typing import List, Tuple
import torch
import torch.nn as nn
from msmodelslim.model.common.layer_wise_forward import (
generated_decoder_layer_visit_func,
transformers_generated_forward_func,
)
from msmodelslim.core.base.protocol import ProcessRequest
from msmodelslim.utils.exception import InvalidModelError
class DummyDecoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
return (self.linear(x),)
class DummyModel(nn.Module):
def __init__(self, num_layers: int = 2):
super().__init__()
self.layers = nn.ModuleList([DummyDecoderLayer() for _ in range(num_layers)])
def named_modules(self):
for idx, layer in enumerate(self.layers):
yield f"layers.{idx}", layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)[0]
return x
class TestLayerWiseForward(unittest.TestCase):
def setUp(self):
torch.manual_seed(0)
self.model = DummyModel(num_layers=3)
self.input = torch.randn(2, 4)
def test_generated_decoder_layer_visit_func_auto_discovery(self):
gen = generated_decoder_layer_visit_func(self.model)
collected: List[ProcessRequest] = list(gen)
self.assertEqual(len(collected), 3)
for _, req in enumerate(collected):
self.assertIsInstance(req, ProcessRequest)
self.assertIn("layers.", req.name)
self.assertIsInstance(req.module, DummyDecoderLayer)
self.assertEqual(req.args, tuple())
self.assertEqual(req.kwargs, {})
def test_generated_decoder_layer_visit_func_with_custom_blocks(self):
blocks = [("x", self.model.layers[1])]
gen = generated_decoder_layer_visit_func(self.model, transformer_blocks=blocks)
reqs = list(gen)
self.assertEqual(len(reqs), 1)
self.assertEqual(reqs[0].name, "x")
self.assertIs(reqs[0].module, self.model.layers[1])
def test_transformers_generated_forward_func_normal_flow(self):
gen = transformers_generated_forward_func(self.model, (self.input,))
req0 = next(gen)
self.assertIsInstance(req0, ProcessRequest)
out0 = req0.module(*req0.args, **req0.kwargs)
req1 = gen.send(out0)
out1 = req1.module(*req1.args, **req1.kwargs)
req2 = gen.send(out1)
out2 = req2.module(*req2.args, **req2.kwargs)
with self.assertRaises(StopIteration):
gen.send(out2)
def test_transformers_generated_forward_func_no_first_input_error(self):
class BadModel(nn.Module):
def named_modules(self):
return []
def forward(self, x):
return x
bad = BadModel()
with self.assertRaises(IndexError):
gen = transformers_generated_forward_func(bad, (self.input,))
next(gen)
def test_generated_decoder_layer_visit_func_when_distributed_then_call_barrier(self):
"""测试generated_decoder_layer_visit_func:分布式环境下应调用barrier"""
from unittest.mock import patch
with patch('msmodelslim.model.common.layer_wise_forward.dist.is_initialized', return_value=True):
with patch('msmodelslim.model.common.layer_wise_forward.dist.barrier') as mock_barrier:
gen = generated_decoder_layer_visit_func(self.model)
list(gen)
mock_barrier.assert_called_once()
def test_transformers_generated_forward_func_when_dict_input_then_use_kwargs(self):
"""测试transformers_generated_forward_func:dict输入时应使用**kwargs"""
from unittest.mock import patch
dict_inputs = {'x': self.input}
with patch('msmodelslim.model.common.layer_wise_forward.TransformersForwardBreak', Exception):
gen = transformers_generated_forward_func(self.model, dict_inputs)
try:
list(gen)
except (InvalidModelError, Exception):
pass
def test_transformers_generated_forward_func_when_single_tensor_input_then_call_directly(self):
"""测试transformers_generated_forward_func:单个tensor输入时应直接调用"""
from unittest.mock import patch
with patch('msmodelslim.model.common.layer_wise_forward.TransformersForwardBreak', Exception):
gen = transformers_generated_forward_func(self.model, self.input)
try:
list(gen)
except (InvalidModelError, Exception):
pass
def test_transformers_generated_forward_func_when_first_block_input_none_then_raise_error(self):
"""测试transformers_generated_forward_func:first_block_input为None时应抛出InvalidModelError"""
from unittest.mock import patch
class NoBreakModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = DummyDecoderLayer()
def named_modules(self):
return [('layer0', self.layer)]
def forward(self, x):
return x
model = NoBreakModel()
gen = transformers_generated_forward_func(model, (self.input,))
with self.assertRaises(InvalidModelError) as context:
list(gen)
self.assertIn("Can't get first block input", str(context.exception))
def test_transformers_generated_forward_func_when_distributed_then_call_barrier(self):
"""测试transformers_generated_forward_func:分布式环境下应调用barrier"""
from unittest.mock import patch
with patch('msmodelslim.model.common.layer_wise_forward.dist.is_initialized', return_value=True):
with patch('msmodelslim.model.common.layer_wise_forward.dist.barrier') as mock_barrier:
gen = transformers_generated_forward_func(self.model, (self.input,))
try:
list(gen)
except (InvalidModelError, Exception):
pass