#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Fraction.h"
#include "mlir/Analysis/Presburger/LinearTransform.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#define DEBUG_TYPE "presburger"
using namespace mlir;
using namespace presburger;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
return std::make_unique<IntegerRelation>(*this);
}
std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
return std::make_unique<IntegerPolyhedron>(*this);
}
void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
assert(space.getNumVars() == oSpace.getNumVars() && "invalid space!");
space = oSpace;
}
void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
assert(oSpace.getNumLocalVars() == 0 && "no locals should be present!");
assert(oSpace.getNumVars() <= getNumVars() && "invalid space!");
unsigned newNumLocals = getNumVars() - oSpace.getNumVars();
space = oSpace;
space.insertVar(VarKind::Local, 0, newNumLocals);
}
void IntegerRelation::setId(VarKind kind, unsigned i, Identifier id) {
assert(space.isUsingIds() &&
"space must be using identifiers to set an identifier");
assert(kind != VarKind::Local && "local variables cannot have identifiers");
assert(i < space.getNumVarKind(kind) && "invalid variable index");
space.setId(kind, i, id);
}
ArrayRef<Identifier> IntegerRelation::getIds(VarKind kind) {
if (!space.isUsingIds())
space.resetIds();
return space.getIds(kind);
}
void IntegerRelation::append(const IntegerRelation &other) {
assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
inequalities.reserveRows(inequalities.getNumRows() +
other.getNumInequalities());
equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
addInequality(other.getInequality(r));
}
for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
addEquality(other.getEquality(r));
}
}
IntegerRelation IntegerRelation::intersect(IntegerRelation other) const {
IntegerRelation result = *this;
result.mergeLocalVars(other);
result.append(other);
return result;
}
bool IntegerRelation::isEqual(const IntegerRelation &other) const {
assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
}
bool IntegerRelation::isObviouslyEqual(const IntegerRelation &other) const {
if (!space.isEqual(other.getSpace()))
return false;
if (getNumEqualities() != other.getNumEqualities())
return false;
if (getNumInequalities() != other.getNumInequalities())
return false;
unsigned cols = getNumCols();
for (unsigned i = 0, eqs = getNumEqualities(); i < eqs; ++i) {
for (unsigned j = 0; j < cols; ++j) {
if (atEq(i, j) != other.atEq(i, j))
return false;
}
}
for (unsigned i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) {
for (unsigned j = 0; j < cols; ++j) {
if (atIneq(i, j) != other.atIneq(i, j))
return false;
}
}
return true;
}
bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
assert(space.isCompatible(other.getSpace()) && "Spaces must be compatible.");
return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other));
}
MaybeOptimum<SmallVector<Fraction, 8>>
IntegerRelation::findRationalLexMin() const {
assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
LexSimplex(*this).findRationalLexMin();
if (!maybeLexMin.isBounded())
return maybeLexMin;
assert(maybeLexMin->size() == getNumVars() &&
"Incorrect number of vars in lexMin!");
maybeLexMin->resize(getNumDimAndSymbolVars());
return maybeLexMin;
}
MaybeOptimum<SmallVector<DynamicAPInt, 8>>
IntegerRelation::findIntegerLexMin() const {
assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
MaybeOptimum<SmallVector<DynamicAPInt, 8>> maybeLexMin =
LexSimplex(*this).findIntegerLexMin();
if (!maybeLexMin.isBounded())
return maybeLexMin.getKind();
assert(maybeLexMin->size() == getNumVars() &&
"Incorrect number of vars in lexMin!");
maybeLexMin->resize(getNumDimAndSymbolVars());
return maybeLexMin;
}
static bool rangeIsZero(ArrayRef<DynamicAPInt> range) {
return llvm::all_of(range, [](const DynamicAPInt &x) { return x == 0; });
}
static void removeConstraintsInvolvingVarRange(IntegerRelation &poly,
unsigned begin, unsigned count) {
for (unsigned i = poly.getNumEqualities(); i > 0; i--)
if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count)))
poly.removeEquality(i - 1);
for (unsigned i = poly.getNumInequalities(); i > 0; i--)
if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
poly.removeInequality(i - 1);
}
IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
return {getSpace(), getNumInequalities(), getNumEqualities()};
}
void IntegerRelation::truncateVarKind(VarKind kind, unsigned num) {
unsigned curNum = getNumVarKind(kind);
assert(num <= curNum && "Can't truncate to more vars!");
removeVarRange(kind, num, curNum);
}
void IntegerRelation::truncateVarKind(VarKind kind,
const CountsSnapshot &counts) {
truncateVarKind(kind, counts.getSpace().getNumVarKind(kind));
}
void IntegerRelation::truncate(const CountsSnapshot &counts) {
truncateVarKind(VarKind::Domain, counts);
truncateVarKind(VarKind::Range, counts);
truncateVarKind(VarKind::Symbol, counts);
truncateVarKind(VarKind::Local, counts);
removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
}
PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
if (getNumLocalVars() == 0)
return PresburgerRelation(*this);
IntegerRelation copy = *this;
std::vector<MaybeLocalRepr> reprs(getNumLocalVars());
copy.getLocalReprs(&reprs);
unsigned numNonDivLocals = 0;
unsigned offset = copy.getVarKindOffset(VarKind::Local);
for (unsigned i = 0, e = copy.getNumLocalVars(); i < e - numNonDivLocals;) {
if (!reprs[i]) {
copy.swapVar(offset + i, offset + e - numNonDivLocals - 1);
std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
++numNonDivLocals;
continue;
}
++i;
}
if (numNonDivLocals == 0)
return PresburgerRelation(*this);
SymbolicLexOpt lexminResult =
SymbolicLexSimplex(copy, 0,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
copy.getNumVars() - numNonDivLocals)))
.computeSymbolicIntegerLexMin();
PresburgerRelation result =
lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain);
PresburgerSpace space = getSpace();
space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
result.setSpace(space);
return result;
}
SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMin() const {
llvm::SmallBitVector isSymbol(getNumVars(), false);
isSymbol.set(getVarKindOffset(VarKind::Symbol),
getVarKindEnd(VarKind::Symbol));
isSymbol.set(getVarKindOffset(VarKind::Domain),
getVarKindEnd(VarKind::Domain));
SymbolicLexOpt result =
SymbolicLexSimplex(*this,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
getNumDomainVars(),
getNumSymbolVars())),
isSymbol)
.computeSymbolicIntegerLexMin();
result.lexopt.removeOutputs(result.lexopt.getNumOutputs() - getNumLocalVars(),
result.lexopt.getNumOutputs());
return result;
}
SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const {
IntegerRelation flippedRel = *this;
for (unsigned j = getNumDomainVars(),
b = getNumDomainVars() + getNumRangeVars();
j < b; j++) {
for (unsigned i = 0, a = getNumEqualities(); i < a; i++)
flippedRel.atEq(i, j) = -1 * atEq(i, j);
for (unsigned i = 0, a = getNumInequalities(); i < a; i++)
flippedRel.atIneq(i, j) = -1 * atIneq(i, j);
}
SymbolicLexOpt flippedSymbolicIntegerLexMax =
flippedRel.findSymbolicIntegerLexMin(),
symbolicIntegerLexMax(
flippedSymbolicIntegerLexMax.lexopt.getSpace());
for (auto &flippedPiece :
flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) {
IntMatrix mat = flippedPiece.output.getOutputMatrix();
for (unsigned i = 0, e = mat.getNumRows(); i < e; i++)
mat.negateRow(i);
MultiAffineFunction maf(flippedPiece.output.getSpace(), mat);
PWMAFunction::Piece piece = {flippedPiece.domain, maf};
symbolicIntegerLexMax.lexopt.addPiece(piece);
}
symbolicIntegerLexMax.unboundedDomain =
flippedSymbolicIntegerLexMax.unboundedDomain;
return symbolicIntegerLexMax;
}
PresburgerRelation
IntegerRelation::subtract(const PresburgerRelation &set) const {
return PresburgerRelation(*this).subtract(set);
}
unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumVarKind(kind));
unsigned insertPos = space.insertVar(kind, pos, num);
inequalities.insertColumns(insertPos, num);
equalities.insertColumns(insertPos, num);
return insertPos;
}
unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) {
unsigned pos = getNumVarKind(kind);
return insertVar(kind, pos, num);
}
void IntegerRelation::addEquality(ArrayRef<DynamicAPInt> eq) {
assert(eq.size() == getNumCols());
unsigned row = equalities.appendExtraRow();
for (unsigned i = 0, e = eq.size(); i < e; ++i)
equalities(row, i) = eq[i];
}
void IntegerRelation::addInequality(ArrayRef<DynamicAPInt> inEq) {
assert(inEq.size() == getNumCols());
unsigned row = inequalities.appendExtraRow();
for (unsigned i = 0, e = inEq.size(); i < e; ++i)
inequalities(row, i) = inEq[i];
}
void IntegerRelation::removeVar(VarKind kind, unsigned pos) {
removeVarRange(kind, pos, pos + 1);
}
void IntegerRelation::removeVar(unsigned pos) { removeVarRange(pos, pos + 1); }
void IntegerRelation::removeVarRange(VarKind kind, unsigned varStart,
unsigned varLimit) {
assert(varLimit <= getNumVarKind(kind));
if (varStart >= varLimit)
return;
unsigned offset = getVarKindOffset(kind);
equalities.removeColumns(offset + varStart, varLimit - varStart);
inequalities.removeColumns(offset + varStart, varLimit - varStart);
space.removeVarRange(kind, varStart, varLimit);
}
void IntegerRelation::removeVarRange(unsigned varStart, unsigned varLimit) {
assert(varLimit <= getNumVars());
if (varStart >= varLimit)
return;
auto removeVarKindInRange = [this](VarKind kind, unsigned &start,
unsigned &limit) {
if (start >= limit)
return;
unsigned offset = getVarKindOffset(kind);
unsigned num = getNumVarKind(kind);
unsigned relativeStart =
start <= offset ? 0 : std::min(num, start - offset);
unsigned relativeLimit =
limit <= offset ? 0 : std::min(num, limit - offset);
removeVarRange(kind, relativeStart, relativeLimit);
limit -= relativeLimit - relativeStart;
};
removeVarKindInRange(VarKind::Domain, varStart, varLimit);
removeVarKindInRange(VarKind::Range, varStart, varLimit);
removeVarKindInRange(VarKind::Symbol, varStart, varLimit);
removeVarKindInRange(VarKind::Local, varStart, varLimit);
}
void IntegerRelation::removeEquality(unsigned pos) {
equalities.removeRow(pos);
}
void IntegerRelation::removeInequality(unsigned pos) {
inequalities.removeRow(pos);
}
void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
if (start >= end)
return;
equalities.removeRows(start, end - start);
}
void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
if (start >= end)
return;
inequalities.removeRows(start, end - start);
}
void IntegerRelation::swapVar(unsigned posA, unsigned posB) {
assert(posA < getNumVars() && "invalid position A");
assert(posB < getNumVars() && "invalid position B");
if (posA == posB)
return;
VarKind kindA = space.getVarKindAt(posA);
VarKind kindB = space.getVarKindAt(posB);
unsigned relativePosA = posA - getVarKindOffset(kindA);
unsigned relativePosB = posB - getVarKindOffset(kindB);
space.swapVar(kindA, kindB, relativePosA, relativePosB);
inequalities.swapColumns(posA, posB);
equalities.swapColumns(posA, posB);
}
void IntegerRelation::clearConstraints() {
equalities.resizeVertically(0);
inequalities.resizeVertically(0);
}
void IntegerRelation::getLowerAndUpperBoundIndices(
unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
unsigned offset, unsigned num) const {
assert(pos < getNumVars() && "invalid position");
assert(offset + num < getNumCols() && "invalid range");
auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
unsigned c, f;
auto cst = isEq ? getEquality(r) : getInequality(r);
for (c = offset, f = offset + num; c < f; ++c) {
if (c == pos)
continue;
if (cst[c] != 0)
break;
}
return c < f;
};
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (containsConstraintDependentOnRange(r, false))
continue;
if (atIneq(r, pos) >= 1) {
lbIndices->emplace_back(r);
} else if (atIneq(r, pos) <= -1) {
ubIndices->emplace_back(r);
}
}
if (!eqIndices)
return;
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
if (atEq(r, pos) == 0)
continue;
if (containsConstraintDependentOnRange(r, true))
continue;
eqIndices->emplace_back(r);
}
}
bool IntegerRelation::hasConsistentState() const {
if (!inequalities.hasConsistentState())
return false;
if (!equalities.hasConsistentState())
return false;
return true;
}
void IntegerRelation::setAndEliminate(unsigned pos,
ArrayRef<DynamicAPInt> values) {
if (values.empty())
return;
assert(pos + values.size() <= getNumVars() &&
"invalid position or too many values");
unsigned constantColPos = getNumCols() - 1;
for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
inequalities.addToColumn(i + pos, constantColPos, values[i]);
for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
equalities.addToColumn(i + pos, constantColPos, values[i]);
removeVarRange(pos, pos + values.size());
}
void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
*this = other;
}
bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
unsigned *rowIdx) const {
assert(colIdx < getNumCols() && "position out of bounds");
auto at = [&](unsigned rowIdx) -> DynamicAPInt {
return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
};
unsigned e = isEq ? getNumEqualities() : getNumInequalities();
for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
if (at(*rowIdx) != 0) {
return true;
}
}
return false;
}
void IntegerRelation::normalizeConstraintsByGCD() {
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
equalities.normalizeRow(i);
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
inequalities.normalizeRow(i);
}
bool IntegerRelation::hasInvalidConstraint() const {
assert(hasConsistentState());
auto check = [&](bool isEq) -> bool {
unsigned numCols = getNumCols();
unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
for (unsigned i = 0, e = numRows; i < e; ++i) {
unsigned j;
for (j = 0; j < numCols - 1; ++j) {
DynamicAPInt v = isEq ? atEq(i, j) : atIneq(i, j);
if (v != 0)
break;
}
if (j < numCols - 1) {
continue;
}
DynamicAPInt v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
if ((isEq && v != 0) || (!isEq && v < 0)) {
return true;
}
}
return false;
};
if (check(true))
return true;
return check(false);
}
static void eliminateFromConstraint(IntegerRelation *constraints,
unsigned rowIdx, unsigned pivotRow,
unsigned pivotCol, unsigned elimColStart,
bool isEq) {
if (isEq && rowIdx == pivotRow)
return;
auto at = [&](unsigned i, unsigned j) -> DynamicAPInt {
return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
};
DynamicAPInt leadCoeff = at(rowIdx, pivotCol);
if (leadCoeff == 0)
return;
DynamicAPInt pivotCoeff = constraints->atEq(pivotRow, pivotCol);
int sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
DynamicAPInt lcm = llvm::lcm(pivotCoeff, leadCoeff);
DynamicAPInt pivotMultiplier = sign * (lcm / abs(pivotCoeff));
DynamicAPInt rowMultiplier = lcm / abs(leadCoeff);
unsigned numCols = constraints->getNumCols();
for (unsigned j = 0; j < numCols; ++j) {
if (j >= elimColStart && j < pivotCol)
continue;
DynamicAPInt v = pivotMultiplier * constraints->atEq(pivotRow, j) +
rowMultiplier * at(rowIdx, j);
isEq ? constraints->atEq(rowIdx, j) = v
: constraints->atIneq(rowIdx, j) = v;
}
}
static unsigned getBestVarToEliminate(const IntegerRelation &cst,
unsigned start, unsigned end) {
assert(start < cst.getNumVars() && end < cst.getNumVars() + 1);
auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
unsigned numLb = 0;
unsigned numUb = 0;
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
if (cst.atIneq(r, pos) > 0) {
++numLb;
} else if (cst.atIneq(r, pos) < 0) {
++numUb;
}
}
return numLb * numUb;
};
unsigned minLoc = start;
unsigned min = getProductOfNumLowerUpperBounds(start);
for (unsigned c = start + 1; c < end; c++) {
unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
if (numLbUbProduct < min) {
min = numLbUbProduct;
minLoc = c;
}
}
return minLoc;
}
bool IntegerRelation::isEmpty() const {
if (isEmptyByGCDTest() || hasInvalidConstraint())
return true;
IntegerRelation tmpCst(*this);
tmpCst.removeRedundantLocalVars();
if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
return true;
unsigned currentPos = 0;
while (currentPos < tmpCst.getNumVars()) {
tmpCst.gaussianEliminateVars(currentPos, tmpCst.getNumVars());
++currentPos;
if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
return true;
}
for (unsigned i = 0, e = tmpCst.getNumVars(); i < e; i++) {
tmpCst.fourierMotzkinEliminate(
getBestVarToEliminate(tmpCst, 0, tmpCst.getNumVars()));
if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumVars()) {
LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
return false;
}
if (tmpCst.hasInvalidConstraint())
return true;
}
return false;
}
bool IntegerRelation::isObviouslyEmpty() const {
return isEmptyByGCDTest() || hasInvalidConstraint();
}
bool IntegerRelation::isEmptyByGCDTest() const {
assert(hasConsistentState());
unsigned numCols = getNumCols();
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
DynamicAPInt gcd = abs(atEq(i, 0));
for (unsigned j = 1; j < numCols - 1; ++j) {
gcd = llvm::gcd(gcd, abs(atEq(i, j)));
}
DynamicAPInt v = abs(atEq(i, numCols - 1));
if (gcd > 0 && (v % gcd != 0)) {
return true;
}
}
return false;
}
IntMatrix IntegerRelation::getBoundedDirections() const {
Simplex simplex(*this);
assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
"direction is bounded in an empty set.");
SmallVector<unsigned, 8> boundedIneqs;
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
if (simplex.isBoundedAlongConstraint(i))
boundedIneqs.emplace_back(i);
}
unsigned dirsNumCols = getNumCols() - 1;
IntMatrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
unsigned row = 0;
for (unsigned i : boundedIneqs) {
for (unsigned col = 0; col < dirsNumCols; ++col)
dirs(row, col) = atIneq(i, col);
++row;
}
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
for (unsigned col = 0; col < dirsNumCols; ++col)
dirs(row, col) = atEq(i, col);
++row;
}
return dirs;
}
bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); }
std::optional<SmallVector<DynamicAPInt, 8>>
IntegerRelation::findIntegerSample() const {
if (isEmptyByGCDTest())
return {};
Simplex simplex(*this);
if (simplex.isEmpty())
return {};
if (!simplex.isUnbounded())
return simplex.findIntegerSample();
IntMatrix m = getBoundedDirections();
std::pair<unsigned, LinearTransform> result =
LinearTransform::makeTransformToColumnEchelon(m);
const LinearTransform &transform = result.second;
IntegerRelation transformedSet = transform.applyTo(*this);
IntegerRelation boundedSet(transformedSet);
unsigned numBoundedDims = result.first;
unsigned numUnboundedDims = getNumVars() - numBoundedDims;
removeConstraintsInvolvingVarRange(boundedSet, numBoundedDims,
numUnboundedDims);
boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars());
std::optional<SmallVector<DynamicAPInt, 8>> boundedSample =
Simplex(boundedSet).findIntegerSample();
if (!boundedSample)
return {};
assert(boundedSet.containsPoint(*boundedSample) &&
"Simplex returned an invalid sample!");
transformedSet.setAndEliminate(0, *boundedSample);
IntegerRelation &cone = transformedSet;
for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
for (unsigned j = 0; j < cone.getNumVars(); ++j) {
DynamicAPInt coeff = cone.atIneq(i, j);
if (coeff < 0)
cone.atIneq(i, cone.getNumVars()) += coeff;
}
}
Simplex shrunkenConeSimplex(cone);
assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
SmallVector<Fraction, 8> shrunkenConeSample =
*shrunkenConeSimplex.getRationalSample();
SmallVector<DynamicAPInt, 8> coneSample(
llvm::map_range(shrunkenConeSample, ceil));
SmallVector<DynamicAPInt, 8> &sample = *boundedSample;
sample.append(coneSample.begin(), coneSample.end());
return transform.postMultiplyWithColumn(sample);
}
static DynamicAPInt valueAt(ArrayRef<DynamicAPInt> expr,
ArrayRef<DynamicAPInt> point) {
assert(expr.size() == 1 + point.size() &&
"Dimensionalities of point and expression don't match!");
DynamicAPInt value = expr.back();
for (unsigned i = 0; i < point.size(); ++i)
value += expr[i] * point[i];
return value;
}
bool IntegerRelation::containsPoint(ArrayRef<DynamicAPInt> point) const {
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
if (valueAt(getEquality(i), point) != 0)
return false;
}
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
if (valueAt(getInequality(i), point) < 0)
return false;
}
return true;
}
std::optional<SmallVector<DynamicAPInt, 8>>
IntegerRelation::containsPointNoLocal(ArrayRef<DynamicAPInt> point) const {
assert(point.size() == getNumVars() - getNumLocalVars() &&
"Point should contain all vars except locals!");
assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() &&
"This function depends on locals being stored last!");
IntegerRelation copy = *this;
copy.setAndEliminate(0, point);
return copy.findIntegerSample();
}
DivisionRepr
IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> *repr) const {
SmallVector<bool, 8> foundRepr(getNumVars(), false);
for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
foundRepr[i] = true;
unsigned localOffset = getVarKindOffset(VarKind::Local);
DivisionRepr divs(getNumVars(), getNumLocalVars());
bool changed;
do {
changed = false;
for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) {
if (!foundRepr[i + localOffset]) {
MaybeLocalRepr res =
computeSingleVarRepr(*this, foundRepr, localOffset + i,
divs.getDividend(i), divs.getDenom(i));
if (!res) {
divs.clearRepr(i);
continue;
}
foundRepr[localOffset + i] = true;
if (repr)
(*repr)[i] = res;
changed = true;
}
}
} while (changed);
return divs;
}
void IntegerRelation::gcdTightenInequalities() {
unsigned numCols = getNumCols();
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
DynamicAPInt gcd = inequalities.normalizeRow(i, getNumCols() - 1);
if (gcd > 1)
atIneq(i, numCols - 1) = floorDiv(atIneq(i, numCols - 1), gcd);
}
}
unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
unsigned posLimit) {
assert(posLimit <= getNumVars());
assert(hasConsistentState());
if (posStart >= posLimit)
return 0;
gcdTightenInequalities();
unsigned pivotCol = 0;
for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
unsigned pivotRow;
if (!findConstraintWithNonZeroAt(pivotCol, true, &pivotRow)) {
if (!findConstraintWithNonZeroAt(pivotCol, false, &pivotRow)) {
continue;
}
break;
}
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
true);
equalities.normalizeRow(i);
}
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
false);
inequalities.normalizeRow(i);
}
removeEquality(pivotRow);
gcdTightenInequalities();
}
posLimit = pivotCol;
removeVarRange(posStart, posLimit);
return posLimit - posStart;
}
bool IntegerRelation::gaussianEliminate() {
gcdTightenInequalities();
unsigned firstVar = 0, vars = getNumVars();
unsigned nowDone, eqs, pivotRow;
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
for (; firstVar < vars; ++firstVar) {
if (!findConstraintWithNonZeroAt(firstVar, true, &pivotRow))
continue;
break;
}
if (firstVar >= vars)
break;
if (pivotRow > nowDone) {
equalities.swapRows(pivotRow, nowDone);
pivotRow = nowDone;
}
for (unsigned i = nowDone + 1; i < eqs; ++i) {
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, true);
equalities.normalizeRow(i);
}
for (unsigned i = 0, ineqs = getNumInequalities(); i < ineqs; ++i) {
eliminateFromConstraint(this, i, pivotRow, firstVar, 0, false);
inequalities.normalizeRow(i);
}
gcdTightenInequalities();
}
if (nowDone == eqs)
return false;
for (unsigned i = nowDone; i < eqs; ++i) {
if (atEq(i, vars) == 0)
continue;
*this = getEmpty(getSpace());
return true;
}
removeEqualityRange(nowDone, eqs);
return true;
}
void IntegerRelation::removeRedundantInequalities() {
SmallVector<bool, 32> redun(getNumInequalities(), false);
IntegerRelation tmpCst(*this);
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
tmpCst.inequalities.negateRow(r);
--tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
if (tmpCst.isEmpty()) {
redun[r] = true;
inequalities.fillRow(r, 0);
tmpCst.inequalities.fillRow(r, 0);
} else {
++tmpCst.atIneq(r, tmpCst.getNumCols() - 1);
tmpCst.inequalities.negateRow(r);
}
}
unsigned pos = 0;
for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
if (!redun[r])
inequalities.copyRow(r, pos++);
}
inequalities.resizeVertically(pos);
}
void IntegerRelation::removeRedundantConstraints() {
gcdTightenInequalities();
Simplex simplex(*this);
simplex.detectRedundant();
unsigned pos = 0;
unsigned numIneqs = getNumInequalities();
for (unsigned r = 0; r < numIneqs; r++) {
if (!simplex.isMarkedRedundant(r))
inequalities.copyRow(r, pos++);
}
inequalities.resizeVertically(pos);
pos = 0;
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
equalities.copyRow(r, pos++);
}
equalities.resizeVertically(pos);
}
std::optional<DynamicAPInt> IntegerRelation::computeVolume() const {
assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
Simplex simplex(*this);
if (simplex.isEmpty())
return DynamicAPInt(0);
DynamicAPInt count(1);
SmallVector<DynamicAPInt, 8> dim(getNumVars() + 1);
bool hasUnboundedVar = false;
for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) {
dim[i] = 1;
auto [min, max] = simplex.computeIntegerBounds(dim);
dim[i] = 0;
assert((!min.isEmpty() && !max.isEmpty()) &&
"Polytope should be rationally non-empty!");
if (min.isUnbounded() || max.isUnbounded()) {
hasUnboundedVar = true;
continue;
}
if (min.getBoundedOptimum() > max.getBoundedOptimum())
return DynamicAPInt(0);
count *= (*max - *min + 1);
}
if (count == 0)
return DynamicAPInt(0);
if (hasUnboundedVar)
return {};
return count;
}
void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
assert(posA < getNumLocalVars() && "Invalid local var position");
assert(posB < getNumLocalVars() && "Invalid local var position");
unsigned localOffset = getVarKindOffset(VarKind::Local);
posA += localOffset;
posB += localOffset;
inequalities.addToColumn(posB, posA, 1);
equalities.addToColumn(posB, posA, 1);
removeVar(posB);
}
void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
assert(space.isUsingIds() && other.space.isUsingIds() &&
"both relations need to have identifers to merge and align");
unsigned i = 0;
for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
const Identifier *findBegin =
other.space.getIds(VarKind::Symbol).begin() + i;
const Identifier *findEnd = other.space.getIds(VarKind::Symbol).end();
const Identifier *itr = std::find(findBegin, findEnd, identifier);
if (itr != findEnd) {
other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
other.getVarKindOffset(VarKind::Symbol) + i +
std::distance(findBegin, itr));
} else {
other.insertVar(VarKind::Symbol, i);
other.space.setId(VarKind::Symbol, i, identifier);
}
++i;
}
for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
insertVar(VarKind::Symbol, i);
space.setId(VarKind::Symbol, i, other.space.getId(VarKind::Symbol, i));
}
}
unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) {
IntegerRelation &relA = *this;
IntegerRelation &relB = other;
unsigned oldALocals = relA.getNumLocalVars();
auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool {
if (i >= j)
return false;
if (j < oldALocals)
return false;
relA.eliminateRedundantLocalVar(i, j);
relB.eliminateRedundantLocalVar(i, j);
return true;
};
presburger::mergeLocalVars(*this, other, merge);
return relA.getNumLocalVars() - oldALocals;
}
bool IntegerRelation::hasOnlyDivLocals() const {
return getLocalReprs().hasAllReprs();
}
void IntegerRelation::removeDuplicateDivs() {
DivisionRepr divs = getLocalReprs();
auto merge = [this](unsigned i, unsigned j) -> bool {
eliminateRedundantLocalVar(i, j);
return true;
};
divs.removeDuplicateDivs(merge);
}
void IntegerRelation::simplify() {
bool changed = true;
while (changed) {
if (isObviouslyEmpty())
return;
changed = false;
normalizeConstraintsByGCD();
changed |= gaussianEliminate();
changed |= removeDuplicateConstraints();
}
}
void IntegerRelation::removeRedundantLocalVars() {
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
equalities.normalizeRow(i);
while (true) {
unsigned i, e, j, f;
for (i = 0, e = getNumEqualities(); i < e; ++i) {
for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j)
if (abs(atEq(i, j)) == 1)
break;
if (j < f)
break;
}
if (i == e)
break;
for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
if (atEq(k, j) != 0) {
eliminateFromConstraint(this, k, i, j, j, true);
equalities.normalizeRow(k);
}
}
for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
eliminateFromConstraint(this, k, i, j, j, false);
removeVar(j);
removeEquality(i);
}
}
void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
unsigned varLimit, VarKind dstKind,
unsigned pos) {
assert(varLimit <= getNumVarKind(srcKind) && "invalid id range");
if (varStart >= varLimit)
return;
unsigned srcOffset = getVarKindOffset(srcKind);
unsigned dstOffset = getVarKindOffset(dstKind);
unsigned convertCount = varLimit - varStart;
int forwardMoveOffset = dstOffset > srcOffset ? -convertCount : 0;
equalities.moveColumns(srcOffset + varStart, convertCount,
dstOffset + pos + forwardMoveOffset);
inequalities.moveColumns(srcOffset + varStart, convertCount,
dstOffset + pos + forwardMoveOffset);
space.convertVarKind(srcKind, varStart, varLimit - varStart, dstKind, pos);
}
void IntegerRelation::addBound(BoundType type, unsigned pos,
const DynamicAPInt &value) {
assert(pos < getNumCols());
if (type == BoundType::EQ) {
unsigned row = equalities.appendExtraRow();
equalities(row, pos) = 1;
equalities(row, getNumCols() - 1) = -value;
} else {
unsigned row = inequalities.appendExtraRow();
inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
inequalities(row, getNumCols() - 1) =
type == BoundType::LB ? -value : value;
}
}
void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr,
const DynamicAPInt &value) {
assert(type != BoundType::EQ && "EQ not implemented");
assert(expr.size() == getNumCols());
unsigned row = inequalities.appendExtraRow();
for (unsigned i = 0, e = expr.size(); i < e; ++i)
inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
type == BoundType::LB ? -value : value;
}
void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
const DynamicAPInt &divisor) {
assert(dividend.size() == getNumCols() && "incorrect dividend size");
assert(divisor > 0 && "positive divisor expected");
appendVar(VarKind::Local);
SmallVector<DynamicAPInt, 8> dividendCopy(dividend.begin(), dividend.end());
dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0));
addInequality(
getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
addInequality(
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
}
static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
bool symbolic = false) {
assert(pos < cst.getNumVars() && "invalid position");
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
DynamicAPInt v = cst.atEq(r, pos);
if (v * v != 1)
continue;
unsigned c;
unsigned f = symbolic ? cst.getNumDimVars() : cst.getNumVars();
for (c = 0; c < f; c++) {
if (c == pos)
continue;
if (cst.atEq(r, c) != 0) {
break;
}
}
if (c == f)
return r;
}
return -1;
}
LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
int rowIdx;
if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
return failure();
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
setAndEliminate(pos, constVal);
return success();
}
void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
if (constantFoldVar(t).failed())
t++;
}
}
std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
unsigned pos, SmallVectorImpl<DynamicAPInt> *lb,
DynamicAPInt *boundFloorDivisor, SmallVectorImpl<DynamicAPInt> *ub,
unsigned *minLbPos, unsigned *minUbPos) const {
assert(pos < getNumDimVars() && "Invalid variable position");
int eqPos = findEqualityToConstant(*this, pos, true);
if (eqPos != -1) {
auto eq = getEquality(eqPos);
if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
[](const DynamicAPInt &coeff) { return coeff == 0; }))
return std::nullopt;
if (lb) {
lb->resize(getNumSymbolVars() + 1);
if (ub)
ub->resize(getNumSymbolVars() + 1);
for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
DynamicAPInt v = atEq(eqPos, pos);
assert(v * v == 1);
(*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
: -atEq(eqPos, getNumDimVars() + c) / v;
if (ub)
(*ub)[c] = (*lb)[c];
}
assert(boundFloorDivisor &&
"both lb and divisor or none should be provided");
*boundFloorDivisor = 1;
}
if (minLbPos)
*minLbPos = eqPos;
if (minUbPos)
*minUbPos = eqPos;
return DynamicAPInt(1);
}
unsigned r, e;
for (r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, pos) != 0)
break;
}
if (r == e)
return std::nullopt;
SmallVector<unsigned, 4> lbIndices, ubIndices;
getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
nullptr, 0,
getNumDimVars());
std::optional<DynamicAPInt> minDiff;
unsigned minLbPosition = 0, minUbPosition = 0;
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
unsigned j, e;
for (j = 0, e = getNumCols() - 1; j < e; j++)
if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
break;
}
if (j < getNumCols() - 1)
continue;
DynamicAPInt diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
atIneq(lbPos, getNumCols() - 1) + 1,
atIneq(lbPos, pos));
diff = std::max<DynamicAPInt>(diff, DynamicAPInt(0));
if (minDiff == std::nullopt || diff < minDiff) {
minDiff = diff;
minLbPosition = lbPos;
minUbPosition = ubPos;
}
}
}
if (lb && minDiff) {
lb->resize(getNumSymbolVars() + 1);
if (ub)
ub->resize(getNumSymbolVars() + 1);
*boundFloorDivisor = atIneq(minLbPosition, pos);
assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
(*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
}
if (ub) {
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
(*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
}
(*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
}
if (minLbPos)
*minLbPos = minLbPosition;
if (minUbPos)
*minUbPos = minUbPosition;
return minDiff;
}
template <bool isLower>
std::optional<DynamicAPInt>
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
projectOut(0, pos);
projectOut(1, getNumVars() - 1);
int eqRowIdx = findEqualityToConstant(*this, 0, false);
if (eqRowIdx != -1)
return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
unsigned r, e;
for (r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, 0) != 0)
break;
}
if (r == e)
return std::nullopt;
std::optional<DynamicAPInt> minOrMaxConst;
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (isLower) {
if (atIneq(r, 0) <= 0)
continue;
} else if (atIneq(r, 0) >= 0) {
continue;
}
unsigned c, f;
for (c = 0, f = getNumCols() - 1; c < f; c++)
if (c != 0 && atIneq(r, c) != 0)
break;
if (c < getNumCols() - 1)
continue;
DynamicAPInt boundConst =
isLower ? ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
: floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
if (isLower) {
if (minOrMaxConst == std::nullopt || boundConst > minOrMaxConst)
minOrMaxConst = boundConst;
} else {
if (minOrMaxConst == std::nullopt || boundConst < minOrMaxConst)
minOrMaxConst = boundConst;
}
}
return minOrMaxConst;
}
std::optional<DynamicAPInt>
IntegerRelation::getConstantBound(BoundType type, unsigned pos) const {
if (type == BoundType::LB)
return IntegerRelation(*this)
.computeConstantLowerOrUpperBound<true>(pos);
if (type == BoundType::UB)
return IntegerRelation(*this)
.computeConstantLowerOrUpperBound<false>(pos);
assert(type == BoundType::EQ && "expected EQ");
std::optional<DynamicAPInt> lb =
IntegerRelation(*this).computeConstantLowerOrUpperBound<true>(
pos);
std::optional<DynamicAPInt> ub =
IntegerRelation(*this)
.computeConstantLowerOrUpperBound<false>(pos);
return (lb && ub && *lb == *ub) ? std::optional<DynamicAPInt>(*ub)
: std::nullopt;
}
bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
assert(pos < getNumCols() - 1);
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
unsigned sum = 0;
for (unsigned c = pos; c < pos + num; c++) {
if (atIneq(r, c) != 0)
sum++;
}
if (sum > 1)
return false;
}
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
unsigned sum = 0;
for (unsigned c = pos; c < pos + num; c++) {
if (atEq(r, c) != 0)
sum++;
}
if (sum > 1)
return false;
}
return true;
}
void IntegerRelation::removeTrivialRedundancy() {
gcdTightenInequalities();
normalizeConstraintsByGCD();
SmallDenseMap<ArrayRef<DynamicAPInt>, std::pair<unsigned, DynamicAPInt>>
rowsWithoutConstTerm;
SmallDenseSet<ArrayRef<DynamicAPInt>, 8> rowSet;
auto isTriviallyValid = [&](unsigned r) -> bool {
for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
if (atIneq(r, c) != 0)
return false;
}
return atIneq(r, getNumCols() - 1) >= 0;
};
SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
DynamicAPInt *rowStart = &inequalities(r, 0);
auto row = ArrayRef<DynamicAPInt>(rowStart, getNumCols());
if (isTriviallyValid(r) || !rowSet.insert(row).second) {
redunIneq[r] = true;
continue;
}
DynamicAPInt constTerm = atIneq(r, getNumCols() - 1);
auto rowWithoutConstTerm =
ArrayRef<DynamicAPInt>(rowStart, getNumCols() - 1);
const auto &ret =
rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
if (!ret.second) {
auto &val = ret.first->second;
if (val.second > constTerm) {
redunIneq[val.first] = true;
val = {r, constTerm};
} else {
redunIneq[r] = true;
}
}
}
unsigned pos = 0;
for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
if (!redunIneq[r])
inequalities.copyRow(r, pos++);
inequalities.resizeVertically(pos);
}
#undef DEBUG_TYPE
#define DEBUG_TYPE "fm"
void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
bool *isResultIntegerExact) {
LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
LLVM_DEBUG(dump());
assert(pos < getNumVars() && "invalid position");
assert(hasConsistentState());
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
if (atEq(r, pos) != 0) {
LogicalResult ret = gaussianEliminateVar(pos);
(void)ret;
assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed");
LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
LLVM_DEBUG(dump());
return;
}
}
gcdTightenInequalities();
if (isColZero(pos)) {
removeVar(pos);
LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
LLVM_DEBUG(dump());
return;
}
SmallVector<unsigned, 4> lbIndices;
SmallVector<unsigned, 4> ubIndices;
std::vector<unsigned> nbIndices;
nbIndices.reserve(getNumInequalities());
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
if (atIneq(r, pos) == 0) {
nbIndices.emplace_back(r);
} else if (atIneq(r, pos) >= 1) {
lbIndices.emplace_back(r);
} else {
ubIndices.emplace_back(r);
}
}
PresburgerSpace newSpace = getSpace();
VarKind idKindRemove = newSpace.getVarKindAt(pos);
unsigned relativePos = pos - newSpace.getVarKindOffset(idKindRemove);
newSpace.removeVarRange(idKindRemove, relativePos, relativePos + 1);
IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
getNumEqualities(), getNumCols() - 1, newSpace);
bool allLCMsAreOne = true;
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
SmallVector<DynamicAPInt, 4> ineq;
ineq.reserve(newRel.getNumCols());
DynamicAPInt lbCoeff = atIneq(lbPos, pos);
DynamicAPInt ubCoeff = -atIneq(ubPos, pos);
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
continue;
assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
DynamicAPInt lcm = llvm::lcm(lbCoeff, ubCoeff);
ineq.emplace_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
atIneq(lbPos, l) * (lcm / lbCoeff));
assert(lcm > 0 && "lcm should be positive!");
if (lcm != 1)
allLCMsAreOne = false;
}
if (darkShadow) {
ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
}
newRel.addInequality(ineq);
}
}
LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << allLCMsAreOne
<< "\n");
if (allLCMsAreOne && isResultIntegerExact)
*isResultIntegerExact = true;
for (auto nbPos : nbIndices) {
SmallVector<DynamicAPInt, 4> ineq;
ineq.reserve(getNumCols() - 1);
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
continue;
ineq.emplace_back(atIneq(nbPos, l));
}
newRel.addInequality(ineq);
}
assert(newRel.getNumConstraints() ==
lbIndices.size() * ubIndices.size() + nbIndices.size());
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
SmallVector<DynamicAPInt, 4> eq;
eq.reserve(newRel.getNumCols());
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
continue;
eq.emplace_back(atEq(r, l));
}
newRel.addEquality(eq);
}
newRel.gcdTightenInequalities();
newRel.normalizeConstraintsByGCD();
newRel.removeTrivialRedundancy();
clearAndCopyFrom(newRel);
LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
LLVM_DEBUG(dump());
}
#undef DEBUG_TYPE
#define DEBUG_TYPE "presburger"
void IntegerRelation::projectOut(unsigned pos, unsigned num) {
if (num == 0)
return;
assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
assert(pos + num < getNumCols() && "invalid range");
unsigned currentPos = pos;
unsigned numToEliminate = num;
unsigned numGaussianEliminated = 0;
while (currentPos < getNumVars()) {
unsigned curNumEliminated =
gaussianEliminateVars(currentPos, currentPos + numToEliminate);
++currentPos;
numToEliminate -= curNumEliminated + 1;
numGaussianEliminated += curNumEliminated;
}
for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
unsigned numToEliminate = num - numGaussianEliminated - i;
fourierMotzkinEliminate(
getBestVarToEliminate(*this, pos, pos + numToEliminate));
}
gcdTightenInequalities();
normalizeConstraintsByGCD();
}
namespace {
enum BoundCmpResult { Greater, Less, Equal, Unknown };
static BoundCmpResult compareBounds(ArrayRef<DynamicAPInt> a,
ArrayRef<DynamicAPInt> b) {
assert(a.size() == b.size());
if (!std::equal(a.begin(), a.end() - 1, b.begin()))
return Unknown;
if (a.back() == b.back())
return Equal;
return a.back() < b.back() ? Less : Greater;
}
}
static void getCommonConstraints(const IntegerRelation &a,
const IntegerRelation &b, IntegerRelation &c) {
c = IntegerRelation(a.getSpace());
for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
if (a.getInequality(r) == b.getInequality(s)) {
c.addInequality(a.getInequality(r));
break;
}
}
}
for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
if (a.getEquality(r) == b.getEquality(s)) {
c.addEquality(a.getEquality(r));
break;
}
}
}
}
LogicalResult
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
assert(getNumLocalVars() == 0 && "local ids not supported yet here");
IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
getCommonConstraints(*this, otherCst, commonCst);
std::vector<SmallVector<DynamicAPInt, 8>> boundingLbs;
std::vector<SmallVector<DynamicAPInt, 8>> boundingUbs;
boundingLbs.reserve(2 * getNumDimVars());
boundingUbs.reserve(2 * getNumDimVars());
SmallVector<DynamicAPInt, 4> lb, otherLb, ub, otherUb;
SmallVector<DynamicAPInt, 4> minLb(getNumSymbolVars() + 1);
SmallVector<DynamicAPInt, 4> maxUb(getNumSymbolVars() + 1);
SmallVector<DynamicAPInt, 8> newLb(getNumCols()), newUb(getNumCols());
DynamicAPInt lbFloorDivisor, otherLbFloorDivisor;
for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
if (!extent.has_value())
return failure();
auto otherExtent = otherCst.getConstantBoundOnDimSize(
d, &otherLb, &otherLbFloorDivisor, &otherUb);
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
return failure();
assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
auto res = compareBounds(lb, otherLb);
if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
minLb = lb;
minLb.back() -= lbFloorDivisor - 1;
} else if (res == BoundCmpResult::Greater) {
minLb = otherLb;
minLb.back() -= otherLbFloorDivisor - 1;
} else {
auto constLb = getConstantBound(BoundType::LB, d);
auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
if (!constLb.has_value() || !constOtherLb.has_value())
return failure();
std::fill(minLb.begin(), minLb.end(), 0);
minLb.back() = std::min(*constLb, *constOtherLb);
}
auto uRes = compareBounds(ub, otherUb);
if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
maxUb = ub;
} else if (uRes == BoundCmpResult::Less) {
maxUb = otherUb;
} else {
auto constUb = getConstantBound(BoundType::UB, d);
auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
if (!constUb.has_value() || !constOtherUb.has_value())
return failure();
std::fill(maxUb.begin(), maxUb.end(), 0);
maxUb.back() = std::max(*constUb, *constOtherUb);
}
std::fill(newLb.begin(), newLb.end(), 0);
std::fill(newUb.begin(), newUb.end(), 0);
newLb[d] = lbFloorDivisor;
newUb[d] = -lbFloorDivisor;
std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars());
std::transform(newLb.begin() + getNumDimVars(), newLb.end(),
newLb.begin() + getNumDimVars(),
std::negate<DynamicAPInt>());
std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars());
boundingLbs.emplace_back(newLb);
boundingUbs.emplace_back(newUb);
}
clearConstraints();
for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
addInequality(boundingLbs[d]);
addInequality(boundingUbs[d]);
}
append(commonCst);
removeTrivialRedundancy();
return success();
}
bool IntegerRelation::isColZero(unsigned pos) const {
unsigned rowPos;
return !findConstraintWithNonZeroAt(pos, false, &rowPos) &&
!findConstraintWithNonZeroAt(pos, true, &rowPos);
}
static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
unsigned num,
SmallVectorImpl<unsigned> &nbIneqIndices,
SmallVectorImpl<unsigned> &nbEqIndices) {
assert(pos < cst.getNumVars() && "invalid start position");
assert(pos + num <= cst.getNumVars() && "invalid limit");
for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
unsigned c;
for (c = pos; c < pos + num; ++c) {
if (cst.atIneq(r, c) != 0)
break;
}
if (c == pos + num)
nbIneqIndices.emplace_back(r);
}
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
unsigned c;
for (c = pos; c < pos + num; ++c) {
if (cst.atEq(r, c) != 0)
break;
}
if (c == pos + num)
nbEqIndices.emplace_back(r);
}
}
void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
assert(pos + num <= getNumVars() && "invalid range");
SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
getIndependentConstraints(*this, 0, num, nbIneqIndices, nbEqIndices);
for (auto nbIndex : llvm::reverse(nbIneqIndices))
removeInequality(nbIndex);
for (auto nbIndex : llvm::reverse(nbEqIndices))
removeEquality(nbIndex);
}
IntegerPolyhedron IntegerRelation::getDomainSet() const {
IntegerRelation copyRel = *this;
copyRel.convertVarKind(VarKind::Range, 0, getNumVarKind(VarKind::Range),
VarKind::Local);
copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
VarKind::SetDim);
return IntegerPolyhedron(std::move(copyRel));
}
bool IntegerRelation::removeDuplicateConstraints() {
bool changed = false;
SmallDenseMap<ArrayRef<DynamicAPInt>, unsigned> hashTable;
unsigned ineqs = getNumInequalities(), cols = getNumCols();
if (ineqs <= 1)
return changed;
ArrayRef<DynamicAPInt> row = getInequality(0).drop_back();
hashTable.insert({row, 0});
for (unsigned k = 1; k < ineqs; ++k) {
row = getInequality(k).drop_back();
if (!hashTable.contains(row)) {
hashTable.insert({row, k});
continue;
}
unsigned l = hashTable[row];
changed = true;
if (atIneq(k, cols - 1) <= atIneq(l, cols - 1))
inequalities.swapRows(k, l);
removeInequality(k);
--k;
--ineqs;
}
SmallVector<DynamicAPInt> negIneq(cols - 1);
for (unsigned k = 0; k < ineqs; ++k) {
row = getInequality(k).drop_back();
negIneq.assign(row.begin(), row.end());
for (DynamicAPInt &ele : negIneq)
ele = -ele;
if (!hashTable.contains(negIneq))
continue;
unsigned l = hashTable[row];
auto sum = atIneq(l, cols - 1) + atIneq(k, cols - 1);
if (sum > 0 || l == k)
continue;
changed = true;
if (k < l)
std::swap(l, k);
if (sum == 0) {
addEquality(getInequality(k));
removeInequality(k);
removeInequality(l);
} else
*this = getEmpty(getSpace());
break;
}
return changed;
}
IntegerPolyhedron IntegerRelation::getRangeSet() const {
IntegerRelation copyRel = *this;
copyRel.convertVarKind(VarKind::Domain, 0, getNumVarKind(VarKind::Domain),
VarKind::Local);
return IntegerPolyhedron(std::move(copyRel));
}
void IntegerRelation::intersectDomain(const IntegerPolyhedron &poly) {
assert(getDomainSet().getSpace().isCompatible(poly.getSpace()) &&
"Domain set is not compatible with poly");
IntegerRelation rel = poly;
rel.inverse();
rel.appendVar(VarKind::Range, getNumRangeVars());
mergeLocalVars(rel);
append(rel);
}
void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) {
assert(getRangeSet().getSpace().isCompatible(poly.getSpace()) &&
"Range set is not compatible with poly");
IntegerRelation rel = poly;
rel.appendVar(VarKind::Domain, getNumDomainVars());
mergeLocalVars(rel);
append(rel);
}
void IntegerRelation::inverse() {
unsigned numRangeVars = getNumVarKind(VarKind::Range);
convertVarKind(VarKind::Domain, 0, getVarKindEnd(VarKind::Domain),
VarKind::Range);
convertVarKind(VarKind::Range, 0, numRangeVars, VarKind::Domain);
}
void IntegerRelation::compose(const IntegerRelation &rel) {
assert(getRangeSet().getSpace().isCompatible(rel.getDomainSet().getSpace()) &&
"Range of `this` should be compatible with Domain of `rel`");
IntegerRelation copyRel = rel;
unsigned numBVars = getNumRangeVars();
appendVar(VarKind::Range, copyRel.getNumRangeVars());
copyRel.convertVarKind(VarKind::Domain, 0, numBVars, VarKind::Range, 0);
intersectRange(IntegerPolyhedron(copyRel));
convertVarKind(VarKind::Range, 0, numBVars, VarKind::Local);
}
void IntegerRelation::applyDomain(const IntegerRelation &rel) {
inverse();
compose(rel);
inverse();
}
void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
void IntegerRelation::printSpace(raw_ostream &os) const {
space.print(os);
os << getNumConstraints() << " constraints\n";
}
void IntegerRelation::removeTrivialEqualities() {
for (int i = getNumEqualities() - 1; i >= 0; --i)
if (rangeIsZero(getEquality(i)))
removeEquality(i);
}
bool IntegerRelation::isFullDim() {
if (getNumVars() == 0)
return true;
if (isEmpty())
return false;
removeTrivialEqualities();
if (getNumEqualities() > 0)
return false;
Simplex simplex(*this);
return llvm::none_of(llvm::seq<int>(getNumInequalities()), [&](int i) {
return simplex.isFlatAlong(getInequality(i));
});
}
void IntegerRelation::mergeAndCompose(const IntegerRelation &other) {
assert(getNumDomainVars() == other.getNumRangeVars() &&
"Domain of this and range of other do not match");
IntegerRelation result = other;
const unsigned thisDomain = getNumDomainVars();
const unsigned thisRange = getNumRangeVars();
const unsigned otherDomain = other.getNumDomainVars();
const unsigned otherRange = other.getNumRangeVars();
insertVar(VarKind::Domain, 0, otherDomain);
insertVar(VarKind::Range, 0, otherRange);
result.insertVar(VarKind::Domain, otherDomain, thisDomain);
result.insertVar(VarKind::Range, otherRange, thisRange);
mergeAndAlignSymbols(result);
mergeLocalVars(result);
result.removeVarRange(VarKind::Domain, otherDomain, otherDomain + thisDomain);
result.convertToLocal(VarKind::Range, 0, otherRange);
convertToLocal(VarKind::Domain, otherDomain, otherDomain + thisDomain);
removeVarRange(VarKind::Range, 0, otherRange);
for (unsigned i = 0, e = result.getNumDomainVars(); i < e; ++i)
if (result.getSpace().getId(VarKind::Domain, i).hasValue())
space.setId(VarKind::Domain, i,
result.getSpace().getId(VarKind::Domain, i));
for (unsigned i = 0, e = getNumRangeVars(); i < e; ++i)
if (space.getId(VarKind::Range, i).hasValue())
result.space.setId(VarKind::Range, i, space.getId(VarKind::Range, i));
result.append(*this);
result.removeRedundantLocalVars();
*this = result;
}
void IntegerRelation::print(raw_ostream &os) const {
assert(hasConsistentState());
printSpace(os);
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
os << " ";
for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
os << atEq(i, j) << "\t";
}
os << "= 0\n";
}
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
os << " ";
for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
os << atIneq(i, j) << "\t";
}
os << ">= 0\n";
}
os << '\n';
}
void IntegerRelation::dump() const { print(llvm::errs()); }
unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
unsigned num) {
assert((kind != VarKind::Domain || num == 0) &&
"Domain has to be zero in a set");
return IntegerRelation::insertVar(kind, pos, num);
}
IntegerPolyhedron
IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const {
return IntegerPolyhedron(IntegerRelation::intersect(other));
}
PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const {
return PresburgerSet(IntegerRelation::subtract(other));
}