#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
using namespace mlir::vector;
static ArrayAttr
inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
const SmallVector<unsigned> &permutation) {
SmallVector<bool> newInBoundsValues(permutation.size());
size_t index = 0;
for (unsigned pos : permutation)
newInBoundsValues[pos] =
cast<BoolAttr>(attr.getValue()[index++]).getValue();
return builder.getBoolArrayAttr(newInBoundsValues);
}
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
auto originalVecType = cast<VectorType>(vec.getType());
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
SmallVector<bool> newScalableDims(addedRank, false);
newScalableDims.append(originalVecType.getScalableDims().begin(),
originalVecType.getScalableDims().end());
VectorType newVecType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}
static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
SmallVector<int64_t> permutation;
for (int64_t i = addedRank,
e = cast<VectorType>(broadcasted.getType()).getRank();
i < e; ++i)
permutation.push_back(i);
for (int64_t i = 0; i < addedRank; ++i)
permutation.push_back(i);
return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
}
namespace {
struct TransferReadPermutationLowering
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.getNumResults() == 0)
return rewriter.notifyMatchFailure(op, "0 result permutation map");
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op, "map is not permutable to minor identity, apply another pattern");
}
AffineMap permutationMap =
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return rewriter.notifyMatchFailure(op, "map is not identity");
permutationMap = map.getPermutationMap(permutation, op.getContext());
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
SmallVector<int64_t> newVectorShape(originalShape.size());
ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
SmallVector<bool> newScalableDims(originalShape.size());
for (const auto &pos : llvm::enumerate(permutation)) {
newVectorShape[pos.value()] = originalShape[pos.index()];
newScalableDims[pos.value()] = originalScalableDims[pos.index()];
}
ArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
VectorType newReadType = VectorType::get(
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
return rewriter
.create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
.getResult();
}
};
struct TransferWritePermutationLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.isMinorIdentity())
return rewriter.notifyMatchFailure(op, "map is already minor identity");
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op, "map is not permutable to minor identity, apply another pattern");
}
auto comp = compressUnusedDims(map);
AffineMap permutationMap = inversePermutation(comp);
SmallVector<int64_t> indices;
llvm::transform(permutationMap.getResults(), std::back_inserter(indices),
[](AffineExpr expr) {
return dyn_cast<AffineDimExpr>(expr).getPosition();
});
ArrayAttr newInBoundsAttr =
inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
auto newWrite = rewriter.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
return Value();
}
};
struct TransferWriteNonPermutationLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) {
return rewriter.notifyMatchFailure(
op,
"map is already permutable to minor identity, apply another pattern");
}
SmallVector<bool> foundDim(map.getNumDims(), false);
for (AffineExpr exp : map.getResults())
foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
SmallVector<AffineExpr> exprs;
bool foundFirstDim = false;
SmallVector<int64_t> missingInnerDim;
for (size_t i = 0; i < foundDim.size(); i++) {
if (foundDim[i]) {
foundFirstDim = true;
continue;
}
if (!foundFirstDim)
continue;
missingInnerDim.push_back(i);
exprs.push_back(rewriter.getAffineDimExpr(i));
}
Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
missingInnerDim.size());
Value newMask;
if (op.getMask())
newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
AffineMap newMap =
AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
auto newWrite = rewriter.create<vector::TransferWriteOp>(
op.getLoc(), newVec, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
return Value();
}
};
struct TransferOpReduceRank
: public MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp op,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
if (maskOp)
return rewriter.notifyMatchFailure(op, "Masked case not supported");
AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
if (!dimExpr || dimExpr.getValue() != 0)
break;
numLeadingBroadcast++;
}
if (numLeadingBroadcast == 0)
return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
VectorType originalVecType = op.getVectorType();
unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
AffineMap newMap = AffineMap::get(
map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
op.getContext());
if (!newMap.isMinorIdentityWithBroadcasting()) {
return rewriter.notifyMatchFailure(
op, "map is not a minor identity with broadcasting");
}
if (reducedShapeRank == 0) {
Value newRead;
if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
newRead = rewriter.create<memref::LoadOp>(
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
SmallVector<int64_t> newShape(
originalVecType.getShape().take_back(reducedShapeRank));
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
if (newShape.empty())
return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
VectorType newReadType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
ArrayAttr newInBoundsAttr =
op.getInBounds()
? rewriter.getArrayAttr(
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}
};
}
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadPermutationLowering, TransferWritePermutationLowering,
TransferOpReduceRank, TransferWriteNonPermutationLowering>(
patterns.getContext(), benefit);
}
namespace {
struct TransferReadToVectorLoadLowering
: public MaskableOpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
maxTransferRank(maxRank) {}
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferReadOp read,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
read, "vector type is greater than max transfer rank");
}
if (maskOp)
return rewriter.notifyMatchFailure(read, "Masked case not supported");
SmallVector<unsigned> broadcastedDims;
if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(read, "not a memref source");
if (!isLastMemrefDimUnitStride(memRefType))
return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
unbroadcastedVectorShape, read.getVectorType().getElementType());
auto memrefElTy = memRefType.getElementType();
if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
return rewriter.notifyMatchFailure(read, "incompatible element type");
if (!isa<VectorType>(memrefElTy) &&
memrefElTy != read.getVectorType().getElementType())
return rewriter.notifyMatchFailure(read, "non-matching element type");
if (read.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
Operation *res;
if (read.getMask()) {
if (read.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
res = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}
if (!broadcastedDims.empty())
res = rewriter.create<vector::BroadcastOp>(
read.getLoc(), read.getVectorType(), res->getResult(0));
return res->getResult(0);
}
std::optional<unsigned> maxTransferRank;
};
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
}
};
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(storeOp, "not single element vector");
Value extracted;
if (vecType.getRank() == 0) {
extracted = rewriter.create<vector::ExtractElementOp>(
storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};
struct TransferWriteToVectorStoreLowering
: public MaskableOpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
std::optional<unsigned> maxRank,
PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
maxTransferRank(maxRank) {}
FailureOr<mlir::Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp write,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
return rewriter.notifyMatchFailure(
write, "vector type is greater than max transfer rank");
}
if (maskOp)
return rewriter.notifyMatchFailure(write, "Masked case not supported");
if (
!write.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "permutation map is not minor identity: " << write;
});
auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "not a memref type: " << write;
});
if (!isLastMemrefDimUnitStride(memRefType))
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "most minor stride is not 1: " << write;
});
auto memrefElTy = memRefType.getElementType();
if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
if (!isa<VectorType>(memrefElTy) &&
memrefElTy != write.getVectorType().getElementType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
if (write.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "out of bounds dim: " << write;
});
if (write.getMask()) {
if (write.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(
write.getLoc(), [=](Diagnostic &diag) {
diag << "vector type is not rank 1, can't create masked store, "
"needs VectorToSCF: "
<< write;
});
rewriter.create<vector::MaskedStoreOp>(
write.getLoc(), write.getSource(), write.getIndices(),
write.getMask(), write.getVector());
} else {
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
write.getSource(), write.getIndices());
}
return Value();
}
std::optional<unsigned> maxTransferRank;
};
}
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
PatternBenefit benefit) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank, benefit);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext(), benefit);
}