* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Analysis/Cluster/Graph.h"
#include <algorithm>
#include <queue>
#include "llvm/Support/Debug.h"
#include "mlir/IR/OpDefinition.h"
#define DEBUG_TYPE "analysis-graph"
namespace mlir {
Graph::Cluster::Cluster(size_t node_id, Operation *op, const llvm::DenseMap<Operation *, size_t> &op_idx_map)
: cluster_id_(node_id), max_id_(node_id) {
for (auto operand : op->getOperands()) {
if (auto *def_op = operand.getDefiningOp()) {
auto it = op_idx_map.find(def_op);
if (it != op_idx_map.end()) {
inputs_.insert(it->second);
}
}
}
}
void Graph::Cluster::Merge(Cluster *other) {
other->cluster_id_ = cluster_id_;
max_id_ = std::max(max_id_, other->max_id_);
cluster_size_ += other->cluster_size_;
inputs_.insert(other->inputs_.cbegin(), other->inputs_.cend());
other->Clean();
}
std::unique_ptr<Graph> Graph::Build(Block *block, std::vector<Operation *> *ops,
llvm::DenseMap<Operation *, size_t> *op_idx_map, bool aggressive_cut) {
std::vector<Operation *> local_ops;
llvm::DenseMap<Operation *, size_t> local_map;
for (Operation &op : block->getOperations()) {
if (op.hasTrait<OpTrait::IsTerminator>()) {
continue;
}
local_map[&op] = local_ops.size();
local_ops.push_back(&op);
}
auto graph = std::unique_ptr<Graph>(new Graph(local_ops, local_map, aggressive_cut));
if (ops != nullptr) {
*ops = std::move(local_ops);
}
if (op_idx_map != nullptr) {
*op_idx_map = std::move(local_map);
}
return graph;
}
Graph::Graph(const std::vector<Operation *> &ops, const llvm::DenseMap<Operation *, size_t> &op_idx_map,
bool aggressive_cut)
: ops_(ops) {
clusters_.reserve(ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
clusters_.emplace_back(i, ops[i], op_idx_map);
}
if (!aggressive_cut) {
bitmax_ = std::make_unique<BitMax>(ops.size());
}
}
size_t Graph::Find(size_t node_id) {
size_t &pre_id = clusters_[node_id].cluster_id_;
return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
}
void Graph::Merge(const std::vector<size_t> &candidates) {
size_t min_id = *std::min_element(candidates.begin(), candidates.end());
for (auto id : candidates) {
if (id == min_id) {
continue;
}
clusters_[min_id].Merge(&clusters_[id]);
}
if (bitmax_ != nullptr) {
bitmax_->SetMax(min_id, clusters_[min_id].max_id_);
}
}
std::vector<std::vector<size_t>> Graph::CollectClusters() {
std::vector<std::vector<size_t>> cluster_map(clusters_.size());
for (size_t i = 0; i < clusters_.size(); ++i) {
cluster_map[Find(i)].push_back(i);
}
return cluster_map;
}
void Graph::Dfs(size_t node_id, const VisitFunc &visitor) {
++seen_;
DepthFirstSearch(Find(node_id), visitor);
}
size_t Graph::GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
size_t Graph::GetMaxIdWithCutStrategy(size_t cluster_id) {
return (bitmax_ == nullptr) ? GetMaxId(cluster_id) : bitmax_->FindMax(Find(cluster_id));
}
size_t Graph::GetMaxId(size_t cluster_id) { return clusters_[Find(cluster_id)].max_id_; }
const std::set<size_t> &Graph::GetInputs(size_t cluster_id) {
cluster_id = Find(cluster_id);
RefreshInputs(cluster_id);
return clusters_[cluster_id].inputs_;
}
void Graph::RefreshInputs(size_t i) {
auto &inputs = clusters_[i].inputs_;
std::set<size_t> new_inputs;
for (auto it = inputs.cbegin(); it != inputs.cend(); ++it) {
size_t new_id = Find(*it);
if (new_id != i) {
new_inputs.insert(new_id);
}
}
inputs = std::move(new_inputs);
}
void Graph::DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
if (clusters_[cluster_id].seen_ >= seen_) {
return;
}
clusters_[cluster_id].seen_ = seen_;
auto result = visitor(cluster_id);
if (result != VisitResult::kFollow) {
return;
}
const auto &inputs = GetInputs(cluster_id);
for (auto it = inputs.crbegin(); it != inputs.crend(); ++it) {
DepthFirstSearch(*it, visitor);
}
}
bool Graph::HasCircle() {
std::vector<size_t> valid_clusters;
for (size_t i = 0; i < clusters_.size(); ++i) {
if (clusters_[i].cluster_id_ == i) {
valid_clusters.push_back(i);
}
}
std::vector<int> out_degree(clusters_.size(), 0);
std::queue<size_t> que;
for (auto &cluster_id : valid_clusters) {
for (size_t i : GetInputs(cluster_id)) {
out_degree[i]++;
}
}
for (auto &cluster_id : valid_clusters) {
if (out_degree[cluster_id] == 0) {
que.push(cluster_id);
}
}
size_t count = 0;
while (!que.empty()) {
size_t u = que.front();
que.pop();
count++;
for (size_t i : GetInputs(u)) {
if (--out_degree[i] == 0) {
que.push(i);
}
}
}
return count != valid_clusters.size();
}
void CircleChecker::RemoveCircle(std::vector<size_t> *candidates) {
if (candidates->size() <= 1) {
return;
}
candidates_.clear();
candidates_.insert(candidates->cbegin(), candidates->cend());
for (auto iter = candidates->cbegin(); iter != candidates->cend(); ++iter) {
if (candidates_.count(*iter) == 0) {
continue;
}
circle_nodes_.clear();
if (CheckCircle(*iter)) {
RemoveCircleNodesFromCandidates();
}
}
candidates->erase(std::remove_if(candidates->begin(), candidates->end(),
[this](size_t c) { return this->candidates_.count(c) == 0; }),
candidates->end());
}
bool CircleChecker::CheckCircle(size_t basenode) {
const auto &inputs = graph_->GetInputs(basenode);
std::set<size_t> visited_circle_nodes;
for (auto x : inputs) {
if (candidates_.count(x) > 0) {
continue;
}
bool has_circle = false;
std::set<size_t> done;
auto candidate_min = *candidates_.begin();
auto vis_func = [this, &has_circle, &done, &visited_circle_nodes, &candidate_min](size_t cluster_id) {
if (graph_->GetMaxIdWithCutStrategy(cluster_id) < candidate_min) {
return VisitResult::kExclude;
}
if (done.count(cluster_id) > 0 || acyclic_nodes_.count(cluster_id) > 0 ||
visited_circle_nodes.count(cluster_id) > 0) {
return VisitResult::kExclude;
}
done.insert(cluster_id);
if (candidates_.count(cluster_id) > 0) {
has_circle = true;
circle_nodes_.push_back(cluster_id);
return VisitResult::kExclude;
}
return VisitResult::kFollow;
};
graph_->Dfs(x, vis_func);
if (has_circle) {
visited_circle_nodes.insert(done.cbegin(), done.cend());
} else {
acyclic_nodes_.insert(done.cbegin(), done.cend());
}
}
return !circle_nodes_.empty();
}
void CircleChecker::RemoveCircleNodesFromCandidates() {
auto remove_from_candidates = [this](size_t node_id) {
if (candidates_.count(node_id) > 0) {
candidates_.erase(node_id);
return VisitResult::kFollow;
}
return VisitResult::kExclude;
};
for (auto node : circle_nodes_) {
graph_->Dfs(node, remove_from_candidates);
}
}
}