#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
struct FoldExpandOfRankReducingExtract
: public OpRewritePattern<ExpandShapeOp> {
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = expandShapeOp.getResultType();
auto extractSliceOp =
expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
if (!extractSliceOp)
return failure();
RankedTensorType srcType = extractSliceOp.getSourceType();
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
srcType, extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
if (nonReducingExtractType != resultType)
return failure();
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
mixedStrides);
return success();
}
};
struct FoldUnPaddingCollapseIntoExtract
: public OpRewritePattern<tensor::CollapseShapeOp> {
using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
PatternRewriter &rewriter) const override {
auto extractSliceOp =
collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp || !extractSliceOp->hasOneUse())
return failure();
SliceVerificationResult res = isRankReducedType(
collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
if (res != SliceVerificationResult::Success)
return rewriter.notifyMatchFailure(collapseShapeOp,
"expected unpadding collapse");
Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
return success();
}
};
template <typename OpTy>
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy insertSliceOp,
PatternRewriter &rewriter) const override {
auto collapseShapeOp =
insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
if (!collapseShapeOp)
return failure();
RankedTensorType srcType = collapseShapeOp.getSrcType();
RankedTensorType nonReducingInsertType =
RankedTensorType::get(insertSliceOp.getStaticSizes(),
insertSliceOp.getDestType().getElementType());
if (nonReducingInsertType != srcType)
return failure();
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
insertSliceOp.getDest(), mixedOffsets,
mixedSizes, mixedStrides);
return success();
}
};
template <typename OpTy>
struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy insertSliceOp,
PatternRewriter &rewriter) const override {
auto expandShapeOp = insertSliceOp.getSource()
.template getDefiningOp<tensor::ExpandShapeOp>();
if (!expandShapeOp)
return failure();
SliceVerificationResult res = isRankReducedType(
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
if (res != SliceVerificationResult::Success)
return rewriter.notifyMatchFailure(insertSliceOp,
"expected rank increasing expansion");
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
});
return success();
}
};
}
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
patterns
.add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
}