#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "lower-vector-mask"
namespace mlir {
namespace vector {
#define GEN_PASS_DEF_LOWERVECTORMASKPASS
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::vector;
namespace {
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = cast<VectorType>(op.getResult().getType());
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
auto loc = op.getLoc();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
loc, lowType, rewriter.getZeroAttr(lowType));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < dim; d++) {
Value bnd =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
}
};
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
value));
return success();
}
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, trueDimSize != 0));
} else {
SmallVector<bool> values(dstType.getDimSize(0), false);
for (int64_t d = 0; d < trueDimSize; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
}
return success();
}
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
rewriter.replaceOp(op, result);
return success();
}
};
}
void mlir::vector::populateVectorMaskOpLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
patterns.getContext(), benefit);
}
namespace {
template <class SourceOp>
struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
using OpRewritePattern<MaskOp>::OpRewritePattern;
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
if (!maskableOp)
return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
}
protected:
virtual LogicalResult
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const = 0;
};
struct MaskedTransferReadOpPattern
: public MaskOpRewritePattern<TransferReadOp> {
public:
using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
if (maskingOp.hasPassthru())
return rewriter.notifyMatchFailure(
maskingOp, "Can't lower passthru to vector.transfer_read");
rewriter.replaceOpWithNewOp<TransferReadOp>(
maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
maskingOp.getMask(), readOp.getInBounds());
return success();
}
};
struct MaskedTransferWriteOpPattern
: public MaskOpRewritePattern<TransferWriteOp> {
public:
using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(TransferWriteOp writeOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
Type resultType =
writeOp.getResult() ? writeOp.getResult().getType() : Type();
rewriter.replaceOpWithNewOp<TransferWriteOp>(
maskingOp.getOperation(), resultType, writeOp.getVector(),
writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
maskingOp.getMask(), writeOp.getInBounds());
return success();
}
};
struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
public:
using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
Value passthru = maskingOp.hasPassthru()
? maskingOp.getPassthru()
: rewriter.create<arith::ConstantOp>(
gatherOp.getLoc(),
rewriter.getZeroAttr(gatherOp.getVectorType()));
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
passthru);
return success();
}
};
struct LowerVectorMaskPass
: public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
using Base::Base;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<vector::VectorDialect>();
}
};
}
void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns) {
patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
MaskedGatherOpPattern>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
return std::make_unique<LowerVectorMaskPass>();
}