#include "Utility.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include <limits>
namespace deduceMin {
int deduceMinCountInBlock(Block &block,
const std::function<int(Operation *)> &countFunc);
int deduceMinCountBetweeOps(Operation *beginOp, Operation *endOp,
const std::function<int(Operation *)> &countFunc) {
assert(beginOp && endOp);
assert(beginOp == endOp || beginOp->isBeforeInBlock(endOp));
int count = 0;
for (auto op = beginOp; op != endOp; op = op->getNextNode()) {
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
assert(!ifOp.getThenRegion().empty() && !ifOp.getElseRegion().empty());
auto minThen =
deduceMinCountInBlock(ifOp.getThenRegion().front(), countFunc);
auto minElse =
deduceMinCountInBlock(ifOp.getElseRegion().front(), countFunc);
count += std::min(minThen, minElse);
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
auto tripCount = constantTripCount(forOp.getLowerBound(),
forOp.getUpperBound(), forOp.getStep())
.value_or(0);
if (tripCount > 0) {
count += tripCount * deduceMinCountInBlock(*forOp.getBody(), countFunc);
}
} else {
count += countFunc(op);
}
}
return count;
}
int deduceMinCountInBlock(Block &block,
const std::function<int(Operation *)> &countFunc) {
if (block.empty())
return 0;
return deduceMinCountBetweeOps(&block.front(), &block.back(), countFunc);
}
}
int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
const std::function<int(Operation *)> &countFunc,
int pathSum, int foundMin) {
using namespace deduceMin;
while (consumerOp->getParentRegion() != defValue.getParentRegion()) {
pathSum += deduceMin::deduceMinCountBetweeOps(
&consumerOp->getBlock()->front(), consumerOp, countFunc);
consumerOp = consumerOp->getParentOp();
}
if (Operation *defOp = defValue.getDefiningOp()) {
pathSum +=
deduceMinCountBetweeOps(defOp->getNextNode(), consumerOp, countFunc);
foundMin = std::min(foundMin, pathSum);
return foundMin;
}
if (auto arg = mlir::dyn_cast<BlockArgument>(defValue)) {
Block *block = arg.getOwner();
auto forOp = dyn_cast<scf::ForOp>(block->getParentOp());
if (!forOp || forOp.getBody()->empty()) {
return 0;
}
Operation *firstOpInLoop = &*forOp.getBody()->begin();
pathSum += deduceMinCountBetweeOps(firstOpInLoop, consumerOp, countFunc);
if (pathSum >= foundMin)
return foundMin;
Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1];
int countLoopInit = deduceMinCountOnDefChain(incomingVal, forOp, countFunc,
pathSum, foundMin);
Operation *yieldOp = block->getTerminator();
Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1);
int countPreviousIter = deduceMinCountOnDefChain(
prevVal, yieldOp, countFunc, pathSum, foundMin);
return std::min(std::min(countLoopInit, countPreviousIter), foundMin);
}
return 0;
}
int deduceMinCountOnDefChain(Value defValue, Operation *consumerOp,
llvm::function_ref<int(Operation *)> countFunc) {
return deduceMinCountOnDefChain(defValue, consumerOp, countFunc, 0,
std::numeric_limits<int>::max());
}