#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#define DEBUG_TYPE "vector-contract-lowering"
using namespace mlir;
using namespace mlir::vector;
static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
return i;
}
return std::nullopt;
}
static SmallVector<Attribute> adjustIter(ArrayAttr iteratorTypes,
int64_t index) {
SmallVector<Attribute> results;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
int64_t idx = it.index();
if (idx == index)
continue;
results.push_back(it.value());
}
return results;
}
static AffineMap adjustMap(AffineMap map, int64_t index,
PatternRewriter &rewriter) {
auto *ctx = rewriter.getContext();
SmallVector<AffineExpr> results;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t idx = map.getDimPosition(i);
if (idx == index)
continue;
auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
results.push_back(targetExpr);
}
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
static Value reshapeLoad(Location loc, Value val, VectorType type,
int64_t index, int64_t pos,
PatternRewriter &rewriter) {
if (index == -1)
return val;
if (index == 0)
return rewriter.create<vector::ExtractOp>(loc, val, pos);
VectorType vType = VectorType::Builder(type).dropDim(0);
VectorType resType = VectorType::Builder(type).dropDim(index);
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
Value ext = rewriter.create<vector::ExtractOp>(loc, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, load, result, d);
}
return result;
}
static Value reshapeStore(Location loc, Value val, Value result,
VectorType type, int64_t index, int64_t pos,
PatternRewriter &rewriter) {
if (index == -1)
return val;
if (index == 0)
return rewriter.create<vector::InsertOp>(loc, val, result, pos);
VectorType vType = VectorType::Builder(type).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
Value ext = rewriter.create<vector::ExtractOp>(loc, result, d);
Value ins = rewriter.create<vector::ExtractOp>(loc, val, d);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, sto, result, d);
}
return result;
}
static std::optional<Value>
createContractArithOp(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind, PatternRewriter &rewriter,
bool isInt, Value mask = Value()) {
using vector::CombiningKind;
Value mul;
if (isInt) {
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
return std::nullopt;
mul = rewriter.create<arith::MulIOp>(loc, x, y);
} else {
if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
kind == CombiningKind::XOR)
return std::nullopt;
if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
if (mask)
fma = selectPassthru(rewriter, mask, fma, acc);
return fma;
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
if (!acc)
return std::optional<Value>(mul);
return makeArithReduction(rewriter, loc, kind, mul, acc,
nullptr, mask);
}
static SmallVector<int64_t> getReductionIndex(AffineMap map,
ArrayAttr iteratorTypes) {
SmallVector<int64_t> dimsIdx;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
dimsIdx.push_back(i);
}
return dimsIdx;
}
static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (map.getDimPosition(i) == dim)
return i;
}
return std::nullopt;
}
static Value createAdd(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
return rewriter.create<arith::AddIOp>(loc, x, y);
return rewriter.create<arith::AddFOp>(loc, x, y);
}
static Value createMul(Location loc, Value x, Value y, bool isInt,
PatternRewriter &rewriter) {
if (isInt)
return rewriter.create<arith::MulIOp>(loc, x, y);
return rewriter.create<arith::MulFOp>(loc, x, y);
}
namespace {
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
class ContractionOpToOuterProductOpLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
class ContractionOpToDotLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpToDotLowering(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
class ContractionOpLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, Value mask) const;
FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
vector::ContractionOp op, Value mask) const;
};
struct UnrolledOuterProductGenerator
: public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
: StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
res(op.getAcc()), lhsType(op.getLhsType()) {
auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
if (maskableOp.isMasked())
mask = maskableOp.getMaskingOp().getMask();
}
Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
if (!v)
return v;
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
Value promote(Value v, Type dstElementType) {
Type elementType = v.getType();
auto vecType = dyn_cast<VectorType>(elementType);
if (vecType)
elementType = vecType.getElementType();
if (elementType == dstElementType)
return v;
Type promotedType = dstElementType;
if (vecType)
promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
if (mask && !maybeMask.has_value())
return failure();
Type resElementType = cast<VectorType>(res.getType()).getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
extractA = promote(extractA, resElementType);
extractB = promote(extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
loc, res.getType(), extractA, extractB, res, kind);
res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
}
return res;
}
std::optional<int64_t> getReductionSize(VectorType vecType,
int64_t reductionDim) {
if (vecType.getScalableDims()[reductionDim])
return std::nullopt;
int64_t reductionSize = vecType.getDimSize(reductionDim);
assert(reductionSize > 0 &&
"Reduction dim must be a known static size to allow unrolling");
return reductionSize;
}
FailureOr<Value> matmat() {
if (!iters({Par(), Par(), Red()}))
return failure();
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
if (layout({{m, k}, {k, n}, {m, n}})) {
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tLhs = t(lhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{m, k}, {n, k}, {m, n}})) {
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tLhs = t(lhs);
Value tRhs = t(rhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k, m}, {k, n}, {m, n}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tMask = t(mask, {2, 0, 1});
return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k, m}, {n, k}, {m, n}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tRhs = t(rhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{m, k}, {k, n}, {n, m}})) {
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tLhs = t(lhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{m, k}, {n, k}, {n, m}})) {
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tRhs = t(rhs);
Value tLhs = t(lhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k, m}, {k, n}, {n, m}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tMask = t(mask, {2, 0, 1});
return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k, m}, {n, k}, {n, m}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tRhs = t(rhs);
Value tMask = t(mask, {2, 0, 1});
return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
}
}
return failure();
}
FailureOr<Value> matvec() {
if (!iters({Par(), Red()}))
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
if (layout({{m, k}, {k}, {m}})) {
if (auto reductionSize = getReductionSize(lhsType, 1)) {
Value tLhs = t(lhs);
Value tMask = t(mask);
return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k, m}, {k}, {m}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tMask = t(mask);
return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k}, {m, k}, {m}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tRhs = t(rhs);
Value tMask = t(mask);
return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
}
}
if (layout({{k}, {k, m}, {m}})) {
if (auto reductionSize = getReductionSize(lhsType, 0)) {
Value tMask = t(mask);
return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
}
}
return failure();
}
FailureOr<Value> tmatvec() {
if (!iters({Red(), Par()}))
return failure();
AffineExpr k, m;
bindDims(rewriter.getContext(), k, m);
if (layout({{m, k}, {k}, {m}}))
if (auto reductionSize = getReductionSize(lhsType, 1))
return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
if (layout({{k, m}, {k}, {m}}))
if (auto reductionSize = getReductionSize(lhsType, 0))
return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
if (layout({{k}, {m, k}, {m}}))
if (auto reductionSize = getReductionSize(lhsType, 0))
return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
if (layout({{k}, {k, m}, {m}}))
if (auto reductionSize = getReductionSize(lhsType, 0))
return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
return failure();
}
private:
vector::CombiningKind kind;
Value lhs, rhs, res, mask;
VectorType lhsType;
};
FailureOr<Value>
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
return failure();
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
return matmatRes;
}
FailureOr<Value> matvecRes = e.matvec();
if (succeeded(matvecRes)) {
return matvecRes;
}
FailureOr<Value> tmatvecRes = e.tmatvec();
return tmatvecRes;
}
FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
if (maskOp)
return failure();
if (failed(filter(op)))
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Dot)
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
static constexpr std::array<int64_t, 2> perm = {1, 0};
Location loc = op.getLoc();
Value lhs = op.getLhs(), rhs = op.getRhs();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, op.getContext());
};
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
SmallVector<AffineMap> maps = op.getIndexingMapsArray();
if (isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])) {
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = tmp;
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = tmp;
} else {
return failure();
}
} else if (isParallelIterator(iteratorTypes[0]) &&
isReductionIterator(iteratorTypes[1])) {
if (maps == infer({{m, n}, {n}, {m}})) {
} else if (maps == infer({{n, m}, {n}, {m}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{n}, {m, n}, {m}})) {
std::swap(lhs, rhs);
} else if (maps == infer({{n}, {n, m}, {m}})) {
std::swap(lhs, rhs);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else {
return failure();
}
} else {
return failure();
}
VectorType dstType = cast<VectorType>(op.getResultType());
assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
"Expected dst type of rank 1 or 2");
unsigned rank = dstType.getRank();
unsigned dstRows = dstType.getShape()[0];
unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
bool isInt = isa<IntegerType>(dstType.getElementType());
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
: SmallVector<int64_t, 2>{r, c};
res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
}
}
if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
return res;
}
struct ContractOpToElementwise
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override {
if (maskOp)
return failure();
if (failed(filter(contractOp)))
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
return failure();
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
SmallVector<int64_t> lhsReductionDims =
getReductionIndex(lhsMap, contractOp.getIteratorTypes());
SmallVector<int64_t> rhsReductionDims =
getReductionIndex(rhsMap, contractOp.getIteratorTypes());
for (int64_t dim : lhsReductionDims) {
if (lhsShape[dim] != 1)
return failure();
}
for (int64_t dim : rhsReductionDims) {
if (rhsShape[dim] != 1)
return failure();
}
AffineMap accMap = contractOp.getIndexingMapsArray()[2];
unsigned numParallelDims = accMap.getNumResults();
unsigned numLhsDimToBroadcast =
numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
unsigned numRhsDimToBroadcast =
numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
SmallVector<int64_t> lhsDims;
SmallVector<int64_t> lhsTranspose;
SmallVector<int64_t> rhsDims;
SmallVector<int64_t> rhsTranspose;
for (int64_t dim : lhsReductionDims)
lhsTranspose.push_back(numLhsDimToBroadcast + dim);
for (int64_t dim : rhsReductionDims)
rhsTranspose.push_back(numRhsDimToBroadcast + dim);
for (unsigned i = 0; i < numParallelDims; i++) {
std::optional<unsigned> lhsDim =
getDimPosition(lhsMap, accMap.getDimPosition(i));
if (lhsDim) {
lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
} else {
lhsDims.push_back(
cast<VectorType>(contractOp.getResultType()).getDimSize(i));
lhsTranspose.push_back(lhsDims.size() - 1);
}
std::optional<unsigned> rhsDim =
getDimPosition(rhsMap, accMap.getDimPosition(i));
if (rhsDim) {
rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
} else {
rhsDims.push_back(
cast<VectorType>(contractOp.getResultType()).getDimSize(i));
rhsTranspose.push_back(rhsDims.size() - 1);
}
}
Value newLhs = contractOp.getLhs();
Value newRhs = contractOp.getRhs();
Location loc = contractOp.getLoc();
if (!lhsDims.empty()) {
lhsDims.append(lhsShape.begin(), lhsShape.end());
auto expandedType =
VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
}
if (!rhsDims.empty()) {
rhsDims.append(rhsShape.begin(), rhsShape.end());
auto expandedType =
VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
}
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
newLhs = rewriter.create<vector::ExtractOp>(loc, newLhs, lhsOffsets);
newRhs = rewriter.create<vector::ExtractOp>(loc, newRhs, rhsOffsets);
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
if (result)
return *result;
return failure();
}
private:
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const {
if (failed(filter(op)))
return failure();
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
return failure();
if (op.getKind() != vector::CombiningKind::ADD) {
return rewriter.notifyMatchFailure(
op, "contractions other than 'add' not supported");
}
MLIRContext *ctx = op.getContext();
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
FailureOr<Value> newVal1 =
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal1))
return newVal1;
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
FailureOr<Value> newVal2 =
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal2))
return newVal2;
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
FailureOr<Value> newVal3 =
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal3))
return newVal3;
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
FailureOr<Value> newVal4 =
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
if (!failed(newVal4))
return newVal4;
Value mask;
if (maskOp)
mask = maskOp.getMask();
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
if (failed(newOp))
return failure();
return newOp;
}
std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
op.getContractingDimMap();
DenseSet<int64_t> lhsContractingDimSet;
DenseSet<int64_t> rhsContractingDimSet;
for (auto &dimPair : contractingDimMap) {
lhsContractingDimSet.insert(dimPair.first);
rhsContractingDimSet.insert(dimPair.second);
}
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
auto newOp = lowerParallel(rewriter, op, lhsIndex, -1, mask);
if (failed(newOp))
return failure();
return newOp;
}
}
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
auto newOp = lowerParallel(rewriter, op, -1, rhsIndex, mask);
if (failed(newOp))
return failure();
return newOp;
}
}
if (!contractingDimMap.empty()) {
auto newOp = lowerReduction(rewriter, op, mask);
if (failed(newOp))
return failure();
return newOp;
}
return failure();
}
FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
vector::ContractionOp op,
int64_t lhsIndex,
int64_t rhsIndex,
Value mask) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = cast<VectorType>(op.getResultType());
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
int64_t iterIndex = -1;
int64_t dimSize = -1;
if (lhsIndex >= 0) {
iterIndex = iMap[0].getDimPosition(lhsIndex);
if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
if (lhsType.getScalableDims()[lhsIndex])
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
<< ") is not supported yet";
});
dimSize = lhsType.getDimSize(lhsIndex);
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
if (rhsType.getScalableDims()[rhsIndex])
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
<< ") is not supported yet";
});
dimSize = rhsType.getDimSize(rhsIndex);
}
if (iterIndex < 0)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected either lhsIndex=" << lhsIndex
<< " or rhsIndex=" << rhsIndex << " to be nonnegative";
});
int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
if (resIndex == -1 && dimSize != 1)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected the dimension for iterIndex=" << iterIndex
<< " to either appear in the result map, or to be a unit dimension";
});
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
Value lowMask;
if (mask)
lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
iterIndex, d, rewriter);
Operation *lowContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, lowAffine, lowIter);
lowContract = maskOperation(rewriter, lowContract, lowMask);
result = reshapeStore(loc, lowContract->getResult(0), result, resType,
resIndex, d, rewriter);
}
return result;
}
FailureOr<Value> ContractionOpLowering::lowerReduction(
PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
if (isa<VectorType>(resType))
return rewriter.notifyMatchFailure(op,
"did not expect a VectorType result");
bool isInt = isa<IntegerType>(resType);
int64_t iterIndex = 0;
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
if (!lookupLhs.has_value())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
});
if (!lookupRhs.has_value())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
});
int64_t lhsIndex = *lookupLhs;
int64_t rhsIndex = *lookupRhs;
int64_t dimSize = lhsType.getDimSize(lhsIndex);
if (dimSize != rhsType.getDimSize(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expect LHS dimension " << lhsIndex
<< " to have the same size as RHS dimension " << rhsIndex;
});
if (lhsType.getRank() == 1) {
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
Value acc = op.getAcc();
Operation *reductionOp =
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
: rewriter.create<vector::ReductionOp>(loc, kind, m);
return maskOperation(rewriter, reductionOp, mask)->getResult(0);
}
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
adjustMap(iMap[1], iterIndex, rewriter),
adjustMap(iMap[2], iterIndex, rewriter)};
auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
auto lowIter =
rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
Value result = op.getAcc();
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
Value newMask;
if (mask)
newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
iterIndex, d, rewriter);
Operation *newContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, result, lowAffine, lowIter);
result = maskOperation(rewriter, newContract, newMask)->getResult(0);
}
return result;
}
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
VectorType resType = op.getResultVectorType();
if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
return failure();
auto loc = op.getLoc();
VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
Type eltType = resType.getElementType();
bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = op.getAcc();
vector::CombiningKind kind = op.getKind();
OpBuilder::InsertionGuard guard(rewriter);
auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
Operation *rootOp;
Value mask;
if (maskableOp.isMasked()) {
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
rootOp = maskableOp.getMaskingOp();
mask = maskableOp.getMaskingOp().getMask();
} else {
rootOp = op;
}
if (!rhsType) {
Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
std::optional<Value> mult = createContractArithOp(
loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
if (!mult.has_value())
return failure();
rewriter.replaceOp(rootOp, *mult);
return success();
}
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), d);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, acc, d);
Value extrMask;
if (mask)
extrMask = rewriter.create<vector::ExtractOp>(loc, mask, d);
std::optional<Value> m = createContractArithOp(
loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, *m, result, d);
}
rewriter.replaceOp(rootOp, result);
return success();
}
};
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rew) const {
if (maskOp)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::Matmul)
return failure();
if (failed(filter(op)))
return failure();
auto iteratorTypes = op.getIteratorTypes().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
!isReductionIterator(iteratorTypes[2]))
return failure();
Type elementType = op.getLhsType().getElementType();
if (!elementType.isIntOrFloat())
return failure();
Type dstElementType = op.getType();
if (auto vecType = dyn_cast<VectorType>(dstElementType))
dstElementType = vecType.getElementType();
if (elementType != dstElementType)
return failure();
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
AffineExpr m, n, k;
bindDims(rew.getContext(), m, n, k);
Value lhs = op.getLhs();
auto lhsMap = op.getIndexingMapsArray()[0];
if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0});
else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
return failure();
Value rhs = op.getRhs();
auto rhsMap = op.getIndexingMapsArray()[1];
if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0});
else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
return failure();
VectorType lhsType = cast<VectorType>(lhs.getType());
VectorType rhsType = cast<VectorType>(rhs.getType());
int64_t lhsRows = lhsType.getDimSize(0);
int64_t lhsColumns = lhsType.getDimSize(1);
int64_t rhsColumns = rhsType.getDimSize(1);
Type flattenedLHSType =
VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs);
Type flattenedRHSType =
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
rhsColumns);
mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
getElementTypeOrSelf(op.getAcc().getType())),
mul);
auto accMap = op.getIndexingMapsArray()[2];
if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0});
else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
llvm_unreachable("invalid contraction semantics");
Value res =
isa<IntegerType>(elementType)
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
return res;
}
}
void mlir::vector::populateVectorContractLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit, bool disableOuterProductLowering) {
if (!disableOuterProductLowering)
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(
options, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorOuterProductLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
}