#include "mlir/Analysis/Liveness.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
struct BlockInfoBuilder {
using ValueSetT = Liveness::ValueSetT;
BlockInfoBuilder() = default;
BlockInfoBuilder(Block *block) : block(block) {
auto gatherOutValues = [&](Value value) {
for (Operation *useOp : value.getUsers()) {
Block *ownerBlock = useOp->getBlock();
ownerBlock = block->getParent()->findAncestorBlockInRegion(*ownerBlock);
assert(ownerBlock && "Use leaves the current parent region");
if (ownerBlock != block) {
outValues.insert(value);
break;
}
}
};
for (BlockArgument argument : block->getArguments()) {
defValues.insert(argument);
gatherOutValues(argument);
}
for (Operation &operation : *block)
for (Value result : operation.getResults())
gatherOutValues(result);
block->walk([&](Operation *op) {
for (Value result : op->getResults())
defValues.insert(result);
for (Value operand : op->getOperands())
useValues.insert(operand);
for (Region ®ion : op->getRegions())
for (Block &child : region.getBlocks())
for (BlockArgument arg : child.getArguments())
defValues.insert(arg);
});
llvm::set_subtract(useValues, defValues);
}
bool updateLiveIn() {
ValueSetT newIn = useValues;
llvm::set_union(newIn, outValues);
llvm::set_subtract(newIn, defValues);
if (newIn.size() == inValues.size())
return false;
inValues = std::move(newIn);
return true;
}
void updateLiveOut(const DenseMap<Block *, BlockInfoBuilder> &builders) {
for (Block *succ : block->getSuccessors()) {
const BlockInfoBuilder &builder = builders.find(succ)->second;
llvm::set_union(outValues, builder.inValues);
}
}
Block *block{nullptr};
ValueSetT inValues;
ValueSetT outValues;
ValueSetT defValues;
ValueSetT useValues;
};
}
static void buildBlockMapping(Operation *operation,
DenseMap<Block *, BlockInfoBuilder> &builders) {
SetVector<Block *> toProcess;
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
BlockInfoBuilder &builder =
builders.try_emplace(block, block).first->second;
if (builder.updateLiveIn())
toProcess.insert(block->pred_begin(), block->pred_end());
});
while (!toProcess.empty()) {
Block *current = toProcess.pop_back_val();
BlockInfoBuilder &builder = builders[current];
builder.updateLiveOut(builders);
if (builder.updateLiveIn())
toProcess.insert(current->pred_begin(), current->pred_end());
}
}
Liveness::Liveness(Operation *op) : operation(op) { build(); }
void Liveness::build() {
DenseMap<Block *, BlockInfoBuilder> builders;
buildBlockMapping(operation, builders);
for (auto &entry : builders) {
BlockInfoBuilder &builder = entry.second;
LivenessBlockInfo &info = blockMapping[entry.first];
info.block = builder.block;
info.inValues = std::move(builder.inValues);
info.outValues = std::move(builder.outValues);
}
}
Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
OperationListT result;
SmallPtrSet<Block *, 32> visited;
SmallVector<Block *, 8> toProcess;
Block *currentBlock;
if (Operation *defOp = value.getDefiningOp())
currentBlock = defOp->getBlock();
else
currentBlock = cast<BlockArgument>(value).getOwner();
toProcess.push_back(currentBlock);
visited.insert(currentBlock);
for (OpOperand &use : value.getUses()) {
Block *useBlock = use.getOwner()->getBlock();
if (visited.insert(useBlock).second)
toProcess.push_back(useBlock);
}
while (!toProcess.empty()) {
Block *block = toProcess.back();
toProcess.pop_back();
const LivenessBlockInfo *blockInfo = getLiveness(block);
Operation *start = blockInfo->getStartOperation(value);
Operation *end = blockInfo->getEndOperation(value, start);
result.push_back(start);
while (start != end) {
start = start->getNextNode();
result.push_back(start);
}
for (Block *successor : block->getSuccessors()) {
if (getLiveness(successor)->isLiveIn(value) &&
visited.insert(successor).second)
toProcess.push_back(successor);
}
}
return result;
}
const LivenessBlockInfo *Liveness::getLiveness(Block *block) const {
auto it = blockMapping.find(block);
return it == blockMapping.end() ? nullptr : &it->second;
}
const Liveness::ValueSetT &Liveness::getLiveIn(Block *block) const {
return getLiveness(block)->in();
}
const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const {
return getLiveness(block)->out();
}
bool Liveness::isDeadAfter(Value value, Operation *operation) const {
Block *block = operation->getBlock();
const LivenessBlockInfo *blockInfo = getLiveness(block);
if (blockInfo->isLiveOut(value))
return false;
Operation *endOperation = blockInfo->getEndOperation(value, operation);
return endOperation == operation || endOperation->isBeforeInBlock(operation);
}
void Liveness::dump() const { print(llvm::errs()); }
void Liveness::print(raw_ostream &os) const {
os << "// ---- Liveness -----\n";
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds;
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
blockIds.insert({block, blockIds.size()});
for (BlockArgument argument : block->getArguments())
valueIds.insert({argument, valueIds.size()});
for (Operation &operation : *block) {
operationIds.insert({&operation, operationIds.size()});
for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()});
}
});
auto printValueRef = [&](Value value) {
if (value.getDefiningOp())
os << "val_" << valueIds[value];
else {
auto blockArg = cast<BlockArgument>(value);
os << "arg" << blockArg.getArgNumber() << "@"
<< blockIds[blockArg.getOwner()];
}
os << " ";
};
auto printValueRefs = [&](const ValueSetT &values) {
std::vector<Value> orderedValues(values.begin(), values.end());
llvm::sort(orderedValues, [&](Value left, Value right) {
return valueIds[left] < valueIds[right];
});
for (Value value : orderedValues)
printValueRef(value);
};
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
os << "// - Block: " << blockIds[block] << "\n";
const auto *liveness = getLiveness(block);
os << "// --- LiveIn: ";
printValueRefs(liveness->inValues);
os << "\n// --- LiveOut: ";
printValueRefs(liveness->outValues);
os << "\n";
os << "// --- BeginLivenessIntervals";
for (Operation &op : *block) {
if (op.getNumResults() < 1)
continue;
os << "\n";
for (Value result : op.getResults()) {
os << "// ";
printValueRef(result);
os << ":";
auto liveOperations = resolveLiveness(result);
llvm::sort(liveOperations, [&](Operation *left, Operation *right) {
return operationIds[left] < operationIds[right];
});
for (Operation *operation : liveOperations) {
os << "\n// ";
operation->print(os);
}
}
}
os << "\n// --- EndLivenessIntervals\n";
os << "// --- BeginCurrentlyLive\n";
for (Operation &op : *block) {
auto currentlyLive = liveness->currentlyLiveValues(&op);
if (currentlyLive.empty())
continue;
os << "// ";
op.print(os);
os << " [";
printValueRefs(currentlyLive);
os << "\b]\n";
}
os << "// --- EndCurrentlyLive\n";
});
os << "// -------------------\n";
}
bool LivenessBlockInfo::isLiveIn(Value value) const {
return inValues.count(value);
}
bool LivenessBlockInfo::isLiveOut(Value value) const {
return outValues.count(value);
}
Operation *LivenessBlockInfo::getStartOperation(Value value) const {
Operation *definingOp = value.getDefiningOp();
if (isLiveIn(value) || !definingOp)
return &block->front();
return definingOp;
}
Operation *LivenessBlockInfo::getEndOperation(Value value,
Operation *startOperation) const {
if (isLiveOut(value))
return &block->back();
Operation *endOperation = startOperation;
for (Operation *useOp : value.getUsers()) {
useOp = block->findAncestorOpInBlock(*useOp);
if (useOp && endOperation->isBeforeInBlock(useOp))
endOperation = useOp;
}
return endOperation;
}
LivenessBlockInfo::ValueSetT
LivenessBlockInfo::currentlyLiveValues(Operation *op) const {
ValueSetT liveSet;
auto addValueToCurrentlyLiveSets = [&](Value value) {
Operation *startOfLiveRange = value.getDefiningOp();
Operation *endOfLiveRange = nullptr;
if (isLiveIn(value) || isa<BlockArgument>(value))
startOfLiveRange = &block->front();
else
startOfLiveRange = block->findAncestorOpInBlock(*startOfLiveRange);
if (isLiveOut(value))
endOfLiveRange = &block->back();
if (startOfLiveRange && !endOfLiveRange)
endOfLiveRange = getEndOperation(value, startOfLiveRange);
assert(endOfLiveRange && "Must have endOfLiveRange at this point!");
if (!(op->isBeforeInBlock(startOfLiveRange) ||
endOfLiveRange->isBeforeInBlock(op)))
liveSet.insert(value);
};
for (Value arg : block->getArguments())
addValueToCurrentlyLiveSets(arg);
for (Value in : inValues)
addValueToCurrentlyLiveSets(in);
for (Operation &walkOp :
llvm::make_range(block->begin(), ++op->getIterator()))
for (auto result : walkOp.getResults())
addValueToCurrentlyLiveSets(result);
return liveSet;
}