#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenerateOp generateOp,
PatternRewriter &rewriter) const override {
auto tensorType =
llvm::cast<RankedTensorType>(generateOp.getResult().getType());
if (!tensorType.hasStaticShape())
return failure();
auto terminatorOp =
cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
Attribute attr;
if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
return failure();
Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter,
DenseElementsAttr::get(tensorType, attr),
tensorType, generateOp->getLoc());
if (!constantOp)
return failure();
rewriter.replaceOp(generateOp, constantOp->getResults());
return success();
}
};
int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputStrides,
int64_t srcLinearIndex) {
assert(inputShape.size() == outputStrides.size());
int64_t dstLinearIndex = 0;
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
srcLinearIndex = quotient;
dstLinearIndex += outputStrides[dim] * remainder;
}
return dstLinearIndex;
}
template <typename ElemType, typename AttrType>
Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
DenseElementsAttr input, AttrType padValue,
ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
auto inputValues = input.tryGetValues<ElemType>();
if (failed(inputValues))
return nullptr;
auto oldShape = input.getType().getShape();
auto newShape =
llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
[](std::tuple<int64_t, int64_t, int64_t> pack) {
auto [old, low, high] = pack;
return old + low + high;
});
int64_t outputSize = computeProduct(newShape);
SmallVector<ElemType> values(outputSize, padValue.getValue());
SmallVector<int64_t> outputStrides = computeStrides(newShape);
int64_t startingOffset = linearize(padLow, outputStrides);
for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
values[outputIndex + startingOffset] = inputValue;
}
auto newType = input.getType().clone(newShape);
auto newAttr = DenseElementsAttr::get(newType, values);
Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter, newAttr, newType, loc);
return constantOp ? constantOp->getResult(0) : nullptr;
}
struct PadOpToConstant final : public OpRewritePattern<PadOp> {
PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
if (padTensorOp.getNofold())
return rewriter.notifyMatchFailure(
padTensorOp, "refusing to fold nofold pad operation");
TypedValue<RankedTensorType> input = padTensorOp.getSource();
RankedTensorType resultType = padTensorOp.getResult().getType();
DenseElementsAttr inputAttr = nullptr;
if (!matchPattern(input, m_Constant(&inputAttr)))
return failure();
Value paddingValue = padTensorOp.getConstantPaddingValue();
Attribute paddingAttr = nullptr;
if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
return rewriter.notifyMatchFailure(padTensorOp,
"unable to get constant value");
auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
if (!lowPad || !highPad)
return rewriter.notifyMatchFailure(padTensorOp,
"unable to extract constant padding");
if (!controlFn(&padTensorOp.getSourceMutable()))
return rewriter.notifyMatchFailure(padTensorOp,
"not folding due to cost function");
Location loc = padTensorOp.getLoc();
Value newOp =
llvm::TypeSwitch<Attribute, Value>(paddingAttr)
.Case([&](FloatAttr floatAttr) {
return constantFoldPadOp<llvm::APFloat>(
rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
})
.Case([&](IntegerAttr integerAttr) {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
.Default(Value());
if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,
"tensor type not supported");
if (newOp.getType() != resultType)
newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
rewriter.replaceOp(padTensorOp, newOp);
return success();
}
private:
ControlFoldFn controlFn;
};
}
void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
patterns.add<GenerateToConstant>(patterns.getContext());
patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
}