#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#include <type_traits>
using namespace mlir;
using namespace mlir::linalg;
#define DEBUG_TYPE "linalg-vectorization"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
ArrayRef<int64_t> inputVecSizes = {},
ArrayRef<bool> inputVecScalableFlags = {},
bool flatten1DDepthwiseConv = false);
template <typename OpType>
static OpType getSingleOpOfType(Block &block) {
OpType res;
block.walk([&](OpType op) {
if (res) {
res = nullptr;
return WalkResult::interrupt();
}
res = op;
return WalkResult::advance();
});
return res;
}
static SmallVector<Value>
extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
int64_t nSize, int64_t wSize, int64_t cSize,
int64_t kwSize, int strideW, int dilationW,
int64_t wSizeStep, bool isSingleChanneled) {
SmallVector<Value> result;
if (isSingleChanneled) {
SmallVector<int64_t> sizes{wSizeStep};
SmallVector<int64_t> strides{1};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, input, ArrayRef<int64_t>{w + kw}, sizes, strides));
}
}
} else {
SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
SmallVector<int64_t> strides{1, 1, 1};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, input,
ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
sizes, strides));
}
}
}
return result;
}
static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
Location loc, Value filter,
int64_t kwSize) {
SmallVector<Value> result;
for (int64_t kw = 0; kw < kwSize; ++kw) {
result.push_back(rewriter.create<vector::ExtractOp>(
loc, filter, ArrayRef<int64_t>{kw}));
}
return result;
}
static SmallVector<Value>
extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
int64_t nSize, int64_t wSize, int64_t fSize,
int64_t wSizeStep, bool isSingleChanneled) {
SmallVector<Value> result;
if (isSingleChanneled) {
SmallVector<int64_t> sizes{wSizeStep};
SmallVector<int64_t> strides{1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res, ArrayRef<int64_t>{w}, sizes, strides));
}
} else {
SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
SmallVector<int64_t> strides{1, 1, 1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, res, ArrayRef<int64_t>{0, w, 0}, sizes, strides));
}
}
return result;
}
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
Value res, int64_t wSize, int64_t wSizeStep,
SmallVectorImpl<Value> &resVals,
bool isSingleChanneled) {
if (isSingleChanneled) {
SmallVector<int64_t> strides{1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
res = rewriter.create<vector::InsertStridedSliceOp>(
loc, resVals[w], res, ArrayRef<int64_t>{w}, strides);
}
} else {
SmallVector<int64_t> strides{1, 1, 1};
for (int64_t w = 0; w < wSize; w += wSizeStep) {
res = rewriter.create<vector::InsertStridedSliceOp>(
loc, resVals[w], res, ArrayRef<int64_t>{0, w, 0},
strides);
}
}
return res;
}
struct VectorizationState {
VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims);
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
VectorType getCanonicalVecType(
Type elementType,
std::optional<AffineMap> dimPermutation = std::nullopt) const {
SmallVector<int64_t> vectorShape;
SmallVector<bool> scalableDims;
if (dimPermutation.has_value()) {
vectorShape =
applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
scalableDims =
applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
} else {
vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
}
return VectorType::get(vectorShape, elementType, scalableDims);
}
Operation *
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
private:
void initIterSpaceStaticSizes(LinalgOp linalgOp) {
iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
}
LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
LinalgOp linalgOp);
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap);
SmallVector<int64_t> iterSpaceStaticSizes;
SmallVector<Value> iterSpaceValueSizes;
SmallVector<int64_t> canonicalVecShape;
SmallVector<bool> scalableVecDims;
DenseMap<AffineMap, Value> activeMaskCache;
OpBuilder::InsertionGuard rewriterGuard;
};
LogicalResult
VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
LinalgOp linalgOp) {
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
continue;
}
Value operand;
unsigned operandDimPos;
if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
operandDimPos)))
return failure();
Value dynamicDim = linalgOp.hasPureTensorSemantics()
? (Value)rewriter.create<tensor::DimOp>(
linalgOp.getLoc(), operand, operandDimPos)
: (Value)rewriter.create<memref::DimOp>(
linalgOp.getLoc(), operand, operandDimPos);
iterSpaceValueSizes.push_back(dynamicDim);
}
return success();
}
LogicalResult
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
rewriter.setInsertionPoint(linalgOp);
if (!inputVectorSizes.empty()) {
canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
scalableVecDims.append(inputScalableVecDims.begin(),
inputScalableVecDims.end());
} else {
canonicalVecShape = linalgOp.getStaticLoopRanges();
scalableVecDims.append(linalgOp.getNumLoops(), false);
}
LDBG("Canonical vector shape: ");
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
LDBG("Scalable vector dims: ");
LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (ShapedType::isDynamicShape(canonicalVecShape))
return failure();
initIterSpaceStaticSizes(linalgOp);
if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
return failure();
return success();
}
Value VectorizationState::getOrCreateMaskFor(
RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap) {
auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
if (!maskableOp)
return Value();
assert(!maskableOp.isMasked() &&
"Masking an operation that is already masked");
assert((!maybeMaskingMap || *maybeMaskingMap) &&
"Unexpected null mask permutation map");
AffineMap maskingMap =
maybeMaskingMap ? *maybeMaskingMap
: AffineMap::getMultiDimIdentityMap(
linalgOp.getNumLoops(), rewriter.getContext());
LDBG("Masking map: " << maskingMap << "\n");
auto activeMaskIt = activeMaskCache.find(maskingMap);
if (activeMaskIt != activeMaskCache.end()) {
Value mask = activeMaskIt->second;
LDBG("Reusing mask: " << mask << "\n");
return mask;
}
SmallVector<int64_t> permutedStaticSizes =
applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
auto maskShape = maskType.getShape();
LDBG("Mask shape: ");
LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (permutedStaticSizes == maskShape) {
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
activeMaskCache[maskingMap] = Value();
return Value();
}
SmallVector<Value> upperBounds =
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
assert(!maskShape.empty() && !upperBounds.empty() &&
"Masked 0-d vectors are not supported yet");
Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
maskType, upperBounds);
LDBG("Creating new mask: " << mask << "\n");
activeMaskCache[maskingMap] = mask;
return mask;
}
Operation *
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap) {
LDBG("Trying to mask: " << *opToMask << "\n");
Value mask =
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
if (!mask) {
LDBG("No mask required\n");
return opToMask;
}
assert(opToMask && "Expected a valid operation to mask");
auto maskOp = cast<vector::MaskOp>(
mlir::vector::maskOperation(rewriter, opToMask, mask));
Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
maskOpTerminator);
LDBG("Masked operation: " << *maskOp << "\n");
return maskOp;
}
static AffineMap reindexIndexingMap(AffineMap map) {
assert(map.isProjectedPermutation(true) &&
"expected projected permutation");
auto res = compressUnusedDims(map);
assert(res.getNumDims() == res.getNumResults() &&
"expected reindexed map with same number of dims and results");
return res;
}
enum class Conv1DOpOrder {
W,
Ncw,
Nwc
};
enum VectorizationStatus {
Failure = 0,
NoReplace,
NewOp
};
struct VectorizationResult {
enum VectorizationStatus status = VectorizationStatus::Failure;
Operation *newOp;
};
std::optional<vector::CombiningKind>
mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
using ::mlir::vector::CombiningKind;
if (!combinerOp)
return std::nullopt;
return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
.Case<arith::AddIOp, arith::AddFOp>(
[&](auto op) { return CombiningKind::ADD; })
.Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
.Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
.Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
.Case<arith::MulIOp, arith::MulFOp>(
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
.Default([&](auto op) { return std::nullopt; });
}
static Operation *matchLinalgReduction(OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
combinerOps.size() != 1)
return nullptr;
return combinerOps[0];
}
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
auto dstVecType = dyn_cast<VectorType>(dstType);
if (dstVecType.getRank() == 0)
return value;
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
}
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value valueToReduce, Value acc,
ArrayRef<bool> dimsToMask) {
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
}
static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
return llvm::to_vector(
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
}
static Value buildVectorWrite(RewriterBase &rewriter, Value value,
OpOperand *outputOperand,
VectorizationState &state) {
Location loc = value.getLoc();
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
opOperandMap.getContext(), opOperandMap.getNumInputs(),
[&](AffineDimExpr dimExpr) -> bool {
return llvm::is_contained(opOperandMap.getResults(), dimExpr);
});
auto vectorType = state.getCanonicalVecType(
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
Operation *write;
if (vectorType.getRank() > 0) {
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
value = broadcastIfNeeded(rewriter, value, vectorType);
assert(value.getType() == vectorType && "Incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), indices, writeMap);
} else {
if (!isa<VectorType>(value.getType()))
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
assert(value.getType() == vectorType && "Incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
loc, value, outputOperand->get(), ValueRange{});
}
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
LDBG("vectorized op: " << *write << "\n");
if (!write->getResults().empty())
return write->getResult(0);
return Value();
}
using CustomVectorizationPrecondition =
std::function<LogicalResult(Operation *, bool)>;
using CustomVectorizationHook =
std::function<VectorizationResult(Operation *, const IRMapping &)>;
static VectorizationResult
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
const IRMapping &bvm, VectorizationState &state,
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
if (!yieldOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
Value vectorValue = bvm.lookup(output.value());
Value newResult =
buildVectorWrite(rewriter, vectorValue,
linalgOp.getDpsInitOperand(output.index()), state);
if (newResult)
newResults.push_back(newResult);
}
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
}
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
VectorizationState &state,
Operation *op,
LinalgOp linalgOp) {
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
if (!indexOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = indexOp.getLoc();
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
auto dim = indexOp.getDim();
auto indexVectorType =
VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
state.getScalableVecDims()[dim]);
auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
if (dim == targetShape.size() - 1)
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
auto permPattern =
llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
std::swap(permPattern[dim], permPattern.back());
auto permMap =
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
auto broadCastOp = rewriter.create<vector::BroadcastOp>(
loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
indexSteps);
SmallVector<int64_t> transposition =
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
std::swap(transposition.back(), transposition[dim]);
auto transposeOp =
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}
static LogicalResult
tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return failure();
if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
return failure();
if (not extractOp.getIndices().empty()) {
if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
return failure();
}
if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}
return success();
}
static Value calculateGatherOffset(RewriterBase &rewriter,
VectorizationState &state,
tensor::ExtractOp extractOp,
const IRMapping &bvm) {
auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
auto loc = extractOp.getLoc();
Value offset = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
auto dimSize = broadcastIfNeeded(
rewriter,
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
indexVecType);
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
auto extractOpIndex = broadcastIfNeeded(
rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
}
return offset;
}
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
auto targetShape = linalgOp.getStaticLoopRanges();
assert(((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
return llvm::all_of(block->getArguments(),
[&val](Value v) { return (v != val); });
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return (indexOp.getDim() != trailingLoopDim);
auto *ancestor = block->findAncestorOpInBlock(*defOp);
if (!ancestor)
return true;
if (isa<arith::ConstantOp>(ancestor))
return true;
bool result = true;
for (auto op : ancestor->getOperands())
result &= isLoopInvariantIdx(linalgOp, op);
return result;
}
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool &foundIndexOp) {
auto targetShape = linalgOp.getStaticLoopRanges();
assert(((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
return llvm::all_of(block->getArguments(),
[&val](Value v) { return (v != val); });
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
return true;
}
auto *ancestor = block->findAncestorOpInBlock(*defOp);
if (!ancestor)
return false;
if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
return false;
bool result = false;
for (auto op : ancestor->getOperands())
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
return result;
}
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
LinalgOp &linalgOp) {
auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;
if ((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
bool leadingIdxsLoopInvariant = true;
auto indices = extractOp.getIndices();
auto leadIndices = indices.drop_back(1);
for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
if (inputShape.getShape()[i] == 1)
continue;
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
}
if (!leadingIdxsLoopInvariant) {
LDBG("Found gather load: " << extractOp);
return VectorMemoryAccessKind::Gather;
}
auto extractOpTrailingIdx = indices.back();
if (leadingIdxsLoopInvariant &&
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
}
bool foundIndexOp = false;
bool isContiguousLoad =
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
isContiguousLoad &= foundIndexOp;
if (isContiguousLoad) {
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}
LDBG("Found gather load: " << extractOp);
return VectorMemoryAccessKind::Gather;
}
static VectorizationResult
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();
auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
loc,
DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
true));
auto passThruConstantOp =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
SmallVector<Value> baseIndices(
extractOp.getIndices().size(),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
VectorMemoryAccessKind memAccessKind =
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
if (memAccessKind == VectorMemoryAccessKind::Gather) {
Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
Operation *gatherOp = rewriter.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
LDBG("Vectorised as gather load: " << extractOp << "\n");
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}
SmallVector<Value> transferReadIdxs;
auto resTrailingDim = resultType.getShape().back();
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
auto idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
transferReadIdxs.push_back(idx);
continue;
}
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
bvm.lookup(extractOp.getIndices()[i]));
transferReadIdxs.push_back(
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
}
auto dstRank = resultType.getRank();
auto srcRank = extractOp.getTensor().getType().getRank();
SmallVector<bool> inBounds(dstRank, true);
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs,
permutationMap, inBounds);
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
}
auto permutationMap = AffineMap::getMinorIdentityMap(
srcRank, std::min(dstRank, srcRank), rewriter.getContext());
int32_t rankDiff = dstRank - srcRank;
while (rankDiff > 0) {
permutationMap = permutationMap.insertResult(
mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
rankDiff--;
}
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
}
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
Value reduceValue, Value initialValue,
const IRMapping &bvm) {
Value reduceVec = bvm.lookup(reduceValue);
Value outputVec = bvm.lookup(initialValue);
auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
auto outputType = dyn_cast<VectorType>(outputVec.getType());
if (!reduceType ||
(outputType && reduceType.getShape() == outputType.getShape()))
return nullptr;
SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
}
static VectorizationResult
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
LDBG("vectorize op " << *op << "\n");
if (!customVectorizationHooks.empty()) {
for (auto &customFunc : customVectorizationHooks) {
VectorizationResult result = customFunc(op, bvm);
if (result.status == VectorizationStatus::Failure)
continue;
return result;
}
}
if (isa<arith::ConstantOp, func::ConstantOp>(op))
return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone(*op)};
if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
auto blockArg = dyn_cast<BlockArgument>(operand);
if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
continue;
SmallVector<Operation *> reductionOps;
Value reduceValue = matchReduction(
linalgOp.getRegionOutputArgs(),
blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
if (!reduceValue)
continue;
reductionOperands.push_back(std::make_pair(reduceValue, operand));
}
if (!reductionOperands.empty()) {
assert(reductionOperands.size() == 1);
Operation *reduceOp =
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
reductionOperands[0].second, bvm);
if (reduceOp)
return VectorizationResult{VectorizationStatus::NewOp, reduceOp};
}
VectorType firstMaxRankedType;
for (Value operand : op->getOperands()) {
auto vecOperand = bvm.lookup(operand);
assert(vecOperand && "Vector operand couldn't be found");
auto vecType = dyn_cast<VectorType>(vecOperand.getType());
if (vecType && (!firstMaxRankedType ||
firstMaxRankedType.getRank() < vecType.getRank()))
firstMaxRankedType = vecType;
}
SmallVector<Value> vecOperands;
for (Value scalarOperand : op->getOperands()) {
Value vecOperand = bvm.lookup(scalarOperand);
assert(vecOperand && "Vector operand couldn't be found");
if (firstMaxRankedType) {
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
getElementTypeOrSelf(vecOperand.getType()),
firstMaxRankedType.getScalableDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
vecOperands.push_back(vecOperand);
}
}
SmallVector<Type> resultTypes;
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
firstMaxRankedType
? VectorType::get(firstMaxRankedType.getShape(), resultType,
firstMaxRankedType.getScalableDims())
: resultType);
}
return VectorizationResult{
VectorizationStatus::NewOp,
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
resultTypes, op->getAttrs())};
}
static LogicalResult
vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
LDBG("Vectorizing operation as linalg generic\n");
Block *block = linalgOp.getBlock();
IRMapping bvm;
SetVector<Value> valuesSet;
mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
if (linalgOp.getNumDpsInits() == 0)
return failure();
Location loc = linalgOp.getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
if (linalgOp.isScalar(opOperand)) {
bvm.map(bbarg, opOperand->get());
continue;
}
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
SmallVector<int64_t> zeroPos;
auto results = indexingMap.getResults();
for (const auto &result : llvm::enumerate(results)) {
if (isa<AffineConstantExpr>(result.value())) {
zeroPos.push_back(result.index());
}
}
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
AffineMap readMap;
VectorType readType;
Type elemType = getElementTypeOrSelf(opOperand->get());
if (linalgOp.isDpsInput(opOperand)) {
readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
readType = state.getCanonicalVecType(elemType);
} else {
readMap = inversePermutation(reindexIndexingMap(indexingMap));
readType =
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
}
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
SmallVector<unsigned> broadcastedDims = readMap.getBroadcastDims();
SmallVector<bool> inBounds(readType.getRank(), false);
for (auto idx : broadcastedDims)
inBounds[idx] = true;
Operation *read = rewriter.create<vector::TransferReadOp>(
loc, readType, opOperand->get(), indices, readMap,
ArrayRef<bool>(inBounds));
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
Value readValue = read->getResult(0);
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
SmallVector<bool> inBounds(readType.getRank(), true);
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
if (readType.getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
bvm.map(bbarg, readValue);
bvm.map(opOperand->get(), readValue);
}
SmallVector<CustomVectorizationHook> hooks;
CustomVectorizationHook vectorizeYield =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
};
hooks.push_back(vectorizeYield);
CustomVectorizationHook vectorizeIndex =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
};
hooks.push_back(vectorizeIndex);
CustomVectorizationHook vectorizeExtract =
[&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);
for (Operation &op : block->getOperations()) {
VectorizationResult result =
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
if (result.status == VectorizationStatus::Failure) {
LDBG("failed to vectorize: " << op << "\n");
return failure();
}
if (result.status == VectorizationStatus::NewOp) {
Operation *maybeMaskedOp =
state.maskOperation(rewriter, result.newOp, linalgOp);
LDBG("New vector op: " << *maybeMaskedOp << "\n");
bvm.map(op.getResults(), maybeMaskedOp->getResults());
}
}
return success();
}
static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
ArrayRef<int64_t> destShape) {
return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
}
static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
Value input,
SmallVector<OpFoldResult> destSizes,
ArrayRef<int64_t> inputVectorSizes,
bool useInBoundsInsteadOfMasking) {
auto inputType = cast<VectorType>(input.getType());
Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
inputType.getElementType());
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
auto destShape = cast<ShapedType>(dest.getType()).getShape();
SmallVector<bool> inBoundsVal(rank, true);
if (useInBoundsInsteadOfMasking) {
for (unsigned i = 0; i < rank; i++)
inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
input,
dest,
SmallVector<Value>(rank, zero),
inBoundsVal);
assert(llvm::none_of(
destShape.drop_front(inputVectorSizes.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVectorSizes may be dynamic");
if (useInBoundsInsteadOfMasking)
return write;
bool needMaskForWrite = !llvm::equal(
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
destShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
Value maskForWrite =
builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
return write;
}
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
Location loc = packOp.getLoc();
auto padValue = packOp.getPaddingValue();
if (!padValue) {
padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status;
assert(succeeded(status) && "failed to reify result shapes");
bool useInBoundsInsteadOfMasking = false;
if (inputVectorSizes.empty()) {
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
useInBoundsInsteadOfMasking = true;
}
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
auto innerDimsPos = packOp.getInnerDimsPos();
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty())
applyPermutationToVector(inputShape,
invertPermutationVector(outerDimsPerm));
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
useInBoundsInsteadOfMasking);
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
packOp.getDestType().getElementType());
auto shapeCastOp =
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
auto destPermutation =
invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getResult(), destPermutation);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
inputVectorSizes, false);
newResults.push_back(write->getResult(0));
return success();
}
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
RankedTensorType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
bool useInBoundsInsteadOfMasking = false;
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
auto destSize = unpackOp.getDestRank();
if (!inputVectorSizes.empty())
assert(inputVectorSizes.size() == destSize &&
"Incorrect number of input vector sizes");
SmallVector<int64_t> vectorSizes(inputVectorSizes);
if (vectorSizes.empty()) {
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
if (!outerDimsPerm.empty())
applyPermutationToVector(vectorSizes, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
vectorSizes[pos] *= innerTiles[i];
useInBoundsInsteadOfMasking = true;
}
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
for (auto [index, size] : enumerate(innerTiles)) {
readVectorSizes[innerDimPos[index]] =
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(readVectorSizes, outerDimsPerm);
}
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
ReifiedRankedShapedTypeDims reifiedRetShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
.reifyResultShapes(rewriter, reifiedRetShapes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << unpackOp);
return failure();
}
Location loc = unpackOp->getLoc();
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
false);
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
RankedTensorType::get(stripMineShape, stripMineElemType);
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
loc, readResult, lastDimToInsertPosPerm);
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));
SmallVector<int64_t> writeVectorSizes(
unpackOp.getDestType().hasStaticShape()
? vectorSizes
: shapeCastOp.getResultVectorType().getShape());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
writeVectorSizes, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
}
static LogicalResult
vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes,
SmallVectorImpl<Value> &newResults) {
auto padValue = padOp.getConstantPaddingValue();
Location loc = padOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(padOp);
ReifiedRankedShapedTypeDims reifiedReturnShapes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
.reifyResultShapes(rewriter, reifiedReturnShapes);
(void)status;
assert(succeeded(status) && "failed to reify result shapes");
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
false);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
false);
newResults.push_back(write->getResult(0));
return success();
}
static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
LDBG("reduction precondition failed: no reduction iterator\n");
return failure();
}
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
if (indexingMap.isPermutation())
continue;
Operation *reduceOp = matchLinalgReduction(&opOperand);
if (!reduceOp || !getCombinerOpKind(reduceOp)) {
LDBG("reduction precondition failed: reduction detection failed\n");
return failure();
}
}
return success();
}
static LogicalResult
vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
bool flatten1DDepthwiseConv) {
if (flatten1DDepthwiseConv) {
LDBG("Vectorization of flattened convs with dynamic shapes is not "
"supported\n");
return failure();
}
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
return failure();
}
Value lhs = conv.getDpsInputOperand(0)->get();
ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
auto shapeWithoutCh = lhsShape.drop_back(1);
if (ShapedType::isDynamicShape(shapeWithoutCh)) {
LDBG("Dynamically-shaped op vectorization precondition failed: only "
"channel dim can be dynamic\n");
return failure();
}
return success();
}
static LogicalResult
vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
bool flatten1DDepthwiseConv) {
if (isa<ConvolutionOpInterface>(op.getOperation()))
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
if (!isElementwise(op) &&
!isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
op.getOperation()))
return failure();
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
return success();
}
static LogicalResult
vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
return failure();
}
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
bool satisfyEmptyCond = inputVectorSizes.empty() &&
unpackOp.getDestType().hasStaticShape() &&
unpackOp.getSourceType().hasStaticShape();
if (!satisfyEmptyCond &&
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
return failure();
return success();
}
static LogicalResult vectorizeLinalgOpPrecondition(
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
if (llvm::is_contained(linalgOp.getStaticShape(), 0))
return failure();
if (!inputVectorSizes.empty() &&
failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
inputVectorSizes)))
return failure();
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
linalgOp, flatten1DDepthwiseConv))) {
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
return failure();
}
SmallVector<CustomVectorizationPrecondition> customPreconditions;
customPreconditions.push_back(tensorExtractVectorizationPrecondition);
for (Operation &innerOp : linalgOp->getRegion(0).front()) {
if (llvm::any_of(
customPreconditions,
[&](const CustomVectorizationPrecondition &customPrecondition) {
return succeeded(
customPrecondition(&innerOp, vectorizeNDExtract));
})) {
continue;
}
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}
if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}
}
if (isElementwise(linalgOp))
return success();
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
return success();
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
LDBG("precondition failed: not projected permutations\n");
return failure();
}
if (failed(reductionPreconditions(linalgOp))) {
LDBG("precondition failed: reduction preconditions\n");
return failure();
}
return success();
}
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
Attribute cstAttr;
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
bool satisfyEmptyCond = true;
if (inputVectorSizes.empty()) {
if (!packOp.getDestType().hasStaticShape() ||
!packOp.getSourceType().hasStaticShape())
satisfyEmptyCond = false;
}
if (!satisfyEmptyCond &&
failed(vector::isValidMaskedInputVector(
resultTensorShape.take_front(packOp.getSourceRank()),
inputVectorSizes)))
return failure();
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
LDBG("inner_tiles must be constant: " << packOp << "\n");
return failure();
}
return success();
}
static LogicalResult
vectorizePadOpPrecondition(tensor::PadOp padOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
LDBG("pad value is not constant: " << padOp << "\n");
return failure();
}
ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
if (failed(vector::isValidMaskedInputVector(resultTensorShape,
inputVectorSizes)))
return failure();
if (llvm::any_of(padOp.getLow(), [](Value v) {
std::optional<int64_t> res = getConstantIntValue(v);
return !res.has_value() || res.value() != 0;
})) {
LDBG("low pad must all be zero: " << padOp << "\n");
return failure();
}
return success();
}
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
"Number of input vector sizes and scalable dims doesn't match");
size_t numOfScalableDims =
llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
if (numOfScalableDims == 0)
return success();
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (numOfScalableDims > 2)
return failure();
bool seenParalell = false;
auto iterators = linalgOp.getIteratorTypesArray();
SmallVector<bool> scalableFlags(inputScalableVecDims);
while (!scalableFlags.back()) {
seenParalell |= (iterators.back() == utils::IteratorType::parallel);
iterators.pop_back();
scalableFlags.pop_back();
}
if (iterators.back() == utils::IteratorType::reduction)
return failure();
if (seenParalell)
return failure();
if (numOfScalableDims == 2) {
scalableFlags.pop_back();
iterators.pop_back();
if (!scalableFlags.back() ||
(iterators.back() != utils::IteratorType::parallel))
return failure();
}
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
inputScalableVecDims)))
return failure();
return TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
vectorizeNDExtract,
flatten1DDepthwiseConv);
})
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
})
.Default([](auto) { return failure(); });
}
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
OpBuilder::InsertionGuard g(rewriter);
auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
for (auto op : make_early_inc_range(toReplace)) {
rewriter.setInsertionPoint(op);
auto expanded = affine::expandAffineExpr(
rewriter, op->getLoc(), op.getAffineMap().getResult(0),
op.getOperands().take_front(op.getAffineMap().getNumDims()),
op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
rewriter.replaceOp(op, expanded);
}
}
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
LDBG("Input scalable vector dims: ");
LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
vectorizeNDExtract,
flatten1DDepthwiseConv))) {
LDBG("Vectorization pre-conditions failed\n");
return failure();
}
VectorizationState state(rewriter);
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
inputScalableVecDims))) {
LDBG("Vectorization state couldn't be initialized\n");
return failure();
}
}
SmallVector<Value> results;
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
}
LDBG("Unsupported convolution can't be vectorized.\n");
return failure();
}
LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");
convertAffineApply(rewriter, linalgOp);
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
})
.Case<tensor::PadOp>([&](auto padOp) {
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
results);
})
.Case<tensor::PackOp>([&](auto packOp) {
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
results);
})
.Case<tensor::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
})
.Default([](auto) { return failure(); });
if (failed(vectorizeResult)) {
LDBG("Vectorization failed\n");
return failure();
}
if (!results.empty())
rewriter.replaceOp(op, results);
else
rewriter.eraseOp(op);
return success();
}
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
auto srcType = cast<MemRefType>(copyOp.getSource().getType());
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
auto srcElementType = getElementTypeOrSelf(srcType);
auto dstElementType = getElementTypeOrSelf(dstType);
if (!VectorType::isValidElementType(srcElementType) ||
!VectorType::isValidElementType(dstElementType))
return failure();
auto readType = VectorType::get(srcType.getShape(), srcElementType);
auto writeType = VectorType::get(dstType.getShape(), dstElementType);
Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(srcType.getRank(), zero);
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
loc, readValue, copyOp.getTarget(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
rewriter.replaceOp(copyOp, writeValue->getResults());
return success();
}
static int64_t getIntFromAttr(Attribute attr) {
return cast<IntegerAttr>(attr).getInt();
}
static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
for (auto o : ofrs) {
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
result.push_back(val);
} else {
result.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, getIntFromAttr(o.template get<Attribute>())));
}
}
return result;
}
struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
GenericPadOpVectorizationPattern(MLIRContext *context,
PatternBenefit benefit = 1)
: GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
tensor::PadOp padOp, Value dest) {
auto sourceType = padOp.getSourceType();
auto resultType = padOp.getResultType();
if (!VectorType::isValidElementType(sourceType.getElementType()))
return failure();
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
if (!sourceType.hasStaticShape())
return failure();
auto elemType = sourceType.getElementType();
padValue = rewriter.create<arith::ConstantOp>(
padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
}
SmallVector<int64_t> vecShape;
SmallVector<bool> readInBounds;
SmallVector<bool> writeInBounds;
for (unsigned i = 0; i < sourceType.getRank(); ++i) {
if (!sourceType.isDynamicDim(i)) {
vecShape.push_back(sourceType.getDimSize(i));
readInBounds.push_back(true);
writeInBounds.push_back(true);
} else if (!resultType.isDynamicDim(i)) {
vecShape.push_back(resultType.getDimSize(i));
readInBounds.push_back(false);
writeInBounds.push_back(
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
static_cast<int64_t>(0));
} else {
return failure();
}
}
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
ArrayRef<bool>{readInBounds});
if (llvm::equal(vecShape, resultType.getShape()) &&
llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
auto writeIndices =
ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
return success();
}
};
template <typename OpTy>
struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const final {
bool changed = false;
for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
if (auto op = dyn_cast<OpTy>(user))
changed |= rewriteUser(rewriter, padOp, op).succeeded();
return success(changed);
}
protected:
virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
tensor::PadOp padOp, OpTy op) const = 0;
};
struct PadOpVectorizationWithTransferReadPattern
: public VectorizePadOpUserPattern<vector::TransferReadOp> {
using VectorizePadOpUserPattern<
vector::TransferReadOp>::VectorizePadOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
vector::TransferReadOp xferOp) const override {
if (!padOp.hasZeroLowPad())
return failure();
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
return failure();
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
return failure();
rewriter.modifyOpInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));
xferOp.getSourceMutable().assign(padOp.getSource());
xferOp.getPaddingMutable().assign(padValue);
});
return success();
}
};
struct PadOpVectorizationWithTransferWritePattern
: public VectorizePadOpUserPattern<vector::TransferWriteOp> {
using VectorizePadOpUserPattern<
vector::TransferWriteOp>::VectorizePadOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
vector::TransferWriteOp xferOp) const override {
if (xferOp.getTransferRank() == 0)
return failure();
if (!padOp.hasZeroLowPad())
return failure();
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
return failure();
if (!xferOp->hasOneUse())
return failure();
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
if (!trimPadding)
return failure();
if (!trimPadding.hasZeroOffset())
return failure();
if (!hasSameTensorSize(padOp.getSource(), trimPadding))
return failure();
rewriter.setInsertionPoint(xferOp);
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
xferOp, padOp.getSource().getType(), xferOp.getVector(),
padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
return success();
}
bool hasSameTensorSize(Value beforePadding,
tensor::ExtractSliceOp afterTrimming) const {
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
if (hasSameTensorSize(castOp.getSource(), afterTrimming))
return true;
auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
if (!t1 || !t2)
return false;
if (t1.getRank() != t2.getRank())
return false;
for (unsigned i = 0; i < t1.getRank(); ++i) {
if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
return false;
if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
return false;
}
if (t1.getNumDynamicDims() == 0)
return true;
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
if (!beforeSlice)
return false;
assert(static_cast<size_t>(t1.getRank()) ==
beforeSlice.getMixedSizes().size());
assert(static_cast<size_t>(t2.getRank()) ==
afterTrimming.getMixedSizes().size());
for (unsigned i = 0; i < t1.getRank(); ++i) {
if (!t1.isDynamicDim(i))
continue;
auto size1 = beforeSlice.getMixedSizes()[i];
auto size2 = afterTrimming.getMixedSizes()[i];
if (isEqualConstantIntOrValue(size1, size2))
continue;
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
if (!v1 || !v2)
return false;
auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
minOp1.getOperands() == minOp2.getOperands())
continue;
}
return true;
}
};
struct PadOpVectorizationWithInsertSlicePattern
: public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
using VectorizePadOpUserPattern<
tensor::InsertSliceOp>::VectorizePadOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
tensor::InsertSliceOp insertOp) const override {
if (!padOp.hasZeroLowPad())
return failure();
if (!insertOp.hasUnitStride())
return failure();
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
return failure();
if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
return failure();
if (insertOp.getDest() == padOp.getResult())
return failure();
auto vecType = VectorType::get(padOp.getType().getShape(),
padOp.getType().getElementType());
unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
}))
return failure();
rewriter.setInsertionPoint(insertOp);
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
auto writeIndices =
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
SmallVector<bool> inBounds(vecRank, true);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
insertOp, read, insertOp.getDest(), writeIndices,
ArrayRef<bool>{inBounds});
return success();
}
};
void mlir::linalg::populatePadOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
baseBenefit);
patterns.add<PadOpVectorizationWithTransferReadPattern,
PadOpVectorizationWithTransferWritePattern,
PadOpVectorizationWithInsertSlicePattern>(
patterns.getContext(), baseBenefit.getBenefit() + 1);
}
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
LDBG("interleavedUses precondition failed, firstOp: "
<< *firstOp << ", second op: " << *secondOp << "\n");
return true;
}
for (auto v : values) {
for (auto &u : v.getUses()) {
Operation *owner = u.getOwner();
if (owner == firstOp || owner == secondOp)
continue;
if (owner->getBlock() == firstOp->getBlock() &&
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
<< ", second op: " << *secondOp << "\n");
return true;
}
}
return false;
}
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
memref::SubViewOp subViewOp;
for (auto &u : v.getUses()) {
if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
if (subViewOp)
return memref::SubViewOp();
subViewOp = newSubViewOp;
}
}
return subViewOp;
}
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
if (xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
Value viewOrAlloc = xferOp.getSource();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return rewriter.notifyMatchFailure(xferOp, "no subview found");
Value subView = subViewOp.getResult();
memref::CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
if (newCopyOp.getTarget() != subView)
continue;
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
break;
}
}
if (!copyOp)
return rewriter.notifyMatchFailure(xferOp, "no copy found");
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
assert(isa<MemRefType>(newFillOp.output().getType()));
if (newFillOp.output() != viewOrAlloc)
continue;
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
continue;
maybeFillOp = newFillOp;
break;
}
}
if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
return rewriter.notifyMatchFailure(xferOp,
"padding value does not match fill");
Value in = copyOp.getSource();
auto vectorType = xferOp.getVectorType();
Value res = rewriter.create<vector::TransferReadOp>(
xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
rewriter.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false)));
if (maybeFillOp)
rewriter.eraseOp(maybeFillOp);
rewriter.eraseOp(copyOp);
rewriter.replaceOp(xferOp, res);
return success();
}
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
if (xferOp.getMask())
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
Value viewOrAlloc = xferOp.getSource();
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return rewriter.notifyMatchFailure(xferOp, "no subview found");
Value subView = subViewOp.getResult();
memref::CopyOp copyOp;
for (auto &u : subViewOp.getResult().getUses()) {
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
if (newCopyOp.getSource() != subView)
continue;
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
break;
}
}
if (!copyOp)
return rewriter.notifyMatchFailure(xferOp, "no copy found");
assert(isa<MemRefType>(copyOp.getTarget().getType()));
Value out = copyOp.getTarget();
auto vector = xferOp.getVector();
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
rewriter.getBoolArrayAttr(
SmallVector<bool>(vector.getType().getRank(), false)));
rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);
return success();
}
template <int N>
static void bindShapeDims(ShapedType shapedType) {}
template <int N, typename IntTy, typename... IntTy2>
static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
val = shapedType.getShape()[N];
bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
}
template <typename... IntTy>
static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
namespace {
bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
isa<BlockArgument>(op->getOperand(0));
}
bool isSupportedPoolKind(vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::MAXNUMF:
case vector::CombiningKind::MAXIMUMF:
case vector::CombiningKind::MAXSI:
case vector::CombiningKind::MAXUI:
case vector::CombiningKind::MINNUMF:
case vector::CombiningKind::MINIMUMF:
case vector::CombiningKind::MINSI:
case vector::CombiningKind::MINUI:
return true;
default:
return false;
}
}
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
int dilationW)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
strideW(strideW), dilationW(dilationW) {
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return;
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
if (!reduceOp)
return;
redOp = reduceOp->getName().getIdentifier();
if (!setOperKind(reduceOp))
return;
auto maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
return;
break;
case Pool:
if (rhsRank != 1)
return;
break;
}
valid = true;
}
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
int64_t nSize, wSize, cSize, kwSize, fSize;
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
nSize = fSize = cSize = 0;
bindShapeDims(resShapedType, wSize);
bindShapeDims(rhsShapedType, kwSize);
lhsShape = {
(wSize + kwSize - 1)};
rhsShape = {kwSize};
resShape = {wSize};
break;
case Conv1DOpOrder::Nwc:
bindShapeDims(resShapedType, nSize, wSize, fSize);
switch (oper) {
case Conv:
bindShapeDims(rhsShapedType, kwSize, cSize);
break;
case Pool:
bindShapeDims(rhsShapedType, kwSize);
cSize = fSize;
break;
}
lhsShape = {nSize,
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
1,
cSize};
switch (oper) {
case Conv:
rhsShape = {kwSize, cSize, fSize};
break;
case Pool:
rhsShape = {kwSize};
break;
}
resShape = {nSize, wSize, fSize};
break;
case Conv1DOpOrder::Ncw:
bindShapeDims(resShapedType, nSize, fSize, wSize);
switch (oper) {
case Conv:
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
break;
case Pool:
bindShapeDims(rhsShapedType, kwSize);
cSize = fSize;
break;
}
lhsShape = {nSize, cSize,
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
1};
switch (oper) {
case Conv:
rhsShape = {fSize, cSize, kwSize};
break;
case Pool:
rhsShape = {kwSize};
break;
}
resShape = {nSize, fSize, wSize};
break;
}
vector::TransferWriteOp write;
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
auto lhsType = VectorType::get(lhsShape, lhsEltType);
auto rhsType = VectorType::get(rhsShape, rhsEltType);
auto resType = VectorType::get(resShape, resEltType);
SmallVector<Value> lhsPadding(lhsShape.size(), zero);
SmallVector<Value> rhsPadding(rhsShape.size(), zero);
SmallVector<Value> resPadding(resShape.size(), zero);
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
lhsPadding);
Value rhs = nullptr;
if (oper == Conv)
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
rhsPadding);
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
break;
case Conv1DOpOrder::Ncw: {
static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
if (oper == Conv)
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
break;
}
}
SmallVector<Value> lhsVals, rhsVals, resVals;
lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
kwSize, strideW, dilationW, wSizeStep,
isSingleChanneled);
if (oper == Conv)
rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
wSizeStep, isSingleChanneled);
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
case Conv:
if (isSingleChanneled) {
resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
} else {
resVals[w] = conv1dSliceAsContraction(rewriter, loc,
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
}
break;
case Pool:
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
resVals[w]);
break;
}
}
}
res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
isSingleChanneled);
switch (conv1DOpOrder) {
case Conv1DOpOrder::W:
case Conv1DOpOrder::Nwc:
break;
case Conv1DOpOrder::Ncw: {
static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
res = rewriter.create<vector::TransposeOp>(loc, res, perm);
break;
}
}
return rewriter
.create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
.getOperation();
}
Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
const Type srcElementType = getElementTypeOrSelf(val.getType());
const Type dstElementType = getElementTypeOrSelf(ty);
assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
if (srcElementType == dstElementType)
return val;
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
const Type dstType =
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
}
if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
srcWidth < dstWidth)
return rewriter.create<arith::ExtFOp>(loc, dstType, val);
if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
srcWidth < dstWidth)
return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
assert(false && "unhandled promotion case");
return nullptr;
}
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
vector::IteratorType par = vector::IteratorType::parallel;
vector::IteratorType red = vector::IteratorType::reduction;
AffineExpr n, w, f, c;
bindDims(ctx, n, w, f, c);
lhs = promote(rewriter, loc, lhs, res.getType());
rhs = promote(rewriter, loc, rhs, res.getType());
return rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, res,
MapList{{n, w, c}, {c, f}, {n, w, f}},
ArrayRef<vector::IteratorType>{par, par, par, red});
}
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
return rewriter.create<vector::OuterProductOp>(
loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
}
Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
Value res) {
if (isPoolExt)
lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
return rewriter
.create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
->getResult(0);
}
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
bool channelDimScalableFlag,
bool flatten) {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
bool scalableChDim = false;
bool useMasking = false;
int64_t nSize, wSize, cSize, kwSize;
bindShapeDims(rhsShapedType, kwSize, cSize);
if (ShapedType::isDynamic(cSize)) {
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
cSize = channelDimVecSize;
scalableChDim = channelDimScalableFlag;
useMasking = true;
}
assert(!(useMasking && flatten) &&
"Unsupported flattened conv with dynamic shapes");
bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
int64_t wSizeStep = strideW == 1 ? wSize : 1;
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
VectorType lhsType = VectorType::get(
{nSize,
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
lhsEltType, {false, false, scalableChDim});
VectorType rhsType =
VectorType::get({kwSize, cSize}, rhsEltType,
{false, scalableChDim});
VectorType resType =
VectorType::get({nSize, wSize, cSize}, resEltType,
{false, false, scalableChDim});
auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
ArrayRef<bool> scalableDims,
Operation *opToMask) {
if (!useMasking)
return opToMask;
auto maskType =
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
SmallVector<bool> inBounds(maskShape.size(), true);
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));
SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
Value maskOp =
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
};
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero});
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
SmallVector<Value> lhsVals, rhsVals, resVals;
auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, maybeMaskedLhs->getResult(0),
ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
inOutSliceSizes, inOutStrides));
}
}
for (int64_t kw = 0; kw < kwSize; ++kw) {
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
loc, maybeMaskedRhs->getResult(0),
ArrayRef<int64_t>{kw}));
}
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, maybeMaskedRes->getResult(0),
ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
inOutStrides));
}
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
};
auto inOutFlattenSliceSizes =
SmallVector<int64_t>{nSize, wSizeStep * cSize};
auto lhsTypeAfterFlattening =
VectorType::get(inOutFlattenSliceSizes, lhsEltType);
auto resTypeAfterFlattening =
VectorType::get(inOutFlattenSliceSizes, resEltType);
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
Value lhsVal = lhsVals[linearIndex(kw, w)];
Value resVal = resVals[w];
if (flatten) {
lhsVal = rewriter.create<vector::ShapeCastOp>(
loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
resVal = rewriter.create<vector::ShapeCastOp>(
loc, resTypeAfterFlattening, resVals[w]);
}
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
rhsVals[kw], resVal, flatten);
if (flatten) {
resVals[w] = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
}
}
}
if (!llvm::all_of(resVals, [](Value v) { return v; })) {
for (auto &collection :
{resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
for (Value v : collection)
rewriter.eraseOp(v.getDefiningOp());
return rewriter.notifyMatchFailure(op, "failed to create FMA");
}
for (int64_t w = 0; w < wSize; w += wSizeStep) {
maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
loc, resVals[w], maybeMaskedRes->getResult(0),
ArrayRef<int64_t>{0, w, 0},
ArrayRef<int64_t>{1, 1, 1});
}
Operation *resOut = rewriter.create<vector::TransferWriteOp>(
loc, maybeMaskedRes->getResult(0), resShaped,
ValueRange{zero, zero, zero});
return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
resOut);
}
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res,
bool flatten) {
auto rhsTy = cast<ShapedType>(rhs.getType());
auto resTy = cast<ShapedType>(res.getType());
lhs = promote(rewriter, loc, lhs, resTy);
if (flatten) {
auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
auto resSize = cast<VectorType>(res.getType()).getShape()[1];
SmallVector<int64_t, 16> indices;
for (int i = 0; i < resSize / rhsSize; ++i) {
for (int j = 0; j < rhsSize; ++j)
indices.push_back(j);
}
rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
}
rhs = rewriter.create<vector::BroadcastOp>(
loc, resTy.clone(rhsTy.getElementType()), rhs);
rhs = promote(rewriter, loc, rhs, resTy);
if (!lhs || !rhs)
return nullptr;
if (isa<FloatType>(resTy.getElementType()))
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
return rewriter.create<arith::AddIOp>(loc, mul, res);
}
FailureOr<Operation *> generateNonChanneledConv() {
AffineExpr w, kw;
bindDims(ctx, w, kw);
if (!iters({Par(), Red()}))
return rewriter.notifyMatchFailure(op,
"failed to match conv::W 1-par 1-red");
if (layout({ {w + kw},
{kw},
{w}}))
return conv(Conv1DOpOrder::W);
return rewriter.notifyMatchFailure(op, "not a conv::W layout");
}
FailureOr<Operation *> generateNwcConv() {
AffineExpr n, w, f, kw, c;
bindDims(ctx, n, w, f, kw, c);
if (!iters({Par(), Par(), Par(), Red(), Red()}))
return rewriter.notifyMatchFailure(
op, "failed to match conv::Nwc 3-par 2-red");
if (layout({ {n, strideW * w + dilationW * kw, c},
{kw, c, f},
{n, w, f}}))
return conv(Conv1DOpOrder::Nwc);
return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
}
FailureOr<Operation *> generateNcwConv() {
AffineExpr n, w, f, kw, c;
bindDims(ctx, n, f, w, c, kw);
if (!iters({Par(), Par(), Par(), Red(), Red()}))
return rewriter.notifyMatchFailure(
op, "failed to match conv::Ncw 3-par 2-red");
if (layout({ {n, c, strideW * w + dilationW * kw},
{f, c, kw},
{n, f, w}}))
return conv(Conv1DOpOrder::Ncw);
return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
}
FailureOr<Operation *> generateNwcPooling() {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
return rewriter.notifyMatchFailure(op,
"failed to match pooling 3-par 1-red");
if (layout({ {n, strideW * w + dilationW * kw, c},
{kw},
{n, w, c}}))
return conv(Conv1DOpOrder::Nwc);
return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
}
FailureOr<Operation *> generateNcwPooling() {
AffineExpr n, w, c, kw;
bindDims(ctx, n, c, w, kw);
if (!iters({Par(), Par(), Par(), Red()}))
return rewriter.notifyMatchFailure(op,
"failed to match pooling 3-par 1-red");
if (layout({ {n, c, strideW * w + dilationW * kw},
{kw},
{n, c, w}}))
return conv(Conv1DOpOrder::Ncw);
return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
}
FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
bool vecChDimScalableFlag = false,
bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
return rewriter.notifyMatchFailure(
op, "failed to match depthwise::Nwc conv 3-par 1-red");
if (layout({ {n, strideW * w + dilationW * kw, c},
{kw, c},
{n, w, c}}))
return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
private:
enum OperKind { Conv, Pool };
bool valid = false;
OperKind oper = Conv;
StringAttr redOp;
StringAttr poolExtOp;
bool isPoolExt = false;
int strideW, dilationW;
Value lhsShaped, rhsShaped, resShaped;
ShapedType lhsShapedType, rhsShapedType, resShapedType;
bool setOperKind(Operation *reduceOp) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
switch (numBlockArguments) {
case 1: {
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
if (isa<BlockArgument>(v))
return true;
if (Operation *op = v.getDefiningOp())
return isCastOfBlockArgument(op);
return false;
}))) {
return false;
}
return true;
}
case 2:
oper = Pool;
isPoolExt = false;
return true;
default:
return false;
}
}
};
}
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
Conv1DGenerator e(rewriter, op, stride, dilation);
auto res = e.generateNonChanneledConv();
if (succeeded(res))
return res;
res = e.generateNwcConv();
if (succeeded(res))
return res;
res = e.generateNcwConv();
if (succeeded(res))
return res;
res = e.generateNwcPooling();
if (succeeded(res))
return res;
res = e.generateNcwPooling();
if (succeeded(res))
return res;
uint64_t vecChDimSize = ShapedType::kDynamic;
bool vecChDimScalableFlag = false;
if (!inputVecSizes.empty()) {
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
"Not a 1D depthwise conv!");
size_t chDimIdx =
TypeSwitch<Operation *, size_t>(op)
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
.Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
}
return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
if (failed(resultOrFail))
return failure();
Operation *newOp = *resultOrFail;
if (newOp->getNumResults() == 0) {
rewriter.eraseOp(op.getOperation());
return success();
}
assert(newOp->getNumResults() == 1 && "expected single result");
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
return success();
}
};
void mlir::linalg::populateConvolutionVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
}