#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-to-vector"
using namespace mlir;
using namespace mlir::vector;
static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
return i;
}
return None;
}
static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes,
int64_t index) {
SmallVector<Attribute, 4> results;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
int64_t idx = it.index();
if (idx == index)
continue;
results.push_back(it.value());
}
return results;
}
static AffineMap adjustMap(AffineMap map, int64_t index,
PatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
SmallVector<AffineExpr, 4> results;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
continue;
auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
results.push_back(targetExpr);
}
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
static Value reshapeLoad(Location loc, Value val, VectorType type,
int64_t index, int64_t pos,
PatternRewriter &rewriter) {
if (index == -1)
return val;
Type lowType = VectorType::Builder(type).dropDim(0);
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
}
VectorType vType = lowType.cast<VectorType>();
Type resType = VectorType::Builder(type).dropDim(index);
auto resVectorType = resType.cast<VectorType>();
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
posAttr);
}
return result;
}
static Value reshapeStore(Location loc, Value val, Value result,
VectorType type, int64_t index, int64_t pos,
PatternRewriter &rewriter) {
if (index == -1)
return val;
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
}
Type lowType = VectorType::Builder(type).dropDim(0);
VectorType vType = lowType.cast<VectorType>();
Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr);
}
return result;
}
template <typename IntType>
static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter,
bool isInt) {
using vector::CombiningKind;
Value mul;
if (isInt) {
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
return Optional<Value>();
mul = rewriter.create<arith::MulIOp>(loc, x, y);
} else {
if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
kind == CombiningKind::XOR)
return Optional<Value>();
if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
if (!acc)
return Optional<Value>(mul);
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
static SmallVector<int64_t> getReductionIndex(AffineMap map,
ArrayAttr iteratorTypes) {
SmallVector<int64_t> dimsIdx;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
dimsIdx.push_back(i);
}
return dimsIdx;
}
static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (map.getDimPosition(i) == dim)
return i;
}
return llvm::None;
}
namespace {
struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto sourceVectorType =
shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
auto resultVectorType =
shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
if (!sourceVectorType || !resultVectorType)
return failure();
auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
shapeCastOp.getSource().getDefiningOp());
if (!sourceShapeCastOp)
return failure();
auto operandSourceVectorType =
sourceShapeCastOp.getSource().getType().cast<VectorType>();
auto operandResultVectorType = sourceShapeCastOp.getType();
if (operandSourceVectorType != resultVectorType ||
operandResultVectorType != sourceVectorType)
return failure();
rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
return success();
}
};
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType dstType = op.getVectorType();
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
Type eltType = dstType.getElementType();
if (!srcType) {
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
return success();
}
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
if (srcRank <= 1 && dstRank == 1) {
Value ext;
if (srcRank == 0)
ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
else
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
return success();
}
if (srcRank < dstRank) {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
rewriter.replaceOp(op, result);
return success();
}
assert(srcRank == dstRank);
int64_t m = -1;
for (int64_t r = 0; r < dstRank; r++)
if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
m = r;
break;
}
if (m == -1) {
rewriter.replaceOp(op, op.getSource());
return success();
}
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
if (m == 0) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
} else {
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
}
}
rewriter.replaceOp(op, result);
return success();
}
};
void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
SmallVectorImpl<int64_t> &result) {
size_t numTransposedDims = transpose.size();
for (size_t transpDim : llvm::reverse(transpose)) {
if (transpDim != numTransposedDims - 1)
break;
numTransposedDims--;
}
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
}
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context)
: OpRewritePattern<vector::TransposeOp>(context),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.getVector();
VectorType inputType = op.getVectorType();
VectorType resType = op.getResultType();
SmallVector<int64_t, 4> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Shuffle &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
Type flattenedType =
VectorType::get(resType.getNumElements(), resType.getElementType());
auto matrix =
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
Value trans = rewriter.create<vector::FlatTransposeOp>(
loc, flattenedType, matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
}
SmallVector<int64_t, 4> prunedTransp;
pruneNonTransposedDims(transp, prunedTransp);
size_t numPrunedDims = transp.size() - prunedTransp.size();
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
auto prunedInStrides = computeStrides(prunedInShape, ones);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
SmallVector<int64_t, 4> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
result =
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
}
rewriter.replaceOp(op, result);
return success();
}
private:
vector::VectorTransformsOptions vectorTransformOptions;
};
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
TransposeOp2DToShuffleLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context)
: OpRewritePattern<vector::TransposeOp>(context),
vectorTransformOptions(vectorTransformOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType srcType = op.getVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
SmallVector<int64_t, 4> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
if (vectorTransformOptions.vectorTransposeLowering !=
VectorTransposeLowering::Shuffle)
return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle");
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
Value casted = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({m * n}, srcType.getElementType()),
op.getVector());
SmallVector<int64_t> mask;
mask.reserve(m * n);
for (int64_t j = 0; j < n; ++j)
for (int64_t i = 0; i < m; ++i)
mask.push_back(i * n + j);
Value shuffled =
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
shuffled);
return success();
}
private:
vector::VectorTransformsOptions vectorTransformOptions;
};
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
VectorType resType = op.getVectorType();
Type eltType = resType.getElementType();
bool isInt = eltType.isa<IntegerType, IndexType>();
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
vector::CombiningKind kind = op.getKind();
if (!rhsType) {
Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
kind, rewriter, isInt);
if (!mult.has_value())
return failure();
rewriter.replaceOp(op, mult.value());
return success();
}
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x =
rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
Optional<Value> m =
createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, m.value(),
result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
};
struct ContractOpToElementwise
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context,
const FilterConstraintType &constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
if (llvm::size(contractOp.getMasks()) != 0)
return failure();
if (failed(filter(contractOp)))
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
return failure();
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
SmallVector<int64_t> lhsReductionDims =
getReductionIndex(lhsMap, contractOp.getIteratorTypes());
SmallVector<int64_t> rhsReductionDims =
getReductionIndex(rhsMap, contractOp.getIteratorTypes());
for (int64_t dim : lhsReductionDims) {
if (lhsShape[dim] != 1)
return failure();
}
for (int64_t dim : rhsReductionDims) {
if (rhsShape[dim] != 1)
return failure();
}
AffineMap accMap = contractOp.getIndexingMapsArray()[2];
unsigned numParallelDims = accMap.getNumResults();
unsigned numLhsDimToBroadcast =
numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
unsigned numRhsDimToBroadcast =
numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
SmallVector<int64_t> lhsDims;
SmallVector<int64_t> lhsTranspose;
SmallVector<int64_t> rhsDims;
SmallVector<int64_t> rhsTranspose;
for (int64_t dim : lhsReductionDims)
lhsTranspose.push_back(numLhsDimToBroadcast + dim);
for (int64_t dim : rhsReductionDims)
rhsTranspose.push_back(numRhsDimToBroadcast + dim);
for (unsigned i = 0; i < numParallelDims; i++) {
llvm::Optional<unsigned> lhsDim =
getDimPosition(lhsMap, accMap.getDimPosition(i));
if (lhsDim) {
lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
} else {
lhsDims.push_back(
contractOp.getResultType().cast<VectorType>().getDimSize(i));
lhsTranspose.push_back(lhsDims.size() - 1);
}
llvm::Optional<unsigned> rhsDim =
getDimPosition(rhsMap, accMap.getDimPosition(i));
if (rhsDim) {
rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
} else {
rhsDims.push_back(
contractOp.getResultType().cast<VectorType>().getDimSize(i));
rhsTranspose.push_back(rhsDims.size() - 1);
}
}
Value newLhs = contractOp.getLhs();
Value newRhs = contractOp.getRhs();
Location loc = contractOp.getLoc();
if (!lhsDims.empty()) {
lhsDims.append(lhsShape.begin(), lhsShape.end());
auto expandedType =
VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
}
if (!rhsDims.empty()) {
rhsDims.append(rhsShape.begin(), rhsShape.end());
auto expandedType =
VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
}
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
newLhs = rewriter.create<vector::ExtractOp>(
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
newRhs = rewriter.create<vector::ExtractOp>(
loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
Optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
rewriter.replaceOp(contractOp, {*result});
return success();
}
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
auto eltType = dstType.getElementType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()),
ArrayRef<bool>{value}));
return success();
}
if (dstType.cast<VectorType>().isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
dimSizes[0].cast<IntegerAttr>().getInt());
if (rank == 1) {
SmallVector<bool, 4> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
return success();
}
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
SmallVector<int64_t, 4> newDimSizes;
for (int64_t r = 1; r < rank; r++)
newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDim; d++) {
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
};
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getResult().getType().cast<VectorType>();
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");
auto loc = op.getLoc();
auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
loc, lowType, rewriter.getZeroAttr(lowType));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < dim; d++) {
Value bnd =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
auto pos = rewriter.getI64ArrayAttr(d);
result =
rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
};
class ShapeCastOp2DDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
desc = rewriter.create<vector::InsertStridedSliceOp>(
loc, vec, desc,
i * mostMinorVectorSize, 1);
}
rewriter.replaceOp(op, desc);
return success();
}
};
class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getSource(), i * mostMinorVectorSize,
mostMinorVectorSize,
1);
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
}
rewriter.replaceOp(op, desc);
return success();
}
};
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
return failure();
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
SmallVector<int64_t, 4> srcIdx(srcRank);
SmallVector<int64_t, 4> resIdx(resRank);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
}
Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
}
rewriter.replaceOp(op, result);
return success();
}
private:
static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
assert(0 <= r && r < tp.getRank());
if (++idx[r] == tp.getDimSize(r)) {
idx[r] = 0;
incIdx(idx, tp, r - 1);
}
}
};
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
if (reduceOp.getKind() != vector::CombiningKind::ADD)
return failure();
Operation *mulOp = reduceOp.getSource().getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return failure();
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
if (!isReduceDim.value()) {
iteratorTypes.push_back(getParallelIteratorTypeName());
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
} else {
iteratorTypes.push_back(getReductionIteratorTypeName());
}
}
auto dstMap = AffineMap::get(reductionMask.size(),
0, exprs, reduceOp.getContext());
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
rewriter.getStrArrayAttr(iteratorTypes));
return success();
}
};
struct CombineContractTranspose
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
continue;
SmallVector<int64_t> perm;
transposeOp.getTransp(perm);
AffineMap permutationMap = AffineMap::getPermutationMap(
extractVector<unsigned>(transposeOp.getTransp()),
contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.getVector();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
return success();
}
};
struct CombineContractBroadcast
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
if (!broadcast)
continue;
auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
continue;
int64_t rankDiff =
broadcast.getVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
if (dim.value() !=
broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
originalDims.push_back(
rewriter.getAffineDimExpr(dim.index() + rankDiff));
}
if (innerDimBroadcast)
continue;
bool nonUnitDimReductionBroadcast = false;
for (int64_t i = 0; i < rankDiff; ++i) {
if (broadcast.getVectorType().getDimSize(i) != 1 &&
isReductionIterator(contractOp.getIteratorTypes()
.getValue()[map.getDimPosition(i)])) {
nonUnitDimReductionBroadcast = true;
break;
}
}
if (nonUnitDimReductionBroadcast)
continue;
AffineMap broadcastMap =
AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.getSource();
changed = true;
}
if (!changed)
return failure();
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
for (auto &m : maps)
m = compressDims(m, unusedDimsBitVector);
SmallVector<Attribute, 4> iterators;
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}
bool hasReductionIteratorApplyingOnBothSides = false;
for (unsigned i = 0; i < iterators.size(); ++i) {
if (!isReductionIterator(iterators[i]))
continue;
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
hasReductionIteratorApplyingOnBothSides = true;
break;
}
}
if (!hasReductionIteratorApplyingOnBothSides)
return failure();
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
return success();
}
};
struct ReorderCastOpsOnBroadcast
: public OpInterfaceRewritePattern<CastOpInterface> {
using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(CastOpInterface op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1)
return failure();
auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
if (!bcastOp)
return failure();
Type castResTy = getElementTypeOrSelf(op->getResult(0));
if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
castResTy = VectorType::get(vecTy.getShape(), castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
bcastOp.getSource(), castResTy, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, op->getResult(0).getType(), castOp->getResult(0));
return success();
}
};
struct ReorderElementwiseOpsOnTranspose final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
SmallVector<ArrayAttr, 4> transposeMaps;
transposeMaps.reserve(op->getNumOperands());
VectorType srcType;
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
transposeMaps.push_back(transposeOp.getTransp());
srcType = transposeOp.getVectorType();
} else if (!matchPattern(operand, m_Constant())) {
return failure();
}
}
if (transposeMaps.empty())
return failure();
if (!llvm::is_splat(transposeMaps))
return rewriter.notifyMatchFailure(op, "different transpose map");
SmallVector<Value, 4> srcValues;
srcValues.reserve(op->getNumOperands());
auto order = extractVector<unsigned>(transposeMaps.front());
SmallVector<int64_t> invOrder(order.size());
for (int i = 0, e = order.size(); i < e; ++i)
invOrder[order[i]] = i;
for (Value operand : op->getOperands()) {
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
srcValues.push_back(transposeOp.getVector());
} else {
auto vectorType = VectorType::get(
srcType.getShape(),
operand.getType().cast<VectorType>().getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand,
rewriter.getI64ArrayAttr(invOrder)));
}
}
auto vectorType = VectorType::get(
srcType.getShape(),
op->getResultTypes()[0].cast<VectorType>().getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
vectorType, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
op, op->getResultTypes()[0], elementwiseOp->getResult(0),
transposeMaps.front());
return success();
}
};
}
static Value createAdd(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
return rewriter.create<arith::AddIOp>(loc, x, y);
return rewriter.create<arith::AddFOp>(loc, x, y);
}
static Value createMul(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
return rewriter.create<arith::MulIOp>(loc, x, y);
return rewriter.create<arith::MulFOp>(loc, x, y);
}
namespace mlir {
LogicalResult
ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rew) const {
if (llvm::size(op.getMasks()) != 0)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
Type elementType = op.getLhsType().getElementType();
if (!elementType.isIntOrFloat())
return failure();
Type dstElementType = op.getType();
if (auto vecType = dstElementType.dyn_cast<VectorType>())
dstElementType = vecType.getElementType();
if (elementType != dstElementType)
return failure();
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
AffineExpr m, n, k;
bindDims(rew.getContext(), m, n, k);
Value lhs = op.getLhs();
auto lhsMap = op.getIndexingMapsArray()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
return failure();
Value rhs = op.getRhs();
auto rhsMap = op.getIndexingMapsArray()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
return failure();
VectorType lhsType = lhs.getType().cast<VectorType>();
VectorType rhsType = rhs.getType().cast<VectorType>();
int64_t lhsRows = lhsType.getDimSize(0);
int64_t lhsColumns = lhsType.getDimSize(1);
int64_t rhsColumns = rhsType.getDimSize(1);
Type flattenedLHSType =
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
Type flattenedRHSType =
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
rhsColumns);
mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
getElementTypeOrSelf(op.getAcc().getType())),
mul);
auto accMap = op.getIndexingMapsArray()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
llvm_unreachable("invalid contraction semantics");
Value res =
elementType.isa<IntegerType>()
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
rew.replaceOp(op, res);
return success();
}
namespace {
struct IteratorType {
IteratorType(StringRef strRef) : strRef(strRef) {}
bool isOfType(Attribute attr) const {
auto sAttr = attr.dyn_cast<StringAttr>();
return sAttr && sAttr.getValue() == strRef;
}
StringRef strRef;
};
struct Par : public IteratorType {
Par() : IteratorType(getParallelIteratorTypeName()) {}
};
struct Red : public IteratorType {
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
struct UnrolledOuterProductGenerator
: public StructuredGenerator<vector::ContractionOp> {
UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op)
: StructuredGenerator<vector::ContractionOp>(builder, op),
kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
res(op.getAcc()), lhsType(op.getLhsType()) {}
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
return builder.create<vector::TransposeOp>(loc, v, perm);
}
Value promote(Value v, Type dstElementType) {
Type elementType = v.getType();
auto vecType = elementType.dyn_cast<VectorType>();
if (vecType)
elementType = vecType.getElementType();
if (elementType == dstElementType)
return v;
Type promotedType = dstElementType;
if (vecType)
promotedType = VectorType::get(vecType.getShape(), promotedType);
if (dstElementType.isa<FloatType>())
return builder.create<arith::ExtFOp>(loc, promotedType, v);
return builder.create<arith::ExtSIOp>(loc, promotedType, v);
}
Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0);
Type resElementType = res.getType().cast<VectorType>().getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
Value b = builder.create<vector::ExtractOp>(loc, rhs, k);
a = promote(a, resElementType);
b = promote(b, resElementType);
res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b,
res, kind);
}
return res;
}
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
return failure();
AffineExpr m, n, k;
bindDims(builder.getContext(), m, n, k);
if (layout({{m, k}, {k, n}, {m, n}}))
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
}
if (layout({{k, m}, {k, n}, {m, n}}))
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
if (layout({{k, m}, {n, k}, {m, n}}))
return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
if (layout({{m, k}, {k, n}, {n, m}}))
return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
}
if (layout({{k, m}, {k, n}, {n, m}}))
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
if (layout({{k, m}, {n, k}, {n, m}}))
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
return failure();
}
FailureOr<Value> matvec() {
if (!iters({Par(), Red()}))
return failure();
AffineExpr m, k;
bindDims(builder.getContext(), m, k);
if (layout({{m, k}, {k}, {m}}))
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
if (layout({{k, m}, {k}, {m}}))
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
if (layout({{k}, {m, k}, {m}}))
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
if (layout({{k}, {k, m}, {m}}))
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
FailureOr<Value> tmatvec() {
if (!iters({Red(), Par()}))
return failure();
AffineExpr k, m;
bindDims(builder.getContext(), k, m);
if (layout({{m, k}, {k}, {m}}))
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
if (layout({{k, m}, {k}, {m}}))
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
if (layout({{k}, {m, k}, {m}}))
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
if (layout({{k}, {k, m}, {m}}))
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
private:
vector::CombiningKind kind;
Value lhs, rhs, res;
VectorType lhsType;
};
}
LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
vector::ContractionOp op, PatternRewriter &rewriter) const {
if (llvm::size(op.getMasks()) != 0)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
return failure();
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
rewriter.replaceOp(op, *matmatRes);
return success();
}
FailureOr<Value> matvecRes = e.matvec();
if (succeeded(matvecRes)) {
rewriter.replaceOp(op, *matvecRes);
return success();
}
FailureOr<Value> tmatvecRes = e.tmatvec();
if (succeeded(tmatvecRes)) {
rewriter.replaceOp(op, *tmatvecRes);
return success();
}
return failure();
}
LogicalResult
ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
if (llvm::size(op.getMasks()) != 0)
return failure();
if (failed(filter(op)))
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Dot)
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
static constexpr std::array<int64_t, 2> perm = {1, 0};
Location loc = op.getLoc();
Value lhs = op.getLhs(), rhs = op.getRhs();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
if (isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])) {
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = tmp;
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = tmp;
} else {
return failure();
}
} else if (isParallelIterator(iteratorTypes[0]) &&
isReductionIterator(iteratorTypes[1])) {
if (maps == infer({{m, n}, {n}, {m}})) {
} else if (maps == infer({{n, m}, {n}, {m}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{n}, {m, n}, {m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{n}, {n, m}, {m}})) {
std::swap(lhs, rhs);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else {
return failure();
}
} else {
return failure();
}
VectorType dstType = op.getResultType().cast<VectorType>();
assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
"Expected dst type of rank 1 or 2");
unsigned rank = dstType.getRank();
unsigned dstRows = dstType.getShape()[0];
unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
bool isInt = dstType.getElementType().isa<IntegerType>();
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
}
}
if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
rewriter.replaceOp(op, res);
return success();
}
LogicalResult
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
if (llvm::size(op.getMasks()) != 0)
return failure();
if (failed(filter(op)))
return failure();
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
return failure();
MLIRContext *ctx = op.getContext();
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
if (succeeded(pat1.matchAndRewrite(op, rewriter)))
return success();
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
if (succeeded(pat2.matchAndRewrite(op, rewriter)))
return success();
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
if (succeeded(pat3.matchAndRewrite(op, rewriter)))
return success();
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
if (succeeded(pat4.matchAndRewrite(op, rewriter)))
return success();
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
op.getContractingDimMap();
DenseSet<int64_t> lhsContractingDimSet;
DenseSet<int64_t> rhsContractingDimSet;
for (auto &dimPair : contractingDimMap) {
lhsContractingDimSet.insert(dimPair.first);
rhsContractingDimSet.insert(dimPair.second);
}
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
auto newOp = lowerParallel(op, lhsIndex, -1, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
}
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
auto newOp = lowerParallel(op, -1, rhsIndex, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
}
if (!contractingDimMap.empty()) {
auto newOp = lowerReduction(op, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
return failure();
}
FailureOr<Value>
ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex,
PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
int64_t iterIndex = -1;
int64_t dimSize = -1;
if (lhsIndex >= 0) {
iterIndex = iMap[0].getDimPosition(lhsIndex);
if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
dimSize = lhsType.getDimSize(lhsIndex);
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
dimSize = rhsType.getDimSize(rhsIndex);
}
if (iterIndex < 0)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected either lhsIndex=" << lhsIndex
<< " or rhsIndex=" << rhsIndex << " to be nonnegative";
});
int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
if (resIndex == -1 && dimSize != 1)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected the dimension for iterIndex=" << iterIndex
<< " to either appear in the result map, or to be a unit dimension";
});
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
Value lowContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, lowAffine, lowIter);
result =
reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
}
return result;
}
FailureOr<Value>
ContractionOpLowering::lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
if (resType.isa<VectorType>())
return rewriter.notifyMatchFailure(op,
"did not expect a VectorType result");
bool isInt = resType.isa<IntegerType>();
int64_t iterIndex = 0;
SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray();
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
if (!lookupLhs.has_value())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
});
if (!lookupRhs.has_value())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
});
int64_t lhsIndex = lookupLhs.value();
int64_t rhsIndex = lookupRhs.value();
int64_t dimSize = lhsType.getDimSize(lhsIndex);
if (dimSize != rhsType.getDimSize(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expect LHS dimension " << lhsIndex
<< " to have the same size as RHS dimension " << rhsIndex;
});
if (lhsType.getRank() == 1) {
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
if (auto acc = op.getAcc())
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
.getResult();
return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
}
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
Value result = op.getAcc();
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
lowAffine, lowIter);
}
return result;
}
}
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(op);
Location loc = op->getLoc();
if (op->getNumResults() != 1)
return {};
Value result = op->getResult(0);
VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
if (!type || map.getNumResults() != multiplicity.size())
return {};
unsigned multiplictyCount = 0;
for (auto exp : map.getResults()) {
auto affinExp = exp.dyn_cast<AffineDimExpr>();
if (!affinExp || affinExp.getPosition() >= type.getRank() ||
type.getDimSize(affinExp.getPosition()) %
multiplicity[multiplictyCount++] !=
0)
return {};
}
DistributeOps ops;
ops.extract =
builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map);
ops.insert =
builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids);
return ops;
}
struct TransferReadToVectorLoadLowering
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadToVectorLoadLowering(MLIRContext *context,
llvm::Optional<unsigned> maxRank)
: OpRewritePattern<vector::TransferReadOp>(context),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank)
return failure();
SmallVector<unsigned, 4> broadcastedDims;
if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
&broadcastedDims))
return failure();
auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
if (!vector::isLastMemrefDimUnitStride(memRefType))
return failure();
ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(),
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
VectorType unbroadcastedVectorType = VectorType::get(
unbroadcastedVectorShape, read.getVectorType().getElementType());
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
return failure();
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != read.getVectorType().getElementType())
return failure();
if (read.hasOutOfBoundsDim())
return failure();
Operation *loadOp;
if (read.getMask()) {
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices(), read.getMask(), fill);
} else {
loadOp = rewriter.create<vector::LoadOp>(
read.getLoc(), unbroadcastedVectorType, read.getSource(),
read.getIndices());
}
if (!broadcastedDims.empty()) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
read, read.getVectorType(), loadOp->getResult(0));
} else {
rewriter.replaceOp(read, loadOp->getResult(0));
}
return success();
}
llvm::Optional<unsigned> maxTransferRank;
};
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
}
};
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return failure();
Value extracted;
if (vecType.getRank() == 0) {
extracted = rewriter.create<vector::ExtractElementOp>(
storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};
struct TransferWriteToVectorStoreLowering
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteToVectorStoreLowering(MLIRContext *context,
llvm::Optional<unsigned> maxRank)
: OpRewritePattern<vector::TransferWriteOp>(context),
maxTransferRank(maxRank) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank)
return failure();
if (
!write.getPermutationMap().isMinorIdentity())
return failure();
auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
if (!memRefType)
return failure();
if (!vector::isLastMemrefDimUnitStride(memRefType))
return failure();
auto memrefElTy = memRefType.getElementType();
if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
return failure();
if (!memrefElTy.isa<VectorType>() &&
memrefElTy != write.getVectorType().getElementType())
return failure();
if (write.hasOutOfBoundsDim())
return failure();
if (write.getMask()) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
write, write.getSource(), write.getIndices(), write.getMask(),
write.getVector());
} else {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
write, write.getVector(), write.getSource(), write.getIndices());
}
return success();
}
llvm::Optional<unsigned> maxTransferRank;
};
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return attr.getInt(); }));
}
struct BubbleDownVectorBitCastForExtract
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (extractOp.getVectorType().getRank() != 1)
return failure();
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
if (castSrcType.getNumElements() == 1)
return failure();
if (castSrcType.getNumElements() > castDstType.getNumElements())
return failure();
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
};
uint64_t index = getFirstIntValue(extractOp.getPosition());
VectorType oneScalarType =
VectorType::get({1}, castSrcType.getElementType());
Value packedValue = rewriter.create<vector::ExtractOp>(
extractOp.getLoc(), oneScalarType, castOp.getSource(),
rewriter.getI64ArrayAttr(index / expandRatio));
VectorType packedType =
VectorType::get({expandRatio}, castDstType.getElementType());
Value castedValue = rewriter.create<vector::BitCastOp>(
extractOp.getLoc(), packedType, packedValue);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getType(), castedValue,
rewriter.getI64ArrayAttr(index % expandRatio));
return success();
}
};
struct BubbleDownBitCastForStridedSliceExtract
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
if (!castOp)
return failure();
VectorType castSrcType = castOp.getSourceVectorType();
VectorType castDstType = castOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
if (castSrcLastDim > castDstLastDim)
return failure();
if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
unsigned rank = extractOp.getVectorType().getRank();
assert(castDstLastDim % castSrcLastDim == 0);
int64_t expandRatio = castDstLastDim / castSrcLastDim;
ArrayAttr newOffsets = extractOp.getOffsets();
if (newOffsets.size() == rank) {
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % expandRatio != 0)
return failure();
offsets.back() = offsets.back() / expandRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
}
ArrayAttr newSizes = extractOp.getSizes();
if (newSizes.size() == rank) {
SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes);
if (sizes.back() % expandRatio != 0)
return failure();
sizes.back() = sizes.back() / expandRatio;
newSizes = rewriter.getI64ArrayAttr(sizes);
}
SmallVector<int64_t, 4> dims =
llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
newSizes, extractOp.getStrides());
rewriter.replaceOpWithNewOp<vector::BitCastOp>(
extractOp, extractOp.getType(), newExtractOp);
return success();
}
};
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
VectorType castSrcType = bitcastOp.getSourceVectorType();
VectorType castDstType = bitcastOp.getResultVectorType();
assert(castSrcType.getRank() == castDstType.getRank());
int64_t castSrcLastDim = castSrcType.getShape().back();
int64_t castDstLastDim = castDstType.getShape().back();
if (castSrcLastDim < castDstLastDim)
return failure();
assert(castSrcLastDim % castDstLastDim == 0);
int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
auto insertOp =
bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
if (!insertOp)
return failure();
if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
unsigned rank = insertOp.getSourceVectorType().getRank();
if (rank != insertOp.getDestVectorType().getRank())
return failure();
ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
if (offsets.back() % shrinkRatio != 0)
return failure();
offsets.back() = offsets.back() / shrinkRatio;
newOffsets = rewriter.getI64ArrayAttr(offsets);
SmallVector<int64_t, 4> srcDims =
llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
srcDims.back() = srcDims.back() / shrinkRatio;
VectorType newCastSrcType =
VectorType::get(srcDims, castDstType.getElementType());
auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
SmallVector<int64_t, 4> dstDims =
llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
dstDims.back() = dstDims.back() / shrinkRatio;
VectorType newCastDstType =
VectorType::get(dstDims, castDstType.getElementType());
auto newCastDstOp = rewriter.create<vector::BitCastOp>(
bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
insertOp.getStrides());
return success();
}
};
static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
bool force32BitVectorIndices, int64_t dim,
Value b, Value *off = nullptr) {
auto loc = op->getLoc();
Type idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
DenseIntElementsAttr indicesAttr;
if (dim == 0 && force32BitVectorIndices) {
indicesAttr = DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
} else if (dim == 0) {
indicesAttr = DenseIntElementsAttr::get(
VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
} else if (force32BitVectorIndices) {
indicesAttr = rewriter.getI32VectorAttr(
llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
} else {
indicesAttr = rewriter.getI64VectorAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
}
Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
}
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
bounds);
}
template <typename ConcreteOp>
struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
public:
explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt)
: mlir::OpRewritePattern<ConcreteOp>(context),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(ConcreteOp xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp.hasOutOfBoundsDim())
return failure();
if (xferOp.getVectorType().getRank() > 1 ||
llvm::size(xferOp.getIndices()) == 0)
return failure();
Location loc = xferOp->getLoc();
VectorType vtp = xferOp.getVectorType();
unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
Value off = xferOp.getIndices()[lastIndex];
Value dim =
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
vtp.getNumScalableDims()),
b);
if (xferOp.getMask()) {
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
}
rewriter.updateRootInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(mask);
xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
});
return success();
}
private:
const bool force32BitVectorIndices;
};
class VectorCreateMaskOpConversion
: public OpRewritePattern<vector::CreateMaskOp> {
public:
explicit VectorCreateMaskOpConversion(MLIRContext *context,
bool enableIndexOpt)
: mlir::OpRewritePattern<vector::CreateMaskOp>(context),
force32BitVectorIndices(enableIndexOpt) {}
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.cast<VectorType>().isScalable())
return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
return failure();
rewriter.replaceOp(
op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
rank == 0 ? 0 : dstType.getDimSize(0),
op.getOperand(0)));
return success();
}
private:
const bool force32BitVectorIndices;
};
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
if (readOp.getTransferRank() == 0)
return failure();
if (readOp.getMask())
return failure();
auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
if (!srcType || !srcType.hasStaticShape())
return failure();
if (!readOp.getPermutationMap().isMinorIdentity())
return failure();
auto targetType = readOp.getVectorType();
if (targetType.getRank() <= 1)
return failure();
SmallVector<int64_t> srcStrides;
int64_t srcOffset;
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();
size_t dimsToDrop = 0;
for (size_t i = 1; i < srcStrides.size(); ++i) {
int dim = srcType.getRank() - i - 1;
if (srcStrides[dim] == 1) {
dimsToDrop++;
} else {
break;
}
}
if (dimsToDrop == 0)
return failure();
auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());
MemRefType resultMemrefType;
if (srcType.getLayout().getAffineMap().isIdentity()) {
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
{}, srcType.getMemorySpaceAsInt());
} else {
AffineMap map = srcType.getLayout().getAffineMap();
int numSymbols = map.getNumSymbols();
for (size_t i = 0; i < dimsToDrop; ++i) {
int dim = srcType.getRank() - i - 1;
map = map.replace(rewriter.getAffineDimExpr(dim),
rewriter.getAffineConstantExpr(0),
map.getNumDims() - 1, numSymbols);
}
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
map, srcType.getMemorySpaceAsInt());
}
auto loc = readOp.getLoc();
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
ArrayAttr inBoundsAttr =
readOp.getInBounds()
? rewriter.getArrayAttr(
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
strides);
auto permMap = getTransferMinorIdentityMap(
rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
readOp.getPadding(),
Value(), inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
result);
return success();
}
};
namespace {
static bool isValidKind(bool isInt, vector::CombiningKind kind) {
using vector::CombiningKind;
enum class KindType { FLOAT, INT, INVALID };
KindType type{KindType::INVALID};
switch (kind) {
case CombiningKind::MINF:
case CombiningKind::MAXF:
type = KindType::FLOAT;
break;
case CombiningKind::MINUI:
case CombiningKind::MINSI:
case CombiningKind::MAXUI:
case CombiningKind::MAXSI:
case CombiningKind::AND:
case CombiningKind::OR:
case CombiningKind::XOR:
type = KindType::INT;
break;
case CombiningKind::ADD:
case CombiningKind::MUL:
type = isInt ? KindType::INT : KindType::FLOAT;
break;
}
bool isValidIntKind = (type == KindType::INT) && isInt;
bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
return (isValidIntKind || isValidFloatKind);
}
static Value genOperator(Location loc, Value x, Value y,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
auto elType = x.getType().cast<VectorType>().getElementType();
bool isInt = elType.isIntOrIndex();
Value combinedResult{nullptr};
switch (kind) {
case CombiningKind::ADD:
if (isInt)
combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
else
combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
break;
case CombiningKind::MUL:
if (isInt)
combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
else
combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
break;
case CombiningKind::MINUI:
combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
break;
case CombiningKind::MINSI:
combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
break;
case CombiningKind::MAXUI:
combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
break;
case CombiningKind::MAXSI:
combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
break;
case CombiningKind::OR:
combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
break;
case CombiningKind::XOR:
combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
break;
case CombiningKind::MINF:
combinedResult = rewriter.create<arith::MinFOp>(loc, x, y);
break;
case CombiningKind::MAXF:
combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y);
break;
}
return combinedResult;
}
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
using OpRewritePattern<vector::ScanOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
auto loc = scanOp.getLoc();
VectorType destType = scanOp.getDestType();
ArrayRef<int64_t> destShape = destType.getShape();
auto elType = destType.getElementType();
bool isInt = elType.isIntOrIndex();
if (!isValidKind(isInt, scanOp.getKind()))
return failure();
VectorType resType = VectorType::get(destShape, elType);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
int64_t reductionDim = scanOp.getReductionDim();
bool inclusive = scanOp.getInclusive();
int64_t destRank = destType.getRank();
VectorType initialValueType = scanOp.getInitialValueType();
int64_t initialValueRank = initialValueType.getRank();
SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
reductionShape[reductionDim] = 1;
VectorType reductionType = VectorType::get(reductionShape, elType);
SmallVector<int64_t> offsets(destRank, 0);
SmallVector<int64_t> strides(destRank, 1);
SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
sizes[reductionDim] = 1;
ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
Value lastOutput, lastInput;
for (int i = 0; i < destShape[reductionDim]; i++) {
offsets[reductionDim] = i;
ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
Value input = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
scanStrides);
Value output;
if (i == 0) {
if (inclusive) {
output = input;
} else {
if (initialValueRank == 0) {
output = rewriter.create<vector::BroadcastOp>(
loc, input.getType(), scanOp.getInitialValue());
} else {
output = rewriter.create<vector::ShapeCastOp>(
loc, input.getType(), scanOp.getInitialValue());
}
}
} else {
Value y = inclusive ? input : lastInput;
output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
assert(output != nullptr);
}
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, output, result, offsets, strides);
lastOutput = output;
lastInput = input;
}
Value reduction;
if (initialValueRank == 0) {
Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
reduction =
rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
} else {
reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
lastOutput);
}
rewriter.replaceOp(scanOp, {result, reduction});
return success();
}
};
}
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool force32BitVectorIndices) {
patterns.add<VectorCreateMaskOpConversion,
MaterializeTransferMask<vector::TransferReadOp>,
MaterializeTransferMask<vector::TransferWriteOp>>(
patterns.getContext(), force32BitVectorIndices);
}
void mlir::vector::populateShapeCastFoldingPatterns(
RewritePatternSet &patterns) {
patterns.add<ShapeCastOpFolder>(patterns.getContext());
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
void mlir::vector::populateVectorBroadcastLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<BroadcastOpLowering>(patterns.getContext());
}
void mlir::vector::populateVectorMaskOpLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
patterns.getContext());
}
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
patterns.getContext());
}
void mlir::vector::populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options) {
patterns.add<OuterProductOpLowering>(patterns.getContext());
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(options,
patterns.getContext());
}
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options) {
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext());
}
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose, ReorderCastOpsOnBroadcast,
ReorderElementwiseOpsOnTranspose>(patterns.getContext());
}
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DropInnerMostUnitDims>(patterns.getContext());
}
void mlir::vector::populateVectorTransferLoweringPatterns(
RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) {
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext());
}
void mlir::vector::populateVectorScanLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<ScanToArithOps>(patterns.getContext());
}