#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/DebugCounter.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using namespace llvm;
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
PassPipelineFn passManagerSetupFn) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
bool wasThreadingEnabled = context->isMultithreadingEnabled();
context->disableMultithreading();
PassManager pm(context, OpPassManager::Nesting::Implicit);
pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
pm.enableTiming(timing);
ParserConfig config(context);
attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
TimingScope parserTiming = timing.nest("Parser");
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, config));
context->enableMultithreading(wasThreadingEnabled);
if (!module)
return failure();
parserTiming.stop();
if (failed(passManagerSetupFn(pm)))
return failure();
if (failed(pm.run(*module)))
return failure();
TimingScope outputTiming = timing.nest("Output");
module->print(os);
os << '\n';
return success();
}
static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
llvm::ThreadPool *threadPool) {
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
if (preloadDialectsInContext)
context.loadAllAvailableDialects();
context.allowUnregisteredDialects(allowUnregisteredDialects);
if (verifyDiagnostics)
context.printOpOnDiagnostic(false);
context.getDebugActionManager().registerActionHandler<DebugCounter>();
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
&context, passManagerSetupFn);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
passManagerSetupFn);
return sourceMgrHandler.verify();
}
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
PassPipelineFn passManagerSetupFn,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext) {
ThreadPool *threadPool = nullptr;
MLIRContext threadPoolCtx;
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
preloadDialectsInContext, passManagerSetupFn, registry,
threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
splitInputFile, true);
}
LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
std::unique_ptr<MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
bool preloadDialectsInContext) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
return failure();
};
return passPipeline.addToPipeline(pm, errorHandler);
};
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry ®istry,
bool preloadDialectsInContext) {
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));
static cl::opt<bool> splitInputFile(
"split-input-file",
cl::desc("Split the input file into pieces and process each "
"chunk independently"),
cl::init(false));
static cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
cl::init(false));
static cl::opt<bool> verifyPasses(
"verify-each",
cl::desc("Run the verifier after each transformation pass"),
cl::init(true));
static cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
cl::desc("Allow operation with no registered dialects"), cl::init(false));
static cl::opt<bool> showDialects(
"show-dialects", cl::desc("Print the list of registered dialects"),
cl::init(false));
InitLLVM y(argc, argv);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerPassManagerCLOptions();
registerDefaultTimingManagerCLOptions();
DebugCounter::registerCLOptions();
PassPipelineCLParser passPipeline("", "Compiler passes to run");
std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
{
llvm::raw_string_ostream os(helpHeader);
interleaveComma(registry.getDialectNames(), os,
[&](auto name) { os << name; });
}
cl::ParseCommandLineOptions(argc, argv, helpHeader);
if (showDialects) {
llvm::outs() << "Available Dialects:\n";
interleave(
registry.getDialectNames(), llvm::outs(),
[](auto name) { llvm::outs() << name; }, "\n");
return success();
}
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return failure();
}
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects, preloadDialectsInContext)))
return failure();
output->keep();
return success();
}