#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
#include <queue>
#define DEBUG_TYPE "licm"
using namespace mlir;
static bool canBeHoisted(Operation *op,
function_ref<bool(OpOperand &)> condition) {
if (op->hasTrait<OpTrait::IsTerminator>())
return false;
auto walkFn = [&](Operation *child) {
for (OpOperand &operand : child->getOpOperands()) {
if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
continue;
if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
};
return !op->walk(walkFn).wasInterrupted();
}
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return canBeHoisted(
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
std::queue<Operation *> worklist;
for (Operation &op : region->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
return isDefinedOutsideRegion(value, region);
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
if (op->getParentRegion() != region)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);
++numMoved;
for (Operation *user : op->getUsers())
if (user->getParentRegion() == region)
worklist.push(user);
}
}
return numMoved;
}
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}
namespace {
class MatchingSubsets {
public:
void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
allSubsetOps.push_back(op);
if (!collectHoistableOps)
return;
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
insertExtractionOp(extractionOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
insertInsertionOp(insertionOp);
}
auto getHoistableSubsetOps() {
return llvm::make_filter_range(
llvm::zip(extractions, insertions), [&](auto pair) {
auto [extractionOp, insertionOp] = pair;
if (extractionOp && insertionOp &&
extractionOp->getResult(0).getType() !=
insertionOp.getSourceOperand().get().getType())
return false;
return allDisjoint(extractionOp, insertionOp);
});
}
LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps = true);
private:
static bool isEquivalent(Value v1, Value v2) { return true; }
bool allDisjoint(SubsetExtractionOpInterface extractionOp,
SubsetInsertionOpInterface insertionOp) const {
for (SubsetOpInterface other : allSubsetOps) {
if (other == extractionOp || other == insertionOp)
continue;
if (extractionOp &&
!other.operatesOnDisjointSubset(extractionOp, isEquivalent))
return false;
if (insertionOp &&
!other.operatesOnDisjointSubset(insertionOp, isEquivalent))
return false;
}
return true;
}
void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
for (auto it : llvm::enumerate(insertions)) {
if (!it.value())
continue;
auto other = cast<SubsetOpInterface>(it.value().getOperation());
if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
extractions[it.index()] = extractionOp;
return;
}
}
extractions.push_back(extractionOp);
insertions.push_back({});
}
void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
for (auto it : llvm::enumerate(extractions)) {
if (!it.value())
continue;
auto other = cast<SubsetOpInterface>(it.value().getOperation());
if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
insertions[it.index()] = insertionOp;
return;
}
}
extractions.push_back({});
insertions.push_back(insertionOp);
}
SmallVector<SubsetExtractionOpInterface> extractions;
SmallVector<SubsetInsertionOpInterface> insertions;
SmallVector<SubsetOpInterface> allSubsetOps;
};
}
static OpOperand *getSingleTerminatorUse(Value value) {
if (!value.hasOneUse())
return nullptr;
OpOperand &use = *value.getUses().begin();
if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
return &use;
return nullptr;
}
LogicalResult
MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
Value value = iterArg;
OpOperand *yieldedOperand = nullptr;
while (!(yieldedOperand = getSingleTerminatorUse(value))) {
Value nextValue = {};
for (OpOperand &use : value.getUses()) {
if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
if (!nestedIterArg)
return failure();
if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
false)))
return failure();
nextValue = nestedLoop.getTiedLoopResult(&use);
continue;
}
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
if (!subsetOp)
return failure();
insert(subsetOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
return failure();
}
if (&use != &insertionOp.getDestinationOperand())
return failure();
if (nextValue)
return failure();
nextValue = insertionOp.getUpdatedDestination();
}
}
if (!nextValue)
return failure();
value = nextValue;
}
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
return failure();
return success();
}
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
for (auto it : subsets.getHoistableSubsetOps()) {
auto extractionOp = std::get<0>(it);
auto insertionOp = std::get<1>(it);
if (extractionOp) {
if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
&operand == &extractionOp.getSourceOperand();
}))
extractionOp = {};
}
if (insertionOp) {
if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
&operand == &insertionOp.getSourceOperand() ||
&operand == &insertionOp.getDestinationOperand();
}))
insertionOp = {};
}
if (extractionOp && insertionOp) {
NewYieldValuesFn newYieldValuesFn =
[&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
return {insertionOp.getSourceOperand().get()};
};
FailureOr<LoopLikeOpInterface> newLoop =
loopLike.replaceWithAdditionalYields(
rewriter, extractionOp.getResult(),
true, newYieldValuesFn);
if (failed(newLoop))
return loopLike;
loopLike = *newLoop;
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
rewriter.moveOpBefore(extractionOp, loopLike);
rewriter.moveOpAfter(insertionOp, loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
loopLike.getTiedLoopInit(iterArg)->get());
rewriter.replaceAllUsesWith(loopResult,
insertionOp.getUpdatedDestination());
insertionOp.getSourceOperand().set(newLoopResult);
insertionOp.getDestinationOperand().set(loopResult);
}
}
return loopLike;
}
LoopLikeOpInterface
mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
LoopLikeOpInterface loopLike) {
for (int64_t i = 0;
i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
loopLike.getRegionIterArgs()[i]);
}
return loopLike;
}