#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstdint>
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTNARROWING
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
}
namespace mlir::arith {
namespace {
template <typename SourceOp>
struct NarrowingPattern : OpRewritePattern<SourceOp> {
NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
PatternBenefit benefit = 1)
: OpRewritePattern<SourceOp>(ctx, benefit),
supportedBitwidths(options.bitwidthsSupported.begin(),
options.bitwidthsSupported.end()) {
assert(!supportedBitwidths.empty() && "Invalid options");
assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
llvm::sort(supportedBitwidths);
}
FailureOr<unsigned>
getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
for (unsigned candidate : supportedBitwidths)
if (candidate >= bitsRequired)
return candidate;
return failure();
}
FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
assert(origTy);
FailureOr<unsigned> bestBitwidth =
getNarrowestCompatibleBitwidth(bitsRequired);
if (failed(bestBitwidth))
return failure();
Type elemTy = getElementTypeOrSelf(origTy);
if (!isa<IntegerType>(elemTy))
return failure();
auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
if (newElemTy == elemTy)
return failure();
if (origTy == elemTy)
return newElemTy;
if (auto shapedTy = dyn_cast<ShapedType>(origTy))
if (dyn_cast<IntegerType>(shapedTy.getElementType()))
return shapedTy.clone(shapedTy.getShape(), newElemTy);
return failure();
}
private:
llvm::SmallVector<unsigned, 6> supportedBitwidths;
};
FailureOr<unsigned> calculateBitsRequired(Type type) {
assert(type);
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
return intTy.getWidth();
return failure();
}
enum class ExtensionKind { Sign, Zero };
class ExtensionOp {
public:
static FailureOr<ExtensionOp> from(Operation *op) {
if (dyn_cast_or_null<arith::ExtSIOp>(op))
return ExtensionOp{op, ExtensionKind::Sign};
if (dyn_cast_or_null<arith::ExtUIOp>(op))
return ExtensionOp{op, ExtensionKind::Zero};
return failure();
}
ExtensionOp(const ExtensionOp &) = default;
ExtensionOp &operator=(const ExtensionOp &) = default;
Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
Value in) {
if (kind == ExtensionKind::Sign)
return rewriter.create<arith::ExtSIOp>(loc, newType, in);
return rewriter.create<arith::ExtUIOp>(loc, newType, in);
}
void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
Value in) {
assert(toReplace->getNumResults() == 1);
Type newType = toReplace->getResult(0).getType();
Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
rewriter.replaceOp(toReplace, newOp->getResult(0));
}
ExtensionKind getKind() { return kind; }
Value getResult() { return op->getResult(0); }
Value getIn() { return op->getOperand(0); }
Type getType() { return getResult().getType(); }
Type getElementType() { return getElementTypeOrSelf(getType()); }
Type getInType() { return getIn().getType(); }
Type getInElementType() { return getElementTypeOrSelf(getInType()); }
private:
ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
assert(op);
assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
}
Operation *op = nullptr;
ExtensionKind kind = {};
};
unsigned calculateBitsRequired(const APInt &value,
ExtensionKind lookThroughExtension) {
if (lookThroughExtension == ExtensionKind::Zero)
return std::max(value.getActiveBits(), 1u);
if (value.isNonNegative())
return value.getActiveBits() + 1;
if (value.isMinSignedValue())
return value.getBitWidth();
return value.getBitWidth() - value.getNumSignBits() + 1;
}
FailureOr<unsigned> calculateBitsRequired(Value value,
ExtensionKind lookThroughExtension) {
if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
if (elemsAttr.getElementType().isIntOrIndex()) {
if (elemsAttr.isSplat())
return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
lookThroughExtension);
unsigned maxBits = 1;
for (const APInt &elemValue : elemsAttr.getValues<APInt>())
maxBits = std::max(
maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
return maxBits;
}
}
}
if (lookThroughExtension == ExtensionKind::Sign) {
if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
return calculateBitsRequired(sext.getIn().getType());
} else if (lookThroughExtension == ExtensionKind::Zero) {
if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
return calculateBitsRequired(zext.getIn().getType());
}
return calculateBitsRequired(value.getType());
}
template <typename BinaryOp>
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
using NarrowingPattern<BinaryOp>::NarrowingPattern;
virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
virtual bool isSupported(ExtensionOp) const { return true; }
LogicalResult matchAndRewrite(BinaryOp op,
PatternRewriter &rewriter) const final {
Type origTy = op.getType();
FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
if (failed(resultBits))
return failure();
FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
if (failed(ext) || !isSupported(*ext))
return failure();
FailureOr<unsigned> lhsBitsRequired =
calculateBitsRequired(ext->getIn(), ext->getKind());
if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
return failure();
FailureOr<unsigned> rhsBitsRequired =
calculateBitsRequired(op.getRhs(), ext->getKind());
if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
return failure();
unsigned commonBitsRequired =
getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
return failure();
Location loc = op.getLoc();
Value newLhs =
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
Value newRhs =
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
ext->recreateAndReplace(rewriter, op, newAdd);
return success();
}
};
struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits + 1;
}
};
struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
bool isSupported(ExtensionOp ext) const override {
return ext.getKind() == ExtensionKind::Sign;
}
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits + 1;
}
};
struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
unsigned getResultBitsProduced(unsigned operandBits) const override {
return 2 * operandBits;
}
};
struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
bool isSupported(ExtensionOp ext) const override {
return ext.getKind() == ExtensionKind::Sign;
}
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits + 1;
}
};
struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
bool isSupported(ExtensionOp ext) const override {
return ext.getKind() == ExtensionKind::Zero;
}
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits;
}
};
template <typename MinMaxOp, ExtensionKind Kind>
struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
bool isSupported(ExtensionOp ext) const override {
return ext.getKind() == Kind;
}
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits;
}
};
using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
template <typename IToFPOp, ExtensionKind Extension>
struct IToFPPattern final : NarrowingPattern<IToFPOp> {
using NarrowingPattern<IToFPOp>::NarrowingPattern;
LogicalResult matchAndRewrite(IToFPOp op,
PatternRewriter &rewriter) const override {
FailureOr<unsigned> narrowestWidth =
calculateBitsRequired(op.getIn(), Extension);
if (failed(narrowestWidth))
return failure();
FailureOr<Type> narrowTy =
this->getNarrowType(*narrowestWidth, op.getIn().getType());
if (failed(narrowTy))
return failure();
Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
op.getIn());
rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
return success();
}
};
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
template <typename CastOp, ExtensionKind Kind>
struct IndexCastPattern final : NarrowingPattern<CastOp> {
using NarrowingPattern<CastOp>::NarrowingPattern;
LogicalResult matchAndRewrite(CastOp op,
PatternRewriter &rewriter) const override {
Value in = op.getIn();
if (!isa<IndexType>(in.getType()))
return failure();
FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::LB, in);
if (failed(lb))
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, in,
nullptr, true);
if (failed(ub))
return failure();
assert(*lb <= *ub && "Invalid bounds");
unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
IntegerType resultTy = cast<IntegerType>(op.getType());
if (resultTy.getWidth() <= bitsRequired)
return failure();
FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
if (failed(narrowTy))
return failure();
Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
if (Kind == ExtensionKind::Sign)
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
else
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
return success();
}
};
using IndexCastSIPattern =
IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
using IndexCastUIPattern =
IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getSource().getDefiningOp());
if (failed(ext))
return failure();
VectorType origTy = op.getResultVectorType();
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newBroadcast =
rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
ext->recreateAndReplace(rewriter, op, newBroadcast);
return success();
}
};
struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getVector().getDefiningOp());
if (failed(ext))
return failure();
Value newExtract = rewriter.create<vector::ExtractOp>(
op.getLoc(), ext->getIn(), op.getMixedPosition());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
}
};
struct ExtensionOverExtractElement final
: NarrowingPattern<vector::ExtractElementOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getVector().getDefiningOp());
if (failed(ext))
return failure();
Value newExtract = rewriter.create<vector::ExtractElementOp>(
op.getLoc(), ext->getIn(), op.getPosition());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
}
};
struct ExtensionOverExtractStridedSlice final
: NarrowingPattern<vector::ExtractStridedSliceOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getVector().getDefiningOp());
if (failed(ext))
return failure();
VectorType origTy = op.getType();
VectorType extractTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
op.getStrides());
ext->recreateAndReplace(rewriter, op, newExtract);
return success();
}
};
template <typename InsertionOp>
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
using NarrowingPattern<InsertionOp>::NarrowingPattern;
virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
InsertionOp origInsert,
Value narrowValue,
Value narrowDest) const = 0;
LogicalResult matchAndRewrite(InsertionOp op,
PatternRewriter &rewriter) const final {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getSource().getDefiningOp());
if (failed(ext))
return failure();
FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
if (failed(newInsert))
return failure();
ext->recreateAndReplace(rewriter, op, *newInsert);
return success();
}
FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
PatternRewriter &rewriter,
ExtensionOp insValue) const {
FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
if (failed(origBitsRequired))
return failure();
FailureOr<unsigned> destBitsRequired =
calculateBitsRequired(op.getDest(), insValue.getKind());
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
return failure();
FailureOr<unsigned> insertedBitsRequired =
calculateBitsRequired(insValue.getIn(), insValue.getKind());
if (failed(insertedBitsRequired) ||
*insertedBitsRequired >= *origBitsRequired)
return failure();
unsigned newInsertionBits =
std::max(*destBitsRequired, *insertedBitsRequired);
FailureOr<Type> newVecTy =
this->getNarrowType(newInsertionBits, op.getType());
if (failed(newVecTy) || *newVecTy == op.getType())
return failure();
FailureOr<Type> newInsertedValueTy =
this->getNarrowType(newInsertionBits, insValue.getType());
if (failed(newInsertedValueTy))
return failure();
Location loc = op.getLoc();
Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
loc, *newInsertedValueTy, insValue.getResult());
Value narrowDest =
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
return createInsertionOp(rewriter, op, narrowValue, narrowDest);
}
};
struct ExtensionOverInsert final
: ExtensionOverInsertionPattern<vector::InsertOp> {
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
vector::InsertOp origInsert,
Value narrowValue,
Value narrowDest) const override {
return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
narrowDest,
origInsert.getMixedPosition());
}
};
struct ExtensionOverInsertElement final
: ExtensionOverInsertionPattern<vector::InsertElementOp> {
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
vector::InsertElementOp origInsert,
Value narrowValue,
Value narrowDest) const override {
return rewriter.create<vector::InsertElementOp>(
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
}
};
struct ExtensionOverInsertStridedSlice final
: ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
vector::InsertStridedSliceOp
createInsertionOp(PatternRewriter &rewriter,
vector::InsertStridedSliceOp origInsert, Value narrowValue,
Value narrowDest) const override {
return rewriter.create<vector::InsertStridedSliceOp>(
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
origInsert.getStrides());
}
};
struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getSource().getDefiningOp());
if (failed(ext))
return failure();
VectorType origTy = op.getResultVectorType();
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newCast =
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
ext->recreateAndReplace(rewriter, op, newCast);
return success();
}
};
struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getVector().getDefiningOp());
if (failed(ext))
return failure();
VectorType origTy = op.getResultVectorType();
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::TransposeOp>(
op.getLoc(), newTy, ext->getIn(), op.getPermutation());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
};
struct ExtensionOverFlatTranspose final
: NarrowingPattern<vector::FlatTransposeOp> {
using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
PatternRewriter &rewriter) const override {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getMatrix().getDefiningOp());
if (failed(ext))
return failure();
VectorType origTy = op.getType();
VectorType newTy =
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
op.getColumnsAttr());
ext->recreateAndReplace(rewriter, op, newTranspose);
return success();
}
};
struct ArithIntNarrowingPass final
: impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
using ArithIntNarrowingBase::ArithIntNarrowingBase;
void runOnOperation() override {
if (bitwidthsSupported.empty() ||
llvm::is_contained(bitwidthsSupported, 0)) {
return signalPassFailure();
}
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
RewritePatternSet patterns(ctx);
populateArithIntNarrowingPatterns(
patterns, ArithIntNarrowingOptions{bitwidthsSupported});
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}
};
}
void populateArithIntNarrowingPatterns(
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
ExtensionOverInsert, ExtensionOverInsertElement,
ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
patterns.getContext(), options, PatternBenefit(2));
patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
IndexCastUIPattern>(patterns.getContext(), options);
}
}