* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "jit_infer_utils.h"
#include "common/memory/tensor_trans_utils.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_shape_inference.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_info_pre_processor.h"
#include "graph/optimize/symbolic/infer_symbolic_shape/symbolic_infer_util.h"
#include "graph/utils/op_type_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "attribute_group/attr_group_symbolic_desc.h"
#include "opt_info/ge_opt_info.h"
#include <stdbool.h>
namespace ge {
namespace {
* 判断node是否为uninfered node, uninfered node会被切到下一张图
*
* 1、if/case节点, 符号化暂不支持对子图的推导, 因此if和case的输出必没有符号, 切到下一张图后可以根据上一张图的输出把cond构造成const,
* 然后对if/case做剪枝, 消除if/case节点, 把子图展平到根图上, 达成对if/case做符号化推导的目的
* 2、存在没有符号的输入且所有输出都没符号
*
* @param node
* @return
*/
bool IsUnInferedNode(const NodePtr &node) {
auto op_desc = node->GetOpDesc();
size_t empty_input_count = 0u;
for (size_t i = 0U; i < op_desc->GetAllInputsDescPtr().size(); ++i) {
auto attr = op_desc->GetInputDesc(i).GetAttrsGroup<SymbolicDescAttr>();
if (attr == nullptr) {
empty_input_count++;
}
}
size_t empty_output_count = 0;
for (size_t i = 0U; i < op_desc->GetAllOutputsDescPtr().size(); ++i) {
auto attr = op_desc->GetOutputDesc(i).GetAttrsGroup<SymbolicDescAttr>();
if (attr == nullptr) {
empty_output_count++;
}
}
if (empty_output_count == op_desc->GetAllOutputsDescPtr().size()) {
if (empty_input_count != 0) {
return true;
}
auto cond_input = SymbolicInferUtil::GetCondInput(node);
if (cond_input != nullptr && !OpTypeUtils::IsDataNode(cond_input->GetType())) {
return true;
}
}
return false;
}
}
Status JitInferUtils::InferSymbolForGraph(const ComputeGraphPtr &graph, const std::vector<GeTensor> &inputs,
std::vector<NodePtr> &infered_nodes) {
PrepareBeforeInferSymbol(graph, inputs);
InferGraphAndGetInferredNodes(graph, inputs, infered_nodes);
return SUCCESS;
}
Status JitInferUtils::PrepareBeforeInferSymbol(const ComputeGraphPtr &graph, const std::vector<GeTensor> &inputs) {
GE_ASSERT_SUCCESS(GeOptInfo::SetOptInfo());
GraphNodePtr graph_node = make_shared<GraphNode>(0);
graph_node->SetComputeGraph(graph);
GraphPtr graph_ptr = std::make_shared<Graph>(GraphUtilsEx::CreateGraphFromComputeGraph(graph));
graph_node->SetGraph(graph_ptr);
GraphPrepare graph_prepare;
GraphOptimize graph_optimize;
GE_ASSERT_SUCCESS(graph_prepare.PrepareInit(graph_node, graph->GetSessionID()));
GE_ASSERT_SUCCESS(graph_optimize.HandleSummaryOp(graph));
graph_prepare.SetGraphNormalized(true);
GE_ASSERT_SUCCESS(graph_prepare.NormalizeGraph(graph, graph_node->GetOptions(), inputs));
GE_ASSERT_SUCCESS(graph_prepare.PrepareDynShape());
return SUCCESS;
}
static void ClearInferedNodesWithAllDataNodes(std::vector<NodePtr> &infered_nodes) {
size_t data_node_num = 0;
for (auto &node : infered_nodes) {
auto node_type = node->GetType();
if (OpTypeUtils::IsVariableNode(node_type) || OpTypeUtils::IsConstNode(node_type) ||
OpTypeUtils::IsDataNode(node_type)) {
data_node_num++;
}
}
if (data_node_num == infered_nodes.size()) {
infered_nodes.clear();
}
}
static bool ParentNodeInfered(const NodePtr &node, std::vector<NodePtr> &infered_nodes) {
if (!node->GetInDataNodes().empty()) {
for (auto &in_data_node : node->GetInDataNodes()) {
if (std::find(infered_nodes.begin(), infered_nodes.end(),in_data_node) == infered_nodes.end()) {
return false;
}
}
}
if (!node->GetInControlNodes().empty()) {
for (auto &in_control_node : node->GetInControlNodes()) {
if (std::find(infered_nodes.begin(), infered_nodes.end(),in_control_node) == infered_nodes.end()) {
return false;
}
}
}
return true;
}
static void DeleteNodesWithoutParentNode(std::vector<NodePtr> &infered_nodes) {
for (auto it = infered_nodes.begin(); it != infered_nodes.end();) {
if (!ParentNodeInfered(*it, infered_nodes)) {
GELOGD("Infer node:%s. Parent node uninfered", (*it)->GetName().c_str());
it = infered_nodes.erase(it);
} else {
++it;
}
}
}
Status JitInferUtils::InferGraphAndGetInferredNodes(const ComputeGraphPtr &graph, const std::vector<GeTensor> &inputs,
std::vector<NodePtr> &infered_nodes) {
bool is_single_op_scene = false;
ge::AttrUtils::GetBool(graph, ge::ATTR_SINGLE_OP_SCENE, is_single_op_scene);
if (is_single_op_scene) {
GELOGI("Skip auto fuse for single op scene.");
return SUCCESS;
}
GE_ASSERT_GRAPH_SUCCESS(SymbolicShapeSymbolizer::Symbolize(graph, inputs), "Symbolize graph input failed, graph %s",
graph->GetName().c_str());
SymbolicShapeInference symbolic_shape_inference;
GE_ASSERT_SUCCESS(symbolic_shape_inference.Infer(graph));
std::vector<NodePtr> uninfered_nodes;
for (auto &node : graph->GetDirectNode()) {
if (IsUnInferedNode(node)) {
GELOGD("[%s][%s] add to uninfered_nodes.", node->GetNamePtr(), node->GetTypePtr());
uninfered_nodes.push_back(node);
continue;
}
infered_nodes.push_back(node);
}
if (uninfered_nodes.size() == 1 && uninfered_nodes[0]->GetType() == "NetOutput") {
infered_nodes.push_back(uninfered_nodes[0]);
}
ClearInferedNodesWithAllDataNodes(infered_nodes);
DeleteNodesWithoutParentNode(infered_nodes);
GELOGD("Infer node size: %d", infered_nodes.size());
return SUCCESS;
}
}