#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
#include <utility>
using namespace mlir;
static constexpr unsigned maxUnderlyingValueSearchDepth = 10;
static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output);
static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
Region *region, Value inputValue,
unsigned inputIndex,
unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
auto getOperandIndexIfPred =
[&](RegionBranchPoint pred) -> std::optional<unsigned> {
SmallVector<RegionSuccessor, 2> successors;
branch.getSuccessorRegions(pred, successors);
for (RegionSuccessor &successor : successors) {
if (successor.getSuccessor() != region)
continue;
ValueRange inputs = successor.getSuccessorInputs();
if (inputs.empty()) {
output.push_back(inputValue);
break;
}
unsigned firstInputIndex, lastInputIndex;
if (region) {
firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
} else {
firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
}
if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
output.push_back(inputValue);
break;
}
return inputIndex - firstInputIndex;
}
return std::nullopt;
};
auto branchPoint = RegionBranchPoint::parent();
if (region)
branchPoint = region;
if (std::optional<unsigned> operandIndex =
getOperandIndexIfPred(RegionBranchPoint::parent())) {
collectUnderlyingAddressValues(
branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
}
Operation *op = branch.getOperation();
for (Region ®ion : op->getRegions()) {
if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
for (Block &block : region) {
if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
block.getTerminator())) {
collectUnderlyingAddressValues(
term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
visited, output);
} else if (block.getNumSuccessors()) {
output.push_back(inputValue);
return;
}
}
}
}
}
static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
Operation *op = result.getOwner();
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op))
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return collectUnderlyingAddressValues(branch, nullptr, result,
result.getResultNumber(), maxDepth,
visited, output);
}
output.push_back(result);
}
static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
Block *block = arg.getOwner();
unsigned argNumber = arg.getArgNumber();
if (!block->isEntryBlock()) {
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branch) {
output.push_back(arg);
return;
}
unsigned index = it.getSuccessorIndex();
Value operand = branch.getSuccessorOperands(index)[argNumber];
if (!operand) {
output.push_back(arg);
return;
}
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
}
Region *region = block->getParent();
Operation *op = region->getParentOp();
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return collectUnderlyingAddressValues(branch, region, arg, argNumber,
maxDepth, visited, output);
}
output.push_back(arg);
}
static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
if (!visited.insert(value).second)
return;
if (maxDepth == 0) {
output.push_back(value);
return;
}
--maxDepth;
if (BlockArgument arg = dyn_cast<BlockArgument>(value))
return collectUnderlyingAddressValues(arg, maxDepth, visited, output);
collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited,
output);
}
static void collectUnderlyingAddressValues(Value value,
SmallVectorImpl<Value> &output) {
DenseSet<Value> visited;
collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited,
output);
}
static LogicalResult
getAllocEffectFor(Value value,
std::optional<MemoryEffects::EffectInstance> &effect,
Operation *&allocScopeOp) {
Operation *op;
if (BlockArgument arg = dyn_cast<BlockArgument>(value))
op = arg.getOwner()->getParentOp();
else
op = cast<OpResult>(value).getOwner();
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
if (!interface)
return failure();
if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value)))
return failure();
if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>(
effect->getResource())) {
allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
return success();
}
allocScopeOp = op->getParentOfType<FunctionOpInterface>();
return success();
}
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (lhs == rhs)
return AliasResult::MustAlias;
Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;
std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc;
Attribute lhsAttr, rhsAttr;
if (matchPattern(lhs, m_Constant(&lhsAttr))) {
if (matchPattern(rhs, m_Constant(&rhsAttr)))
return AliasResult::MayAlias;
return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope))
? AliasResult::NoAlias
: AliasResult::MayAlias;
}
if (matchPattern(rhs, m_Constant(&rhsAttr))) {
return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope))
? AliasResult::NoAlias
: AliasResult::MayAlias;
}
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
if (lhsHasAlloc == rhsHasAlloc) {
return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
if (rhsHasAlloc) {
std::swap(lhs, rhs);
lhsAlloc = rhsAlloc;
lhsAllocScope = rhsAllocScope;
}
if (lhsAllocScope) {
Operation *rhsParentOp = rhs.getParentRegion()->getParentOp();
if (rhsParentOp->isProperAncestor(lhsAllocScope))
return AliasResult::NoAlias;
if (rhsParentOp == lhsAllocScope) {
BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs);
if (rhsArg && rhs.getParentBlock()->isEntryBlock())
return AliasResult::NoAlias;
}
}
return AliasResult::MayAlias;
}
AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
if (lhs == rhs)
return AliasResult::MustAlias;
SmallVector<Value, 8> lhsValues, rhsValues;
collectUnderlyingAddressValues(lhs, lhsValues);
collectUnderlyingAddressValues(rhs, rhsValues);
if (lhsValues.empty() || rhsValues.empty())
return AliasResult::MayAlias;
std::optional<AliasResult> result;
for (Value lhsVal : lhsValues) {
for (Value rhsVal : rhsValues) {
AliasResult nextResult = aliasImpl(lhsVal, rhsVal);
result = result ? result->merge(nextResult) : nextResult;
}
}
return *result;
}
ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
return ModRefResult::getModAndRef();
}
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
if (!interface)
return ModRefResult::getModAndRef();
SmallVector<MemoryEffects::EffectInstance> effects;
interface.getEffects(effects);
ModRefResult result = ModRefResult::getNoModRef();
for (const MemoryEffects::EffectInstance &effect : effects) {
if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
continue;
AliasResult aliasResult = AliasResult::MayAlias;
if (Value effectValue = effect.getValue())
aliasResult = alias(effectValue, location);
if (aliasResult.isNo())
continue;
if (isa<MemoryEffects::Read>(effect.getEffect())) {
result = result.merge(ModRefResult::getRef());
} else {
assert(isa<MemoryEffects::Write>(effect.getEffect()));
result = result.merge(ModRefResult::getMod());
}
if (result.isModAndRef())
break;
}
return result;
}