#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
namespace {
static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
static Operation *createLoadTileSliceIntrinsic(
RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
IntegerAttr tileId, Value tileSliceI32) {
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
break;
}
}
}
static Operation *createStoreTileSliceIntrinsic(
RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
IntegerAttr tileId, Value tileSliceI32) {
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
loc, maskOp, ptr, tileId, tileSliceI32);
}
} else {
switch (type) {
case arm_sme::ArmSMETileType::ZAB:
return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAH:
return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAS:
return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAD:
return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
case arm_sme::ArmSMETileType::ZAQ:
return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
loc, maskOp, ptr, tileId, tileSliceI32);
}
}
}
IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
auto tileId = op.getTileId();
if (!tileId)
op.emitOpError(
"expected tile ID to be allocated before conversion to LLVM");
return tileId;
}
static memref::AllocaOp
createAllocaForTile(RewriterBase &rewriter, Location loc,
FunctionOpInterface func,
arm_sme::ArmSMETileOpInterface tileOp) {
RewriterBase::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&func.getBlocks().front());
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
rewriter.create<arith::ConstantIndexOp>(loc, minElements);
auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
auto alloca = rewriter.create<memref::AllocaOp>(
loc, memrefType, ValueRange{vectorLen, vectorLen});
return alloca;
}
static memref::AllocaOp getOrCreateAllocaForTile(
RewriterBase &rewriter, Location loc, FunctionOpInterface func,
arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
for (auto &op : func.getBlocks().front()) {
auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
if (!alloca)
continue;
auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
alloca->getDiscardableAttr(kInMemoryTileIdAttr));
if (!inMemoryTileId)
continue;
if (inMemoryTileId.getInt() == tileId)
return alloca;
}
auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
alloca->setDiscardableAttr(kInMemoryTileIdAttr,
rewriter.getI32IntegerAttr(tileId));
return alloca;
}
struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
const LLVMTypeConverter &typeConverter,
PatternBenefit benefit)
: ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
typeConverter, benefit) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
if (!tileOp.isInMemoryTile())
return failure();
tileOp->emitWarning(
"failed to allocate SME virtual tile to operation, tile value will go "
"through memory, expect degraded performance");
auto loc = tileOp.getLoc();
auto func = tileOp->getParentOfType<FunctionOpInterface>();
auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
tileOp.getTileId().getInt());
auto zeroTileId = rewriter.getI32IntegerAttr(0);
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
VectorType tileVectorType = tileOp.getTileType();
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
auto swapInMemoryTileWithSMETileZero = [&] {
emitFullTileSwap(rewriter, loc, tileAlloca,
*arm_sme::getSMETileType(tileVectorType), sliceType,
zeroTileId);
};
{
rewriter.setInsertionPoint(op);
swapInMemoryTileWithSMETileZero();
rewriter.setInsertionPointAfter(op);
swapInMemoryTileWithSMETileZero();
}
return success();
}
Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
Value tileMemory, Value sliceIndex) const {
auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
auto descriptor =
rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(), sliceIndex);
return getStridedElementPtr(
loc, llvm::cast<MemRefType>(tileMemory.getType()),
descriptor.getResult(0), {sliceIndexI64, zero},
static_cast<ConversionPatternRewriter &>(rewriter));
}
void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId, Value sliceIndex) const {
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);
auto predicateType = sliceType.clone(rewriter.getI1Type());
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
allTruePredicate, slicePtr, tileId, sliceIndexI32);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
ValueRange{sliceIndex, zero});
}
void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
arm_sme::ArmSMETileType tileType, VectorType sliceType,
IntegerAttr tileId) const {
RewriterBase::InsertionGuard guard(rewriter);
auto minNumElts =
rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<arith::MulIOp>(
loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
rewriter.setInsertionPointToStart(forOp.getBody());
auto sliceIndex = forOp.getInductionVar();
emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
sliceIndex);
}
};
enum class RequiresSpillsAndFills { Yes, No };
template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
RequiresSpillsAndFills::Yes>
struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
using ArmSMEOp = SourceOp;
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
static constexpr bool requiresSpillsAndFillsConversion() {
return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
}
};
template <typename Pattern>
static void addArmSMEConversionPattern(RewritePatternSet &patterns,
LLVMTypeConverter const &typeConverter) {
if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
typename Pattern::ArmSMEOp>,
typename Pattern::ArmSMEOp>) {
patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
Pattern::ArmSMEOp::getOperationName(), typeConverter,
1337);
}
patterns.add<Pattern>(typeConverter);
}
template <typename... Patterns>
static void
addArmSMEConversionPatterns(RewritePatternSet &patterns,
LLVMTypeConverter const &typeConverter) {
(addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
}
struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
auto tileId = getTileIdOrError(zero);
if (!tileId)
return failure();
arm_sme::ArmSMETileType tileType =
*arm_sme::getSMETileType(zero.getTileType());
auto baseMaskForSize = [&] {
switch (tileType) {
case arm_sme::ArmSMETileType::ZAB:
return 0b1111'1111;
case arm_sme::ArmSMETileType::ZAH:
return 0b0101'0101;
case arm_sme::ArmSMETileType::ZAS:
return 0b0001'0001;
case arm_sme::ArmSMETileType::ZAD:
return 0b0000'0001;
default:
llvm_unreachable("bad element size");
}
}();
int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
rewriter.create<arm_sme::aarch64_sme_zero>(
loc, rewriter.getI32IntegerAttr(zeroMask));
rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
}
};
struct LoadTileSliceConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
arm_sme::LoadTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = loadTileSliceOp.getLoc();
auto tileId = getTileIdOrError(loadTileSliceOp);
if (!tileId)
return failure();
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
adaptor.getBase(),
adaptor.getIndices(), rewriter);
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
auto maskOp = loadTileSliceOp.getMask();
auto tileVectorType = loadTileSliceOp.getVectorType();
arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
tileId, tileSliceI32);
rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
return success();
}
};
struct StoreTileSliceConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
arm_sme::StoreTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = storeTileSliceOp.getLoc();
auto tileVectorType = storeTileSliceOp.getVectorType();
auto tileId = getTileIdOrError(storeTileSliceOp);
if (!tileId)
return failure();
Value ptr = this->getStridedElementPtr(
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
adaptor.getIndices(), rewriter);
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
auto maskOp = storeTileSliceOp.getMask();
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
rewriter.replaceOp(storeTileSliceOp,
createStoreTileSliceIntrinsic(rewriter, loc, tileType,
layout, maskOp, ptr,
tileId, tileSliceI32));
return success();
}
};
struct MoveVectorToTileSliceConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = moveVectorToTileSliceOp.getLoc();
auto tileType = moveVectorToTileSliceOp.getTileType();
auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
if (!tileId)
return failure();
auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
switch (moveVectorToTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileId, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.create<arm_sme::aarch64_sme_write_vert>(
loc, tileId, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
}
rewriter.replaceOp(moveVectorToTileSliceOp,
moveVectorToTileSliceOp.getTile());
return success();
}
};
struct MoveTileSliceToVectorConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::MoveTileSliceToVectorOp moveTileSliceToVector,
OpAdaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = moveTileSliceToVector.getLoc();
auto sliceType = moveTileSliceToVector.getSliceType();
auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
auto tileId = getTileIdOrError(moveTileSliceToVector);
if (!tileId)
return failure();
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
auto zeroVector = rewriter.create<arith::ConstantOp>(
loc, sliceType, rewriter.getZeroAttr(sliceType));
auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), sliceIndex);
switch (moveTileSliceToVector.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileId, sliceIndexI32);
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
tileId, sliceIndexI32);
break;
}
return success();
}
};
struct OuterProductOpConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
arm_sme::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tileId = getTileIdOrError(outerProductOp);
if (!tileId)
return failure();
auto isSupportedType = [](VectorType vectorType) {
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
auto elementType = vectorType.getElementType();
if (!elementType.isF16() && !elementType.isBF16() &&
!elementType.isF32() && !elementType.isF64())
return false;
unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
vectorType.getElementTypeBitWidth();
return vectorType.getShape() ==
ArrayRef<int64_t>({minNumElts, minNumElts});
};
if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");
auto resultVectorType = outerProductOp.getResultType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");
auto loc = outerProductOp.getLoc();
Value acc = outerProductOp.getAcc();
if (!acc) {
auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
zero.setTileId(tileId);
acc = zero;
}
Value lhsMask = outerProductOp.getLhsMask();
Value rhsMask = outerProductOp.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy =
outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
Value allActiveMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
outerProductOp.getLhs(),
outerProductOp.getRhs());
rewriter.replaceOp(outerProductOp, acc);
return success();
}
};
template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
struct OuterProductWideningOpConversion
: public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
using ConvertArmSMEOpToLLVMPattern<
OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(OuterProductWideningOp op,
typename OuterProductWideningOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tileId = getTileIdOrError(op);
if (!tileId)
return failure();
auto loc = op.getLoc();
Value acc = op.getAcc();
if (!acc) {
auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
zero.setTileId(tileId);
acc = zero;
}
Value lhsMask = op.getLhsMask();
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
Value allActiveMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
rewriter.create<OuterProductWideningIntrOp>(
loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, acc);
return success();
}
};
struct StreamingVLOpConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
RequiresSpillsAndFills::No> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
arm_sme::StreamingVLOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = streamingVlOp.getLoc();
auto i64Type = rewriter.getI64Type();
auto *intrOp = [&]() -> Operation * {
switch (streamingVlOp.getTypeSize()) {
case arm_sme::TypeSize::Byte:
return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
case arm_sme::TypeSize::Half:
return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
case arm_sme::TypeSize::Word:
return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
case arm_sme::TypeSize::Double:
return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
}
}();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
return success();
}
};
}
namespace {
struct ConvertArmSMEToLLVMPass
: public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
this->dumpTileLiveRanges = dumpTileLiveRanges;
}
void runOnOperation() override {
auto function = getOperation();
if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
return signalPassFailure();
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
function->walk([&](Operation *op) {
if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
!op->isRegistered())
return;
auto isSMETileType = [](Type type) {
return arm_sme::isValidSMETileVectorType(type);
};
if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
llvm::any_of(op->getOperandTypes(), isSMETileType)) {
op->emitOpError("unexpected operation with SME tile type after "
"conversion to LLVM");
signalPassFailure();
}
});
}
};
}
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect,
vector::VectorDialect, scf::SCFDialect,
memref::MemRefDialect>();
target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
UnrealizedConversionCastOp>();
}
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
converter.addConversion([&](VectorType type) -> std::optional<Type> {
if (arm_sme::isValidSMETileVectorType(type))
return type;
return std::nullopt;
});
addArmSMEConversionPatterns<
LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
StreamingVLOpConversion, OuterProductOpConversion,
OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
arm_sme::aarch64_sme_mopa_wide>,
OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
arm_sme::aarch64_sme_mops_wide>,
OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
arm_sme::aarch64_sme_smopa_za32>,
OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
arm_sme::aarch64_sme_smops_za32>,
OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
arm_sme::aarch64_sme_umopa_za32>,
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
arm_sme::aarch64_sme_umops_za32>,
OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
arm_sme::aarch64_sme_smopa_wide>,
OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
arm_sme::aarch64_sme_smops_wide>,
OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
arm_sme::aarch64_sme_umopa_wide>,
OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
arm_sme::aarch64_sme_umops_wide>,
OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
arm_sme::aarch64_sme_sumopa_wide>,
OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
arm_sme::aarch64_sme_sumops_wide>,
OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
arm_sme::aarch64_sme_usmopa_wide>,
OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
arm_sme::aarch64_sme_usmops_wide>,
ZeroOpConversion>(patterns, converter);
}
std::unique_ptr<Pass>
mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
}