9c6f9fba创建于 2025年12月30日历史提交
import io
import sys
import unittest
from unittest.mock import patch
from unittest import mock
from msopst.st.interface.framework import pytorch_model_parser


class Model:
    pass


class Graph:
    def __init__(self):
        self.input = 1
        self.value_info = 1


class Tensor:
    def __init__(self):
        self.dims = 1
        self.data_type = 1


class Type:
    pass


class Tensor_Type:
    pass


class Shape:
    pass


class Dim:
    def __init__(self):
        self.dim_param = "1"
        self.dim_value = 1


class Node:
    def __init__(self):
        self.name = "1"
        self.input = 1
        self.output = 2


class Attr:
    def __init__(self):
        self.name = "1"
        self.type = 2


class Args:
    def __init__(self):
        self.input_file = "a.json"
        self.output_path = ""
        self.model_path = "./"
        self.quiet = ""


class test_pytorch_model_parser(unittest.TestCase):

    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_check_shape_and_notice(self):
        shape = [1, 2, 3, 4]
        out = pytorch_model_parser.check_shape_and_notice(shape)
        self.assertTrue(out is False)
        shape = "1w,2,3,4"
        out = pytorch_model_parser.check_shape_and_notice(shape)
        self.assertTrue(out is False)
        shape = "0,2,3,4"
        out = pytorch_model_parser.check_shape_and_notice(shape)
        self.assertTrue(out is False)
        shape = "1,2,3,4"
        out = pytorch_model_parser.check_shape_and_notice(shape)
        self.assertTrue(out is True)

    def test__get_tensor_shape(self):
        tensor_names = ["a", "b"]
        all_tensors = {"a": {'dtype': 1, 'shape': [1, 3, 224, 224]}}
        dtypes, shapes = pytorch_model_parser._get_tensor_shape(tensor_names, all_tensors)
        print("\n\n\n\n\n")
        print("znnnnnnnnnnnnnnnnn")
        print("dtypes:", dtypes)
        print("shapes:", shapes)
        print("\n\n\n\n\n")
        self.assertTrue(dtypes == [1])
        self.assertTrue(shapes == [[1, 3, 224, 224]])

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._update_tensor_info")
    def test__parse_tensor_info(self, mock__update_tensor_info):
        mock__update_tensor_info.return_value = "1"
        all_tensors = {"b": 1}
        tensor1 = Tensor()
        tensor2 = Tensor()
        tensor3 = Tensor()
        type1 = Type()
        type2 = Type()
        tensor_type1 = Tensor_Type()
        tensor_type1.elem_type = "int32"
        tensor_type2 = Tensor_Type()
        tensor_type2.elem_type = "int32"
        tensor_type2.shape = ""
        tensor1.name = "a"
        tensor2.name = "b"
        tensor3.name = "c"
        type1.tensor_type = tensor_type1
        type2.tensor_type = tensor_type2
        tensor1.type = type1
        tensor3.type = type2
        out = pytorch_model_parser._parse_tensor_info(
            [tensor1, tensor2, tensor3], all_tensors)
        self.assertTrue(out == {'b': 1})

        shape = Shape()
        dim1 = Dim()
        dim2 = Dim
        shape.dim = [dim1, dim2]
        tensor_type2.shape = shape
        type2.tensor_type = tensor_type2
        tensor3.type = type2
        out = pytorch_model_parser._parse_tensor_info(
            [tensor1, tensor2, tensor3], all_tensors)
        self.assertTrue(out == {'b': 1})

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._parse_tensor_info")
    def test__get_all_tensors(self, mock__parse_tensor_info):
        tensor1 = Tensor()
        tensor2 = Tensor()
        graph = Graph()
        tensor1.name = "a"
        tensor2.name = "b"
        graph.initializer = [tensor1, tensor2]
        mock__parse_tensor_info.return_value = {"a": 1}
        out = pytorch_model_parser._get_all_tensors(graph)
        self.assertTrue(out == {'a': 1, 'b': {'dtype': 'float', 'shape': 1}})

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           ".check_shape_and_notice")
    @patch("msopst.st.interface.framework.pytorch_model_parser"
           ".get_shape_and_notice")
    def test__get_input_shape_from_user(self, mock_get_shape_and_notice,
                                        mock_check_shape_and_notice):
        new_shape_map = {"a": {"ori_shape": [1, 2, 3, 4]}}
        mock_get_shape_and_notice.return_value = None
        sys.stdin = io.StringIO("n\n")
        out = pytorch_model_parser._get_input_shape_from_user(new_shape_map)
        self.assertTrue(out is False)

        mock_check_shape_and_notice.return_value = True
        sys.stdin = io.StringIO("m,m\n")
        out = pytorch_model_parser._get_input_shape_from_user(new_shape_map)
        self.assertTrue(out is True)

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._get_tensor_shape")
    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._get_all_tensors")
    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._infer_model_shape")
    @patch("msopst.st.interface.framework.pytorch_model_parser._load_model")
    @patch("onnx.helper.get_attribute_value")
    def test__get_node_list(self, mock_get_attribute_value, mock__load_model,
                            mock__infer_model_shape, mock__get_all_tensors,
                            mock__get_tensor_shape):
        model_path = ""
        ini_op_type = "1"
        mock__load_model.return_value = ""
        model = Model()
        graph = Graph()
        node1 = Node()
        node2 = Node()
        attr1 = Attr()
        attr2 = Attr()
        node1.op_type = "1"
        node1.attribute = [attr1, attr2]
        node2.op_type = "2"
        graph.node = [node1, node2]
        model.graph = graph
        mock__infer_model_shape.return_value = model
        mock__get_all_tensors.return_value = []
        mock__get_tensor_shape.return_value = ("int32", [])
        mock_get_attribute_value.return_value = ""
        out = pytorch_model_parser._get_node_list(model_path, ini_op_type)
        self.assertTrue(out is not None)

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._parse_tensor_info")
    @patch("msopst.st.interface.framework.pytorch_model_parser._load_model")
    def test__get_model_inputs(self, mock__load_model,
                               mock__parse_tensor_info):
        model = Model()
        graph = Graph()
        tensor = Tensor()
        tensor.name = "a"
        graph.initializer = [tensor]
        model.graph = graph
        mock__load_model.return_value = model
        mock__parse_tensor_info.return_value = {"a": {"shape": ""},
                                                "b": {"shape": ""}}
        args = Args()
        pytorch_model = pytorch_model_parser.PyTorchModelParse(args)
        out = pytorch_model._get_model_inputs()
        self.assertTrue(out == {'b': {'ori_shape': '', 'new_shape': []}})

    @patch("onnx.save")
    @patch("onnx.shape_inference.infer_shapes")
    @patch("onnx.helper.make_tensor_value_info")
    @patch("onnx.checker.check_model")
    @patch("msopst.st.interface.framework.pytorch_model_parser._load_model")
    def test__change_shape_fn(self, mock__load_model, mock_check_model,
                              mock_make_tensor_value_info, mock_infer_shapes,
                              mock_save):
        model = Model()
        graph = Graph()
        tensor1 = Tensor()
        tensor2 = Tensor()
        type1 = Type()
        tensor_type = Tensor_Type()
        tensor_type.elem_type = "int8"
        type1.tensor_type = tensor_type
        tensor1.type = type1
        tensor1.name = "a"
        tensor2.name = "b"
        graph.input = [tensor1, tensor2]
        model.graph = graph
        mock__load_model.return_value = model
        mock_check_model.return_value = ""
        mock_make_tensor_value_info.return_value = ""
        mock_infer_shapes.return_value = ""
        mock_save.return_value = ""

        new_shape_map = {"a": {"new_shape": [1, 2, 3]}}
        args = Args()
        pytorch_model = pytorch_model_parser.PyTorchModelParse(args)
        out = pytorch_model._change_shape_fn(new_shape_map)
        self.assertTrue(out is not None)

    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._get_node_list")
    @patch("msopst.st.interface.framework.pytorch_model_parser"
           "._get_input_shape_from_user")
    @patch("msopst.st.interface.utils.check_name_valid")
    @patch("msopst.st.interface.utils.check_path_valid")
    def test_get_model_nodes(self, mock_check_path_valid,
                             mock_check_name_valid,
                             mock__get_input_shape_from_user,
                             mock__get_node_list):
        mock_check_path_valid.return_value = ""
        mock_check_name_valid.return_value = ""
        mock__get_input_shape_from_user.return_value = True
        mock__get_node_list.return_value = []
        ini_op_type = "int8"
        args = Args()
        pytorch_model = pytorch_model_parser.PyTorchModelParse(args)
        pytorch_model._get_model_inputs = mock.Mock(return_value='')
        pytorch_model._change_shape_fn = mock.Mock(return_value='')
        out = pytorch_model.get_model_nodes(ini_op_type)
        self.assertTrue(out == [])
        mock__get_input_shape_from_user.return_value = False
        out = pytorch_model.get_model_nodes(ini_op_type)
        self.assertTrue(out == [])
        pytorch_model.quiet_flag = True
        out = pytorch_model.get_model_nodes(ini_op_type)
        self.assertTrue(out == [])


if __name__ == '__main__':
    unittest.main()