* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/build/memory/checker/dependency_analyzer.h"
#include <stack>
#include <limits>
#include "graph/build/memory/block_mem_assigner.h"
#include "graph/debug/ge_attr_define.h"
#include "common/checker.h"
#include "graph/utils/op_type_utils.h"
#include "graph/optimize/mem_layout_conflict_optimize/mem_layout_conflict_util.h"
namespace ge {
namespace {
constexpr size_t kPrintMaxNameLen = 350U;
std::string kOutNodeMeaning =
"the output node refers to all nodes that use this memory as input, not only directly connected output "
"nodes, but also consider RefNode, continuous memory";
constexpr int64_t kInvalidTopoId = std::numeric_limits<int64_t>::max();
constexpr size_t kMaxLogCharNum = 1200U;
int64_t GetTopoId(const Node *node) {
return node->GetOpDescBarePtr()->GetId();
}
std::string NodeNameOut(const ge::Node *n, const uint32_t out_index) {
std::stringstream ss;
ss << n->GetName().substr(0, kPrintMaxNameLen) << "_out_" << out_index << "(" << n->GetType()
<< ", stream:" << n->GetOpDescBarePtr()->GetStreamId() << ", topo:" << GetTopoId(n) << ")";
return ss.str();
}
std::string NodeName(const ge::Node *n) {
std::stringstream ss;
ss << n->GetName().substr(0, kPrintMaxNameLen) << "(" << n->GetType()
<< ", stream:" << n->GetOpDescBarePtr()->GetStreamId() << ", topo:" << GetTopoId(n) << ")";
return ss.str();
}
std::string TopoIdTypeOutIndex(const ge::Node *n, const uint32_t out_index) {
std::stringstream ss;
int64_t offset = -1;
const auto outputs_offset = n->GetOpDescBarePtr()->GetOutputOffset();
if (out_index < outputs_offset.size()) {
offset = outputs_offset[out_index];
}
const auto output_tensor_desc = n->GetOpDescBarePtr()->MutableOutputDesc(static_cast<uint32_t>(out_index));
int64_t mem_size = -1;
if (output_tensor_desc != nullptr) {
MemReuseUtils::GetTensorSize(*output_tensor_desc, mem_size, false);
}
const auto stream = n->GetOpDescBarePtr()->GetStreamId();
ss << GetTopoId(n) << "(" << n->GetType() << ")[out:" << out_index << ", s:" << stream << ", offset:" << offset
<< ", size:" << mem_size << "]";
return ss.str();
}
std::string TopoIdTypeInIndex(const ge::Node *n, const uint32_t in_index) {
std::stringstream ss;
const auto stream = n->GetOpDescBarePtr()->GetStreamId();
ss << GetTopoId(n) << "(" << n->GetType() << ")[in:" << in_index << ", s:" << stream << "]";
return ss.str();
}
std::string TopoIdTypeIndex(const NodeIndexIO &node_index_io) {
if (node_index_io.io_type_ == kOut) {
return TopoIdTypeOutIndex(node_index_io.node_ptr_, node_index_io.index_);
} else {
return TopoIdTypeInIndex(node_index_io.node_ptr_, node_index_io.index_);
}
}
std::string NodesIdStr(const std::vector<Node *> &nodes) {
std::stringstream ss;
for (const auto node : nodes) {
ss << GetTopoId(node) << ", ";
}
return ss.str();
}
std::vector<std::string> IdsStr(const std::vector<int64_t> &ids) {
std::vector<std::string> ret;
if (ids.empty()) {
return ret;
}
std::stringstream ss;
auto start = ids[0U];
auto end = start;
for (size_t i = 1U; i < ids.size(); ++i) {
if (ids[i] == end + 1) {
end = ids[i];
} else {
if (start == end) {
ss << start;
} else {
ss << start << "-" << end;
}
ss << ", ";
end = ids[i];
start = ids[i];
if (ss.str().length() > kMaxLogCharNum) {
ret.emplace_back(ss.str());
ss.str("");
ss.clear();
}
}
}
if (start == end) {
ss << start;
} else {
ss << start << "-" << end;
}
if (!ss.str().empty()) {
ret.emplace_back(ss.str());
}
return ret;
}
bool IsNeedMergeContinuousInput(const Node *node) {
if (MemLayoutConflictUtil::IsNoPaddingContinuousInput(node)) {
return true;
}
if (MemLayoutConflictUtil::IsContinuousInput(node)) {
for (const auto in_anchor : node->GetAllInDataAnchorsPtr()) {
if (in_anchor->GetPeerOutAnchor() != nullptr) {
const auto in_node = in_anchor->GetPeerOutAnchor()->GetOwnerNodeBarePtr();
GE_ASSERT_NOTNULL(in_node);
GE_ASSERT_NOTNULL(in_node->GetOpDescBarePtr());
std::vector<int64_t> offsets_for_fusion = {};
if (AttrUtils::GetListInt(in_node->GetOpDescBarePtr(), ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION,
offsets_for_fusion)) {
GELOGI("continuous_input node %s(%s) input node %s has %s attr, need merge symbol", NodeName(node).c_str(),
node->GetTypePtr(), NodeName(in_node).c_str(), ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION.c_str());
return true;
}
}
}
}
return false;
}
}
WrapperInfo::WrapperInfo(const ComputeGraphPtr &sub_graph, const size_t bitmap_size)
: subgraph_all_nodes_bitmap(bitmap_size) {
const auto &direct_nodes_visitor = sub_graph->GetDirectNode();
std::for_each(direct_nodes_visitor.begin(), direct_nodes_visitor.end(), [this](const NodePtr &node) {
if (node->GetOpDesc() != nullptr) {
direct_nodes.emplace_back(GetTopoId(node.get()));
}
});
const auto &all_nodes_visitor = sub_graph->GetAllNodes();
std::for_each(all_nodes_visitor.begin(), all_nodes_visitor.end(), [this](const NodePtr &node) {
if (node->GetOpDesc() != nullptr) {
subgraph_all_nodes_bitmap.SetBit(GetTopoId(node.get()));
}
});
}
bool DependencyAnalyzer::CanAReuseB(const Node *const a, uint32_t a_output_index, const Node *const b,
uint32_t b_output_index) const {
if (!init_flag_) {
GELOGE(ge::FAILED, "node memory dependency analyzer is not inited.");
return false;
}
if (!CheckParam(a, a_output_index, b, b_output_index)) {
GELOGE(FAILED, "invalid param.");
return false;
}
return CanAReuseBInner(a, a_output_index, b, b_output_index);
}
bool DependencyAnalyzer::CanAReuseBInner(const Node *const a, uint32_t a_output_index, const Node *const b,
uint32_t b_output_index) const {
const auto &b_node_index_ios = GetSymbolNodeIndexIOList(b, b_output_index);
if (b_node_index_ios.empty()) {
GELOGW("b cannot find symbol. b: %s, b_out_index: %u", b->GetNamePtr(), b_output_index);
return true;
}
bool suspended = true;
for (const auto &b_out_node : b_node_index_ios) {
if ((b_out_node.io_type_ != kIn) || (b_out_node.node_ptr_ == nullptr) ||
(b_out_node.node_ptr_->GetOpDescBarePtr() == nullptr)) {
continue;
}
const auto out_node_id = GetTopoId(b_out_node.node_ptr_);
if (IsOutNodeSkip(b_out_node.node_ptr_, b)) {
continue;
}
suspended = false;
if (!CanReachAllSameBlockAnchor(a, a_output_index, out_node_id)) {
return false;
}
}
if (suspended && (!CanReachAllSameBlockAnchor(a, a_output_index, GetTopoId(b)))) {
return false;
}
return true;
}
bool DependencyAnalyzer::CanReachAllSameBlockAnchor(const Node *const a, const uint32_t a_output_index,
const uint32_t out_node_id) const {
const auto a_out_anchor = a->GetOutDataAnchor(a_output_index);
const auto &id_iter = out_data_anchor_2_id_map_.find(a_out_anchor);
if (id_iter == out_data_anchor_2_id_map_.end()) {
GELOGW("cannot find id for out anchor, node: %s, out_index: %u", a->GetNamePtr(), a_output_index);
return true;
}
if (same_block_anchor_ids_[id_iter->second] != nullptr) {
for (const auto same_out_anchor_id : *same_block_anchor_ids_[id_iter->second]) {
const auto out_anchor = id_2_out_data_anchor_table_[same_out_anchor_id];
if ((out_anchor == nullptr) || (out_anchor->GetOwnerNodeBarePtr() == nullptr) ||
(out_anchor->GetOwnerNodeBarePtr()->GetOpDescBarePtr() == nullptr)) {
return true;
}
const auto dst_id = GetTopoId(out_anchor->GetOwnerNodeBarePtr());
if (IsOutNodeSkip(out_anchor->GetOwnerNodeBarePtr(), a)) {
continue;
}
if (!reach_nodes_bit_map_table_[out_node_id].GetBit(dst_id)) {
return false;
}
}
} else {
if (!reach_nodes_bit_map_table_[out_node_id].GetBit(GetTopoId(a))) {
return false;
}
}
return true;
}
std::string DependencyAnalyzer::WhyACannotReuseB(const ge::Node *a, uint32_t a_output_index, const ge::Node *b,
uint32_t b_output_index) {
if (!CheckParam(a, a_output_index, b, b_output_index)) {
return "invalid param";
}
if (CanAReuseB(a, a_output_index, b, b_output_index)) {
return NodeNameOut(a, a_output_index) + " can reuse " + NodeNameOut(b, b_output_index);
}
ErrorLogReachNodes(b);
if (!reach_nodes_bit_map_table_[GetTopoId(b)].GetBit(GetTopoId(a))) {
return NodeName(b) + " cannot reach " + NodeName(a);
}
std::stringstream ss;
auto b_output_nodes = GetAllUseSameInBlockNodes(b, b_output_index);
ErrorLogReachNodes(b_output_nodes);
if (b_output_nodes.empty()) {
ss << NodeNameOut(b, b_output_index) << " has no output node, ";
b_output_nodes.push_back(b);
}
for (const auto out_node : b_output_nodes) {
if (IsOutNodeSkip(out_node, b)) {
continue;
}
if (!reach_nodes_bit_map_table_[GetTopoId(out_node)].GetBit(GetTopoId(a))) {
if (out_node != a) {
ss << GetTopoId(b) << "\'s output node " << NodeName(out_node) << " cannot reach " << NodeName(a) << ", "
<< kOutNodeMeaning;
} else {
ss << GetTopoId(b) << "\'s output node " << NodeName(out_node) << " cannot reuse its input, "
<< kOutNodeMeaning;
}
ErrorLogOutSymbol(b, b_output_index);
ErrorLogOutSymbol(a, a_output_index);
return ss.str();
}
}
WhyACannotReuseBInner({a, a_output_index, b, b_output_index}, b_output_nodes, ss);
return ss.str();
}
void DependencyAnalyzer::WhyACannotReuseBInner(const MemCheckParam ¶m,
const std::list<const Node *> &b_output_nodes, std::stringstream &ss) {
const auto same_out_anchors = GetAllUseSameOutBlockAnchors(param.a, param.a_output_index);
ErrorLogDependNodes(same_out_anchors);
for (const auto out_node : b_output_nodes) {
if (IsOutNodeSkip(out_node, param.b)) {
continue;
}
for (const auto &dst_anchor : same_out_anchors) {
if (IsOutNodeSkip(dst_anchor->GetOwnerNodeBarePtr(), param.a)) {
continue;
}
if (!reach_nodes_bit_map_table_[GetTopoId(out_node)].GetBit(GetTopoId(dst_anchor->GetOwnerNodeBarePtr()))) {
if (out_node != param.b) {
ss << GetTopoId(param.b) << "\'s output ";
}
ss << "node " << NodeName(out_node) << " cannot reach "
<< NodeNameOut(dst_anchor->GetOwnerNodeBarePtr(), dst_anchor->GetIdx());
if (dst_anchor->GetOwnerNodeBarePtr() != param.a) {
ss << ", which use same memory block with " << NodeNameOut(param.a, param.a_output_index);
}
if (out_node != param.b) {
ss << ", " << kOutNodeMeaning;
}
ErrorLogOutSymbol(param.b, param.b_output_index);
ErrorLogOutSymbol(param.a, param.a_output_index);
return;
}
}
}
}
void DependencyAnalyzer::Debug() {
debug_mode = true;
DeInitWrapperInfo();
(void)InitTables();
InitReachNodesMap();
}
DependencyAnalyzer::DependencyAnalyzer(ComputeGraphPtr &graph, AnchorToSymbol &anchor_to_symbol,
SymbolToAnchors &symbol_to_anchors)
: compute_graph_(graph), anchor_to_symbol_(anchor_to_symbol), symbol_to_anchors_(symbol_to_anchors) {}
Status DependencyAnalyzer::Init() {
const auto start_time = GetCurrentTimestamp();
GELOGI("init start.");
root_graph_ = compute_graph_;
while (root_graph_->GetParentGraph() != root_graph_) {
root_graph_ = root_graph_->GetParentGraph();
if (root_graph_ == nullptr) {
root_graph_ = compute_graph_;
break;
}
}
if (InitTables() != SUCCESS) {
GELOGW("init table failed. skip memory reuse check");
return FAILED;
}
ExtendSymbolMapByContinuousMemory();
InitSameOutBlockNodesMap();
InitReachNodesMap();
GELOGI("init finish.");
const auto end_time = GetCurrentTimestamp();
GELOGI("cost time: %llu microseconds, graph: %s has %zu nodes and %zu out data anchors",
end_time - start_time, compute_graph_->GetName().c_str(), id_max_, id_2_out_data_anchor_table_.size());
return SUCCESS;
}
Status DependencyAnalyzer::InitTables() {
id_2_node_table_.clear();
graph_all_nodes_.clear();
const auto &all_nodes = compute_graph_->GetAllNodesPtr();
id_max_ = all_nodes.size();
id_2_node_table_.resize(id_max_, nullptr);
for (const auto node : all_nodes) {
graph_all_nodes_.emplace_back(node);
if (node->GetOpDescBarePtr() == nullptr) {
GELOGW("op_desc is nullptr, node: %s", node->GetNamePtr());
return FAILED;
}
const auto id = static_cast<size_t>(GetTopoId(node));
if (id >= id_max_) {
GELOGW("node topo id: %" PRId64 " >= id_max_: %zu", id, id_max_);
return FAILED;
}
if (id_2_node_table_[id] != nullptr) {
GELOGW("topo id duplicated, id: %" PRId64 ", origin node: %s, current node: %s", id,
id_2_node_table_[id]->GetNamePtr(), node->GetNamePtr());
return FAILED;
}
id_2_node_table_[GetTopoId(node)] = node;
}
InitOutDataAnchorId(graph_all_nodes_);
InitPreAndNextNodeOnSameStreamTable();
InitNodeTypeTable();
InitWrapperInfo();
init_flag_ = true;
GELOGI("init tables finish.");
return SUCCESS;
}
void DependencyAnalyzer::InitOutDataAnchorId(const std::vector<const Node *> &all_nodes) {
GELOGI("init out data anchor id begin.");
id_2_out_data_anchor_table_.clear();
out_data_anchor_2_id_map_.clear();
size_t num = 0U;
for (const auto &node : all_nodes) {
num += node->GetAllOutDataAnchorsSize();
}
id_2_out_data_anchor_table_.resize(num);
out_data_anchor_2_id_map_.reserve(num);
same_block_anchor_ids_.resize(num, nullptr);
size_t id = 0U;
for (const auto &node : all_nodes) {
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
id_2_out_data_anchor_table_[id] = out_data_anchor;
out_data_anchor_2_id_map_[out_data_anchor] = id;
++id;
}
}
GELOGI("init out data anchor id finish.");
}
void DependencyAnalyzer::ExtendSymbolMapByContinuousMemory() {
for (const auto &cur_node : graph_all_nodes_) {
if (MemLayoutConflictUtil::IsNoPaddingContinuousOutput(cur_node)) {
OutContinuousMergeSymbol(cur_node);
}
if (IsNeedMergeContinuousInput(cur_node)) {
InContinuousMergeSymbol(cur_node);
}
}
GELOGI("extend symbols finish.");
}
const std::list<NodeIndexIO> &DependencyAnalyzer::GetSymbolNodeIndexIOList(const Node *const node,
const uint32_t output_index) const {
static std::list<NodeIndexIO> dummy_ret;
const auto cur_anchor = NodeIndexIO(node, output_index, kOut);
const auto &symbol_iter = anchor_to_symbol_.find(cur_anchor.ToString());
if (symbol_iter != anchor_to_symbol_.end()) {
const auto &anchor_iter = symbol_to_anchors_.find(symbol_iter->second);
if (anchor_iter != symbol_to_anchors_.end()) {
return anchor_iter->second;
}
}
GELOGW("cannot find symbol for node %s, out_index: %u", node->GetNamePtr(), output_index);
return dummy_ret;
}
std::list<const Node *> DependencyAnalyzer::GetAllUseSameInBlockNodes(const Node *node, uint32_t out_index) {
return GetAllUseSameBlockNodesByType(node, out_index, kIn);
}
std::list<OutDataAnchorPtr> DependencyAnalyzer::GetAllUseSameOutBlockAnchors(const Node *node, uint32_t out_index) {
const auto cur_anchor = NodeIndexIO(node, out_index, kOut);
const auto &symbol = anchor_to_symbol_[cur_anchor.ToString()];
std::list<OutDataAnchorPtr> use_same_block_anchors;
std::for_each(symbol_to_anchors_[symbol].cbegin(), symbol_to_anchors_[symbol].cend(),
[&use_same_block_anchors](const NodeIndexIO &anchor) {
if ((anchor.io_type_ == kOut) && (anchor.node_ != nullptr)) {
use_same_block_anchors.emplace_back(anchor.node_->GetOutDataAnchor(anchor.index_));
}
});
return use_same_block_anchors;
}
std::list<const Node *> DependencyAnalyzer::GetAllUseSameBlockNodesByType(const Node *node, uint32_t out_index,
IOType type) {
const auto cur_anchor = NodeIndexIO(node, out_index, kOut);
const auto &symbol = anchor_to_symbol_[cur_anchor.ToString()];
std::list<const Node *> use_same_block_nodes;
std::for_each(symbol_to_anchors_[symbol].begin(), symbol_to_anchors_[symbol].end(),
[&use_same_block_nodes, type](const NodeIndexIO &anchor) {
if ((anchor.io_type_ == type) && (anchor.node_ != nullptr)) {
use_same_block_nodes.emplace_back(anchor.node_.get());
}
});
return use_same_block_nodes;
}
void DependencyAnalyzer::InContinuousMergeSymbol(const Node *cur_node) {
std::vector<OutDataAnchorPtr> out_anchors;
for (const auto &in_anchor : cur_node->GetAllInDataAnchors()) {
if ((in_anchor != nullptr) && (in_anchor->GetPeerOutAnchor() != nullptr)) {
out_anchors.emplace_back(in_anchor->GetPeerOutAnchor());
}
}
if (!out_anchors.empty()) {
MergeOutAnchorSymbol(out_anchors);
}
}
void DependencyAnalyzer::OutContinuousMergeSymbol(const Node *cur_node) {
std::vector<OutDataAnchorPtr> out_anchors;
for (const auto &out_anchor : cur_node->GetAllOutDataAnchors()) {
if (out_anchor != nullptr) {
out_anchors.emplace_back(out_anchor);
}
}
if (!out_anchors.empty()) {
MergeOutAnchorSymbol(out_anchors);
}
}
void DependencyAnalyzer::MergeOutAnchorSymbol(const std::vector<OutDataAnchorPtr> &out_anchors) {
std::vector<NodeIndexIO> anchors_vec;
std::unordered_set<std::string> symbols_set;
for (const auto &out_anchor : out_anchors) {
const auto out_node_index_io = NodeIndexIO(out_anchor->GetOwnerNode(), out_anchor->GetIdx(), kOut);
if (anchor_to_symbol_.find(out_node_index_io.ToString()) == anchor_to_symbol_.end()) {
GELOGW("%s is not in anchor_to_symbol", out_node_index_io.ToString().c_str());
continue;
}
const auto &symbol = anchor_to_symbol_[out_node_index_io.ToString()];
if (!symbols_set.insert(symbol).second) {
continue;
}
for (const auto &node_index_io : symbol_to_anchors_[symbol]) {
anchors_vec.emplace_back(node_index_io);
}
}
MergeSymbols(anchors_vec, symbols_set);
}
void DependencyAnalyzer::MergeSymbols(const std::vector<NodeIndexIO> &anchors_vec,
const std::unordered_set<std::string> &symbols_set) {
if (symbols_set.size() <= 1U) {
return;
}
const std::string &the_one_symbol = *symbols_set.begin();
symbol_to_anchors_[the_one_symbol].clear();
for (const auto &anchor : anchors_vec) {
anchor_to_symbol_[anchor.ToString()] = the_one_symbol;
symbol_to_anchors_[the_one_symbol].emplace_back(anchor);
}
auto iter = symbols_set.begin();
++iter;
for (; iter != symbols_set.end(); ++iter) {
symbol_to_anchors_.erase(*iter);
}
}
void DependencyAnalyzer::InitSameOutBlockNodesMap() {
std::set<OutDataAnchorPtr> visit_set;
for (const auto node : graph_all_nodes_) {
for (uint32_t i = 0U; i < node->GetAllOutDataAnchorsSize(); ++i) {
if (!visit_set.insert(node->GetOutDataAnchor(i)).second) {
continue;
}
const auto anchors_with_same_block = GetAllUseSameOutBlockAnchors(node, i);
if (anchors_with_same_block.size() <= 1U) {
continue;
}
std::list<size_t> ids;
for (const auto &anchor : anchors_with_same_block) {
visit_set.insert(anchor);
const auto iter = out_data_anchor_2_id_map_.find(anchor);
if (iter != out_data_anchor_2_id_map_.end()) {
ids.emplace_back(iter->second);
}
}
same_block_anchor_id_holder_.emplace_back(std::move(ids));
for (const auto id : same_block_anchor_id_holder_.back()) {
same_block_anchor_ids_[id] = &same_block_anchor_id_holder_.back();
}
}
}
GELOGI("init same output block anchor map finish, size: %zu.", same_block_anchor_id_holder_.size());
}
void DependencyAnalyzer::InitPreAndNextNodeOnSameStreamTable() {
pre_node_on_same_stream_table_.clear();
next_node_on_same_stream_table_.clear();
pre_node_on_same_stream_table_.resize(id_max_, kInvalidTopoId);
next_node_on_same_stream_table_.resize(id_max_, kInvalidTopoId);
std::map<int64_t, int64_t> last_node_map;
std::map<int64_t, std::vector<int64_t>> stream_2_ids;
for (const auto node : id_2_node_table_) {
const auto id = GetTopoId(node);
const auto stream_id = node->GetOpDescBarePtr()->GetStreamId();
if (debug_mode) {
stream_2_ids[stream_id].push_back(id);
}
if (stream_id == kInvalidStreamId) {
continue;
}
auto iter = last_node_map.find(stream_id);
if (iter == last_node_map.end()) {
last_node_map[stream_id] = id;
continue;
}
pre_node_on_same_stream_table_[id] = iter->second;
next_node_on_same_stream_table_[iter->second] = id;
iter->second = id;
}
if (debug_mode) {
for (const auto &ids : stream_2_ids) {
const auto ids_strs = IdsStr(ids.second);
for (const auto &str : ids_strs) {
GELOGI("stream: %" PRId64 ", node topo id: [%s]", ids.first, str.c_str());
}
}
}
GELOGI("init pre/next node on same stream table finish, stream size: %u.", last_node_map.size());
}
void DependencyAnalyzer::InitNodeTypeTable() {
is_wrapper_.clear();
wrapper_ids_.clear();
is_wrapper_.resize(id_max_, false);
for (const auto node : graph_all_nodes_) {
const auto id = GetTopoId(node);
if (!node->GetOpDescBarePtr()->GetSubgraphInstanceNames().empty()) {
is_wrapper_[id] = true;
wrapper_ids_.emplace_back(id);
}
}
}
void DependencyAnalyzer::DeInitWrapperInfo() {
wrapper_info_holder_.clear();
parent_2_childs_map_.clear();
parents_on_root_graph_.clear();
}
void DependencyAnalyzer::InitWrapperInfo() {
for (const auto parent_node : graph_all_nodes_) {
if ((parent_node->GetOpDescBarePtr() == nullptr) || (parent_node->GetOwnerComputeGraphBarePtr() == nullptr) ||
parent_node->GetOpDescBarePtr()->GetSubgraphInstanceNames().empty()) {
continue;
}
if (debug_mode) {
GELOGI("parent_node: %s, topo_id: %" PRId64 "", parent_node->GetNamePtr(), GetTopoId(parent_node));
}
for (const auto &sub_graph_name : parent_node->GetOpDescBarePtr()->GetSubgraphInstanceNames()) {
const auto &sub_graph = root_graph_->GetSubgraph(sub_graph_name);
if (sub_graph != nullptr) {
wrapper_info_holder_[parent_node].emplace_back(WrapperInfo(sub_graph, id_max_));
}
}
const Node *grand_parent_node = parent_node->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr();
if ((grand_parent_node != nullptr) && (parent_node->GetOwnerComputeGraphBarePtr() != compute_graph_.get())) {
parent_2_childs_map_[grand_parent_node].emplace_back(parent_node);
} else {
parents_on_root_graph_.insert(parent_node);
}
}
if (debug_mode) {
DebugWrapperInfo();
}
}
void DependencyAnalyzer::InitReachNodesMap() {
reach_nodes_bit_map_table_.clear();
reach_nodes_bit_map_table_.resize(id_max_, LargeBitmap(id_max_));
std::stack<const Node *> node_stack;
vector_bit_t visited_flag(id_max_, false);
vector_bit_t searched_flag(id_max_, false);
for (const auto node : graph_all_nodes_) {
if (node->GetType() == NETOUTPUT) {
node_stack.push(node);
if (debug_mode) {
GELOGI("Netoutput nodes: %" PRId64 ", stream_id: %" PRId64 "", GetTopoId(node),
node->GetOpDescBarePtr()->GetStreamId());
}
}
}
if (node_stack.empty() && !graph_all_nodes_.empty()) {
GELOGW("graph %s has no NETOUTPUT node, use %s as last node", compute_graph_->GetName().c_str(),
graph_all_nodes_.back()->GetNamePtr());
node_stack.push(graph_all_nodes_.back());
if (debug_mode) {
GELOGI("Netoutput nodes: %" PRId64 ", stream_id: %" PRId64 "", GetTopoId(graph_all_nodes_.back()),
graph_all_nodes_.back()->GetOpDescBarePtr()->GetStreamId());
}
}
while (!node_stack.empty()) {
const auto cur_node = node_stack.top();
node_stack.pop();
const auto cur_node_topo_id = GetTopoId(cur_node);
if (!visited_flag[cur_node_topo_id]) {
visited_flag[cur_node_topo_id] = true;
InitReachNodesMap(cur_node, searched_flag);
searched_flag[cur_node_topo_id] = true;
PushPreNode(cur_node, node_stack);
}
if (node_stack.empty()) {
PushNotVisited(visited_flag, node_stack);
}
}
ExtendReachNodesMapCrossSubGraph();
GELOGI("init reach nodes map finish.");
}
void DependencyAnalyzer::InitReachNodesMap(const Node *root_node, const vector_bit_t &searched_flag) {
std::stack<const Node *> nodes_stack;
vector_bit_t visited_flag(id_max_, false);
PushNextNode(root_node, nodes_stack);
const auto root_node_topo_id = GetTopoId(root_node);
while (!nodes_stack.empty()) {
const auto cur_node = nodes_stack.top();
const auto cur_node_topo_id = GetTopoId(cur_node);
nodes_stack.pop();
if (visited_flag[cur_node_topo_id]) {
continue;
}
visited_flag[cur_node_topo_id] = true;
reach_nodes_bit_map_table_[root_node_topo_id].SetBit(cur_node_topo_id);
if (searched_flag[cur_node_topo_id]) {
reach_nodes_bit_map_table_[root_node_topo_id].Or(reach_nodes_bit_map_table_[cur_node_topo_id]);
} else {
PushNextNode(cur_node, nodes_stack);
}
}
}
* 1 子图内所有节点都能到达父节点(除去Netoutput和与父节点同符号的节点),但反过来不行。
* 2 所有到达父节点的,都能到达该父节点对应子图内所有节点。
* 3 设置reuable map时,获取输出节点应该刨除父节点。
* 刨除后如果输出节点为空,表示该节点的输出没人用,能到达的节点都可以复用该节点。
*/
void DependencyAnalyzer::ExtendReachNodesMapCrossSubGraph() {
WrapperNotReachSubGraphNodes();
std::stack<const Node *> nodes_stack;
for (const auto &node : parents_on_root_graph_) {
nodes_stack.push(node);
}
while (!nodes_stack.empty()) {
const auto parent_node = nodes_stack.top();
const auto parent_id = GetTopoId(parent_node);
nodes_stack.pop();
for (const auto &child : parent_2_childs_map_[parent_node]) {
nodes_stack.push(child);
}
const auto owner_graph = parent_node->GetOwnerComputeGraphBarePtr();
for (const auto &node : owner_graph->GetDirectNode()) {
if (!reach_nodes_bit_map_table_[GetTopoId(node.get())].GetBit(parent_id)) {
continue;
}
for (const auto &wrapper_info : wrapper_info_holder_[parent_node]) {
reach_nodes_bit_map_table_[GetTopoId(node.get())].Or(wrapper_info.subgraph_all_nodes_bitmap);
if (debug_mode) {
GELOGI("%" PRId64 " reach %" PRId64 "(%s), extend to reach[%" PRId64 "-%" PRId64 "]", GetTopoId(node.get()),
parent_id, parent_node->GetTypePtr(), wrapper_info.direct_nodes.front(),
wrapper_info.direct_nodes.back());
}
}
}
}
for (const auto &node : parents_on_root_graph_) {
nodes_stack.push(node);
}
while (!nodes_stack.empty()) {
const auto parent_node = nodes_stack.top();
nodes_stack.pop();
for (const auto &child : parent_2_childs_map_[parent_node]) {
nodes_stack.push(child);
}
const auto same_out_nodes_with_wrapper_set = GetAllNodeIdsSameSymbolWithWrapper(parent_node);
const auto parent_id = GetTopoId(parent_node);
const auto &parent_reach_map = reach_nodes_bit_map_table_[parent_id];
for (const auto &wrapper_info : wrapper_info_holder_[parent_node]) {
for (const auto direct_node_id : wrapper_info.direct_nodes) {
reach_nodes_bit_map_table_[direct_node_id].Or(parent_reach_map);
if (same_out_nodes_with_wrapper_set.find(direct_node_id) == same_out_nodes_with_wrapper_set.cend()) {
reach_nodes_bit_map_table_[direct_node_id].SetBit(parent_id);
}
}
}
}
}
* 父节点拓扑序比子图节点的小,如果分配在同一个stream上,那么会存在父节点能够达到子图内节点的情况。
* 由于父节点和子图内与netoutput相连的节点同一块内存,所以只允许子图内节点到达父节点。
*/
void DependencyAnalyzer::WrapperNotReachSubGraphNodes() {
for (const auto &wrapper_pair : wrapper_info_holder_) {
const auto &parent_node = wrapper_pair.first;
const auto parent_id = GetTopoId(parent_node);
for (const auto &subgraph_name : parent_node->GetOpDescBarePtr()->GetSubgraphInstanceNames()) {
const auto &sub_graph = root_graph_->GetSubgraph(subgraph_name);
if (sub_graph != nullptr) {
for (const auto &node : sub_graph->GetAllNodes()) {
reach_nodes_bit_map_table_[parent_id].ClearBit(GetTopoId(node.get()));
}
}
}
}
}
std::unordered_set<int64_t> DependencyAnalyzer::GetAllNodeIdsSameSymbolWithWrapper(const Node *node) {
std::unordered_set<int64_t> ret;
for (uint32_t index = 0U; index < node->GetAllOutDataAnchorsSize(); ++index) {
const auto cur_anchor = NodeIndexIO(node, index, kOut);
const auto &symbol = anchor_to_symbol_[cur_anchor.ToString()];
std::for_each(symbol_to_anchors_[symbol].begin(), symbol_to_anchors_[symbol].end(),
[&ret](const NodeIndexIO &anchor) {
if ((anchor.node_ != nullptr) && (anchor.node_->GetOpDesc() != nullptr)) {
ret.insert(GetTopoId(anchor.node_.get()));
}
});
}
return ret;
}
void DependencyAnalyzer::PushPreNode(const Node *node, std::stack<const Node *> &node_stack) const {
for (auto &pre_node : node->GetInNodesPtr()) {
node_stack.push(pre_node);
}
if (debug_mode) {
GELOGI("%" PRId64 " input nodes: [%s]", GetTopoId(node), NodesIdStr(node->GetInNodesPtr()).c_str());
}
const auto cur_node_topo_id = GetTopoId(node);
auto pre_node_on_same_stream = GetPreNodeOnSameStream(cur_node_topo_id);
if (pre_node_on_same_stream != nullptr) {
node_stack.push(pre_node_on_same_stream);
if (debug_mode) {
GELOGI("%" PRId64 " prev node %" PRId64 " on stream %" PRId64 " ", GetTopoId(node),
GetTopoId(pre_node_on_same_stream), node->GetOpDescBarePtr()->GetStreamId());
}
}
}
void DependencyAnalyzer::PushNextNode(const Node *node, std::stack<const Node *> &nodes_stack) const {
for (const auto &next_node : node->GetOutNodesPtr()) {
nodes_stack.push(next_node);
}
if (debug_mode) {
GELOGI("%" PRId64 " out nodes: [%s]", GetTopoId(node), NodesIdStr(node->GetOutNodesPtr()).c_str());
}
auto next_node_on_same_stream = GetNextNodeOnSameStream(GetTopoId(node));
if (next_node_on_same_stream != nullptr) {
nodes_stack.push(next_node_on_same_stream);
if (debug_mode) {
GELOGI("%" PRId64 " next node %" PRId64 " on stream %" PRId64 "", GetTopoId(node),
GetTopoId(next_node_on_same_stream), node->GetOpDescBarePtr()->GetStreamId());
}
}
}
void DependencyAnalyzer::PushNotVisited(const vector_bit_t &visited_flag, std::stack<const Node *> &nodes_stack) const {
if (id_max_ > static_cast<size_t>(std::numeric_limits<int64_t>::max())) {
return;
}
for (auto i = static_cast<int64_t>(id_max_) - 1; i >= 0; --i) {
if ((!visited_flag[i]) && (id_2_node_table_[i] != nullptr)) {
nodes_stack.push(id_2_node_table_[i]);
if (debug_mode) {
GELOGI("push not visited node: %" PRId64 "", i);
}
break;
}
}
}
const Node *DependencyAnalyzer::GetPreNodeOnSameStream(const int64_t topo_id) const {
auto pre_id = pre_node_on_same_stream_table_[topo_id];
if (static_cast<size_t>(pre_id) >= id_2_node_table_.size()) {
return nullptr;
}
return id_2_node_table_[pre_id];
}
const Node *DependencyAnalyzer::GetNextNodeOnSameStream(const int64_t topo_id) const {
auto next_id = next_node_on_same_stream_table_[topo_id];
if (static_cast<size_t>(next_id) >= id_2_node_table_.size()) {
return nullptr;
}
return id_2_node_table_[next_id];
}
void DependencyAnalyzer::ErrorLogDependNodes(const ge::Node *node) const {
const auto node_id = GetTopoId(node);
bool continue_id_flag = false;
size_t start = std::numeric_limits<size_t>::max();
std::stringstream ss;
if (reach_nodes_bit_map_table_[0U].GetBit(node_id)) {
continue_id_flag = true;
start = 0U;
}
for (size_t i = 1U; i < id_max_; ++i) {
if (reach_nodes_bit_map_table_[i].GetBit(node_id)) {
if (!continue_id_flag) {
start = i;
continue_id_flag = true;
}
} else if (continue_id_flag) {
if ((i - 1U) == start) {
ss << start << ", ";
} else {
ss << start << "-" << (i - 1U) << ", ";
}
start = std::numeric_limits<size_t>::max();
continue_id_flag = false;
}
}
if (continue_id_flag) {
if (start == id_max_) {
ss << start;
} else {
ss << start << "-" << id_max_;
}
}
GELOGE(ge::FAILED, "%zu depends nodes[%s].", node_id, ss.str().c_str());
}
void DependencyAnalyzer::ErrorLogDependNodes(const std::list<OutDataAnchorPtr> &out_anchors) const {
std::set<OutDataAnchor *> visited_set;
for (auto out_anchor : out_anchors) {
if (visited_set.insert(out_anchor.get()).second) {
ErrorLogDependNodes(out_anchor->GetOwnerNodeBarePtr());
}
}
}
void DependencyAnalyzer::ErrorLogReachNodes(const ge::Node *node) const {
const auto &bit_map = reach_nodes_bit_map_table_[GetTopoId(node)];
bool continue_id = false;
size_t start = std::numeric_limits<size_t>::max();
std::stringstream ss;
for (size_t i = 1U; i < id_max_; ++i) {
if (bit_map.GetBit(i)) {
if (!continue_id) {
start = i;
continue_id = true;
}
} else if (continue_id) {
if (start == (i - 1U)) {
ss << start << ", ";
} else {
ss << start << "-" << (i - 1U) << ", ";
}
start = std::numeric_limits<size_t>::max();
continue_id = false;
}
}
if (continue_id) {
if (start == id_max_) {
ss << start;
} else {
ss << start << "-" << id_max_;
}
}
GELOGE(ge::FAILED, "%zu can reach [%s].", GetTopoId(node), ss.str().c_str());
}
void DependencyAnalyzer::ErrorLogReachNodes(const std::list<const Node *> &nodes) const {
std::set<const Node *> visited_set;
for (const auto node : nodes) {
if (visited_set.insert(node).second) {
ErrorLogReachNodes(node);
}
}
}
void DependencyAnalyzer::ErrorLogOutSymbol(const Node *node, uint32_t out_index) const {
NodeIndexIO node_index_io(node, out_index, kOut);
ErrorLogSymbol(node_index_io);
}
void DependencyAnalyzer::ErrorLogSymbol(const NodeIndexIO &node_index_io) const {
const auto symbol_iter = anchor_to_symbol_.find(node_index_io.ToString());
if (symbol_iter == anchor_to_symbol_.end()) {
return;
}
const auto anchors_iter = symbol_to_anchors_.find(symbol_iter->second);
if (anchors_iter == symbol_to_anchors_.end()) {
return;
}
GELOGE(FAILED, "print anchors in same symbol with %s, start-----------------------------",
TopoIdTypeIndex(node_index_io).c_str());
for (const auto &anchor : anchors_iter->second) {
if (anchor.node_ptr_ == nullptr) {
continue;
}
if (anchor.io_type_ == kIn) {
const auto in_anchor = anchor.node_ptr_->GetInDataAnchor(anchor.index_);
if (in_anchor == nullptr) {
continue;
}
const auto peer_out_anchor = in_anchor->GetPeerOutAnchor().get();
if (peer_out_anchor != nullptr) {
GELOGE(FAILED, "%s --> %s",
TopoIdTypeOutIndex(peer_out_anchor->GetOwnerNodeBarePtr(), peer_out_anchor->GetIdx()).c_str(),
TopoIdTypeIndex(anchor).c_str());
} else {
GELOGE(FAILED, "suspended in data anchor: %s", TopoIdTypeIndex(anchor).c_str());
}
} else {
const auto out_anchor = anchor.node_ptr_->GetOutDataAnchor(anchor.index_);
if ((out_anchor != nullptr) && (out_anchor->GetPeerInDataAnchorsPtr().empty())) {
GELOGE(FAILED, "suspended out data anchor: %s", TopoIdTypeIndex(anchor).c_str());
}
}
}
GELOGE(FAILED, "print anchors in same symbol with %s, end-------------------------------",
TopoIdTypeIndex(node_index_io).c_str());
}
bool DependencyAnalyzer::CheckParam(const ge::Node *a, uint32_t a_output_index, const ge::Node *b,
uint32_t b_output_index) const {
if ((a == nullptr) || (b == nullptr) || (a->GetOutDataAnchor(a_output_index) == nullptr) ||
(b->GetOutDataAnchor(b_output_index) == nullptr) || (a->GetOpDescBarePtr() == nullptr) ||
(b->GetOpDescBarePtr() == nullptr)) {
GELOGE(FAILED, "invalid param.");
return false;
}
if ((static_cast<size_t>(GetTopoId(a)) >= id_max_) || (static_cast<size_t>(GetTopoId(b)) >= id_max_)) {
GELOGE(FAILED, "invalid param.");
return false;
}
return true;
}
bool DependencyAnalyzer::IsOutNodeSkip(const Node *const node, const Node *const origin_node) const {
const auto node_id = GetTopoId(node);
if ((static_cast<size_t>(node_id) < id_max_) && is_wrapper_[node_id]) {
return true;
}
if (OpTypeUtils::IsDataNode(node->GetType())) {
return true;
}
if (node->GetOpDescBarePtr()->GetStreamId() == ge::kInvalidStreamId) {
return true;
}
std::string batch_label;
(void)ge::AttrUtils::GetStr(node->GetOpDescBarePtr(), ATTR_NAME_BATCH_LABEL, batch_label);
std::string origin_node_batch_label;
(void)ge::AttrUtils::GetStr(origin_node->GetOpDescBarePtr(), ATTR_NAME_BATCH_LABEL, origin_node_batch_label);
if (batch_label != origin_node_batch_label) {
return true;
}
return false;
}
void DependencyAnalyzer::DebugWrapperInfo() const {
std::stringstream ss;
for (const auto root_parent : parents_on_root_graph_) {
ss << GetTopoId(root_parent) << "(" << root_parent->GetType() << ")"
<< ", ";
}
GELOGI("parent nodes on root graph: [%s]", ss.str().c_str());
ss.str("");
ss.clear();
for (const auto &parent_to_childs : parent_2_childs_map_) {
for (const auto child : parent_to_childs.second) {
ss << GetTopoId(child) << "(" << child->GetType() << ")"
<< ", ";
}
GELOGI("parent node: %" PRId64 "(type: %s), children: [%s]", GetTopoId(parent_to_childs.first),
parent_to_childs.first->GetTypePtr(), ss.str().c_str());
ss.str("");
ss.clear();
}
ss.str("");
ss.clear();
for (const auto &parent_2_childs : wrapper_info_holder_) {
size_t sub_graph_num = 1U;
for (const auto &child : parent_2_childs.second) {
if (!child.direct_nodes.empty()) {
GELOGI("parent node: %" PRId64 "(type: %s), subgraph:%zu, children: [%" PRId64 "-%" PRId64 "]",
GetTopoId(parent_2_childs.first), parent_2_childs.first->GetTypePtr(), sub_graph_num,
child.direct_nodes.front(), child.direct_nodes.back());
}
++sub_graph_num;
}
}
}
}