#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOSCF
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
namespace {
SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
Value tileSliceIndex,
Value tileSliceNumElts, Location loc,
PatternRewriter &rewriter) {
assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
if (rank == 1)
tileSliceOffset =
rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
auto baseIndexPlusTileSliceOffset =
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
if (rank == 2)
outIndices.push_back(indices[1]);
return outIndices;
}
FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
PatternRewriter &rewriter, Location loc, VectorType tileType,
ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
function_ref<Value(Value, ValueRange, Value,
Value)>
makeLoopBody) {
PatternRewriter::InsertionGuard guard(rewriter);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
Value predicate;
Value upperBound;
if (mask) {
auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
loc, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");
auto maskDim0 = createMaskOp.getOperands()[0];
auto maskDim1 = createMaskOp.getOperands()[1];
auto numRowI64 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(), maskDim0);
auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(), numTileSlices);
auto upperBoundI64 =
rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
upperBound = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), upperBoundI64);
predicate =
rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
} else {
upperBound = numTileSlices;
predicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
}
bool hasCarriedArgs = bool(initTile);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
hasCarriedArgs ? ValueRange{initTile}
: ValueRange{});
rewriter.setInsertionPointToStart(forOp.getBody());
Value tileSliceIndex = forOp.getInductionVar();
auto adjustedIndices = getMemrefIndices(
memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
auto nextTile = makeLoopBody(
tileSliceIndex, adjustedIndices, predicate,
hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
assert(bool(nextTile) == hasCarriedArgs);
if (nextTile)
rewriter.create<scf::YieldOp>(loc, nextTile);
return forOp;
}
FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
PatternRewriter &rewriter, Location loc, VectorType tileType,
ValueRange memrefIndices, int memrefRank, Value mask,
function_ref<void(Value, ValueRange, Value)>
makeLoopBody) {
return createLoadStoreForOverTileSlices(
rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
[&](Value index, ValueRange adjustedIndices, Value predicate,
Value) -> Value {
makeLoopBody(index, adjustedIndices, predicate);
return {};
});
}
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
auto mask = tileLoadOp.getMask();
Value initTile;
if (mask) {
auto padOp = tileLoadOp.getPadding();
assert(padOp && "expected padding when masking!");
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
if (!constPadOp || constPadOp.getValue() !=
rewriter.getZeroAttr(tileType.getElementType()))
return rewriter.notifyMatchFailure(
tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
} else {
initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
}
auto forOp = createLoadStoreForOverTileSlices(
rewriter, loc, tileType, tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), mask, initTile,
[&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
Value currentTile) -> Value {
return rewriter.create<arm_sme::LoadTileSliceOp>(
loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
});
if (failed(forOp))
return forOp;
rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
return success();
}
};
struct TileLoadOpWithMaskAndPadNonZeroConversion
: public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
auto tileElementType = tileType.getElementType();
auto maskOp = tileLoadOp.getMask();
if (!maskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "op has no mask, needs unmasked pattern");
auto padOp = tileLoadOp.getPadding();
assert(padOp && "expected padding when masking!");
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");
auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
if (constPadOp &&
constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
return rewriter.notifyMatchFailure(
tileLoadOp, "op has constant zero pad, needs zero pad pattern");
auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), numCols);
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
auto tileSliceIndex = forOp.getInductionVar();
auto currentTile = forOp.getRegionIterArg(0);
auto rowIsActive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
auto rowIsActiveI32 = rewriter.create<arith::ExtSIOp>(
loc, rewriter.getI32Type(), rowIsActive);
auto mask = rewriter.create<arith::AndIOp>(loc, rowIsActiveI32, numColsI32);
auto maskIndex =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), mask);
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
loc, predicateType, maskIndex.getResult());
auto memrefIndices = getMemrefIndices(
tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
tileSliceIndex, numTileSlices, loc, rewriter);
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D,
pad1DOp);
auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
};
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
return createLoadStoreForOverTileSlices(
rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
tileStoreOp.getMask(),
[&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
predicate, tileStoreOp.getBase(), memrefIndices,
tileStoreOp.getLayout());
});
}
};
}
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
TileStoreOpConversion>(patterns.getContext());
}
namespace {
struct ConvertArmSMEToSCFPass
: public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
populateArmSMEToSCFConversionPatterns(patterns);
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
arith::ArithDialect, scf::SCFDialect>();
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
return std::make_unique<ConvertArmSMEToSCFPass>();
}