#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTAFFINETOSTANDARD
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::affine;
using namespace mlir::vector;
static Value buildMinMaxReductionSeq(Location loc,
arith::CmpIPredicate predicate,
ValueRange values, OpBuilder &builder) {
assert(!values.empty() && "empty min/max chain");
assert(predicate == arith::CmpIPredicate::sgt ||
predicate == arith::CmpIPredicate::slt);
auto valueIt = values.begin();
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
if (predicate == arith::CmpIPredicate::sgt)
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
else
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
}
return value;
}
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
ValueRange operands) {
if (auto values = expandAffineMap(builder, loc, map, operands))
return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
builder);
return nullptr;
}
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
ValueRange operands) {
if (auto values = expandAffineMap(builder, loc, map, operands))
return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
builder);
return nullptr;
}
Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
op.getUpperBoundOperands());
}
Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
op.getLowerBoundOperands());
}
namespace {
class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
public:
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineMinOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.getOperands());
if (!reduced)
return failure();
rewriter.replaceOp(op, reduced);
return success();
}
};
class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
public:
using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineMaxOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.getOperands());
if (!reduced)
return failure();
rewriter.replaceOp(op, reduced);
return success();
}
};
class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
public:
using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineYieldOp op,
PatternRewriter &rewriter) const override {
if (isa<scf::ParallelOp>(op->getParentOp())) {
return failure();
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
return success();
}
};
class AffineForLowering : public OpRewritePattern<AffineForOp> {
public:
using OpRewritePattern<AffineForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineForOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step =
rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
step, op.getInits());
rewriter.eraseBlock(scfForOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
scfForOp.getRegion().end());
rewriter.replaceOp(op, scfForOp.getResults());
return success();
}
};
class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
public:
using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineParallelOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Value, 8> steps;
SmallVector<Value, 8> upperBoundTuple;
SmallVector<Value, 8> lowerBoundTuple;
SmallVector<Value, 8> identityVals;
lowerBoundTuple.reserve(op.getNumDims());
upperBoundTuple.reserve(op.getNumDims());
for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
op.getLowerBoundsOperands());
if (!lower)
return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
lowerBoundTuple.push_back(lower);
Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
op.getUpperBoundsOperands());
if (!upper)
return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
upperBoundTuple.push_back(upper);
}
steps.reserve(op.getSteps().size());
for (int64_t step : op.getSteps())
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
auto affineParOpTerminator =
cast<AffineYieldOp>(op.getBody()->getTerminator());
scf::ParallelOp parOp;
if (op.getResults().empty()) {
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
upperBoundTuple, steps,
nullptr);
rewriter.eraseBlock(parOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
rewriter.replaceOp(op, parOp.getResults());
rewriter.setInsertionPoint(affineParOpTerminator);
rewriter.replaceOpWithNewOp<scf::ReduceOp>(affineParOpTerminator);
return success();
}
ArrayRef<Attribute> reductions = op.getReductions().getValue();
for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
Attribute reduction = std::get<0>(pair);
Type resultType = std::get<1>(pair);
std::optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt()));
assert(reductionOp && "Reduction operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
identityVals.push_back(
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
parOp = rewriter.create<scf::ParallelOp>(
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
nullptr);
rewriter.eraseBlock(parOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
parOp.getRegion().end());
assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
"Unequal number of reductions and operands.");
rewriter.setInsertionPoint(affineParOpTerminator);
auto reduceOp = rewriter.replaceOpWithNewOp<scf::ReduceOp>(
affineParOpTerminator, affineParOpTerminator->getOperands());
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
std::optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
cast<IntegerAttr>(reductions[i]).getInt());
assert(reductionOp && "Reduction Operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
rewriter.setInsertionPoint(&parOp.getBody()->back());
Block &reductionBody = reduceOp.getReductions()[i].front();
rewriter.setInsertionPointToEnd(&reductionBody);
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
return success();
}
};
class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
public:
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto integerSet = op.getIntegerSet();
Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);
Value cond = nullptr;
for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
AffineExpr constraintExpr = integerSet.getConstraint(i);
bool isEquality = integerSet.isEq(i);
auto numDims = integerSet.getNumDims();
Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
operandsRef.take_front(numDims),
operandsRef.drop_front(numDims));
if (!affResult)
return failure();
auto pred =
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
Value cmpVal =
rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
cond = cond
? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
: cmpVal;
}
cond = cond ? cond
: rewriter.create<arith::ConstantIntOp>(loc, 1,
1);
bool hasElseRegion = !op.getElseRegion().empty();
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
hasElseRegion);
rewriter.inlineRegionBefore(op.getThenRegion(),
&ifOp.getThenRegion().back());
rewriter.eraseBlock(&ifOp.getThenRegion().back());
if (hasElseRegion) {
rewriter.inlineRegionBefore(op.getElseRegion(),
&ifOp.getElseRegion().back());
rewriter.eraseBlock(&ifOp.getElseRegion().back());
}
rewriter.replaceOp(op, ifOp.getResults());
return success();
}
};
class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
public:
using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineApplyOp op,
PatternRewriter &rewriter) const override {
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<8>(op.getOperands()));
if (!maybeExpandedMap)
return failure();
rewriter.replaceOp(op, *maybeExpandedMap);
return success();
}
};
class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
public:
using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLoadOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
return failure();
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
*resultOperands);
return success();
}
};
class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
public:
using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffinePrefetchOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
return failure();
rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
op, op.getMemref(), *resultOperands, op.getIsWrite(),
op.getLocalityHint(), op.getIsDataCache());
return success();
}
};
class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
public:
using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineStoreOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
return failure();
rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
};
class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
public:
using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDmaStartOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);
auto maybeExpandedSrcMap = expandAffineMap(
rewriter, op.getLoc(), op.getSrcMap(),
operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
if (!maybeExpandedSrcMap)
return failure();
auto maybeExpandedDstMap = expandAffineMap(
rewriter, op.getLoc(), op.getDstMap(),
operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
if (!maybeExpandedDstMap)
return failure();
auto maybeExpandedTagMap = expandAffineMap(
rewriter, op.getLoc(), op.getTagMap(),
operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
if (!maybeExpandedTagMap)
return failure();
rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
*maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
*maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
return success();
}
};
class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
public:
using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDmaWaitOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getTagIndices());
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
if (!maybeExpandedTagMap)
return failure();
rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
return success();
}
};
class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
public:
using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineVectorLoadOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getMapOperands());
auto resultOperands =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!resultOperands)
return failure();
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getVectorType(), op.getMemRef(), *resultOperands);
return success();
}
};
class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
public:
using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineVectorStoreOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 8> indices(op.getMapOperands());
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
if (!maybeExpandedMap)
return failure();
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
return success();
}
};
}
void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
patterns.add<
AffineApplyLowering,
AffineDmaStartLowering,
AffineDmaWaitLowering,
AffineLoadLowering,
AffineMinLowering,
AffineMaxLowering,
AffineParallelLowering,
AffinePrefetchLowering,
AffineStoreLowering,
AffineForLowering,
AffineIfLowering,
AffineYieldOpLowering>(patterns.getContext());
}
void mlir::populateAffineToVectorConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<
AffineVectorLoadLowering,
AffineVectorStoreLowering>(patterns.getContext());
}
namespace {
class LowerAffinePass
: public impl::ConvertAffineToStandardBase<LowerAffinePass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns);
populateAffineToVectorConversionPatterns(patterns);
populateAffineExpandIndexOpsPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
scf::SCFDialect, VectorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
}
std::unique_ptr<Pass> mlir::createLowerAffinePass() {
return std::make_unique<LowerAffinePass>();
}