#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/OwningOpRef.h"
#include "gtest/gtest.h"
using namespace mlir;
namespace {
class IndexFolderTest : public testing::Test {
public:
IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
template <typename OpT>
void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
protected:
MLIRContext ctx;
OpBuilder b{&ctx};
};
}
template <typename OpT>
void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
ArrayRef<Attribute> operands) {
OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
state.addTypes(type);
OwningOpRef<OpT> op = cast<OpT>(b.create(state));
SmallVector<OpFoldResult> results;
LogicalResult result = op->getOperation()->fold(operands, results);
if (failed(result)) {
value = nullptr;
return;
}
ASSERT_EQ(results.size(), 1u);
value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
ASSERT_TRUE(value);
}
TEST_F(IndexFolderTest, TestCastUOpFolder) {
IntegerAttr value;
auto fold = [&](Type type, Attribute input) {
foldOp<index::CastUOp>(value, type, input);
};
fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
ASSERT_TRUE(value);
EXPECT_EQ(value.getInt(), 20480u);
fold(b.getIntegerType(64), b.getIndexAttr(2000));
ASSERT_TRUE(value);
EXPECT_EQ(value.getInt(), 2000u);
fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
EXPECT_FALSE(value);
fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
ASSERT_TRUE(value);
EXPECT_EQ(value.getInt(), 65536);
fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
EXPECT_FALSE(value);
}
TEST_F(IndexFolderTest, TestCastSOpFolder) {
IntegerAttr value;
auto fold = [&](Type type, Attribute input) {
foldOp<index::CastSOp>(value, type, input);
};
fold(b.getIntegerType(64), b.getIndexAttr(-2000));
ASSERT_TRUE(value);
EXPECT_EQ(value.getInt(), -2000);
fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
ASSERT_TRUE(value);
EXPECT_EQ(value.getInt(), -65536);
}