import unittest
from typing import Tuple, Optional
from unittest.mock import patch
import torch
import torch.nn as nn
from msmodelslim.core.base.protocol import ProcessRequest
from msmodelslim.model.common.layer_wise_forward import (
TransformersForwardBreak,
generated_decoder_layer_visit_func,
transformers_generated_forward_func
)
class DummyDecoderLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) \
-> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
return (x, None)
class SimpleModel(nn.Module):
"""简单的测试模型,包含一个线性层和一个TransformerDecoder层"""
def __init__(self):
super().__init__()
self.decoder = DummyDecoderLayer()
def forward(self, x: torch.Tensor) \
-> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
x = self.decoder(x)
return x
class TestTransformersGenerated(unittest.TestCase):
"""测试_TransformersForwardBreak异常类和相关的生成器函数"""
def setUp(self):
"""测试前的准备工作"""
self.model = SimpleModel()
def test_transformers_forward_break(self):
"""测试_TransformersForwardBreak异常类"""
exception = TransformersForwardBreak()
self.assertIsInstance(exception, Exception)
def test_transformers_generated_visit_func(self):
"""测试transformers_generated_visit_func函数"""
generator = generated_decoder_layer_visit_func(self.model)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
def test_transformers_generated_visit_func_with_custom_blocks(self):
"""测试transformers_generated_visit_func函数,使用自定义的transformer_blocks"""
transformer_blocks = [(name, module)
for name, module in self.model.named_modules()
if "decoder" in module.__class__.__name__.lower()]
generator = generated_decoder_layer_visit_func(self.model, transformer_blocks)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
@patch('torch.distributed.get_rank')
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.barrier')
def test_transformers_generated_forward_func(self, mock_barrier, mock_get_world_size, mock_get_rank):
"""测试_transformers_generated_forward_func函数"""
mock_get_rank.return_value = 0
mock_get_world_size.return_value = 1
input_data = torch.randn(5, 10)
generator = transformers_generated_forward_func(self.model, input_data)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
self.assertTrue(torch.equal(request.args[0], input_data))
if __name__ == '__main__':
unittest.main()