#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
namespace mlir {
#define GEN_PASS_DEF_SCCP
#include "mlir/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::dataflow;
static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
if (!lattice || lattice->getValue().isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
if (!latticeValue.getConstantValue())
return failure();
Dialect *dialect = latticeValue.getConstantDialect();
Value constant = folder.getOrCreateConstant(
builder.getInsertionBlock(), dialect, latticeValue.getConstantValue(),
value.getType());
if (!constant)
return failure();
value.replaceAllUsesWith(constant);
return success();
}
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region ®ion : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};
OperationFolder folder(context);
OpBuilder builder(context);
addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(solver, builder, folder, res));
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
op.erase();
continue;
}
addToWorklist(op.getRegions());
}
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(solver, builder, folder, arg);
}
}
namespace {
struct SCCP : public impl::SCCPBase<SCCP> {
void runOnOperation() override;
};
}
void SCCP::runOnOperation() {
Operation *op = getOperation();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
rewrite(solver, op->getContext(), op->getRegions());
}
std::unique_ptr<Pass> mlir::createSCCPPass() {
return std::make_unique<SCCP>();
}