#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include <cassert>
#include <cstdint>
#include <functional>
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
using namespace mlir;
using namespace mlir::vector;
template <typename IntType>
static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
return i;
}
return std::nullopt;
}
namespace {
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto sourceVectorType =
dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
auto resultVectorType =
dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
if (!sourceVectorType || !resultVectorType)
return failure();
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
shapeCastOp.getSource().getDefiningOp());
if (!sourceShapeCastOp)
return failure();
auto operandSourceVectorType =
cast<VectorType>(sourceShapeCastOp.getSource().getType());
auto operandResultVectorType = sourceShapeCastOp.getType();
if (operandSourceVectorType != resultVectorType ||
operandResultVectorType != sourceVectorType)
return failure();
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
return success();
}
};
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
if (reduceOp.getKind() != vector::CombiningKind::ADD)
return failure();
Operation *mulOp = reduceOp.getSource().getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return failure();
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
SmallVector<AffineExpr> exprs;
SmallVector<vector::IteratorType> iteratorTypes;
for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
if (!isReduceDim.value()) {
iteratorTypes.push_back(vector::IteratorType::parallel);
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
} else {
iteratorTypes.push_back(vector::IteratorType::reduction);
}
}
auto dstMap =
AffineMap::get(reductionMask.size(),
0, exprs, reduceOp.getContext());
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
return IteratorTypeAttr::get(rewriter.getContext(), t);
}))));
return success();
}
};
struct CombineContractABTranspose final
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
continue;
AffineMap permutationMap = AffineMap::getPermutationMap(
transposeOp.getPermutation(), contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.getVector();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
return success();
}
};
struct CombineContractResultTranspose final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
PatternRewriter &rewriter) const override {
auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
if (!contractOp || !contractOp->hasOneUse())
return failure();
auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
if (!accTOp)
return failure();
MLIRContext *context = contractOp.getContext();
auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
AffineMap contractMap = maps.back();
auto accTMap =
AffineMap::getPermutationMap(accTOp.getPermutation(), context);
auto resTMap =
AffineMap::getPermutationMap(resTOp.getPermutation(), context);
auto combinedResMap = resTMap.compose(contractMap);
if (inversePermutation(accTMap) != resTMap)
return failure();
maps.back() = combinedResMap;
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
return success();
}
};
struct CombineContractBroadcast
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
if (!broadcast)
continue;
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcType ||
srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
int64_t rankDiff =
broadcast.getResultVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
if (dim.value() != broadcast.getResultVectorType().getDimSize(
rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
originalDims.push_back(
rewriter.getAffineDimExpr(dim.index() + rankDiff));
}
if (innerDimBroadcast)
continue;
bool nonUnitDimReductionBroadcast = false;
for (int64_t i = 0; i < rankDiff; ++i) {
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
isReductionIterator(contractOp.getIteratorTypes()
.getValue()[map.getDimPosition(i)])) {
nonUnitDimReductionBroadcast = true;
break;
}
}
if (nonUnitDimReductionBroadcast)
continue;
AffineMap broadcastMap =
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
originalDims, contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.getSource();
changed = true;
}
if (!changed)
return failure();
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
for (auto &m : maps)
m = compressDims(m, unusedDimsBitVector);
SmallVector<Attribute> iterators;
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}
bool hasReductionIteratorApplyingOnBothSides = false;
for (unsigned i = 0; i < iterators.size(); ++i) {
if (!isReductionIterator(iterators[i]))
continue;
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
hasReductionIteratorApplyingOnBothSides = true;
break;
}
}
if (!hasReductionIteratorApplyingOnBothSides)
return failure();
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
return success();
}
};
struct ReorderCastOpsOnBroadcast
: public OpInterfaceRewritePattern<CastOpInterface> {
using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(CastOpInterface op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return failure();
auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!bcastOp)
return failure();
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
castResTy = vecTy.clone(castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, op->getResult(0).getType(), castOp->getResult(0));
return success();
}
};
struct ReorderElementwiseOpsOnTranspose final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
SmallVector<ArrayRef<int64_t>> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
VectorType srcType;
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
transposeMaps.push_back(transposeOp.getPermutation());
srcType = transposeOp.getSourceVectorType();
} else if (!matchPattern(operand, m_Constant())) {
return failure();
}
}
if (transposeMaps.empty())
return failure();
if (!llvm::all_equal(transposeMaps))
return rewriter.notifyMatchFailure(op, "different transpose map");
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
auto order = transposeMaps.front();
SmallVector<int64_t> invOrder(order.size());
for (int i = 0, e = order.size(); i < e; ++i)
invOrder[order[i]] = i;
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
srcValues.push_back(transposeOp.getVector());
} else {
auto vectorType =
srcType.clone(cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand, invOrder));
}
}
auto vectorType = srcType.clone(
cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
vectorType, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
op, op->getResultTypes()[0], elementwiseOp->getResult(0),
transposeMaps.front());
return success();
}
};
static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return attr.getInt(); }));
}
struct BubbleDownVectorBitCastForExtract
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (extractOp.getSourceVectorType().getRank() != 1)
return failure();
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
if (castSrcType.getNumElements() == 1)
return failure();
if (castSrcType.getNumElements() > castDstType.getNumElements())
return failure();
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
assert(values[0].is<Attribute>() && "Unexpected non-constant index");
return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
};
uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
Location loc = extractOp.getLoc();
Value packedValue = rewriter.create<vector::ExtractOp>(
loc, castOp.getSource(), index / expandRatio);
Type packedVecType = VectorType::get({1}, packedValue.getType());
Value zero = rewriter.create<arith::ConstantOp>(
loc, packedVecType, rewriter.getZeroAttr(packedVecType));
packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
0);
VectorType packedType =
VectorType::get({expandRatio}, castDstType.getElementType());
Value castedValue =
rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
index % expandRatio);
return success();
}
};
struct BubbleDownBitCastForStridedSliceExtract
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
if (castSrcLastDim > castDstLastDim)
return failure();
if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOne(); }))
return failure();
unsigned rank = extractOp.getSourceVectorType().getRank();
assert(castDstLastDim % castSrcLastDim == 0);
int64_t expandRatio = castDstLastDim / castSrcLastDim;
ArrayAttr newOffsets = extractOp.getOffsets();
if (newOffsets.size() == rank) {
SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
if (offsets.back() % expandRatio != 0)
return failure();
offsets.back() = offsets.back() / expandRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
}
ArrayAttr newSizes = extractOp.getSizes();
if (newSizes.size() == rank) {
SmallVector<int64_t> sizes = getIntValueVector(newSizes);
if (sizes.back() % expandRatio != 0)
return failure();
sizes.back() = sizes.back() / expandRatio;
newSizes = rewriter.getI64ArrayAttr(sizes);
}
SmallVector<int64_t> dims =
llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
newSizes, extractOp.getStrides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
return success();
}
};
struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
castDstType.isScalable())
return failure();
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
int64_t ratio;
if (isNumElemsShrink) {
assert(castSrcLastDim % castDstLastDim == 0);
ratio = castSrcLastDim / castDstLastDim;
} else {
assert(castDstLastDim % castSrcLastDim == 0);
ratio = castDstLastDim / castSrcLastDim;
}
auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
if (!insertOp)
return failure();
auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
if (!insertSrcType)
return failure();
SmallVector<int64_t> srcDims(insertSrcType.getShape());
srcDims.back() =
isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
dstDims.back() =
isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
auto newCastDstOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertOp>(
bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
return success();
}
};
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
if (castSrcType.getRank() == 0)
return failure();
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
if (castSrcLastDim < castDstLastDim)
return failure();
assert(castSrcLastDim % castDstLastDim == 0);
int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
auto insertOp =
bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
if (!insertOp)
return failure();
if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOne(); }))
return failure();
unsigned rank = insertOp.getSourceVectorType().getRank();
if (rank != insertOp.getDestVectorType().getRank())
return failure();
unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
unsigned destinationWidth =
castDstType.getElementType().getIntOrFloatBitWidth();
unsigned numElements = destinationWidth / sourceWidth;
if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
return failure();
ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
if (offsets.back() % shrinkRatio != 0)
return failure();
offsets.back() = offsets.back() / shrinkRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
SmallVector<int64_t> srcDims =
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
srcDims.back() = srcDims.back() / shrinkRatio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
SmallVector<int64_t> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
dstDims.back() = dstDims.back() / shrinkRatio;
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
auto newCastDstOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
insertOp.getStrides());
return success();
}
};
struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
public:
BreakDownVectorBitCast(MLIRContext *context,
std::function<bool(vector::BitCastOp)> controlFn,
PatternBenefit benefit)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
if (controlFn && !controlFn(bitcastOp))
return failure();
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
if (castSrcType.getRank() != 1)
return failure();
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
if (castSrcLastDim < castDstLastDim)
return failure();
assert(castSrcLastDim % castDstLastDim == 0);
int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
if (castSrcLastDim == shrinkRatio)
return failure();
Location loc = bitcastOp.getLoc();
Type elemType = castDstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
SmallVector<int64_t> sliceShape{castDstLastDim};
SmallVector<int64_t> strides{1};
VectorType newCastDstType =
VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
castDstType.getElementType());
for (int i = 0, e = shrinkRatio; i < e; ++i) {
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
sliceShape, strides);
Value bitcast =
rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
res = rewriter.create<InsertStridedSliceOp>(
loc, bitcast, res,
ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
}
rewriter.replaceOp(bitcastOp, res);
return success();
}
private:
std::function<bool(BitCastOp)> controlFn;
};
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return failure();
if (op->getNumOperands() == 0 ||
op->getResults()[0].getType() != op->getOperand(0).getType()) {
return failure();
}
if (isa<vector::FMAOp>(op)) {
return failure();
}
auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
if (!lhsBcastOrSplat ||
!isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
return failure();
auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
if (bcast)
return (bcast.getOperand().getType() == lhsBcastOrSplatType);
auto splat = val.getDefiningOp<vector::SplatOp>();
if (splat)
return (splat.getOperand().getType() == lhsBcastOrSplatType);
return false;
})) {
return failure();
}
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
lhsBcastOrSplatType, op->getAttrs());
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
return success();
}
};
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
bool force32BitVectorIndices, int64_t dim,
Value b, Value *off = nullptr) {
auto loc = op->getLoc();
Type idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
DenseIntElementsAttr indicesAttr;
if (dim == 0 && force32BitVectorIndices) {
indicesAttr = DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
} else if (dim == 0) {
indicesAttr = DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
} else if (force32BitVectorIndices) {
indicesAttr = rewriter.getI32VectorAttr(
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
} else {
indicesAttr = rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
}
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
}
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
bounds);
}
template <typename ConcreteOp>
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
public:
explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
PatternBenefit benefit = 1)
: mlir::OpRewritePattern<ConcreteOp>(context, benefit),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(ConcreteOp xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp.hasOutOfBoundsDim())
return failure();
if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
return failure();
Location loc = xferOp->getLoc();
VectorType vtp = xferOp.getVectorType();
unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
Value off = xferOp.getIndices()[lastIndex];
Value dim =
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
vtp.getScalableDims()),
b);
if (xferOp.getMask()) {
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
}
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(mask);
xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
});
return success();
}
private:
const bool force32BitVectorIndices;
};
class VectorCreateMaskOpConversion
: public OpRewritePattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
bool enableIndexOpt,
PatternBenefit benefit = 1)
: mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (cast<VectorType>(dstType).isScalable())
return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
return failure();
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
rank == 0 ? 0 : dstType.getDimSize(0),
op.getOperand(0)));
return success();
}
private:
const bool force32BitVectorIndices;
};
static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
if (!denseAttr)
return false;
assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
}
struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arith::SelectOp selectOp,
PatternRewriter &rewriter) const override {
auto vecType = dyn_cast<VectorType>(selectOp.getType());
if (!vecType || !vecType.getElementType().isInteger(1))
return failure();
Value cond = selectOp.getCondition();
if (isa<VectorType>(cond.getType()))
return failure();
if (vecType.getRank() != 1 || vecType.isScalable())
return failure();
if (vecType.getShape()[0] != 1)
return failure();
auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
return failure();
auto falseConst =
selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
return failure();
auto elemType = rewriter.getIntegerType(vecType.getNumElements());
auto bcastType = VectorType::get({1}, elemType);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
return success();
}
};
static FailureOr<size_t>
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
SmallVector<int64_t> srcStrides;
int64_t srcOffset;
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
auto isUnitDim = [](VectorType type, int dim) {
return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
};
size_t result = 0;
int rankDiff = srcType.getRank() - vectorType.getRank();
for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] != 1 ||
srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
break;
result++;
}
return result;
}
class DropInnerMostUnitDimsTransferRead
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
if (readOp.getTransferRank() == 0)
return failure();
if (readOp.getMask())
return failure();
auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
if (!srcType)
return failure();
if (!readOp.getPermutationMap().isMinorIdentity())
return failure();
auto targetType = readOp.getVectorType();
if (targetType.getRank() <= 1)
return failure();
FailureOr<size_t> maybeDimsToDrop =
getTransferFoldableInnerUnitDims(srcType, targetType);
if (failed(maybeDimsToDrop))
return failure();
size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();
auto inBounds = readOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
targetType.getScalableDims().drop_back(dimsToDrop));
auto loc = readOp.getLoc();
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, loc, readOp.getSource());
SmallVector<OpFoldResult> offsets(srcType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
auto resultMemrefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
ArrayAttr inBoundsAttr =
readOp.getInBounds()
? rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
readOp.getPadding(),
Value(), inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
result);
return success();
}
};
class DropInnerMostUnitDimsTransferWrite
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (writeOp.getTransferRank() == 0)
return failure();
if (writeOp.getMask())
return failure();
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
if (!srcType)
return failure();
if (!writeOp.getPermutationMap().isMinorIdentity())
return failure();
auto targetType = writeOp.getVectorType();
if (targetType.getRank() <= 1)
return failure();
FailureOr<size_t> maybeDimsToDrop =
getTransferFoldableInnerUnitDims(srcType, targetType);
if (failed(maybeDimsToDrop))
return failure();
size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();
auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
targetType.getScalableDims().drop_back(dimsToDrop));
Location loc = writeOp.getLoc();
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
SmallVector<OpFoldResult> offsets(srcType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(srcType.getRank(),
rewriter.getIndexAttr(1));
auto resultMemrefType =
cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
strides));
ArrayAttr inBoundsAttr =
writeOp.getInBounds()
? rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, resultTargetVecType, writeOp.getVector());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
writeOp, shapeCast, rankedReducedView,
writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
Value(), inBoundsAttr);
return success();
}
};
struct CanonicalizeContractMatmulToMMT final
: OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
FilterConstraintType constraint)
: OpRewritePattern<vector::ContractionOp>(context, benefit),
filter(std::move(constraint)) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
if (failed(filter(op)))
return failure();
Location loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Value res = op.getAcc();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, op.getContext());
};
AffineExpr m;
AffineExpr n;
AffineExpr k;
bindDims(rewriter.getContext(), m, n, k);
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.getIteratorTypes().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
if (iteratorTypes.size() != 3 ||
!vector::isParallelIterator(iteratorTypes[0]) ||
!vector::isParallelIterator(iteratorTypes[1]) ||
!vector::isReductionIterator(iteratorTypes[2]))
return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
if (maps == canonicalForm)
return rewriter.notifyMatchFailure(op, "already in the canonical form");
auto createTranspose = [&rewriter, loc](Value mat) -> Value {
if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
Value trans =
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
VectorType newType =
cast<VectorType>(trans.getType())
.clone(cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
}
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
Value trans =
rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
VectorType newType =
VectorType::get(cast<VectorType>(trans.getType()).getShape(),
cast<VectorType>(mat.getType()).getElementType());
return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
}
return rewriter.create<vector::TransposeOp>(loc, mat, perm);
};
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
rhs = createTranspose(rhs);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
lhs = createTranspose(lhs);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
rhs = createTranspose(rhs);
lhs = createTranspose(lhs);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
std::swap(rhs, lhs);
rhs = createTranspose(rhs);
lhs = createTranspose(lhs);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
std::swap(rhs, lhs);
rhs = createTranspose(rhs);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
std::swap(lhs, rhs);
lhs = createTranspose(lhs);
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else {
return rewriter.notifyMatchFailure(op, "unhandled contraction form");
}
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
op.getIteratorTypes());
return success();
};
private:
FilterConstraintType filter;
};
template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
if (!lhsDefOp || !rhsDefOp) {
return rewriter.notifyMatchFailure(contractOp,
"no defining op on contract operands");
}
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
contractOp.getIteratorTypesAttr());
return success();
}
};
struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
if (op.getKind() != vector::CombiningKind::ADD)
return failure();
Value acc = op.getAcc();
if (!acc)
return failure();
if (!acc.getType().isIntOrFloat())
return failure();
auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
if (!parentReduction)
return failure();
Location loc = op.getLoc();
Value vAdd;
if (isa<IntegerType>(acc.getType())) {
vAdd = rewriter.createOrFold<arith::AddIOp>(
loc, parentReduction.getVector(), op.getVector());
} else {
vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
op.getVector());
}
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
parentReduction.getAcc());
return success();
}
};
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
auto inVecShape = inVecTy.getShape();
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dim, isScalable] :
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
if (dim == 1 && !isScalable)
continue;
newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
if (newShape.empty()) {
newShape.push_back(1);
newScalableDims.push_back(false);
}
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
}
struct DropUnitDimFromElementwiseOps final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultVectorType)
return failure();
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
if (!sourceVectorType)
return failure();
if (sourceVectorType.getRank() < 2)
return failure();
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
if (newVType == opVectorType)
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
dropNonScalableUnitDimFromType(resultVectorType);
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
return success();
}
};
struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
if (op.getKind() != vector::CombiningKind::ADD)
return failure();
Type elemType = op.getSourceVectorType().getElementType();
if (!isa<FloatType>(elemType))
return failure();
auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
if (!vAdd)
return failure();
auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
if (!addLhs)
return failure();
if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
return failure();
auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
vAdd.getRhs());
rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
op.getAcc());
return success();
}
};
struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
BreakDownVectorReduction(MLIRContext *context,
unsigned maxNumElementsToExtract,
PatternBenefit benefit)
: OpRewritePattern(context, benefit),
maxNumElementsToExtract(maxNumElementsToExtract) {}
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
VectorType type = op.getSourceVectorType();
if (type.isScalable() || op.isMasked())
return failure();
assert(type.getRank() == 1 && "Expected a 1-d vector");
int64_t numElems = type.getNumElements();
if (numElems > maxNumElementsToExtract) {
return rewriter.notifyMatchFailure(
op, llvm::formatv("has too many vector elements ({0}) to break down "
"(max allowed: {1})",
numElems, maxNumElementsToExtract));
}
Location loc = op.getLoc();
SmallVector<Value> extracted(numElems, nullptr);
for (auto [idx, extractedElem] : llvm::enumerate(extracted))
extractedElem = rewriter.create<vector::ExtractOp>(
loc, op.getVector(), static_cast<int64_t>(idx));
Value res = extracted.front();
for (auto extractedElem : llvm::drop_begin(extracted))
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
extractedElem, op.getFastmathAttr());
if (Value acc = op.getAcc())
res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
op.getFastmathAttr());
rewriter.replaceOp(op, res);
return success();
}
private:
unsigned maxNumElementsToExtract = 0;
};
template <typename MulOpType>
struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
using OpRewritePattern<MulOpType>::OpRewritePattern;
bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
if (!broadcastOp.computeBroadcastedUnitDims().empty())
return false;
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
return srcType && srcType.getRank() != 2;
}
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
return failure();
auto matchOuterProduct =
[&](Value operandA,
Value operandB) -> FailureOr<vector::OuterProductOp> {
auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
if (!transposedLhs)
return failure();
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
return failure();
auto broadcastedLhs =
transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
return failure();
auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
return failure();
return rewriter.create<vector::OuterProductOp>(
mulOp->getLoc(), resType, broadcastedLhs.getSource(),
broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
};
Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
auto maybeOuterP = matchOuterProduct(lhs, rhs);
if (failed(maybeOuterP))
maybeOuterP = matchOuterProduct(rhs, lhs);
if (failed(maybeOuterP))
return failure();
rewriter.replaceOp(mulOp, maybeOuterP->getResult());
return success();
}
};
}
void mlir::vector::populateFoldArithExtensionPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
patterns.getContext());
}
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool force32BitVectorIndices,
PatternBenefit benefit) {
patterns.add<VectorCreateMaskOpConversion,
MaterializeTransferMask<vector::TransferReadOp>,
MaterializeTransferMask<vector::TransferWriteOp>>(
patterns.getContext(), force32BitVectorIndices, benefit);
patterns.add<FoldI1Select>(patterns.getContext(), benefit);
}
void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
patterns.getContext(), benefit);
}
void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
RewritePatternSet &patterns,
std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
std::move(controlFn), benefit);
}
void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
RewritePatternSet &patterns,
std::function<LogicalResult(vector::ContractionOp)> constraint,
PatternBenefit benefit) {
patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
std::move(constraint));
}
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractABTranspose, CombineContractResultTranspose,
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
patterns.getContext(), benefit);
}
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropInnerMostUnitDimsTransferRead,
DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
patterns.getContext(), benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
patterns.add<ReduceRedundantZero>(patterns.getContext(),
PatternBenefit(benefit.getBenefit() + 1));
}
void mlir::vector::populateBreakDownVectorReductionPatterns(
RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
PatternBenefit benefit) {
patterns.add<BreakDownVectorReduction>(patterns.getContext(),
maxNumElementsToExtract, benefit);
}
void mlir::vector::populateElementwiseToVectorOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
FoldArithToVectorOuterProduct<arith::MulIOp>>(
patterns.getContext());
}
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"