#include <numeric>
#include <optional>
#include <type_traits>
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOSCF
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
using vector::TransferReadOp;
using vector::TransferWriteOp;
namespace {
static const char kPassLabel[] = "__vector_to_scf_lowering__";
template <typename OpTy>
struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
explicit VectorToSCFPattern(MLIRContext *context,
VectorTransferToSCFOptions opt)
: OpRewritePattern<OpTy>(context), options(opt) {}
VectorTransferToSCFOptions options;
};
template <typename OpTy>
static std::optional<int64_t> unpackedDim(OpTy xferOp) {
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
auto map = xferOp.getPermutationMap();
if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
return expr.getPosition();
}
assert(xferOp.isBroadcastDim(0) &&
"Expected AffineDimExpr or AffineConstantExpr");
return std::nullopt;
}
template <typename OpTy>
static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
auto map = xferOp.getPermutationMap();
return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
b.getContext());
}
template <typename OpTy>
static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
SmallVector<Value, 8> &indices) {
typename OpTy::Adaptor adaptor(xferOp);
auto dim = unpackedDim(xferOp);
auto prevIndices = adaptor.getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
Location loc = xferOp.getLoc();
bool isBroadcast = !dim.has_value();
if (!isBroadcast) {
AffineExpr d0, d1;
bindDims(xferOp.getContext(), d0, d1);
Value offset = adaptor.getIndices()[*dim];
indices[*dim] =
affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
}
}
static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
Value value) {
if (hasRetVal) {
assert(value && "Expected non-empty value");
b.create<scf::YieldOp>(loc, value);
} else {
b.create<scf::YieldOp>(loc);
}
}
template <typename OpTy>
static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
if (!xferOp.getMask())
return Value();
if (xferOp.getMaskType().getRank() != 1)
return Value();
if (xferOp.isBroadcastDim(0))
return Value();
Location loc = xferOp.getLoc();
return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
}
template <typename OpTy>
static Value generateInBoundsCheck(
OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
TypeRange resultTypes,
function_ref<Value(OpBuilder &, Location)> inBoundsCase,
function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
bool hasRetVal = !resultTypes.empty();
Value cond;
bool isBroadcast = !dim;
Location loc = xferOp.getLoc();
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
Value memrefDim =
vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
AffineExpr d0, d1;
bindDims(xferOp.getContext(), d0, d1);
Value base = xferOp.getIndices()[*dim];
Value memrefIdx =
affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
memrefIdx);
}
if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
if (cond)
cond = lb.create<arith::AndIOp>(cond, maskCond);
else
cond = maskCond;
}
if (cond) {
auto check = lb.create<scf::IfOp>(
cond,
[&](OpBuilder &b, Location loc) {
maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
},
[&](OpBuilder &b, Location loc) {
if (outOfBoundsCase) {
maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
} else {
b.create<scf::YieldOp>(loc);
}
});
return hasRetVal ? check.getResult(0) : Value();
}
return inBoundsCase(b, loc);
}
template <typename OpTy>
static void generateInBoundsCheck(
OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
function_ref<void(OpBuilder &, Location)> inBoundsCase,
function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
generateInBoundsCheck(
b, xferOp, iv, dim, TypeRange(),
[&](OpBuilder &b, Location loc) {
inBoundsCase(b, loc);
return Value();
},
[&](OpBuilder &b, Location loc) {
if (outOfBoundsCase)
outOfBoundsCase(b, loc);
return Value();
});
}
static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
if (!attr)
return attr;
return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
}
template <typename OpTy>
static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
unsigned targetRank) {
if (newXferOp.getVectorType().getRank() > targetRank)
newXferOp->setAttr(kPassLabel, b.getUnitAttr());
}
template <typename OpTy>
static bool isTensorOp(OpTy xferOp) {
if (isa<RankedTensorType>(xferOp.getShapedType())) {
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
assert(xferOp->getNumResults() > 0);
}
return true;
}
return false;
}
namespace lowering_n_d {
struct BufferAllocs {
Value dataBuffer;
Value maskBuffer;
};
static Operation *getAutomaticAllocationScope(Operation *op) {
Operation *scope =
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
return scope;
}
template <typename OpTy>
static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
Location loc = xferOp.getLoc();
OpBuilder::InsertionGuard guard(b);
Operation *scope = getAutomaticAllocationScope(xferOp);
assert(scope->getNumRegions() == 1 &&
"AutomaticAllocationScope with >1 regions");
b.setInsertionPointToStart(&scope->getRegion(0).front());
BufferAllocs result;
auto bufferType = MemRefType::get({}, xferOp.getVectorType());
result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
if (xferOp.getMask()) {
auto maskType = MemRefType::get({}, xferOp.getMask().getType());
auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
b.setInsertionPoint(xferOp);
b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
}
return result;
}
static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
auto vectorType = dyn_cast<VectorType>(type.getElementType());
if (vectorType.getScalableDims().front())
return failure();
auto memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
newMemrefShape.push_back(vectorType.getDimSize(0));
return MemRefType::get(newMemrefShape,
VectorType::Builder(vectorType).dropDim(0));
}
template <typename OpTy>
static Value getMaskBuffer(OpTy xferOp) {
assert(xferOp.getMask() && "Expected that transfer op has mask");
auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
assert(loadOp && "Expected transfer op mask produced by LoadOp");
return loadOp.getMemRef();
}
template <typename OpTy>
struct Strategy;
template <>
struct Strategy<TransferReadOp> {
static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
assert(storeOp && "Expected TransferReadOp result used by StoreOp");
return storeOp;
}
static Value getBuffer(TransferReadOp xferOp) {
return getStoreOp(xferOp).getMemRef();
}
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
auto storeOp = getStoreOp(xferOp);
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
static TransferReadOp rewriteOp(OpBuilder &b,
VectorTransferToSCFOptions options,
TransferReadOp xferOp, Value buffer, Value iv,
ValueRange ) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
loc, vecType, xferOp.getSource(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeApplyPassLabel(b, newXferOp, options.targetRank);
b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
return newXferOp;
}
static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
Value buffer, Value iv,
ValueRange ) {
SmallVector<Value, 8> storeIndices;
getBufferIndices(xferOp, storeIndices);
storeIndices.push_back(iv);
Location loc = xferOp.getLoc();
auto bufferType = dyn_cast<ShapedType>(buffer.getType());
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
return Value();
}
static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
scf::ForOp ) {
rewriter.eraseOp(getStoreOp(xferOp));
rewriter.eraseOp(xferOp);
}
static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
};
template <>
struct Strategy<TransferWriteOp> {
static Value getBuffer(TransferWriteOp xferOp) {
auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
assert(loadOp && "Expected transfer op vector produced by LoadOp");
return loadOp.getMemRef();
}
static void getBufferIndices(TransferWriteOp xferOp,
SmallVector<Value, 8> &indices) {
auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
static TransferWriteOp rewriteOp(OpBuilder &b,
VectorTransferToSCFOptions options,
TransferWriteOp xferOp, Value buffer,
Value iv, ValueRange loopState) {
SmallVector<Value, 8> loadIndices;
getBufferIndices(xferOp, loadIndices);
loadIndices.push_back(iv);
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, type, vec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
maybeApplyPassLabel(b, newXferOp, options.targetRank);
return newXferOp;
}
static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
Value buffer, Value iv,
ValueRange loopState) {
return isTensorOp(xferOp) ? loopState[0] : Value();
}
static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
scf::ForOp forOp) {
if (isTensorOp(xferOp)) {
assert(forOp->getNumResults() == 1 && "Expected one for loop result");
rewriter.replaceOp(xferOp, forOp->getResult(0));
} else {
rewriter.eraseOp(xferOp);
}
}
static Value initialLoopState(TransferWriteOp xferOp) {
return isTensorOp(xferOp) ? xferOp.getSource() : Value();
}
};
template <typename OpTy>
LogicalResult checkPrepareXferOp(OpTy xferOp,
VectorTransferToSCFOptions options) {
if (xferOp->hasAttr(kPassLabel))
return failure();
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
if (xferOp.getVectorType().getScalableDims().front())
return failure();
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
if (xferOp.getVectorType().getElementType() !=
xferOp.getShapedType().getElementType())
return failure();
return success();
}
struct PrepareTransferReadConversion
: public VectorToSCFPattern<TransferReadOp> {
using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (checkPrepareXferOp(xferOp, options).failed())
return failure();
auto buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
buffers.maskBuffer);
}
Location loc = xferOp.getLoc();
rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
buffers.dataBuffer);
rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
return success();
}
};
struct PrepareTransferWriteConversion
: public VectorToSCFPattern<TransferWriteOp> {
using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
if (checkPrepareXferOp(xferOp, options).failed())
return failure();
Location loc = xferOp.getLoc();
auto buffers = allocBuffers(rewriter, xferOp);
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getVectorMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});
if (xferOp.getMask()) {
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getMaskMutable().assign(buffers.maskBuffer);
});
}
return success();
}
};
struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
LogicalResult matchAndRewrite(vector::PrintOp printOp,
PatternRewriter &rewriter) const override {
if (!printOp.getSource())
return failure();
VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
if (!vectorType)
return failure();
if (vectorType.getRank() > 1 && vectorType.isScalable())
return failure();
auto loc = printOp.getLoc();
auto value = printOp.getSource();
if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
auto width = intTy.getWidth();
auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
intTy.getSignedness());
auto signlessSourceVectorType =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
auto signlessTargetVectorType =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
value);
if (value.getType() != signlessTargetVectorType) {
if (width == 1 || intTy.isUnsigned())
value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
value);
else
value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
value);
}
value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
vectorType = targetVectorType;
}
auto scalableDimensions = vectorType.getScalableDims();
auto shape = vectorType.getShape();
constexpr int64_t singletonShape[] = {1};
if (vectorType.getRank() == 0)
shape = singletonShape;
if (vectorType.getRank() != 1) {
auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
}
vector::PrintOp firstClose;
SmallVector<Value, 8> loopIndices;
for (unsigned d = 0; d < shape.size(); d++) {
Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
if (!scalableDimensions.empty() && scalableDimensions[d]) {
auto vscale = rewriter.create<vector::VectorScaleOp>(
loc, rewriter.getIndexType());
upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
}
auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
auto loop =
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
auto printClose = rewriter.create<vector::PrintOp>(
loc, vector::PrintPunctuation::Close);
if (!firstClose)
firstClose = printClose;
auto loopIdx = loop.getInductionVar();
loopIndices.push_back(loopIdx);
rewriter.setInsertionPointToStart(loop.getBody());
auto notLastIndex = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
rewriter.create<scf::IfOp>(loc, notLastIndex,
[&](OpBuilder &builder, Location loc) {
builder.create<vector::PrintOp>(
loc, vector::PrintPunctuation::Comma);
builder.create<scf::YieldOp>(loc);
});
rewriter.setInsertionPointToStart(loop.getBody());
}
Value flatIndex;
auto currentStride = 1;
for (int d = shape.size() - 1; d >= 0; d--) {
auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
if (flatIndex)
flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
else
flatIndex = index;
currentStride *= shape[d];
}
auto element =
rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
rewriter.create<vector::PrintOp>(loc, element,
vector::PrintPunctuation::NoPunctuation);
rewriter.setInsertionPointAfter(firstClose);
rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
rewriter.eraseOp(printOp);
return success();
}
static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
return IntegerType::get(intTy.getContext(), intTy.getWidth(),
IntegerType::Signless);
};
};
template <typename OpTy>
struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
void initialize() {
this->setHasBoundedRewriteRecursion();
}
static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
SmallVectorImpl<Value> &loadIndices,
Value iv) {
assert(xferOp.getMask() && "Expected transfer op to have mask");
Value maskBuffer = getMaskBuffer(xferOp);
for (Operation *user : maskBuffer.getUsers()) {
if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
Operation::operand_range prevIndices = loadOp.getIndices();
loadIndices.append(prevIndices.begin(), prevIndices.end());
break;
}
}
if (!xferOp.isBroadcastDim(0))
loadIndices.push_back(iv);
}
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
return failure();
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
auto castedDataBuffer =
locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
Value castedMaskBuffer;
if (xferOp.getMask()) {
Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
castedMaskBuffer = maskBuffer;
} else {
auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
}
auto lb = locB.create<arith::ConstantIndexOp>(0);
auto ub = locB.create<arith::ConstantIndexOp>(
castedDataType->getDimSize(castedDataType->getRank() - 1));
auto step = locB.create<arith::ConstantIndexOp>(1);
auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
auto result = locB.create<scf::ForOp>(
lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Type stateType = loopState.empty() ? Type() : loopState[0].getType();
auto result = generateInBoundsCheck(
b, xferOp, iv, unpackedDim(xferOp),
stateType ? TypeRange(stateType) : TypeRange(),
[&](OpBuilder &b, Location loc) {
OpTy newXfer = Strategy<OpTy>::rewriteOp(
b, this->options, xferOp, castedDataBuffer, iv, loopState);
if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
xferOp.getMaskType().getRank() > 1)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer);
SmallVector<Value, 8> loadIndices;
getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
rewriter.modifyOpInPlace(newXfer, [&]() {
newXfer.getMaskMutable().assign(mask);
});
}
return loopState.empty() ? Value() : newXfer->getResult(0);
},
[&](OpBuilder &b, Location ) {
return Strategy<OpTy>::handleOutOfBoundsDim(
b, xferOp, castedDataBuffer, iv, loopState);
});
maybeYieldValue(b, loc, !loopState.empty(), result);
});
Strategy<OpTy>::cleanup(rewriter, xferOp, result);
return success();
}
};
}
namespace lowering_n_d_unrolled {
template <typename OpTy>
static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
int64_t i) {
if (!xferOp.getMask())
return;
if (xferOp.isBroadcastDim(0)) {
newXferOp.getMaskMutable().assign(xferOp.getMask());
return;
}
if (xferOp.getMaskType().getRank() > 1) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXferOp);
llvm::SmallVector<int64_t, 1> indices({i});
Location loc = xferOp.getLoc();
auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
newXferOp.getMaskMutable().assign(newMask);
}
}
struct UnrollTransferReadConversion
: public VectorToSCFPattern<TransferReadOp> {
using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
void initialize() {
setHasBoundedRewriteRecursion();
}
Value buildResultVector(PatternRewriter &rewriter,
TransferReadOp xferOp) const {
if (auto insertOp = getInsertOp(xferOp))
return insertOp.getDest();
Location loc = xferOp.getLoc();
return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
}
vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
if (xferOp->hasOneUse()) {
Operation *xferOpUser = *xferOp->getUsers().begin();
if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
return insertOp;
}
return vector::InsertOp();
}
void getInsertionIndices(TransferReadOp xferOp,
SmallVectorImpl<OpFoldResult> &indices) const {
if (auto insertOp = getInsertOp(xferOp)) {
auto pos = insertOp.getMixedPosition();
indices.append(pos.begin(), pos.end());
}
}
LogicalResult matchAndRewrite(TransferReadOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return rewriter.notifyMatchFailure(
xferOp, "vector rank is less or equal to target rank");
if (isTensorOp(xferOp) && !options.lowerTensors)
return rewriter.notifyMatchFailure(
xferOp, "transfers operating on tensors are excluded");
if (xferOp.getVectorType().getElementType() !=
xferOp.getShapedType().getElementType())
return rewriter.notifyMatchFailure(
xferOp, "not yet supported: element type mismatch");
auto xferVecType = xferOp.getVectorType();
if (xferVecType.getScalableDims()[0]) {
return rewriter.notifyMatchFailure(
xferOp, "scalable dimensions cannot be unrolled");
}
auto insertOp = getInsertOp(xferOp);
auto vec = buildResultVector(rewriter, xferOp);
auto vecType = dyn_cast<VectorType>(vec.getType());
VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
int64_t dimSize = xferVecType.getShape()[0];
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
vec = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
[&](OpBuilder &b, Location loc) {
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
SmallVector<OpFoldResult, 8> insertionIndices;
getInsertionIndices(xferOp, insertionIndices);
insertionIndices.push_back(rewriter.getIndexAttr(i));
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
loc, newXferVecType, xferOp.getSource(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
return b.create<vector::InsertOp>(loc, newXferOp, vec,
insertionIndices);
},
[&](OpBuilder &b, Location loc) {
return vec;
});
}
if (insertOp) {
rewriter.replaceOp(insertOp, vec);
rewriter.eraseOp(xferOp);
} else {
rewriter.replaceOp(xferOp, vec);
}
return success();
}
};
struct UnrollTransferWriteConversion
: public VectorToSCFPattern<TransferWriteOp> {
using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
void initialize() {
setHasBoundedRewriteRecursion();
}
Value getDataVector(TransferWriteOp xferOp) const {
if (auto extractOp = getExtractOp(xferOp))
return extractOp.getVector();
return xferOp.getVector();
}
vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
if (auto *op = xferOp.getVector().getDefiningOp())
return dyn_cast<vector::ExtractOp>(op);
return vector::ExtractOp();
}
void getExtractionIndices(TransferWriteOp xferOp,
SmallVectorImpl<OpFoldResult> &indices) const {
if (auto extractOp = getExtractOp(xferOp)) {
auto pos = extractOp.getMixedPosition();
indices.append(pos.begin(), pos.end());
}
}
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
VectorType inputVectorTy = xferOp.getVectorType();
if (inputVectorTy.getRank() <= options.targetRank)
return failure();
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
if (inputVectorTy.getElementType() !=
xferOp.getShapedType().getElementType())
return failure();
auto vec = getDataVector(xferOp);
if (inputVectorTy.getScalableDims()[0]) {
return failure();
}
int64_t dimSize = inputVectorTy.getShape()[0];
Value source = xferOp.getSource();
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
auto updatedSource = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp),
isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
[&](OpBuilder &b, Location loc) {
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
SmallVector<OpFoldResult, 8> extractionIndices;
getExtractionIndices(xferOp, extractionIndices);
extractionIndices.push_back(b.getI64IntegerAttr(i));
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
Value xferVec;
if (inputVectorTy.getRank() == 1) {
xferVec = b.create<vector::BroadcastOp>(
loc, VectorType::get({}, extracted.getType()), extracted);
} else {
xferVec = extracted;
}
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, sourceType, xferVec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
},
[&](OpBuilder &b, Location loc) {
return isTensorOp(xferOp) ? source : Value();
});
if (isTensorOp(xferOp))
source = updatedSource;
}
if (isTensorOp(xferOp))
rewriter.replaceOp(xferOp, source);
else
rewriter.eraseOp(xferOp);
return success();
}
};
}
namespace lowering_1_d {
template <typename OpTy>
static std::optional<int64_t>
get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
SmallVector<Value, 8> &memrefIndices) {
auto indices = xferOp.getIndices();
auto map = xferOp.getPermutationMap();
assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
memrefIndices.append(indices.begin(), indices.end());
assert(map.getNumResults() == 1 &&
"Expected 1 permutation map result for 1D transfer");
if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
Location loc = xferOp.getLoc();
auto dim = expr.getPosition();
AffineExpr d0, d1;
bindDims(xferOp.getContext(), d0, d1);
Value offset = memrefIndices[dim];
memrefIndices[dim] =
affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
return dim;
}
assert(xferOp.isBroadcastDim(0) &&
"Expected AffineDimExpr or AffineConstantExpr");
return std::nullopt;
}
template <typename OpTy>
struct Strategy1d;
template <>
struct Strategy1d<TransferReadOp> {
static void generateForLoopBody(OpBuilder &b, Location loc,
TransferReadOp xferOp, Value iv,
ValueRange loopState) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
auto vec = loopState[0];
auto nextVec = generateInBoundsCheck(
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
[&](OpBuilder &b, Location loc) {
Value val =
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
},
[&](OpBuilder & , Location loc) { return vec; });
b.create<scf::YieldOp>(loc, nextVec);
}
static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
Location loc = xferOp.getLoc();
return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
xferOp.getPadding());
}
};
template <>
struct Strategy1d<TransferWriteOp> {
static void generateForLoopBody(OpBuilder &b, Location loc,
TransferWriteOp xferOp, Value iv,
ValueRange ) {
SmallVector<Value, 8> indices;
auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
generateInBoundsCheck(
b, xferOp, iv, dim,
[&](OpBuilder &b, Location loc) {
auto val =
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
});
b.create<scf::YieldOp>(loc);
}
static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
return Value();
}
};
template <typename OpTy>
struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getTransferRank() == 0)
return failure();
auto map = xferOp.getPermutationMap();
auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
if (!memRefType)
return failure();
if (xferOp.getVectorType().getRank() != 1)
return failure();
if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
return failure();
Location loc = xferOp.getLoc();
auto vecType = xferOp.getVectorType();
auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ub =
rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
if (vecType.isScalable()) {
Value vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
}
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
rewriter.replaceOpWithNewOp<scf::ForOp>(
xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
});
return success();
}
};
}
}
void mlir::populateVectorToSCFConversionPatterns(
RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
if (options.unroll) {
patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
lowering_n_d_unrolled::UnrollTransferWriteConversion>(
patterns.getContext(), options);
} else {
patterns.add<lowering_n_d::PrepareTransferReadConversion,
lowering_n_d::PrepareTransferWriteConversion,
lowering_n_d::TransferOpConversion<TransferReadOp>,
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
patterns.getContext(), options);
}
if (options.targetRank == 1) {
patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
patterns.getContext(), options);
}
patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
options);
}
namespace {
struct ConvertVectorToSCFPass
: public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
this->lowerTensors = options.lowerTensors;
}
void runOnOperation() override {
VectorTransferToSCFOptions options;
options.unroll = fullUnroll;
options.targetRank = targetRank;
options.lowerTensors = lowerTensors;
RewritePatternSet lowerTransferPatterns(&getContext());
mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
lowerTransferPatterns);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(lowerTransferPatterns));
RewritePatternSet patterns(&getContext());
populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
}
std::unique_ptr<Pass>
mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
return std::make_unique<ConvertVectorToSCFPass>(options);
}