#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "llvm/Support/Debug.h"
#include <queue>
namespace mlir {
namespace triton {
namespace gpu {
#define GEN_PASS_DEF_TRITONGPUFUSENESTEDLOOPS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static constexpr llvm::StringLiteral kFlattenAttr = "tt.flatten";
static constexpr llvm::StringLiteral kMustExecuteAttrName = "ttg.must-execute";
static constexpr llvm::StringLiteral kAlwaysFuseAttrName = "ttg.always-fuse";
namespace {
struct FuseNestedLoopsPass
: public impl::TritonGPUFuseNestedLoopsBase<FuseNestedLoopsPass> {
using TritonGPUFuseNestedLoopsBase::TritonGPUFuseNestedLoopsBase;
void runOnOperation() override;
};
struct LoopNestNode {
LoopNestNode(scf::ForOp loop) : loop(loop) {}
scf::ForOp loop;
SmallVector<LoopNestNode *, 1> children;
};
struct LoopNest {
LoopNest(scf::ForOp outermost);
void print(raw_ostream &os) const;
LLVM_DUMP_METHOD void dump() const;
SmallVector<std::unique_ptr<LoopNestNode>> nodes;
LoopNestNode *root;
};
}
LoopNest::LoopNest(scf::ForOp outermost)
: root(
nodes.emplace_back(std::make_unique<LoopNestNode>(outermost)).get()) {
}
void LoopNest::print(raw_ostream &os) const {
std::string buffer;
auto printLoopFirstLine = [&](scf::ForOp loop) {
buffer.clear();
llvm::raw_string_ostream str(buffer);
loop.print(str);
os << buffer.substr(0, buffer.find('\n'));
};
os << "LoopNest:\n";
SmallVector<std::pair<LoopNestNode *, unsigned>> stack;
stack.emplace_back(root, 0);
while (!stack.empty()) {
auto [node, indent] = stack.pop_back_val();
os << std::string(indent * 2, ' ');
printLoopFirstLine(node->loop);
os << "\n";
for (LoopNestNode *child : node->children)
stack.emplace_back(child, indent + 1);
}
os << "\n";
}
void LoopNest::dump() const { print(llvm::dbgs()); }
static void findLoopNests(Operation *container,
SmallVectorImpl<LoopNest> &nests);
static void constructLoopNest(LoopNestNode *parent, LoopNest &nest,
SmallVectorImpl<LoopNest> &nests) {
parent->loop->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
if (op == parent->loop)
return WalkResult::advance();
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
auto &child =
nest.nodes.emplace_back(std::make_unique<LoopNestNode>(forOp));
parent->children.push_back(child.get());
constructLoopNest(child.get(), nest, nests);
return WalkResult::skip();
}
if (op->getNumRegions()) {
findLoopNests(op, nests);
return WalkResult::skip();
}
return WalkResult::advance();
});
}
static void findLoopNests(Operation *container,
SmallVectorImpl<LoopNest> &nests) {
container->walk<mlir::WalkOrder::PreOrder>([&](scf::ForOp loop) {
LoopNest nest(loop);
constructLoopNest(nest.root, nest, nests);
nests.push_back(std::move(nest));
return WalkResult::skip();
});
}
namespace {
struct Logue {
void moveBefore(Block *block, Block::iterator it) {
for (Operation *op : ops)
op->moveBefore(block, it);
}
void replaceAllUsesWith(ValueRange values, Region &containingRegion) {
for (auto [newOut, output] : llvm::zip(values, outputs)) {
output.replaceUsesWithIf(newOut, [&](OpOperand &use) {
return !containingRegion.isAncestor(use.getOwner()->getParentRegion());
});
}
}
unsigned getNumOutputs() const { return outputs.size(); }
ValueRange getOutputs() const { return outputs; }
TypeRange getOutputTypes() const { return getOutputs().getTypes(); }
SmallVector<Operation *> ops;
SmallVector<Value> outputs;
};
}
static Logue createLogueFrom(llvm::iterator_range<Block::iterator> ops,
mlir::DominanceInfo &domInfo) {
Logue logue;
for (Operation &op : ops)
logue.ops.push_back(&op);
if (ops.empty())
return logue;
Operation &lastOp = *std::prev(ops.end());
auto isOutput = [&](OpResult result) {
for (Operation *user : result.getUsers()) {
if (domInfo.properlyDominates(&lastOp, user))
return true;
}
return false;
};
for (Operation &op : ops) {
for (OpResult result : op.getOpResults()) {
if (isOutput(result))
logue.outputs.push_back(result);
}
}
return logue;
}
static bool canHoistLoopBoundComputation(Operation *op) {
auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); };
return isMemoryEffectFree(op) &&
llvm::all_of(op->getOperandTypes(), isScalar) &&
llvm::all_of(op->getResultTypes(), isScalar);
}
static bool isOuterLoopInvariant(mlir::DominanceInfo &domInfo, scf::ForOp outer,
ArrayRef<Value> values,
llvm::SetVector<Operation *> &toHoist) {
return getDominatingValueSetOpsToHoist(domInfo, outer, values, toHoist,
canHoistLoopBoundComputation);
}
static unsigned getIntTypeWidth(Type type) {
if (isa<IndexType>(type))
return IndexType::kInternalStorageBitWidth;
return cast<IntegerType>(type).getWidth();
}
static Value computeNumIters(ImplicitLocOpBuilder &b, scf::ForOp loop) {
Value diff =
b.create<arith::SubIOp>(loop.getUpperBound(), loop.getLowerBound());
return b.create<arith::CeilDivSIOp>(diff, loop.getStep());
}
static Value castIntIfNecessary(ImplicitLocOpBuilder &b, Value value,
Type type) {
if (value.getType() == type)
return value;
if (isa<IndexType>(value.getType()) || isa<IndexType>(type))
return b.create<arith::IndexCastOp>(type, value);
if (cast<IntegerType>(value.getType()).getWidth() >
cast<IntegerType>(type).getWidth())
return b.create<arith::TruncIOp>(type, value);
return b.create<arith::ExtSIOp>(type, value);
}
static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) {
Type elTy = getElementTypeOrSelf(type);
if (!elTy.isIntOrIndexOrFloat() ||
(!isa<RankedTensorType>(type) && type != elTy))
return b.create<ub::PoisonOp>(type);
TypedAttr attr = isa<FloatType>(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0))
: b.getIntegerAttr(elTy, 0);
if (auto tensor = dyn_cast<RankedTensorType>(type))
attr = SplatElementsAttr::get(tensor, attr);
return b.create<arith::ConstantOp>(attr);
}
static scf::YieldOp getYield(Region &body) {
return cast<scf::YieldOp>(body.front().back());
}
static scf::IfOp eraseIfResults(ImplicitLocOpBuilder &b, scf::IfOp ifOp,
llvm::BitVector indices,
SmallVector<Value> replaceWith) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(ifOp);
while (indices.size() < ifOp.getNumResults())
indices.push_back(false);
getYield(ifOp.getThenRegion())->eraseOperands(indices);
getYield(ifOp.getElseRegion())->eraseOperands(indices);
TypeRange newTypes = getYield(ifOp.getThenRegion()).getOperandTypes();
auto newIf = b.create<scf::IfOp>(newTypes, ifOp.getCondition());
newIf.getThenRegion().takeBody(ifOp.getThenRegion());
newIf.getElseRegion().takeBody(ifOp.getElseRegion());
SmallVector<Value> replacements;
auto replIt = replaceWith.begin();
auto resIt = newIf->result_begin();
for (unsigned i : llvm::seq(ifOp.getNumResults()))
replacements.push_back(indices[i] ? *replIt++ : *resIt++);
assert(ValueRange(replacements).getTypes() == ifOp.getResultTypes());
ifOp.replaceAllUsesWith(replacements);
ifOp.erase();
return newIf;
}
static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) {
scf::ForOp outer = parent->loop;
SmallVector<scf::ForOp> innerLoops;
llvm::SetVector<Operation *> toHoist;
for (LoopNestNode *child : parent->children) {
scf::ForOp inner = child->loop;
assert(child->children.empty() && "fuseOneLevel runs leaf-to-root");
if (!isOuterLoopInvariant(
domInfo, outer,
{inner.getLowerBound(), inner.getUpperBound(), inner.getStep()},
toHoist))
continue;
innerLoops.push_back(child->loop);
}
parent->children.clear();
if (innerLoops.empty())
return;
hoistOpsBefore(outer, toHoist);
unsigned intTyWidth = getIntTypeWidth(outer.getInductionVar().getType());
Location loc = outer.getLoc();
ImplicitLocOpBuilder b(loc, outer);
Value lenOuter = computeNumIters(b, outer);
SmallVector<Value> lenInners;
for (scf::ForOp loop : innerLoops) {
Value lenInner = computeNumIters(b, loop);
intTyWidth = std::max(intTyWidth, getIntTypeWidth(lenInner.getType()));
lenInners.push_back(lenInner);
}
auto intTy = b.getIntegerType(intTyWidth);
auto intTyCst = [&](int64_t v) {
return b.create<arith::ConstantOp>(IntegerAttr::get(intTy, v));
};
unsigned N = innerLoops.size() - 1;
Value innerLen = intTyCst(0);
SmallVector<Value> partialInnerSums;
partialInnerSums.push_back(innerLen);
for (Value lenInner : lenInners) {
lenInner = castIntIfNecessary(b, lenInner, intTy);
lenInner = b.create<arith::MaxSIOp>(intTyCst(1), lenInner);
innerLen = b.create<arith::AddIOp>(innerLen, lenInner);
partialInnerSums.push_back(innerLen);
}
innerLen = b.create<arith::SubIOp>(innerLen, intTyCst(N));
Value totalIters =
b.create<arith::MulIOp>(castIntIfNecessary(b, lenOuter, intTy), innerLen);
SmallVector<Logue> logues;
auto addLogue = [&](Block::iterator begin, Block::iterator end) {
logues.push_back(createLogueFrom({begin, end}, domInfo));
};
addLogue(outer.getBody()->begin(), innerLoops.front()->getIterator());
for (auto i : llvm::seq<unsigned>(0, innerLoops.size() - 1)) {
addLogue(std::next(innerLoops[i]->getIterator()),
innerLoops[i + 1]->getIterator());
}
addLogue(std::next(innerLoops.back()->getIterator()),
std::prev(outer.getBody()->end()));
SmallVector<Value> fusedInits;
fusedInits.push_back(intTyCst(-1));
fusedInits.push_back(
b.create<arith::SubIOp>(outer.getLowerBound(), outer.getStep()));
unsigned outerArgsStartIdx = fusedInits.size();
llvm::append_range(fusedInits, outer.getInits());
unsigned ivarStartIdx = fusedInits.size();
for (scf::ForOp loop : innerLoops) {
fusedInits.push_back(
createPoisonOrZero(b, loop.getInductionVar().getType()));
}
unsigned innerOutsStartIdx = fusedInits.size();
for (scf::ForOp loop : innerLoops) {
for (Type resultType : loop.getResultTypes())
fusedInits.push_back(createPoisonOrZero(b, resultType));
}
unsigned logueOutsStartIdx = fusedInits.size();
for (Logue &logue : llvm::drop_end(logues)) {
for (Type outputType : logue.getOutputTypes())
fusedInits.push_back(createPoisonOrZero(b, outputType));
}
auto fused =
b.create<scf::ForOp>(intTyCst(0), totalIters, intTyCst(1), fusedInits);
for (auto [arg, fusedArg] :
llvm::zip(outer.getRegionIterArgs(),
fused.getRegionIterArgs().slice(outerArgsStartIdx))) {
arg.replaceAllUsesWith(fusedArg);
}
b.setInsertionPointToStart(fused.getBody());
Value T = fused.getRegionIterArg(0);
Value nextT = b.create<arith::AddIOp>(T, intTyCst(1));
Value rollover =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, T,
b.create<arith::SubIOp>(innerLen, intTyCst(1)));
T = b.create<arith::SelectOp>(rollover, intTyCst(0), nextT);
Value curI = fused.getRegionIterArg(1);
Value i;
assert(partialInnerSums.size() == N + 2);
ArrayRef<BlockArgument> ivars = fused.getRegionIterArgs().slice(ivarStartIdx);
auto bodyOutsIt =
ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx;
auto logueOutsIt =
ValueRange(fused.getRegionIterArgs()).begin() + logueOutsStartIdx;
SmallVector<scf::IfOp> prologueIfs, bodyIfs;
for (unsigned k = 0; k <= N; ++k) {
Value innerStartT =
b.create<arith::SubIOp>(partialInnerSums[k], intTyCst(k));
Value prologueCond =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, T, innerStartT);
scf::ForOp inner = innerLoops[k];
Logue &prologue = logues[k];
SmallVector<Type> prologueOutTypes{inner.getInductionVar().getType()};
llvm::append_range(prologueOutTypes, prologue.getOutputTypes());
llvm::append_range(prologueOutTypes, inner.getInits().getTypes());
if (k == 0)
prologueOutTypes.push_back(curI.getType());
auto prologueIf = b.create<scf::IfOp>(prologueOutTypes, prologueCond);
prologueIfs.push_back(prologueIf);
Block *thenBlock = b.createBlock(&prologueIf.getThenRegion());
prologue.moveBefore(thenBlock, thenBlock->end());
if (k == 0) {
b.setInsertionPointToStart(thenBlock);
i = b.create<arith::AddIOp>(curI, outer.getStep());
mlir::replaceAllUsesInRegionWith(outer.getInductionVar(), i,
prologueIf.getThenRegion());
}
b.setInsertionPointToEnd(thenBlock);
SmallVector<Value> thenOuts{inner.getLowerBound()};
llvm::append_range(thenOuts, prologue.getOutputs());
llvm::append_range(thenOuts, inner.getInits());
if (k == 0)
thenOuts.push_back(i);
b.create<scf::YieldOp>(thenOuts);
b.createBlock(&prologueIf.getElseRegion());
Value lastJk = ivars[k];
unsigned numOuts = prologue.getNumOutputs();
SmallVector<Value> elseOuts{lastJk};
elseOuts.append(logueOutsIt, logueOutsIt + numOuts);
elseOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults());
if (k == 0)
elseOuts.push_back(curI);
logueOutsIt += numOuts;
b.create<scf::YieldOp>(elseOuts);
Value jk = prologueIf.getResult(0);
ValueRange prologueOuts = prologueIf.getResults().slice(1, numOuts);
ValueRange prologueInits =
prologueIf.getResults().slice(1 + numOuts, inner.getNumResults());
inner.getInductionVar().replaceAllUsesWith(jk);
prologue.replaceAllUsesWith(prologueOuts, prologueIf.getThenRegion());
for (auto [init, iterArg] :
llvm::zip(prologueInits, inner.getRegionIterArgs()))
iterArg.replaceAllUsesWith(init);
if (k == 0) {
i = prologueIf.getResults().back();
outer.getInductionVar().replaceAllUsesWith(i);
}
b.setInsertionPointAfter(prologueIf);
Value innerEndT = b.create<arith::AddIOp>(
innerStartT, castIntIfNecessary(b, lenInners[k], intTy));
Value ge =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, T, innerStartT);
Value lt = b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, T, innerEndT);
Value bodyCond = b.create<arith::AndIOp>(ge, lt);
SmallVector<Type> bodyOutTypes{jk.getType()};
llvm::append_range(bodyOutTypes, inner->getResultTypes());
auto bodyIf = b.create<scf::IfOp>(bodyOutTypes, bodyCond);
bodyIfs.push_back(bodyIf);
inner.getBody()->eraseArguments([](Value arg) { return true; });
bodyIf.getThenRegion().takeBody(inner.getBodyRegion());
auto yield = getYield(bodyIf.getThenRegion());
b.setInsertionPoint(yield);
Value nextJk = b.create<arith::AddIOp>(jk, inner.getStep());
yield->insertOperands(0, nextJk);
b.createBlock(&bodyIf.getElseRegion());
SmallVector<Value> bodyForwardedOuts{jk};
bodyForwardedOuts.append(bodyOutsIt, bodyOutsIt + inner.getNumResults());
bodyOutsIt += inner->getNumResults();
b.create<scf::YieldOp>(bodyForwardedOuts);
inner.replaceAllUsesWith(
bodyIf.getResults().slice(1, inner.getNumResults()));
if (inner->hasAttr(kMustExecuteAttrName)) {
b.setInsertionPoint(bodyIf);
bodyIf.getConditionMutable().assign(
b.create<arith::ConstantOp>(b.getBoolAttr(true)));
}
b.setInsertionPointAfter(bodyIf);
}
Logue &epilogue = logues.back();
auto outerYield = cast<scf::YieldOp>(outer.getBody()->getTerminator());
SmallVector<Value> usedIterArgs;
for (Value output : epilogue.getOutputs()) {
for (OpOperand &use : output.getUses()) {
if (use.getOwner() == outerYield) {
usedIterArgs.push_back(fused.getRegionIterArgs().drop_front(
outerArgsStartIdx)[use.getOperandNumber()]);
}
}
}
auto epilogueCond =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, T,
b.create<arith::SubIOp>(innerLen, intTyCst(1)));
auto epilogueIf =
b.create<scf::IfOp>(epilogue.getOutputTypes(), epilogueCond);
Block *thenBlock = b.createBlock(&epilogueIf.getThenRegion());
epilogue.moveBefore(thenBlock, thenBlock->end());
b.setInsertionPointToEnd(thenBlock);
b.create<scf::YieldOp>(epilogue.getOutputs());
b.createBlock(&epilogueIf.getElseRegion());
b.create<scf::YieldOp>(usedIterArgs);
epilogue.replaceAllUsesWith(epilogueIf.getResults(),
epilogueIf.getThenRegion());
SmallVector<Value> outerOuts{T, i};
llvm::append_range(outerOuts, outerYield.getOperands());
for (scf::IfOp bodyIf : bodyIfs)
outerOuts.push_back(bodyIf.getResult(0));
for (auto [bodyIf, loop] : llvm::zip(bodyIfs, innerLoops)) {
llvm::append_range(outerOuts,
bodyIf.getResults().slice(1, loop.getNumResults()));
}
for (auto [logueIf, logue] : llvm::zip(prologueIfs, llvm::drop_end(logues))) {
llvm::append_range(outerOuts,
logueIf.getResults().slice(1, logue.getNumOutputs()));
}
b.setInsertionPointToEnd(fused.getBody());
b.create<scf::YieldOp>(outerOuts);
outer.replaceAllUsesWith(
fused.getResults().slice(outerArgsStartIdx, outer.getNumResults()));
auto fusedInitsIt = fused.getInitsMutable().begin() + innerOutsStartIdx;
auto fusedArgsIt = fused.getRegionIterArgs().begin() + innerOutsStartIdx;
auto fusedYieldIt = getYield(fused.getBodyRegion())->getOpOperands().begin() +
innerOutsStartIdx;
SmallVector<OpOperand *> yieldsToUpdate;
SmallVector<Value> reset, forwarded;
for (auto [loop, ifOp, bodyIf, prologue] :
llvm::zip(innerLoops, prologueIfs, bodyIfs, logues)) {
unsigned numResults = loop.getNumResults();
unsigned prologueSkip = 1 + prologue.getNumOutputs();
llvm::BitVector removeIndices(prologueSkip + numResults);
SmallVector<Value> replaceWith;
for (auto [i, init] : llvm::enumerate(loop.getInits())) {
if (init.getParentRegion() == &fused.getBodyRegion())
continue;
fusedInitsIt[i].assign(init);
replaceWith.push_back(fusedArgsIt[i]);
removeIndices.set(prologueSkip + i);
yieldsToUpdate.push_back(&fusedYieldIt[i]);
forwarded.push_back(bodyIf.getResult(1 + i));
reset.push_back(init);
}
eraseIfResults(b, ifOp, removeIndices, replaceWith);
fusedInitsIt += numResults;
fusedArgsIt += numResults;
fusedYieldIt += numResults;
}
if (!yieldsToUpdate.empty()) {
MutableOperandRange(getYield(epilogueIf.getThenRegion())).append(reset);
MutableOperandRange(getYield(epilogueIf.getElseRegion())).append(forwarded);
b.setInsertionPoint(epilogueIf);
TypeRange newTypes = getYield(epilogueIf.getThenRegion()).getOperandTypes();
auto newIf = b.create<scf::IfOp>(newTypes, epilogueIf.getCondition());
newIf.getThenRegion().takeBody(epilogueIf.getThenRegion());
newIf.getElseRegion().takeBody(epilogueIf.getElseRegion());
epilogueIf.replaceAllUsesWith(
newIf.getResults().take_front(epilogueIf.getNumResults()));
ResultRange newResults =
newIf.getResults().drop_front(epilogueIf.getNumResults());
for (auto [i, yieldOperand] : llvm::enumerate(yieldsToUpdate))
yieldOperand->set(newResults[i]);
epilogueIf.erase();
}
if (outer->hasAttr(kWarpSpecializeAttrName) ||
llvm::any_of(innerLoops, [](scf::ForOp loop) {
return loop->hasAttr(kWarpSpecializeAttrName);
}))
fused->setAttr(kWarpSpecializeAttrName, b.getUnitAttr());
bool disallowAccMultiBuffer = getDisallowAccMultiBuffer(outer);
for (scf::ForOp loop : innerLoops) {
disallowAccMultiBuffer |= getDisallowAccMultiBuffer(loop);
}
if (disallowAccMultiBuffer)
fused->setAttr(kDisallowAccMultiBufferAttrName, b.getUnitAttr());
int numStages = 1;
if (auto stageAttr = outer->getAttrOfType<IntegerAttr>(kNumStagesAttrName))
numStages = stageAttr.getInt();
for (scf::ForOp loop : innerLoops) {
if (auto stageAttr = loop->getAttrOfType<IntegerAttr>(kNumStagesAttrName))
numStages = std::max<int>(numStages, stageAttr.getInt());
loop.erase();
}
outer.erase();
parent->loop = fused;
if (numStages > 1)
fused->setAttr(kNumStagesAttrName, b.getI32IntegerAttr(numStages));
}
static void flattenLoopNest(LoopNestNode *node, mlir::DominanceInfo &domInfo) {
for (LoopNestNode *child : node->children)
flattenLoopNest(child, domInfo);
fuseOneLevel(node, domInfo);
}
static bool shouldFuse(const LoopNest &nest) {
if (nest.root->loop->hasAttr(kAlwaysFuseAttrName))
return true;
return nest.nodes.size() == 2 && nest.root->children.size() == 1 &&
nest.root->loop->hasAttr(kFlattenAttr);
}
static void sinkOps(Region &limit, Block *sinkBlock, Block::iterator sinkBefore,
llvm::iterator_range<Block::iterator> prologue,
function_ref<bool(Operation *)> inSinkRegion) {
llvm::SetVector<Operation *> sunkOps;
auto canBeSunk = [&](Operation &op) -> std::pair<bool, bool> {
if (!isPure(&op) || isa<DotOpInterface>(op))
return {false, false};
bool isRoot = true;
for (Operation *user : op.getUsers()) {
if (inSinkRegion(user))
continue;
isRoot = false;
if (sunkOps.contains(user))
continue;
return {false, false};
}
return {true, isRoot};
};
SmallVector<Operation *> roots;
for (Operation &op : llvm::reverse(prologue)) {
auto [canSink, isRoot] = canBeSunk(op);
if (canSink)
sunkOps.insert(&op);
if (isRoot)
roots.push_back(&op);
}
if (sunkOps.empty())
return;
hoistOpsBefore(sinkBlock, sinkBefore, sunkOps);
}
static void optimizeEpilogueDependencies(scf::ForOp outerLoop,
scf::ForOp innerLoop,
mlir::DominanceInfo &domInfo) {
auto inEpilogue = [&](Operation *op) {
return domInfo.properlyDominates(innerLoop, op, false);
};
Region &limit = outerLoop.getBodyRegion();
sinkOps(limit, outerLoop.getBody(), std::next(innerLoop->getIterator()),
{outerLoop.getBody()->begin(), innerLoop->getIterator()}, inEpilogue);
}
static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop,
scf::ForOp innerLoop,
mlir::DominanceInfo &domInfo) {
Location loc = innerLoop.getLoc();
llvm::SetVector<Operation *> toHoist;
if (!isOuterLoopInvariant(domInfo, outerLoop,
{innerLoop.getLowerBound(),
innerLoop.getUpperBound(), innerLoop.getStep()},
toHoist))
return failure();
hoistOpsBefore(outerLoop, toHoist);
ImplicitLocOpBuilder b(loc, outerLoop);
innerLoop->setAttr(kMustExecuteAttrName, b.getUnitAttr());
Value lenInner = computeNumIters(b, innerLoop);
auto zeroAttr = IntegerAttr::get(lenInner.getType(), 0);
Value innerLoopEmpty =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, lenInner,
b.create<arith::ConstantOp>(zeroAttr));
auto ifOp = b.create<scf::IfOp>(outerLoop.getResultTypes(), innerLoopEmpty);
mlir::IRMapping map;
b.createBlock(&ifOp.getThenRegion());
auto newLoop = cast<scf::ForOp>(b.clone(*outerLoop, map));
b.create<scf::YieldOp>(newLoop.getResults());
auto newInnerLoop = cast<scf::ForOp>(map.lookup(innerLoop));
newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits());
newInnerLoop.erase();
outerLoop.replaceAllUsesWith(ifOp.getResults());
Block *block = b.createBlock(&ifOp.getElseRegion());
outerLoop->remove();
b.insert(outerLoop);
b.create<scf::YieldOp>(outerLoop.getResults());
return success();
}
static LogicalResult preprocessLoopNest(const LoopNest &nest,
mlir::DominanceInfo &domInfo) {
assert(nest.nodes.size() == 2 && nest.root->children.size() == 1);
scf::ForOp &outerLoop = nest.root->loop;
scf::ForOp &innerLoop = nest.root->children.front()->loop;
moveLoopInvariantCode(outerLoop);
optimizeEpilogueDependencies(outerLoop, innerLoop, domInfo);
return speculateInnerLoopLength(outerLoop, innerLoop, domInfo);
}
void FuseNestedLoopsPass::runOnOperation() {
auto &domInfo = getAnalysis<DominanceInfo>();
for (auto func : getOperation().getOps<FuncOp>()) {
SmallVector<LoopNest> nests;
findLoopNests(func, nests);
for (LoopNest &nest : nests) {
if (!shouldFuse(nest))
continue;
if (!nest.root->loop->hasAttr(kAlwaysFuseAttrName) &&
failed(preprocessLoopNest(nest, domInfo)))
continue;
flattenLoopNest(nest.root, domInfo);
}
}
}
}
}
}