#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
function_ref<bool(Value, Operation *)> isOperandReady) {
const auto isReady = [&](Value value) {
if (isOperandReady && isOperandReady(value, op))
return true;
Operation *parent = value.getDefiningOp();
if (!parent)
return true;
do {
if (parent == op)
return true;
if (unscheduledOps.contains(parent))
return false;
} while ((parent = parent->getParentOp()));
return true;
};
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
return llvm::all_of(nestedOp->getOperands(),
[&](Value operand) { return isReady(operand); })
? WalkResult::advance()
: WalkResult::interrupt();
});
return !readyToSchedule.wasInterrupted();
}
bool mlir::sortTopologically(
Block *block, llvm::iterator_range<Block::iterator> ops,
function_ref<bool(Value, Operation *)> isOperandReady) {
if (ops.empty())
return true;
DenseSet<Operation *> unscheduledOps;
for (Operation &op : ops)
unscheduledOps.insert(&op);
Block::iterator nextScheduledOp = ops.begin();
Block::iterator end = ops.end();
bool allOpsScheduled = true;
while (!unscheduledOps.empty()) {
bool scheduledAtLeastOnce = false;
for (Operation &op :
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
if (!isOpReady(&op, unscheduledOps, isOperandReady))
continue;
unscheduledOps.erase(&op);
op.moveBefore(block, nextScheduledOp);
scheduledAtLeastOnce = true;
if (&op == &*nextScheduledOp)
++nextScheduledOp;
}
if (!scheduledAtLeastOnce) {
allOpsScheduled = false;
unscheduledOps.erase(&*nextScheduledOp);
++nextScheduledOp;
}
}
return allOpsScheduled;
}
bool mlir::sortTopologically(
Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
if (block->empty())
return true;
if (block->back().hasTrait<OpTrait::IsTerminator>())
return sortTopologically(block, block->without_terminator(),
isOperandReady);
return sortTopologically(block, *block, isOperandReady);
}
bool mlir::computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady) {
if (ops.empty())
return true;
DenseSet<Operation *> unscheduledOps;
for (Operation *op : ops)
unscheduledOps.insert(op);
unsigned nextScheduledOp = 0;
bool allOpsScheduled = true;
while (!unscheduledOps.empty()) {
bool scheduledAtLeastOnce = false;
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
continue;
unscheduledOps.erase(ops[i]);
std::swap(ops[i], ops[nextScheduledOp]);
scheduledAtLeastOnce = true;
++nextScheduledOp;
}
if (!scheduledAtLeastOnce) {
allOpsScheduled = false;
unscheduledOps.erase(ops[nextScheduledOp++]);
}
}
return allOpsScheduled;
}
SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
SetVector<Block *> blocks;
for (Block &b : region) {
if (blocks.count(&b) == 0) {
llvm::ReversePostOrderTraversal<Block *> traversal(&b);
blocks.insert(traversal.begin(), traversal.end());
}
}
assert(blocks.size() == region.getBlocks().size() &&
"some blocks are not sorted");
return blocks;
}
namespace {
class TopoSortHelper {
public:
explicit TopoSortHelper(const SetVector<Operation *> &toSort)
: toSort(toSort) {}
SetVector<Operation *> sort() {
if (toSort.size() <= 1) {
return toSort;
}
Region *rootRegion = findCommonAncestorRegion();
assert(rootRegion && "expected all ops to have a common ancestor");
SetVector<Operation *> result = topoSortRegion(*rootRegion);
assert(result.size() == toSort.size() &&
"expected all operations to be present in the result");
return result;
}
private:
Region *findCommonAncestorRegion() {
DenseMap<Region *, size_t> regionCounts;
size_t expectedCount = toSort.size();
Region *res = nullptr;
for (Operation *op : toSort) {
Region *current = op->getParentRegion();
ancestorBlocks.insert(op->getBlock());
while (current) {
if (++regionCounts[current] == expectedCount) {
res = current;
break;
}
ancestorBlocks.insert(current->getParentOp()->getBlock());
current = current->getParentRegion();
}
}
auto firstRange = llvm::make_first_range(regionCounts);
ancestorRegions.insert(firstRange.begin(), firstRange.end());
return res;
}
SetVector<Operation *> topoSortRegion(Region &rootRegion) {
using StackT = PointerUnion<Region *, Block *, Operation *>;
SetVector<Operation *> result;
SmallVector<StackT> stack;
stack.push_back(&rootRegion);
while (!stack.empty()) {
StackT current = stack.pop_back_val();
if (auto *region = dyn_cast<Region *>(current)) {
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(*region);
for (Block *block : llvm::reverse(sortedBlocks)) {
if (ancestorBlocks.contains(block))
stack.push_back(block);
}
continue;
}
if (auto *block = dyn_cast<Block *>(current)) {
for (Operation &op : llvm::reverse(*block))
stack.push_back(&op);
continue;
}
auto *op = cast<Operation *>(current);
if (toSort.contains(op))
result.insert(op);
for (Region &subRegion : op->getRegions())
if (ancestorRegions.contains(&subRegion))
stack.push_back(&subRegion);
}
return result;
}
const SetVector<Operation *> &toSort;
DenseSet<Region *> ancestorRegions;
DenseSet<Block *> ancestorBlocks;
};
}
SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
return TopoSortHelper(toSort).sort();
}