#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <optional>
using namespace mlir;
using namespace presburger;
void MultiAffineFunction::assertIsConsistent() const {
assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
output.getNumColumns() &&
"Inconsistent number of output columns");
assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
divs.getNumNonDivs() &&
"Inconsistent number of non-division variables in divs");
assert(space.getNumRangeVars() == output.getNumRows() &&
"Inconsistent number of output rows");
assert(space.getNumLocalVars() == divs.getNumDivs() &&
"Inconsistent number of divisions.");
assert(divs.hasAllReprs() && "All divisions should have a representation");
}
static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA,
ArrayRef<DynamicAPInt> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<DynamicAPInt, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.emplace_back(vecA[i] - vecB[i]);
return result;
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
for (const Piece &piece : pieces)
domain.unionInPlace(piece.domain);
return domain;
}
void MultiAffineFunction::print(raw_ostream &os) const {
space.print(os);
os << "Division Representation:\n";
divs.print(os);
os << "Output:\n";
output.print(os);
}
SmallVector<DynamicAPInt, 8>
MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
"Point has incorrect dimensionality!");
SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)};
SmallVector<std::optional<DynamicAPInt>, 8> divValues =
divs.divValuesAt(point);
pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
for (const std::optional<DynamicAPInt> &divVal : divValues)
pointHomogenous.emplace_back(*divVal);
pointHomogenous.emplace_back(1);
SmallVector<DynamicAPInt, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return getAsRelation().isEqual(other.getAsRelation());
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const IntegerPolyhedron &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
IntegerRelation restrictedThis = getAsRelation();
restrictedThis.intersectDomain(domain);
IntegerRelation restrictedOther = other.getAsRelation();
restrictedOther.intersectDomain(domain);
return restrictedThis.isEqual(restrictedOther);
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const PresburgerSet &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return llvm::all_of(domain.getAllDisjuncts(),
[&](const IntegerRelation &disjunct) {
return isEqual(other, IntegerPolyhedron(disjunct));
});
}
void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
assert(end <= getNumOutputs() && "Invalid range");
if (start >= end)
return;
space.removeVarRange(VarKind::Range, start, end);
output.removeRows(start, end - start);
}
void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
assert(space.isCompatible(other.space) && "Functions should be compatible");
unsigned nDivs = getNumDivs();
unsigned divOffset = divs.getDivOffset();
other.divs.insertDiv(0, nDivs);
SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1);
for (unsigned i = 0; i < nDivs; ++i) {
std::fill(div.begin(), div.end(), 0);
std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
div.begin());
div.back() = divs.getDividend(i).back();
other.divs.setDiv(i, div, divs.getDenom(i));
}
other.space.insertVar(VarKind::Local, 0, nDivs);
other.output.insertColumns(divOffset, nDivs);
auto merge = [&](unsigned i, unsigned j) {
if (i >= j)
return false;
if (j < nDivs)
return false;
other.space.removeVarRange(VarKind::Local, j, j + 1);
other.output.addToColumn(divOffset + i, divOffset + j, 1);
other.output.removeColumn(divOffset + j);
return true;
};
other.divs.removeDuplicateDivs(merge);
unsigned newDivs = other.divs.getNumDivs() - nDivs;
space.insertVar(VarKind::Local, nDivs, newDivs);
output.insertColumns(divOffset + nDivs, newDivs);
divs = other.divs;
assertIsConsistent();
other.assertIsConsistent();
}
PresburgerSet
MultiAffineFunction::getLexSet(OrderingKind comp,
const MultiAffineFunction &other) const {
assert(getSpace().isCompatible(other.getSpace()) &&
"Output space of funcs should be compatible");
MultiAffineFunction funcA = *this;
MultiAffineFunction funcB = other;
funcA.mergeDivs(funcB);
PresburgerSpace resultSpace = funcA.getDomainSpace();
PresburgerSet result =
PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
IntegerPolyhedron levelSet(
1 + 2 * resultSpace.getNumLocalVars(),
funcA.getNumOutputs(),
resultSpace.getNumVars() + 1, resultSpace);
for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
funcA.divs.getDenom(i),
funcA.divs.getDivOffset() + i));
levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
funcA.divs.getDenom(i),
funcA.divs.getDivOffset() + i));
}
for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
SmallVector<DynamicAPInt, 8> subExpr =
subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
switch (comp) {
case OrderingKind::LT:
levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1));
break;
case OrderingKind::GT:
levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1));
break;
case OrderingKind::GE:
case OrderingKind::LE:
case OrderingKind::EQ:
case OrderingKind::NE:
assert(false && "Not implemented case");
}
result.unionInPlace(levelSet);
levelSet.removeInequality(levelSet.getNumInequalities() - 1);
levelSet.addEquality(subExpr);
}
return result;
}
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
if (!space.isCompatible(other.space))
return false;
if (!this->getDomain().isEqual(other.getDomain()))
return false;
return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
return pieceA.output.isEqual(pieceB.output, commonDomain);
});
});
}
void PWMAFunction::addPiece(const Piece &piece) {
assert(piece.isConsistent() && "Piece should be consistent");
assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
"Piece should be disjoint from the function");
pieces.emplace_back(piece);
}
void PWMAFunction::print(raw_ostream &os) const {
space.print(os);
os << getNumPieces() << " pieces:\n";
for (const Piece &piece : pieces) {
os << "Domain of piece:\n";
piece.domain.print(os);
os << "Output of piece\n";
piece.output.print(os);
}
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
"Ranges of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
PWMAFunction result(getSpace());
for (const Piece &pieceA : pieces) {
PresburgerSet dom(pieceA.domain);
for (const Piece &pieceB : func.pieces) {
PresburgerSet better = tiebreak(pieceB, pieceA);
result.addPiece({better, pieceB.output});
dom = dom.subtract(better);
}
result.addPiece({dom, pieceA.output});
}
PresburgerSet dom = getDomain();
for (const Piece &pieceB : func.pieces)
result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
return result;
}
template <OrderingKind comp>
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
const PWMAFunction::Piece &pieceB) {
PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
result = result.intersect(pieceA.domain).intersect(pieceB.domain);
return result;
}
PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex<OrderingKind::LT>);
}
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex<OrderingKind::GT>);
}
void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for subtraction.");
MultiAffineFunction copyOther = other;
mergeDivs(copyOther);
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1));
assertIsConsistent();
}
static void addDivisionConstraints(IntegerRelation &rel,
const DivisionRepr &divs) {
assert(divs.hasAllReprs() &&
"All divisions in divs should have a representation");
assert(rel.getNumVars() == divs.getNumVars() &&
"Relation and divs should have the same number of vars");
assert(rel.getNumLocalVars() == divs.getNumDivs() &&
"Relation and divs should have the same number of local vars");
for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
}
}
IntegerRelation MultiAffineFunction::getAsRelation() const {
IntegerRelation result(PresburgerSpace::getRelationSpace(
space.getNumDomainVars(), 0, space.getNumSymbolVars(),
space.getNumLocalVars()));
addDivisionConstraints(result, divs);
result.insertVar(VarKind::Range, 0, getNumOutputs());
SmallVector<DynamicAPInt, 8> eq(result.getNumCols());
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
ArrayRef<DynamicAPInt> expr = getOutputExpr(i);
std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
std::copy(expr.begin() + getNumDomainVars(), expr.end(),
eq.begin() + result.getVarKindEnd(VarKind::Range));
eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
result.addEquality(eq);
}
return result;
}
void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
space.removeVarRange(VarKind::Range, start, end);
for (Piece &piece : pieces)
piece.output.removeOutputs(start, end);
}
std::optional<SmallVector<DynamicAPInt, 8>>
PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars());
for (const Piece &piece : pieces)
if (piece.domain.containsPoint(point))
return piece.output.valueAt(point);
return std::nullopt;
}