#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
struct TestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
static StringRef getDialectNamespace() { return "test"; };
TestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
};
struct AnotherTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
static StringRef getDialectNamespace() { return "test"; };
AnotherTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<AnotherTestDialect>()) {}
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
MLIRContext context;
context.loadDialect<TestDialect>();
ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
}
struct SecondTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
static StringRef getDialectNamespace() { return "test2"; }
SecondTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<SecondTestDialect>()) {}
};
struct TestDialectInterfaceBase
: public DialectInterface::Base<TestDialectInterfaceBase> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
virtual int function() const { return 42; }
};
struct TestDialectInterface : public TestDialectInterfaceBase {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 56; }
};
struct SecondTestDialectInterface : public TestDialectInterfaceBase {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 78; }
};
TEST(Dialect, DelayedInterfaceRegistration) {
DialectRegistry registry;
registry.insert<TestDialect, SecondTestDialect>();
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface == nullptr);
DialectRegistry secondRegistry;
secondRegistry.insert<SecondTestDialect>();
secondRegistry.addExtension(
+[](MLIRContext *ctx, SecondTestDialect *dialect) {
dialect->addInterfaces<SecondTestDialectInterface>();
});
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
TEST(Dialect, RepeatedDelayedRegistration) {
DialectRegistry registry;
registry.insert<TestDialect>();
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
DialectRegistry secondRegistry;
secondRegistry.insert<TestDialect>();
secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
context.appendDialectRegistry(secondRegistry);
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}
namespace {
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
DummyExtension(int *counter, int numRecursive)
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}
void apply(MLIRContext *ctx, TestDialect *dialect) const final {
++(*counter);
DialectRegistry nestedRegistry;
for (int i = 0; i < numRecursive; ++i)
nestedRegistry.addExtension(
std::make_unique<DummyExtension>(counter, 0));
ctx->appendDialectRegistry(nestedRegistry);
}
private:
int *counter;
int numRecursive;
};
}
TEST(Dialect, NestedDialectExtension) {
DialectRegistry registry;
registry.insert<TestDialect>();
int counter1 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
int counter2 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
EXPECT_GE(counter1, 101);
EXPECT_GE(counter2, 1);
}
}