* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/memory_conflict/identity_pass.h"
#include <string>
#include <vector>
#include "framework/common/debug/ge_log.h"
#include "common/omg_util/omg_util.h"
#include "common/checker.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/debug/ge_attr_define.h"
namespace ge {
namespace {
const auto node_filter = [](const Node &ctrl_node) { return OpTypeUtils::IsAssignLikeNode(ctrl_node.GetType()); };
bool IsIdentityUsefulForVarCache(const NodePtr &node) {
if (node->GetType() != READVARIABLEOP) {
return false;
}
const auto &out_assign_nodes = NodeUtils::GetOutControlNodes(*node, node_filter);
if (out_assign_nodes.empty()) {
return false;
}
for (const auto &out_data_node : node->GetOutDataNodes()) {
std::unordered_set<Node *> out_data_node_all_in_nodes;
for (const auto &out_data_node_all_in_node : out_data_node->GetInAllNodes()) {
(void)out_data_node_all_in_nodes.insert(out_data_node_all_in_node.get());
}
for (const auto &out_assign_node : out_assign_nodes) {
if (out_data_node_all_in_nodes.find(out_assign_node.get()) != out_data_node_all_in_nodes.end()) {
GELOGI("Node [%s %s]'s out ctrl assign node [%s %s] is also a input of it's data consumer node [%s %s]",
node->GetName().c_str(), node->GetType().c_str(), out_assign_node->GetName().c_str(),
out_assign_node->GetType().c_str(), out_data_node->GetName().c_str(), out_data_node->GetType().c_str());
return true;
}
}
}
return false;
}
Status CheckIdentityUsable(const NodePtr &node, bool &usable) {
std::string node_type;
if (node->GetOpDesc()->HasAttr(ATTR_NO_NEED_CONSTANT_FOLDING) ||
node->GetOpDesc()->HasAttr(ATTR_NAME_IS_INSERTED_BY_CANN)) {
usable = true;
return SUCCESS;
}
for (auto &in_node : node->GetInDataNodes()) {
auto in_node_opdesc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_opdesc);
if ((in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node))
|| !in_node_opdesc->GetSubgraphInstanceNames().empty()) {
usable = true;
return SUCCESS;
}
GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type),
"[Get][OriginalType] of node:%s failed", in_node->GetName().c_str());
bool need_skip = (node_type != SWITCH) && (node_type != REFSWITCH);
if (need_skip) {
GELOGD("skip identity %s connected to switch", node->GetName().c_str());
break;
}
GE_CHECK_NOTNULL(node->GetOutControlAnchor());
if (!node->GetOutControlAnchor()->GetPeerInControlAnchors().empty()) {
usable = true;
return SUCCESS;
}
}
for (auto &out_node : node->GetOutDataNodes()) {
auto out_node_opdesc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_node_opdesc);
const bool is_near_subgraph_out =
(NodeUtils::IsSubgraphOutput(out_node)) || (!out_node_opdesc->GetSubgraphInstanceNames().empty());
if (is_near_subgraph_out) {
usable = true;
return SUCCESS;
}
GE_CHK_STATUS_RET(GetOriginalType(out_node, node_type),
"[Get][OriginalType] of node:%s failed", out_node->GetName().c_str());
if ((node_type != MERGE) && (node_type != REFMERGE)) {
GELOGD("skip identity %s connected to merge", node->GetName().c_str());
break;
}
GE_CHECK_NOTNULL(node->GetInControlAnchor());
if (!node->GetInControlAnchor()->GetPeerOutControlAnchors().empty()) {
usable = true;
return SUCCESS;
}
}
usable = IsIdentityUsefulForVarCache(node);
return SUCCESS;
}
}
Status IdentityPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
std::string type;
Status status_ret = GetOriginalType(node, type);
if (status_ret != SUCCESS) {
GELOGE(status_ret, "[Get][OriginalType] of node:%s failed.", node->GetName().c_str());
return status_ret;
}
if ((type != IDENTITY) && (type != IDENTITYN) && (type != READVARIABLEOP)) {
return SUCCESS;
}
if (!force_) {
bool usable = false;
auto ret = CheckIdentityUsable(node, usable);
if (ret != SUCCESS) {
return ret;
}
if (usable) {
return SUCCESS;
}
}
size_t n = node->GetOpDesc()->GetOutputsSize();
if (node->GetOpDesc()->GetInputsSize() != n) {
REPORT_INNER_ERR_MSG("E19999", "Num:%zu of input desc node:%s(%s) not equal to it's output desc num:%zu, "
"check invalid", node->GetOpDesc()->GetInputsSize(),
node->GetName().c_str(), node->GetType().c_str(), n);
GELOGE(PARAM_INVALID, "[Check][Param] Num:%zu of input desc node:%s(%s) not equal to it's output desc num:%zu",
node->GetOpDesc()->GetInputsSize(), node->GetName().c_str(), node->GetType().c_str(), n);
return PARAM_INVALID;
}
std::vector<int32_t> io_map;
for (size_t i = 0; i < n; i++) {
io_map.push_back(i);
}
return SafelyRemoveIdentity(node, io_map);
}
Status IdentityPass::SafelyRemoveIdentity(NodePtr &node, const std::vector<int32_t> &io_map) {
if (node->GetType() == READVARIABLEOP) {
const bool read_first_then_write = NodeUtils::IsIdentityUsefulForRWControl(node);
for (const auto &out_data_node : node->GetOutDataNodes()) {
if (read_first_then_write) {
GE_ASSERT_SUCCESS(GraphUtils::CopyOutCtrlEdges(node, out_data_node, node_filter));
GELOGI("Node [%s %s]'s out ctrl relation is copyed to it's data consumer node [%s %s]", node->GetName().c_str(),
node->GetType().c_str(), out_data_node->GetName().c_str(), out_data_node->GetType().c_str());
} else {
GE_ASSERT_SUCCESS(GraphUtils::CopyInCtrlEdges(node, out_data_node));
GELOGI("Node [%s %s]'s in ctrl relation is copyed to it's data consumer node [%s %s]", node->GetName().c_str(),
node->GetType().c_str(), out_data_node->GetName().c_str(), out_data_node->GetType().c_str());
}
}
}
return IsolateAndDeleteNode(node, io_map);
}
REG_PASS_OPTION("IdentityPass").LEVELS(OoLevel::kO3);
}