#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
namespace {
struct TransferReadToArmSMELowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const final {
if (transferReadOp.getTransferRank() != 2)
return rewriter.notifyMatchFailure(transferReadOp,
"not a 2 result permutation map");
auto vectorType = transferReadOp.getVectorType();
if (!arm_sme::isValidSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(transferReadOp,
"not a valid vector type for SME");
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
if (transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(transferReadOp,
"not inbounds transfer read");
AffineMap map = transferReadOp.getPermutationMap();
if (!map.isPermutation())
return rewriter.notifyMatchFailure(transferReadOp,
"unsupported permutation map");
bool transposed = !map.isIdentity();
arm_sme::TileSliceLayout layout =
transposed ? arm_sme::TileSliceLayout::Vertical
: arm_sme::TileSliceLayout::Horizontal;
auto mask = transferReadOp.getMask();
auto padding = mask ? transferReadOp.getPadding() : nullptr;
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transferReadOp, vectorType, transferReadOp.getSource(),
transferReadOp.getIndices(), padding, mask, layout);
return success();
}
};
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
auto vType = writeOp.getVectorType();
if (!arm_sme::isValidSMETileVectorType(vType))
return failure();
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();
if (writeOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(writeOp,
"not inbounds transfer write");
AffineMap map = writeOp.getPermutationMap();
if (!map.isPermutation())
return rewriter.notifyMatchFailure(writeOp,
"unsupported permutation map");
bool transposed = !map.isIdentity();
arm_sme::TileSliceLayout layout =
transposed ? arm_sme::TileSliceLayout::Vertical
: arm_sme::TileSliceLayout::Horizontal;
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp.getMask(), layout);
return success();
}
};
struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::LoadOp load,
PatternRewriter &rewriter) const override {
if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
load, load.getVectorType(), load.getBase(), load.getIndices());
return success();
}
};
struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::StoreOp store,
PatternRewriter &rewriter) const override {
if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
store, store.getValueToStore(), store.getBase(), store.getIndices());
return success();
}
};
struct BroadcastOpToArmSMELowering
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const final {
auto tileType = broadcastOp.getResultVectorType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
auto loc = broadcastOp.getLoc();
auto srcType = broadcastOp.getSourceType();
auto srcVectorType = dyn_cast<VectorType>(srcType);
Value broadcastOp1D;
if (srcType.isIntOrFloat() ||
(srcVectorType && (srcVectorType.getRank() == 0))) {
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
broadcastOp1D = broadcastOp.getSource();
else
return failure();
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp =
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
rewriter.replaceOp(broadcastOp, forOp.getResult(0));
return success();
}
};
struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {
auto tileType = splatOp.getResult().getType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
auto loc = splatOp.getLoc();
auto srcType = splatOp.getOperand().getType();
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
(void)srcType;
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getInput());
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
auto forOp =
createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
rewriter.replaceOp(splatOp, forOp.getResult(0));
return success();
}
};
struct TransposeOpToArmSMELowering
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const final {
auto tileType = transposeOp.getResultVectorType();
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
ArrayRef<int64_t> permutation = transposeOp.getPermutation();
if (permutation[0] != 1 || permutation[1] != 0)
return failure();
auto loc = transposeOp.getLoc();
Value input = transposeOp.getVector();
if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
xferOp && xferOp->hasOneUse()) {
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getPermutationMapAttrName(),
AffineMapAttr::get(AffineMap::getPermutationMap(
permutation, transposeOp.getContext())));
});
rewriter.replaceOp(transposeOp, xferOp);
return success();
}
Value vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
Value minTileSlices = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value numTileSlices =
rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
auto bufferType =
MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
tileType.getElementType());
auto buffer = rewriter.create<memref::AllocaOp>(
loc, bufferType, ValueRange{numTileSlices, numTileSlices});
auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
loc, input, buffer, ValueRange{c0, c0});
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
arm_sme::TileSliceLayout::Vertical);
return success();
}
};
struct VectorOuterProductToArmSMELowering
: public OpRewritePattern<vector::OuterProductOp> {
using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
PatternRewriter &rewriter) const override {
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
return rewriter.notifyMatchFailure(outerProductOp,
"AXPY operations not supported");
if (!arm_sme::isValidSMETileVectorType(
outerProductOp.getResultVectorType()))
return rewriter.notifyMatchFailure(
outerProductOp, "outer product does not fit into SME tile");
auto kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(
outerProductOp,
"unsupported kind (lowering to SME only supports ADD at the moment)");
Value lhsMask = {};
Value rhsMask = {};
Operation *rootOp = outerProductOp;
auto loc = outerProductOp.getLoc();
if (outerProductOp.isMasked()) {
auto maskOp = outerProductOp.getMaskingOp();
rewriter.setInsertionPoint(maskOp);
rootOp = maskOp;
auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
if (failed(operandMasks))
return failure();
std::tie(lhsMask, rhsMask) = *operandMasks;
}
rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
return success();
}
static FailureOr<std::pair<Value, Value>>
decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return failure();
auto maskType = createMaskOp.getVectorType();
Value lhsMaskDim = createMaskOp.getOperand(0);
Value rhsMaskDim = createMaskOp.getOperand(1);
VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
Value lhsMask =
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
Value rhsMask =
rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
return std::make_pair(lhsMask, rhsMask);
}
};
struct VectorExtractToArmSMELowering
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
VectorType sourceType = extractOp.getSourceVectorType();
if (!arm_sme::isValidSMETileVectorType(sourceType))
return failure();
auto loc = extractOp.getLoc();
auto position = extractOp.getMixedPosition();
Value sourceVector = extractOp.getVector();
if (position.empty()) {
rewriter.replaceOp(extractOp, sourceVector);
return success();
}
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
auto moveTileSliceToVector =
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
sliceIndex);
if (position.size() == 1) {
rewriter.replaceOp(extractOp, moveTileSliceToVector);
return success();
}
assert(position.size() == 2);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, moveTileSliceToVector, position[1]);
return success();
}
};
struct VectorInsertToArmSMELowering
: public OpRewritePattern<vector::InsertOp> {
using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
PatternRewriter &rewriter) const override {
VectorType resultType = insertOp.getResult().getType();
if (!arm_sme::isValidSMETileVectorType(resultType))
return failure();
auto loc = insertOp.getLoc();
auto position = insertOp.getMixedPosition();
Value source = insertOp.getSource();
if (position.empty()) {
rewriter.replaceOp(insertOp, source);
return success();
}
Value tileSlice = source;
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
if (position.size() == 2) {
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
loc, insertOp.getDest(), sliceIndex);
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
position[1]);
}
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
insertOp, tileSlice, insertOp.getDest(), sliceIndex);
return success();
}
};
struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::PrintOp printOp,
PatternRewriter &rewriter) const override {
if (!printOp.getSource())
return failure();
VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
return failure();
auto loc = printOp.getLoc();
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto minTileRows =
rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
{
rewriter.setInsertionPointToStart(forOp.getBody());
Value rowIndex = forOp.getInductionVar();
auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
loc, printOp.getSource(), rowIndex);
rewriter.create<vector::PrintOp>(loc, tileSlice,
printOp.getPunctuation());
}
rewriter.eraseOp(printOp);
return success();
}
};
struct FoldTransferWriteOfExtractTileSlice
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const final {
if (!isa<MemRefType>(writeOp.getSource().getType()))
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
if (writeOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(writeOp,
"not inbounds transfer write");
auto moveTileSlice =
writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>();
if (!moveTileSlice)
return rewriter.notifyMatchFailure(
writeOp, "vector to store not from MoveTileSliceToVectorOp");
AffineMap map = writeOp.getPermutationMap();
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp,
"unsupported permutation map");
Value mask = writeOp.getMask();
if (!mask) {
auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
mask = rewriter.create<arith::ConstantOp>(
writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
}
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(),
mask, writeOp.getSource(), writeOp.getIndices(),
moveTileSlice.getLayout());
return success();
}
};
struct ExtractFromCreateMaskToPselLowering
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (extractOp.getNumIndices() != 1)
return rewriter.notifyMatchFailure(extractOp, "not single extract index");
auto resultType = extractOp.getResult().getType();
auto resultVectorType = dyn_cast<VectorType>(resultType);
if (!resultVectorType)
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
auto maskType = createMaskOp.getVectorType();
if (maskType.getRank() != 2 || !maskType.allDimsScalable())
return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
auto isSVEPredicateSize = [](int64_t size) {
return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
};
auto rowsBaseSize = maskType.getDimSize(0);
auto colsBaseSize = maskType.getDimSize(1);
if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
return rewriter.notifyMatchFailure(
createMaskOp, "mask dimensions not SVE predicate-sized");
auto loc = extractOp.getLoc();
VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
rewriter.setInsertionPoint(createMaskOp);
auto rowMask = rewriter.create<vector::CreateMaskOp>(
loc, rowMaskType, createMaskOp.getOperand(0));
auto colMask = rewriter.create<vector::CreateMaskOp>(
loc, colMaskType, createMaskOp.getOperand(1));
rewriter.setInsertionPoint(extractOp);
auto position =
vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
position[0]);
return success();
}
};
}
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
}