#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
#include <list>
using namespace mlir;
using namespace llvm;
namespace {
class TestDialectVersionParser : public cl::parser<test::TestDialectVersion> {
public:
TestDialectVersionParser(cl::Option &o)
: cl::parser<test::TestDialectVersion>(o) {}
bool parse(cl::Option &o, StringRef , StringRef arg,
test::TestDialectVersion &v) {
long long major, minor;
if (getAsSignedInteger(arg.split(".").first, 10, major))
return o.error("Invalid argument '" + arg);
if (getAsSignedInteger(arg.split(".").second, 10, minor))
return o.error("Invalid argument '" + arg);
v = test::TestDialectVersion(major, minor);
return false;
}
static void print(raw_ostream &os, const test::TestDialectVersion &v) {
os << v.major_ << "." << v.minor_;
};
};
struct TestBytecodeRoundtripPass
: public PassWrapper<TestBytecodeRoundtripPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeRoundtripPass)
StringRef getArgument() const final { return "test-bytecode-roundtrip"; }
StringRef getDescription() const final {
return "Test pass to implement bytecode roundtrip tests.";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<test::TestDialect>();
}
TestBytecodeRoundtripPass() = default;
TestBytecodeRoundtripPass(const TestBytecodeRoundtripPass &) {}
LogicalResult initialize(MLIRContext *context) override {
testDialect = context->getOrLoadDialect<test::TestDialect>();
return success();
}
void runOnOperation() override {
switch (testKind) {
case (0):
return runTest0(getOperation());
case (1):
return runTest1(getOperation());
case (2):
return runTest2(getOperation());
case (3):
return runTest3(getOperation());
case (4):
return runTest4(getOperation());
case (5):
return runTest5(getOperation());
case (6):
return runTest6(getOperation());
default:
llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass");
}
}
mlir::Pass::Option<test::TestDialectVersion, TestDialectVersionParser>
targetVersion{*this, "test-dialect-version",
llvm::cl::desc(
"Specifies the test dialect version to emit and parse"),
cl::init(test::TestDialectVersion())};
mlir::Pass::Option<int> testKind{
*this, "test-kind", llvm::cl::desc("Specifies the test kind to execute"),
cl::init(0)};
private:
void doRoundtripWithConfigs(Operation *op,
const BytecodeWriterConfig &writeConfig,
const ParserConfig &parseConfig) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
if (failed(writeBytecodeToFile(op, os, writeConfig))) {
op->emitError() << "failed to write bytecode\n";
signalPassFailure();
return;
}
auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig);
if (!newModuleOp.get()) {
op->emitError() << "failed to read bytecode\n";
signalPassFailure();
return;
}
newModuleOp->print(llvm::outs());
}
void runTest0(Operation *op) {
auto newCtx = std::make_shared<MLIRContext>();
test::TestDialectVersion targetEmissionVersion = targetVersion;
BytecodeWriterConfig writeConfig;
writeConfig.setDialectVersion<test::TestDialect>(
std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
auto versionOr = writer.getDialectVersion<test::TestDialect>();
assert(succeeded(versionOr) && "expected reader to be able to access "
"the version for test dialect");
const auto *version =
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
if (version->major_ >= 2)
return failure();
if (auto type = llvm::dyn_cast<IntegerType>(entryValue)) {
llvm::outs() << "Overriding IntegerType encoding...\n";
dialectGroupName = StringLiteral("funky");
writer.writeVarInt( 999);
writer.writeVarInt(type.getWidth() << 2 | type.getSignedness());
return success();
}
return failure();
});
newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry());
newCtx->allowUnregisteredDialects();
ParserConfig parseConfig(newCtx.get(), true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
auto versionOr = reader.getDialectVersion<test::TestDialect>();
assert(succeeded(versionOr) && "expected reader to be able to access "
"the version for test dialect");
const auto *version =
reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
if (version->major_ >= 2)
return success();
if (dialectName != StringLiteral("funky"))
return success();
uint64_t encoding;
if (failed(reader.readVarInt(encoding)) || encoding != 999)
return success();
llvm::outs() << "Overriding parsing of IntegerType encoding...\n";
uint64_t widthAndSignedness, width;
IntegerType::SignednessSemantics signedness;
if (succeeded(reader.readVarInt(widthAndSignedness)) &&
((width = widthAndSignedness >> 2), true) &&
((signedness = static_cast<IntegerType::SignednessSemantics>(
widthAndSignedness & 0x3)),
true))
entry = IntegerType::get(reader.getContext(), width, signedness);
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
void runTest1(Operation *op) {
auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachTypeCallback(
[&](Type entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
if (llvm::isa<test::TestI32Type>(entryValue)) {
llvm::outs() << "Overriding TestI32Type encoding...\n";
auto builtinI32Type =
IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
dialectGroupName = StringLiteral("builtin");
if (succeeded(iface->writeType(builtinI32Type, writer)))
return success();
}
return failure();
});
ParserConfig parseConfig(op->getContext(), true);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
void runTest2(Operation *op) {
auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), true);
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
if (dialectName != StringLiteral("builtin"))
return success();
Type builtinAttr = iface->readType(reader);
if (auto integerType =
llvm::dyn_cast_or_null<IntegerType>(builtinAttr)) {
if (integerType.getWidth() == 32 && integerType.isSignless()) {
llvm::outs() << "Overriding parsing of TestI32Type encoding...\n";
entry = test::TestI32Type::get(reader.getContext());
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
void runTest3(Operation *op) {
auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute entryValue, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
if (auto testParamAttrs =
llvm::dyn_cast<test::TestAttrParamsAttr>(entryValue)) {
llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n";
dialectGroupName = StringLiteral("builtin");
auto denseAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, i32Type),
{testParamAttrs.getV0(), testParamAttrs.getV1()});
if (succeeded(iface->writeAttribute(denseAttr, writer)))
return success();
}
return failure();
});
ParserConfig parseConfig(op->getContext(), false);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
void runTest4(Operation *op) {
auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
auto i32Type = IntegerType::get(op->getContext(), 32,
IntegerType::SignednessSemantics::Signless);
BytecodeWriterConfig writeConfig;
ParserConfig parseConfig(op->getContext(), false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
Attribute builtinAttr = iface->readAttribute(reader);
if (auto denseAttr =
llvm::dyn_cast_or_null<DenseIntElementsAttr>(builtinAttr)) {
if (denseAttr.getType().getShape() == ArrayRef<int64_t>(2) &&
denseAttr.getElementType() == i32Type) {
llvm::outs()
<< "Overriding parsing of TestAttrParamsAttr encoding...\n";
int v0 = denseAttr.getValues<IntegerAttr>()[0].getInt();
int v1 = denseAttr.getValues<IntegerAttr>()[1].getInt();
entry =
test::TestAttrParamsAttr::get(reader.getContext(), v0, v1);
}
}
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
void runTest5(Operation *op) {
auto *builtin = op->getContext()->getLoadedDialect<mlir::BuiltinDialect>();
BytecodeDialectInterface *iface =
builtin->getRegisteredInterface<BytecodeDialectInterface>();
BytecodeWriterConfig writeConfig;
writeConfig.attachAttributeCallback(
[&](Attribute attr, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeAttribute(attr, writer);
});
writeConfig.attachTypeCallback(
[&](Type type, std::optional<StringRef> &dialectGroupName,
DialectBytecodeWriter &writer) -> LogicalResult {
return iface->writeType(type, writer);
});
ParserConfig parseConfig(op->getContext(), false);
parseConfig.getBytecodeReaderConfig().attachAttributeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Attribute &entry) -> LogicalResult {
Attribute builtinAttr = iface->readAttribute(reader);
if (!builtinAttr)
return failure();
entry = builtinAttr;
return success();
});
parseConfig.getBytecodeReaderConfig().attachTypeCallback(
[&](DialectBytecodeReader &reader, StringRef dialectName,
Type &entry) -> LogicalResult {
Type builtinType = iface->readType(reader);
if (!builtinType) {
return failure();
}
entry = builtinType;
return success();
});
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
LogicalResult downgradeToVersion(Operation *op,
const test::TestDialectVersion &version) {
if ((version.major_ == 2) && (version.minor_ == 0))
return success();
if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
return op->emitError() << "current test dialect version is 2.0, "
"can't downgrade to version: "
<< version.major_ << "." << version.minor_;
}
auto status = op->walk([&](test::TestVersionedOpA op) {
auto &prop = op.getProperties();
if (prop.modifier.getValue()) {
op->emitOpError() << "cannot downgrade to version " << version.major_
<< "." << version.minor_
<< " since the modifier is not compatible";
return WalkResult::interrupt();
}
llvm::outs() << "downgrading op...\n";
return WalkResult::advance();
});
return failure(status.wasInterrupted());
}
void runTest6(Operation *op) {
test::TestDialectVersion targetEmissionVersion = targetVersion;
auto status = downgradeToVersion(op, targetEmissionVersion);
assert(succeeded(status) && "expected the downgrade to succeed");
(void)status;
BytecodeWriterConfig writeConfig;
writeConfig.setDialectVersion<test::TestDialect>(
std::make_unique<test::TestDialectVersion>(targetEmissionVersion));
ParserConfig parseConfig(op->getContext(), true);
doRoundtripWithConfigs(op, writeConfig, parseConfig);
}
test::TestDialect *testDialect;
};
}
namespace mlir {
void registerTestBytecodeRoundtripPasses() {
PassRegistration<TestBytecodeRoundtripPass>();
}
}