#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
#define DEBUG_TYPE "affine-utils"
using namespace mlir;
using namespace affine;
using namespace presburger;
namespace {
class AffineApplyExpander
: public AffineExprVisitor<AffineApplyExpander, Value> {
public:
AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
ValueRange symbolValues, Location loc)
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
loc(loc) {}
template <typename OpTy>
Value buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
auto op = builder.create<OpTy>(loc, lhs, rhs);
return op.getResult();
}
Value visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<arith::AddIOp>(expr);
}
Value visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<arith::MulIOp>(expr);
}
Value visitModExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
Value isRemainderNegative = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zeroCst);
Value correctedRemainder =
builder.create<arith::AddIOp>(loc, remainder, rhs);
Value result = builder.create<arith::SelectOp>(
loc, isRemainderNegative, correctedRemainder, remainder);
return result;
}
Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
}
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
Value negative = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, lhs, zeroCst);
Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
Value dividend =
builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
Value correctedQuotient =
builder.create<arith::SubIOp>(loc, noneCst, quotient);
Value result = builder.create<arith::SelectOp>(loc, negative,
correctedQuotient, quotient);
return result;
}
Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
}
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
Value nonPositive = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, lhs, zeroCst);
Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
Value dividend =
builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
Value negatedQuotient =
builder.create<arith::SubIOp>(loc, zeroCst, quotient);
Value incrementedQuotient =
builder.create<arith::AddIOp>(loc, quotient, oneCst);
Value result = builder.create<arith::SelectOp>(
loc, nonPositive, negatedQuotient, incrementedQuotient);
return result;
}
Value visitConstantExpr(AffineConstantExpr expr) {
auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
return op.getResult();
}
Value visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < dimValues.size() &&
"affine dim position out of range");
return dimValues[expr.getPosition()];
}
Value visitSymbolExpr(AffineSymbolExpr expr) {
assert(expr.getPosition() < symbolValues.size() &&
"symbol dim position out of range");
return symbolValues[expr.getPosition()];
}
private:
OpBuilder &builder;
ValueRange dimValues;
ValueRange symbolValues;
Location loc;
};
}
mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
AffineExpr expr,
ValueRange dimValues,
ValueRange symbolValues) {
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
}
std::optional<SmallVector<Value, 8>>
mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
AffineMap affineMap, ValueRange operands) {
auto numDims = affineMap.getNumDims();
auto expanded = llvm::to_vector<8>(
llvm::map_range(affineMap.getResults(),
[numDims, &builder, loc, operands](AffineExpr expr) {
return expandAffineExpr(builder, loc, expr,
operands.take_front(numDims),
operands.drop_front(numDims));
}));
if (llvm::all_of(expanded, [](Value v) { return v; }))
return expanded;
return std::nullopt;
}
static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) {
if (elseBlock)
assert(ifOp.hasElse() && "else block expected");
Block *destBlock = ifOp->getBlock();
Block *srcBlock = elseBlock ? ifOp.getElseBlock() : ifOp.getThenBlock();
destBlock->getOperations().splice(
Block::iterator(ifOp), srcBlock->getOperations(), srcBlock->begin(),
std::prev(srcBlock->end()));
ifOp.erase();
}
static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) {
auto ifOperands = ifOp.getOperands();
auto *res = ifOp.getOperation();
while (!isa<func::FuncOp>(res->getParentOp())) {
auto *parentOp = res->getParentOp();
if (auto forOp = dyn_cast<AffineForOp>(parentOp)) {
if (llvm::is_contained(ifOperands, forOp.getInductionVar()))
break;
} else if (auto parallelOp = dyn_cast<AffineParallelOp>(parentOp)) {
for (auto iv : parallelOp.getIVs())
if (llvm::is_contained(ifOperands, iv))
break;
} else if (!isa<AffineIfOp>(parentOp)) {
break;
}
res = parentOp;
}
return res;
}
static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) {
if (hoistOverOp == ifOp)
return ifOp;
IRMapping operandMap;
OpBuilder b(hoistOverOp);
auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(),
ifOp.getOperands(),
true);
Operation *hoistOverOpClone = nullptr;
StringAttr idForIfOp = b.getStringAttr("__mlir_if_hoisting");
operandMap.clear();
b.setInsertionPointAfter(hoistOverOp);
ifOp->setAttr(idForIfOp, b.getBoolAttr(true));
hoistOverOpClone = b.clone(*hoistOverOp, operandMap);
promoteIfBlock(ifOp, false);
auto *thenBlock = hoistedIfOp.getThenBlock();
thenBlock->getOperations().splice(thenBlock->begin(),
hoistOverOp->getBlock()->getOperations(),
Block::iterator(hoistOverOp));
AffineIfOp ifCloneInElse;
hoistOverOpClone->walk([&](AffineIfOp ifClone) {
if (!ifClone->getAttr(idForIfOp))
return WalkResult::advance();
ifCloneInElse = ifClone;
return WalkResult::interrupt();
});
assert(ifCloneInElse && "if op clone should exist");
if (!ifCloneInElse.hasElse())
ifCloneInElse.erase();
else
promoteIfBlock(ifCloneInElse, true);
auto *elseBlock = hoistedIfOp.getElseBlock();
elseBlock->getOperations().splice(
elseBlock->begin(), hoistOverOpClone->getBlock()->getOperations(),
Block::iterator(hoistOverOpClone));
return hoistedIfOp;
}
LogicalResult
mlir::affine::affineParallelize(AffineForOp forOp,
ArrayRef<LoopReduction> parallelReductions,
AffineParallelOp *resOp) {
unsigned numReductions = parallelReductions.size();
if (numReductions != forOp.getNumIterOperands())
return failure();
Location loc = forOp.getLoc();
OpBuilder outsideBuilder(forOp);
AffineMap lowerBoundMap = forOp.getLowerBoundMap();
ValueRange lowerBoundOperands = forOp.getLowerBoundOperands();
AffineMap upperBoundMap = forOp.getUpperBoundMap();
ValueRange upperBoundOperands = forOp.getUpperBoundOperands();
auto reducedValues = llvm::to_vector<4>(llvm::map_range(
parallelReductions, [](const LoopReduction &red) { return red.value; }));
auto reductionKinds = llvm::to_vector<4>(llvm::map_range(
parallelReductions, [](const LoopReduction &red) { return red.kind; }));
AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>(
loc, ValueRange(reducedValues).getTypes(), reductionKinds,
llvm::ArrayRef(lowerBoundMap), lowerBoundOperands,
llvm::ArrayRef(upperBoundMap), upperBoundOperands,
llvm::ArrayRef(forOp.getStepAsInt()));
newPloop.getRegion().takeBody(forOp.getRegion());
Operation *yieldOp = &newPloop.getBody()->back();
SmallVector<Value> newResults;
newResults.reserve(numReductions);
for (unsigned i = 0; i < numReductions; ++i) {
Value init = forOp.getInits()[i];
Operation *reductionOp = yieldOp->getOperand(i).getDefiningOp();
assert(reductionOp && "yielded value is expected to be produced by an op");
outsideBuilder.getInsertionBlock()->getOperations().splice(
outsideBuilder.getInsertionPoint(), newPloop.getBody()->getOperations(),
reductionOp);
reductionOp->setOperands({init, newPloop->getResult(i)});
forOp->getResult(i).replaceAllUsesWith(reductionOp->getResult(0));
}
unsigned numIVs = 1;
yieldOp->setOperands(reducedValues);
newPloop.getBody()->eraseArguments(numIVs, numReductions);
forOp.erase();
if (resOp)
*resOp = newPloop;
return success();
}
LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
if (ifOp.getNumResults() != 0)
return failure();
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config,
nullptr, &erased);
if (erased) {
if (folded)
*folded = true;
return failure();
}
if (folded)
*folded = false;
assert(llvm::all_of(ifOp.getOperands(),
[](Value v) {
return isTopLevelValue(v) || isAffineForInductionVar(v);
}) &&
"operands not composed");
auto *hoistOverOp = getOutermostInvariantForOp(ifOp);
AffineIfOp hoistedIfOp = ::hoistAffineIfOp(ifOp, hoistOverOp);
if (hoistedIfOp == ifOp)
return failure();
(void)applyPatternsAndFoldGreedily(
hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
frozenPatterns);
return success();
}
AffineExpr mlir::affine::substWithMin(AffineExpr e, AffineExpr dim,
AffineExpr min, AffineExpr max,
bool positivePath) {
if (e == dim)
return positivePath ? min : max;
if (auto bin = dyn_cast<AffineBinaryOpExpr>(e)) {
AffineExpr lhs = bin.getLHS();
AffineExpr rhs = bin.getRHS();
if (bin.getKind() == mlir::AffineExprKind::Add)
return substWithMin(lhs, dim, min, max, positivePath) +
substWithMin(rhs, dim, min, max, positivePath);
auto c1 = dyn_cast<AffineConstantExpr>(bin.getLHS());
auto c2 = dyn_cast<AffineConstantExpr>(bin.getRHS());
if (c1 && c1.getValue() < 0)
return getAffineBinaryOpExpr(
bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
if (c2 && c2.getValue() < 0)
return getAffineBinaryOpExpr(
bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
return getAffineBinaryOpExpr(
bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
substWithMin(rhs, dim, min, max, positivePath));
}
return e;
}
void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
if (op.hasMinMaxBounds())
return;
AffineMap lbMap = op.getLowerBoundsMap();
SmallVector<int64_t, 8> steps = op.getSteps();
bool isAlreadyNormalized =
llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
int64_t step = std::get<0>(tuple);
auto lbExpr = dyn_cast<AffineConstantExpr>(std::get<1>(tuple));
return lbExpr && lbExpr.getValue() == 0 && step == 1;
});
if (isAlreadyNormalized)
return;
AffineValueMap ranges;
AffineValueMap::difference(op.getUpperBoundsValueMap(),
op.getLowerBoundsValueMap(), &ranges);
auto builder = OpBuilder::atBlockBegin(op.getBody());
auto zeroExpr = builder.getAffineConstantExpr(0);
SmallVector<AffineExpr, 8> lbExprs;
SmallVector<AffineExpr, 8> ubExprs;
for (unsigned i = 0, e = steps.size(); i < e; ++i) {
int64_t step = steps[i];
lbExprs.push_back(zeroExpr);
AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
ubExprs.push_back(ubExpr);
BlockArgument iv = op.getBody()->getArgument(i);
AffineExpr lbExpr = lbMap.getResult(i);
unsigned nDims = lbMap.getNumDims();
auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
auto map = AffineMap::get(nDims + 1,
lbMap.getNumSymbols(), expr);
OperandRange lbOperands = op.getLowerBoundsOperands();
OperandRange dimOperands = lbOperands.take_front(nDims);
OperandRange symbolOperands = lbOperands.drop_front(nDims);
SmallVector<Value, 8> applyOperands{dimOperands};
applyOperands.push_back(iv);
applyOperands.append(symbolOperands.begin(), symbolOperands.end());
auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
iv.replaceAllUsesExcept(apply, apply);
}
SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
op.setSteps(newSteps);
auto newLowerMap = AffineMap::get(
0, 0, lbExprs, op.getContext());
op.setLowerBounds({}, newLowerMap);
auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
ubExprs, op.getContext());
op.setUpperBounds(ranges.getOperands(), newUpperMap);
}
LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
bool promoteSingleIter) {
if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
return success();
if (op.hasConstantLowerBound() && (op.getConstantLowerBound() == 0) &&
(op.getStep() == 1))
return success();
if (op.getLowerBoundMap().getNumResults() != 1)
return failure();
Location loc = op.getLoc();
OpBuilder opBuilder(op);
int64_t origLoopStep = op.getStepAsInt();
AffineMap oldLbMap = op.getLowerBoundMap();
SmallVector<AffineExpr> lbExprs(op.getUpperBoundMap().getNumResults(),
op.getLowerBoundMap().getResult(0));
AffineValueMap lbMap(oldLbMap, op.getLowerBoundOperands());
AffineMap paddedLbMap =
AffineMap::get(oldLbMap.getNumDims(), oldLbMap.getNumSymbols(), lbExprs,
op.getContext());
AffineValueMap paddedLbValueMap(paddedLbMap, op.getLowerBoundOperands());
AffineValueMap ubValueMap(op.getUpperBoundMap(), op.getUpperBoundOperands());
AffineValueMap newUbValueMap;
AffineValueMap::difference(ubValueMap, paddedLbValueMap, &newUbValueMap);
(void)newUbValueMap.canonicalize();
unsigned numResult = newUbValueMap.getNumResults();
SmallVector<AffineExpr> scaleDownExprs(numResult);
for (unsigned i = 0; i < numResult; ++i)
scaleDownExprs[i] = opBuilder.getAffineDimExpr(i).ceilDiv(origLoopStep);
AffineMap scaleDownMap =
AffineMap::get(numResult, 0, scaleDownExprs, op.getContext());
AffineMap newUbMap = scaleDownMap.compose(newUbValueMap.getAffineMap());
op.setUpperBound(newUbValueMap.getOperands(), newUbMap);
op.setLowerBound({}, opBuilder.getConstantAffineMap(0));
op.setStep(1);
opBuilder.setInsertionPointToStart(op.getBody());
AffineMap scaleIvMap =
AffineMap::get(1, 0, -opBuilder.getAffineDimExpr(0) * origLoopStep);
AffineValueMap scaleIvValueMap(scaleIvMap, ValueRange{op.getInductionVar()});
AffineValueMap newIvToOldIvMap;
AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap);
(void)newIvToOldIvMap.canonicalize();
auto newIV = opBuilder.create<AffineApplyOp>(
loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands());
op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
return success();
}
static bool mustReachAtInnermost(const MemRefAccess &srcAccess,
const MemRefAccess &destAccess) {
if (getAffineScope(srcAccess.opInst) != getAffineScope(destAccess.opInst))
return false;
unsigned nsLoops =
getNumCommonSurroundingLoops(*srcAccess.opInst, *destAccess.opInst);
DependenceResult result =
checkMemrefAccessDependence(srcAccess, destAccess, nsLoops + 1);
return hasDependence(result);
}
static bool mayHaveEffect(Operation *srcMemOp, Operation *destMemOp,
unsigned minSurroundingLoops) {
MemRefAccess srcAccess(srcMemOp);
MemRefAccess destAccess(destMemOp);
Region *srcScope = getAffineScope(srcMemOp);
if (srcAccess.memref == destAccess.memref &&
srcScope == getAffineScope(destMemOp)) {
unsigned nsLoops = getNumCommonSurroundingLoops(*srcMemOp, *destMemOp);
FlatAffineValueConstraints dependenceConstraints;
for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
DependenceResult result = checkMemrefAccessDependence(
srcAccess, destAccess, d, &dependenceConstraints,
nullptr);
if (!noDependence(result))
return true;
}
return false;
}
return true;
}
template <typename EffectType, typename T>
bool mlir::affine::hasNoInterveningEffect(
Operation *start, T memOp,
llvm::function_ref<bool(Value, Value)> mayAlias) {
bool hasSideEffect = false;
Value memref = memOp.getMemRef();
std::function<void(Operation *)> checkOperation = [&](Operation *op) {
if (hasSideEffect)
return;
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<MemoryEffects::EffectInstance, 1> effects;
memEffect.getEffects(effects);
bool opMayHaveEffect = false;
for (auto effect : effects) {
if (isa<EffectType>(effect.getEffect())) {
if (effect.getValue() && effect.getValue() != memref &&
!mayAlias(effect.getValue(), memref))
continue;
opMayHaveEffect = true;
break;
}
}
if (!opMayHaveEffect)
return;
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
unsigned minSurroundingLoops =
getNumCommonSurroundingLoops(*start, *memOp);
if (mayHaveEffect(op, memOp, minSurroundingLoops))
hasSideEffect = true;
return;
}
hasSideEffect = true;
return;
}
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
for (Region ®ion : op->getRegions())
for (Block &block : region)
for (Operation &op : block)
checkOperation(&op);
return;
}
hasSideEffect = true;
};
auto until = [&](Operation *parent, Operation *to) {
assert(parent->isAncestor(to));
checkOperation(parent);
};
std::function<void(Operation *, Operation *)> recur =
[&](Operation *from, Operation *untilOp) {
assert(
from->getParentRegion()->isAncestor(untilOp->getParentRegion()) &&
"Checking for side effect between two operations without a common "
"ancestor");
if (from->getParentRegion() != untilOp->getParentRegion()) {
recur(from, untilOp->getParentOp());
until(untilOp->getParentOp(), untilOp);
return;
}
SmallVector<Block *, 2> todoBlocks;
{
for (auto iter = ++from->getIterator(), end = from->getBlock()->end();
iter != end && &*iter != untilOp; ++iter) {
checkOperation(&*iter);
}
if (untilOp->getBlock() != from->getBlock())
for (Block *succ : from->getBlock()->getSuccessors())
todoBlocks.push_back(succ);
}
SmallPtrSet<Block *, 4> done;
while (!todoBlocks.empty()) {
Block *blk = todoBlocks.pop_back_val();
if (done.count(blk))
continue;
done.insert(blk);
for (auto &op : *blk) {
if (&op == untilOp)
break;
checkOperation(&op);
if (&op == blk->getTerminator())
for (Block *succ : blk->getSuccessors())
todoBlocks.push_back(succ);
}
}
};
recur(start, memOp);
return !hasSideEffect;
}
static void forwardStoreToLoad(
AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo,
llvm::function_ref<bool(Value, Value)> mayAlias) {
Operation *lastWriteStoreOp = nullptr;
for (auto *user : loadOp.getMemRef().getUsers()) {
auto storeOp = dyn_cast<AffineWriteOpInterface>(user);
if (!storeOp)
continue;
MemRefAccess srcAccess(storeOp);
MemRefAccess destAccess(loadOp);
if (srcAccess != destAccess)
continue;
if (!domInfo.dominates(storeOp, loadOp))
continue;
if (storeOp->getBlock() != loadOp->getBlock() &&
!mustReachAtInnermost(srcAccess, destAccess))
continue;
if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(storeOp, loadOp,
mayAlias))
continue;
assert(lastWriteStoreOp == nullptr &&
"multiple simultaneous replacement stores");
lastWriteStoreOp = storeOp;
}
if (!lastWriteStoreOp)
return;
Value storeVal =
cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
if (storeVal.getType() != loadOp.getValue().getType())
return;
loadOp.getValue().replaceAllUsesWith(storeVal);
memrefsToErase.insert(loadOp.getMemRef());
loadOpsToErase.push_back(loadOp);
}
template bool
mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read,
affine::AffineReadOpInterface>(
mlir::Operation *, affine::AffineReadOpInterface,
llvm::function_ref<bool(Value, Value)>);
static void findUnusedStore(AffineWriteOpInterface writeA,
SmallVectorImpl<Operation *> &opsToErase,
PostDominanceInfo &postDominanceInfo,
llvm::function_ref<bool(Value, Value)> mayAlias) {
for (Operation *user : writeA.getMemRef().getUsers()) {
auto writeB = dyn_cast<AffineWriteOpInterface>(user);
if (!writeB)
continue;
if (writeB == writeA)
continue;
if (writeB->getParentRegion() != writeA->getParentRegion())
continue;
MemRefAccess srcAccess(writeB);
MemRefAccess destAccess(writeA);
if (srcAccess != destAccess)
continue;
if (!postDominanceInfo.postDominates(writeB, writeA))
continue;
if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(writeA, writeB,
mayAlias))
continue;
opsToErase.push_back(writeA);
break;
}
}
static void loadCSE(AffineReadOpInterface loadA,
SmallVectorImpl<Operation *> &loadOpsToErase,
DominanceInfo &domInfo,
llvm::function_ref<bool(Value, Value)> mayAlias) {
SmallVector<AffineReadOpInterface, 4> loadCandidates;
for (auto *user : loadA.getMemRef().getUsers()) {
auto loadB = dyn_cast<AffineReadOpInterface>(user);
if (!loadB || loadB == loadA)
continue;
MemRefAccess srcAccess(loadB);
MemRefAccess destAccess(loadA);
if (srcAccess != destAccess) {
continue;
}
if (!domInfo.dominates(loadB, loadA))
continue;
if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
loadB.getOperation(), loadA, mayAlias))
continue;
if (loadB.getValue().getType() != loadA.getValue().getType())
continue;
loadCandidates.push_back(loadB);
}
Value loadB;
for (AffineReadOpInterface option : loadCandidates) {
if (llvm::all_of(loadCandidates, [&](AffineReadOpInterface depStore) {
return depStore == option ||
domInfo.dominates(option.getOperation(),
depStore.getOperation());
})) {
loadB = option.getValue();
break;
}
}
if (loadB) {
loadA.getValue().replaceAllUsesWith(loadB);
loadOpsToErase.push_back(loadA);
}
}
void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
PostDominanceInfo &postDomInfo,
AliasAnalysis &aliasAnalysis) {
SmallVector<Operation *, 8> opsToErase;
SmallPtrSet<Value, 4> memrefsToErase;
auto mayAlias = [&](Value val1, Value val2) -> bool {
return !aliasAnalysis.alias(val1, val2).isNo();
};
f.walk([&](AffineReadOpInterface loadOp) {
forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
});
for (auto *op : opsToErase)
op->erase();
opsToErase.clear();
f.walk([&](AffineWriteOpInterface storeOp) {
findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias);
});
for (auto *op : opsToErase)
op->erase();
opsToErase.clear();
for (auto memref : memrefsToErase) {
Operation *defOp = memref.getDefiningOp();
if (!defOp || !hasSingleEffect<MemoryEffects::Allocate>(defOp, memref))
continue;
if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) {
return !isa<AffineWriteOpInterface>(ownerOp) &&
!hasSingleEffect<MemoryEffects::Free>(ownerOp, memref);
}))
continue;
for (auto *user : llvm::make_early_inc_range(memref.getUsers()))
user->erase();
defOp->erase();
}
f.walk([&](AffineReadOpInterface loadOp) {
loadCSE(loadOp, opsToErase, domInfo, mayAlias);
});
for (auto *op : opsToErase)
op->erase();
}
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, Operation *op,
ArrayRef<Value> extraIndices, AffineMap indexRemap,
ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
bool allowNonDereferencingOps) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank;
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
"symbolic operand count mismatch");
assert(indexRemap.getNumInputs() ==
extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
}
assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
cast<MemRefType>(newMemRef.getType()).getElementType());
SmallVector<unsigned, 2> usePositions;
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
if (opEntry.value() == oldMemRef)
usePositions.push_back(opEntry.index());
}
if (usePositions.empty())
return success();
if (usePositions.size() > 1) {
assert(false && "multiple dereferencing uses in a single op not supported");
return failure();
}
unsigned memRefOperandPos = usePositions.front();
OpBuilder builder(op);
auto affMapAccInterface = dyn_cast<AffineMapAccessInterface>(op);
if (!affMapAccInterface) {
if (!allowNonDereferencingOps) {
return failure();
}
op->setOperand(memRefOperandPos, newMemRef);
return success();
}
NamedAttribute oldMapAttrPair =
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1,
op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
SmallVector<Value, 4> oldMemRefOperands;
SmallVector<Value, 4> affineApplyOps;
oldMemRefOperands.reserve(oldMemRefRank);
if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
for (auto resultExpr : oldMap.getResults()) {
auto singleResMap = AffineMap::get(oldMap.getNumDims(),
oldMap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
oldMapOperands);
oldMemRefOperands.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end());
}
SmallVector<Value, 4> remapOperands;
remapOperands.reserve(extraOperands.size() + oldMemRefRank +
symbolOperands.size());
remapOperands.append(extraOperands.begin(), extraOperands.end());
remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
remapOperands.append(symbolOperands.begin(), symbolOperands.end());
SmallVector<Value, 4> remapOutputs;
remapOutputs.reserve(oldMemRefRank);
if (indexRemap &&
indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
for (auto resultExpr : indexRemap.getResults()) {
auto singleResMap = AffineMap::get(
indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
remapOperands);
remapOutputs.push_back(afOp);
affineApplyOps.push_back(afOp);
}
} else {
remapOutputs.assign(remapOperands.begin(), remapOperands.end());
}
SmallVector<Value, 4> newMapOperands;
newMapOperands.reserve(newMemRefRank);
for (Value extraIndex : extraIndices) {
assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
"invalid memory op index");
newMapOperands.push_back(extraIndex);
}
newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
assert(newMapOperands.size() == newMemRefRank);
auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
newMap = simplifyAffineMap(newMap);
canonicalizeMapAndOperands(&newMap, &newMapOperands);
for (Value value : affineApplyOps)
if (value.use_empty())
value.getDefiningOp()->erase();
OperationState state(op->getLoc(), op->getName());
state.operands.reserve(op->getNumOperands() + extraIndices.size());
state.operands.append(op->operand_begin(),
op->operand_begin() + memRefOperandPos);
state.operands.push_back(newMemRef);
state.operands.append(newMapOperands.begin(), newMapOperands.end());
state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
oldMapNumInputs,
op->operand_end());
state.types.reserve(op->getNumResults());
for (auto result : op->getResults())
state.types.push_back(result.getType());
auto newMapAttr = AffineMapAttr::get(newMap);
for (auto namedAttr : op->getAttrs()) {
if (namedAttr.getName() == oldMapAttrPair.getName())
state.attributes.push_back({namedAttr.getName(), newMapAttr});
else
state.attributes.push_back(namedAttr);
}
auto *repOp = builder.create(state);
op->replaceAllUsesWith(repOp);
op->erase();
return success();
}
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
AffineMap indexRemap, ArrayRef<Value> extraOperands,
ArrayRef<Value> symbolOperands, Operation *domOpFilter,
Operation *postDomOpFilter, bool allowNonDereferencingOps,
bool replaceInDeallocOp) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank;
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
"symbol operand count mismatch");
assert(indexRemap.getNumInputs() ==
extraOperands.size() + oldMemRefRank + symbolOperands.size());
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
}
assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
cast<MemRefType>(newMemRef.getType()).getElementType());
std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
if (domOpFilter)
domInfo = std::make_unique<DominanceInfo>(
domOpFilter->getParentOfType<func::FuncOp>());
if (postDomOpFilter)
postDomInfo = std::make_unique<PostDominanceInfo>(
postDomOpFilter->getParentOfType<func::FuncOp>());
DenseSet<Operation *> opsToReplace;
for (auto *op : oldMemRef.getUsers()) {
if (domOpFilter && !domInfo->dominates(domOpFilter, op))
continue;
if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
continue;
if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) &&
!replaceInDeallocOp)
continue;
if (!isa<AffineMapAccessInterface>(*op)) {
if (!allowNonDereferencingOps) {
LLVM_DEBUG(llvm::dbgs()
<< "Memref replacement failed: non-deferencing memref op: \n"
<< *op << '\n');
return failure();
}
if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
"memrefs normalizable trait: \n"
<< *op << '\n');
return failure();
}
}
opsToReplace.insert(op);
}
for (auto *op : opsToReplace) {
if (failed(replaceAllMemRefUsesWith(
oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
symbolOperands, allowNonDereferencingOps)))
llvm_unreachable("memref replacement guaranteed to succeed here");
}
return success();
}
void mlir::affine::createAffineComputationSlice(
Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
SmallVector<Value, 4> subOperands;
subOperands.reserve(opInst->getNumOperands());
for (auto operand : opInst->getOperands())
if (isa_and_nonnull<AffineApplyOp>(operand.getDefiningOp()))
subOperands.push_back(operand);
SmallVector<Operation *, 4> affineApplyOps;
getReachableAffineApplyOps(subOperands, affineApplyOps);
if (affineApplyOps.empty())
return;
bool localized = true;
for (auto *op : affineApplyOps) {
for (auto result : op->getResults()) {
for (auto *user : result.getUsers()) {
if (user != opInst) {
localized = false;
break;
}
}
}
}
if (localized)
return;
OpBuilder builder(opInst);
SmallVector<Value, 4> composedOpOperands(subOperands);
auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
sliceOps->reserve(composedMap.getNumResults());
for (auto resultExpr : composedMap.getResults()) {
auto singleResMap = AffineMap::get(composedMap.getNumDims(),
composedMap.getNumSymbols(), resultExpr);
sliceOps->push_back(builder.create<AffineApplyOp>(
opInst->getLoc(), singleResMap, composedOpOperands));
}
SmallVector<Value, 4> newOperands(opInst->getOperands());
for (Value &operand : newOperands) {
unsigned j, f;
for (j = 0, f = subOperands.size(); j < f; j++) {
if (operand == subOperands[j])
break;
}
if (j < subOperands.size())
operand = (*sliceOps)[j];
}
for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++)
opInst->setOperand(idx, newOperands[idx]);
}
enum TileExprPattern { TileFloorDiv, TileMod, TileNone };
static LogicalResult getTileSizePos(
AffineMap map,
SmallVectorImpl<std::tuple<AffineExpr, unsigned, unsigned>> &tileSizePos) {
SmallVector<std::tuple<AffineExpr, AffineExpr, unsigned>, 4> floordivExprs;
unsigned pos = 0;
for (AffineExpr expr : map.getResults()) {
if (expr.getKind() == AffineExprKind::FloorDiv) {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (isa<AffineConstantExpr>(binaryExpr.getRHS()))
floordivExprs.emplace_back(
std::make_tuple(binaryExpr.getLHS(), binaryExpr.getRHS(), pos));
}
pos++;
}
if (floordivExprs.empty()) {
tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
return success();
}
for (std::tuple<AffineExpr, AffineExpr, unsigned> fexpr : floordivExprs) {
AffineExpr floordivExprLHS = std::get<0>(fexpr);
AffineExpr floordivExprRHS = std::get<1>(fexpr);
unsigned floordivPos = std::get<2>(fexpr);
bool found = false;
pos = 0;
for (AffineExpr expr : map.getResults()) {
bool notTiled = false;
if (pos != floordivPos) {
expr.walk([&](AffineExpr e) {
if (e == floordivExprLHS) {
if (expr.getKind() == AffineExprKind::Mod) {
AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(expr);
if (floordivExprLHS == binaryExpr.getLHS() &&
floordivExprRHS == binaryExpr.getRHS()) {
if (!found) {
tileSizePos.emplace_back(
std::make_tuple(binaryExpr.getRHS(), floordivPos, pos));
found = true;
} else {
notTiled = true;
}
} else {
notTiled = true;
}
} else {
notTiled = true;
}
}
});
}
if (notTiled) {
tileSizePos = SmallVector<std::tuple<AffineExpr, unsigned, unsigned>>{};
return success();
}
pos++;
}
}
return success();
}
static bool
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
AffineExpr expr = layoutMap.getResults()[dim];
MLIRContext *context = layoutMap.getContext();
return expr
.walk([&](AffineExpr e) {
if (isa<AffineDimExpr>(e) &&
llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
return e == getAffineDimExpr(dim, context);
}))
return WalkResult::interrupt();
return WalkResult::advance();
})
.wasInterrupted();
}
static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
TileExprPattern pat) {
AffineExpr newMapOutput;
AffineBinaryOpExpr binaryExpr = nullptr;
switch (pat) {
case TileExprPattern::TileMod:
binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
newMapOutput = binaryExpr.getRHS();
break;
case TileExprPattern::TileFloorDiv:
binaryExpr = cast<AffineBinaryOpExpr>(oldMapOutput);
newMapOutput = getAffineBinaryOpExpr(
AffineExprKind::CeilDiv, binaryExpr.getLHS(), binaryExpr.getRHS());
break;
default:
newMapOutput = oldMapOutput;
}
return newMapOutput;
}
static void createNewDynamicSizes(MemRefType oldMemRefType,
MemRefType newMemRefType, AffineMap map,
memref::AllocOp *allocOp, OpBuilder b,
SmallVectorImpl<Value> &newDynamicSizes) {
SmallVector<Value, 4> inAffineApply;
ArrayRef<int64_t> oldMemRefShape = oldMemRefType.getShape();
unsigned dynIdx = 0;
for (unsigned d = 0; d < oldMemRefType.getRank(); ++d) {
if (oldMemRefShape[d] < 0) {
inAffineApply.emplace_back(allocOp->getDynamicSizes()[dynIdx]);
dynIdx++;
} else {
auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
inAffineApply.emplace_back(
b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
}
}
unsigned newDimIdx = 0;
ArrayRef<int64_t> newMemRefShape = newMemRefType.getShape();
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
(void)getTileSizePos(map, tileSizePos);
for (AffineExpr expr : map.getResults()) {
if (newMemRefShape[newDimIdx] < 0) {
enum TileExprPattern pat = TileExprPattern::TileNone;
for (auto pos : tileSizePos) {
if (newDimIdx == std::get<1>(pos))
pat = TileExprPattern::TileFloorDiv;
else if (newDimIdx == std::get<2>(pos))
pat = TileExprPattern::TileMod;
}
AffineExpr newMapOutput = createDimSizeExprForTiledLayout(expr, pat);
AffineMap newMap =
AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput);
Value affineApp =
b.create<AffineApplyOp>(allocOp->getLoc(), newMap, inAffineApply);
newDynamicSizes.emplace_back(affineApp);
}
newDimIdx++;
}
}
LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
MemRefType memrefType = allocOp->getType();
OpBuilder b(*allocOp);
MemRefType newMemRefType = normalizeMemRefType(memrefType);
if (newMemRefType == memrefType)
return failure();
Value oldMemRef = allocOp->getResult();
SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
memref::AllocOp newAlloc;
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
(void)getTileSizePos(layoutMap, tileSizePos);
if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType());
SmallVector<Value, 4> newDynamicSizes;
createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
newDynamicSizes);
newAlloc =
b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
newDynamicSizes, allocOp->getAlignmentAttr());
} else {
newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
allocOp->getAlignmentAttr());
}
if (failed(replaceAllMemRefUsesWith(oldMemRef, newAlloc,
{},
layoutMap,
{},
symbolOperands,
nullptr,
nullptr,
true))) {
newAlloc.erase();
return failure();
}
assert(llvm::all_of(oldMemRef.getUsers(), [&](Operation *op) {
return hasSingleEffect<MemoryEffects::Free>(op, oldMemRef);
}));
oldMemRef.replaceAllUsesWith(newAlloc);
allocOp->erase();
return success();
}
MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
unsigned rank = memrefType.getRank();
if (rank == 0)
return memrefType;
if (memrefType.getLayout().isIdentity()) {
return memrefType;
}
AffineMap layoutMap = memrefType.getLayout().getAffineMap();
unsigned numSymbolicOperands = layoutMap.getNumSymbols();
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
(void)getTileSizePos(layoutMap, tileSizePos);
if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty())
return memrefType;
ArrayRef<int64_t> shape = memrefType.getShape();
FlatAffineValueConstraints fac(rank, numSymbolicOperands);
SmallVector<unsigned, 4> memrefTypeDynDims;
for (unsigned d = 0; d < rank; ++d) {
if (shape[d] > 0) {
fac.addBound(BoundType::LB, d, 0);
fac.addBound(BoundType::UB, d, shape[d] - 1);
} else {
memrefTypeDynDims.emplace_back(d);
}
}
unsigned newRank = layoutMap.getNumResults();
if (failed(fac.composeMatchingMap(layoutMap)))
return memrefType;
fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars());
SmallVector<int64_t, 4> newShape(newRank);
MLIRContext *context = memrefType.getContext();
for (unsigned d = 0; d < newRank; ++d) {
if (isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
newShape[d] = ShapedType::kDynamic;
continue;
}
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
if (!ubConst.has_value() || *ubConst < 0) {
LLVM_DEBUG(llvm::dbgs()
<< "can't normalize map due to unknown/invalid upper bound");
return memrefType;
}
newShape[d] = *ubConst + 1;
}
auto newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setLayout(AffineMapAttr::get(
AffineMap::getMultiDimIdentityMap(newRank, context)));
return newMemRefType;
}
DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs,
Value rhs) {
DivModValue result;
AffineExpr d0, d1;
bindDims(b.getContext(), d0, d1);
result.quotient =
affine::makeComposedAffineApply(b, loc, d0.floorDiv(d1), {lhs, rhs});
result.remainder =
affine::makeComposedAffineApply(b, loc, d0 % d1, {lhs, rhs});
return result;
}
static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
ArrayRef<Value> set) {
if (set.empty())
return failure();
OpFoldResult result = set[0];
AffineExpr s0, s1;
bindSymbols(b.getContext(), s0, s1);
for (unsigned i = 1, e = set.size(); i < e; i++)
result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
return result;
}
FailureOr<SmallVector<Value>>
mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
ArrayRef<Value> basis) {
unsigned numDims = basis.size();
SmallVector<Value> divisors;
for (unsigned i = 1; i < numDims; i++) {
ArrayRef<Value> slice = basis.drop_front(i);
FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
if (failed(prod))
return failure();
divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
}
SmallVector<Value> results;
results.reserve(divisors.size() + 1);
Value residual = linearIndex;
for (Value divisor : divisors) {
DivModValue divMod = getDivMod(b, loc, residual, divisor);
results.push_back(divMod.quotient);
residual = divMod.remainder;
}
results.push_back(residual);
return results;
}
OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder) {
assert(multiIndex.size() == basis.size());
SmallVector<AffineExpr> basisAffine;
for (size_t i = 0; i < basis.size(); ++i) {
basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
}
SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
SmallVector<OpFoldResult> strides;
strides.reserve(stridesAffine.size());
llvm::transform(stridesAffine, std::back_inserter(strides),
[&builder, &basis](AffineExpr strideExpr) {
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), strideExpr, basis);
});
auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
return affine::makeComposedFoldedAffineApply(
builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
}