* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/utils/graph_utils_ex.h"
#include "graph_metadef/common/ge_common/util.h"
#include "common/util/trace_manager/trace_manager.h"
#include "graph/refiner/format_refiner.h"
#include "graph/normal_graph/operator_impl.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_desc_utils_ex.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/transformer_utils.h"
#include "graph/utils/node_utils_ex.h"
#include "common/util/mem_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/operator_factory.h"
namespace af {
graphStatus GraphUtilsEx::InferOriginFormat(const ComputeGraphPtr &graph) {
return FormatRefiner::InferOrigineFormat(graph);
}
graphStatus GraphUtilsEx::InferShapeInNeed(const ComputeGraphPtr &graph) {
GE_LOGW_IF(graph->TopologicalSorting() != GRAPH_SUCCESS, "Verify failed.");
for (const auto &node_ptr : graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node_ptr);
const auto op_desc = node_ptr->GetOpDesc();
bool is_need_infer = false;
(void)AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer);
if (is_need_infer) {
if (NodeUtilsEx::Verify(node_ptr) != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Verifying %s failed.", node_ptr->GetName().c_str());
GELOGE(FAILED, "[Call][Verify] Verifying %s failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
const graphStatus status = NodeUtilsEx::InferShapeAndType(node_ptr);
if ((!OpTypeUtils::IsDataNode(node_ptr->GetType())) && (status == GRAPH_PARAM_INVALID)) {
GELOGI("Op %s does not have the IMPLEMT_INFERFUNC definition, "
"and subsequent operators no longer perform shape inference.",
node_ptr->GetName().c_str());
break;
}
if (status != GRAPH_SUCCESS) {
REPORT_INNER_ERR_MSG("E18888", "Inferring %s failed.", node_ptr->GetName().c_str());
GELOGE(FAILED, "[Call][InferShapeAndType] Inferring %s failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_anchor->GetOwnerNodeBarePtr()->GetOpDesc());
auto output_tensor = out_anchor->GetOwnerNodeBarePtr()->GetOpDesc()->MutableOutputDesc(
static_cast<uint32_t>(out_anchor->GetIdx()));
GE_CHECK_NOTNULL(output_tensor);
TensorUtils::SetRealDimCnt(*(output_tensor.get()),
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto peer_in_tensor_desc = peer_anchor->GetOwnerNodeBarePtr()->GetOpDesc()->MutableInputDesc(
static_cast<uint32_t>(peer_anchor->GetIdx()));
GE_CHECK_NOTNULL(peer_in_tensor_desc);
OpDescUtilsEx::UpdateShapeAndDType(output_tensor, peer_in_tensor_desc);
}
}
}
}
return GRAPH_SUCCESS;
}
std::vector<NodePtr> GraphUtilsEx::GetUserInputDataNodes(const ComputeGraphPtr &compute_graph) {
std::vector<NodePtr> user_input_nodes;
for (const auto &node : compute_graph->GetInputNodes()) {
if (!AttrUtils::HasAttr(node->GetOpDesc(), "_is_multi_batch_shape_data")) {
user_input_nodes.emplace_back(node);
}
}
return user_input_nodes;
}
graphStatus GraphUtilsEx::CopyGraph(const Graph &src_graph, Graph &dst_graph) {
std::string graph_name;
AscendString ascend_name;
if (dst_graph.GetName(ascend_name) == GRAPH_SUCCESS) {
graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : "");
}
if (graph_name.empty() && (src_graph.GetName(ascend_name) == GRAPH_SUCCESS)) {
graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : "");
}
ComputeGraphPtr new_compute_graph = MakeShared<ComputeGraph>(graph_name);
GE_CHECK_NOTNULL(new_compute_graph);
const ComputeGraphPtr src_compute_graph = GraphUtilsEx::GetComputeGraph(src_graph);
GE_CHECK_NOTNULL(src_compute_graph);
if (src_compute_graph->GetParentGraph() != nullptr) {
GELOGE(GRAPH_FAILED, "[Check][RootGraph] Only support copy root graph, current graph name:%s, "
"parent graph name:%s.", src_compute_graph->GetName().c_str(),
src_compute_graph->GetParentGraph()->GetName().c_str());
return GRAPH_FAILED;
}
const int32_t depth = 0;
std::map<ConstNodePtr, NodePtr> node_old_2_new;
std::map<ConstOpDescPtr, OpDescPtr> op_desc_old_2_new;
graphStatus ret = GraphUtils::CopyComputeGraph(src_compute_graph, new_compute_graph,
node_old_2_new, op_desc_old_2_new, depth);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Copy][Graph] failed, ret:%d.", ret);
return GRAPH_FAILED;
}
Graph tmp_graph = GraphUtilsEx::CreateGraphFromComputeGraph(new_compute_graph);
ret = GraphUtilsEx::CopyGraphImpl(src_graph, tmp_graph, node_old_2_new, op_desc_old_2_new);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Copy][GraphImpl] failed, ret:%d.", ret);
return GRAPH_FAILED;
}
std::swap(dst_graph, tmp_graph);
return GRAPH_SUCCESS;
}
Graph GraphUtilsEx::CreateGraph() {
return Graph("");
}
Operator GraphUtilsEx::CreateOperator(const char_t *const operator_name, const char_t *const operator_type) {
return af::OperatorFactory::CreateOperator(operator_name, operator_type);
}
}
af::Graph GeApiWrapper_CreateGraphFromComputeGraph(const af::ComputeGraphPtr &compute_graph) {
return af::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph);
}
size_t GeApiWrapper_GetComputeGraphInputSize(const af::Graph &graph) {
return af::GraphUtilsEx::GetComputeGraph(graph)->GetInputSize();
}
size_t GeApiWrapper_GetComputeGraphOutputSize(const af::Graph &graph) {
return af::GraphUtilsEx::GetComputeGraph(graph)->GetOutputSize();
}