#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
using llvm::dbgs;
#define DEBUG_TYPE "hoist-padding"
#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::linalg::detail;
#ifndef NDEBUG
static bool debugPrintLoopInShortForm(Operation *op) {
AsmState state(op->getParentOfType<func::FuncOp>());
(void)state;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
forOp.getInductionVar().printAsOperand(dbgs(), state);
dbgs() << " @ " << forOp.getOperation();
return true;
}
return false;
}
#endif
static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) {
LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:",
[](Operation *op) {
dbgs() << "\n";
DBGS() << "----";
if (debugPrintLoopInShortForm(op)) {
dbgs() << "\n";
return;
}
dbgs() << *op << "\n";
});
DBGS() << "\n";);
}
static void
getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels,
SmallVector<scf::ForOp> &reverseEnclosingLoops) {
scf::ForOp outermostEnclosingForOp = nullptr;
Operation *nextEnclosingOp = padOp->getParentOp();
while (nLevels-- > 0 &&
(outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
LLVM_DEBUG(DBGS() << "loops: ";
debugPrintLoopInShortForm(outermostEnclosingForOp);
dbgs() << "\n");
reverseEnclosingLoops.push_back(outermostEnclosingForOp);
nextEnclosingOp = outermostEnclosingForOp->getParentOp();
}
}
static void
getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop,
SmallVector<scf::ForOp> &reverseEnclosingLoops) {
scf::ForOp outermostEnclosingForOp = nullptr;
Operation *nextEnclosingOp = padOp->getParentOp();
while (outermostEnclosingForOp != untilLoop &&
(outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) {
LLVM_DEBUG(DBGS() << "loops: ";
debugPrintLoopInShortForm(outermostEnclosingForOp);
dbgs() << "\n");
reverseEnclosingLoops.push_back(outermostEnclosingForOp);
nextEnclosingOp = outermostEnclosingForOp->getParentOp();
}
}
static void computeBackwardSlice(tensor::PadOp padOp,
scf::ForOp outermostEnclosingForOp,
SetVector<Operation *> &backwardSlice) {
DominanceInfo domInfo(outermostEnclosingForOp);
BackwardSliceOptions sliceOptions;
sliceOptions.filter = [&](Operation *op) {
return domInfo.dominates(outermostEnclosingForOp, op) &&
!padOp->isProperAncestor(op);
};
sliceOptions.inclusive = true;
SetVector<Value> valuesDefinedAbove;
getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(),
valuesDefinedAbove);
for (Value v : valuesDefinedAbove) {
getBackwardSlice(v, &backwardSlice, sliceOptions);
}
getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions);
}
namespace {
struct HoistPaddingAnalysis {
HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops);
HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp);
bool isValid() { return valid.has_value() && valid.value(); }
bool isInvalid() { return valid.has_value() && !valid.value(); }
SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter,
Location loc) const;
void enableHoistPadding(RewriterBase &rewriter);
void finalizeHoistPaddingAnalysis();
private:
std::optional<bool> valid;
tensor::PadOp opToHoist;
SmallVector<scf::ForOp> reverseEnclosingLoops;
LogicalResult dropNonIndexDependencies();
public:
scf::ForOp outermostEnclosingForOp;
SetVector<Operation *> backwardSlice;
SmallVector<scf::ForOp> packingLoops;
tensor::ExtractSliceOp sliceOp;
scf::ForOp padConsumingForOp;
};
}
HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops)
: valid(std::nullopt), opToHoist(padOp) {
getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops);
if (reverseEnclosingLoops.empty()) {
LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
valid = false;
return;
}
outermostEnclosingForOp = reverseEnclosingLoops.back();
sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp) {
LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
valid = false;
return;
}
}
HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp,
scf::ForOp outermostEnclosingForOp)
: valid(std::nullopt), opToHoist(padOp) {
getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp,
reverseEnclosingLoops);
if (reverseEnclosingLoops.empty()) {
LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n");
valid = false;
return;
}
this->outermostEnclosingForOp = reverseEnclosingLoops.back();
if (this->outermostEnclosingForOp != outermostEnclosingForOp) {
LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n");
valid = false;
return;
}
sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp) {
LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n");
valid = false;
return;
}
}
void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) {
if (isInvalid())
return;
if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
outermostEnclosingForOp = cast<scf::ForOp>(
hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp));
}
}
void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() {
if (isInvalid())
return;
if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) {
LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n"
<< outermostEnclosingForOp << "\n"
<< "--sliceOp: " << sliceOp << "\n"
<< "--sliceOp.getSource(): " << sliceOp.getSource()
<< "\n");
LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n");
valid = false;
return;
}
if (sliceOp->hasOneUse()) {
padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin()));
}
Value paddingValue = opToHoist.getConstantPaddingValue();
if (!paddingValue ||
!isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) {
LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n");
valid = false;
return;
}
computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice);
if (backwardSlice.size() <= 1) {
valid = false;
return;
}
debugPrintBackwardSlice(backwardSlice);
if (failed(dropNonIndexDependencies())) {
LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n");
valid = false;
return;
}
debugPrintBackwardSlice(backwardSlice);
for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops))
if (backwardSlice.contains(forOp))
packingLoops.push_back(forOp);
if (packingLoops.size() > 1 && padConsumingForOp) {
LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> "
"Downgrade to 1 loop\n");
packingLoops.resize(1);
}
valid = true;
}
LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() {
SetVector<Value> indexEdges;
auto addIndexOperandsToIndexEdges = [&](Operation *operation) {
for (Value operand : operation->getOperands())
if (operand.getType().isIndex())
indexEdges.insert(operand);
};
auto hasIndexResult = [&](Operation *operation) {
return llvm::any_of(operation->getResults(), [&](Value result) {
return indexEdges.contains(result);
});
};
SetVector<Operation *> operationsToRemove;
for (Operation *op : llvm::reverse(backwardSlice)) {
if (op == opToHoist || op == sliceOp) {
addIndexOperandsToIndexEdges(op);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) {
addIndexOperandsToIndexEdges(op);
continue;
}
}
if (hasIndexResult(op)) {
addIndexOperandsToIndexEdges(op);
if (llvm::any_of(op->getOperandTypes(),
[](Type type) { return !type.isIndex(); })) {
LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: "
<< op << " -> Skip\n");
return failure();
}
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect();
if (hasMemoryEffect || op->getNumRegions() != 0) {
LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: "
<< op << " -> Skip\n");
return failure();
}
continue;
}
if (!isa<arith::ConstantOp>(op))
operationsToRemove.insert(op);
}
backwardSlice.set_subtract(operationsToRemove);
return success();
}
SmallVector<Value>
HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
Location loc) const {
SmallVector<Value> dynamicTensorSizes;
for (auto forOp : packingLoops) {
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
if (v == forOp.getUpperBound())
return false;
Operation *op = v.getDefiningOp();
if (!op)
return true;
return !isa<affine::AffineMinOp, affine::AffineMaxOp,
affine::AffineApplyOp>(op);
},
true);
assert(succeeded(loopUb) && "could not get upper bound");
Value ubVal = getValueOrCreateConstantIndexOp(rewriter, loc, *loopUb);
AffineExpr lb, ub, step;
bindDims(rewriter.getContext(), lb, ub);
bindSymbols(rewriter.getContext(), step);
Value res = rewriter.createOrFold<affine::AffineApplyOp>(
loc, (ub - lb).ceilDiv(step),
ValueRange{forOp.getLowerBound(), ubVal,
cast<scf::ForOp>(forOp).getStep()});
dynamicTensorSizes.push_back(res);
}
return dynamicTensorSizes;
}
static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) {
return outer.isDefinedOutsideOfLoop(v) || matchPattern(v, m_Constant());
}
static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer,
scf::ForOp forOp) {
MLIRContext *ctx = forOp->getContext();
AffineExpr iv, lb, step;
bindDims(ctx, iv, lb);
bindSymbols(ctx, step);
if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) ||
!isDefinedOutsideOrConstant(outer, forOp.getStep()))
return Value();
Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(),
stepVal = forOp.getStep();
auto loc = forOp->getLoc();
return rewriter.createOrFold<affine::AffineApplyOp>(
loc, (iv - lb).ceilDiv(step), ValueRange{ivVal, lbVal, stepVal});
}
static FailureOr<PackingResult> buildPackingLoopNestImpl(
RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType,
tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) {
SmallVector<OpFoldResult> offsets, sizes, strides;
SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings;
scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
Location loc = opToHoist->getLoc();
RankedTensorType paddedTensorType = opToHoist.getResultType();
int paddedRank = paddedTensorType.getRank();
BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource());
while (bbArg) {
auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
if (!forOp)
break;
if (forOp != outerLoop && !outerLoop->isAncestor(forOp))
break;
OpOperand &operand = *forOp.getTiedLoopInit(bbArg);
bvm.map(bbArg, operand.get());
bbArg = dyn_cast<BlockArgument>(operand.get());
}
Value hoistedPackedTensor = emptyOp.getResult();
OpBuilder::InsertionGuard g(rewriter);
for (Operation *op : analysis.backwardSlice) {
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) {
LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n");
continue;
}
}
auto forOp = dyn_cast<scf::ForOp>(op);
if (!forOp) {
rewriter.clone(*op, bvm);
continue;
}
auto clonedForOp = rewriter.create<scf::ForOp>(
loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs());
bvm.map(forOp.getResults(), clonedForOp.getResults());
assert(clonedForOp->getNumRegions() == 1);
clonedLoopIvs.push_back(clonedForOp.getInductionVar());
rewriter.setInsertionPointToStart(&clonedForOp->getRegion(0).front());
Value loopIndependentIterationCount =
buildLoopIterationCount(rewriter, outerLoop, clonedForOp);
if (!loopIndependentIterationCount)
llvm_unreachable("loop independence prerequisite not met");
leadingHoistedPackedTensorIndexings.push_back(
loopIndependentIterationCount);
hoistedPackedTensor = clonedForOp.getRegionIterArgs().front();
}
int64_t nPackedLoops = clonedLoopIvs.size();
offsets =
SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(),
leadingHoistedPackedTensorIndexings.end()};
offsets.append(paddedRank, rewriter.getIndexAttr(0));
sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1));
for (int64_t sz : transposedTensorType.getShape()) {
if (ShapedType::isDynamic(sz))
return failure();
sizes.push_back(rewriter.getIndexAttr(sz));
}
strides = SmallVector<OpFoldResult>(nPackedLoops + paddedRank,
rewriter.getIndexAttr(1));
GenericOp maybeTransposeOp;
Value paddedTensor = bvm.lookup(opToHoist.getResult());
if (!transposeVector.empty()) {
Value outputTensor = rewriter.create<tensor::ExtractSliceOp>(
loc, transposedTensorType, hoistedPackedTensor, offsets, sizes,
strides);
maybeTransposeOp = makeTransposeOp(rewriter, loc, paddedTensor,
outputTensor, transposeVector);
paddedTensor = maybeTransposeOp.getResult(0);
}
if (nPackedLoops > 0) {
Value inserted = rewriter.create<tensor::InsertSliceOp>(
loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides);
Value valueToYield = inserted;
for (Value iv : llvm::reverse(clonedLoopIvs)) {
auto forOp = scf::getForInductionVarOwner(iv);
rewriter.setInsertionPointToEnd(&forOp.getRegion().front());
rewriter.create<scf::YieldOp>(loc, valueToYield);
valueToYield = forOp.getResult(0);
}
}
return PackingResult{
offsets,
sizes,
strides,
clonedLoopIvs,
leadingHoistedPackedTensorIndexings,
maybeTransposeOp,
cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())};
}
static FailureOr<PackingResult> buildPackingLoopNestImpl(
RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist,
ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) {
int nPackedLoops = analysis.packingLoops.size();
LLVM_DEBUG(DBGS() << "\n";
DBGS() << "Func:\n"
<< *opToHoist->getParentOfType<func::FuncOp>() << "\n";
DBGS() << "Start hoisting above " << nPackedLoops << " loops\n");
Location loc = opToHoist->getLoc();
RankedTensorType paddedTensorType = opToHoist.getResultType();
FailureOr<RankedTensorType> transposedTensorType =
tensor::computeTransposedType(paddedTensorType, transposeVector);
if (failed(transposedTensorType)) {
LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n");
return failure();
}
SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic);
llvm::append_range(packedShape, transposedTensorType->getShape());
auto hoistedPackedTensorType = RankedTensorType::get(
packedShape, transposedTensorType->getElementType());
scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(outerLoop);
SmallVector<Value> dynamicTensorSizes =
analysis.getHoistedPackedTensorSizes(rewriter, loc);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, hoistedPackedTensorType.getShape(),
hoistedPackedTensorType.getElementType(), dynamicTensorSizes);
return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
*transposedTensorType, emptyOp, analysis);
}
FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest(
RewriterBase &rewriter, tensor::PadOp opToHoist,
scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) {
HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp);
analysis.enableHoistPadding(rewriter);
analysis.finalizeHoistPaddingAnalysis();
if (!analysis.isValid()) {
LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
return failure();
}
IRMapping bvm;
return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector,
analysis);
}
static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
Value expectedSource) {
LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp
<< "\n");
LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n");
Value source = extractSliceOp.getSource();
LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
while (source && source != expectedSource) {
auto destOp =
dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
source = destOp.getDpsInitOperand(cast<OpResult>(source).getResultNumber())
->get();
}
LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
return source == expectedSource;
}
static tensor::ExtractSliceOp
padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
Value hoistedPackedTensor,
tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) {
LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n");
LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: "
<< paddedValueBeforeHoisting << "\n");
OpOperand *pUse = nullptr;
for (OpOperand &use : outerSliceOp->getUses()) {
if (use.getOwner() == forOp) {
assert(!pUse && "Multiple slice uses in the for loop");
pUse = &use;
}
}
assert(pUse && "No slice use in the for loop");
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)
return tensor::ExtractSliceOp();
if (!tracesBackToExpectedValue(yieldingExtractSliceOp,
paddedValueBeforeHoisting))
return tensor::ExtractSliceOp();
SmallVector<Value> initArgs = forOp.getInitArgs();
initArgs[iterArgNumber] = hoistedPackedTensor;
SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
int64_t numOriginalForOpResults = initArgs.size();
LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
<< "\n");
tensor::ExtractSliceOp extracted;
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(forOp);
extracted = rewriter.create<tensor::ExtractSliceOp>(
hoistedPackedTensor.getLoc(), hoistedPackedTensor,
outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
outerSliceOp.getMixedStrides());
rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
}
scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
rewriter, initArgs, true,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
return yieldOperands;
}));
LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
<< "\n");
LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
LLVM_DEBUG(DBGS() << "with result #"
<< numOriginalForOpResults + iterArgNumber
<< " of forOp, giving us: " << extracted << "\n");
rewriter.startOpModification(extracted);
extracted.getSourceMutable().assign(
newForOp.getResult(numOriginalForOpResults + iterArgNumber));
rewriter.finalizeOpModification(extracted);
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
<< "\n");
LLVM_DEBUG(DBGS() << "with region iter arg #"
<< numOriginalForOpResults + iterArgNumber << "\n");
rewriter.replaceAllUsesWith(
paddedValueBeforeHoisting,
newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
return extracted;
}
static Value replaceByPackingResult(RewriterBase &rewriter,
const IRMapping &bvm,
tensor::PadOp opToHoist,
RankedTensorType transposedTensorType,
const HoistPaddingAnalysis &analysis,
const PackingResult &packingResult) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(opToHoist);
Location loc = opToHoist->getLoc();
RankedTensorType paddedTensorType = opToHoist.getResultType();
int paddedRank = paddedTensorType.getRank();
int64_t nPackedLoops = packingResult.clonedLoopIvs.size();
LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n");
scf::ForOp outerLoop = analysis.outermostEnclosingForOp;
ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops;
Value hoistedPackedTensor;
SmallVector<Value> loopIterationCounts;
SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank,
rewriter.getIndexAttr(0));
if (nPackedLoops > 0) {
loopIterationCounts =
llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) {
return buildLoopIterationCount(rewriter, outerLoop,
cast<scf::ForOp>(loop));
}));
if (llvm ::any_of(loopIterationCounts, [](Value v) { return !v; }))
llvm_unreachable("loop independence prerequisite not met");
std::copy(loopIterationCounts.begin(), loopIterationCounts.end(),
offsets.begin());
hoistedPackedTensor =
scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front())
->getResult(0);
} else {
hoistedPackedTensor = bvm.lookup(opToHoist.getResult());
}
LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n");
scf::ForOp forOp = analysis.padConsumingForOp;
if (forOp) {
return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor,
analysis.sliceOp, forOp);
}
return rewriter.create<tensor::ExtractSliceOp>(
loc, transposedTensorType, hoistedPackedTensor, offsets,
packingResult.sizes, packingResult.strides);
}
FailureOr<Value> mlir::linalg::hoistPaddingOnTensors(
RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops,
ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp,
SmallVectorImpl<GenericOp> &transposeOps) {
LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n";
DBGS() << " by " << numLoops << " loops\n");
HoistPaddingAnalysis analysis(opToHoist, numLoops);
analysis.enableHoistPadding(rewriter);
analysis.finalizeHoistPaddingAnalysis();
if (!analysis.isValid()) {
LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n");
return failure();
}
IRMapping bvm;
FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl(
rewriter, bvm, opToHoist, transposeVector, analysis);
if (failed(packingResult)) {
LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n");
return failure();
}
if (!transposeVector.empty())
transposeOps.push_back(packingResult->maybeTransposeOp);
FailureOr<RankedTensorType> transposedTensorType =
tensor::computeTransposedType(opToHoist.getResultType(), transposeVector);
assert(succeeded(transposedTensorType) && "unexpected failure in type");
Value newResult =
replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType,
analysis, *packingResult);
Location loc = opToHoist->getLoc();
RankedTensorType paddedTensorType = opToHoist.getResultType();
if (!transposeVector.empty()) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newResult.getDefiningOp());
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, paddedTensorType.getShape(), paddedTensorType.getElementType());
GenericOp unTransposeOp =
makeTransposeOp(rewriter, loc, newResult, emptyTensor, transposeVector);
newResult = unTransposeOp.getResult(0);
transposeOps.push_back(unTransposeOp);
}
LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n");
LLVM_DEBUG(
DBGS() << "After hoisting: "
<< newResult.getDefiningOp()->getParentOfType<func::FuncOp>()
<< "\n");
hoistedOp = packingResult->hoistedPadOp;
LLVM_DEBUG(DBGS() << "--SUCCESS\n");
return newResult;
}
FailureOr<Value>
mlir::linalg::hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
ArrayRef<int64_t> transposeVector,
tensor::PadOp &hoistedOp,
SmallVectorImpl<GenericOp> &transposeOps) {
IRRewriter rewriter(opToHoist.getContext());
return hoistPaddingOnTensors(rewriter, opToHoist, numLoops, transposeVector,
hoistedOp, transposeOps);
}