#include "mlir/Transforms/SROA.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
#define GEN_PASS_DEF_SROA
#include "mlir/Transforms/Passes.h.inc"
}
#define DEBUG_TYPE "sroa"
using namespace mlir;
namespace {
struct MemorySlotDestructuringInfo {
SmallPtrSet<Attribute, 8> usedIndices;
DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
SmallVector<DestructurableAccessorOpInterface> accessors;
};
}
static std::optional<MemorySlotDestructuringInfo>
computeDestructuringInfo(DestructurableMemorySlot &slot,
const DataLayout &dataLayout) {
assert(isa<DestructurableTypeInterface>(slot.elemType));
if (slot.ptr.use_empty())
return {};
MemorySlotDestructuringInfo info;
SmallVector<MemorySlot> usedSafelyWorklist;
auto scheduleAsBlockingUse = [&](OpOperand &use) {
SmallPtrSetImpl<OpOperand *> &blockingUses =
info.userToBlockingUses.getOrInsertDefault(use.getOwner());
blockingUses.insert(&use);
};
for (OpOperand &use : slot.ptr.getUses()) {
if (auto accessor =
dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
dataLayout)) {
info.accessors.push_back(accessor);
continue;
}
}
scheduleAsBlockingUse(use);
}
SmallPtrSet<OpOperand *, 16> visited;
while (!usedSafelyWorklist.empty()) {
MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
if (!visited.insert(&subslotUse).second)
continue;
Operation *subslotUser = subslotUse.getOwner();
if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
if (succeeded(memOp.ensureOnlySafeAccesses(
mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
continue;
scheduleAsBlockingUse(subslotUse);
}
}
SetVector<Operation *> forwardSlice;
mlir::getForwardSlice(slot.ptr, &forwardSlice);
for (Operation *user : forwardSlice) {
if (!info.userToBlockingUses.contains(user))
continue;
SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
auto promotable = dyn_cast<PromotableOpInterface>(user);
if (!promotable)
return {};
SmallVector<OpOperand *> newBlockingUses;
if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
return {};
for (OpOperand *blockingUse : newBlockingUses) {
assert(llvm::is_contained(user->getResults(), blockingUse->get()));
SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
newUserBlockingUseSet.insert(blockingUse);
}
}
return info;
}
static void destructureSlot(
DestructurableMemorySlot &slot,
DestructurableAllocationOpInterface allocator, OpBuilder &builder,
const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
const SROAStatistics &statistics) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(slot.ptr.getParentBlock());
DenseMap<Attribute, MemorySlot> subslots =
allocator.destructure(slot, info.usedIndices, builder, newAllocators);
if (statistics.slotsWithMemoryBenefit &&
slot.subelementTypes.size() != info.usedIndices.size())
(*statistics.slotsWithMemoryBenefit)++;
if (statistics.maxSubelementAmount)
statistics.maxSubelementAmount->updateMax(slot.subelementTypes.size());
SetVector<Operation *> usersToRewire;
for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
usersToRewire.insert(user);
for (DestructurableAccessorOpInterface accessor : info.accessors)
usersToRewire.insert(accessor);
usersToRewire = mlir::topologicalSort(usersToRewire);
llvm::SmallVector<Operation *> toErase;
for (Operation *toRewire : llvm::reverse(usersToRewire)) {
builder.setInsertionPointAfter(toRewire);
if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
if (accessor.rewire(slot, subslots, builder, dataLayout) ==
DeletionKind::Delete)
toErase.push_back(accessor);
continue;
}
auto promotable = cast<PromotableOpInterface>(toRewire);
if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
builder) == DeletionKind::Delete)
toErase.push_back(promotable);
}
for (Operation *toEraseOp : toErase)
toEraseOp->erase();
assert(slot.ptr.use_empty() && "after destructuring, the original slot "
"pointer should no longer be used");
LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
<< "\n");
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
std::optional<DestructurableAllocationOpInterface> newAllocator =
allocator.handleDestructuringComplete(slot, builder);
if (newAllocator)
newAllocators.push_back(*newAllocator);
}
LogicalResult mlir::tryToDestructureMemorySlots(
ArrayRef<DestructurableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
SROAStatistics statistics) {
bool destructuredAny = false;
SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
allocators.end());
SmallVector<DestructurableAllocationOpInterface> newWorkList;
newWorkList.reserve(allocators.size());
while (true) {
bool changesInThisRound = false;
for (DestructurableAllocationOpInterface allocator : workList) {
bool destructuredAnySlot = false;
for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
std::optional<MemorySlotDestructuringInfo> info =
computeDestructuringInfo(slot, dataLayout);
if (!info)
continue;
destructureSlot(slot, allocator, builder, dataLayout, *info,
newWorkList, statistics);
destructuredAnySlot = true;
break;
}
if (!destructuredAnySlot)
newWorkList.push_back(allocator);
changesInThisRound |= destructuredAnySlot;
}
if (!changesInThisRound)
break;
destructuredAny |= changesInThisRound;
workList.swap(newWorkList);
newWorkList.clear();
}
return success(destructuredAny);
}
namespace {
struct SROA : public impl::SROABase<SROA> {
using impl::SROABase<SROA>::SROABase;
void runOnOperation() override {
Operation *scopeOp = getOperation();
SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
&maxSubelementAmount};
auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
bool changed = false;
for (Region ®ion : scopeOp->getRegions()) {
if (region.getBlocks().empty())
continue;
OpBuilder builder(®ion.front(), region.front().begin());
SmallVector<DestructurableAllocationOpInterface> allocators;
region.walk([&](DestructurableAllocationOpInterface allocator) {
allocators.emplace_back(allocator);
});
if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
statistics)))
changed = true;
}
if (!changed)
markAllAnalysesPreserved();
}
};
}