#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <type_traits>
namespace mlir {
namespace tensor {
#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
}
}
using namespace mlir;
static Value getTensorOperand(vector::TransferReadOp op) {
return op.getSource();
}
static Value getTensorOperand(tensor::InsertSliceOp op) {
return op.getSource();
}
namespace {
class TransferReadOfExtractSliceOpFolder final
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
vector::MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
};
class InsertSliceOfTransferWriteOpFolder final
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override;
};
}
template <typename XferOp, typename ExtractOrInsertOp>
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
RewriterBase &rewriter, XferOp xferOp,
ExtractOrInsertOp extractOrInsertSliceOp) {
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
if (xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "masked transfer");
if (!extractOrInsertSliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride insert/extract, requires keeping track of "
"strides, this may result in needing to insert "
"vector.insert_strided_slice/extract_strided_slice ops");
}
return success();
}
FailureOr<mlir::Value>
TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
auto extractSliceOp =
getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
if (!extractSliceOp)
return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
LogicalResult preconditionResult =
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
extractSliceOp);
if (failed(preconditionResult))
return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
SmallVector<Value> indices(readOp.getIndices().begin(),
readOp.getIndices().end());
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
indices, sourceIndices);
Operation *newOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
sourceIndices,
AffineMapAttr::get(expandDimsToRank(
readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
extractSliceOp.getDroppedDims())),
readOp.getPadding(),
Value(), readOp.getInBoundsAttr());
if (maskOp)
newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
return newOp->getResults()[0];
}
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
auto writeOp = getTensorOperand(insertSliceOp)
.template getDefiningOp<vector::TransferWriteOp>();
if (!writeOp)
return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
LogicalResult preconditionResult =
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
insertSliceOp);
if (failed(preconditionResult))
return preconditionResult;
SmallVector<Value> indices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
sourceIndices);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
insertSliceOp.getDestType().getRank(),
insertSliceOp.getDroppedDims())),
writeOp.getInBoundsAttr());
return success();
}
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy insertSliceOp,
PatternRewriter &rewriter) const override {
auto sourceInsertSliceOp =
insertSliceOp.getSource()
.template getDefiningOp<tensor::InsertSliceOp>();
if (!sourceInsertSliceOp)
return failure();
if (!insertSliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(insertSliceOp,
"requires unit strides");
}
if (!sourceInsertSliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(sourceInsertSliceOp,
"requires unit strides");
}
int64_t srcDim = 0;
llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
if (droppedDims[d])
continue;
if (insertSliceOp.getMixedSizes()[d] !=
sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
return rewriter.notifyMatchFailure(
sourceInsertSliceOp,
"requires matching sizes to fold, otherwise a copy is needed");
}
}
SmallVector<OpFoldResult> resolvedSizes;
affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
sourceInsertSliceOp.getMixedSizes(),
droppedDims, resolvedSizes);
if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
rewriter.setInsertionPoint(
insertSliceOp->template getParentOfType<scf::InParallelOp>());
}
SmallVector<Value> resolvedOffsets;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedStrides(), droppedDims,
sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
rewriter.setInsertionPoint(insertSliceOp);
rewriter.replaceOpWithNewOp<OpTy>(
insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
getAsOpFoldResult(resolvedOffsets), resolvedSizes,
insertSliceOp.getMixedStrides());
return success();
}
};
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
}
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadOfExtractSliceOpFolder,
InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
}
namespace {
struct FoldTensorSubsetOpsPass final
: public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
void runOnOperation() override;
};
}
void FoldTensorSubsetOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
tensor::populateFoldTensorSubsetOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
return std::make_unique<FoldTensorSubsetOpsPass>();
}