#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
using namespace mlir::bufferization;
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
static BufferViewFlowAnalysis::ValueSetT
resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
BufferViewFlowAnalysis::ValueSetT result;
SmallVector<Value, 8> queue;
queue.push_back(value);
while (!queue.empty()) {
Value currentValue = queue.pop_back_val();
if (result.insert(currentValue).second) {
auto it = map.find(currentValue);
if (it != map.end()) {
for (Value aliasValue : it->second)
queue.push_back(aliasValue);
}
}
}
return result;
}
BufferViewFlowAnalysis::ValueSetT
BufferViewFlowAnalysis::resolve(Value rootValue) const {
return resolveValues(dependencies, rootValue);
}
BufferViewFlowAnalysis::ValueSetT
BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
return resolveValues(reverseDependencies, rootValue);
}
void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
for (auto &entry : dependencies)
llvm::set_subtract(entry.second, aliasValues);
}
void BufferViewFlowAnalysis::rename(Value from, Value to) {
dependencies[to] = dependencies[from];
dependencies.erase(from);
for (auto &[_, value] : dependencies) {
if (value.contains(from)) {
value.insert(to);
value.erase(from);
}
}
}
void BufferViewFlowAnalysis::build(Operation *op) {
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
this->dependencies[value].insert(dep);
this->reverseDependencies[dep].insert(value);
}
};
auto populateTerminalValues = [&](Operation *op) {
for (Value v : op->getResults())
if (isa<BaseMemRefType>(v.getType()))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
if (isa<BaseMemRefType>(v.getType()))
this->terminals.insert(v);
};
op->walk([&](Operation *op) {
if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
bufferViewFlowOp.populateDependencies(registerDependencies);
for (Value v : op->getResults())
if (isa<BaseMemRefType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
if (isa<BaseMemRefType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
return WalkResult::advance();
}
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
registerDependencies(viewInterface.getViewSource(),
viewInterface->getResult(0));
return WalkResult::advance();
}
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
Block *parentBlock = branchInterface->getBlock();
for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
it != e; ++it) {
auto successorOperands =
branchInterface.getSuccessorOperands(it.getIndex());
registerDependencies(successorOperands.getForwardedOperands(),
(*it)->getArguments().drop_front(
successorOperands.getProducedOperandCount()));
}
return WalkResult::advance();
}
if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) {
SmallVector<RegionSuccessor, 2> entrySuccessors;
regionInterface.getSuccessorRegions(RegionBranchPoint::parent(),
entrySuccessors);
for (RegionSuccessor &entrySuccessor : entrySuccessors) {
registerDependencies(
regionInterface.getEntrySuccessorOperands(entrySuccessor),
entrySuccessor.getSuccessorInputs());
}
for (Region ®ion : regionInterface->getRegions()) {
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region, successorRegions);
for (RegionSuccessor &successorRegion : successorRegions) {
for (Block &block : region)
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator()))
registerDependencies(
terminator.getSuccessorOperands(successorRegion),
successorRegion.getSuccessorInputs());
}
}
return WalkResult::advance();
}
if (isa<RegionBranchTerminatorOpInterface>(op))
return WalkResult::advance();
if (isa<CallOpInterface>(op)) {
populateTerminalValues(op);
for (Value operand : op->getOperands())
for (Value result : op->getResults())
registerDependencies({operand}, {result});
return WalkResult::advance();
}
populateTerminalValues(op);
return WalkResult::advance();
});
}
bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
static bool hasAllocateSideEffect(Value v) {
Operation *op = v.getDefiningOp();
if (!op)
return false;
return hasEffect<MemoryEffects::Allocate>(op, v);
}
static bool isFunctionArgument(Value v) {
auto bbArg = dyn_cast<BlockArgument>(v);
if (!bbArg)
return false;
Block *b = bbArg.getOwner();
auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
if (!funcOp)
return false;
return bbArg.getOwner() == &funcOp.getFunctionBody().front();
}
static Value getViewBase(Value value) {
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
value = viewLikeOp.getViewSource();
return value;
}
BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
v1 = getViewBase(v1);
v2 = getViewBase(v2);
if (v1 == v2)
return true;
SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
SmallPtrSet<Value, 16> terminal1, terminal2;
bool allAllocs1 = true, allAllocs2 = true;
bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
SmallPtrSet<Value, 16> &terminal,
bool &allAllocs,
bool &allAllocsOrFuncEntryArgs) {
for (Value v : origin) {
if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
terminal.insert(v);
allAllocs &= hasAllocateSideEffect(v);
allAllocsOrFuncEntryArgs &=
isFunctionArgument(v) || hasAllocateSideEffect(v);
}
}
assert(!terminal.empty() && "expected non-empty terminal set");
};
gatherTerminalBuffers(origin1, terminal1, allAllocs1,
allAllocsOrFuncEntryArgs1);
gatherTerminalBuffers(origin2, terminal2, allAllocs2,
allAllocsOrFuncEntryArgs2);
if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
*terminal1.begin() == *terminal2.begin())
return true;
bool distinctTerminalSets = true;
for (Value v : terminal1)
distinctTerminalSets &= !terminal2.contains(v);
if (!distinctTerminalSets)
return std::nullopt;
bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
if (isolatedAlloc1 || isolatedAlloc2)
return false;
return std::nullopt;
}