#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#define DEBUG_TYPE "arm-sme-vector-legalization"
namespace mlir::arm_sme {
#define GEN_PASS_DEF_VECTORLEGALIZATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
"op vector size is not multiple of SME tiles");
static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
"op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
kMatchFailureNonPermutationMap("op affine map is not a permutation");
static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
"expected transpose from illegal type to legal type");
struct SMESubTile {
int row{0};
int col{0};
VectorType type;
};
SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
Location loc,
ValueRange indices,
ArrayRef<int> scalableOffsets) {
auto vscale = builder.create<vector::VectorScaleOp>(loc);
return llvm::map_to_vector(
llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
auto [index, base] = pair;
auto offset = builder.create<arith::MulIOp>(
loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
return builder.create<arith::AddIOp>(loc, index, offset);
});
}
SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
ValueRange indices,
SMESubTile smeTile) {
return addConstantScalableOffset(builder, loc, indices,
{smeTile.row, smeTile.col});
}
bool isSupportedMaskOp(Value mask) {
return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
}
Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
SMESubTile smeTile) {
assert(isSupportedMaskOp(mask));
if (!mask)
return Value{};
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
auto smeTileMaskDims = addConstantScalableOffset(
builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
return smeTileCreateMask.getResult();
}
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
VectorType smeTileType,
bool transposeIndices = false) {
assert(isMultipleOfSMETileVectorType(type) &&
"`type` not multiple of SME tiles");
return llvm::map_range(
StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
smeTileType.getDimSize(1)}),
[=](auto indices) {
int row = int(indices[0]);
int col = int(indices[1]);
if (transposeIndices)
std::swap(row, col);
return SMESubTile{row, col, smeTileType};
});
}
int getNumberOfSMETilesForVectorType(VectorType type) {
assert(isMultipleOfSMETileVectorType(type) &&
"`type` not multiple of SME tiles");
int64_t vectorRows = type.getDimSize(0);
int64_t vectorCols = type.getDimSize(1);
auto elementType = type.getElementType();
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
}
struct LegalizeArithConstantOpsByDecomposition
: public OneToNOpConversionPattern<arith::ConstantOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!vectorType || !denseAttr || !denseAttr.isSplat())
return failure();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(constantOp,
kMatchFailureNotSMETileTypeMultiple);
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
adaptor.getResultMapping());
return success();
}
};
struct LegalizeVectorOuterProductOpsByDecomposition
: public OneToNOpConversionPattern<vector::OuterProductOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(outerProductOp,
kMatchFailureNotSMETileTypeMultiple);
Value mask;
Operation *rootOp = outerProductOp;
auto loc = outerProductOp.getLoc();
if (outerProductOp.isMasked()) {
auto maskOp = outerProductOp.getMaskingOp();
mask = maskOp.getMask();
rootOp = maskOp;
}
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(outerProductOp,
kMatchFailureUnsupportedMaskOp);
ValueRange accSMETiles = adaptor.getAcc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
SmallVector<Value> resultSMETiles;
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
auto lhs = rewriter.create<vector::ScalableExtractOp>(
loc, sliceType, outerProductOp.getLhs(), smeTile.row);
auto rhs = rewriter.create<vector::ScalableExtractOp>(
loc, sliceType, outerProductOp.getRhs(), smeTile.col);
auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
loc, smeTileType, lhs, rhs,
!accSMETiles.empty() ? accSMETiles[index] : Value{},
outerProductOp.getKind());
auto maskedOuterProduct =
vector::maskOperation(rewriter, smeOuterProduct, smeMask);
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
}
rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
return success();
}
};
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
: public OneToNOpConversionPattern<vector::MaskOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (auto outerProductOp =
llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
getContext());
return static_cast<RewritePattern &>(pattern).matchAndRewrite(
outerProductOp, rewriter);
}
return failure();
}
};
struct LegalizeTransferReadOpsByDecomposition
: public OneToNOpConversionPattern<vector::TransferReadOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(readOp,
kMatchFailureNotSMETileTypeMultiple);
auto mask = readOp.getMask();
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(readOp,
kMatchFailureUnsupportedMaskOp);
auto permutationMap = readOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(readOp,
kMatchFailureNonPermutationMap);
bool transposed = !permutationMap.isIdentity();
auto loc = readOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
SmallVector<Value> resultSMETiles;
for (SMESubTile smeTile :
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
auto smeRead = rewriter.create<vector::TransferReadOp>(
loc, smeTileType, readOp.getSource(),
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
readOp.getInBoundsAttr());
resultSMETiles.push_back(smeRead);
}
rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
return success();
}
};
struct LegalizeTransferWriteOpsByDecomposition
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNotSMETileTypeMultiple);
auto mask = writeOp.getMask();
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureUnsupportedMaskOp);
auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNonPermutationMap);
bool transposed = !permutationMap.isIdentity();
auto loc = writeOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto inputSMETiles = adaptor.getVector();
Value destTensorOrMemref = writeOp.getSource();
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
rewriter, vectorType, smeTileType, transposed))) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
loc, inputSMETiles[index], destTensorOrMemref,
getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
if (writeOp.hasPureTensorSemantics())
destTensorOrMemref = smeWrite.getResult();
}
if (writeOp.hasPureTensorSemantics())
rewriter.replaceOp(writeOp, destTensorOrMemref);
else
rewriter.eraseOp(writeOp);
return success();
}
};
struct LegalizeMultiTileTransferWriteAsStoreLoop
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");
auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNonPermutationMap);
bool transposed = !permutationMap.isIdentity();
if (transposed)
return rewriter.notifyMatchFailure(writeOp,
"TODO: transpose unsupported");
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNotSMETileTypeMultiple);
auto mask = writeOp.getMask();
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
vectorType.getDimSize(1) > 16)))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureUnsupportedMaskOp);
auto loc = writeOp.getLoc();
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto createVscaleMultiple = [&](int64_t multiplier) {
return rewriter.create<arith::MulIOp>(
loc, vscale,
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
};
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto minTileSlices = smeTileType.getDimSize(0);
VectorType sliceMaskType =
VectorType::get(minTileSlices, rewriter.getI1Type(), true);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = createVscaleMultiple(minTileSlices);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto storeLoop =
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
rewriter.setInsertionPointToStart(storeLoop.getBody());
auto inputSMETiles = adaptor.getVector();
auto tileSliceIndex = storeLoop.getInductionVar();
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
auto tileRow = createVscaleMultiple(smeTile.row);
auto tileCol = createVscaleMultiple(smeTile.col);
auto sliceIndex =
rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
writeOp.getIndices()[0]);
auto storeCol =
rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
Value sliceMask = nullptr;
if (mask) {
sliceMask = rewriter.create<vector::ExtractOp>(
loc, mask, OpFoldResult(sliceIndex));
if (sliceMaskType != sliceMask.getType())
sliceMask = rewriter.create<vector::ScalableExtractOp>(
loc, sliceMaskType, sliceMask, smeTile.col);
}
Value tile = inputSMETiles[index];
auto slice =
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
rewriter.create<vector::TransferWriteOp>(
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
sliceMask,
rewriter.getBoolArrayAttr(
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
}
rewriter.eraseOp(writeOp);
return success();
}
};
struct FoldExtractFromVectorOfSMELikeCreateMasks
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");
VectorType extractedMaskType =
llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
if (!extractedMaskType)
return rewriter.notifyMatchFailure(extractOp,
"extracted type is not a vector type");
auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
if (numScalable != 2)
return rewriter.notifyMatchFailure(
extractOp, "expected extracted type to be an SME-like mask");
if (extractOp.getStaticPosition().size() != 1)
return rewriter.notifyMatchFailure(
extractOp, "only a single extraction index is supported");
auto frontMaskDim = createMaskOp.getOperand(0);
if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
return rewriter.notifyMatchFailure(
extractOp,
"constant vector.create_masks dims should be folded elsewhere");
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto extractionIndex = getValueOrCreateConstantIndexOp(
rewriter, loc, extractOp.getMixedPosition()[0]);
auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
frontMaskDim);
auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
extractOp, extractedMaskType,
ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
return success();
}
};
bool isLegalVectorType(VectorType vType) {
bool seenFixedDim = false;
for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
seenFixedDim |= !scalableFlag;
if (seenFixedDim && scalableFlag)
return false;
}
return true;
}
struct LiftIllegalVectorTransposeToMemory
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
static Value getExtensionSource(Operation *op) {
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
return op->getOperand(0);
return {};
}
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto sourceType = transposeOp.getSourceVectorType();
auto resultType = transposeOp.getResultVectorType();
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
return rewriter.notifyMatchFailure(transposeOp,
kMatchFailureNotIllegalToLegal);
Value maybeRead = transposeOp.getVector();
auto *transposeSourceOp = maybeRead.getDefiningOp();
Operation *extendOp = nullptr;
if (Value extendSource = getExtensionSource(transposeSourceOp)) {
maybeRead = extendSource;
extendOp = transposeSourceOp;
}
auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
if (!illegalRead)
return rewriter.notifyMatchFailure(
transposeOp,
"expected source to be (possibly extended) transfer_read");
if (!illegalRead.getPermutationMap().isIdentity())
return rewriter.notifyMatchFailure(
illegalRead, "expected read to have identity permutation map");
auto loc = transposeOp.getLoc();
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto readType = illegalRead.getVectorType();
auto readSizes = llvm::map_to_vector(
llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
[&](auto dim) -> Value {
auto [size, isScalable] = dim;
auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
if (!isScalable)
return dimSize;
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
});
SmallVector<Value> strides(readType.getRank(), Value(one));
auto readSubview = rewriter.create<memref::SubViewOp>(
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
strides);
Value mask = illegalRead.getMask();
if (mask) {
mask = rewriter.create<vector::TransposeOp>(loc, mask,
transposeOp.getPermutation());
}
mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
transposeOp.getPermutation(), getContext());
auto transposedSubview = rewriter.create<memref::TransposeOp>(
loc, readSubview, AffineMapAttr::get(transposeMap));
ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
if (inBoundsAttr) {
SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
inBoundsAttr.end());
applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
}
VectorType legalReadType = resultType.clone(readType.getElementType());
SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
auto legalRead = rewriter.create<vector::TransferReadOp>(
loc, legalReadType, transposedSubview, readIndices,
illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
inBoundsAttr);
rewriter.replaceOp(transposeOp, [&]() -> Operation * {
if (extendOp)
return rewriter.create(loc, extendOp->getName().getIdentifier(),
Value(legalRead), resultType);
return legalRead;
}());
return success();
}
};
struct ConvertIllegalShapeCastOpsToTransposes
: public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto sourceType = shapeCastOp.getSourceVectorType();
auto resultType = shapeCastOp.getResultVectorType();
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
return rewriter.notifyMatchFailure(shapeCastOp,
kMatchFailureNotIllegalToLegal);
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
return rewriter.notifyMatchFailure(
shapeCastOp, "expected source to be a 2D scalable vector with a "
"trailing unit dim");
auto loc = shapeCastOp.getLoc();
auto transpose = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
if (resultType.getRank() == 1)
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
transpose);
else
rewriter.replaceOp(shapeCastOp, transpose);
return success();
}
};
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
if (!isMultipleOfSMETileVectorType(vectorType))
return std::nullopt;
auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
auto smeTileType =
getSMETileTypeForElement(vectorType.getElementType());
types = SmallVector<Type>(smeTileCount, smeTileType);
return success();
});
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes>(context);
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
1024);
patterns.add<LegalizeArithConstantOpsByDecomposition,
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
populateFuncTypeConversionPatterns(converter, patterns);
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
if (failed(applyPartialOneToNConversion(getOperation(), converter,
std::move(patterns))))
return signalPassFailure();
}
};
}
std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
return std::make_unique<VectorLegalizationPass>();
}