#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::arm_sve {
#define GEN_PASS_DEF_LEGALIZEVECTORSTORAGE
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::arm_sve;
constexpr StringLiteral kSVELegalizerTag("__arm_sve_legalize_vector_storage__");
namespace {
bool isSVEMaskType(VectorType type) {
return type.getRank() > 0 && type.getElementType().isInteger(1) &&
type.getScalableDims().back() && type.getShape().back() < 16 &&
llvm::isPowerOf2_32(type.getShape().back()) &&
!llvm::is_contained(type.getScalableDims().drop_back(), true);
}
VectorType widenScalableMaskTypeToSvbool(VectorType type) {
assert(isSVEMaskType(type));
return VectorType::Builder(type).setDim(type.getRank() - 1, 16);
}
template <typename TOp, typename TLegalizerCallback>
void replaceOpWithLegalizedOp(PatternRewriter &rewriter, TOp op,
TLegalizerCallback callback) {
auto newOp = op.clone();
rewriter.insert(newOp);
rewriter.replaceOp(op, callback(newOp));
}
template <typename TOp, typename TLegalizerCallback>
void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op,
TLegalizerCallback callback) {
replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) {
return rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), TypeRange{op.getResult().getType()},
ValueRange{callback(newOp)},
NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag),
rewriter.getUnitAttr()));
});
}
static FailureOr<Value> getSVELegalizedMemref(Value illegalMemref) {
Operation *definingOp = illegalMemref.getDefiningOp();
if (!definingOp || !definingOp->hasAttr(kSVELegalizerTag))
return failure();
auto unrealizedConversion =
llvm::cast<UnrealizedConversionCastOp>(definingOp);
return unrealizedConversion.getOperand(0);
}
struct RelaxScalableVectorAllocaAlignment
: public OpRewritePattern<memref::AllocaOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AllocaOp allocaOp,
PatternRewriter &rewriter) const override {
auto memrefElementType = allocaOp.getType().getElementType();
auto vectorType = llvm::dyn_cast<VectorType>(memrefElementType);
if (!vectorType || !vectorType.isScalable() || allocaOp.getAlignment())
return failure();
unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
rewriter.modifyOpInPlace(allocaOp,
[&] { allocaOp.setAlignment(aligment); });
return success();
}
};
template <typename AllocLikeOp>
struct LegalizeSVEMaskAllocation : public OpRewritePattern<AllocLikeOp> {
using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AllocLikeOp allocLikeOp,
PatternRewriter &rewriter) const override {
auto vectorType =
llvm::dyn_cast<VectorType>(allocLikeOp.getType().getElementType());
if (!vectorType || !isSVEMaskType(vectorType))
return failure();
replaceOpWithUnrealizedConversion(
rewriter, allocLikeOp, [&](AllocLikeOp newAllocLikeOp) {
newAllocLikeOp.getResult().setType(
llvm::cast<MemRefType>(newAllocLikeOp.getType().cloneWith(
{}, widenScalableMaskTypeToSvbool(vectorType))));
return newAllocLikeOp;
});
return success();
}
};
struct LegalizeSVEMaskTypeCastConversion
: public OpRewritePattern<vector::TypeCastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TypeCastOp typeCastOp,
PatternRewriter &rewriter) const override {
auto resultType = typeCastOp.getResultMemRefType();
auto vectorType = llvm::dyn_cast<VectorType>(resultType.getElementType());
if (!vectorType || !isSVEMaskType(vectorType))
return failure();
auto legalMemref = getSVELegalizedMemref(typeCastOp.getMemref());
if (failed(legalMemref))
return failure();
replaceOpWithUnrealizedConversion(
rewriter, typeCastOp, [&](vector::TypeCastOp newTypeCast) {
newTypeCast.setOperand(*legalMemref);
newTypeCast.getResult().setType(
llvm::cast<MemRefType>(newTypeCast.getType().cloneWith(
{}, widenScalableMaskTypeToSvbool(vectorType))));
return newTypeCast;
});
return success();
}
};
struct LegalizeSVEMaskStoreConversion
: public OpRewritePattern<memref::StoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto loc = storeOp.getLoc();
Value valueToStore = storeOp.getValueToStore();
auto vectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
if (!vectorType || !isSVEMaskType(vectorType))
return failure();
auto legalMemref = getSVELegalizedMemref(storeOp.getMemref());
if (failed(legalMemref))
return failure();
auto legalMaskType = widenScalableMaskTypeToSvbool(
llvm::cast<VectorType>(valueToStore.getType()));
auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>(
loc, legalMaskType, valueToStore);
replaceOpWithLegalizedOp(rewriter, storeOp,
[&](memref::StoreOp newStoreOp) {
newStoreOp.setOperand(0, convertToSvbool);
newStoreOp.setOperand(1, *legalMemref);
return newStoreOp;
});
return success();
}
};
struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto loc = loadOp.getLoc();
Value loadedMask = loadOp.getResult();
auto vectorType = llvm::dyn_cast<VectorType>(loadedMask.getType());
if (!vectorType || !isSVEMaskType(vectorType))
return failure();
auto legalMemref = getSVELegalizedMemref(loadOp.getMemref());
if (failed(legalMemref))
return failure();
auto legalMaskType = widenScalableMaskTypeToSvbool(vectorType);
replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) {
newLoadOp.setMemRef(*legalMemref);
newLoadOp.getResult().setType(legalMaskType);
return rewriter.create<arm_sve::ConvertFromSvboolOp>(
loc, loadedMask.getType(), newLoadOp);
});
return success();
}
};
}
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
RewritePatternSet &patterns) {
patterns.add<RelaxScalableVectorAllocaAlignment,
LegalizeSVEMaskAllocation<memref::AllocaOp>,
LegalizeSVEMaskAllocation<memref::AllocOp>,
LegalizeSVEMaskTypeCastConversion,
LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
patterns.getContext());
}
namespace {
struct LegalizeVectorStorage
: public arm_sve::impl::LegalizeVectorStorageBase<LegalizeVectorStorage> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateLegalizeVectorStoragePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
}
ConversionTarget target(getContext());
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[](UnrealizedConversionCastOp unrealizedConversion) {
return !unrealizedConversion->hasAttr(kSVELegalizerTag);
});
if (failed(applyPartialConversion(getOperation(), target, {})))
signalPassFailure();
}
};
}
std::unique_ptr<Pass> mlir::arm_sve::createLegalizeVectorStoragePass() {
return std::make_unique<LegalizeVectorStorage>();
}