* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "graph/utils/graph_utils_ex.h"
#include "register/op_registry.h"
#include "parser/common/op_registration_tbe.h"
#include "parser/onnx_parser.h"
#include "st/parser_st_utils.h"
#include "ge/ge_api_types.h"
#include "depends/ops_stub/ops_stub.h"
#include "framework/omg/parser/parser_factory.h"
#include "parser/onnx/onnx_util.h"
#include "graph/ge_local_context.h"
#include "common/ge_common/ge_types.h"
#include "parser/onnx/onnx_parser_internal.h"
#include "parser/onnx/onnx_file_constant_parser.h"
namespace ge {
class STestOnnxParser : public testing::Test {
protected:
void SetUp() {
ParerSTestsUtils::ClearParserInnerCtx();
RegisterCustomOp();
}
void TearDown() {}
public:
void RegisterCustomOp();
};
static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}
static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
return SUCCESS;
}
Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) {
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func =
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX);
if (auto_mapping_subgraph_index_func == nullptr) {
std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl;
return FAILED;
}
return auto_mapping_subgraph_index_func(graph,
[&](int data_index, int &parent_index) -> Status {
parent_index = data_index + 1;
return SUCCESS;
},
[&](int output_index, int &parent_index) -> Status {
parent_index = output_index;
return SUCCESS;
});
}
void STestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Conv2D")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Conv")
.ParseParamsFn(ParseParams);
REGISTER_CUSTOM_OP("If")
.FrameworkType(domi::ONNX)
.OriginOpType({"ai.onnx::9::If",
"ai.onnx::10::If",
"ai.onnx::11::If",
"ai.onnx::12::If",
"ai.onnx::13::If"})
.ParseParamsFn(ParseParams)
.ParseParamsByOperatorFn(ParseParamByOpFunc)
.ParseSubgraphPostFn(ParseSubgraphPostFnIf);
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Add")
.ParseParamsFn(ParseParams);
REGISTER_CUSTOM_OP("Identity")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Identity")
.ParseParamsFn(ParseParams);
std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
domi::OpRegTbeParserFactory::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}
ge::onnx::GraphProto CreateOnnxGraph(const std::string &op_type = "Add") {
ge::onnx::GraphProto onnx_graph;
(void)onnx_graph.add_input();
(void)onnx_graph.add_output();
::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node();
::ge::onnx::NodeProto* node_add = onnx_graph.add_node();
node_const1->set_op_type(kOpTypeConstant);
node_const2->set_op_type(kOpTypeConstant);
node_add->set_op_type(op_type);
::ge::onnx::AttributeProto* attr = node_const1->add_attribute();
attr->set_name(ge::kAttrNameValue);
::ge::onnx::TensorProto* tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL);
attr = node_const1->add_attribute();
tensor_proto->add_external_data();
ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data();
string_proto->set_key("location");
string_proto->set_value("const.onnx");
attr = node_const2->add_attribute();
attr->set_name(ge::kAttrNameValue);
tensor_proto = attr->mutable_t();
tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT);
return onnx_graph;
}
TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtilsEx::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y");
}
TEST_F(STestOnnxParser, onnx_parser_precheck) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
ge::GetParserContext().run_mode = ge::ONLY_PRE_CHECK;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, GRAPH_FAILED);
}
TEST_F(STestOnnxParser, onnx_parser_if_node) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_if.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, FAILED);
}
TEST_F(STestOnnxParser, onnx_parser_expand_one_to_many) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
MemBuffer *buffer = ParerSTestsUtils::MemBufferFromFile(model_file.c_str());
ret = ge::aclgrphParseONNXFromMem(reinterpret_cast<char *>(buffer->data), buffer->size, parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
TEST_F(STestOnnxParser, onnx_parser_expand_one_to_many_with_stable_sort) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto graph_options = GetThreadLocalContext().GetAllGraphOptions();
graph_options[OPTION_TOPOSORTING_MODE] = "3";
GetThreadLocalContext().SetGraphOption(graph_options);
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
graph_options = GetThreadLocalContext().GetAllGraphOptions();
graph_options[OPTION_TOPOSORTING_MODE] = "";
GetThreadLocalContext().SetGraphOption(graph_options);
MemBuffer *buffer = ParerSTestsUtils::MemBufferFromFile(model_file.c_str());
ret = ge::aclgrphParseONNXFromMem(reinterpret_cast<char *>(buffer->data), buffer->size, parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
TEST_F(STestOnnxParser, onnx_parser_to_json) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_clip_v9.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
OnnxModelParser onnx_parser;
const char *json_file = "tmp.json";
auto ret = onnx_parser.ToJson(model_file.c_str(), json_file);
EXPECT_EQ(ret, SUCCESS);
const char *json_null = nullptr;
ret = onnx_parser.ToJson(model_file.c_str(), json_null);
EXPECT_EQ(ret, FAILED);
const char *model_null = nullptr;
ret = onnx_parser.ToJson(model_null, json_null);
EXPECT_EQ(ret, FAILED);
}
TEST_F(STestOnnxParser, onnx_parser_const_data_type) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_const_type.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_if_const_intput.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto onnx_graph = model_proto.mutable_graph();
*onnx_graph = CreateOnnxGraph();
ge::onnx::OperatorSetIdProto* op_st = model_proto.add_opset_import();
op_st->set_domain("ai.onnx");
op_st->set_version(11);
ge::Graph graph;
Status ret = modelParser.ModelParseToGraph(model_proto, graph);
EXPECT_EQ(ret, INTERNAL_ERROR);
}
TEST_F(STestOnnxParser, FileConstantParseParam)
{
OnnxFileConstantParser parser;
ge::onnx::NodeProto input_node;
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant");
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src);
ge::onnx::TensorProto tensor_proto;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("value");
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
*attribute_tensor = tensor_proto;
attribute_tensor->set_data_type(OnnxDataType::UINT16);
attribute_tensor->add_dims(4);
ge::onnx::StringStringEntryProto *string_proto1 = attribute_tensor->add_external_data();
string_proto1->set_key("location");
string_proto1->set_value("/tmp/weight");
ge::onnx::StringStringEntryProto *string_proto2 = attribute_tensor->add_external_data();
string_proto2->set_key("offset");
string_proto2->set_value("4");
ge::onnx::StringStringEntryProto *string_proto3 = attribute_tensor->add_external_data();
string_proto3->set_key("length");
string_proto3->set_value("16");
Status ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(STestOnnxParser, onnx_test_PreChecker_not_support)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto onnx_graph = model_proto.mutable_graph();
*onnx_graph = CreateOnnxGraph("Test");
ge::onnx::OperatorSetIdProto* op_st = model_proto.add_opset_import();
op_st->set_domain("ai.onnx");
op_st->set_version(11);
ge::Graph graph;
Status ret = modelParser.ModelParseToGraph(model_proto, graph);
EXPECT_EQ(ret, FAILED);
EXPECT_EQ(PreChecker::Instance().HasError(), true);
}
TEST_F(STestOnnxParser, onnx_test_SetExternalPath)
{
OnnxModelParser modelParser;
ge::onnx::ModelProto model_proto;
auto onnx_graph = model_proto.mutable_graph();
*onnx_graph = CreateOnnxGraph("Test");
auto ret = modelParser.SetExternalPath("/usr/local", model_proto);
EXPECT_EQ(ret, SUCCESS);
}
}