#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::scf;
namespace {
struct TestSCFWhileOpBuilderPass
: public PassWrapper<TestSCFWhileOpBuilderPass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass)
StringRef getArgument() const final { return "test-scf-while-op-builder"; }
StringRef getDescription() const final {
return "test build functions of scf.while";
}
explicit TestSCFWhileOpBuilderPass() = default;
TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default;
void runOnOperation() override {
func::FuncOp func = getOperation();
func.walk([&](WhileOp whileOp) {
Location loc = whileOp->getLoc();
ImplicitLocOpBuilder builder(loc, whileOp);
TypeRange resultTypes = whileOp->getResultTypes();
ValueRange operands = whileOp->getOperands();
builder.create<WhileOp>(
loc, resultTypes, operands,
[&](OpBuilder &b, Location loc, ValueRange args) {
ImplicitLocOpBuilder builder(loc, b);
auto castOp =
builder.create<UnrealizedConversionCastOp>(resultTypes, args);
auto cmp = builder.create<ConstantIntOp>(1, 1);
builder.create<ConditionOp>(cmp, castOp->getResults());
},
[&](OpBuilder &b, Location loc, ValueRange args) {
ImplicitLocOpBuilder builder(loc, b);
auto castOp = builder.create<UnrealizedConversionCastOp>(
operands.getTypes(), args);
builder.create<YieldOp>(castOp->getResults());
});
});
}
};
}
namespace mlir {
namespace test {
void registerTestSCFWhileOpBuilderPass() {
PassRegistration<TestSCFWhileOpBuilderPass>();
}
}
}