#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "vector-transfer-opt"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
for (; op != nullptr && op->getParentRegion() != region;
op = op->getParentOp())
;
return op;
}
namespace {
class TransferOptimization {
public:
TransferOptimization(RewriterBase &rewriter, Operation *op)
: rewriter(rewriter), dominators(op), postDominators(op) {}
void deadStoreOp(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp);
void removeDeadOp() {
for (Operation *op : opToErase)
rewriter.eraseOp(op);
opToErase.clear();
}
private:
RewriterBase &rewriter;
bool isReachable(Operation *start, Operation *dest);
DominanceInfo dominators;
PostDominanceInfo postDominators;
std::vector<Operation *> opToErase;
};
}
bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
assert(start->getParentRegion() == dest->getParentRegion() &&
"This function only works for ops i the same region");
if (dominators.dominates(start, dest))
return true;
Block *startBlock = start->getBlock();
Block *destBlock = dest->getBlock();
SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
startBlock->succ_end());
SmallPtrSet<Block *, 32> visited;
while (!worklist.empty()) {
Block *bb = worklist.pop_back_val();
if (!visited.insert(bb).second)
continue;
if (dominators.dominates(bb, destBlock))
return true;
worklist.append(bb->succ_begin(), bb->succ_end());
}
return false;
}
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
<< "\n");
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source =
memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
if (!processed.insert(user).second)
continue;
if (isa<memref::SubViewOp, memref::CastOp>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
continue;
if (user == write.getOperation())
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
if (memref::isSameViewOrTrivialAlias(
cast<MemrefValue>(nextWrite.getSource()),
cast<MemrefValue>(write.getSource())) &&
checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
firstOverwriteCandidate = nextWrite;
else
assert(
postDominators.postDominates(nextWrite, firstOverwriteCandidate));
continue;
}
}
if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(transferOp.getOperation()),
true))
continue;
}
blockingAccesses.push_back(user);
}
if (firstOverwriteCandidate == nullptr)
return;
Region *topRegion = firstOverwriteCandidate->getParentRegion();
Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
assert(writeAncestor &&
"write op should be recursively part of the top region");
for (Operation *access : blockingAccesses) {
Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
if (accessAncestor == nullptr ||
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
<< *accessAncestor << "\n");
return;
}
}
LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
<< " overwritten by: " << *firstOverwriteCandidate << "\n");
opToErase.push_back(write.getOperation());
}
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source =
memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
if (!processed.insert(user).second)
continue;
if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
continue;
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation()),
true))
continue;
if (memref::isSameViewOrTrivialAlias(
cast<MemrefValue>(read.getSource()),
cast<MemrefValue>(write.getSource())) &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
else
assert(dominators.dominates(write, lastwrite));
continue;
}
}
blockingWrites.push_back(user);
}
if (lastwrite == nullptr)
return;
Region *topRegion = lastwrite->getParentRegion();
Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
assert(readAncestor &&
"read op should be recursively part of the top region");
for (Operation *write : blockingWrites) {
Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
<< *write << "\n");
return;
}
}
LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
<< " to: " << *read.getOperation() << "\n");
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
SmallVector<int64_t> reducedShape;
for (const auto size : mixedSizes) {
if (llvm::dyn_cast_if_present<Value>(size)) {
reducedShape.push_back(ShapedType::kDynamic);
continue;
}
auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
if (value == 1)
continue;
reducedShape.push_back(value.getSExtValue());
}
return reducedShape;
}
static MemRefType dropUnitDims(MemRefType inputType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto targetShape = getReducedShape(sizes);
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
}
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
MemRefType inputType = cast<MemRefType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
SmallVector<OpFoldResult> strides(inputType.getRank(),
rewriter.getIndexAttr(1));
MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
if (canonicalizeStridedLayout(resultType) ==
canonicalizeStridedLayout(inputType))
return input;
return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
sizes, strides);
}
static int getReducedRank(ArrayRef<int64_t> shape) {
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
}
static VectorType trimNonScalableUnitDims(VectorType oldType) {
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
continue;
newShape.push_back(dimSize);
newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
}
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
static FailureOr<Value>
createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
vector::CreateMaskOp op) {
auto type = op.getType();
VectorType reducedType = trimNonScalableUnitDims(type);
if (reducedType.getRank() == type.getRank())
return failure();
SmallVector<Value> reducedOperands;
for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
type.getShape(), type.getScalableDims(), op.getOperands())) {
if (dim == 1 && !dimIsScalable) {
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
return failure();
continue;
}
reducedOperands.push_back(operand);
}
return rewriter
.create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
.getResult();
}
namespace {
class TransferReadDropUnitDimsPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
if (!sourceType)
return failure();
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
Value maskOp = transferReadOp.getMask();
if (maskOp) {
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
if (failed(rankReducedCreateMask))
return failure();
maskOp = *rankReducedCreateMask;
}
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp);
rewriter.replaceOp(transferReadOp, shapeCast);
return success();
}
};
class TransferWriteDropUnitDimsPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
if (!sourceType)
return failure();
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
Value maskOp = transferWriteOp.getMask();
if (maskOp) {
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
transferWriteOp,
"unsupported mask op, only 'vector.create_mask' is "
"currently supported");
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
if (failed(rankReducedCreateMask))
return failure();
maskOp = *rankReducedCreateMask;
}
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
return success();
}
};
}
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
ShapedType inputType = cast<ShapedType>(input.getType());
if (inputType.getRank() == 1)
return input;
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < firstDimToCollapse; ++i)
reassociation.push_back(ReassociationIndices{i});
ReassociationIndices collapsedIndices;
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
collapsedIndices.push_back(i);
reassociation.push_back(collapsedIndices);
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
}
static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
Location loc,
ArrayRef<int64_t> shape,
ValueRange indices,
int64_t firstDimToCollapse) {
assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
SmallVector<Value> indicesAfterCollapsing(
indices.begin(), indices.begin() + firstDimToCollapse);
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
indices.end());
if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
indicesAfterCollapsing.push_back(indicesToCollapse[0]);
return indicesAfterCollapsing;
}
OpFoldResult collapsedOffset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
auto collapsedStrides = computeSuffixProduct(
ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
auto &&[collapsedExpr, collapsedVals] =
computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
collapsedOffset = affine::makeComposedFoldedAffineApply(
rewriter, loc, collapsedExpr, collapsedVals);
if (collapsedOffset.is<Value>()) {
indicesAfterCollapsing.push_back(collapsedOffset.get<Value>());
} else {
indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, *getConstantIntValue(collapsedOffset)));
}
return indicesAfterCollapsing;
}
namespace {
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
public:
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
auto source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
if (!sourceType)
return failure();
if (vectorType.getRank() <= 1)
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferReadOp.getMask())
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferReadOp.getIndices(), firstDimToCollapse);
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
}
private:
unsigned targetVectorBitwidth;
};
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
if (!sourceType)
return failure();
if (vectorType.getRank() <= 1)
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferWriteOp.getMask())
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferWriteOp.getIndices(), firstDimToCollapse);
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
vector::TransferWriteOp flatWrite =
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.eraseOp(transferWriteOp);
return success();
}
private:
unsigned targetVectorBitwidth;
};
template <class VectorExtractOp>
class RewriteScalarExtractOfTransferReadBase
: public OpRewritePattern<VectorExtractOp> {
using Base = OpRewritePattern<VectorExtractOp>;
public:
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
: Base::OpRewritePattern(context, benefit),
allowMultipleUses(allowMultipleUses) {}
LogicalResult match(VectorExtractOp extractOp) const override {
auto xferOp =
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
if (isa<VectorType>(extractOp.getResult().getType()))
return failure();
if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
return failure();
if (allowMultipleUses &&
!llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
return isa<vector::ExtractOp, vector::ExtractElementOp>(
use.getOwner());
}))
return failure();
if (xferOp.getMask())
return failure();
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
if (xferOp.hasOutOfBoundsDim())
return failure();
return success();
}
private:
bool allowMultipleUses;
};
class RewriteScalarExtractElementOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
void rewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
if (extractOp.getPosition()) {
AffineExpr sym0, sym1;
bindSymbols(extractOp.getContext(), sym0, sym1);
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 + sym1,
{newIndices[newIndices.size() - 1], extractOp.getPosition()});
if (ofr.is<Value>()) {
newIndices[newIndices.size() - 1] = ofr.get<Value>();
} else {
newIndices[newIndices.size() - 1] =
rewriter.create<arith::ConstantIndexOp>(loc,
*getConstantIntValue(ofr));
}
}
if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
}
};
class RewriteScalarExtractOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
void rewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
assert(pos.is<Attribute>() && "Unexpected non-constant index");
int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(),
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
if (ofr.is<Value>()) {
newIndices[idx] = ofr.get<Value>();
} else {
newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
extractOp.getLoc(), *getConstantIntValue(ofr));
}
}
if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
}
};
class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
auto vecType = xferOp.getVectorType();
if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
return failure();
if (xferOp.getMask())
return failure();
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
Value scalar;
if (vecType.getRank() == 0) {
scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
xferOp.getVector());
} else {
SmallVector<int64_t> pos(vecType.getRank(), 0);
scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
xferOp.getVector(), pos);
}
if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
} else {
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
}
return success();
}
};
}
void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
Operation *rootOp) {
TransferOptimization opt(rewriter, rootOp);
rootOp->walk([&](vector::TransferReadOp read) {
if (isa<MemRefType>(read.getShapedType()))
opt.storeToLoadForwarding(read);
});
opt.removeDeadOp();
rootOp->walk([&](vector::TransferWriteOp write) {
if (isa<MemRefType>(write.getShapedType()))
opt.deadStoreOp(write);
});
opt.removeDeadOp();
}
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
bool allowMultipleUses) {
patterns.add<RewriteScalarExtractElementOfTransferRead,
RewriteScalarExtractOfTransferRead>(patterns.getContext(),
benefit, allowMultipleUses);
patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
RewritePatternSet &patterns, unsigned targetVectorBitwidth,
PatternBenefit benefit) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), targetVectorBitwidth, benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
}