#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct TestEmulateWideIntPass
: public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
TestEmulateWideIntPass() = default;
TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
vector::VectorDialect>();
}
StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
StringRef getDescription() const final {
return "Function pass to test Wide Integer Emulation";
}
void runOnOperation() override {
if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
signalPassFailure();
return;
}
func::FuncOp op = getOperation();
if (!op.getSymName().starts_with(testFunctionPrefix))
return;
MLIRContext *ctx = op.getContext();
arith::WideIntEmulationConverter typeConverter(widestIntSupported);
auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> std::optional<Value> {
auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
return cast->getResult(0);
};
typeConverter.addSourceMaterialization(addBitcast);
typeConverter.addTargetMaterialization(addBitcast);
ConversionTarget target(*ctx);
target
.addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
[&typeConverter](Operation *op) {
return typeConverter.isLegal(op);
});
RewritePatternSet patterns(ctx);
arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
Option<std::string> testFunctionPrefix{
*this, "function-prefix",
llvm::cl::desc("Prefix of functions to run the emulation pass on"),
llvm::cl::init("emulate_")};
Option<unsigned> widestIntSupported{
*this, "widest-int-supported",
llvm::cl::desc("Maximum integer bit width supported by the target"),
llvm::cl::init(32)};
};
}
namespace mlir::test {
void registerTestArithEmulateWideIntPass() {
PassRegistration<TestEmulateWideIntPass>();
}
}