* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "graph/partition/base_cluster.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
#include "graph_metadef/common/ge_common/util.h"
#include "common/plugin/ge_make_unique_util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "framework/common/framework_types_internal.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/op_type_utils.h"
#include "graph/partition/base_partitioner.h"
namespace ge {
thread_local size_t ge::BaseCluster::unique_id_ = 0;
using ClusterPtr = std::shared_ptr<BaseCluster>;
namespace {
NodePtr AddPartitionedCallKeepTopo(const ComputeGraphPtr &compute_graph,
const std::vector<NodePtr> &ori_nodes, const OpDescPtr &partitioned_call_op) {
auto min_id_node = NodeUtils::GetNodeWithMinimalId(ori_nodes);
GE_ASSERT_NOTNULL(min_id_node);
auto ret_nodes = compute_graph->InsertNodes(min_id_node, {partitioned_call_op});
GE_ASSERT_TRUE(!ret_nodes.empty());
return ret_nodes[0];
}
}
PartitionNodeAttrNameManager &PartitionNodeAttrNameManager::GetInstance() {
static PartitionNodeAttrNameManager instance;
return instance;
}
void PartitionNodeAttrNameManager::RegisterNodeAttrName(const PartitionNodeAttrName &attr_name) {
const std::lock_guard<std::mutex> lock(mutex_);
attr_array_.emplace_back(attr_name);
}
PartitionNodeAttrRegister::PartitionNodeAttrRegister(const std::string &attr_name, bool is_support_tensor,
const bool is_need_copy, const SetAttrFn set_fn,
const GetAttrFn get_fn) noexcept {
PartitionNodeAttrName partition_attr_name(attr_name, is_support_tensor, is_need_copy, set_fn, get_fn);
PartitionNodeAttrNameManager::GetInstance().RegisterNodeAttrName(partition_attr_name);
}
std::string BaseCluster::DebugString() const {
std::stringstream ss;
if (type_index_ < static_cast<int32_t>(partitioner_->type_index_to_type_.size())) {
ss << partitioner_->type_index_to_type_[type_index_];
}
ss << "[" << id_ << "]" << "[" << type_index_ << "]" << "(size:" << nodes_.size() << ")";
ss << "(" << min_ << "," << max_ << ")(";
for (const auto &cluster : in_clusters_) {
ss << cluster->id_ << ",";
}
ss << ")->(";
for (const auto &cluster : out_clusters_) {
ss << cluster->id_ << ",";
}
ss << ")|";
for (const auto &node : nodes_) {
ss << (node->GetName() + "|");
}
return ss.str();
}
size_t BaseCluster::MinId() const { return min_; }
size_t BaseCluster::Id() const { return id_; }
size_t BaseCluster::UniqueId() const { return unique_id_; }
void BaseCluster::UpdateRank(size_t rank) {
max_ = rank;
min_ = rank;
};
bool BaseCluster::IsData() const { return partitioner_->type_index_to_type_[type_index_] == kClusterData; }
bool BaseCluster::IsIndependent() const { return partitioner_->type_index_to_type_[type_index_] == kClusterStage; }
bool BaseCluster::IsNetOutput() const { return partitioner_->type_index_to_type_[type_index_] == kClusterNetoutput; }
bool BaseCluster::IsInputNode() const { return partitioner_->type_index_to_type_[type_index_] == kClusterInputNode; }
bool BaseCluster::IsRefVariable() const {
if ((nodes_.size() == 1) && OpTypeUtils::IsVarLikeNode(nodes_[0]->GetType())) {
std::string ref_variable_name;
return (AttrUtils::GetStr(nodes_[0]->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_variable_name) &&
!ref_variable_name.empty());
}
return false;
}
void BaseCluster::AddInput(ClusterPtr in) {
if (std::find(in_clusters_.begin(), in_clusters_.end(), in.get()) != in_clusters_.end()) {
return;
}
in_clusters_.insert(in_clusters_.end(), in.get());
if (std::find(in->out_clusters_.begin(), in->out_clusters_.end(), this) != in->out_clusters_.end()) {
return;
}
in->out_clusters_.insert(in->out_clusters_.end(), this);
};
void BaseCluster::RemoveInput(ClusterPtr in) {
in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), in.get()), in_clusters_.end());
in->out_clusters_.erase(std::remove(in->out_clusters_.begin(), in->out_clusters_.end(), this),
in->out_clusters_.end());
};
void BaseCluster::AddOutput(ClusterPtr out) {
if (std::find(out_clusters_.begin(), out_clusters_.end(), out.get()) != out_clusters_.end()) {
return;
}
out_clusters_.insert(out_clusters_.end(), out.get());
if (std::find(out->in_clusters_.begin(), out->in_clusters_.end(), this) != out->in_clusters_.end()) {
return;
}
out->in_clusters_.insert(out->in_clusters_.end(), this);
};
void BaseCluster::RemoveOutput(ClusterPtr out) {
out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), out.get()), out_clusters_.end());
out->in_clusters_.erase(std::remove(out->in_clusters_.begin(), out->in_clusters_.end(), this),
out->in_clusters_.end());
};
void BaseCluster::Merge(ClusterPtr other) {
if (other->IsIndependent()) { return; }
nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end());
other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), this),
other->in_clusters_.end());
other->out_clusters_.erase(std::remove(other->out_clusters_.begin(), other->out_clusters_.end(), this),
other->out_clusters_.end());
in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), other.get()), in_clusters_.end());
out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), other.get()), out_clusters_.end());
auto in_clusters = other->in_clusters_;
for (const auto &cluster : in_clusters) {
cluster->RemoveOutput(other);
cluster->AddOutput(shared_from_this());
}
auto out_clusters = other->out_clusters_;
for (const auto &cluster : out_clusters) {
cluster->RemoveInput(other);
cluster->AddInput(shared_from_this());
}
if (other->max_ > max_) {
max_ = other->max_;
}
if (other->min_ < min_) {
min_ = other->min_;
}
}
bool BaseCluster::TryMerge(ClusterPtr other) {
std::queue<BaseCluster *> forward_reached;
forward_reached.push(other.get());
std::unordered_set<BaseCluster *> visited;
while (!forward_reached.empty()) {
auto current_cluster = forward_reached.front();
forward_reached.pop();
for (const auto &cluster : current_cluster->out_clusters_) {
if ((cluster->max_ == max_) && (current_cluster != other.get())) {
return false;
} else if (cluster->min_ < max_) {
if (visited.insert(cluster).second) {
forward_reached.push(cluster);
}
}
}
}
Merge(other);
return true;
}
void BaseCluster::MergeInOutClusters(const std::unordered_set<BaseCluster *> &path_cluster_set,
const std::vector<BaseCluster *> &ordered_path_clusters) {
std::vector<BaseCluster *> merge_in_clusters;
std::vector<BaseCluster *> merge_out_clusters;
std::unordered_set<BaseCluster *> unique_path_in_clusters;
std::unordered_set<BaseCluster *> unique_path_out_clusters;
for (auto &path_cluster : ordered_path_clusters) {
for (const auto &cluster_input : path_cluster->in_clusters_) {
if ((path_cluster_set.find(cluster_input) == path_cluster_set.cend()) &&
unique_path_in_clusters.insert(cluster_input).second) {
merge_in_clusters.emplace_back(cluster_input);
}
}
for (const auto &cluster_output : path_cluster->out_clusters_) {
if ((path_cluster_set.find(cluster_output) == path_cluster_set.cend()) &&
unique_path_out_clusters.insert(cluster_output).second) {
merge_out_clusters.emplace_back(cluster_output);
}
}
max_ = (path_cluster->max_ > max_) ? path_cluster->max_ : max_;
min_ = (path_cluster->min_ < min_) ? path_cluster->min_ : min_;
}
in_clusters_ = std::move(merge_in_clusters);
out_clusters_ = std::move(merge_out_clusters);
}
void BaseCluster::UpdateInOutClusters(const std::unordered_set<BaseCluster *> &path_cluster_set) {
for (auto &in_cluster : in_clusters_) {
const auto &it = in_cluster;
for (const auto &out_cluster : path_cluster_set) {
it->out_clusters_.erase(std::remove(it->out_clusters_.begin(), it->out_clusters_.end(), out_cluster),
it->out_clusters_.end());
}
if (std::find(it->out_clusters_.begin(), it->out_clusters_.end(), this) == it->out_clusters_.end()) {
it->out_clusters_.emplace_back(this);
}
}
for (auto &out_cluster : out_clusters_) {
const auto &it = out_cluster;
for (const auto &in_cluster : path_cluster_set) {
it->in_clusters_.erase(std::remove(it->in_clusters_.begin(), it->in_clusters_.end(), in_cluster),
it->in_clusters_.end());
}
if (std::find(it->in_clusters_.begin(), it->in_clusters_.end(), this) == it->in_clusters_.end()) {
it->in_clusters_.emplace_back(this);
}
}
}
void BaseCluster::MergeAllPath(const std::vector<ClusterPtr> &path_clusters) {
constexpr int32_t kUnknownShapeTypeIndex = 5;
std::unordered_set<BaseCluster *> path_cluster_set;
std::vector<BaseCluster *> ordered_path_clusters;
ordered_path_clusters.reserve(path_clusters.size() + 1U);
path_cluster_set.insert(this);
ordered_path_clusters.emplace_back(this);
size_t node_size = nodes_.size();
for (const auto &path_cluster : path_clusters) {
if (path_cluster->IsIndependent()) {
continue;
}
if (path_cluster->type_index_ == kUnknownShapeTypeIndex) {
type_index_ = kUnknownShapeTypeIndex;
}
path_cluster_set.insert(path_cluster.get());
ordered_path_clusters.emplace_back(path_cluster.get());
node_size += path_cluster->nodes_.size();
}
nodes_.reserve(node_size);
for (const auto &path_cluster : path_clusters) {
nodes_.insert(nodes_.end(), path_cluster->nodes_.begin(), path_cluster->nodes_.end());
}
MergeInOutClusters(path_cluster_set, ordered_path_clusters);
path_cluster_set.erase(this);
UpdateInOutClusters(path_cluster_set);
}
std::vector<ClusterPtr> BaseCluster::MergeAllPathFrom(const ClusterPtr &other) {
std::queue<BaseCluster *> forward_reached_queue;
std::queue<BaseCluster *> backward_reached_queue;
std::unordered_set<BaseCluster *> forward_reached_clusters;
std::unordered_set<BaseCluster *> backward_reached_clusters;
std::vector<ClusterPtr> path_clusters;
if (other->IsIndependent()) {
return path_clusters;
}
path_clusters.push_back(other);
forward_reached_queue.push(other.get());
backward_reached_queue.push(this);
while (!forward_reached_queue.empty()) {
auto current_cluster = forward_reached_queue.front();
forward_reached_queue.pop();
for (const auto &cluster : current_cluster->out_clusters_) {
if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) {
forward_reached_clusters.insert(cluster);
forward_reached_queue.push(cluster);
}
}
}
while (!backward_reached_queue.empty()) {
auto current_cluster = backward_reached_queue.front();
backward_reached_queue.pop();
for (const auto &cluster : current_cluster->in_clusters_) {
if (cluster->max_ > other->min_ && cluster->max_ != other->max_ &&
backward_reached_clusters.count(cluster) == 0) {
backward_reached_clusters.insert(cluster);
backward_reached_queue.push(cluster);
if (forward_reached_clusters.count(cluster) != 0) {
path_clusters.push_back(cluster->shared_from_this());
}
}
}
}
MergeAllPath(path_clusters);
return path_clusters;
}
std::vector<BaseCluster *> BaseCluster::Inputs() const { return in_clusters_; }
std::vector<BaseCluster *> BaseCluster::Outputs() const { return out_clusters_; }
const std::vector<NodePtr> &BaseCluster::Nodes() const { return nodes_; }
bool BaseCluster::AddFrameInput(InDataAnchorPtr anchor) {
bool added = true;
if (anchor != nullptr && anchor->GetPeerOutAnchor() != nullptr) {
auto index = inputs_.size();
if (merge_inputs_) {
GELOGD("Merge inputs is enabled");
auto src_node = anchor->GetPeerOutAnchor()->GetOwnerNode();
std::string src_key = src_node->GetName() + ":" + std::to_string(anchor->GetPeerOutAnchor()->GetIdx());
std::map<std::string, size_t>::const_iterator it = src_key_to_frame_input_.find(src_key);
if (it != src_key_to_frame_input_.cend()) {
index = it->second;
GELOGD("[%s:%d] Reuse data index: %zu",
anchor->GetOwnerNode()->GetName().c_str(),
anchor->GetIdx(),
it->second);
added = false;
} else {
inputs_.push_back(anchor);
inputs_index_[anchor] = index;
GELOGD("[%s:%d] Assign data index: %zu",
anchor->GetOwnerNode()->GetName().c_str(),
anchor->GetIdx(),
index);
src_key_to_frame_input_[src_key] = index;
}
} else {
inputs_index_[anchor] = index;
inputs_.push_back(anchor);
}
data_to_node_inputs_[index].emplace_back(anchor);
}
return added;
}
void BaseCluster::AddFrameOutput(OutDataAnchorPtr anchor) {
if (anchor != nullptr) {
outputs_index_[anchor] = outputs_.size();
outputs_.push_back(anchor);
}
}
InDataAnchorPtr BaseCluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) {
return partition_node_->GetInDataAnchor(static_cast<int32_t>(inputs_index_[anchor]));
}
OutDataAnchorPtr BaseCluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) {
return partition_node_->GetOutDataAnchor(static_cast<int32_t>(outputs_index_[anchor]));
}
InControlAnchorPtr BaseCluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); };
OutControlAnchorPtr BaseCluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); };
Status BaseCluster::BuildFrame() {
if (partitioner_->IsNeedBuildPartitionFrame(*this)) {
return BuildPartitionFrame();
} else {
auto node = nodes_.front();
auto in_control_anchor = node->GetInControlAnchor();
if (in_control_anchor != nullptr) {
for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()];
if (src_cluster->id_ != id_) {
GE_CHK_GRAPH_STATUS_RET(
GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor),
"[Remove][Edge] from node %s index %d to node %s failed, index %d.",
peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(peer_out_control_anchor),
in_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(in_control_anchor));
control_inputs_.insert(src_cluster);
src_cluster->control_outputs_.insert(peer_out_control_anchor);
}
}
}
if (IsData() || IsIndependent()) {
for (const auto &anchor : node->GetAllOutDataAnchors()) {
AddFrameOutput(anchor);
}
} else {
for (const auto &anchor : node->GetAllInDataAnchors()) {
(void)AddFrameInput(anchor);
}
}
partition_node_ = node;
}
return SUCCESS;
}
Status BaseCluster::AddPartitionedCallNodeOutput(const NodePtr &node, const OpDescPtr &partition_op) {
for (const auto &anchor : node->GetAllOutDataAnchors()) {
auto peer_in_anchors = anchor->GetPeerInDataAnchors();
for (const auto &peer_in_anchor : peer_in_anchors) {
auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()];
if (src_cluster->id_ != id_) {
AddFrameOutput(anchor);
GE_CHK_GRAPH_STATUS_RET(partition_op->AddOutputDesc(node->GetOpDesc()->GetOutputDesc(anchor->GetIdx())),
"[Add][OutputDesc] to op:%s failed.", partition_op->GetName().c_str());
auto tensor = partition_op->MutableOutputDesc(partition_op->GetAllOutputsDescPtr().size() - 1U);
GE_CHK_GRAPH_STATUS_RET(UpdateFrameTensor(node->GetOpDesc(), tensor),
"[Call][UpdateFrameOutTensor] failed, from %s to %s:%d.", node->GetName().c_str(),
partition_op->GetName().c_str(), anchor->GetIdx());
break;
}
}
}
return SUCCESS;
}
Status BaseCluster::RemoveNodeFromRoot(const ge::ComputeGraphPtr &graph) {
for (auto &node : nodes_) {
GE_ASSERT_NOTNULL(subgraph_->AddNode(node), "[Add][Node] failed.");
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveJustNode(graph, node), "[Remove][JustNode] failed, graph:%s, node:%s.",
graph->GetName().c_str(), node->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(node->SetOwnerComputeGraph(subgraph_), "[Set][OwnerComputeGraph] %s for node:%s failed.",
subgraph_->GetName().c_str(), node->GetName().c_str());
}
return SUCCESS;
}
Status BaseCluster::BuildPartitionNodes(const OpDescPtr &partition_op) {
for (auto &node : nodes_) {
for (const auto &anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
continue;
}
auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()];
if (src_cluster->id_ != id_) {
if (AddFrameInput(anchor)) {
GE_CHK_GRAPH_STATUS_RET(partition_op->AddInputDesc(node->GetOpDesc()->GetInputDesc(anchor->GetIdx())),
"[Add][InputDesc] to op:%s failed.", partition_op->GetName().c_str());
auto tensor = partition_op->MutableInputDesc(partition_op->GetAllInputsSize() - 1U);
GE_CHK_GRAPH_STATUS_RET(UpdateFrameTensor(node->GetOpDesc(), tensor),
"[Call][UpdateFrameInTensor] failed, from %s to %s:%zu.", node->GetName().c_str(),
partition_op->GetName().c_str(), partition_op->GetAllInputsSize() - 1U);
}
}
}
auto in_control_anchor = node->GetInControlAnchor();
if (in_control_anchor != nullptr) {
for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
if (peer_out_control_anchor == nullptr) {
continue;
}
auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()];
if (src_cluster->id_ != id_) {
GE_CHK_GRAPH_STATUS_RET(
GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor),
"[Remove][Edge] from %s:%d to %s:%d failed.", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
peer_out_control_anchor->GetIdx(), node->GetName().c_str(), in_control_anchor->GetIdx());
control_inputs_.insert(src_cluster);
src_cluster->control_outputs_.insert(peer_out_control_anchor);
}
}
}
GE_CHK_GRAPH_STATUS_RET(AddPartitionedCallNodeOutput(node, partition_op),
"[Add][PartitionedCallNodeOutput] failed, node[%s], partitioned_call[%s].",
node->GetName().c_str(), partition_op->GetName().c_str());
}
return SUCCESS;
}
Status BaseCluster::BuildPartitionFrame() {
auto graph = partitioner_->root_graph_;
std::string sub_graph_name = partitioner_->GetSubgraphName(*this);
subgraph_ = MakeShared<ComputeGraph>(sub_graph_name);
GE_ASSERT_NOTNULL(subgraph_, "[New][Memory] for subgraph failed, name:%s.", sub_graph_name.c_str());
auto partitioned_op =
MakeShared<OpDesc>(sub_graph_name + "_PartitionedCall_" + std::to_string(unique_id_++), "PartitionedCall");
GE_ASSERT_NOTNULL(partitioned_op, "[New][Memory] for partition op failed.");
GE_CHK_GRAPH_STATUS_RET(partitioned_op->AddSubgraphName(subgraph_->GetName()), "[Add][SubgraphName] %s for op:%s.",
subgraph_->GetName().c_str(), partitioned_op->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(partitioned_op->SetSubgraphInstanceName(0, subgraph_->GetName()),
"[Call][SetSubgraphInstanceName] for op:%s failed, index:0, name:%s.",
partitioned_op->GetName().c_str(), subgraph_->GetName().c_str());
GE_CHK_STATUS_RET(BuildPartitionNodes(partitioned_op),
"[Call][SetSubgraphInstanceName] for op:%s failed, index:0, name:%s.",
partitioned_op->GetName().c_str(), subgraph_->GetName().c_str());
partition_node_ = AddPartitionedCallKeepTopo(graph, nodes_, partitioned_op);
GE_ASSERT_NOTNULL(partition_node_, "[Add][Node] %s to graph:%s failed.", partitioned_op->GetName().c_str(),
graph->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(partition_node_->SetOwnerComputeGraph(graph),
"[Set][OwnerComputeGraph] %s for node:%s failed.", graph->GetName().c_str(),
partitioned_op->GetName().c_str());
GE_CHK_STATUS_RET(RemoveNodeFromRoot(graph),
"[Call][SetSubgraphInstanceName] for op:%s failed, index:0, name:%s.",
partitioned_op->GetName().c_str(), subgraph_->GetName().c_str());
subgraph_->SetParentNode(partition_node_);
subgraph_->SetParentGraph(graph);
auto root_graph = GraphUtils::FindRootGraph(graph);
GE_CHK_GRAPH_STATUS_RET(root_graph->AddSubgraph(subgraph_), "[Add][Subgraph] %s to root graph:%s failed.",
subgraph_->GetName().c_str(), graph->GetName().c_str());
std::string session_graph_id;
GE_ASSERT_TRUE(AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id),
"[Get][Attr] %s on root graph:%s failed.", ATTR_NAME_SESSION_GRAPH_ID.c_str(),
graph->GetName().c_str());
GE_ASSERT_TRUE(AttrUtils::SetStr(*subgraph_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id),
"[Set][Attr] %s on subgraph:%s failed.", ATTR_NAME_SESSION_GRAPH_ID.c_str(),
subgraph_->GetName().c_str());
return SUCCESS;
}
Status BaseCluster::CombinePartitionFrame() {
size_t index = 0U;
for (const auto &anchor : inputs_) {
auto peer_out_anchor = anchor->GetPeerOutAnchor();
GE_ASSERT_NOTNULL(peer_out_anchor);
auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()];
auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor);
auto dst_anchor = GetFrameInDataAnchor(anchor);
for (const auto &node_anchor : data_to_node_inputs_[index]) {
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(peer_out_anchor, node_anchor),
"[Remove][Edge] from %s:%d to %s:%d fail.",
peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
node_anchor->GetOwnerNode()->GetName().c_str(), node_anchor->GetIdx());
}
GE_ASSERT_NOTNULL(src_anchor);
GE_ASSERT_NOTNULL(dst_anchor);
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(src_anchor, dst_anchor), "[Add][Edge] from %s:%d to %s:%d failed.",
src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(),
dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx());
++index;
}
for (const auto &src_cluster : control_inputs_) {
auto src_anchor = src_cluster->GetFrameOutControlAnchor();
auto dst_anchor = GetFrameInControlAnchor();
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(src_anchor, dst_anchor), "[Add][Edge] from %s:%d to %s:%d failed.",
src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(),
dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx());
}
return SUCCESS;
}
Status BaseCluster::BuildPartitionSubgraph() {
if (IsData() || IsNetOutput() || IsIndependent()) {
return SUCCESS;
}
int64_t parent_node_index = 0;
for (const auto &anchor : inputs_) {
GE_CHECK_NOTNULL(subgraph_);
auto data_op =
MakeShared<OpDesc>(subgraph_->GetName() + std::string("_Data_") + std::to_string(parent_node_index), ge::DATA);
GE_ASSERT_NOTNULL(data_op, "[New][Memory] for data op failed.");
auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx());
GE_CHK_GRAPH_STATUS_RET(data_op->AddInputDesc(input_desc), "[Add][InputDesc] to op:%s failed, graph %s.",
data_op->GetName().c_str(), subgraph_->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(data_op->AddOutputDesc(input_desc), "[Add][OutputDesc] to op:%s failed, graph %s.",
data_op->GetName().c_str(), subgraph_->GetName().c_str());
GE_ASSERT_TRUE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index),
"[Set][Attr] %s on subgraph data node:%s failed.", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
data_op->GetName().c_str());
GE_CHK_STATUS_RET(CopyOpAttr(anchor->GetOwnerNode()->GetOpDesc(), data_op), "[Call][CopyOpAttr] data op failed.");
auto data_node = subgraph_->AddNode(data_op);
GE_ASSERT_NOTNULL(data_node, "[Add][Node] %s to subgraph:%s failed.", data_op->GetName().c_str(),
subgraph_->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(data_node->SetOwnerComputeGraph(subgraph_), "[Set][OwnerGraph] %s of data node:%s failed.",
subgraph_->GetName().c_str(), data_op->GetName().c_str());
for (const auto &node_anchor : data_to_node_inputs_[parent_node_index]) {
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node_anchor),
"[Call][AddEdge] Failed add data input edge to %s:%d",
node_anchor->GetOwnerNode()->GetName().c_str(), node_anchor->GetIdx());
}
parent_node_index++;
}
if (outputs_.empty() && control_outputs_.empty()) {
return SUCCESS;
}
auto net_output_op = MakeShared<OpDesc>(subgraph_->GetName() + "_" + NODE_NAME_NET_OUTPUT, ge::NETOUTPUT);
GE_ASSERT_NOTNULL(net_output_op, "[New][Memory] for netoutput op failed.");
GE_CHK_STATUS_RET(SetAttrToNetoutput(net_output_op), "[Call][SetAttrToNetoutput] net_output_op failed.");
for (size_t i = 0; i < outputs_.size(); ++i) {
GeTensorDesc input_desc;
GE_CHK_GRAPH_STATUS_RET(net_output_op->AddInputDesc(input_desc), "[Add][InputDesc] to op:%s failed.",
net_output_op->GetName().c_str());
}
auto net_output_node = subgraph_->AddNode(net_output_op);
GE_ASSERT_NOTNULL(net_output_node, "[Call][AddNode] Failed add netoutput node:%s to subgraph:%s.",
net_output_op->GetName().c_str(), subgraph_->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(net_output_node->SetOwnerComputeGraph(subgraph_),
"[Set][OwnerGraph] %s of netoutput node:%s failed.", subgraph_->GetName().c_str(),
net_output_node->GetName().c_str());
parent_node_index = 0;
for (const auto &anchor : outputs_) {
auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(anchor->GetIdx()));
GE_ASSERT_TRUE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index),
"[Set][Attr] parent_node_index on subgraph node:%s netoutput's input failed.",
anchor->GetOwnerNode()->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(net_output_op->UpdateInputDesc(parent_node_index, output_desc),
"[Update][InputDesc] of netoutput node:%s failed.", net_output_op->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(anchor, net_output_node->GetInDataAnchor(parent_node_index)),
"[Add][Edge] from %s:%d to netoutput node:%s failed.",
anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx(),
net_output_op->GetName().c_str());
parent_node_index++;
}
for (const auto &anchor : control_outputs_) {
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()),
"[Add][ControlEdge] from %s:%d to netoutput node:%s failed.",
anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx(),
net_output_op->GetName().c_str());
}
return SUCCESS;
}
void BaseCluster::Clear() {
in_clusters_.clear();
out_clusters_.clear();
nodes_.clear();
inputs_index_.clear();
outputs_index_.clear();
inputs_.clear();
outputs_.clear();
control_inputs_.clear();
control_outputs_.clear();
partition_node_.reset();
subgraph_.reset();
unique_id_ = 0;
}
void BaseCluster::SetMergeInputs(bool merge_inputs) {
merge_inputs_ = merge_inputs;
}
Status BaseCluster::SetGraphId(const uint32_t graph_id) const {
if (subgraph_ == nullptr) {
GELOGW("Subgraph is null, cannot set graph id, graph_id=%u.", graph_id);
return SUCCESS;
}
subgraph_->SetGraphID(graph_id);
return SUCCESS;
}
std::vector<NodePtr> &BaseCluster::GetMutableDirectNode() {
return nodes_;
}
Status BaseCluster::UpdateFrameTensor(const OpDescPtr &op_desc, GeTensorDescPtr &tensor) {
GE_ASSERT_NOTNULL(op_desc, "[Check][Valid] op_desc is invalid");
GE_ASSERT_NOTNULL(tensor, "[Check][Valid] tensor is invalid");
const auto &node_attr_names = PartitionNodeAttrNameManager::GetInstance().GetAllPartitionNodeAttrNames();
for (const auto &item : node_attr_names) {
if (item.IsSupportTensor()) {
GE_CHK_STATUS_RET(CopyAttr(op_desc, tensor, item), "[Call][CopyAttr] failed, from %s to %s.",
op_desc->GetName().c_str(), tensor->GetName().c_str());
}
}
return SUCCESS;
}
Status BaseCluster::CopyOpAttr(const OpDescPtr &from, OpDescPtr &to) {
if (from == nullptr) {
return SUCCESS;
}
const auto &node_attr_names = PartitionNodeAttrNameManager::GetInstance().GetAllPartitionNodeAttrNames();
for (const auto &item : node_attr_names) {
if (item.IsNeedCopy()) {
GE_CHK_STATUS_RET(CopyAttr(from, to, item), "[Call][CopyAttr] failed, from %s to %s.", from->GetName().c_str(),
to->GetName().c_str());
}
}
return SUCCESS;
}
Status BaseCluster::CopyAttr(AttrUtils::ConstAttrHolderAdapter &&from, AttrUtils::AttrHolderAdapter &&to,
const PartitionNodeAttrName &item) {
const auto get_fn = item.GetAttrFunction();
const auto set_fn = item.SetAttrFunction();
GE_CHECK_NOTNULL(get_fn);
GE_CHECK_NOTNULL(set_fn);
AnyValue val;
const auto has_attr = get_fn(from.get(), item.GetAttrName(), val);
if (has_attr) {
GE_ASSERT_TRUE(set_fn(to.get(), item.GetAttrName(), val), "[Call][set_fn] failed, attr %s",
item.GetAttrName().c_str());
}
return SUCCESS;
}
}