#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "tile-using-interface"
using namespace mlir;
scf::SCFTilingOptions &
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
auto tileSizes = llvm::to_vector(ts);
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
return tileSizes;
};
return *this;
}
static SmallVector<int64_t>
fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
size_t iterationDomainSize) {
SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
if (filledVector.size() < iterationDomainSize) {
auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
filledVector.append(range.begin(), range.end());
}
if (filledVector.size() > iterationDomainSize)
filledVector.resize(iterationDomainSize);
return filledVector;
}
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
return false;
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
if (!sizeAsInt)
return false;
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
if (!strideAsInt)
return false;
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
}
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, Value iv,
OpFoldResult tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
return tileSize;
if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
return tileSize;
AffineExpr s0, s1, d0;
bindDims(b.getContext(), d0);
bindSymbols(b.getContext(), s0, s1);
AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
return affine::makeComposedFoldedAffineMin(
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
using YieldTiledValuesFn = std::function<LogicalResult(
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
SmallVector<Value> &tiledValues,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *op,
ValueRange newDestArgs) {
Operation *clonedOp = rewriter.clone(*op);
if (newDestArgs.empty())
return clonedOp;
if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
return clonedOp;
}
static LogicalResult generateLoopNestUsingForOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
YieldTiledValuesFn yieldTiledValuesFn,
SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Value> ivs;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
if (isConstantIntValue(tileSize, 0))
continue;
Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
auto loop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
[](OpBuilder &bodyBuilder, Location bodyLoc,
Value iv, ValueRange ) {});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPointToEnd(loop.getBody());
destinationTensors = loop.getRegionIterArgs();
}
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
tiledResults, resultOffsets, resultSizes))) {
return rewriter.notifyMatchFailure(
loc, "failed to generate inner tile loop body");
}
if (loops.empty())
return success();
assert(tiledResults.size() == destinationTensors.size() &&
"Number of results of body should be equal to number of iter args");
SmallVector<Value> yieldedValues;
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
loc, tiledValue, destinationTensor, resultOffset, resultSize,
resultStride);
yieldedValues.push_back(insertSlice);
}
rewriter.create<scf::YieldOp>(loc, yieldedValues);
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(MutableArrayRef(loops).drop_back(),
MutableArrayRef(loops).drop_front())) {
rewriter.setInsertionPointToEnd(
cast<scf::ForOp>(outerLoop.getOperation()).getBody());
rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
}
return success();
}
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
SmallVector<LoopLikeOpInterface> &loops) {
SmallVector<OpFoldResult> lbs, ubs, steps;
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
if (isConstantIntValue(tileSize, 0))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
steps.push_back(tileSize);
}
assert(!lbs.empty() && "Expected at least one loop range");
std::optional<ArrayAttr> mappingAttr;
if (!mappingVector.empty())
mappingAttr = rewriter.getArrayAttr(mappingVector);
auto forallOp = rewriter.create<scf::ForallOp>(
loc, lbs, ubs, steps, destinationTensors, mappingAttr);
loops.push_back(forallOp);
rewriter.setInsertionPoint(forallOp.getTerminator());
destinationTensors = forallOp.getRegionOutArgs();
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
destinationTensors, tiledResults, resultOffsets,
resultSizes)))
return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
loc, tiledValue, destinationTensor, resultOffset, resultSize,
resultStride);
}
return success();
}
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
const scf::SCFTilingOptions &options,
ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes,
ValueRange destinationTensors,
YieldTiledValuesFn tiledBodyFn,
SmallVector<LoopLikeOpInterface> &loops) {
if (llvm::all_of(tileSizes, isZeroIndex)) {
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
tiledResults, resultOffsets, resultSizes);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
destinationTensors, tiledBodyFn, loops);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
rewriter, loc, loopRanges, tileSizes, options.mappingVector,
destinationTensors, tiledBodyFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
template <typename LoopType>
FailureOr<LoopLikeOpInterface>
yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
}
template <>
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = loopOp.getLoc();
rewriter.setInsertionPoint(loopOp);
auto inits = llvm::to_vector(loopOp.getInitArgs());
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = rewriter.create<scf::ForOp>(
loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
inits, [](OpBuilder &, Location, Value, ValueRange) {});
Block *loopBody = loopOp.getBody();
Block *newLoopBody = newLoop.getBody();
rewriter.mergeBlocks(
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> tiledValues;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
ValueRange newRegionIterArgs =
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
newRegionIterArgs, tiledValues, resultOffsets,
resultSizes))) {
rewriter.eraseOp(newLoop);
return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
}
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
Value insert = rewriter.create<tensor::InsertSliceOp>(
yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
resultStride);
newYieldValues.push_back(insert);
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
rewriter.replaceOp(loopOp,
newLoop->getResults().take_front(loopOp.getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}
template <>
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = loopOp.getLoc();
rewriter.setInsertionPoint(loopOp);
auto inits = llvm::to_vector(loopOp.getOutputs());
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = rewriter.create<scf::ForallOp>(
loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
loopOp.getMixedStep(), inits, loopOp.getMapping(),
[](OpBuilder &, Location, ValueRange) {});
Block *loopBody = loopOp.getBody();
Block *newLoopBody = newLoop.getBody();
rewriter.mergeBlocks(
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
rewriter.setInsertionPoint(terminator);
SmallVector<Value> tiledValues;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
ValueRange regionIterArgs =
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
regionIterArgs, tiledValues, resultOffsets,
resultSizes))) {
rewriter.eraseOp(newLoop);
return rewriter.notifyMatchFailure(loopOp,
"failed to get yielded tiled values");
}
rewriter.setInsertionPointToEnd(terminator.getBody());
for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
resultStride);
}
rewriter.replaceOp(loopOp,
newLoop->getResults().take_front(loopOp.getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
loopLikeOp.getOperation())
.Case<scf::ForOp, scf::ForallOp>(
[&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
return yieldTiledValuesAndReplaceLoop(
loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
})
.Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
});
}
static LogicalResult addInitOperandsToLoopNest(
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
SmallVector<scf::ForOp> newLoops;
if (loops.empty())
return success();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loops.front());
SmallVector<Value> ivs;
for (auto &loop : loops.drop_back()) {
rewriter.setInsertionPoint(loop);
auto forLoop = cast<scf::ForOp>(loop.getOperation());
SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
newInits.append(newInitValues.begin(), newInitValues.end());
auto newLoop = rewriter.create<scf::ForOp>(
forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
forLoop.getStep(), newInits,
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
SmallVector<Value> sourceBlockArgs;
sourceBlockArgs.push_back(newLoop.getInductionVar());
auto newRegionIterArgs = newLoop.getRegionIterArgs();
sourceBlockArgs.append(
newRegionIterArgs.begin(),
std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
rewriter.replaceOp(
forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
loop = newLoop;
ivs.push_back(newLoop.getInductionVar());
newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
}
LoopLikeOpInterface innerMostLoop = loops.back();
FailureOr<LoopLikeOpInterface> newInnerMostLoop =
yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
getNewTiledYieldsFn);
if (failed(newInnerMostLoop))
return innerMostLoop.emitOpError("failed to return additional yields");
loops.back() = newInnerMostLoop.value();
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
auto outerForLoop = cast<scf::ForOp>(outerLoop);
auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
llvm::to_vector(outerLoopYield.getOperands());
ValueRange additionalYields =
innerLoop->getResults().take_back(newInitValues.size());
newYields.append(additionalYields.begin(), additionalYields.end());
rewriter.setInsertionPoint(outerLoopYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
}
return success();
}
FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
if (!options.tileSizeComputationFunction) {
return rewriter.notifyMatchFailure(
op, "missing tile size computation function");
}
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
SmallVector<OpFoldResult> tileSizes =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizes.size() < iterationDomain.size()) {
auto zero = rewriter.getIndexAttr(0);
tileSizes.append(numLoops - tileSizes.size(), zero);
}
SmallVector<int64_t> interchangeVector;
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
}
if (!interchangeVector.empty()) {
if (!isPermutationVector(interchangeVector)) {
return rewriter.notifyMatchFailure(
op, "invalid intechange vector, not a permutation of the entire "
"iteration space");
}
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
}
FailureOr<TilingResult> tilingResult;
YieldTiledValuesFn innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes)
-> LogicalResult {
SmallVector<OpFoldResult> offsets, sizes;
{
int materializedLoopNum = 0;
for (auto [tileSize, loopRange] :
llvm::zip_equal(tileSizes, iterationDomain)) {
if (isConstantIntValue(tileSize, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
}
Value iv = ivs[materializedLoopNum++];
offsets.push_back(iv);
sizes.push_back(
getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
}
}
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
applyPermutationToVector(offsets, inversePermutation);
applyPermutationToVector(sizes, inversePermutation);
}
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
TilingResult{{clonedOp}, clonedOp->getResults()};
return success();
}
tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
if (failed(tilingResult)) {
rewriter.eraseOp(clonedOp);
return op.emitOpError("faild to tile operation");
}
rewriter.eraseOp(clonedOp);
for (auto [index, tiledValue] :
llvm::enumerate(tilingResult->tiledValues)) {
tiledResults.push_back(tiledValue);
SmallVector<OpFoldResult> resultOffset, resultSize;
if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize))) {
for (auto op : tilingResult->tiledOps) {
rewriter.eraseOp(op);
}
return rewriter.notifyMatchFailure(
op, "failed to get slice of result produced");
}
resultOffsets.emplace_back(std::move(resultOffset));
resultSizes.emplace_back(std::move(resultSize));
}
return success();
};
SmallVector<Value> destinationTensors;
if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
destinationTensors))) {
return rewriter.notifyMatchFailure(op,
"unable to create destination tensors");
}
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
tileSizes, destinationTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
if (loops.empty()) {
return scf::SCFTilingResult{tilingResult->tiledOps, loops,
tilingResult->tiledValues};
}
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
}
FailureOr<scf::SCFReductionTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSizes) {
Location loc = op.getLoc();
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
auto tileSizesVector = llvm::to_vector(tileSizes);
if (tileSizesVector.size() < iterationDomain.size()) {
auto zero = b.getIndexAttr(0);
tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
zero);
}
SmallVector<utils::IteratorType> iterators =
tilingInterfaceOp.getLoopIteratorTypes();
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
FailureOr<SmallVector<Value>> maybeInitTensors =
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(maybeInitTensors)) {
return b.notifyMatchFailure(op, "Failed to create initial tensors.");
}
SmallVector<Value> &initTensors = maybeInitTensors.value();
SmallVector<Operation *> parallelTiledOps;
auto innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes)
-> LogicalResult {
SmallVector<OpFoldResult> offsets, sizes;
{
int materializedLoopNum = 0;
for (auto [tileSize, loopRange] :
llvm::zip_equal(tileSizesVector, iterationDomain)) {
if (isConstantIntValue(tileSize, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
}
Value iv = ivs[materializedLoopNum++];
offsets.push_back(iv);
sizes.push_back(
getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
}
}
{
auto clonedOp = cast<PartialReductionOpInterface>(
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
FailureOr<TilingResult> partialTilingResult =
clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
sizes, reductionDims);
if (failed(partialTilingResult)) {
return failure();
}
std::swap(parallelTiledOps, partialTilingResult->tiledOps);
std::swap(tiledResult, partialTilingResult->tiledValues);
b.eraseOp(clonedOp);
}
for (auto result : tiledResult) {
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
resultOffsets.emplace_back(std::move(outOffsets));
SmallVector<OpFoldResult> outSizes;
for (size_t i = 0; i < offsets.size(); i++) {
outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
}
resultSizes.emplace_back(std::move(outSizes));
}
return success();
};
SmallVector<LoopLikeOpInterface> loops;
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
initTensors, innerYieldTiledValuesFn, loops)))
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
SmallVector<Value> replacements = llvm::map_to_vector(
loops.front()->getResults(), [](OpResult r) -> Value { return r; });
b.setInsertionPointAfter(*loops.begin());
FailureOr<MergeResult> mergeResult =
op.mergeReductions(b, loc, replacements, reductionDims);
if (failed(mergeResult)) {
return failure();
}
b.replaceOp(op, mergeResult->replacements);
SCFReductionTilingResult reductionTilingResult;
std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
std::swap(reductionTilingResult.initialValues, initTensors);
std::swap(reductionTilingResult.loops, loops);
std::swap(reductionTilingResult.replacements, mergeResult->replacements);
return reductionTilingResult;
}
static std::tuple<OpResult, std::optional<OpOperand *>>
getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
std::optional<OpOperand *> destinationIterArg;
auto loopIt = loops.rbegin();
while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
auto loop = *loopIt;
if (iterArg.getOwner()->getParentOp() != loop)
break;
source = loop.getTiedLoopInit(iterArg);
loopIt++;
}
if (loopIt == loops.rend())
destinationIterArg = source;
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
loops);
if (!fusableProducer)
return std::nullopt;
unsigned resultNumber = fusableProducer.getResultNumber();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
Operation *fusableProducerOp = fusableProducer.getOwner();
if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
failed(tensor::getOrCreateDestinations(
rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
origDestinationTensors)))
return std::nullopt;
clonedOpDestinationTensors = origDestinationTensors;
if (destinationInitArg &&
isa<DestinationStyleOpInterface>(fusableProducerOp)) {
clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
}
Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
rewriter, fusableProducerOp, clonedOpDestinationTensors);
SmallVector<Value> candidateSliceOpOperands =
llvm::to_vector(candidateSliceOp->getOperands());
candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
tensor::ExtractSliceOp clonedCandidateSliceOp =
mlir::clone(rewriter, candidateSliceOp,
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceExtractSliceWithTiledProducer(
rewriter, clonedCandidateSliceOp,
clonedProducerOp->getResult(resultNumber));
if (failed(tileAndFuseResult))
return std::nullopt;
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);
rewriter.eraseOp(clonedCandidateSliceOp);
rewriter.eraseOp(clonedProducerOp);
if (destinationInitArg &&
isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
loops.front()
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
tileAndFuseResult->tiledValues[0],
tileAndFuseResult->tiledOps};
}
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
*tiledOwner = fusedProducerInfo.tiledOps[0];
Location loc = originalOwner->getLoc();
SmallVector<unsigned> initNumberList =
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
0, originalOwner->getNumResults()))
: llvm::to_vector(yieldResultNumber);
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, loc, originalOwner->getResult(resultNumber));
if (succeeded(initValue)) {
initValueList.push_back(initValue.value());
} else {
return failure();
}
}
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange ,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
sliceSizes = sliceOp.getMixedSizes();
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 1);
}))
return failure();
unsigned sliceResultNumber =
fusedProducerInfo.origProducer.getResultNumber();
auto tilableOp = cast<TilingInterface>(originalOwner);
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
if (tilableOp->getNumResults() > 1 &&
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
return failure();
}
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
for (const auto &resultNumber : initNumberList) {
if (resultNumber == sliceResultNumber) {
offsetList.push_back(sliceOffset);
sizesList.push_back(sliceSizes);
} else {
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
SmallVector<OpFoldResult> offset, sizes;
if (failed(tilableOp.getResultTilePosition(
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
offset, sizes))) {
return failure();
}
offsetList.push_back(offset);
sizesList.push_back(sizes);
}
}
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
}
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
tiledResult.push_back(tiledOwner->getResult(resultNumber));
tiledOffset.emplace_back(offsetList[index]);
tiledSizes.emplace_back(sizesList[index]);
}
return success();
};
return addInitOperandsToLoopNest(rewriter, loops, initValueList,
newYieldValuesFn);
}
FailureOr<scf::SCFTileAndFuseResult>
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
RewriterBase &rewriter, TilingInterface consumer,
const scf::SCFTileAndFuseOptions &options) {
if (!consumer->getNumResults()) {
return rewriter.notifyMatchFailure(
consumer, "invalid pattern for op with no results");
}
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCF(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
for (auto *tiledOp : tilingResult->tiledOps)
tiledAndFusedOps.insert(tiledOp);
auto &loops = tilingResult->loops;
if (loops.empty()) {
DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] :
llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
replacements[origVal] = replacement;
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
DenseMap<Value, size_t> origValToResultNumber;
for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
origValToResultNumber[result] = index;
}
auto addCandidateSlices = [](Operation *fusedOp,
std::deque<tensor::ExtractSliceOp> &candidates) {
for (Value operand : fusedOp->getOperands())
if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
candidates.push_back(sliceOp);
};
std::deque<tensor::ExtractSliceOp> candidates;
addCandidateSlices(tiledAndFusedOps.back(), candidates);
OpBuilder::InsertionGuard g(rewriter);
while (!candidates.empty()) {
tensor::ExtractSliceOp candidateSliceOp = candidates.front();
candidates.pop_front();
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
loops);
if (!fusableProducer)
continue;
auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
candidateSliceOp, fusableProducer, destinationInitArg.has_value());
if (!fuseSlice)
continue;
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
if (!fusedResult)
continue;
if (yieldReplacement) {
Operation *fusableProducerOp = fusableProducer.getOwner();
if (failed(yieldReplacementForFusedProducer(
rewriter, candidateSliceOp, fusedResult.value(), loops))) {
return rewriter.notifyMatchFailure(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
fusableProducerOp->getNumResults() +
index;
}
}
if (Operation *tiledAndFusedOp =
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
addCandidateSlices(tiledAndFusedOp, candidates);
}
}
DenseMap<Value, Value> replacements;
for (auto [origVal, resultNumber] : origValToResultNumber) {
replacements[origVal] = loops.front()->getResult(resultNumber);
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
static LogicalResult
checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
Value result = candidateSliceOp.getResult();
Value::use_range uses = result.getUses();
if (!llvm::hasSingleElement(uses)) {
LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
return failure();
}
OpOperand &operandUse = (*uses.begin());
Operation *userOp = operandUse.getOwner();
if (!isa<scf::YieldOp>(userOp)) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected scf.yield to be the only user, but got -> "
<< (*userOp));
return failure();
}
if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
"be in the same block\n");
return failure();
}
return success();
}
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
if (!llvm::hasSingleElement(val.getUses()))
return failure();
OpOperand &operand = (*val.getUses().begin());
Operation *consumerOp = operand.getOwner();
if (!isa<TilingInterface>(consumerOp) ||
!isa<DestinationStyleOpInterface>(consumerOp))
return failure();
if (containingOpBlock != consumerOp->getBlock())
return failure();
return &operand;
}
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
Operation *containingOp = candidateSliceOp->getParentOp();
auto forOp = dyn_cast<scf::ForOp>(containingOp);
if (!forOp)
return failure();
Value resultingValue = forOp->getResult(resultNumber);
return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
Operation *containingOp = iterArg.getOwner()->getParentOp();
if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
return failure();
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
if (!forallOp)
return failure();
Value resultingValue =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
return getConsumerFromUses(resultingValue, containingOp->getBlock());
}
static LogicalResult checkAssumptionForLoop(Operation *loopOp,
Operation *consumerOp) {
if (loopOp->getNumResults() == 1)
return success();
Block *parentBlock = consumerOp->getBlock();
for (Operation *userOp : loopOp->getUsers()) {
if (userOp == consumerOp)
continue;
if (parentBlock != userOp->getBlock() ||
!consumerOp->isBeforeInBlock(userOp))
return failure();
}
return success();
}
static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(insertSlice);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(parallelInsertSlice);
} else {
return failure();
}
}
static void
fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
TilingResult &tilingResult,
ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::YieldOp oldTerminatorOp =
cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
unsigned totalOldResults = oldTerminatorOp->getNumResults();
unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(totalOldResults + totalTiledResults);
for (auto oldResult : oldTerminatorOp.getResults()) {
newYieldOperands.push_back(oldResult);
}
rewriter.setInsertionPointAfter(oldTerminatorOp);
Location loc = newForOp.getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, tiledResult, bbArg, resultOffset, resultSize, strides);
newYieldOperands.push_back(newInsertSliceOp);
}
rewriter.create<scf::YieldOp>(loc, newYieldOperands);
rewriter.eraseOp(oldTerminatorOp);
}
static void
fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
SmallVector<Value> tiledResults,
ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
ArrayRef<BlockArgument> bbArgs) {
scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
Location firstYieldOpLoc =
(*(newTerminatorOp.getYieldingOps().begin())).getLoc();
for (auto [tiledResult, bbArg, resultOffset, resultSize] :
llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> strides(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
}
}
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
Operation *candidateSliceOp) {
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
FailureOr<OpOperand *> maybeConsumerOpOperand =
getUntiledConsumerFromSlice(candidateSliceOp);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
}
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
Operation *consumerOp = consumerOpOperand->getOwner();
unsigned operandNumber = consumerOpOperand->getOperandNumber();
unsigned resultNumber = 0;
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
resultNumber = producerResult.getResultNumber();
} else {
return rewriter.notifyMatchFailure(
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
Operation *oldLoopOp = nullptr;
SmallVector<Value> newOuts;
Block *oldLoopBody = nullptr;
unsigned initSize = 0;
unsigned rank = 1;
if (isInsertSliceOp) {
auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
oldLoopOp = forOp;
llvm::append_range(newOuts, forOp.getInits());
oldLoopBody = forOp.getBody();
initSize = forOp.getInits().size();
} else {
auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
oldLoopOp = forallOp;
llvm::append_range(newOuts, forallOp.getOutputs());
oldLoopBody = forallOp.getBody();
initSize = forallOp.getOutputs().size();
rank = forallOp.getRank();
}
if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
return rewriter.notifyMatchFailure(
oldLoopOp, "containing loop op should either yield just one value or "
"have the consumer op as its first user");
}
OpBuilder::InsertionGuard g(rewriter);
auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
newOuts.append(dpsInits);
Location loc = oldLoopOp->getLoc();
rewriter.setInsertionPoint(consumerOp);
Operation *newLoopOp = nullptr;
Block *newLoopBody = nullptr;
if (isInsertSliceOp) {
auto forOp = cast<scf::ForOp>(oldLoopOp);
auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep(), newOuts);
newLoopOp = newForOp;
newLoopBody = newForOp.getBody();
} else {
auto forallOp = cast<scf::ForallOp>(oldLoopOp);
auto newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), newOuts, forallOp.getMapping());
newLoopOp = newForallOp;
rewriter.eraseOp(newForallOp.getTerminator());
newLoopBody = newForallOp.getBody();
}
unsigned oldNumArguments = oldLoopBody->getNumArguments();
rewriter.mergeBlocks(oldLoopBody, newLoopBody,
newLoopBody->getArguments().take_front(oldNumArguments));
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
clonedInsertSliceOp =
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
auto newForOpBlockArgsForConsumerDest =
newLoopBody->getArguments().drop_front(oldNumArguments);
auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return failure();
}
rewriter.replaceAllUsesWith(
tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
clonedInsertSliceOp.getSource());
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
if (llvm::any_of(strides, [](OpFoldResult stride) {
return !isConstantIntValue(stride, 1);
})) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp, "can't get iter domain position from input position");
}
unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
SmallVector<SmallVector<OpFoldResult>> resultOffsets(
totalNumResultsOfConsumer);
SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
if (failed(clonedConsumerOp.getResultTilePosition(
rewriter, idx, iterDomainOffsets, iterDomainSizes,
resultOffsets[idx], resultSizes[idx]))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
"can't get result domain position from iter domain position");
}
}
auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
if (isInsertSliceOp) {
auto newForOp = cast<scf::ForOp>(newLoopOp);
fixTerminatorSCFYield(
rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
newForOp.getBody()->getArguments().drop_front(1 + initSize));
} else {
auto newForallOp = cast<scf::ForallOp>(newLoopOp);
fixTerminatorSCFInParallel(
rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
arrayRefOffsets, arrayRefSizes,
newForallOp.getBody()->getArguments().drop_front(rank + initSize));
}
for (auto &&[oldResult, newResult] :
llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
newLoopOp->getResults().drop_front(initSize))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
rewriter.eraseOp(oldLoopOp);
rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand,
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
tileAndFuseResult->tiledOps};
}
FailureOr<SmallVector<scf::ForOp>>
mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
TilingInterface op) {
if (op->getNumResults() > 0) {
return rewriter.notifyMatchFailure(
op, "unable to lower to loops operations with return values");
}
SmallVector<Range> domain = op.getIterationDomain(rewriter);
SmallVector<Value> ivs;
SmallVector<scf::ForOp> loops;
Location loc = op.getLoc();
for (auto loopRange : domain) {
Value offsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
Value sizeVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
Value strideVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
strideVal, ValueRange{});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPoint(loop.getBody()->getTerminator());
}
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
return failure();
}
return loops;
}