#include "mlir/Transforms/Mem2Reg.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
namespace mlir {
#define GEN_PASS_DEF_MEM2REG
#include "mlir/Transforms/Passes.h.inc"
}
#define DEBUG_TYPE "mem2reg"
using namespace mlir;
namespace {
using BlockingUsesMap =
llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
struct MemorySlotPromotionInfo {
SmallPtrSet<Block *, 8> mergePoints;
BlockingUsesMap userToBlockingUses;
};
class MemorySlotPromotionAnalyzer {
public:
MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
const DataLayout &dataLayout)
: slot(slot), dominance(dominance), dataLayout(dataLayout) {}
std::optional<MemorySlotPromotionInfo> computeInfo();
private:
LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
SmallPtrSet<Block *, 16>
computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
MemorySlot slot;
DominanceInfo &dominance;
const DataLayout &dataLayout;
};
using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
class MemorySlotPromoter {
public:
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
OpBuilder &builder, DominanceInfo &dominance,
const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics,
BlockIndexCache &blockIndexCache);
std::optional<PromotableAllocationOpInterface> promoteSlot();
private:
Value computeReachingDefInBlock(Block *block, Value reachingDef);
void computeReachingDefInRegion(Region *region, Value reachingDef);
void removeBlockingUses();
Value getOrCreateDefaultValue();
MemorySlot slot;
PromotableAllocationOpInterface allocator;
OpBuilder &builder;
Value defaultValue;
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
BlockIndexCache &blockIndexCache;
};
}
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
BlockIndexCache &blockIndexCache)
: slot(slot), allocator(allocator), builder(builder), dominance(dominance),
dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
blockIndexCache(blockIndexCache) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
return arg.getOwner()->getParentOp() == allocator;
return slot.ptr.getDefiningOp() == allocator;
};
assert(isResultOrNewBlockArgument() &&
"a slot must be a result of the allocator or an argument of the child "
"regions of the allocator");
#endif
}
Value MemorySlotPromoter::getOrCreateDefaultValue() {
if (defaultValue)
return defaultValue;
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(slot.ptr.getParentBlock());
return defaultValue = allocator.getDefaultValue(slot, builder);
}
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
BlockingUsesMap &userToBlockingUses) {
for (OpOperand &use : slot.ptr.getUses()) {
SmallPtrSet<OpOperand *, 4> &blockingUses =
userToBlockingUses[use.getOwner()];
blockingUses.insert(&use);
}
SetVector<Operation *> forwardSlice;
mlir::getForwardSlice(slot.ptr, &forwardSlice);
for (Operation *user : forwardSlice) {
if (!userToBlockingUses.contains(user))
continue;
SmallPtrSet<OpOperand *, 4> &blockingUses = userToBlockingUses[user];
SmallVector<OpOperand *> newBlockingUses;
if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
dataLayout))
return failure();
} else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
dataLayout))
return failure();
} else {
return failure();
}
for (OpOperand *blockingUse : newBlockingUses) {
assert(llvm::is_contained(user->getResults(), blockingUse->get()));
SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
userToBlockingUses[blockingUse->getOwner()];
newUserBlockingUseSet.insert(blockingUse);
}
}
for (auto &[toPromote, _] : userToBlockingUses)
if (isa<PromotableMemOpInterface>(toPromote) &&
toPromote->getParentRegion() != slot.ptr.getParentRegion())
return failure();
return success();
}
SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
SmallPtrSetImpl<Block *> &definingBlocks) {
SmallPtrSet<Block *, 16> liveIn;
SmallVector<Block *> liveInWorkList;
SmallPtrSet<Block *, 16> visited;
for (Operation *user : slot.ptr.getUsers()) {
if (visited.contains(user->getBlock()))
continue;
visited.insert(user->getBlock());
for (Operation &op : user->getBlock()->getOperations()) {
if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
if (memOp.loadsFrom(slot)) {
liveInWorkList.push_back(user->getBlock());
break;
}
if (memOp.storesTo(slot))
break;
}
}
}
while (!liveInWorkList.empty()) {
Block *liveInBlock = liveInWorkList.pop_back_val();
if (!liveIn.insert(liveInBlock).second)
continue;
for (Block *pred : liveInBlock->getPredecessors())
if (!definingBlocks.contains(pred))
liveInWorkList.push_back(pred);
}
return liveIn;
}
using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
void MemorySlotPromotionAnalyzer::computeMergePoints(
SmallPtrSetImpl<Block *> &mergePoints) {
if (slot.ptr.getParentRegion()->hasOneBlock())
return;
IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
SmallPtrSet<Block *, 16> definingBlocks;
for (Operation *user : slot.ptr.getUsers())
if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
if (storeOp.storesTo(slot))
definingBlocks.insert(user->getBlock());
idfCalculator.setDefiningBlocks(definingBlocks);
SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
idfCalculator.setLiveInBlocks(liveIn);
SmallVector<Block *> mergePointsVec;
idfCalculator.calculate(mergePointsVec);
mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
}
bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
SmallPtrSetImpl<Block *> &mergePoints) {
for (Block *mergePoint : mergePoints)
for (Block *pred : mergePoint->getPredecessors())
if (!isa<BranchOpInterface>(pred->getTerminator()))
return false;
return true;
}
std::optional<MemorySlotPromotionInfo>
MemorySlotPromotionAnalyzer::computeInfo() {
MemorySlotPromotionInfo info;
if (failed(computeBlockingUses(info.userToBlockingUses)))
return {};
computeMergePoints(info.mergePoints);
if (!areMergePointsUsable(info.mergePoints))
return {};
return info;
}
Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
Value reachingDef) {
SmallVector<Operation *> blockOps;
for (Operation &op : block->getOperations())
blockOps.push_back(&op);
for (Operation *op : blockOps) {
if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
if (info.userToBlockingUses.contains(memOp))
reachingDefs.insert({memOp, reachingDef});
if (memOp.storesTo(slot)) {
builder.setInsertionPointAfter(memOp);
Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
replacedValuesMap[memOp] = stored;
}
}
}
return reachingDef;
}
void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Value reachingDef) {
assert(reachingDef && "expected an initial reaching def to be provided");
if (region->hasOneBlock()) {
computeReachingDefInBlock(®ion->front(), reachingDef);
return;
}
struct DfsJob {
llvm::DomTreeNodeBase<Block> *block;
Value reachingDef;
};
SmallVector<DfsJob> dfsStack;
auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion());
dfsStack.emplace_back<DfsJob>(
{domTree.getNode(®ion->front()), reachingDef});
while (!dfsStack.empty()) {
DfsJob job = dfsStack.pop_back_val();
Block *block = job.block->getBlock();
if (info.mergePoints.contains(block)) {
BlockArgument blockArgument =
block->addArgument(slot.elemType, slot.ptr.getLoc());
builder.setInsertionPointToStart(block);
allocator.handleBlockArgument(slot, blockArgument, builder);
job.reachingDef = blockArgument;
if (statistics.newBlockArgumentAmount)
(*statistics.newBlockArgumentAmount)++;
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
assert(job.reachingDef);
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
if (info.mergePoints.contains(blockOperand.get())) {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
}
}
}
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
}
static const DenseMap<Block *, size_t> &
getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
auto [it, inserted] = blockIndexCache.try_emplace(region);
if (!inserted)
return it->second;
DenseMap<Block *, size_t> &blockIndices = it->second;
SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region);
for (auto [index, block] : llvm::enumerate(topologicalOrder))
blockIndices[block] = index;
return blockIndices;
}
static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion,
BlockIndexCache &blockIndexCache) {
const DenseMap<Block *, size_t> &topoBlockIndices =
getOrCreateBlockIndices(blockIndexCache, ®ion);
llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
if (lhsBlockIndex == rhsBlockIndex)
return lhs->isBeforeInBlock(rhs);
return lhsBlockIndex < rhsBlockIndex;
});
}
void MemorySlotPromoter::removeBlockingUses() {
llvm::SmallVector<Operation *> usersToRemoveUses(
llvm::make_first_range(info.userToBlockingUses));
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
blockIndexCache);
llvm::SmallVector<Operation *> toErase;
llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
llvm::SmallVector<PromotableOpInterface> toVisit;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
if (!reachingDef)
reachingDef = getOrCreateDefaultValue();
builder.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
slot, info.userToBlockingUses[toPromote], builder, reachingDef,
dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
replacedValuesList.push_back({toPromoteMemOp, replacedValue});
continue;
}
auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
builder.setInsertionPointAfter(toPromote);
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
builder) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteBasic.requiresReplacedValues())
toVisit.push_back(toPromoteBasic);
}
for (PromotableOpInterface op : toVisit) {
builder.setInsertionPointAfter(op);
op.visitReplacedValues(replacedValuesList, builder);
}
for (Operation *toEraseOp : toErase)
toEraseOp->erase();
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
}
std::optional<PromotableAllocationOpInterface>
MemorySlotPromoter::promoteSlot() {
computeReachingDefInRegion(slot.ptr.getParentRegion(),
getOrCreateDefaultValue());
removeBlockingUses();
for (Block *mergePoint : info.mergePoints) {
for (BlockOperand &use : mergePoint->getUses()) {
auto user = cast<BranchOpInterface>(use.getOwner());
SuccessorOperands succOperands =
user.getSuccessorOperands(use.getOperandNumber());
assert(succOperands.size() == mergePoint->getNumArguments() ||
succOperands.size() + 1 == mergePoint->getNumArguments());
if (succOperands.size() + 1 == mergePoint->getNumArguments())
succOperands.append(getOrCreateDefaultValue());
}
}
LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
<< "\n");
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
return allocator.handlePromotionComplete(slot, defaultValue, builder);
}
LogicalResult mlir::tryToPromoteMemorySlots(
ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
const DataLayout &dataLayout, DominanceInfo &dominance,
Mem2RegStatistics statistics) {
bool promotedAny = false;
BlockIndexCache blockIndexCache;
SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
allocators.end());
SmallVector<PromotableAllocationOpInterface> newWorkList;
newWorkList.reserve(workList.size());
while (true) {
bool changesInThisRound = false;
for (PromotableAllocationOpInterface allocator : workList) {
bool changedAllocator = false;
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
std::optional<PromotableAllocationOpInterface> newAllocator =
MemorySlotPromoter(slot, allocator, builder, dominance,
dataLayout, std::move(*info), statistics,
blockIndexCache)
.promoteSlot();
changedAllocator = true;
if (newAllocator)
newWorkList.push_back(*newAllocator);
break;
}
}
if (!changedAllocator)
newWorkList.push_back(allocator);
changesInThisRound |= changedAllocator;
}
if (!changesInThisRound)
break;
promotedAny = true;
workList.swap(newWorkList);
newWorkList.clear();
}
return success(promotedAny);
}
namespace {
struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
void runOnOperation() override {
Operation *scopeOp = getOperation();
Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
bool changed = false;
auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
auto &dominance = getAnalysis<DominanceInfo>();
for (Region ®ion : scopeOp->getRegions()) {
if (region.getBlocks().empty())
continue;
OpBuilder builder(®ion.front(), region.front().begin());
SmallVector<PromotableAllocationOpInterface> allocators;
region.walk([&](PromotableAllocationOpInterface allocator) {
allocators.emplace_back(allocator);
});
if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
dominance, statistics)))
changed = true;
}
if (!changed)
markAllAnalysesPreserved();
}
};
}