#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
for (Operation *user : point.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
solver->enqueue({user, analysis});
}
AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
registerPointKind<CFGEdge>();
}
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
for (Region ®ion : top->getRegions()) {
if (region.empty())
continue;
for (Value argument : region.front().getArguments())
setToEntryState(getLatticeElement(argument));
}
return initializeRecursively(top);
}
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
visitOperation(op);
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
getOrCreate<Executable>(&block)->blockContentSubscribe(this);
visitBlock(&block);
for (Operation &op : block)
if (failed(initializeRecursively(&op)))
return failure();
}
}
return success();
}
LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
return success();
}
void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
if (op->getNumResults() == 0)
return;
if (!getOrCreate<Executable>(op->getBlock())->isLive())
return;
SmallVector<AbstractSparseLattice *> resultLattices;
resultLattices.reserve(op->getNumResults());
for (Value result : op->getResults()) {
AbstractSparseLattice *resultLattice = getLatticeElement(result);
resultLattices.push_back(resultLattice);
}
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
return visitRegionSuccessors({branch}, branch,
RegionBranchPoint::parent(),
resultLattices);
}
SmallVector<const AbstractSparseLattice *> operandLattices;
operandLattices.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
AbstractSparseLattice *operandLattice = getLatticeElement(operand);
operandLattice->useDefSubscribe(this);
operandLattices.push_back(operandLattice);
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
return visitExternalCallImpl(call, operandLattices, resultLattices);
}
const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
if (!predecessors->allPredecessorsKnown())
return setAllToEntryStates(resultLattices);
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
return;
}
visitOperationImpl(op, operandLattices, resultLattices);
}
void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (block->getNumArguments() == 0)
return;
if (!getOrCreate<Executable>(block)->isLive())
return;
SmallVector<AbstractSparseLattice *> argLattices;
argLattices.reserve(block->getNumArguments());
for (BlockArgument argument : block->getArguments()) {
AbstractSparseLattice *argLattice = getLatticeElement(argument);
argLattices.push_back(argLattice);
}
if (block->isEntryBlock()) {
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
return setAllToEntryStates(argLattices);
}
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
}
return;
}
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
return visitRegionSuccessors(block, branch, block->getParent(),
argLattices);
}
return visitNonControlFlowArgumentsImpl(block->getParentOp(),
RegionSuccessor(block->getParent()),
argLattices, 0);
}
for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
it != e; ++it) {
Block *predecessor = *it;
auto *edgeExecutable =
getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
edgeExecutable->blockContentSubscribe(this);
if (!edgeExecutable->isLive())
continue;
if (auto branch =
dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
SuccessorOperands operands =
branch.getSuccessorOperands(it.getSuccessorIndex());
for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
if (Value operand = operands[idx]) {
join(lattice, *getLatticeElementFor(block, operand));
} else {
setAllToEntryStates(lattice);
}
}
} else {
return setAllToEntryStates(argLattices);
}
}
}
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
for (Operation *op : predecessors->getKnownPredecessors()) {
std::optional<OperandRange> operands;
if (op == branch) {
operands = branch.getEntrySuccessorOperands(successor);
} else if (auto regionTerminator =
dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
operands = regionTerminator.getSuccessorOperands(successor);
}
if (!operands) {
return setAllToEntryStates(lattices);
}
ValueRange inputs = predecessors->getSuccessorInputs(op);
assert(inputs.size() == operands->size() &&
"expected the same number of successor inputs as operands");
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
if (llvm::dyn_cast_if_present<Operation *>(point)) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(
branch->getResults().slice(firstIndex, inputs.size())),
lattices, firstIndex);
} else {
if (!inputs.empty())
firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
Region *region = point.get<Block *>()->getParent();
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(region, region->getArguments().slice(
firstIndex, inputs.size())),
lattices, firstIndex);
}
}
for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
}
}
const AbstractSparseLattice *
AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
return state;
}
void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
ArrayRef<AbstractSparseLattice *> lattices) {
for (AbstractSparseLattice *lattice : lattices)
setToEntryState(lattice);
}
void AbstractSparseForwardDataFlowAnalysis::join(
AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
propagateIfChanged(lhs, lhs->join(rhs));
}
AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: DataFlowAnalysis(solver), symbolTable(symbolTable) {
registerPointKind<CFGEdge>();
}
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
return initializeRecursively(top);
}
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
visitOperation(op);
for (Region ®ion : op->getRegions()) {
for (Block &block : region) {
getOrCreate<Executable>(&block)->blockContentSubscribe(this);
for (auto it = block.rbegin(); it != block.rend(); it++)
if (failed(initializeRecursively(&*it)))
return failure();
}
}
return success();
}
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
else if (llvm::dyn_cast_if_present<Block *>(point))
return success();
else
return failure();
return success();
}
SmallVector<AbstractSparseLattice *>
AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
SmallVector<AbstractSparseLattice *> resultLattices;
resultLattices.reserve(values.size());
for (Value result : values) {
AbstractSparseLattice *resultLattice = getLatticeElement(result);
resultLattices.push_back(resultLattice);
}
return resultLattices;
}
SmallVector<const AbstractSparseLattice *>
AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
ProgramPoint point, ValueRange values) {
SmallVector<const AbstractSparseLattice *> resultLattices;
resultLattices.reserve(values.size());
for (Value result : values) {
const AbstractSparseLattice *resultLattice =
getLatticeElementFor(point, result);
resultLattices.push_back(resultLattice);
}
return resultLattices;
}
static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
}
void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
if (!getOrCreate<Executable>(op->getBlock())->isLive())
return;
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
getLatticeElementsFor(op, op->getResults());
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(branch, operandLattices);
return;
}
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
BitVector unaccounted(op->getNumOperands(), true);
for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
OperandRange forwarded = successorOperands.getForwardedOperands();
if (!forwarded.empty()) {
MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
forwarded.getBeginOperandIndex(), forwarded.size());
for (OpOperand &operand : operands) {
unaccounted.reset(operand.getOperandNumber());
if (std::optional<BlockArgument> blockArg =
detail::getBranchSuccessorArgument(
successorOperands, operand.getOperandNumber(), block)) {
meet(getLatticeElement(operand.get()),
*getLatticeElementFor(op, *blockArg));
}
}
}
}
for (int index : unaccounted.set_bits()) {
OpOperand &operand = op->getOpOperand(index);
visitBranchOperand(operand);
}
return;
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
Operation *callableOp = call.resolveCallable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
BitVector unaccounted(op->getNumOperands(), true);
OperandRange argOperands = call.getArgOperands();
MutableArrayRef<OpOperand> argOpOperands =
operandsToOpOperands(argOperands);
Region *region = callable.getCallableRegion();
if (!region || region->empty() || !getSolverConfig().isInterprocedural())
return visitExternalCallImpl(call, operandLattices, resultLattices);
Block &block = region->front();
for (auto [blockArg, argOpOperand] :
llvm::zip(block.getArguments(), argOpOperands)) {
meet(getLatticeElement(argOpOperand.get()),
*getLatticeElementFor(op, blockArg));
unaccounted.reset(argOpOperand.getOperandNumber());
}
for (int index : unaccounted.set_bits()) {
OpOperand &opOperand = op->getOpOperand(index);
visitCallOperand(opOperand);
}
return;
}
}
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return;
}
}
if (op->hasTrait<OpTrait::ReturnLike>()) {
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
const PredecessorState *callsites =
getOrCreateFor<PredecessorState>(op, callable);
if (callsites->allPredecessorsKnown()) {
for (Operation *call : callsites->getKnownPredecessors()) {
SmallVector<const AbstractSparseLattice *> callResultLattices =
getLatticeElementsFor(op, call->getResults());
for (auto [op, result] :
llvm::zip(operandLattices, callResultLattices))
meet(op, *result);
}
} else {
setAllToExitStates(operandLattices);
}
return;
}
}
visitOperationImpl(op, operandLattices, resultLattices);
}
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
Operation *op = branch.getOperation();
SmallVector<RegionSuccessor> successors;
SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
branch.getEntrySuccessorRegions(operands, successors);
BitVector unaccounted(op->getNumOperands(), true);
for (RegionSuccessor &successor : successors) {
OperandRange operands = branch.getEntrySuccessorOperands(successor);
MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
ValueRange inputs = successor.getSuccessorInputs();
for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
unaccounted.reset(operand.getOperandNumber());
}
}
for (int index : unaccounted.set_bits()) {
visitBranchOperand(op->getOpOperand(index));
}
}
void AbstractSparseBackwardDataFlowAnalysis::
visitRegionSuccessorsFromTerminator(
RegionBranchTerminatorOpInterface terminator,
RegionBranchOpInterface branch) {
assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
"expected a `RegionBranchTerminatorOpInterface` op");
assert(terminator->getParentOp() == branch.getOperation() &&
"expected `branch` to be the parent op of `terminator`");
SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
nullptr);
SmallVector<RegionSuccessor> successors;
terminator.getSuccessorRegions(operandAttributes, successors);
BitVector unaccounted(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : successors) {
ValueRange inputs = successor.getSuccessorInputs();
OperandRange operands = terminator.getSuccessorOperands(successor);
MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
meet(getLatticeElement(opOperand.get()),
*getLatticeElementFor(terminator, input));
unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
}
}
for (int index : unaccounted.set_bits()) {
visitBranchOperand(terminator->getOpOperand(index));
}
}
const AbstractSparseLattice *
AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
Value value) {
AbstractSparseLattice *state = getLatticeElement(value);
addDependency(state, point);
return state;
}
void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
ArrayRef<AbstractSparseLattice *> lattices) {
for (AbstractSparseLattice *lattice : lattices)
setToExitState(lattice);
}
void AbstractSparseBackwardDataFlowAnalysis::meet(
AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
propagateIfChanged(lhs, lhs->meet(rhs));
}