"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import unittest
from typing import Tuple, Optional
from unittest.mock import patch
import torch
import torch.nn as nn
from msmodelslim.core.base.protocol import ProcessRequest
from msmodelslim.model.common.layer_wise_forward import (
TransformersForwardBreak,
generated_decoder_layer_visit_func,
transformers_generated_forward_func
)
class DummyDecoderLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) \
-> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
return (x, None)
class SimpleModel(nn.Module):
"""简单的测试模型,包含一个线性层和一个TransformerDecoder层"""
def __init__(self):
super().__init__()
self.decoder = DummyDecoderLayer()
def forward(self, x: torch.Tensor) \
-> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
x = self.decoder(x)
return x
class TestTransformersGenerated(unittest.TestCase):
"""测试_TransformersForwardBreak异常类和相关的生成器函数"""
def setUp(self):
"""测试前的准备工作"""
self.model = SimpleModel()
def test_transformers_forward_break(self):
"""测试_TransformersForwardBreak异常类"""
exception = TransformersForwardBreak()
self.assertIsInstance(exception, Exception)
def test_transformers_generated_visit_func(self):
"""测试transformers_generated_visit_func函数"""
generator = generated_decoder_layer_visit_func(self.model)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
def test_transformers_generated_visit_func_with_custom_blocks(self):
"""测试transformers_generated_visit_func函数,使用自定义的transformer_blocks"""
transformer_blocks = [(name, module)
for name, module in self.model.named_modules()
if "decoder" in module.__class__.__name__.lower()]
generator = generated_decoder_layer_visit_func(self.model, transformer_blocks)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
@patch('torch.distributed.get_rank')
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.barrier')
def test_transformers_generated_forward_func(self, mock_barrier, mock_get_world_size, mock_get_rank):
"""测试_transformers_generated_forward_func函数"""
mock_get_rank.return_value = 0
mock_get_world_size.return_value = 1
input_data = torch.randn(5, 10)
generator = transformers_generated_forward_func(self.model, input_data)
request = next(generator)
self.assertIsInstance(request, ProcessRequest)
self.assertEqual(request.name, "decoder")
self.assertEqual(request.module, self.model.decoder)
self.assertTrue(torch.equal(request.args[0], input_data))
if __name__ == '__main__':
unittest.main()