* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "cluster.h"
#include <sstream>
namespace optimize {
void Cluster::AddInput(Cluster &input) {
if (std::find(inputs_.cbegin(), inputs_.cend(), &input) != inputs_.cend()) {
return;
}
inputs_.insert(&input);
if (std::find(input.outputs_.cbegin(), input.outputs_.cend(), this) != input.outputs_.cend()) {
return;
}
input.outputs_.insert(this);
}
void Cluster::AddOutput(Cluster &output) {
if (std::find(outputs_.cbegin(), outputs_.cend(), &output) != outputs_.cend()) {
return;
}
outputs_.insert(&output);
if (std::find(output.inputs_.cbegin(), output.inputs_.cend(), this) != output.inputs_.cend()) {
return;
}
output.inputs_.insert(this);
}
void Cluster::MergeFrom(Cluster &from) {
nodes_.insert(nodes_.cend(), from.nodes_.cbegin(), from.nodes_.cend());
node_set_.insert(from.node_set_.begin(), from.node_set_.end());
from.inputs_.erase(this);
from.outputs_.erase(this);
inputs_.erase(&from);
outputs_.erase(&from);
meta_data_.ins_num += from.meta_data_.ins_num;
if (!from.meta_data_.vectorized_repeats.empty()) {
meta_data_.vectorized_repeats = from.meta_data_.vectorized_repeats;
}
auto in_clusters = from.inputs_;
for (const auto &cluster : in_clusters) {
cluster->RemoveOutput(from);
cluster->AddOutput(*this);
}
auto out_clusters = from.outputs_;
for (const auto &cluster : out_clusters) {
cluster->RemoveInput(from);
cluster->AddInput(*this);
}
auto connected_nodes = Cluster::FindConnectedNodes(from, *this);
in_nodes_ = Cluster::CalculateMergedInNodes(from, *this, connected_nodes);
out_nodes_ = Cluster::CalculateMergedOutNodes(from, *this);
}
void Cluster::RemoveInput(Cluster &input) {
inputs_.erase(&input);
input.outputs_.erase(this);
}
void Cluster::RemoveOutput(Cluster &output) {
outputs_.erase(&output);
output.inputs_.erase(this);
}
const std::list<af::AscNodePtr> &Cluster::Nodes() const {
return nodes_;
}
std::unordered_set<af::AscNodePtr> Cluster::FindConnectedNodes(const Cluster &pre, const Cluster &post) {
std::unordered_set<af::AscNodePtr> connected_nodes;
for (const auto &node : pre.out_nodes_) {
if (post.in_nodes_.count(node) > 0UL) {
connected_nodes.insert(node);
}
}
return connected_nodes;
}
std::unordered_set<af::AscNodePtr> Cluster::CalculateMergedInNodes(
const Cluster &pre, const Cluster &post, const std::unordered_set<af::AscNodePtr> &connected_nodes) {
std::unordered_set<af::AscNodePtr> merged_in;
for (const auto &node : pre.in_nodes_) {
merged_in.insert(node);
}
for (const auto &node : post.in_nodes_) {
if (connected_nodes.count(node) == 0UL) {
merged_in.insert(node);
}
}
return merged_in;
}
std::unordered_set<af::AscNodePtr> Cluster::CalculateMergedOutNodes(const Cluster &pre, const Cluster &post) {
std::unordered_set<af::AscNodePtr> merged_out;
for (const auto &node: post.out_nodes_) {
merged_out.insert(node);
}
for (const auto &node: pre.out_nodes_) {
for (const auto &out_node: node->GetOutDataNodes()) {
auto asc_out_node = std::dynamic_pointer_cast<af::AscNode>(out_node);
if ((asc_out_node != nullptr) && !post.ContainsNode(asc_out_node)) {
merged_out.insert(node);
}
}
}
return merged_out;
}
std::string Cluster::DebugString() const {
std::stringstream ss;
ss << "id:" << id_ << ", node_size:" << nodes_.size() << ", enable_vf:" << meta_data_.enable_vf << ",";
ss << " inputs:[";
for (const auto &cluster : inputs_) {
ss << cluster->id_ << ",";
}
ss << "] outputs:[";
for (const auto &cluster : outputs_) {
ss << cluster->id_ << ",";
}
ss << "] nodes:|";
for (const auto &node : nodes_) {
ss << (node->GetName() + "|");
}
return ss.str();
}
}