#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::scf;
static bool isShapePreserving(ForOp forOp, int64_t arg) {
assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
"arg is out of bounds");
Value value = forOp.getYieldedValues()[arg];
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
OpResult opResult = dyn_cast<OpResult>(value);
if (!opResult)
return false;
using tensor::InsertSliceOp;
value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
.template Case<InsertSliceOp>(
[&](InsertSliceOp op) { return op.getDest(); })
.template Case<ForOp>([&](ForOp forOp) {
return isShapePreserving(forOp, opResult.getResultNumber())
? forOp.getInitArgs()[opResult.getResultNumber()]
: Value();
})
.Default([&](auto op) { return Value(); });
}
return false;
}
namespace {
template <typename OpTy>
struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
if (!forOp)
return failure();
if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
return failure();
Value initArg = forOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
};
};
template <typename OpTy>
struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
if (!forOp)
return failure();
auto opResult = cast<OpResult>(dimOp.getSource());
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
rewriter.modifyOpInPlace(dimOp, [&]() {
dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
});
return success();
}
};
template <typename OpTy>
struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop);
}
};
struct SCFForLoopCanonicalization
: public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
void runOnOperation() override {
auto *parentOp = getOperation();
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
signalPassFailure();
}
};
}
void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns
.add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
DimOfLoopResultFolder<tensor::DimOp>,
DimOfLoopResultFolder<memref::DimOp>>(ctx);
}
std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
return std::make_unique<SCFForLoopCanonicalization>();
}