* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/passes/standard_optimize/same_transdata_breadth_fusion_pass.h"
#include <stack>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/framework_types_internal.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/common/trans_op_creator.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "api/gelib/gelib.h"
#include "formats/utils/formats_trans_utils.h"
namespace ge {
namespace {
constexpr size_t kPrintNameLimit = 200U;
bool IsSingleOutputToMultiInput(const OutDataAnchorPtr &out_data_anchor) {
const auto &peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors();
if (peer_in_data_anchors.empty()) {
return false;
}
if (peer_in_data_anchors.size() > 1U) {
return true;
}
const auto peer_in_anchor = *peer_in_data_anchors.begin();
const auto next_node = peer_in_anchor->GetOwnerNode();
if ((next_node != nullptr) && (next_node->GetOpDesc() != nullptr)) {
if (next_node->GetOpDesc()->GetSubgraphInstanceNames().size() > 1U) {
return true;
}
}
return false;
}
bool IsTransOp(const NodePtr &node) {
if (node == nullptr) {
return false;
}
return node->GetType() == CAST || node->GetType() == TRANSPOSE || node->GetType() == TRANSPOSED ||
node->GetType() == RESHAPE || node->GetType() == TRANSDATA;
}
std::unordered_map<std::string, NodeType> kNodeTypeMap = {
{NETOUTPUT, NodeType::kNetOutput},
{DATA, NodeType::kData},
{TRANSDATA, NodeType::kTransdata},
{CAST, NodeType::kCast}
};
NodeType GetNodeType(const NodePtr &node) {
const auto type = node->GetType();
const auto iter = kNodeTypeMap.find(type);
if (iter != kNodeTypeMap.end()) {
return iter->second;
}
if (!node->GetOpDescBarePtr()->GetSubgraphInstanceNames().empty()) {
return NodeType::kWrapperNode;
}
return NodeType::kOthers;
}
bool IsWrapperNode(const OpDescPtr &op_desc) {
return !op_desc->GetSubgraphInstanceNames().empty();
}
void PrintPaths(const Paths &paths) {
GELOGI("Paths size: %zu", paths.size());
for (size_t i = 0U; i < paths.size(); ++i) {
if (!paths[i].empty()) {
std::stringstream ss;
for (size_t j = 0U; j < paths[i].size(); ++j) {
const auto &link_node = paths[i][j];
const auto &pre_node = link_node.real_peer_out_anchor->GetOwnerNode();
const auto out_index = link_node.real_peer_out_anchor->GetIdx();
const auto in_index = link_node.in_anchor->GetIdx();
const auto &print_name = pre_node->GetName().size() > kPrintNameLimit ?
pre_node->GetName().substr(kPrintNameLimit) : pre_node->GetName();
ss << "[" << pre_node->GetType() << "][" << print_name << "][" << out_index << "]->[" << in_index << "]";
}
const auto &link_node = paths[i].back();
const auto &last_node = link_node.in_anchor->GetOwnerNode();
const auto &print_name = last_node->GetName().size() > kPrintNameLimit ?
last_node->GetName().substr(kPrintNameLimit) : last_node->GetName();
ss << "[" << last_node->GetType() << "][" << print_name << "]";
GELOGI("paths[%zu]: %s", i, ss.str().c_str());
}
}
}
void PrintFusedTransdata(const std::vector<Paths> &same_transdata_paths_groups) {
for (const auto &same_transdata_paths : same_transdata_paths_groups) {
std::stringstream ss;
auto iter = same_transdata_paths.begin();
while (iter != same_transdata_paths.end()) {
ss << iter->back().in_anchor->GetOwnerNodeBarePtr()->GetName();
if (iter + 1 != same_transdata_paths.end()) {
ss << ", ";
}
++iter;
}
GELOGI("transdata node [%s] will be fused.", ss.str().c_str());
}
}
graphStatus CheckOpSupported(const Path &path, const LinkNode &link_node, bool &is_supported) {
if (path.empty()) {
is_supported = true;
return GRAPH_SUCCESS;
}
const auto &transdata = link_node.in_anchor->GetOwnerNode();
const auto &first_cast = path.front().in_anchor->GetOwnerNode();
auto input_desc = first_cast->GetOpDesc()->GetInputDescPtr(0U);
GE_CHECK_NOTNULL(input_desc);
auto trans_op_desc = transdata->GetOpDesc();
auto trans_out_desc = trans_op_desc->GetOutputDescPtr(0U);
GE_CHECK_NOTNULL(trans_out_desc);
const std::shared_ptr<GeTensorDesc> new_trans_out_tensor = MakeShared<GeTensorDesc>(*trans_out_desc);
GE_CHECK_NOTNULL(new_trans_out_tensor);
new_trans_out_tensor->SetDataType(input_desc->GetDataType());
auto transdata_op_desc =
TransOpCreator::CreateTransDataOp("tmp_" + trans_op_desc->GetName(), *input_desc, *new_trans_out_tensor);
bool is_transdata_supported = false;
bool is_cast_supported = false;
if (transdata_op_desc != nullptr) {
is_transdata_supported = true;
}
GELOGD("transdata: %s, [%s->%s], new datatype: %s, is_transdata_supported: %d.", transdata->GetNamePtr(),
TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(),
TypeUtils::FormatToSerialString(trans_out_desc->GetFormat()).c_str(),
TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(), is_transdata_supported);
for (const auto &cast_link_node : path) {
const auto &cast_node = cast_link_node.in_anchor->GetOwnerNode();
auto cast_input_desc = cast_node->GetOpDesc()->GetInputDescPtr(0U);
GE_CHECK_NOTNULL(cast_input_desc);
auto cast_output_desc = cast_node->GetOpDesc()->GetOutputDescPtr(0U);
GE_CHECK_NOTNULL(cast_output_desc);
const std::shared_ptr<GeTensorDesc> new_cast_in_tensor = MakeShared<GeTensorDesc>(*cast_input_desc);
GE_CHECK_NOTNULL(new_cast_in_tensor);
const std::shared_ptr<GeTensorDesc> new_cast_out_tensor = MakeShared<GeTensorDesc>(*cast_output_desc);
GE_CHECK_NOTNULL(new_cast_out_tensor);
new_cast_in_tensor->SetFormat(trans_out_desc->GetFormat());
new_cast_out_tensor->SetFormat(trans_out_desc->GetFormat());
auto cast_op_desc =
TransOpCreator::CreateCastOp("tmp_" + cast_node->GetName(), *new_cast_in_tensor, *new_cast_out_tensor);
if (cast_op_desc != nullptr) {
is_cast_supported = true;
}
GELOGD("cast: %s, [%s->%s], new format: %s, is_cast_supported: %d.", cast_node->GetName().c_str(),
TypeUtils::DataTypeToSerialString(cast_input_desc->GetDataType()).c_str(),
TypeUtils::DataTypeToSerialString(cast_output_desc->GetDataType()).c_str(),
TypeUtils::FormatToSerialString(trans_out_desc->GetFormat()).c_str(), is_cast_supported);
if (!is_cast_supported) {
break;
}
}
is_supported = is_transdata_supported && is_cast_supported;
return GRAPH_SUCCESS;
}
std::set<std::string> GetInControlIdentityNodes(const Path &path, const LinkNode &transdata_link_node) {
std::set<std::string> in_node_names;
const auto &transdata_node = transdata_link_node.in_anchor->GetOwnerNode();
for (const auto &in_node : transdata_node->GetInControlNodes()) {
in_node_names.insert(in_node->GetName());
}
for (const auto &link_node : path) {
if (link_node.node_type == NodeType::kTransdata) {
break;
}
const auto node = link_node.in_anchor->GetOwnerNode();
for (const auto &in_node : node->GetInControlNodes()) {
if (in_node->GetType() == IDENTITY) {
in_node_names.insert(in_node->GetName());
}
}
}
return in_node_names;
}
bool IsSame(const CompareInfo &lh, const CompareInfo &rh) {
const bool all_same = (lh.stream_label == rh.stream_label) &&
(lh.input_tensor_desc->GetFormat() == rh.input_tensor_desc->GetFormat()) &&
(lh.output_tensor_desc->GetFormat() == rh.output_tensor_desc->GetFormat()) &&
(lh.input_tensor_desc->GetOriginShape() == rh.input_tensor_desc->GetOriginShape()) &&
(lh.output_tensor_desc->GetOriginShape() == rh.output_tensor_desc->GetOriginShape()) &&
(lh.in_ctrl_nodes == rh.in_ctrl_nodes);
return all_same;
}
void PrintCompareInfo(const NodePtr &node, const CompareInfo &info) {
std::stringstream ss;
ss << "[";
for (auto &item : info.in_ctrl_nodes) {
ss << item << ", ";
}
ss << "]";
GELOGI("transdata[%s] compare info: [IN %s[0x%x] [%s]], [OUT %s[0x%x] [%s]], stream_label: \"%s\", in_ctrl_nodes: %s",
node->GetNamePtr(),
TypeUtils::FormatToSerialString(info.input_tensor_desc->GetFormat()).c_str(),
info.input_tensor_desc->GetFormat(),
formats::ShapeToString(info.input_tensor_desc->GetOriginShape()).c_str(),
TypeUtils::FormatToSerialString(info.output_tensor_desc->GetFormat()).c_str(),
info.output_tensor_desc->GetFormat(),
formats::ShapeToString(info.output_tensor_desc->GetOriginShape()).c_str(),
info.stream_label.c_str(), ss.str().c_str());
}
graphStatus CopyTensorDesc(const GeTensorDesc &src_desc, GeTensorDescPtr &dst_desc) {
GE_ASSERT_NOTNULL(dst_desc);
dst_desc->SetFormat(src_desc.GetFormat());
dst_desc->SetOriginFormat(src_desc.GetOriginFormat());
dst_desc->SetShape(src_desc.GetShape());
dst_desc->SetOriginShape(src_desc.GetOriginShape());
std::vector<std::pair<int64_t, int64_t>> shape_range;
GE_ASSERT_SUCCESS(src_desc.GetShapeRange(shape_range));
GE_ASSERT_SUCCESS(dst_desc->SetShapeRange(shape_range));
std::vector<std::pair<int64_t, int64_t>> origin_shape_range;
GE_ASSERT_SUCCESS(src_desc.GetOriginShapeRange(origin_shape_range));
GE_ASSERT_SUCCESS(dst_desc->SetOriginShapeRange(origin_shape_range));
uint32_t real_dim = 0;
if (TensorUtils::GetRealDimCnt(src_desc, real_dim) == GRAPH_SUCCESS) {
TensorUtils::SetRealDimCnt(*dst_desc, real_dim);
}
return GRAPH_SUCCESS;
}
void UpdateDataType(const GeTensorDescPtr &src_desc, GeTensorDescPtr &dst_desc) {
dst_desc->SetDataType(src_desc->GetDataType());
dst_desc->SetOriginDataType(src_desc->GetOriginDataType());
}
* head--+--cast1--transdata1
* |
* +--cast2--transdata2
* 不能直接用transdata的dtype,比如这种场景,显然head的输出dtype和transdata的dtype是不一样的
*/
graphStatus UpdateTransdataDtype(const Node *trans, const OutDataAnchorPtr &head_out_anchor) {
const auto head_op_desc = head_out_anchor->GetOwnerNode()->GetOpDesc();
const auto head_tensor_desc = head_op_desc->GetOutputDesc(head_out_anchor->GetIdx());
auto trans_op_desc = trans->GetOpDesc();
GE_ASSERT_NOTNULL(trans_op_desc->MutableInputDesc(0));
GE_ASSERT_NOTNULL(trans_op_desc->MutableOutputDesc(0));
trans_op_desc->MutableInputDesc(0)->SetDataType(head_tensor_desc.GetDataType());
trans_op_desc->MutableInputDesc(0)->SetOriginDataType(head_tensor_desc.GetOriginDataType());
trans_op_desc->MutableOutputDesc(0)->SetDataType(head_tensor_desc.GetDataType());
trans_op_desc->MutableOutputDesc(0)->SetOriginDataType(head_tensor_desc.GetOriginDataType());
return GRAPH_SUCCESS;
}
NodePtr CreateDataNode(ComputeGraphPtr &sub_graph, const size_t parent_index) {
const auto data_name = sub_graph->GetName() + "_transdata_fusion_arg_" +std::to_string(parent_index);
OpDescPtr op_desc = MakeShared<OpDesc>(data_name, DATA);
GE_ASSERT_NOTNULL(op_desc);
GE_ASSERT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index));
const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
GE_ASSERT_SUCCESS(op_desc->AddInputDesc(data_desc));
GE_ASSERT_SUCCESS(op_desc->AddOutputDesc(data_desc));
NodePtr data_node = sub_graph->AddNode(op_desc);
GE_ASSERT_NOTNULL(data_node);
GELOGI("add new data[%s]", data_name.c_str());
return data_node;
}
size_t GetKeepTransdataPathIndex(const Paths &paths_group) {
size_t keep_transdata_path_index = 0U;
for (size_t i = 1U; i < paths_group.size(); ++i) {
if (paths_group[i].back().in_anchor->GetOwnerNode()->GetOpDesc()->GetId() <
paths_group[keep_transdata_path_index].back().in_anchor->GetOwnerNode()->GetOpDesc()->GetId()) {
keep_transdata_path_index = i;
}
}
return keep_transdata_path_index;
}
* data--partitioned_call--transdata
*
* +---------------------+
* | partitioned_call |
* | |
* | sub_data |
* | | |
* | Netoutput |
* +---------------------+
*
* 为什么要判断类型对端必须是data节点?
* 因为上图中也满足transdata和data在一个子图中的条件,但是不能在这里处理,应该在AddNewInputForNetOutput中处理
*/
graphStatus ConnetToFusedAnchors(std::vector<InDataAnchorPtr> &fused_anchors, const ComputeGraph *const sub_graph,
OutDataAnchorPtr &new_data_out_anchor) {
auto iter = fused_anchors.begin();
while (iter != fused_anchors.end()) {
auto fused_in_anchor = *iter;
GE_ASSERT_NOTNULL(fused_in_anchor->GetPeerOutAnchor());
if ((fused_in_anchor->GetOwnerNode()->GetOwnerComputeGraphBarePtr() == sub_graph)
&& (fused_in_anchor->GetPeerOutAnchor()->GetOwnerNodeBarePtr()->GetType() == DATA)) {
GE_ASSERT_SUCCESS(fused_in_anchor->Unlink(fused_in_anchor->GetPeerOutAnchor()));
GE_ASSERT_NOTNULL(new_data_out_anchor);
GE_ASSERT_SUCCESS(new_data_out_anchor->LinkTo(fused_in_anchor));
iter = fused_anchors.erase(iter);
} else {
++iter;
}
}
return GRAPH_SUCCESS;
}
bool IsAllNodeInPathsWithSameTransdata(const std::set<InDataAnchorPtr> &allowed_set, std::queue<Path> &path_queue) {
while (!path_queue.empty()) {
const auto &path = path_queue.front();
for (const auto &node : path) {
if (allowed_set.find(node.in_anchor) == allowed_set.end()) {
const auto node_ptr = node.in_anchor->GetOwnerNodeBarePtr();
GELOGI("node: %s, type: %s is not in allowed_set", node_ptr->GetNamePtr(), node_ptr->GetTypePtr());
return false;
}
}
path_queue.pop();
}
return true;
}
graphStatus MoveTransdataOutDataEdge(NodePtr &trans) {
auto trans_out = trans->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(trans_out);
auto trans_in = trans->GetInDataAnchor(0);
GE_ASSERT_NOTNULL(trans_in);
auto pre_out = trans_in->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(pre_out);
for (auto &peer_in : trans_out->GetPeerInDataAnchors()) {
GE_ASSERT_SUCCESS(peer_in->Unlink(trans_out));
GE_ASSERT_SUCCESS(pre_out->LinkTo(peer_in));
}
return GRAPH_SUCCESS;
}
graphStatus MoveTransdataInCtrlEdge(NodePtr &trans) {
auto trans_out = trans->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(trans_out);
for (auto &next_node : trans->GetOutDataNodes()) {
GE_ASSERT_SUCCESS(GraphUtils::MoveInCtrlEdges(trans, next_node));
}
for (auto &next_node : trans->GetOutControlNodes()) {
GE_ASSERT_SUCCESS(GraphUtils::MoveInCtrlEdges(trans, next_node));
}
return GRAPH_SUCCESS;
}
* a
* |
* if then_subgraph then_subgraph
* | +-------------+ +-------------+
* transdata | data1 | | data1 |
* | | | | | | |
* c | netoutput | | b |
* +-------------+ | | |
* | netoutput |
* +-------------+
* 约束:多子图贯穿场景,path从子图外到子图内再到子图外,约束所有子图上都存在从a到达transdata的路径
* 完整实现这个约束比较复杂,当前为简单实现,不允许贯穿具有多子图的节点
*/
bool IsHeadNodeOutOfSubgraph(const NodePtr &netoutput, const NodePtr &head_node) {
auto sub_graph = netoutput->GetOwnerComputeGraphBarePtr();
while ((sub_graph != nullptr) && (sub_graph->GetParentNodeBarePtr() != nullptr)) {
auto parent_node_graph = sub_graph->GetParentNodeBarePtr()->GetOwnerComputeGraphBarePtr();
if (parent_node_graph == head_node->GetOwnerComputeGraphBarePtr()) {
return true;
}
sub_graph = parent_node_graph;
}
return false;
}
}
graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) {
if (root_graph_ == nullptr) {
root_graph_ = GraphUtils::FindRootGraph(graph);
GE_ASSERT_NOTNULL(root_graph_);
}
if (graph != root_graph_) {
GELOGI("sub graph %s in, just return.", graph->GetName().c_str());
return GRAPH_SUCCESS;
}
GE_ASSERT_SUCCESS(graph->TopologicalSorting());
GE_DUMP(graph, "before_SameTransdataBreadthFusionPass");
GELOGI("[SameTransdataBreadthFusionPass]: optimize begin, graph: %s.", graph->GetName().c_str());
GE_ASSERT_SUCCESS(CollectAllSubgraphDataNodesMap());
GE_ASSERT_SUCCESS(DoRun(graph));
GELOGI("[SameTransdataBreadthFusionPass]: Optimize success.");
GE_DUMP(graph, "after_SameTransdataBreadthFusionPass");
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::CollectAllSubgraphDataNodesMap() {
for (const auto &sub_graph : root_graph_->GetAllSubgraphs()) {
if ((sub_graph == nullptr) || (sub_graph->GetParentNode() == nullptr)) {
continue;
}
auto &data_nodes = graph_nodes_[sub_graph];
for (const auto &data : sub_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(data->GetOpDesc());
if (data->GetType() != DATA) {
continue;
}
uint32_t parent_index = 0;
GE_ASSERT_TRUE(AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index),
"[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
data->GetName().c_str(), data->GetType().c_str());
data_nodes[parent_index] = data;
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::DoRun(ComputeGraphPtr graph) {
for (const auto &head_node : graph->GetAllNodes()) {
if (IsTransOp(head_node)) {
continue;
}
for (auto &head_out_anchor : head_node->GetAllOutDataAnchors()) {
if (!IsSingleOutputToMultiInput(head_out_anchor)) {
continue;
}
GE_ASSERT_SUCCESS(RunForNode(head_out_anchor));
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::RunForNode(OutDataAnchorPtr &head_out_anchor) {
Paths paths;
GE_ASSERT_SUCCESS(GetPathsToTransdata(head_out_anchor, paths));
if (paths.size() <= 1U) {
return GRAPH_SUCCESS;
}
PrintPaths(paths);
GE_ASSERT_SUCCESS(FuseTransdata(paths));
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetPathsToTransdata(const OutDataAnchorPtr &head_out_anchor,
Paths &paths) const {
std::queue<Path> path_queue;
bool is_supported = false;
GE_ASSERT_SUCCESS(GetRealInAnchors(head_out_anchor, head_out_anchor, path_queue, Path()));
while (!path_queue.empty()) {
auto cur_path = path_queue.front();
path_queue.pop();
const auto &link_node = cur_path.back();
const auto &owner_node = link_node.in_anchor->GetOwnerNode();
switch (link_node.node_type) {
case NodeType::kTransdata :
GE_ASSERT_SUCCESS(CheckOpSupported(cur_path, link_node, is_supported));
if ((owner_node->GetOutDataNodesSize() != 0U) && is_supported) {
paths.emplace_back(std::move(cur_path));
}
break;
case NodeType::kCast:
for (const auto &cast_out_anchor : owner_node->GetAllOutDataAnchors()) {
GE_ASSERT_SUCCESS(GetRealInAnchors(cast_out_anchor, cast_out_anchor, path_queue, cur_path));
}
break;
case NodeType::kOthers:
break;
case NodeType::kData:
case NodeType::kNetOutput:
case NodeType::kWrapperNode:
GELOGE(FAILED, "type: %s node name: %s should not be here.",
owner_node->GetTypePtr(), owner_node->GetNamePtr());
return GRAPH_FAILED;
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetRealInAnchors(const OutDataAnchorPtr &real_out_anchor,
const OutDataAnchorPtr &out_anchor,
std::queue<Path> &path_queue,
const Path &path) const {
std::stack<OutDataAnchorPtr> out_anchor_stack;
out_anchor_stack.push(out_anchor);
while (!out_anchor_stack.empty()) {
const auto cur_out_anchor = out_anchor_stack.top();
out_anchor_stack.pop();
for (const auto &in_anchor : cur_out_anchor->GetPeerInDataAnchors()) {
const auto next_node = in_anchor->GetOwnerNode();
if ((next_node != nullptr) && (next_node->GetOpDesc() != nullptr)) {
if (IsWrapperNode(next_node->GetOpDesc())) {
GE_ASSERT_SUCCESS(GetRealInAnchorsForWrapperNode(in_anchor, out_anchor_stack));
} else if (next_node->GetType() == NETOUTPUT) {
GE_ASSERT_SUCCESS(GetRealInAnchorsForNetOutput(real_out_anchor, in_anchor, path, out_anchor_stack));
} else {
Path new_path(path);
new_path.emplace_back(LinkNode{in_anchor, real_out_anchor, GetNodeType(next_node)});
path_queue.push(std::move(new_path));
}
}
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetRealInAnchorsForWrapperNode(
const InDataAnchorPtr &in_anchor, std::stack<OutDataAnchorPtr> &out_anchor_stack) const {
const auto op_desc = in_anchor->GetOwnerNode()->GetOpDesc();
for (const auto &name : op_desc->GetSubgraphInstanceNames()) {
const auto &sub_graph = root_graph_->GetSubgraph(name);
GE_ASSERT_NOTNULL(sub_graph);
OutDataAnchorPtr data_out_anchor = nullptr;
GE_ASSERT_SUCCESS(GetSubgraphDataOutAnchor(sub_graph, in_anchor->GetIdx(), data_out_anchor));
if (data_out_anchor == nullptr) {
continue;
}
out_anchor_stack.push(data_out_anchor);
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetSubgraphDataOutAnchor(const ComputeGraphPtr &sub_graph,
const int32_t wrapper_node_input_index,
OutDataAnchorPtr &data_out_anchor) const {
GE_ASSERT_NOTNULL(sub_graph);
const auto &sub_graph_iter = graph_nodes_.find(sub_graph);
GE_ASSERT_TRUE(sub_graph_iter != graph_nodes_.end());
const auto &data_nodes = sub_graph_iter->second;
const auto &data_iter = data_nodes.find(wrapper_node_input_index);
if (data_iter != data_nodes.end()) {
data_out_anchor = data_iter->second->GetOutDataAnchor(0);
} else {
data_out_anchor = nullptr;
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetRealInAnchorsForNetOutput(
const OutDataAnchorPtr &real_out_anchor, const InDataAnchorPtr &in_anchor, const Path &path,
std::stack<OutDataAnchorPtr> &out_anchor_stack) const {
const auto &netoutput = in_anchor->GetOwnerNode();
const auto &op_desc = netoutput->GetOpDesc();
const auto in_tensor_desc = op_desc->GetInputDesc(static_cast<uint32_t>(in_anchor->GetIdx()));
uint32_t parent_index;
if (!AttrUtils::GetInt(in_tensor_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
return GRAPH_SUCCESS;
}
GE_ASSERT_NOTNULL(netoutput->GetOwnerComputeGraphBarePtr());
const auto &wrapper_node = netoutput->GetOwnerComputeGraphBarePtr()->GetParentNode();
GE_ASSERT_NOTNULL(wrapper_node);
* 如果父节点有多个子图,并且head在子图内,要求path不能通过netoutput延伸到子图外。
* 因为如果transdata提取到某一子图内,那么其他子图就不等价了
*/
const auto &wrapper_op_desc = wrapper_node->GetOpDesc();
GE_ASSERT_NOTNULL(wrapper_op_desc);
if ((wrapper_op_desc->GetSubgraphInstanceNames().size() > 1U) && (real_out_anchor != nullptr)) {
const auto &head_node = path.empty() ? real_out_anchor->GetOwnerNode() :
path.front().real_peer_out_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(head_node);
if (IsHeadNodeOutOfSubgraph(netoutput, head_node)
|| (head_node->GetOwnerComputeGraphBarePtr() == netoutput->GetOwnerComputeGraphBarePtr())) {
GELOGI("wrapper_node:%s has %zu subgraphs, head_node: %s is in subgraph[%s], stop search through netoutput: %s",
wrapper_node->GetNamePtr(), wrapper_op_desc->GetSubgraphInstanceNames().size(), head_node->GetNamePtr(),
head_node->GetOwnerComputeGraphBarePtr()->GetName().c_str(), netoutput->GetNamePtr());
return GRAPH_SUCCESS;
}
}
const auto &out_anchor = wrapper_node->GetOutDataAnchor(parent_index);
GE_ASSERT_NOTNULL(out_anchor, "wrapper_node: %s, parent_index: %u", wrapper_node->GetNamePtr(), parent_index);
out_anchor_stack.push(out_anchor);
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::FuseTransdata(Paths &paths) {
std::vector<Paths> same_transdata_paths_groups;
GE_ASSERT_SUCCESS(GetSameTransdataPath(paths, same_transdata_paths_groups));
auto iter = same_transdata_paths_groups.begin();
while (iter != same_transdata_paths_groups.end()) {
GE_ASSERT_SUCCESS(RemoveUnSupportedPath(*iter));
if (iter->size() <= 1U) {
iter = same_transdata_paths_groups.erase(iter);
} else {
++iter;
}
}
PrintFusedTransdata(same_transdata_paths_groups);
for (auto &paths_group : same_transdata_paths_groups) {
GE_ASSERT_SUCCESS(AddNewPathToTransdataForDiffGraph(paths_group));
const auto keep_transdata_path_index = GetKeepTransdataPathIndex(paths_group);
GE_ASSERT_SUCCESS(UpdateTensorDesc(paths_group, keep_transdata_path_index));
GE_ASSERT_SUCCESS(ExtractTransdata(paths_group, keep_transdata_path_index));
for (size_t i = 0U; i < paths_group.size(); ++i) {
if (i != keep_transdata_path_index) {
GE_ASSERT_SUCCESS(DeleteTransdata(paths_group[i]));
}
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetSameTransdataPath(Paths &paths,
std::vector<Paths> &same_transdata_paths_groups) {
while (paths.size() > 1U) {
auto iter = paths.begin();
Paths same_transdata_paths;
same_transdata_paths.emplace_back(std::move(*iter));
const auto &first_path = same_transdata_paths.front();
iter = paths.erase(iter);
LinkNode first_transdata = first_path.back();
CompareInfo first_info;
GE_ASSERT_SUCCESS(GetCompareInfo(first_path, first_transdata, first_info));
while (iter != paths.end()) {
auto &another_path = *iter;
LinkNode another_transdata = another_path.back();
CompareInfo another_info;
GE_ASSERT_SUCCESS(GetCompareInfo(another_path, another_transdata, another_info));
if (IsSame(first_info, another_info)) {
same_transdata_paths.emplace_back(std::move(another_path));
iter = paths.erase(iter);
} else {
GELOGI("%s is different from %s", another_transdata.in_anchor->GetOwnerNode()->GetNamePtr(),
first_transdata.in_anchor->GetOwnerNode()->GetNamePtr());
iter++;
}
}
if (same_transdata_paths.size() > 1U) {
same_transdata_paths_groups.emplace_back(std::move(same_transdata_paths));
}
}
return GRAPH_SUCCESS;
}
* +----------------------+
* +---transdata1 | partitioned_call |
* | | |
* op--+---transdata2 | +---add |
* | | | |
* +---partitioned_call | data--+---transdta3 |
* +----------------------+
*
* |
* \|/
*
* +----------------------+
* +---transdata1 | partitioned_call |
* | | |
* op--+---transdata2 | data---add |
* | | |
* +---partitioned_call | data_1---transdta3 |
* | | +----------------------+
* +------+
* add new path from op to transdata3
*/
graphStatus SameTransdataBreadthFusionPass::AddNewPathToTransdataForDiffGraph(Paths &paths_group) {
auto head_out_anchor = paths_group[0].front().real_peer_out_anchor;
std::set<InDataAnchorPtr> allowed_in_anchors;
for (const auto &cur_path : paths_group) {
allowed_in_anchors.emplace(cur_path.front().in_anchor);
}
GE_ASSERT_SUCCESS(AddNewPath(head_out_anchor, head_out_anchor, allowed_in_anchors));
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::AddNewPath(OutDataAnchorPtr &out_anchor,
OutDataAnchorPtr &new_out_anchor,
const std::set<InDataAnchorPtr> &allowed_in_anchors) {
AnchorPairStack out_anchor_pair_stack;
out_anchor_pair_stack.push(std::make_pair(out_anchor, new_out_anchor));
while (!out_anchor_pair_stack.empty()) {
const auto cur_out_anchor = out_anchor_pair_stack.top().first;
const auto cur_new_out_anchor = out_anchor_pair_stack.top().second;
out_anchor_pair_stack.pop();
for (auto &peer_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) {
const auto head_next = peer_in_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(head_next->GetOpDescBarePtr());
const auto head_next_type = GetNodeType(head_next);
if ((head_next_type != NodeType::kWrapperNode) && (head_next_type != NodeType::kNetOutput)) {
if (allowed_in_anchors.find(peer_in_anchor) != allowed_in_anchors.end()) {
GE_ASSERT_SUCCESS(peer_in_anchor->Unlink(peer_in_anchor->GetPeerOutAnchor()));
GE_ASSERT_SUCCESS(cur_new_out_anchor->LinkTo(peer_in_anchor));
}
continue;
}
std::vector<InDataAnchorPtr> fused_anchors;
std::vector<InDataAnchorPtr> not_fused_anchors;
GE_ASSERT_SUCCESS(CollectFusedInAnchors(peer_in_anchor, allowed_in_anchors, head_next_type,
fused_anchors, not_fused_anchors));
if (fused_anchors.empty() || not_fused_anchors.empty()) {
* peer_in_anchor 对应的真实的节点都是不能和当前transdata融合在一起的,
* or, peer_in_anchor 对应的真实的节点都能和当前transdata融合在一起的
*/
continue;
}
* 部分可以融合,其余的不能融合;real_in_anchor对端必然是子图的Data节点。
* 有两种:1是子图Data单输出多引用,输出节点有的可以被融合,有的不能被融合;
* 2是多个子图,有的子图Data节点对应的real_in_anchor可以被融合,有的Data对应的real_in_anchor
* 1和2有可能同时出现。
* 需要在子图中新增一个Data节点连接到可被融合的real_in_anchor上
*
* 为每个子图新增data,即使有的子图中没有可以融合的transdata;
* 为每个netoutput新增输入,即使它对应的实际输出节点没有可以融合的transdata
*/
const auto input_size = head_next->GetAllInDataAnchorsSize();
GE_ASSERT_SUCCESS(NodeUtils::AppendInputAnchor(head_next, input_size + 1U));
GELOGI("add new input for %s[%s], new input index:%u",
head_next->GetTypePtr(), head_next->GetNamePtr(), input_size);
auto new_in_anchor = head_next->GetInDataAnchor(input_size);
GE_ASSERT_SUCCESS(cur_new_out_anchor->LinkTo(new_in_anchor));
if (head_next_type == NodeType::kWrapperNode) {
GE_ASSERT_SUCCESS(AddNewInputForWrapper(peer_in_anchor, fused_anchors, out_anchor_pair_stack));
} else {
GE_ASSERT_SUCCESS(AddNewInputForNetOutput(peer_in_anchor, fused_anchors, out_anchor_pair_stack));
}
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::AddNewInputForWrapper(InDataAnchorPtr &wrapper_in_anchor,
std::vector<InDataAnchorPtr> &fused_anchors,
AnchorPairStack &out_anchor_pair_stack) {
if (fused_anchors.empty()) {
return GRAPH_SUCCESS;
}
auto owner_node = wrapper_in_anchor->GetOwnerNode();
auto op_desc = owner_node->GetOpDesc();
const auto parent_index = owner_node->GetAllInDataAnchorsSize() - 1U;
for (const auto &name : op_desc->GetSubgraphInstanceNames()) {
auto sub_graph = root_graph_->GetSubgraph(name);
GE_ASSERT_NOTNULL(sub_graph);
auto new_data = CreateDataNode(sub_graph, parent_index);
GE_ASSERT_NOTNULL(new_data);
UpdateGraphNode(sub_graph, parent_index, new_data);
auto new_data_out_anchor = new_data->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(new_data_out_anchor);
GE_ASSERT_SUCCESS(ConnetToFusedAnchors(fused_anchors, sub_graph.get(), new_data_out_anchor));
if (fused_anchors.empty()) {
return GRAPH_SUCCESS;
}
OutDataAnchorPtr sub_graph_data_out_anchor;
GE_ASSERT_SUCCESS(GetSubgraphDataOutAnchor(sub_graph, wrapper_in_anchor->GetIdx(), sub_graph_data_out_anchor));
if (sub_graph_data_out_anchor == nullptr) {
continue;
}
out_anchor_pair_stack.push(std::make_pair(sub_graph_data_out_anchor, new_data_out_anchor));
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::AddNewInputForNetOutput(InDataAnchorPtr &netout_in_anchor,
std::vector<InDataAnchorPtr> &fused_anchors,
AnchorPairStack &out_anchor_pair_stack) const {
if (fused_anchors.empty()) {
return GRAPH_SUCCESS;
}
auto netoutput = netout_in_anchor->GetOwnerNode();
auto netoutput_op_desc = netoutput->GetOpDesc();
GE_ASSERT_NOTNULL(netoutput_op_desc);
auto in_tensor_desc = netoutput_op_desc->GetInputDesc(static_cast<uint32_t>(netout_in_anchor->GetIdx()));
uint32_t parent_index;
if (!AttrUtils::GetInt(in_tensor_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGW("node %s(%s) %d input does not has %s attr.", netoutput_op_desc->GetNamePtr(),
netoutput_op_desc->GetTypePtr(), netout_in_anchor->GetIdx(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
return GRAPH_SUCCESS;
}
GE_ASSERT_NOTNULL(netoutput->GetOwnerComputeGraphBarePtr());
auto wrapper_node = netoutput->GetOwnerComputeGraphBarePtr()->GetParentNode();
GE_ASSERT_NOTNULL(wrapper_node);
const auto out_size = wrapper_node->GetAllOutDataAnchorsSize();
GE_ASSERT_SUCCESS(NodeUtils::AppendOutputAnchor(wrapper_node, out_size + 1U));
GELOGI("add new input for %s[%s], new input index:%u",
wrapper_node->GetTypePtr(), wrapper_node->GetNamePtr(), out_size);
const auto input_index = netoutput->GetAllInDataAnchorsSize() - 1U;
auto new_in_tensor_desc = netoutput_op_desc->MutableInputDesc(input_index);
GE_ASSERT_TRUE(AttrUtils::SetInt(new_in_tensor_desc, ATTR_NAME_PARENT_NODE_INDEX, out_size));
auto new_out_anchor = wrapper_node->GetOutDataAnchor(out_size);
GE_ASSERT_SUCCESS(ConnetToFusedAnchors(fused_anchors, wrapper_node->GetOwnerComputeGraphBarePtr(), new_out_anchor));
if (fused_anchors.empty()) {
return GRAPH_SUCCESS;
}
auto out_anchor = wrapper_node->GetOutDataAnchor(parent_index);
GE_ASSERT_NOTNULL(out_anchor);
out_anchor_pair_stack.push(std::make_pair(out_anchor, new_out_anchor));
return GRAPH_SUCCESS;
}
void SameTransdataBreadthFusionPass::UpdateGraphNode(const ComputeGraphPtr &sub_graph,
const uint32_t parent_index, NodePtr &node) {
graph_nodes_[sub_graph][parent_index] = node;
}
graphStatus SameTransdataBreadthFusionPass::RemoveUnSupportedPath(Paths &paths_with_same_transdata) const {
std::set<InDataAnchorPtr> allowed_set;
for (const auto &path : paths_with_same_transdata) {
for (const auto &node : path) {
allowed_set.emplace(node.in_anchor);
}
}
auto path_iter = paths_with_same_transdata.begin();
while (path_iter != paths_with_same_transdata.end()) {
std::queue<Path> path_queue;
for (size_t i = 0U; i < path_iter->size() - 1U; ++i) {
const auto &cast_node = path_iter->at(i);
for (auto &out_data_anchor : cast_node.in_anchor->GetOwnerNodeBarePtr()->GetAllOutDataAnchors()) {
GE_ASSERT_SUCCESS(GetRealInAnchors(out_data_anchor, out_data_anchor, path_queue, Path()));
}
}
if (!IsAllNodeInPathsWithSameTransdata(allowed_set, path_queue)) {
GELOGI("remove path from paths_with_same_transdata, transdata node: %s",
path_iter->back().in_anchor->GetOwnerNodeBarePtr()->GetNamePtr());
path_iter = paths_with_same_transdata.erase(path_iter);
} else {
++path_iter;
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::GetCompareInfo(const Path &path, const LinkNode &link_node,
CompareInfo &info) {
const auto &node = link_node.in_anchor->GetOwnerNode();
const auto &iter = node_to_info_map_.find(node);
if (iter != node_to_info_map_.end()) {
info = iter->second;
return GRAPH_SUCCESS;
}
const auto &op_desc = node->GetOpDesc();
(void)AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, info.stream_label);
info.in_ctrl_nodes = GetInControlIdentityNodes(path, link_node);
info.input_tensor_desc = op_desc->GetInputDescPtr(link_node.in_anchor->GetIdx());
GE_ASSERT_NOTNULL(info.input_tensor_desc);
info.output_tensor_desc = op_desc->GetOutputDescPtr(0);
GE_ASSERT_NOTNULL(info.output_tensor_desc);
node_to_info_map_[node] = info;
PrintCompareInfo(node, info);
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::UpdateTensorDesc(const Paths &paths_group,
size_t keep_transdata_path_index) {
const auto trans = paths_group[keep_transdata_path_index].back().in_anchor->GetOwnerNodeBarePtr();
const auto head_out_anchor = paths_group[keep_transdata_path_index].front().real_peer_out_anchor;
GE_ASSERT_SUCCESS(UpdateTransdataDtype(trans, head_out_anchor));
const auto trans_out_tensor_desc = trans->GetOpDescBarePtr()->GetOutputDesc(0U);
for (const auto &path : paths_group) {
for (size_t i = 0U; i < path.size() - 1U; ++i) {
const auto &link_node = path[i];
const auto node = link_node.in_anchor->GetOwnerNodeBarePtr();
if (link_node.in_anchor->GetPeerOutAnchor() != link_node.real_peer_out_anchor) {
GE_ASSERT_SUCCESS(UpdateTensorDescForDiffGraph(trans_out_tensor_desc, link_node));
}
auto node_in_tensor_desc = node->GetOpDesc()->MutableInputDesc(0U);
auto node_out_tensor_desc = node->GetOpDesc()->MutableOutputDesc(0U);
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, node_in_tensor_desc));
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, node_out_tensor_desc));
}
const auto &link_node = path.back();
if (link_node.in_anchor->GetPeerOutAnchor() != link_node.real_peer_out_anchor) {
GE_ASSERT_SUCCESS(UpdateTensorDescForDiffGraph(trans_out_tensor_desc, link_node));
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::UpdateTensorDescForDiffGraph(
const GeTensorDesc &trans_out_tensor_desc, const LinkNode &link_node) {
std::stack<LinkNode> link_node_stack;
link_node_stack.push(link_node);
while (!link_node_stack.empty()) {
const auto cur_link_node = link_node_stack.top();
link_node_stack.pop();
const auto &pre_node = cur_link_node.in_anchor->GetPeerOutAnchor()->GetOwnerNode();
GE_ASSERT_NOTNULL(pre_node->GetOpDescBarePtr());
const auto pre_node_type = GetNodeType(pre_node);
GE_ASSERT_TRUE((pre_node_type == NodeType::kData) || (pre_node_type == NodeType::kWrapperNode),
"current node %s must connect to Data or Wrapper, but pre_node %s node type is %u",
cur_link_node.in_anchor->GetOwnerNodeBarePtr()->GetNamePtr(), pre_node->GetNamePtr(),
static_cast<uint32_t>(pre_node_type));
if (pre_node_type == NodeType::kData) {
GE_ASSERT_SUCCESS(UpdateTensorDescForConnectData(trans_out_tensor_desc, cur_link_node, link_node_stack));
} else {
GE_ASSERT_SUCCESS(UpdateTensorDescForConnectWrapper(trans_out_tensor_desc, cur_link_node, link_node_stack));
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::UpdateTensorDescForConnectData(
const GeTensorDesc &trans_out_tensor_desc, const LinkNode &link_node, std::stack<LinkNode> &link_node_stack) const {
const auto &owner_node = link_node.in_anchor->GetOwnerNode();
const auto owner_graph = owner_node->GetOwnerComputeGraphBarePtr();
GE_ASSERT_NOTNULL(owner_graph->GetParentNode());
auto out_anchor = link_node.in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(out_anchor);
const auto data_op_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_ASSERT_NOTNULL(data_op_desc);
auto data_in_tensor_desc = data_op_desc->MutableInputDesc(0U);
auto data_out_tensor_desc = data_op_desc->MutableOutputDesc(0U);
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, data_in_tensor_desc));
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, data_out_tensor_desc));
uint32_t parent_index;
GE_ASSERT_TRUE(AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index));
const auto wrapper_op_desc = owner_graph->GetParentNode()->GetOpDesc();
GE_ASSERT_NOTNULL(wrapper_op_desc);
auto wrapper_in_tensor_desc = wrapper_op_desc->MutableInputDesc(parent_index);
GE_ASSERT_NOTNULL(wrapper_in_tensor_desc);
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, wrapper_in_tensor_desc));
const auto wrapper_in_anchor = owner_graph->GetParentNode()->GetInDataAnchor(parent_index);
GE_ASSERT_NOTNULL(wrapper_in_anchor);
const auto peer_out_anchor = wrapper_in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(peer_out_anchor);
if (peer_out_anchor != link_node.real_peer_out_anchor) {
const LinkNode new_link_node{wrapper_in_anchor, link_node.real_peer_out_anchor, NodeType::kWrapperNode};
link_node_stack.push(new_link_node);
}
const auto real_peer_node = link_node.real_peer_out_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(real_peer_node);
GE_ASSERT_NOTNULL(real_peer_node->GetOpDesc());
const auto &real_src_tensor_desc = real_peer_node->GetOpDesc()->MutableOutputDesc(
static_cast<uint32_t>(link_node.real_peer_out_anchor->GetIdx()));
GE_ASSERT_NOTNULL(real_src_tensor_desc);
UpdateDataType(real_src_tensor_desc, wrapper_in_tensor_desc);
UpdateDataType(real_src_tensor_desc, data_in_tensor_desc);
UpdateDataType(real_src_tensor_desc, data_out_tensor_desc);
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::UpdateTensorDescForConnectWrapper(
const GeTensorDesc &trans_out_tensor_desc, const LinkNode &link_node, std::stack<LinkNode> &link_node_stack) {
GE_ASSERT_NOTNULL(link_node.in_anchor->GetPeerOutAnchor());
const auto &wrapper_node = link_node.in_anchor->GetPeerOutAnchor()->GetOwnerNode();
const auto &wrapper_op_desc = wrapper_node->GetOpDesc();
GE_ASSERT_NOTNULL(wrapper_op_desc);
const auto wrapper_node_output_index = link_node.in_anchor->GetPeerOutAnchor()->GetIdx();
auto wrapper_node_output_tensor_desc = wrapper_op_desc->MutableOutputDesc(wrapper_node_output_index);
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, wrapper_node_output_tensor_desc));
const auto &real_in_tensor_desc = link_node.in_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(
static_cast<uint32_t>(link_node.in_anchor->GetIdx()));
GE_ASSERT_NOTNULL(real_in_tensor_desc);
UpdateDataType(real_in_tensor_desc, wrapper_node_output_tensor_desc);
const auto &sub_graph_names = wrapper_op_desc->GetSubgraphInstanceNames();
GE_ASSERT_TRUE(!sub_graph_names.empty());
for (const auto &sub_graph_name : sub_graph_names) {
const auto &sub_graph = root_graph_->GetSubgraph(sub_graph_name);
GE_ASSERT_NOTNULL(sub_graph);
const auto &netoutput = sub_graph->FindFirstNodeMatchType(NETOUTPUT);
GE_ASSERT_NOTNULL(netoutput);
GE_ASSERT_NOTNULL(netoutput->GetOpDesc());
const auto &netoutput_op_desc = netoutput->GetOpDesc();
const auto input_size = netoutput->GetAllInDataAnchorsSize();
uint32_t netoutput_input_index = input_size;
for (uint32_t i = 0U; i < input_size; ++i) {
const auto input_desc = netoutput_op_desc->GetInputDesc(i);
uint32_t parent_index = 0U;
if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)
|| (parent_index != static_cast<uint32_t>(wrapper_node_output_index))) {
continue;
}
netoutput_input_index = i;
}
GE_ASSERT_TRUE(netoutput_input_index != input_size,
"cannot find correspond input. netoutput: %s input_size: %u, parent_node: %s, parent output: %u",
netoutput->GetNamePtr(), input_size, wrapper_node->GetNamePtr(), wrapper_node_output_index);
auto netoutput_in_tensor_desc = netoutput_op_desc->MutableInputDesc(netoutput_input_index);
GE_ASSERT_SUCCESS(CopyTensorDesc(trans_out_tensor_desc, netoutput_in_tensor_desc));
UpdateDataType(real_in_tensor_desc, netoutput_in_tensor_desc);
const auto netoutput_in_anchor = netoutput->GetInDataAnchor(netoutput_input_index);
GE_ASSERT_NOTNULL(netoutput_in_anchor);
const auto peer_out_anchor = netoutput_in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(peer_out_anchor);
if (peer_out_anchor != link_node.real_peer_out_anchor) {
const LinkNode new_link_node{netoutput_in_anchor, link_node.real_peer_out_anchor, NodeType::kNetOutput};
link_node_stack.push(new_link_node);
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::ExtractTransdata(const Paths &paths_group,
size_t keep_transdata_path_index) const {
const auto &path = paths_group[keep_transdata_path_index];
auto trans_in_anchor = path.back().in_anchor;
GELOGI("keep transdata[%s]", trans_in_anchor->GetOwnerNode()->GetNamePtr());
const auto &head_out_anchor = path.front().real_peer_out_anchor;
if (head_out_anchor == trans_in_anchor->GetPeerOutAnchor()) {
GE_ASSERT_SUCCESS(LinkHeadToTransdata(paths_group, keep_transdata_path_index));
return GRAPH_SUCCESS;
}
auto trans = trans_in_anchor->GetOwnerNode();
auto trans_out_anchor = trans->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(trans_out_anchor);
GE_ASSERT_NOTNULL(trans_in_anchor->GetPeerOutAnchor());
auto trans_pre = trans_in_anchor->GetPeerOutAnchor()->GetOwnerNode();
GE_ASSERT_SUCCESS(MoveTransdataInCtrlEdge(trans));
GE_ASSERT_SUCCESS(GraphUtils::MoveOutCtrlEdges(trans, trans_pre));
GE_ASSERT_SUCCESS(MoveTransdataOutDataEdge(trans));
GE_ASSERT_SUCCESS(LinkHeadToTransdata(paths_group, keep_transdata_path_index));
auto head = head_out_anchor->GetOwnerNode();
if (head->GetOwnerComputeGraphBarePtr() != trans->GetOwnerComputeGraphBarePtr()) {
head->GetOwnerComputeGraphBarePtr()->AddNode(trans);
GraphUtils::RemoveJustNode(trans->GetOwnerComputeGraph(), trans);
trans->SetOwnerComputeGraph(head->GetOwnerComputeGraph());
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::LinkHeadToTransdata(const Paths &paths_group,
size_t keep_transdata_path_index) const {
const auto &path = paths_group[keep_transdata_path_index];
const auto head_out_anchor = path.front().real_peer_out_anchor;
const auto trans_in_anchor = path.back().in_anchor;
const auto trans = trans_in_anchor->GetOwnerNode();
auto trans_out_anchor = trans->GetOutDataAnchor(0);
GE_ASSERT_NOTNULL(trans_out_anchor);
std::set<InDataAnchorPtr> allowed_in_anchors;
for (const auto &cur_path : paths_group) {
allowed_in_anchors.emplace(cur_path.front().in_anchor);
}
for (auto &peer_in_anchor : head_out_anchor->GetPeerInDataAnchors()) {
const auto head_next = peer_in_anchor->GetOwnerNode();
GE_ASSERT_NOTNULL(head_next->GetOpDescBarePtr());
const auto head_next_type = GetNodeType(head_next);
if ((head_next_type != NodeType::kWrapperNode) && (head_next_type != NodeType::kNetOutput)) {
if (allowed_in_anchors.find(peer_in_anchor) != allowed_in_anchors.end()) {
peer_in_anchor->UnlinkAll();
if (peer_in_anchor->GetOwnerNode() != trans) {
GE_ASSERT_SUCCESS(trans_out_anchor->LinkTo(peer_in_anchor));
}
}
continue;
}
std::vector<InDataAnchorPtr> fused_anchors;
std::vector<InDataAnchorPtr> not_fused_anchors;
GE_ASSERT_SUCCESS(CollectFusedInAnchors(peer_in_anchor, allowed_in_anchors, head_next_type,
fused_anchors, not_fused_anchors));
(void) not_fused_anchors;
if (!fused_anchors.empty()) {
peer_in_anchor->UnlinkAll();
GE_ASSERT_SUCCESS(trans_out_anchor->LinkTo(peer_in_anchor));
}
}
if (!head_out_anchor->IsLinkedWith(trans_in_anchor)) {
if (trans_in_anchor->GetPeerOutAnchor() != nullptr) {
GE_ASSERT_SUCCESS(trans_in_anchor->Unlink(trans_in_anchor->GetPeerOutAnchor()));
}
GE_ASSERT_SUCCESS(head_out_anchor->LinkTo(trans_in_anchor));
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::CollectFusedInAnchors(
const InDataAnchorPtr &in_anchor, const std::set<InDataAnchorPtr> &allowed_in_anchors,
const NodeType head_next_type, std::vector<InDataAnchorPtr> &fused_anchors,
std::vector<InDataAnchorPtr> ¬_fused_anchors) const {
std::queue<Path> path_queue;
std::stack<OutDataAnchorPtr> out_anchor_stack;
if (head_next_type == NodeType::kWrapperNode) {
GE_ASSERT_SUCCESS(GetRealInAnchorsForWrapperNode(in_anchor, out_anchor_stack));
} else {
GE_ASSERT_SUCCESS(GetRealInAnchorsForNetOutput(nullptr, in_anchor, Path(), out_anchor_stack));
}
while (!out_anchor_stack.empty()) {
const auto cur_out_anchor = out_anchor_stack.top();
out_anchor_stack.pop();
for (const auto &cur_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) {
const auto next_node = cur_in_anchor->GetOwnerNode();
if ((next_node != nullptr) && (next_node->GetOpDesc() != nullptr)) {
if (IsWrapperNode(next_node->GetOpDesc())) {
GE_ASSERT_SUCCESS(GetRealInAnchorsForWrapperNode(cur_in_anchor, out_anchor_stack));
continue;
}
if (next_node->GetType() == NETOUTPUT) {
GE_ASSERT_SUCCESS(GetRealInAnchorsForNetOutput(nullptr, cur_in_anchor, Path(), out_anchor_stack));
continue;
}
if (allowed_in_anchors.find(cur_in_anchor) == allowed_in_anchors.end()) {
not_fused_anchors.emplace_back(cur_in_anchor);
} else {
fused_anchors.emplace_back(cur_in_anchor);
}
}
}
}
return GRAPH_SUCCESS;
}
graphStatus SameTransdataBreadthFusionPass::DeleteTransdata(const Path &path) const {
auto trans_in_anchor = path.back().in_anchor;
auto trans = trans_in_anchor->GetOwnerNode();
auto trans_out_anchor = path.back().in_anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(trans_out_anchor);
auto trans_pre = trans_out_anchor->GetOwnerNode();
GE_ASSERT_SUCCESS(MoveTransdataInCtrlEdge(trans));
GE_ASSERT_SUCCESS(GraphUtils::MoveOutCtrlEdges(trans, trans_pre));
GE_ASSERT_SUCCESS(MoveTransdataOutDataEdge(trans));
GE_ASSERT_SUCCESS(trans_in_anchor->Unlink(trans_in_anchor->GetPeerOutAnchor()));
GE_ASSERT_SUCCESS(GraphUtils::RemoveJustNode(trans->GetOwnerComputeGraph(), trans));
return GRAPH_SUCCESS;
}
REG_PASS_OPTION("SameTransdataBreadthFusionPass").LEVELS(OoLevel::kO3);
}