#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include <climits>
using namespace mlir;
namespace {
struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const final;
};
}
LogicalResult
PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value x = op.getLhs();
FloatAttr scalarExponent;
DenseFPElementsAttr vectorExponent;
bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
auto isExponentValue = [&](double value) -> bool {
if (isScalar)
return scalarExponent.getValue().isExactlyValue(value);
if (isVector && vectorExponent.isSplat())
return vectorExponent.getSplatValue<FloatAttr>()
.getValue()
.isExactlyValue(value);
return false;
};
auto bcast = [&](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
return value;
};
if (isExponentValue(1.0)) {
rewriter.replaceOp(op, x);
return success();
}
if (isExponentValue(2.0)) {
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
return success();
}
if (isExponentValue(3.0)) {
Value square =
rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
return success();
}
if (isExponentValue(-1.0)) {
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
return success();
}
if (isExponentValue(0.5)) {
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
return success();
}
if (isExponentValue(-0.5)) {
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
return success();
}
if (isExponentValue(0.75)) {
Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
ValueRange{powHalf, powQuarter});
return success();
}
return failure();
}
namespace {
template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
unsigned exponentThreshold;
public:
PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
PatternBenefit benefit = 1,
ArrayRef<StringRef> generatedNames = {})
: OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
exponentThreshold(exponentThreshold) {}
LogicalResult matchAndRewrite(PowIOpTy op,
PatternRewriter &rewriter) const final;
};
}
template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
LogicalResult
PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
PowIOpTy op, PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value base = op.getLhs();
IntegerAttr scalarExponent;
DenseIntElementsAttr vectorExponent;
bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
int64_t exponentValue = 0;
if (isScalar)
exponentValue = scalarExponent.getInt();
else if (isVector && vectorExponent.isSplat())
exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
else
return failure();
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
return value;
};
Value one;
Type opType = getElementTypeOrSelf(op.getType());
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(opType, 1.0));
else
one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(opType, 1));
if (exponentValue == 0) {
rewriter.replaceOp(op, bcast(one));
return success();
}
bool exponentIsNegative = false;
if (exponentValue < 0) {
exponentIsNegative = true;
exponentValue *= -1;
}
if (exponentValue > exponentThreshold)
return failure();
if (exponentIsNegative)
base = rewriter.create<DivOpTy>(loc, bcast(one), base);
Value result = base;
for (unsigned i = 1; i < exponentValue; ++i)
result = rewriter.create<MulOpTy>(loc, result, base);
rewriter.replaceOp(op, result);
return success();
}
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
patterns
.add<PowFStrengthReduction,
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
patterns.getContext());
}