* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "bg_condition.h"
#include "common/checker.h"
#include "common/debug/ge_log.h"
#include "common/ge_inner_attrs.h"
#include "common/util/mem_utils.h"
#include "core/builder/node_types.h"
#include "exe_graph/lowering/exe_graph_attrs.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/fast_graph/execute_graph.h"
#include "exe_graph/lowering/value_holder_utils.h"
#include "graph/utils/execute_graph_utils.h"
namespace gert {
namespace bg {
namespace {
struct GuarderInfo {
bool has_guarder;
ge::FastNode *in_graph_guarder;
std::string guarder_node_type;
};
ValueHolderPtr CreateInnerData(int32_t index) {
auto data = ValueHolder::CreateSingleDataOutput("InnerData", {});
GE_ASSERT_NOTNULL(data);
auto desc = ValueHolderUtils::GetNodeOpDescBarePtr(data);
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(desc, "index", index));
return data;
}
std::string GetGuarderOutside(const ge::FastNode *node) {
if (!IsInnerDataType(node->GetTypePtr())) {
return "";
}
auto outside_guarder_type = ge::AttrUtils::GetStr(node->GetOpDescBarePtr(), kNodeWithGuarderOutside);
return outside_guarder_type != nullptr ? *outside_guarder_type : "";
}
GuarderInfo FindGuarder(const ge::EdgeSrcEndpoint &edge_src_point) {
auto src_node = edge_src_point.node;
const auto guarder_type = GetGuarderOutside(src_node);
if (!guarder_type.empty()) {
if (!IsFreeNode(guarder_type.c_str())) {
GELOGE(ge::PARAM_INVALID, "get guarder type of node %s[%s] failed, guarder type = %s",
src_node->GetNamePtr(), src_node->GetTypePtr(),
guarder_type.c_str());
return {false, nullptr, ""};
}
return {true, nullptr, guarder_type};
}
for (const auto &guarder_edge : src_node->GetOutEdgesRefByIndex(edge_src_point.index)) {
if (guarder_edge == nullptr) {
continue;
}
auto guarder_node = guarder_edge->dst;
int32_t guarder_index = 0;
if (!ge::AttrUtils::GetInt(guarder_node->GetOpDescBarePtr(), kReleaseResourceIndex, guarder_index)) {
continue;
}
if (guarder_index == guarder_edge->dst_input) {
return {true, guarder_node, guarder_node->GetType()};
}
}
return {false, nullptr, ""};
}
}
std::vector<ValueHolderPtr> EmptySubgraphBuilder() { return {}; }
ge::graphStatus CreateBranchGraphs(
const ValueHolderPtr &cond_holder, const std::vector<SubgraphBuilder> &builders,
const std::vector<std::string> &subgraph_names,
std::unordered_map<std::string, std::vector<ValueHolderPtr>> &subgraph_names_to_outputs) {
GE_ASSERT_EQ(builders.size(), subgraph_names.size());
size_t output_count = 0UL;
for (size_t i = 0UL; i < builders.size(); ++i) {
GE_ASSERT_NOTNULL(ValueHolder::PushGraphFrame(cond_holder, subgraph_names[i].c_str()));
auto outputs = builders[i]();
if (i == 0UL) {
output_count = outputs.size();
} else {
if (output_count != outputs.size()) {
GELOGE(ge::PARAM_INVALID,
"Failed to create branch graphs, %zu th(%s) outputs num are different with first(%s) %zu/%zu", i,
subgraph_names[i].c_str(), subgraph_names[0U].c_str(), outputs.size(), output_count);
return ge::GRAPH_PARAM_INVALID;
}
}
auto frame = ValueHolder::PopGraphFrame(outputs, {});
if (frame == nullptr) {
GELOGE(ge::PARAM_INVALID, "Failed to construct %s graph", subgraph_names[i].c_str());
return ge::PARAM_INVALID;
}
subgraph_names_to_outputs[subgraph_names[i]] = std::move(outputs);
}
return ge::GRAPH_SUCCESS;
}
*
* WaitAnyone(Only One)
* :
* BranchDone(0..subgraph_names)
* :
* WatcherPlaceholder BranchPivot(0..subgraph_names)
* \ |
* SwitchNotify
* |
* InnerData(switch-index)
*/
ge::graphStatus CreateControlFrame(const ValueHolderPtr &cond_holder, size_t branch_num) {
GE_ASSERT_NOTNULL(ValueHolder::PushGraphFrame(cond_holder, ge::kConditionGraph));
auto cond_data = CreateInnerData(0);
auto outputs = ValueHolder::CreateDataOutput("SwitchNotify", {cond_data}, branch_num + 1U);
GE_ASSERT_NOTNULL(outputs[0U]);
auto output_desc = ValueHolderUtils::GetNodeOpDescBarePtr(outputs[0U]);
GE_ASSERT_TRUE(ge::AttrUtils::SetBool(output_desc, ge::kRequestWatcher, true));
GE_ASSERT_NOTNULL(ValueHolder::CreateVoid<bg::ValueHolder>("WatcherPlaceholder", {outputs[0U]}));
auto wait_anyone = ValueHolder::CreateVoid<bg::ValueHolder>("WaitAnyone", {});
for (size_t i = 0U; i < branch_num; ++i) {
auto pivot = ValueHolder::CreateVoid<bg::ValueHolder>("BranchPivot", {outputs[i + 1U]});
auto done = ValueHolder::CreateVoid<bg::ValueHolder>("BranchDone", {});
GE_ASSERT_NOTNULL(pivot);
GE_ASSERT_NOTNULL(done);
auto pivot_desc = ValueHolderUtils::GetNodeOpDescBarePtr(pivot);
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(pivot_desc, ge::kRelativeBranch, i + 1U));
auto done_desc = ValueHolderUtils::GetNodeOpDescBarePtr(done);
GE_ASSERT_TRUE(ge::AttrUtils::SetInt(done_desc, ge::kRelativeBranch, i + 1U));
GE_ASSERT_HYPER_SUCCESS(ValueHolder::AddDependency(pivot, done));
GE_ASSERT_HYPER_SUCCESS(ValueHolder::AddDependency(done, wait_anyone));
}
ValueHolder::PopGraphFrame();
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CalcSubgraphGuardersPolicy(const ge::FastNode * const parent_node,
std::set<std::pair<int32_t, std::string>> &all_guarder_indexes_types,
std::vector<OutputsGuarderInfo> &subgraph_indexes_to_policy) {
const auto op_desc = parent_node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
GE_ASSERT_TRUE(!subgraph_name.empty());
const size_t subgraph_num = subgraph_name.size() - 1U;
subgraph_indexes_to_policy.resize(subgraph_num);
std::vector<std::set<int32_t>> subgraph_indexes_to_guarder_indexes(subgraph_num);
std::vector<std::set<ge::EdgeSrcEndpoint>> subgraph_indexes_to_guarded_endpoint(subgraph_num);
std::vector<std::map<int32_t, PlusRefGuarderInfo>> subgraph_indexes_to_parent_indexes_to_no_guarder(subgraph_num);
for (size_t i = 0U; i < subgraph_num; ++i) {
auto subgraph = ge::FastNodeUtils::GetSubgraphFromNode(parent_node, i + 1U);
GE_ASSERT_NOTNULL(subgraph);
auto netoutput_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subgraph, "InnerNetOutput");
GE_ASSERT_NOTNULL(netoutput_node);
for (const auto edge : netoutput_node->GetAllInDataEdgesRef()) {
if (edge == nullptr) {
continue;
}
GE_ASSERT_NOTNULL(edge->src);
GE_ASSERT_NOTNULL(edge->dst);
ge::EdgeSrcEndpoint src_endpoint{edge->src, edge->src_output};
const GuarderInfo guarder_info = FindGuarder(src_endpoint);
if (guarder_info.has_guarder) {
all_guarder_indexes_types.insert({edge->dst_input, guarder_info.guarder_node_type});
subgraph_indexes_to_guarder_indexes[i].insert(edge->dst_input);
if (guarder_info.in_graph_guarder != nullptr) {
if (!subgraph_indexes_to_guarded_endpoint[i].insert(src_endpoint).second) {
subgraph_indexes_to_policy[i].plus_refs.push_back(
{edge->src, edge->src_output, ge::EdgeDstEndpoint{edge->dst, edge->dst_input}});
}
subgraph_indexes_to_policy[i].to_be_removed_guarders.push_back(guarder_info.in_graph_guarder);
} else {
subgraph_indexes_to_policy[i].plus_refs.push_back(
{edge->src, edge->src_output, ge::EdgeDstEndpoint{edge->dst, edge->dst_input}});
}
} else {
subgraph_indexes_to_parent_indexes_to_no_guarder[i][edge->dst_input] = {edge->src, edge->src_output,
{edge->dst, edge->dst_input}};
}
}
}
for (auto guarder_index_type : all_guarder_indexes_types) {
for (size_t i = 0U; i < subgraph_num; ++i) {
if (subgraph_indexes_to_guarder_indexes[i].count(guarder_index_type.first) > 0) {
continue;
}
subgraph_indexes_to_policy[i].plus_refs.emplace_back(
std::move(subgraph_indexes_to_parent_indexes_to_no_guarder[i][guarder_index_type.first]));
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddGuardersToParentNode(const std::set<std::pair<int32_t, std::string>> &all_guarder_indexes_types,
const std::vector<ValueHolderPtr> &outputs) {
for (const auto &guarder_index_type : all_guarder_indexes_types) {
ValueHolder::CreateVoidGuarder(guarder_index_type.second.c_str(), outputs[guarder_index_type.first], {});
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus RemoveGuarders(const std::vector<ge::FastNode *> &remove_guarders) {
for (const auto &guarder : remove_guarders) {
ge::FastNodeUtils::UnlinkAll(guarder);
ge::ExecuteGraphUtils::RemoveNodeWithoutRelink(guarder->GetExtendInfo()->GetOwnerGraphBarePtr(), guarder);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddPlusCount(const std::string &name, const std::vector<PlusRefGuarderInfo> &plus_refs) {
auto op_desc = ge::MakeShared<ge::OpDesc>(name, "IdentityAddr");
GE_ASSERT_NOTNULL(op_desc);
for (size_t i = 0U; i < plus_refs.size(); ++i) {
GE_ASSERT_SUCCESS(op_desc->AddInputDesc(ge::GeTensorDesc()));
GE_ASSERT_SUCCESS(op_desc->AddOutputDesc(ge::GeTensorDesc()));
}
const auto node = plus_refs[0].output_node->GetExtendInfo()->GetOwnerGraphBarePtr()->AddNode(op_desc);
GE_ASSERT_NOTNULL(node);
for (size_t i = 0U; i < plus_refs.size(); ++i) {
ge::ExecuteGraphUtils::InsertNodeAfter({plus_refs[i].output_node, plus_refs[i].output_index},
{plus_refs[i].netoutput_in}, node, i, i);
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CalcChainConflictSolvePolicy(const ge::FastNode * const node,
std::vector<ge::FastNode *> &subgraph_indexes_to_netoutput,
std::set<int32_t> &conflict_indexes, std::set<int32_t> &all_same_indexes) {
const auto op_desc = node->GetOpDescBarePtr();
GE_ASSERT_NOTNULL(op_desc);
const auto &subgraph_name = op_desc->GetSubgraphInstanceNames();
GE_ASSERT_TRUE(!subgraph_name.empty());
const size_t subgraph_num = subgraph_name.size() - 1U;
std::vector<std::set<int32_t>> output_indexes_to_input_indexes(node->GetDataOutNum());
subgraph_indexes_to_netoutput.resize(subgraph_num);
for (size_t subgraph_index = 0U; subgraph_index < subgraph_num; ++subgraph_index) {
auto subgraph = ge::FastNodeUtils::GetSubgraphFromNode(node, subgraph_index + 1U);
GE_ASSERT_NOTNULL(subgraph);
auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subgraph, "InnerNetOutput");
GE_ASSERT_NOTNULL(netoutput);
subgraph_indexes_to_netoutput[subgraph_index] = netoutput;
size_t output_index = 0U;
for (const auto &in_node : netoutput->GetInDataNodes()) {
GE_ASSERT_TRUE(output_index < output_indexes_to_input_indexes.size());
if (IsInnerDataType(in_node->GetTypePtr())) {
int32_t data_index;
GE_ASSERT_TRUE(ge::AttrUtils::GetInt(in_node->GetOpDescBarePtr(), kFeedIndex, data_index));
output_indexes_to_input_indexes[output_index++].insert(data_index);
} else {
output_indexes_to_input_indexes[output_index++].insert(-1);
}
}
}
for (size_t i = 0U; i < output_indexes_to_input_indexes.size(); ++i) {
if (output_indexes_to_input_indexes[i].size() > 1U) {
conflict_indexes.insert(static_cast<int32_t>(i));
continue;
}
if (*output_indexes_to_input_indexes[i].begin() != -1) {
all_same_indexes.insert(static_cast<int32_t>(i));
continue;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus AddPointFromForConflicts(const std::string &name, const std::set<int32_t> &conflict_indexes,
const std::vector<ge::FastNode *> &netoutput_nodes) {
size_t subgraph_index = 0U;
for (const auto &netoutput_node : netoutput_nodes) {
auto op_desc = ge::MakeShared<ge::OpDesc>(name + "_" + std::to_string(subgraph_index++), "PointFromInputs");
GE_ASSERT_NOTNULL(op_desc);
for (size_t i = 0U; i < conflict_indexes.size(); ++i) {
GE_ASSERT_SUCCESS(op_desc->AddInputDesc(ge::GeTensorDesc()));
GE_ASSERT_SUCCESS(op_desc->AddOutputDesc(ge::GeTensorDesc()));
}
const auto node = netoutput_node->GetExtendInfo()->GetOwnerGraphBarePtr()->AddNode(op_desc);
GE_ASSERT_NOTNULL(node);
uint32_t pfi_index = 0U;
for (auto conflict_index : conflict_indexes) {
GE_ASSERT_SUCCESS(ge::ExecuteGraphUtils::InsertNodeBefore({netoutput_node, conflict_index}, node,
pfi_index, pfi_index));
++pfi_index;
}
}
return ge::GRAPH_SUCCESS;
}
ge::graphStatus SortAllSubgraphs(const ValueHolderPtr &holder) {
auto node = holder->GetFastNode();
for (size_t i = 1U; i < node->GetOpDescBarePtr()->GetSubgraphInstanceNames().size(); ++i) {
const auto subgraph = ge::FastNodeUtils::GetSubgraphFromNode(node, i);
GE_ASSERT_NOTNULL(subgraph);
GE_ASSERT_SUCCESS(subgraph->TopologicalSorting(), "Failed to topo sort on %zu graph , node %s type %s", i,
node->GetNamePtr(), node->GetTypePtr());
}
return ge::GRAPH_SUCCESS;
}
}
}