* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <stdio.h>
#include <gtest/gtest.h>
#include <nlohmann/json.hpp>
#include "fusion_stub.hpp"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
using namespace std;
using namespace ge;
using namespace fe;
static const std::string SCOPE_KEY = "fusion_scope";
static const std::string PATTERN_KEY = "_pattern";
static const int64_t INVALID_OFFSET = -1;
void GetSrcDstIndex(OpDescPtr src_op, OpDescPtr dst_op, map<string, vector<desc_info>> &src_map,
map<string, vector<desc_info>> &dst_map, uint32_t &src_index, uint32_t &dst_index) {
map<string, vector<desc_info>>::iterator it;
it = dst_map.find(src_op->GetName());
vector<desc_info> &vec1 = it->second;
for (uint32_t loop = 0; loop < vec1.size(); loop++) {
if (vec1[loop].targetname == dst_op->GetName()) {
src_index = vec1[loop].index;
}
}
it = src_map.find(dst_op->GetName());
vector<desc_info> &vec2 = it->second;
for (uint32_t loop = 0; loop < vec2.size(); loop++) {
if (vec2[loop].targetname == src_op->GetName()) {
dst_index = vec2[loop].index;
}
}
}
void
CreateModelGraph(ge::ComputeGraphPtr model_graph, vector<ge::OpDescPtr> &op_list, map<string, vector<desc_info>> &src_map,
map<string, vector<desc_info>> &dst_map) {
uint32_t src_index = 0;
uint32_t dst_index = 0;
for (auto opdef : op_list) {
NodePtr node = model_graph->AddNode(opdef);
}
for (OpDescPtr opdef : op_list) {
vector<string> dst_name_temp_list;
ge::AttrUtils::GetListStr(opdef, OPDESC_DST_NAME, dst_name_temp_list);
NodePtr node = model_graph->FindNode(opdef->GetName());
for (OpDescPtr dst_opdef : op_list) {
vector<string> src_name_temp_list;
ge::AttrUtils::GetListStr(dst_opdef, OPDESC_SRC_NAME, src_name_temp_list);
for (auto src_name_temp : src_name_temp_list) {
if (src_name_temp == opdef->GetName()) {
src_index = 0;
dst_index = 0;
GetSrcDstIndex(opdef, dst_opdef, src_map, dst_map, src_index, dst_index);
NodePtr dst_node = model_graph->FindNode(dst_opdef->GetName());
ge::GraphUtils::AddEdge(node->GetOutDataAnchor(src_index), dst_node->GetInDataAnchor(dst_index));
}
}
}
}
vector<string> dst_name_temp_list1;
for (NodePtr node: model_graph->GetDirectNode()) {
cout << "nodename = " << node->GetName() << endl;
for (auto dstnode1: node->GetOutDataNodes()) {
cout << "output node = " << dstnode1->GetName() << endl;
}
cout << "===========================================" << endl;
}
(void) model_graph->TopologicalSorting();
uint32_t ddr_addr = 0;
for (auto node : model_graph->GetDirectNode()) {
OpDescPtr opdef = node->GetOpDesc();
string node_type = opdef->GetType();
int32_t input_size = opdef->GetInputsSize();
int32_t output_size = opdef->GetOutputsSize();
vector<int64_t> input_list;
for (int32_t loop = 0; loop < input_size; loop++) {
input_list.push_back(ddr_addr++);
}
opdef->SetInputOffset(input_list);
vector<int64_t> output_list;
for (int32_t loop = 0; loop < output_size; loop++) {
output_list.push_back(ddr_addr++);
}
opdef->SetOutputOffset(input_list);
}
return;
}
ge::OpDescPtr CreateOpDefUbFusion(string name, string type, vector<string> &srcname_list, vector<string> &dstname_list,
vector<ge::GeTensorDesc> &inputdesc_list, vector<ge::GeTensorDesc> &outputdesc_list) {
ge::OpDescPtr opdef = std::make_shared<OpDesc>(name, type);
uint32_t src_node_num = inputdesc_list.size();
vector<bool> fusion_is_input_const_vector;
for (uint32_t loop = 0; loop < src_node_num; loop++) {
fusion_is_input_const_vector.push_back(false);
opdef->AddInputDesc(inputdesc_list[loop]);
}
opdef->SetIsInputConst(fusion_is_input_const_vector);
uint32_t dst_node_num = outputdesc_list.size();
for (uint32_t loop = 0; loop < dst_node_num; loop++) {
opdef->AddOutputDesc(outputdesc_list[loop]);
}
return opdef;
}
void filltensordesc(GeTensorDesc &tensor_desc, uint32_t n, uint32_t c, uint32_t h, uint32_t w, uint32_t datatype,
uint32_t format) {
std::vector<int64_t> s_v;
s_v.push_back(n);
s_v.push_back(c);
s_v.push_back(h);
s_v.push_back(w);
GeShape s(s_v);
tensor_desc.SetShape(s);
tensor_desc.SetFormat((Format) format);
tensor_desc.SetDataType((DataType) datatype);
return;
}
std::vector<BufferFusionInfo> SortedBufferFusionFun() {
std::vector<BufferFusionInfo> sorted_buffer_fusion_vec;
return sorted_buffer_fusion_vec;
}
#ifndef DAVINCI_LITE
void SetTvmType(ge::OpDescPtr opdef)
{
ge::AttrUtils::SetInt(opdef, ge::ATTR_NAME_IMPLY_TYPE, static_cast<int64_t>(domi::ImplyType::TVM));
}
#endif
void SetPattern(ge::OpDescPtr opdef, string optype) {
auto key_pattern = PATTERN_KEY;
ge::AttrUtils::SetStr(opdef, key_pattern, optype);
}
bool GetPattern(ge::OpDescPtr opdef, string &optype) {
auto key_pattern = PATTERN_KEY;
if (ge::AttrUtils::GetStr(opdef, key_pattern, optype)) {
return true;
}
return false;
}
void PrintGraph(ge::ComputeGraphPtr graph) {
for (auto node : graph->GetDirectNode()) {
for (auto out_anchor : node->GetAllOutDataAnchors()) {
for (auto peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
ge::NodePtr peer_node = peer_in_anchor->GetOwnerNode();
cout << " src_node name = " << node->GetName() << " dst_node name = " << peer_node->GetName() << endl;
cout << " src index = " << out_anchor->GetIdx() << " dst index = " << peer_in_anchor->GetIdx() << endl;
cout << " src format = " << node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()).GetFormat()
<< " dst format = " << peer_node->GetOpDesc()->GetOutputDesc(peer_in_anchor->GetIdx()).GetFormat()
<< endl;
}
}
}
}
#ifndef DAVINCI_LITE
void SetAICoreOp(ge::OpDescPtr opdef) {
ge::AttrUtils::SetStr(opdef, "tvm_magic", "RT_DEV_BINARY_MAGIC_ELF");
}
#endif