* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "inner_identity_add_pass.h"
#include <stack>
#include "node_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/debug/ge_util.h"
#include "graph/operator_factory.h"
#include "graph/passes/pass_utils.h"
#include "common/checker.h"
namespace ge {
namespace {
const std::string kInnerIdentityAttr = "_inner_identity";
bool IsChangedByRefNode(const NodePtr &node) {
const auto &type = NodeUtils::GetNodeType(node);
return (type == REFSWITCH) || (type == REFMERGE) || (type == READVARIABLEOP);
}
bool IsNoNeedInsertIdentity(const OutDataAnchorPtr &anchor) {
std::stack<OutDataAnchorPtr> out_data_anchor_stack;
out_data_anchor_stack.push(anchor);
auto input_node = anchor->GetOwnerNode();
while (!out_data_anchor_stack.empty()) {
auto out_data_anchor = out_data_anchor_stack.top();
out_data_anchor_stack.pop();
auto node = out_data_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(node);
if (OpTypeUtils::IsDataNode(node->GetType())) {
const auto in_node = NodeUtils::GetParentInput(node);
if (in_node != nullptr) {
node = in_node;
out_data_anchor = node->GetOutDataAnchor(0U);
}
}
int32_t reuse_in_index = -1;
if (!GraphUtils::IsRefFromInput(out_data_anchor, reuse_in_index)) {
input_node = node;
} else {
auto in_data_anchor = node->GetInDataAnchor(reuse_in_index);
GE_ASSERT_NOTNULL(in_data_anchor);
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_out_data_anchor != nullptr) {
out_data_anchor_stack.push(peer_out_data_anchor);
}
}
}
return OpTypeUtils::IsVarLikeNode(input_node->GetType()) || IsChangedByRefNode(input_node);
}
}
Status InnerIdentityAddPass::Run(ComputeGraphPtr graph) {
GE_ASSERT_NOTNULL(graph);
if (connectivity_ == nullptr) {
connectivity_ = ComGraphMakeUnique<ConnectionMatrix>(graph);
if (connectivity_ == nullptr) {
GELOGW("Make shared failed");
return FAILED;
}
GE_ASSERT_SUCCESS(connectivity_->Generate(graph), "Cannot generate connection matrix for graph %s.",
graph->GetName().c_str());
}
GE_ASSERT_SUCCESS(PassUtils::UpdateRefAttr(graph));
for (const auto &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_ASSERT_NOTNULL(op_desc);
bool is_ref = false;
(void)AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref);
if (!is_ref) {
continue;
}
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
int32_t reuse_in_index = -1;
if (!GraphUtils::IsRefFromInput(out_data_anchor, reuse_in_index)) {
continue;
}
auto in_data_anchor = node->GetInDataAnchor(reuse_in_index);
GE_ASSERT_NOTNULL(in_data_anchor);
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(peer_out_data_anchor);
auto ref_input_node = peer_out_data_anchor->GetOwnerNodeBarePtr();
if ((ref_input_node->GetType() == TENSORMOVE) && (peer_out_data_anchor->GetPeerInDataNodesSize() == 1U)) {
GELOGD("ref node %s input is TensorMove %s, this Tensormove has only one output, skip insert inner identity",
node->GetNamePtr(), ref_input_node->GetNamePtr());
continue;
}
if (IsNoNeedInsertIdentity(peer_out_data_anchor)) {
GELOGD("ref node %s input is %s, skip insert inner identity", node->GetNamePtr(), ref_input_node->GetNamePtr());
continue;
}
std::vector<InDataAnchorPtr> target_in_data_anchors;
for (const auto &peer_in_data_anchor : peer_out_data_anchor->GetPeerInDataAnchors()) {
GE_ASSERT_NOTNULL(peer_in_data_anchor);
if ((peer_in_data_anchor == in_data_anchor) || (node == peer_in_data_anchor->GetOwnerNode())) {
continue;
}
if (connectivity_->IsConnected(node, peer_in_data_anchor->GetOwnerNode())) {
target_in_data_anchors.emplace_back(peer_in_data_anchor);
}
}
auto identity = AddIdentity(node->GetOwnerComputeGraph(), peer_out_data_anchor, in_data_anchor);
GE_ASSERT_NOTNULL(identity);
for (const auto &peer_in_data_anchor : target_in_data_anchors) {
GE_ASSERT_SUCCESS(GraphUtils::RemoveEdge(peer_out_data_anchor, peer_in_data_anchor),
"RemoveEdge from %s to %s failed", ref_input_node->GetNamePtr(),
peer_in_data_anchor->GetOwnerNode()->GetNamePtr());
GE_ASSERT_SUCCESS(GraphUtils::AddEdge(identity->GetOutDataAnchor(0), peer_in_data_anchor));
GELOGD("Add edge from %s to %s", identity->GetNamePtr(),
peer_in_data_anchor->GetOwnerNode()->GetNamePtr());
}
}
}
return SUCCESS;
}
NodePtr InnerIdentityAddPass::AddIdentity(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_anchor,
const InDataAnchorPtr &dst_anchor) {
static size_t identity_count = 0;
std::string name = "inner_identity_" + std::to_string(identity_count++);
const auto identity_op = OperatorFactory::CreateOperator(name.c_str(), IDENTITY);
identity_op.BreakConnect();
const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(identity_op);
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(AttrUtils::SetBool(op_desc, kInnerIdentityAttr, true));
GE_ASSERT_TRUE(AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true));
GE_ASSERT_TRUE(AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, true));
auto identity_node = graph->AddNode(op_desc);
GE_ASSERT_NOTNULL(identity_node);
GE_ASSERT_SUCCESS(GraphUtils::InsertNodeAfter(src_anchor, {dst_anchor}, identity_node, 0, 0),
"Insert node %s between %s and %s failed", identity_node->GetNamePtr(),
src_anchor->GetOwnerNode()->GetNamePtr(), dst_anchor->GetOwnerNode()->GetNamePtr());
connectivity_->Update(graph, {identity_node});
GELOGI("Add identity_node %s between src_node %s and dst_node %s", identity_node->GetName().c_str(),
src_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetOwnerNode()->GetName().c_str());
return identity_node;
}
REG_PASS_OPTION("InnerIdentityAddPass").LEVELS(OoLevel::kO0);
}