#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include <functional>
using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
namespace {
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto module = op->getParentOfType<ModuleOp>();
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
rewriter.create<LLVM::UnreachableOp>(loc);
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
};
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
: public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
op->getSuccessors(), op->getAttrs());
return success();
}
};
struct BranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
using Base::Base;
};
struct CondBranchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
using Base::Base;
};
struct SwitchOpLowering
: public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
using Base::Base;
};
}
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
}
namespace {
struct ConvertControlFlowToLLVM
: public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
ConvertControlFlowToLLVM() = default;
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
}
std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
return std::make_unique<ConvertControlFlowToLLVM>();
}