#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <memory>
using namespace mlir;
bool CallGraphNode::isExternal() const { return !callableRegion; }
Region *CallGraphNode::getCallableRegion() const {
assert(!isExternal() && "the external node has no callable region");
return callableRegion;
}
void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
assert(isExternal() && "abstract edges are only valid on external nodes");
addEdge(node, Edge::Kind::Abstract);
}
void CallGraphNode::addCallEdge(CallGraphNode *node) {
addEdge(node, Edge::Kind::Call);
}
void CallGraphNode::addChildEdge(CallGraphNode *child) {
addEdge(child, Edge::Kind::Child);
}
bool CallGraphNode::hasChildren() const {
return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
}
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
edges.insert({node, kind});
}
static void computeCallGraph(Operation *op, CallGraph &cg,
SymbolTableCollection &symbolTable,
CallGraphNode *parentNode, bool resolveCalls) {
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
if (resolveCalls && parentNode)
parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
return;
}
if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
if (auto *callableRegion = callable.getCallableRegion())
parentNode = cg.getOrAddNode(callableRegion, parentNode);
else
return;
}
for (Region ®ion : op->getRegions())
for (Operation &nested : region.getOps())
computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op)
: externalCallerNode(nullptr),
unknownCalleeNode(nullptr) {
SymbolTableCollection symbolTable;
computeCallGraph(op, *this, symbolTable, nullptr,
false);
computeCallGraph(op, *this, symbolTable, nullptr,
true);
}
CallGraphNode *CallGraph::getOrAddNode(Region *region,
CallGraphNode *parentNode) {
assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
"expected parent operation to be callable");
std::unique_ptr<CallGraphNode> &node = nodes[region];
if (!node) {
node.reset(new CallGraphNode(region));
if (parentNode) {
parentNode->addChildEdge(node.get());
} else {
externalCallerNode.addAbstractEdge(node.get());
}
}
return node.get();
}
CallGraphNode *CallGraph::lookupNode(Region *region) const {
const auto *it = nodes.find(region);
return it == nodes.end() ? nullptr : it->second.get();
}
CallGraphNode *
CallGraph::resolveCallable(CallOpInterface call,
SymbolTableCollection &symbolTable) const {
Operation *callable = call.resolveCallable(&symbolTable);
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
return getUnknownCalleeNode();
}
void CallGraph::eraseNode(CallGraphNode *node) {
if (node->hasChildren()) {
for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
if (edge.isChild())
eraseNode(edge.getTarget());
}
for (auto &it : nodes) {
it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
return edge.getTarget() == node;
});
}
nodes.erase(node->getCallableRegion());
}
void CallGraph::dump() const { print(llvm::errs()); }
void CallGraph::print(raw_ostream &os) const {
os << "// ---- CallGraph ----\n";
auto emitNodeName = [&](const CallGraphNode *node) {
if (node == getExternalCallerNode()) {
os << "<External-Caller-Node>";
return;
}
if (node == getUnknownCalleeNode()) {
os << "<Unknown-Callee-Node>";
return;
}
auto *callableRegion = node->getCallableRegion();
auto *parentOp = callableRegion->getParentOp();
os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
<< callableRegion->getRegionNumber();
auto attrs = parentOp->getAttrDictionary();
if (!attrs.empty())
os << " : " << attrs;
};
for (auto &nodeIt : nodes) {
const CallGraphNode *node = nodeIt.second.get();
os << "// - Node : ";
emitNodeName(node);
os << "\n";
for (auto &edge : *node) {
os << "// -- ";
if (edge.isCall())
os << "Call";
else if (edge.isChild())
os << "Child";
os << "-Edge : ";
emitNodeName(edge.getTarget());
os << "\n";
}
os << "//\n";
}
os << "// -- SCCs --\n";
for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
os << "// - SCC : \n";
for (auto &node : scc) {
os << "// -- Node :";
emitNodeName(node);
os << "\n";
}
os << "\n";
}
os << "// -------------------\n";
}