#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "loop-fusion-utils"
using namespace mlir;
using namespace mlir::affine;
static void getLoadAndStoreMemRefAccesses(Operation *opA,
DenseMap<Value, bool> &values) {
opA->walk([&](Operation *op) {
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
if (values.count(loadOp.getMemRef()) == 0)
values[loadOp.getMemRef()] = false;
} else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
values[storeOp.getMemRef()] = true;
}
});
}
static bool isDependentLoadOrStoreOp(Operation *op,
DenseMap<Value, bool> &values) {
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
return values.count(loadOp.getMemRef()) > 0 && values[loadOp.getMemRef()];
}
if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
return values.count(storeOp.getMemRef()) > 0;
}
return false;
}
static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
DenseMap<Value, bool> values;
getLoadAndStoreMemRefAccesses(opA, values);
Operation *firstDepOp = nullptr;
for (Block::iterator it = std::next(Block::iterator(opA));
it != Block::iterator(opB); ++it) {
Operation *opX = &(*it);
opX->walk([&](Operation *op) {
if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
firstDepOp = opX;
});
if (firstDepOp)
break;
}
return firstDepOp;
}
static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
DenseMap<Value, bool> values;
getLoadAndStoreMemRefAccesses(opB, values);
Operation *lastDepOp = nullptr;
for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
it != Block::reverse_iterator(opA); ++it) {
Operation *opX = &(*it);
opX->walk([&](Operation *op) {
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
if (isDependentLoadOrStoreOp(op, values)) {
lastDepOp = opX;
return WalkResult::interrupt();
}
return WalkResult::advance();
}
for (Value value : op->getResults()) {
for (Operation *user : value.getUsers()) {
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
lastDepOp = opX;
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (lastDepOp)
break;
}
return lastDepOp;
}
static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
AffineForOp dstForOp) {
bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
if (firstDepOpA) {
if (lastDepOpB) {
if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
return nullptr;
}
return firstDepOpA;
}
return forOpB;
}
static bool
gatherLoadsAndStores(AffineForOp forOp,
SmallVectorImpl<Operation *> &loadAndStoreOps) {
bool hasIfOp = false;
forOp.walk([&](Operation *op) {
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op))
loadAndStoreOps.push_back(op);
else if (isa<AffineIfOp>(op))
hasIfOp = true;
});
return !hasIfOp;
}
static unsigned getMaxLoopDepth(ArrayRef<Operation *> srcOps,
ArrayRef<Operation *> dstOps) {
if (dstOps.empty())
return 0;
DenseSet<Value> producerConsumerMemrefs;
gatherProducerConsumerMemrefs(srcOps, dstOps, producerConsumerMemrefs);
SmallVector<Operation *, 4> targetDstOps;
for (Operation *dstOp : dstOps) {
auto loadOp = dyn_cast<AffineReadOpInterface>(dstOp);
Value memref = loadOp ? loadOp.getMemRef()
: cast<AffineWriteOpInterface>(dstOp).getMemRef();
if (producerConsumerMemrefs.count(memref) > 0)
targetDstOps.push_back(dstOp);
}
assert(!targetDstOps.empty() &&
"No dependences between 'srcForOp' and 'dstForOp'?");
unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps);
if (all_of(targetDstOps, llvm::IsaPred<AffineReadOpInterface>))
return loopDepth;
for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) {
Operation *srcOpInst = targetDstOps[i];
MemRefAccess srcAccess(srcOpInst);
for (unsigned j = 0; j < e; ++j) {
auto *dstOpInst = targetDstOps[j];
MemRefAccess dstAccess(dstOpInst);
unsigned numCommonLoops =
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
DependenceResult result =
checkMemrefAccessDependence(srcAccess, dstAccess, d);
if (hasDependence(result)) {
loopDepth = std::min(loopDepth, d - 1);
break;
}
}
}
}
return loopDepth;
}
FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp,
AffineForOp dstForOp,
unsigned dstLoopDepth,
ComputationSliceState *srcSlice,
FusionStrategy fusionStrategy) {
if (dstLoopDepth == 0) {
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n");
return FusionResult::FailPrecondition;
}
auto *block = srcForOp->getBlock();
if (block != dstForOp->getBlock()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n");
return FusionResult::FailPrecondition;
}
if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n");
return FusionResult::FailBlockDependence;
}
bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
SmallVector<Operation *, 4> opsA;
if (!gatherLoadsAndStores(forOpA, opsA)) {
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
return FusionResult::FailPrecondition;
}
SmallVector<Operation *, 4> opsB;
if (!gatherLoadsAndStores(forOpB, opsB)) {
LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n");
return FusionResult::FailPrecondition;
}
if (fusionStrategy.getStrategy() == FusionStrategy::ProducerConsumer) {
assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion");
if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) {
LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n");
return FusionResult::FailFusionDependence;
}
}
unsigned numCommonLoops =
affine::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
SmallVector<Operation *, 4> strategyOpsA;
switch (fusionStrategy.getStrategy()) {
case FusionStrategy::Generic:
strategyOpsA.append(opsA.begin(), opsA.end());
break;
case FusionStrategy::ProducerConsumer:
for (Operation *op : opsA) {
if (isa<AffineWriteOpInterface>(op))
strategyOpsA.push_back(op);
}
break;
case FusionStrategy::Sibling:
for (Operation *op : opsA) {
auto load = dyn_cast<AffineReadOpInterface>(op);
if (load && load.getMemRef() == fusionStrategy.getSiblingFusionMemRef())
strategyOpsA.push_back(op);
}
break;
}
SliceComputationResult sliceComputationResult = affine::computeSliceUnion(
strategyOpsA, opsB, dstLoopDepth, numCommonLoops,
isSrcForOpBeforeDstForOp, srcSlice);
if (sliceComputationResult.value == SliceComputationResult::GenericFailure) {
LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
return FusionResult::FailPrecondition;
}
if (sliceComputationResult.value ==
SliceComputationResult::IncorrectSliceFailure) {
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n");
return FusionResult::FailIncorrectSlice;
}
return FusionResult::Success;
}
static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
bool siblingFusionUser) {
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (!tripCount || *tripCount != 1)
return failure();
auto *parentOp = forOp->getParentOp();
if (!isa<AffineForOp>(parentOp))
return failure();
SmallVector<Value> newOperands;
llvm::append_range(newOperands,
forOp.getBody()->getTerminator()->getOperands());
IRRewriter rewriter(parentOp->getContext());
int64_t parentOpNumResults = parentOp->getNumResults();
AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
AffineForOp newLoop =
cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
rewriter, forOp.getInits(), false,
[&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
return newOperands;
}));
SetVector<Operation *> forwardSlice;
if (siblingFusionUser) {
for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
SetVector<Operation *> tmpForwardSlice;
getForwardSlice(forOp.getResult(i), &tmpForwardSlice);
forwardSlice.set_union(tmpForwardSlice);
}
}
for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
forOp.getResult(i).replaceAllUsesWith(
newLoop.getResult(i + parentOpNumResults));
}
if (siblingFusionUser) {
topologicalSort(forwardSlice);
for (Operation *op : llvm::reverse(forwardSlice))
op->moveAfter(newLoop);
}
auto iv = forOp.getInductionVar();
iv.replaceAllUsesWith(newLoop.getInductionVar());
auto forOpIterArgs = forOp.getRegionIterArgs();
for (auto it : llvm::zip(forOpIterArgs, newLoop.getRegionIterArgs().take_back(
forOpIterArgs.size()))) {
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
}
forOp.getBody()->back().erase();
auto *parentBlock = forOp->getBlock();
parentBlock->getOperations().splice(Block::iterator(forOp),
forOp.getBody()->getOperations());
forOp.erase();
return success();
}
void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
const ComputationSliceState &srcSlice,
bool isInnermostSiblingInsertion) {
OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint);
IRMapping mapper;
b.clone(*srcForOp, mapper);
SmallVector<AffineForOp, 4> sliceLoops;
for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) {
auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]);
if (!loopIV)
continue;
auto forOp = getForInductionVarOwner(loopIV);
sliceLoops.push_back(forOp);
if (AffineMap lbMap = srcSlice.lbs[i]) {
auto lbOperands = srcSlice.lbOperands[i];
canonicalizeMapAndOperands(&lbMap, &lbOperands);
forOp.setLowerBound(lbOperands, lbMap);
}
if (AffineMap ubMap = srcSlice.ubs[i]) {
auto ubOperands = srcSlice.ubOperands[i];
canonicalizeMapAndOperands(&ubMap, &ubOperands);
forOp.setUpperBound(ubOperands, ubMap);
}
}
llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
auto srcIsUnitSlice = [&]() {
return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
(getSliceIterationCount(sliceTripCountMap) == 1));
};
for (AffineForOp forOp : sliceLoops) {
if (isLoopParallelAndContainsReduction(forOp) &&
isInnermostSiblingInsertion && srcIsUnitSlice())
(void)promoteSingleIterReductionLoop(forOp, true);
else
(void)promoteIfSingleIteration(forOp);
}
}
bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot,
LoopNestStats *stats) {
auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
auto *childForOp = forOp.getOperation();
auto *parentForOp = forOp->getParentOp();
if (forOp != forOpRoot) {
if (!isa<AffineForOp>(parentForOp)) {
LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
return WalkResult::interrupt();
}
stats->loopMap[parentForOp].push_back(forOp);
}
unsigned count = 0;
stats->opCountMap[childForOp] = 0;
for (auto &op : *forOp.getBody()) {
if (!isa<AffineForOp, AffineIfOp>(op))
++count;
}
stats->opCountMap[childForOp] = count;
std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount) {
LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n");
return WalkResult::interrupt();
}
stats->tripCountMap[childForOp] = *maybeConstTripCount;
return WalkResult::advance();
});
return !walkResult.wasInterrupted();
}
static int64_t getComputeCostHelper(
Operation *forOp, LoopNestStats &stats,
llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
DenseMap<Operation *, int64_t> *computeCostMap) {
int64_t opCount = stats.opCountMap[forOp] - 1;
if (stats.loopMap.count(forOp) > 0) {
for (auto childForOp : stats.loopMap[forOp]) {
opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
computeCostMap);
}
}
if (computeCostMap) {
auto it = computeCostMap->find(forOp);
if (it != computeCostMap->end()) {
opCount += it->second;
}
}
int64_t tripCount = stats.tripCountMap[forOp];
if (tripCountOverrideMap) {
auto it = tripCountOverrideMap->find(forOp);
if (it != tripCountOverrideMap->end()) {
tripCount = it->second;
}
}
return tripCount * opCount;
}
int64_t mlir::affine::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
return getComputeCostHelper(forOp, stats,
nullptr,
nullptr);
}
bool mlir::affine::getFusionComputeCost(AffineForOp srcForOp,
LoopNestStats &srcStats,
AffineForOp dstForOp,
LoopNestStats &dstStats,
const ComputationSliceState &slice,
int64_t *computeCost) {
llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
DenseMap<Operation *, int64_t> computeCostMap;
if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
return false;
int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
assert(sliceIterationCount > 0);
bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
auto *insertPointParent = slice.insertPoint->getParentOp();
if (storeLoadFwdGuaranteed) {
unsigned storeCount = 0;
llvm::SmallDenseSet<Value, 4> storeMemrefs;
srcForOp.walk([&](AffineWriteOpInterface storeOp) {
storeMemrefs.insert(storeOp.getMemRef());
++storeCount;
});
if (storeCount > 0)
computeCostMap[insertPointParent] = -storeCount;
for (Value memref : storeMemrefs) {
for (Operation *user : memref.getUsers()) {
if (!isa<AffineReadOpInterface>(user))
continue;
SmallVector<AffineForOp, 4> loops;
getAffineForIVs(*user, &loops);
if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
if (auto forOp = dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
if (computeCostMap.count(forOp) == 0)
computeCostMap[forOp] = 0;
computeCostMap[forOp] -= 1;
}
}
}
}
}
int64_t sliceComputeCost = getComputeCostHelper(
srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
computeCostMap[insertPointParent] = sliceComputeCost;
*computeCost =
getComputeCostHelper(dstForOp, dstStats,
nullptr, &computeCostMap);
return true;
}
void mlir::affine::gatherProducerConsumerMemrefs(
ArrayRef<Operation *> srcOps, ArrayRef<Operation *> dstOps,
DenseSet<Value> &producerConsumerMemrefs) {
DenseSet<Value> srcStoreMemRefs;
for (Operation *op : srcOps)
if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op))
srcStoreMemRefs.insert(storeOp.getMemRef());
for (Operation *op : dstOps)
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op))
if (srcStoreMemRefs.count(loadOp.getMemRef()) > 0)
producerConsumerMemrefs.insert(loadOp.getMemRef());
}