#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <cstddef>
#include <memory>
#include <optional>
#include <vector>
namespace mlir {
#define GEN_PASS_DEF_REMOVEDEADVALUES
#include "mlir/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::dataflow;
namespace {
static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
for (Value value : values) {
if (!value)
continue;
const Liveness *liveness = la.getLiveness(value);
if (!liveness || liveness->isLive)
return true;
}
return false;
}
static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
BitVector lives(values.size(), true);
for (auto [index, value] : llvm::enumerate(values)) {
if (!value) {
lives.reset(index);
continue;
}
const Liveness *liveness = la.getLiveness(value);
if (liveness && !liveness->isLive)
lives.reset(index);
}
return lives;
}
static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
std::vector<Type> newResultTypes;
for (OpResult result : op->getResults())
if (!toErase[result.getResultNumber()])
newResultTypes.push_back(result.getType());
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
state.addRegion();
Operation *newOp = builder.create(state);
for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
Region &newRegion = newOp->getRegion(index);
Block *temp = new Block();
newRegion.push_back(temp);
while (!region.empty())
region.front().moveBefore(temp);
temp->erase();
}
unsigned indexOfNextNewCallOpResultToReplace = 0;
for (auto [index, result] : llvm::enumerate(op->getResults())) {
assert(result && "expected result to be non-null");
if (toErase[index]) {
result.dropAllUses();
} else {
result.replaceAllUsesWith(
newOp->getResult(indexOfNextNewCallOpResultToReplace++));
}
}
op->erase();
}
static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
OpOperand *values = operands.getBase();
SmallVector<OpOperand *> opOperands;
for (unsigned i = 0, e = operands.size(); i < e; i++)
opOperands.push_back(&values[i]);
return opOperands;
}
static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
return;
op->dropAllUses();
op->erase();
}
static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la) {
if (funcOp.isPublic())
return;
SmallVector<Value> arguments(funcOp.getArguments());
BitVector nonLiveArgs = markLives(arguments, la);
nonLiveArgs = nonLiveArgs.flip();
for (auto [index, arg] : llvm::enumerate(arguments))
if (arg && nonLiveArgs[index])
arg.dropAllUses();
funcOp.eraseArguments(nonLiveArgs);
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
SmallVector<OpOperand *> callOpOperands =
operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
for (int index : nonLiveArgs.set_bits())
nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
callOp->eraseOperands(nonLiveCallOperands);
}
Operation *lastReturnOp = funcOp.back().getTerminator();
size_t numReturns = lastReturnOp->getNumOperands();
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
BitVector liveCallRets = markLives(callOp->getResults(), la);
nonLiveRets &= liveCallRets.flip();
}
for (Block &block : funcOp.getBlocks()) {
Operation *returnOp = block.getTerminator();
if (returnOp && returnOp->getNumOperands() == numReturns)
returnOp->eraseOperands(nonLiveRets);
}
funcOp.eraseResults(nonLiveRets);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
dropUsesAndEraseResults(callOp, nonLiveRets);
}
}
static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la) {
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), la);
};
auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
for (Region ®ion : regionBranchOp->getRegions()) {
SmallVector<Value> arguments(region.front().getArguments());
BitVector regionLiveArgs = markLives(arguments, la);
liveArgs[®ion] = regionLiveArgs;
}
};
auto getSuccessors = [&](Region *region = nullptr) {
auto point = region ? region : RegionBranchPoint::parent();
SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
};
auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
Operation *terminator = nullptr) {
OperandRange operands =
terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
.getSuccessorOperands(successor)
: regionBranchOp.getEntrySuccessorOperands(successor);
SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
return opOperands;
};
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors()) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
};
auto markNonForwardedReturnValues =
[&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
for (Region ®ion : regionBranchOp->getRegions()) {
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(®ion)) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
}
}
};
auto updateOperandsOrTerminatorOperandsToKeep =
[&](BitVector &valuesToKeep, BitVector &resultsToKeep,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
for (const RegionSuccessor &successor : getSuccessors(region)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
successor.getSuccessorInputs())) {
size_t operandNum = opOperand->getOperandNumber();
bool updateBasedOn =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
}
}
};
auto recomputeResultsAndArgsToKeep =
[&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
BitVector &operandsToKeep,
DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
bool &resultsOrArgsToKeepChanged) {
resultsOrArgsToKeepChanged = false;
for (const RegionSuccessor &successor : getSuccessors()) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
operandsToKeep[opOperand->getOperandNumber()];
bool toRecompute =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
if (!toRecompute && recomputeBasedOn)
resultsOrArgsToKeepChanged = true;
if (successorRegion) {
argsToKeep[successorRegion][cast<BlockArgument>(input)
.getArgNumber()] =
argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()] |
recomputeBasedOn;
} else {
resultsToKeep[cast<OpResult>(input).getResultNumber()] =
resultsToKeep[cast<OpResult>(input).getResultNumber()] |
recomputeBasedOn;
}
}
}
for (Region ®ion : regionBranchOp->getRegions()) {
Operation *terminator = region.front().getTerminator();
for (const RegionSuccessor &successor : getSuccessors(®ion)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
successor.getSuccessorInputs())) {
bool recomputeBasedOn =
terminatorOperandsToKeep[region.back().getTerminator()]
[opOperand->getOperandNumber()];
bool toRecompute =
successorRegion
? argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()]
: resultsToKeep[cast<OpResult>(input).getResultNumber()];
if (!toRecompute && recomputeBasedOn)
resultsOrArgsToKeepChanged = true;
if (successorRegion) {
argsToKeep[successorRegion][cast<BlockArgument>(input)
.getArgNumber()] =
argsToKeep[successorRegion]
[cast<BlockArgument>(input).getArgNumber()] |
recomputeBasedOn;
} else {
resultsToKeep[cast<OpResult>(input).getResultNumber()] =
resultsToKeep[cast<OpResult>(input).getResultNumber()] |
recomputeBasedOn;
}
}
}
}
};
auto markValuesToKeep =
[&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
BitVector &operandsToKeep,
DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
bool resultsOrArgsToKeepChanged = true;
while (resultsOrArgsToKeepChanged) {
updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
resultsToKeep, argsToKeep);
for (Region ®ion : regionBranchOp->getRegions()) {
updateOperandsOrTerminatorOperandsToKeep(
terminatorOperandsToKeep[region.back().getTerminator()],
resultsToKeep, argsToKeep, ®ion);
}
recomputeResultsAndArgsToKeep(
resultsToKeep, argsToKeep, operandsToKeep,
terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
}
};
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), la)) {
regionBranchOp->dropAllUses();
regionBranchOp->erase();
return;
}
BitVector resultsToKeep;
DenseMap<Region *, BitVector> argsToKeep;
BitVector operandsToKeep;
DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
markLiveResults(resultsToKeep);
markLiveArgs(argsToKeep);
markNonForwardedOperands(operandsToKeep);
markNonForwardedReturnValues(terminatorOperandsToKeep);
markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
terminatorOperandsToKeep);
regionBranchOp->eraseOperands(operandsToKeep.flip());
for (Region ®ion : regionBranchOp->getRegions()) {
assert(!region.empty() && "expected a non-empty region in an op "
"implementing `RegionBranchOpInterface`");
for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
if (argsToKeep[®ion][index])
continue;
if (arg)
arg.dropAllUses();
}
region.front().eraseArguments(argsToKeep[®ion].flip());
}
for (Region ®ion : regionBranchOp->getRegions()) {
Operation *terminator = region.front().getTerminator();
terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
}
dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
void runOnOperation() override;
};
}
void RemoveDeadValues::runOnOperation() {
auto &la = getAnalysis<RunLivenessAnalysis>();
Operation *module = getOperation();
WalkResult acceptableIR = module->walk([&](Operation *op) {
if (isa<BranchOpInterface>(op) ||
(isa<SymbolOpInterface>(op) && !isa<FunctionOpInterface>(op)) ||
(isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
op->emitError() << "cannot optimize an IR with non-function symbol ops, "
"non-call symbol user ops or branch ops\n";
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (acceptableIR.wasInterrupted())
return;
module->walk([&](Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
cleanFuncOp(funcOp, module, la);
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
cleanRegionBranchOp(regionBranchOp, la);
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
} else if (isa<CallOpInterface>(op)) {
} else {
cleanSimpleOp(op, la);
}
});
}
std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
return std::make_unique<RemoveDeadValues>();
}