#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Inliner.h"
namespace mlir {
#define GEN_PASS_DEF_INLINER
#include "mlir/Transforms/Passes.h.inc"
}
#define DEBUG_TYPE "inliner-pass"
using namespace mlir;
static void defaultInlinerOptPipeline(OpPassManager &pm) {
pm.addPass(createCanonicalizerPass());
}
namespace {
class InlinerPass : public impl::InlinerBase<InlinerPass> {
public:
InlinerPass();
InlinerPass(const InlinerPass &) = default;
InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
llvm::StringMap<OpPassManager> opPipelines);
void runOnOperation() override;
static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
Operation *op) {
return mlir::cast<InlinerPass>(pass).runPipeline(pipeline, op);
}
private:
LogicalResult initializeOptions(
StringRef options,
function_ref<LogicalResult(const Twine &)> errorHandler) override;
InlinerConfig config;
};
}
InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
InlinerPass::InlinerPass(
std::function<void(OpPassManager &)> defaultPipelineArg)
: InlinerPass(std::move(defaultPipelineArg),
llvm::StringMap<OpPassManager>{}) {}
InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
llvm::StringMap<OpPassManager> opPipelines)
: config(std::move(defaultPipeline), maxInliningIterations) {
if (opPipelines.empty())
return;
for (auto &it : opPipelines)
opPipelineList.addValue(it.second);
config.setOpPipelines(std::move(opPipelines));
}
static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
unsigned inliningThreshold) {
if (inliningThreshold == 0U)
return false;
if (inliningThreshold == -1U)
return true;
Region *callerRegion = resolvedCall.sourceNode->getCallableRegion();
Region *calleeRegion = resolvedCall.targetNode->getCallableRegion();
assert(calleeRegion && callerRegion && "unexpected external node");
auto countOps = [](Region *region) {
unsigned count = 0;
region->walk([&](Operation *) { ++count; });
return count;
};
unsigned callerOps = countOps(callerRegion);
if (callerOps == 0)
return true;
unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
<< inliningThreshold << "%): " << ratio << "%\n");
return ratio <= inliningThreshold;
}
void InlinerPass::runOnOperation() {
CallGraph &cg = getAnalysis<CallGraph>();
Operation *op = getOperation();
if (!op->hasTrait<OpTrait::SymbolTable>()) {
op->emitOpError() << " was scheduled to run under the inliner, but does "
"not define a symbol table";
return signalPassFailure();
}
auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
return isProfitableToInline(call, inliningThreshold);
};
Inliner inliner(op, cg, *this, getAnalysisManager(), runPipelineHelper,
config, profitabilityCb);
if (failed(inliner.doInlining()))
signalPassFailure();
return;
}
LogicalResult InlinerPass::initializeOptions(
StringRef options,
function_ref<LogicalResult(const Twine &)> errorHandler) {
if (failed(Pass::initializeOptions(options, errorHandler)))
return failure();
if (!defaultPipelineStr.empty()) {
std::string defaultPipelineCopy = defaultPipelineStr;
config.setDefaultPipeline([=](OpPassManager &pm) {
(void)parsePassPipeline(defaultPipelineCopy, pm);
});
} else if (defaultPipelineStr.getNumOccurrences()) {
config.setDefaultPipeline(nullptr);
}
llvm::StringMap<OpPassManager> pipelines;
for (OpPassManager pipeline : opPipelineList)
if (!pipeline.empty())
pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline);
config.setOpPipelines(std::move(pipelines));
config.setMaxInliningIterations(maxInliningIterations);
return success();
}
std::unique_ptr<Pass> mlir::createInlinerPass() {
return std::make_unique<InlinerPass>();
}
std::unique_ptr<Pass>
mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
std::move(opPipelines));
}
std::unique_ptr<Pass> mlir::createInlinerPass(
llvm::StringMap<OpPassManager> opPipelines,
std::function<void(OpPassManager &)> defaultPipelineBuilder) {
return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
std::move(opPipelines));
}