#include "mlir/Transforms/Inliner.h"
#include "mlir/IR/Threading.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "inlining"
using namespace mlir;
using ResolvedCall = Inliner::ResolvedCall;
static void walkReferencedSymbolNodes(
Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
function_ref<void(CallGraphNode *, Operation *)> callback) {
auto symbolUses = SymbolTable::getSymbolUses(op);
assert(symbolUses && "expected uses to be valid");
Operation *symbolTableOp = op->getParentOp();
for (const SymbolTable::SymbolUse &use : *symbolUses) {
auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
CallGraphNode *&node = refIt.first->second;
if (refIt.second) {
auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
use.getSymbolRef());
auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
if (!callableOp)
continue;
node = cg.lookupNode(callableOp.getCallableRegion());
}
if (node)
callback(node, use.getUser());
}
}
namespace {
struct CGUseList {
struct CGUser {
DenseSet<CallGraphNode *> topLevelUses;
DenseMap<CallGraphNode *, int> innerUses;
};
CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
void eraseNode(CallGraphNode *node);
bool isDead(CallGraphNode *node) const;
bool hasOneUseAndDiscardable(CallGraphNode *node) const;
void recomputeUses(CallGraphNode *node, CallGraph &cg);
void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
private:
void decrementDiscardableUses(CGUser &uses);
DenseMap<CallGraphNode *, int> discardableSymNodeUses;
DenseMap<CallGraphNode *, CGUser> nodeUses;
SymbolTableCollection &symbolTable;
};
}
CGUseList::CGUseList(Operation *op, CallGraph &cg,
SymbolTableCollection &symbolTable)
: symbolTable(symbolTable) {
DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
if (symbol && (allUsesVisible || symbol.isPrivate()) &&
symbol.canDiscardOnUseEmpty()) {
discardableSymNodeUses.try_emplace(node, 0);
}
continue;
}
}
walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
[](CallGraphNode *, Operation *) {});
}
};
SymbolTable::walkSymbolTables(op, !op->getBlock(),
walkFn);
for (auto &it : alwaysLiveNodes)
discardableSymNodeUses.erase(it.second);
for (CallGraphNode *node : cg)
recomputeUses(node, cg);
}
void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
CallGraph &cg) {
auto &userRefs = nodeUses[userNode].innerUses;
auto walkFn = [&](CallGraphNode *node, Operation *user) {
auto parentIt = userRefs.find(node);
if (parentIt == userRefs.end())
return;
--parentIt->second;
--discardableSymNodeUses[node];
};
DenseMap<Attribute, CallGraphNode *> resolvedRefs;
walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
}
void CGUseList::eraseNode(CallGraphNode *node) {
for (auto &edge : *node)
if (edge.isChild())
eraseNode(edge.getTarget());
auto useIt = nodeUses.find(node);
assert(useIt != nodeUses.end() && "expected node to be valid");
decrementDiscardableUses(useIt->getSecond());
nodeUses.erase(useIt);
discardableSymNodeUses.erase(node);
}
bool CGUseList::isDead(CallGraphNode *node) const {
Operation *nodeOp = node->getCallableRegion()->getParentOp();
if (!isa<SymbolOpInterface>(nodeOp))
return isMemoryEffectFree(nodeOp) && nodeOp->use_empty();
auto symbolIt = discardableSymNodeUses.find(node);
return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
}
bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
Operation *nodeOp = node->getCallableRegion()->getParentOp();
if (!isa<SymbolOpInterface>(nodeOp))
return isMemoryEffectFree(nodeOp) && nodeOp->hasOneUse();
auto symbolIt = discardableSymNodeUses.find(node);
return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
}
void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
Operation *parentOp = node->getCallableRegion()->getParentOp();
CGUser &uses = nodeUses[node];
decrementDiscardableUses(uses);
uses = CGUser();
DenseMap<Attribute, CallGraphNode *> resolvedRefs;
auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
auto discardSymIt = discardableSymNodeUses.find(refNode);
if (discardSymIt == discardableSymNodeUses.end())
return;
if (user != parentOp)
++uses.innerUses[refNode];
else if (!uses.topLevelUses.insert(refNode).second)
return;
++discardSymIt->second;
};
walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
}
void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
for (auto &useIt : lhsUses.innerUses) {
rhsUses.innerUses[useIt.first] += useIt.second;
discardableSymNodeUses[useIt.first] += useIt.second;
}
}
void CGUseList::decrementDiscardableUses(CGUser &uses) {
for (CallGraphNode *node : uses.topLevelUses)
--discardableSymNodeUses[node];
for (auto &it : uses.innerUses)
discardableSymNodeUses[it.first] -= it.second;
}
namespace {
class CallGraphSCC {
public:
CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
: parentIterator(parentIterator) {}
std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
void remove(CallGraphNode *node) {
auto it = llvm::find(nodes, node);
if (it != nodes.end()) {
nodes.erase(it);
parentIterator.ReplaceNode(node, nullptr);
}
}
private:
std::vector<CallGraphNode *> nodes;
llvm::scc_iterator<const CallGraph *> &parentIterator;
};
}
static LogicalResult runTransformOnCGSCCs(
const CallGraph &cg,
function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) {
currentSCC.reset(*cgi);
++cgi;
if (failed(sccTransformer(currentSCC)))
return failure();
}
return success();
}
static void collectCallOps(iterator_range<Region::iterator> blocks,
CallGraphNode *sourceNode, CallGraph &cg,
SymbolTableCollection &symbolTable,
SmallVectorImpl<ResolvedCall> &calls,
bool traverseNestedCGNodes) {
SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
auto addToWorklist = [&](CallGraphNode *node,
iterator_range<Region::iterator> blocks) {
for (Block &block : blocks)
worklist.emplace_back(&block, node);
};
addToWorklist(sourceNode, blocks);
while (!worklist.empty()) {
Block *block;
std::tie(block, sourceNode) = worklist.pop_back_val();
for (Operation &op : *block) {
if (auto call = dyn_cast<CallOpInterface>(op)) {
CallInterfaceCallable callable = call.getCallableForCallee();
if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
if (!isa<FlatSymbolRefAttr>(symRef))
continue;
}
CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
if (!targetNode->isExternal())
calls.emplace_back(call, sourceNode, targetNode);
continue;
}
for (auto &nestedRegion : op.getRegions()) {
CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
if (traverseNestedCGNodes || !nestedNode)
addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
}
}
}
}
#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
#endif
static bool inlineHistoryIncludes(
CallGraphNode *node, std::optional<size_t> inlineHistoryID,
MutableArrayRef<std::pair<CallGraphNode *, std::optional<size_t>>>
inlineHistory) {
while (inlineHistoryID.has_value()) {
assert(*inlineHistoryID < inlineHistory.size() &&
"Invalid inline history ID");
if (inlineHistory[*inlineHistoryID].first == node)
return true;
inlineHistoryID = inlineHistory[*inlineHistoryID].second;
}
return false;
}
namespace {
struct InlinerInterfaceImpl : public InlinerInterface {
InlinerInterfaceImpl(MLIRContext *context, CallGraph &cg,
SymbolTableCollection &symbolTable)
: InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
void
processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
CallGraphNode *node;
Region *region = inlinedBlocks.begin()->getParent();
while (!(node = cg.lookupNode(region))) {
region = region->getParentRegion();
assert(region && "expected valid parent node");
}
collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
true);
}
void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
void eraseDeadCallables() {
for (CallGraphNode *node : deadNodes)
node->getCallableRegion()->getParentOp()->erase();
}
SmallPtrSet<CallGraphNode *, 8> deadNodes;
SmallVector<ResolvedCall, 8> calls;
CallGraph &cg;
SymbolTableCollection &symbolTable;
};
}
namespace mlir {
class Inliner::Impl {
public:
Impl(Inliner &inliner) : inliner(inliner) {}
LogicalResult inlineSCC(InlinerInterfaceImpl &inlinerIface,
CGUseList &useList, CallGraphSCC ¤tSCC,
MLIRContext *context);
private:
LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
CallGraphSCC ¤tSCC, MLIRContext *context);
LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
MLIRContext *context);
LogicalResult optimizeCallable(CallGraphNode *node,
llvm::StringMap<OpPassManager> &pipelines);
LogicalResult inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CGUseList &useList, CallGraphSCC ¤tSCC);
bool shouldInline(ResolvedCall &resolvedCall);
private:
Inliner &inliner;
llvm::SmallVector<llvm::StringMap<OpPassManager>> pipelines;
};
LogicalResult Inliner::Impl::inlineSCC(InlinerInterfaceImpl &inlinerIface,
CGUseList &useList,
CallGraphSCC ¤tSCC,
MLIRContext *context) {
unsigned iterationCount = 0;
do {
if (failed(optimizeSCC(inlinerIface.cg, useList, currentSCC, context)))
return failure();
if (failed(inlineCallsInSCC(inlinerIface, useList, currentSCC)))
break;
} while (++iterationCount < inliner.config.getMaxInliningIterations());
return success();
}
LogicalResult Inliner::Impl::optimizeSCC(CallGraph &cg, CGUseList &useList,
CallGraphSCC ¤tSCC,
MLIRContext *context) {
SmallVector<CallGraphNode *, 4> nodesToVisit;
for (auto *node : currentSCC) {
if (node->isExternal())
continue;
if (node->hasChildren())
continue;
auto *region = node->getCallableRegion();
if (!region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
continue;
nodesToVisit.push_back(node);
}
if (nodesToVisit.empty())
return success();
if (failed(optimizeSCCAsync(nodesToVisit, context)))
return failure();
for (CallGraphNode *node : nodesToVisit)
useList.recomputeUses(node, cg);
return success();
}
LogicalResult
Inliner::Impl::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
MLIRContext *ctx) {
size_t numThreads = ctx->getNumThreads();
const auto &opPipelines = inliner.config.getOpPipelines();
if (pipelines.size() < numThreads) {
pipelines.reserve(numThreads);
pipelines.resize(numThreads, opPipelines);
}
for (CallGraphNode *node : nodesToVisit)
inliner.am.nest(node->getCallableRegion()->getParentOp());
std::vector<std::atomic<bool>> activePMs(pipelines.size());
std::fill(activePMs.begin(), activePMs.end(), false);
return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
bool expectedInactive = false;
return isActive.compare_exchange_strong(expectedInactive, true);
});
assert(it != activePMs.end() &&
"could not find inactive pass manager for thread");
unsigned pmIndex = it - activePMs.begin();
LogicalResult result = optimizeCallable(node, pipelines[pmIndex]);
activePMs[pmIndex].store(false);
return result;
});
}
LogicalResult
Inliner::Impl::optimizeCallable(CallGraphNode *node,
llvm::StringMap<OpPassManager> &pipelines) {
Operation *callable = node->getCallableRegion()->getParentOp();
StringRef opName = callable->getName().getStringRef();
auto pipelineIt = pipelines.find(opName);
const auto &defaultPipeline = inliner.config.getDefaultPipeline();
if (pipelineIt == pipelines.end()) {
if (!defaultPipeline)
return success();
OpPassManager defaultPM(opName);
defaultPipeline(defaultPM);
pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
}
return inliner.runPipelineHelper(inliner.pass, pipelineIt->second, callable);
}
LogicalResult
Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CGUseList &useList, CallGraphSCC ¤tSCC) {
CallGraph &cg = inlinerIface.cg;
auto &calls = inlinerIface.calls;
llvm::SmallSetVector<CallGraphNode *, 1> deadNodes;
for (CallGraphNode *node : currentSCC) {
if (node->isExternal())
continue;
if (useList.isDead(node)) {
deadNodes.insert(node);
} else {
collectCallOps(*node->getCallableRegion(), node, cg,
inlinerIface.symbolTable, calls,
false);
}
}
using InlineHistoryT = std::optional<size_t>;
SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
llvm::dbgs() << "}\n";
});
bool inlinedAnyCalls = false;
for (unsigned i = 0; i < calls.size(); ++i) {
if (deadNodes.contains(calls[i].sourceNode))
continue;
ResolvedCall it = calls[i];
InlineHistoryT inlineHistoryID = callHistory[i];
bool inHistory =
inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
bool doInline = !inHistory && shouldInline(it);
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
else
llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
});
if (!doInline)
continue;
unsigned prevSize = calls.size();
Region *targetRegion = it.targetNode->getCallableRegion();
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
LogicalResult inlineResult =
inlineCall(inlinerIface, call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, !inlineInPlace);
if (failed(inlineResult)) {
LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
continue;
}
inlinedAnyCalls = true;
InlineHistoryT newInlineHistoryID{inlineHistory.size()};
inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
(void)historyToString;
LLVM_DEBUG(llvm::dbgs()
<< "* new inlineHistory entry: " << newInlineHistoryID << ". ["
<< getNodeName(call) << ", " << historyToString(inlineHistoryID)
<< "]\n");
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
<< "}\n with historyID = " << newInlineHistoryID
<< ", added due to inlining of\n call {" << call
<< "}\n with historyID = "
<< historyToString(inlineHistoryID) << "\n");
}
useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
call.erase();
if (inlineInPlace) {
useList.eraseNode(it.targetNode);
deadNodes.insert(it.targetNode);
}
}
for (CallGraphNode *node : deadNodes) {
currentSCC.remove(node);
inlinerIface.markForDeletion(node);
}
calls.clear();
return success(inlinedAnyCalls);
}
bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
if (resolvedCall.call->hasTrait<OpTrait::IsTerminator>())
return false;
if (llvm::count_if(*resolvedCall.targetNode,
[&](CallGraphNode::Edge const &edge) -> bool {
return edge.getTarget() == resolvedCall.targetNode;
}) > 0)
return false;
Region *callableRegion = resolvedCall.targetNode->getCallableRegion();
if (callableRegion->isAncestor(resolvedCall.call->getParentRegion()))
return false;
bool calleeHasMultipleBlocks =
llvm::hasNItemsOrMore(*callableRegion, 2);
auto callerRegionSupportsMultipleBlocks = [&]() {
return callableRegion->getParentOp()->getName() ==
resolvedCall.call->getParentOp()->getName() ||
!resolvedCall.call->getParentOp()
->mightHaveTrait<OpTrait::SingleBlock>();
};
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
return false;
if (!inliner.isProfitableToInline(resolvedCall))
return false;
return true;
}
LogicalResult Inliner::doInlining() {
Impl impl(*this);
auto *context = op->getContext();
SymbolTableCollection symbolTable;
InlinerInterfaceImpl inlinerIface(context, cg, symbolTable);
CGUseList useList(op, cg, symbolTable);
LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
return impl.inlineSCC(inlinerIface, useList, scc, context);
});
if (failed(result))
return result;
inlinerIface.eraseDeadCallables();
return success();
}
}