"""
msmodelslim.utils.dag_utils.model_structure_process 模块的单元测试
"""
import unittest
from unittest.mock import Mock
from ascend_utils.core.dag.dag_node import DagNode
from msmodelslim.utils.dag_utils.model_infos import ModuleType
from msmodelslim.utils.dag_utils.model_structure_process import StructureProcess
class MockDagNode:
def __init__(self, name, op_type, in_features=0, out_features=0, inputs=None, input_nodes=None, output_nodes=None):
self.name_in_network = name
self.op_type = op_type
self.node = Mock()
self.node.in_features = in_features
self.node.out_features = out_features
self.inputs = inputs or []
self.input_nodes = input_nodes or []
self.output_nodes = output_nodes or []
class TestStructureProcess(unittest.TestCase):
"""测试StructureProcess类"""
def test_is_ffn_matmul_with_ffn_matmul_num_2_true_case(self):
"""测试is_ffn_matmul函数,ffn_matmul_num=2,返回True的情况"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
matmul_list = [matmul1, matmul2]
result = StructureProcess.is_ffn_matmul(matmul_list, 2)
self.assertTrue(result)
def test_is_ffn_matmul_with_ffn_matmul_num_2_false_case(self):
"""测试is_ffn_matmul函数,ffn_matmul_num=2,返回False的情况"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=128, out_features=64)
matmul_list = [matmul1, matmul2]
result = StructureProcess.is_ffn_matmul(matmul_list, 2)
self.assertFalse(result)
def test_is_ffn_matmul_with_wrong_length(self):
"""测试is_ffn_matmul函数,长度不匹配的情况"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul_list = [matmul1]
result = StructureProcess.is_ffn_matmul(matmul_list, 2)
self.assertFalse(result)
def test_is_ffn_matmul_with_ffn_matmul_num_3_true_case1(self):
"""测试is_ffn_matmul函数,ffn_matmul_num=3,返回True的情况1"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
matmul3 = MockDagNode("matmul3", "Linear", in_features=64, out_features=32)
matmul_list = [matmul1, matmul2, matmul3]
result = StructureProcess.is_ffn_matmul(matmul_list, 3)
self.assertTrue(result)
def test_is_ffn_matmul_with_ffn_matmul_num_3_true_case2(self):
"""测试is_ffn_matmul函数,ffn_matmul_num=3,返回True的情况2"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
matmul3 = MockDagNode("matmul3", "Linear", in_features=256, out_features=128)
matmul_list = [matmul1, matmul2, matmul3]
result = StructureProcess.is_ffn_matmul(matmul_list, 3)
self.assertTrue(result)
def test_is_ffn_matmul_with_ffn_matmul_num_3_false_case(self):
"""测试is_ffn_matmul函数,ffn_matmul_num=3,返回False的情况"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=128, out_features=64)
matmul3 = MockDagNode("matmul3", "Linear", in_features=32, out_features=16)
matmul_list = [matmul1, matmul2, matmul3]
result = StructureProcess.is_ffn_matmul(matmul_list, 3)
self.assertFalse(result)
def test_is_ffn_matmul_with_unsupported_ffn_matmul_num(self):
"""测试is_ffn_matmul函数,不支持的ffn_matmul_num"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul2 = MockDagNode("matmul2", "Linear", in_features=256, out_features=128)
matmul3 = MockDagNode("matmul3", "Linear", in_features=128, out_features=256)
matmul4 = MockDagNode("matmul4", "Linear", in_features=256, out_features=128)
matmul_list = [matmul1, matmul2, matmul3, matmul4]
with self.assertRaises(Exception) as context:
StructureProcess.is_ffn_matmul(matmul_list, 4)
self.assertIn("unsupported ffn_matmul_num: 4", str(context.exception))
def test_mhsa_matmul_process_with_proj_mat_out_features_times_3_equal_qkv_mat_out_features(self):
"""测试mhsa_matmul_process函数,proj_mat.out_features * 3 == qkv_mat.out_features的情况"""
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=128)
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=384)
matmul_list = [proj_mat, qkv_mat]
qkv_list = []
proj_list = []
StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)
self.assertEqual(qkv_list, [["qkv"]])
self.assertEqual(proj_list, ["proj"])
def test_mhsa_matmul_process_with_proj_mat_out_features_equal_qkv_mat_out_features(self):
"""测试mhsa_matmul_process函数,proj_mat.out_features == qkv_mat.out_features的情况"""
input_node = MockDagNode("input", "Add")
qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
qkv3 = MockDagNode("v", "Linear", in_features=128, out_features=256)
input_node.output_nodes = [qkv1, qkv2, qkv3, MockDagNode("other", "ReLU")]
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.input_nodes = [input_node]
qkv_mat.inputs = [Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
matmul_list = [proj_mat, qkv_mat]
qkv_list = []
proj_list = []
StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)
self.assertEqual(qkv_list, [["q", "k", "v"]])
self.assertEqual(proj_list, ["proj"])
def test_mhsa_matmul_process_with_wrong_length(self):
"""测试mhsa_matmul_process函数,matmul_list长度不正确"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul_list = [matmul1]
qkv_list = []
proj_list = []
StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
def test_mhsa_matmul_process_with_qkv_mat_inputs_not_equal_1(self):
"""测试mhsa_matmul_process函数,qkv_mat.inputs长度不为1"""
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.inputs = [Mock(), Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
matmul_list = [proj_mat, qkv_mat]
qkv_list = []
proj_list = []
StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
def test_mhsa_matmul_process_with_wrong_qkv_mat_list_length(self):
"""测试mhsa_matmul_process函数,qkv_mat_list长度不为3"""
input_node = MockDagNode("input", "Add")
qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
input_node.output_nodes = [qkv1, qkv2, MockDagNode("other", "ReLU")]
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.input_nodes = [input_node]
qkv_mat.inputs = [Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
matmul_list = [proj_mat, qkv_mat]
qkv_list = []
proj_list = []
StructureProcess.mhsa_matmul_process(matmul_list, qkv_list, proj_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
def test_mhsa_matmul_ln_process_with_proj_mat_out_features_times_3_equal_qkv_mat_out_features(self):
"""测试mhsa_matmul_ln_process函数,proj_mat.out_features * 3 == qkv_mat.out_features的情况"""
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=128)
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=384)
ln1 = MockDagNode("ln", "LayerNorm")
matmul_list = [proj_mat, qkv_mat]
ln_list = [ln1]
qkv_list = []
proj_list = []
mhsa_ln_list = []
StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)
self.assertEqual(qkv_list, [["qkv"]])
self.assertEqual(proj_list, ["proj"])
self.assertEqual(mhsa_ln_list, ["ln"])
def test_mhsa_matmul_ln_process_with_proj_mat_out_features_equal_qkv_mat_out_features(self):
"""测试mhsa_matmul_ln_process函数,proj_mat.out_features == qkv_mat.out_features的情况"""
input_node = MockDagNode("input", "Add")
qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
qkv3 = MockDagNode("v", "Linear", in_features=128, out_features=256)
input_node.output_nodes = [qkv1, qkv2, qkv3, MockDagNode("other", "ReLU")]
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.input_nodes = [input_node]
qkv_mat.inputs = [Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
ln1 = MockDagNode("ln", "LayerNorm")
matmul_list = [proj_mat, qkv_mat]
ln_list = [ln1]
qkv_list = []
proj_list = []
mhsa_ln_list = []
StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)
self.assertEqual(qkv_list, [["q", "k", "v"]])
self.assertEqual(proj_list, ["proj"])
self.assertEqual(mhsa_ln_list, ["ln"])
def test_mhsa_matmul_ln_process_with_wrong_length(self):
"""测试mhsa_matmul_ln_process函数,matmul_list长度不正确"""
matmul1 = MockDagNode("matmul1", "Linear", in_features=128, out_features=256)
matmul_list = [matmul1]
ln_list = [MockDagNode("ln", "LayerNorm")]
qkv_list = []
proj_list = []
mhsa_ln_list = []
StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
self.assertEqual(mhsa_ln_list, [])
def test_mhsa_matmul_ln_process_with_qkv_mat_inputs_not_equal_1(self):
"""测试mhsa_matmul_ln_process函数,qkv_mat.inputs长度不为1"""
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.inputs = [Mock(), Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
ln1 = MockDagNode("ln", "LayerNorm")
matmul_list = [proj_mat, qkv_mat]
ln_list = [ln1]
qkv_list = []
proj_list = []
mhsa_ln_list = []
StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
self.assertEqual(mhsa_ln_list, [])
def test_mhsa_matmul_ln_process_with_wrong_qkv_mat_list_length(self):
"""测试mhsa_matmul_ln_process函数,qkv_mat_list长度不为3"""
input_node = MockDagNode("input", "Add")
qkv1 = MockDagNode("q", "Linear", in_features=128, out_features=256)
qkv2 = MockDagNode("k", "Linear", in_features=128, out_features=256)
input_node.output_nodes = [qkv1, qkv2, MockDagNode("other", "ReLU")]
qkv_mat = MockDagNode("qkv", "Linear", in_features=128, out_features=256)
qkv_mat.input_nodes = [input_node]
qkv_mat.inputs = [Mock()]
proj_mat = MockDagNode("proj", "Linear", in_features=256, out_features=256)
ln1 = MockDagNode("ln", "LayerNorm")
matmul_list = [proj_mat, qkv_mat]
ln_list = [ln1]
qkv_list = []
proj_list = []
mhsa_ln_list = []
StructureProcess.mhsa_matmul_ln_process(matmul_list, ln_list, qkv_list, proj_list, mhsa_ln_list)
self.assertEqual(qkv_list, [])
self.assertEqual(proj_list, [])
self.assertEqual(mhsa_ln_list, [])
if __name__ == '__main__':
unittest.main()