* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_
#define INC_GRAPH_UTILS_NODE_UTILS_H_
#include <set>
#include <vector>
#include <cstring>
#include "graph/types.h"
#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/compute_graph.h"
namespace ge {
extern const std::set<std::string> kConstOpTypes;
extern const std::set<std::string> kEnterOpTypes;
extern const std::set<std::string> kMergeOpTypes;
extern const std::set<std::string> kSwitchOpTypes;
extern const std::set<std::string> kNextIterationOpTypes;
extern const std::set<std::string> kExitOpTypes;
extern const std::set<std::string> kIfOpTypes;
extern const std::set<std::string> kWhileOpTypes;
extern const std::set<std::string> kCaseOpTypes;
extern const std::set<std::string> kForOpTypes;
class NodeUtils {
public:
static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor);
static graphStatus SetAllAnchorStatus(const NodePtr &node_ptr);
static graphStatus SetAllAnchorStatus(Node &node);
static bool IsAnchorStatusSet(const NodePtr &node_ptr);
static bool IsAnchorStatusSet(const Node &node);
static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node);
static void UpdateIsInputConst(const NodePtr &node_ptr);
static void UpdateIsInputConst(Node &node);
static bool IsConst(const Node &node);
static void UnlinkAll(const Node &node);
static bool ClearInputDesc(const OpDescPtr &op_desc, const uint32_t index);
static bool ClearOutputDesc(const OpDescPtr &op_desc, const uint32_t index);
static graphStatus AppendInputAnchor(const NodePtr &node, const uint32_t num);
static graphStatus RemoveInputAnchor(const NodePtr &node, const uint32_t num);
static graphStatus AppendOutputAnchor(const NodePtr &node, const uint32_t num);
static graphStatus RemoveOutputAnchor(const NodePtr &node, const uint32_t num);
static GeTensorDesc GetOutputDesc(const Node &node, const uint32_t index);
static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow);
static std::string GetNodeType(const Node &node);
static std::string GetNodeType(const NodePtr &node);
static graphStatus GetDirectSubgraphs(const NodePtr &node, std::vector<ComputeGraphPtr> &subgraphs);
static ComputeGraphPtr GetSubgraph(const Node &node, const uint32_t index);
static graphStatus SetSubgraph(Node &node, const uint32_t index, const ComputeGraphPtr &subgraph);
* @brief Add the subgraph to the node with the ir name, will not register the ir name with type
* @param node the node the subgraph will be added
* @param subgraph_ir_name the subgraph ir name
* @param subgraph the subgraph
* @return GRAPH_SUCCESS: success, others: failed
*/
static graphStatus AddSubgraph(Node &node, const std::string &subgraph_ir_name, const ComputeGraphPtr &subgraph);
* @brief Add the static subgraph to the node with the given ir name, will register the ir name with as kStatic
* @param node_ptr the node the subgraph will be added
* @param subgraph_ir_name the subgraph ir name
* @param subgraph the subgraph
* @return GRAPH_SUCCESS: success, others: failed
*/
static graphStatus AddSubgraph(const NodePtr &node_ptr, const std::string &subgraph_ir_name,
const ComputeGraphPtr &subgraph);
* @brief Add the dynamic subgraph to the node with the given ir name, will register the ir name with as kDynamic
* @param node_ptr the node the subgraph will be added
* @param subgraph_ir_name the subgraph ir name
* @param subgraphs vector of dynamic subgraphs
* @return GRAPH_SUCCESS: success, others: failed
*/
static graphStatus AddSubgraphs(const NodePtr &node_ptr, const std::string &subgraph_ir_name,
const std::vector<ComputeGraphPtr> &subgraphs);
static std::string GenDynamicSubgraphName(const std::string &subgraph_ir_name, int64_t index);
static NodePtr CreatNodeWithoutGraph(const OpDescPtr op_desc);
static bool IsSubgraphInput(const NodePtr &node);
static bool IsSubgraphInput(const Node *const node);
static bool IsSubgraphOutput(const NodePtr &node);
static NodePtr GetParentInput(const Node &node);
static NodePtr GetParentInput(const NodePtr &node);
static bool IsWrapperNode(const NodePtr &node);
static NodePtr GetParentNode(const Node &node);
static NodePtr GetParentNode(const NodePtr &node);
static InDataAnchorPtr GetParentInDataAnchor(const NodePtr &node);
static NodeToOutAnchor GetParentInputAndAnchor(const NodePtr &node);
static NodeToOutAnchor GetParentInputAndAnchorCrossSubgraph(const NodePtr &node);
static bool IsDynamicShape(const Node &node);
static bool IsDynamicShape(const NodePtr &node);
static bool IsWhileVaryingInput(const ge::NodePtr &node);
static bool GetConstOpType(const NodePtr &node, std::string &type);
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node);
* 获取`node`挂载的所有子图中的索引为`index`的Data节点集合;
* 每个子图最多能找到一个跟`index`匹配的Data节点
* @param node
* @param index
* @return
*/
static std::vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, const int32_t index);
* 获取`node`挂载的所有子图中的NetOutput节点集合;
* 每个子图有且只有一个NetOutput节点
* @param node
* @return
*/
static std::vector<NodePtr> GetSubgraphOutputNodes(const Node &node);
* 获取`node`所在的图对应的根图
* @param node
* @return
*/
static ComputeGraphPtr FindRootGraph(const Node &node);
* 根据`node_filter`获取被node控制的输出节点
* @param node
* @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输出节点的获取
* @return
*/
static std::vector<NodePtr> GetOutControlNodes(const Node &node, const NodeFilter &node_filter);
* 根据`node_filter`获取node的输出数据消费节点
* @param node
* @param node_filter 数据边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输出节点的获取
* @return
*/
static std::vector<NodePtr> GetOutDataNodes(const Node &node, const NodeFilter &node_filter);
* 根据`node_filter`获取控制node的输入节点
* @param node
* @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输入节点的获取
* @return
*/
static std::vector<NodePtr> GetInControlNodes(const Node &node, const NodeFilter &node_filter);
* 根据`node_filter`获取node的数据输入节点
* @param node
* @param node_filter 数据边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输入节点的获取
* @return
*/
static std::vector<NodePtr> GetInDataNodes(const Node &node, const NodeFilter &node_filter);
static NodePtr GetInDataNodeByIndex(const Node &node, const int32_t index);
static std::pair<NodePtr, OutDataAnchorPtr> GetInDataNodeAndAnchorByIndex(const Node &node, const int32_t index);
static std::vector<std::pair<InDataAnchorPtr, NodePtr>> GetOutDataNodesWithAnchorByIndex(const Node &node,
const int32_t index);
* 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点的类型
* 其他情况返回`node`本身的节点类型
* @param node
* @return
*/
static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node);
* 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点对象
* 其他情况返回`node`本身
* @param node
* @return
*/
static NodePtr GetInNodeCrossSubgraph(const ge::NodePtr &node);
static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node);
static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node,
int32_t &peer_out_anchor_index);
static bool IsNodeInRootGraph(const NodePtr &node);
static bool IsMultiBranchControlFlowOp(const NodePtr &node);
static graphStatus SetNodeParallelGroup(Node &node, const char_t *const group_name);
static graphStatus UpdateInputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape);
static graphStatus UpdateOutputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape);
static bool IsDtResourceNode(const NodePtr &node);
static bool IsLikeAtomicClean(const NodePtr &node);
* 用于判断identity节点是否被用于控制先读后写顺序的,如果是的话,
* 则图优化的时候不能无脑删除identity节点来提升性能
* @param node_ptr
* @return
*/
static bool IsIdentityUsefulForRWControl(const NodePtr &node_ptr);
* 尝试通过pld占位节点对应的实际const节点来获取权重
* @param node_ptr placeholder的占位节点,常见于图拆分中间状态的图的输入节点类型
* @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空
* @return 失败时代表内部流程错误,成功时不代表一定获取到了权重
*/
static graphStatus TryGetWeightByPlaceHolderNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor);
* 尝试通过Data占位节点对应的实际const节点来获取权重
* @param node_ptr Data占位节点,常见于子图的输入节点类型
* @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空
* @return 失败时代表内部流程错误,成功时不代表一定获取到了权重
*/
static graphStatus TryGetWeightByDataNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor);
* 判断`node`的名称是否是`name`
* @param node
* @param name
* @return 如果是的话,返回true,否则 false
*/
static bool IsNameEqual(const NodePtr &node, const ge::char_t *const name);
* 判断`node`的类型是否是`type`
* @param node
* @param type
* @return
*/
static bool IsTypeEqual(const NodePtr &node, const ge::char_t *const type);
static NodePtr GetNodeWithMinimalId(const std::vector<NodePtr> &nodes);
};
struct NodeCompareKey {
bool operator()(const NodePtr &n0, const NodePtr &n1) const {
if ((n0 == nullptr) || (n1 == nullptr)) {
return false;
}
int32_t comp_res = strcmp(n0->GetNamePtr(), n1->GetNamePtr());
if (comp_res == 0) {
const auto graph0 = n0->GetOwnerComputeGraph();
const auto graph1 = n1->GetOwnerComputeGraph();
if ((graph0 == nullptr) || (graph1 == nullptr)) {
return false;
}
return (graph0->GetName() < graph1->GetName());
}
return (comp_res < 0);
}
};
using OrderedNodeSet = std::set<NodePtr, NodeCompareKey>;
}
#endif