#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ManagedStatic.h"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREE
#include "mlir/Reducer/Passes.h.inc"
}
using namespace mlir;
static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
bool eraseOpNotInRange) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
for (const auto &op : enumerate(region.getOps())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
opsNotInRange.push_back(&op.value());
else
opsInRange.push_back(&op.value());
}
for (Operation *op : opsInRange) {
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyOpPatternsAndFold(op, patterns, config);
}
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {
op->dropAllUses();
op->erase();
}
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test, bool eraseOpNotInRange) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
if (initStatus.first != Tester::Interestingness::True)
return module.emitWarning() << "uninterested module will not be reduced";
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
root->update(initStatus);
ReductionNode *smallestNode = root;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ReductionNode ¤tNode = *iter;
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
eraseOpNotInRange);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
currentNode.getSize() < smallestNode->getSize())
smallestNode = ¤tNode;
++iter;
}
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
trace.push_back(curNode);
while (curNode != root) {
curNode = curNode->getParent();
trace.push_back(curNode);
}
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
if (test.isInteresting(module).second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
return success();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
if (failed(findOptimal<IteratorType>(module, region, {}, test,
true)))
return failure();
return findOptimal<IteratorType>(module, region, patterns, test,
false);
}
namespace {
class ReductionPatternInterfaceCollection
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
public:
using Base::Base;
void populateReductionPatterns(RewritePatternSet &pattern) const {
for (const DialectReductionPatternInterface &interface : *this)
interface.populateReductionPatterns(pattern);
}
};
class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
LogicalResult initialize(MLIRContext *context) override;
void runOnOperation() override;
private:
LogicalResult reduceOp(ModuleOp module, Region ®ion);
FrozenRewritePatternSet reducerPatterns;
};
}
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
RewritePatternSet patterns(context);
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
reducerPatterns = std::move(patterns);
return success();
}
void ReductionTreePass::runOnOperation() {
Operation *topOperation = getOperation();
while (topOperation->getParentOp() != nullptr)
topOperation = topOperation->getParentOp();
ModuleOp module = dyn_cast<ModuleOp>(topOperation);
if (!module) {
emitError(getOperation()->getLoc())
<< "top-level op must be 'builtin.module'";
return signalPassFailure();
}
SmallVector<Operation *, 8> workList;
workList.push_back(getOperation());
do {
Operation *op = workList.pop_back_val();
for (Region ®ion : op->getRegions())
if (!region.empty())
if (failed(reduceOp(module, region)))
return signalPassFailure();
for (Region ®ion : op->getRegions())
for (Operation &op : region.getOps())
if (op.getNumRegions() != 0)
workList.push_back(&op);
} while (!workList.empty());
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, region, reducerPatterns, test);
default:
return module.emitError() << "unsupported traversal mode detected";
}
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}