#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#define DEBUG_TYPE "linalg-padding"
using namespace mlir;
using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
OpOperand *opOperand,
const LinalgPaddingOptions &options,
SmallVector<int64_t> &paddedShape,
bool &alreadyHasRequestedShape) {
AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
alreadyHasRequestedShape = true;
DenseMap<int64_t, int64_t> shapeDimToMultiple;
for (const auto &dimEn : enumerate(options.paddingDimensions)) {
for (const auto &en : enumerate(indexingMap.getResults())) {
if (en.value().isFunctionOfDim(dimEn.value())) {
int64_t dimSize = shape[en.index()];
if (options.padToMultipleOf.has_value()) {
shapeDimToMultiple[en.index()] =
(*options.padToMultipleOf)[dimEn.index()];
} else {
shapeDimToMultiple[en.index()] = 1;
}
if (ShapedType::isDynamic(dimSize)) {
alreadyHasRequestedShape = false;
} else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
alreadyHasRequestedShape = false;
}
}
}
}
auto ceil = [](int64_t val, int64_t multiple) {
return ((val + multiple - 1) / multiple) * multiple;
};
paddedShape.assign(shape.begin(), shape.end());
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
if (!shapeDimToMultiple.contains(i)) {
LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
continue;
}
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB,
{opOperand->get(),
i},
nullptr, true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
}
paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
}
return success();
}
static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
const LinalgPaddingOptions &options) {
assert(
(!options.padToMultipleOf.has_value() ||
options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
"invalid number of elements in padToMultipleOf");
SmallVector<int64_t> paddedShape;
bool alreadyHasRequestedShape = false;
if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
alreadyHasRequestedShape)))
return rewriter.notifyMatchFailure(opToPad,
"--failed to compute padded shape");
bool nofold = opOperand->getOperandNumber() < options.packPaddings.size()
? options.packPaddings[opOperand->getOperandNumber()]
: false;
if (!nofold && alreadyHasRequestedShape)
return opOperand->get();
if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
}
Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
Value paddingValue;
if (auto complexTy = dyn_cast<ComplexType>(
getElementTypeOrSelf(opOperand->get().getType()))) {
auto complexAttr = cast<ArrayAttr>(paddingAttr);
paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
complexTy, complexAttr);
} else {
paddingValue = rewriter.create<arith::ConstantOp>(
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
}
auto paddedTensorType = RankedTensorType::get(
paddedShape, getElementTypeOrSelf(opOperand->get()));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
opOperand->get(), paddingValue, nofold);
}
LogicalResult
linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
const LinalgPaddingOptions &constOptions,
LinalgOp &paddedOp, SmallVector<Value> &replacements,
SmallVector<tensor::PadOp> &padOps) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
Location loc = opToPad->getLoc();
LinalgPaddingOptions options(constOptions);
if (options.paddingValues.empty()) {
SmallVector<Type> types(opToPad->getOperandTypes());
llvm::append_range(types, opToPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
rewriter.getZeroAttr(getElementTypeOrSelf(t)));
}
}
if (!opToPad.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(opToPad,
"expected operation on tensors");
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(opToPad);
SmallVector<Value> newOperands;
newOperands.reserve(opToPad->getNumOperands());
for (OpOperand &opOperand : opToPad->getOpOperands()) {
FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
rewriter, opToPad, &opOperand, options);
if (failed(paddedOperand)) {
LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
<< opOperand.get() << " -> FAIL\n");
return rewriter.notifyMatchFailure(opToPad,
"operand cannot be bound statically");
}
newOperands.push_back(*paddedOperand);
if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
}
ReifiedRankedShapedTypeDims reifiedResultShapes;
if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
return rewriter.notifyMatchFailure(opToPad,
"failed to reify result shapes");
}
assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
"expected same number of results");
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes();
paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
SmallVector<Value> paddedSubtensorResults;
paddedSubtensorResults.reserve(opToPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) {
replacements = std::move(paddedSubtensorResults);
return success();
}
assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
opToPad.getNumDpsInits() &&
"expected matching number of results");
for (auto it :
llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
replacements.push_back(rewriter
.create<linalg::CopyOp>(loc, std::get<0>(it),
std::get<1>(it).get())
.getResult(0));
} else if (options.copyBackOp ==
LinalgPaddingOptions::CopyBackOp::
BufferizationMaterializeInDestination) {
replacements.push_back(
rewriter
.create<bufferization::MaterializeInDestinationOp>(
loc, std::get<0>(it), std::get<1>(it).get())
->getResult(0));
} else {
llvm_unreachable("unsupported copy back op");
}
}
return success();
}
FailureOr<LinalgOp>
mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
const LinalgPaddingOptions &options) {
assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
"invalid options");
if (!linalgOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
linalgOp, "only applies to Linalg ops with tensor semantics");
LinalgOp paddedOp;
SmallVector<Value> newResults;
SmallVector<tensor::PadOp> padOps;
if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp,
newResults, padOps)))
return rewriter.notifyMatchFailure(linalgOp,
"failed to rewrite as a padded op");
for (const auto &en : enumerate(options.hoistPaddings)) {
if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
break;
OpOperand &opOperand = paddedOp->getOpOperand(en.index());
auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
if (!padOp || en.value() == 0) {
(void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip");
continue;
}
if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
(void)rewriter.notifyMatchFailure(linalgOp,
"non static padding shape -- skip");
continue;
}
tensor::PadOp hoistedOp;
SmallVector<GenericOp> transposeOps;
SmallVector<int64_t> transposeVector =
en.index() < options.transposePaddings.size()
? options.transposePaddings[en.index()]
: SmallVector<int64_t>{};
FailureOr<Value> newResult = hoistPaddingOnTensors(
padOp, en.value(), transposeVector, hoistedOp, transposeOps);
if (failed(newResult)) {
(void)rewriter.notifyMatchFailure(linalgOp,
"failed to apply hoistPadding");
continue;
}
rewriter.replaceOp(padOp, *newResult);
}
rewriter.replaceOp(linalgOp, newResults);
return paddedOp;
}