#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Debug/CLOptionsSetup.h"
#include "mlir/Debug/Counter.h"
#include "mlir/Debug/DebuggerExecutionContextHook.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Debug/Observers/ActionLogging.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/Timing.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/ParseUtilities.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
using namespace llvm;
namespace {
class BytecodeVersionParser : public cl::parser<std::optional<int64_t>> {
public:
BytecodeVersionParser(cl::Option &o)
: cl::parser<std::optional<int64_t>>(o) {}
bool parse(cl::Option &o, StringRef , StringRef arg,
std::optional<int64_t> &v) {
long long w;
if (getAsSignedInteger(arg, 10, w))
return o.error("Invalid argument '" + arg +
"', only integer is supported.");
v = w;
return false;
}
};
struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
MlirOptMainConfigCLOptions() {
static cl::opt<bool, true> allowUnregisteredDialects(
"allow-unregistered-dialect",
cl::desc("Allow operation with no registered dialects"),
cl::location(allowUnregisteredDialectsFlag), cl::init(false));
static cl::opt<bool, true> dumpPassPipeline(
"dump-pass-pipeline", cl::desc("Print the pipeline that will be run"),
cl::location(dumpPassPipelineFlag), cl::init(false));
static cl::opt<bool, true> emitBytecode(
"emit-bytecode", cl::desc("Emit bytecode when generating output"),
cl::location(emitBytecodeFlag), cl::init(false));
static cl::opt<bool, true> elideResourcesFromBytecode(
"elide-resource-data-from-bytecode",
cl::desc("Elide resources when generating bytecode"),
cl::location(elideResourceDataFromBytecodeFlag), cl::init(false));
static cl::opt<std::optional<int64_t>, true,
BytecodeVersionParser>
bytecodeVersion(
"emit-bytecode-version",
cl::desc("Use specified bytecode when generating output"),
cl::location(emitBytecodeVersion), cl::init(std::nullopt));
static cl::opt<std::string, true> irdlFile(
"irdl-file",
cl::desc("IRDL file to register before processing the input"),
cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));
static cl::opt<bool, true> enableDebuggerHook(
"mlir-enable-debugger-hook",
cl::desc("Enable Debugger hook for debugging MLIR Actions"),
cl::location(enableDebuggerActionHookFlag), cl::init(false));
static cl::opt<bool, true> explicitModule(
"no-implicit-module",
cl::desc("Disable implicit addition of a top-level module op during "
"parsing"),
cl::location(useExplicitModuleFlag), cl::init(false));
static cl::opt<bool, true> runReproducer(
"run-reproducer", cl::desc("Run the pipeline stored in the reproducer"),
cl::location(runReproducerFlag), cl::init(false));
static cl::opt<bool, true> showDialects(
"show-dialects",
cl::desc("Print the list of registered dialects and exit"),
cl::location(showDialectsFlag), cl::init(false));
static cl::opt<std::string, true> splitInputFile{
"split-input-file", llvm::cl::ValueOptional,
cl::callback([&](const std::string &str) {
if (str.empty())
splitInputFile.setValue(kDefaultSplitMarker);
}),
cl::desc("Split the input file into chunks using the given or "
"default marker and process each chunk independently"),
cl::location(splitInputFileFlag), cl::init("")};
static cl::opt<std::string, true> outputSplitMarker(
"output-split-marker",
cl::desc("Split marker to use for merging the ouput"),
cl::location(outputSplitMarkerFlag), cl::init(kDefaultSplitMarker));
static cl::opt<bool, true> verifyDiagnostics(
"verify-diagnostics",
cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
cl::location(verifyDiagnosticsFlag), cl::init(false));
static cl::opt<bool, true> verifyPasses(
"verify-each",
cl::desc("Run the verifier after each transformation pass"),
cl::location(verifyPassesFlag), cl::init(true));
static cl::opt<bool, true> verifyRoundtrip(
"verify-roundtrip",
cl::desc("Round-trip the IR after parsing and ensure it succeeds"),
cl::location(verifyRoundtripFlag), cl::init(false));
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
static cl::opt<std::string, true>
generateReproducerFile(
"mlir-generate-reproducer",
llvm::cl::desc(
"Generate an mlir reproducer at the provided filename"
" (no crash required)"),
cl::location(generateReproducerFileFlag), cl::init(""),
cl::value_desc("filename"));
passPlugins.setCallback([&](const std::string &pluginPath) {
auto plugin = PassPlugin::load(pluginPath);
if (!plugin) {
errs() << "Failed to load passes from '" << pluginPath
<< "'. Request ignored.\n";
return;
}
plugin.get().registerPassRegistryCallbacks();
});
static cl::list<std::string> dialectPlugins(
"load-dialect-plugin", cl::desc("Load dialects from plugin library"));
this->dialectPlugins = std::addressof(dialectPlugins);
static PassPipelineCLParser passPipeline("", "Compiler passes to run", "p");
setPassPipelineParser(passPipeline);
}
void setDialectPluginsCallback(DialectRegistry ®istry);
cl::list<std::string> *dialectPlugins = nullptr;
};
}
ManagedStatic<MlirOptMainConfigCLOptions> clOptionsConfig;
void MlirOptMainConfig::registerCLOptions(DialectRegistry ®istry) {
clOptionsConfig->setDialectPluginsCallback(registry);
tracing::DebugConfig::registerCLOptions();
}
MlirOptMainConfig MlirOptMainConfig::createFromCLOptions() {
clOptionsConfig->setDebugConfig(tracing::DebugConfig::createFromCLOptions());
return *clOptionsConfig;
}
MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
const PassPipelineCLParser &passPipeline) {
passPipelineCallback = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
return failure();
};
if (failed(passPipeline.addToPipeline(pm, errorHandler)))
return failure();
if (this->shouldDumpPassPipeline()) {
pm.dump();
llvm::errs() << "\n";
}
return success();
};
return *this;
}
void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
DialectRegistry ®istry) {
dialectPlugins->setCallback([&](const std::string &pluginPath) {
auto plugin = DialectPlugin::load(pluginPath);
if (!plugin) {
errs() << "Failed to load dialect plugin from '" << pluginPath
<< "'. Request ignored.\n";
return;
};
plugin.get().registerDialectRegistryCallbacks(registry);
});
}
LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
DialectRegistry registry;
registry.insert<irdl::IRDLDialect>();
ctx.appendDialectRegistry(registry);
std::string errorMessage;
std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
if (!file) {
emitError(UnknownLoc::get(&ctx)) << errorMessage;
return failure();
}
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
if (!module)
return failure();
return irdl::loadDialects(module.get());
}
static LogicalResult doVerifyRoundTrip(Operation *op,
const MlirOptMainConfig &config,
bool useBytecode) {
MLIRContext roundtripContext;
OwningOpRef<Operation *> roundtripModule;
roundtripContext.appendDialectRegistry(
op->getContext()->getDialectRegistry());
if (op->getContext()->allowsUnregisteredDialects())
roundtripContext.allowUnregisteredDialects();
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, roundtripContext)))
return failure();
std::string testType = (useBytecode) ? "bytecode" : "textual";
{
std::string buffer;
llvm::raw_string_ostream ostream(buffer);
if (useBytecode) {
if (failed(writeBytecodeToFile(op, ostream))) {
op->emitOpError()
<< "failed to write bytecode, cannot verify round-trip.\n";
return failure();
}
} else {
op->print(ostream,
OpPrintingFlags().printGenericOpForm().enableDebugInfo());
}
FallbackAsmResourceMap fallbackResourceMap;
ParserConfig parseConfig(&roundtripContext, true,
&fallbackResourceMap);
roundtripModule =
parseSourceString<Operation *>(ostream.str(), parseConfig);
if (!roundtripModule) {
op->emitOpError() << "failed to parse " << testType
<< " content back, cannot verify round-trip.\n";
return failure();
}
}
std::string reference, roundtrip;
{
llvm::raw_string_ostream ostreamref(reference);
op->print(ostreamref,
OpPrintingFlags().printGenericOpForm().enableDebugInfo());
llvm::raw_string_ostream ostreamrndtrip(roundtrip);
roundtripModule.get()->print(
ostreamrndtrip,
OpPrintingFlags().printGenericOpForm().enableDebugInfo());
}
if (reference != roundtrip) {
return op->emitOpError()
<< testType
<< " roundTrip testing roundtripped module differs "
"from reference:\n<<<<<<Reference\n"
<< reference << "\n=====\n"
<< roundtrip << "\n>>>>>roundtripped\n";
}
return success();
}
static LogicalResult doVerifyRoundTrip(Operation *op,
const MlirOptMainConfig &config) {
auto txtStatus = doVerifyRoundTrip(op, config, false);
auto bcStatus = doVerifyRoundTrip(op, config, true);
return success(succeeded(txtStatus) && succeeded(bcStatus));
}
static LogicalResult
performActions(raw_ostream &os,
const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
MLIRContext *context, const MlirOptMainConfig &config) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
bool wasThreadingEnabled = context->isMultithreadingEnabled();
context->disableMultithreading();
PassReproducerOptions reproOptions;
FallbackAsmResourceMap fallbackResourceMap;
ParserConfig parseConfig(context, true,
&fallbackResourceMap);
if (config.shouldRunReproducer())
reproOptions.attachResourceParser(parseConfig);
TimingScope parserTiming = timing.nest("Parser");
OwningOpRef<Operation *> op = parseSourceFileForTool(
sourceMgr, parseConfig, !config.shouldUseExplicitModule());
parserTiming.stop();
if (!op)
return failure();
if (config.shouldVerifyRoundtrip() &&
failed(doVerifyRoundTrip(op.get(), config)))
return failure();
context->enableMultithreading(wasThreadingEnabled);
PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
pm.enableVerifier(config.shouldVerifyPasses());
if (failed(applyPassManagerCLOptions(pm)))
return failure();
pm.enableTiming(timing);
if (config.shouldRunReproducer() && failed(reproOptions.apply(pm)))
return failure();
if (failed(config.setupPassPipeline(pm)))
return failure();
if (failed(pm.run(*op)))
return failure();
if (!config.getReproducerFilename().empty()) {
StringRef anchorName = pm.getAnyOpAnchorName();
const auto &passes = pm.getPasses();
makeReproducer(anchorName, passes, op.get(),
config.getReproducerFilename());
}
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
BytecodeWriterConfig writerConfig(fallbackResourceMap);
if (auto v = config.bytecodeVersionToEmit())
writerConfig.setDesiredBytecodeVersion(*v);
if (config.shouldElideResourceDataFromBytecode())
writerConfig.setElideResourceDataFlag();
return writeBytecodeToFile(op.get(), os, writerConfig);
}
if (config.bytecodeVersionToEmit().has_value())
return emitError(UnknownLoc::get(pm.getContext()))
<< "bytecode version while not emitting bytecode";
AsmState asmState(op.get(), OpPrintingFlags(), nullptr,
&fallbackResourceMap);
op.get()->print(os, asmState);
os << '\n';
return success();
}
static LogicalResult processBuffer(raw_ostream &os,
std::unique_ptr<MemoryBuffer> ownedBuffer,
const MlirOptMainConfig &config,
DialectRegistry ®istry,
llvm::ThreadPoolInterface *threadPool) {
auto sourceMgr = std::make_shared<SourceMgr>();
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
return failure();
context.allowUnregisteredDialects(config.shouldAllowUnregisteredDialects());
if (config.shouldVerifyDiagnostics())
context.printOpOnDiagnostic(false);
tracing::InstallDebugHandler installDebugHandler(context,
config.getDebugConfig());
if (!config.shouldVerifyDiagnostics()) {
SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
return performActions(os, sourceMgr, &context, config);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
(void)performActions(os, sourceMgr, &context, config);
return sourceMgrHandler.verify();
}
std::pair<std::string, std::string>
mlir::registerAndParseCLIOptions(int argc, char **argv,
llvm::StringRef toolName,
DialectRegistry ®istry) {
static cl::opt<std::string> inputFilename(
cl::Positional, cl::desc("<input file>"), cl::init("-"));
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"));
MlirOptMainConfig::registerCLOptions(registry);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerPassManagerCLOptions();
registerDefaultTimingManagerCLOptions();
tracing::DebugCounter::registerCLOptions();
std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
{
llvm::raw_string_ostream os(helpHeader);
interleaveComma(registry.getDialectNames(), os,
[&](auto name) { os << name; });
}
cl::ParseCommandLineOptions(argc, argv, helpHeader);
return std::make_pair(inputFilename.getValue(), outputFilename.getValue());
}
static LogicalResult printRegisteredDialects(DialectRegistry ®istry) {
llvm::outs() << "Available Dialects: ";
interleave(registry.getDialectNames(), llvm::outs(), ",");
llvm::outs() << "\n";
return success();
}
LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr<llvm::MemoryBuffer> buffer,
DialectRegistry ®istry,
const MlirOptMainConfig &config) {
if (config.shouldShowDialects())
return printRegisteredDialects(registry);
ThreadPoolInterface *threadPool = nullptr;
MLIRContext threadPoolCtx;
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), config, registry,
threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
config.inputSplitMarker(),
config.outputSplitMarker());
}
LogicalResult mlir::MlirOptMain(int argc, char **argv,
llvm::StringRef inputFilename,
llvm::StringRef outputFilename,
DialectRegistry ®istry) {
InitLLVM y(argc, argv);
MlirOptMainConfig config = MlirOptMainConfig::createFromCLOptions();
if (config.shouldShowDialects())
return printRegisteredDialects(registry);
if (inputFilename == "-" &&
sys::Process::FileDescriptorIsDisplayed(fileno(stdin)))
llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to "
"interrupt)\n";
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return failure();
}
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return failure();
}
if (failed(MlirOptMain(output->os(), std::move(file), registry, config)))
return failure();
output->keep();
return success();
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
DialectRegistry ®istry) {
std::string inputFilename, outputFilename;
std::tie(inputFilename, outputFilename) =
registerAndParseCLIOptions(argc, argv, toolName, registry);
return MlirOptMain(argc, argv, inputFilename, outputFilename, registry);
}