#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"
using namespace mlir;
using namespace presburger;
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
ArrayRef<int64_t> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<int64_t, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.push_back(vecA[i] - vecB[i]);
return result;
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
for (const MultiAffineFunction &piece : pieces)
domain.unionInPlace(piece.getDomain());
return domain;
}
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == domainSet.getNumDimAndSymbolVars() &&
"Point has incorrect dimensionality!");
Optional<SmallVector<int64_t, 8>> maybeLocalValues =
getDomain().containsPointNoLocal(point);
if (!maybeLocalValues)
return {};
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
pointHomogenous.append(*maybeLocalValues);
pointHomogenous.emplace_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
Optional<SmallVector<int64_t, 8>>
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumInputs() &&
"Point has incorrect dimensionality!");
for (const MultiAffineFunction &piece : pieces)
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
return output;
return {};
}
void MultiAffineFunction::print(raw_ostream &os) const {
os << "Domain:";
domainSet.print(os);
os << "Output:\n";
output.print(os);
os << "\n";
}
void MultiAffineFunction::dump() const { print(llvm::errs()); }
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
return getDomainSpace().isCompatible(other.getDomainSpace()) &&
getDomain().isEqual(other.getDomain()) &&
isEqualWhereDomainsOverlap(other);
}
unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos,
unsigned num) {
assert(kind != VarKind::Domain && "Domain has to be zero in a set");
unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos;
output.insertColumns(absolutePos, num);
return domainSet.insertVar(kind, pos, num);
}
void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart,
unsigned varLimit) {
output.removeColumns(varStart + domainSet.getVarKindOffset(kind),
varLimit - varStart);
domainSet.removeVarRange(kind, varStart, varLimit);
}
void MultiAffineFunction::truncateOutput(unsigned count) {
assert(count <= output.getNumRows());
output.resizeVertically(count);
}
void PWMAFunction::truncateOutput(unsigned count) {
assert(count <= numOutputs);
for (MultiAffineFunction &piece : pieces)
piece.truncateOutput(count);
numOutputs = count;
}
void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) {
output.insertColumns(domainSet.getVarKindEnd(VarKind::Local),
other.domainSet.getNumLocalVars());
other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local),
domainSet.getNumLocalVars());
auto merge = [this, &other](unsigned i, unsigned j) -> bool {
domainSet.eliminateRedundantLocalVar(i, j);
other.domainSet.eliminateRedundantLocalVar(i, j);
unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local);
output.addToColumn(localOffset + j, localOffset + i, 1);
output.removeColumn(localOffset + j);
other.output.addToColumn(localOffset + j, localOffset + i, 1);
other.output.removeColumn(localOffset + j);
return true;
};
presburger::mergeLocalVars(domainSet, other.domainSet, merge);
}
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!getDomainSpace().isCompatible(other.getDomainSpace()))
return false;
MultiAffineFunction commonFunc = *this;
commonFunc.mergeLocalVars(other);
commonFunc.domainSet.append(other.domainSet);
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
commonDomainMatching.addEquality(
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
}
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
if (!space.isCompatible(other.space))
return false;
if (!this->getDomain().isEqual(other.getDomain()))
return false;
for (const MultiAffineFunction &aPiece : this->pieces)
for (const MultiAffineFunction &bPiece : other.pieces)
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
return false;
return true;
}
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
assert(space.isCompatible(piece.getDomainSpace()) &&
"Piece to be added is not compatible with this PWMAFunction!");
assert(piece.isConsistent() && "Piece is internally inconsistent!");
assert(this->getDomain()
.intersect(PresburgerSet(piece.getDomain()))
.isIntegerEmpty() &&
"New piece's domain overlaps with that of existing pieces!");
pieces.push_back(piece);
}
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
const Matrix &output) {
addPiece(MultiAffineFunction(domain, output));
}
void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
for (const IntegerRelation &newDom : domain.getAllDisjuncts())
addPiece(IntegerPolyhedron(newDom), output);
}
void PWMAFunction::print(raw_ostream &os) const {
os << pieces.size() << " pieces:\n";
for (const MultiAffineFunction &piece : pieces)
piece.print(os);
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
MultiAffineFunction maf2)>
tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
"Number of outputs of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
PWMAFunction result(getSpace(), getNumOutputs());
for (const MultiAffineFunction &funcA : pieces) {
PresburgerSet dom(funcA.getDomain());
for (const MultiAffineFunction &funcB : func.pieces) {
PresburgerSet better = tiebreak(funcB, funcA);
result.addPiece(better, funcB.getOutputMatrix());
dom = dom.subtract(better);
}
result.addPiece(dom, funcA.getOutputMatrix());
}
PresburgerSet dom = getDomain();
for (const MultiAffineFunction &funcB : func.pieces)
result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
return result;
}
template <bool lexMin>
static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
const MultiAffineFunction &mafB) {
assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
"Domain spaces should be compatible.");
assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
"Number of outputs of both functions should be same.");
assert(mafA.getDomain().getNumLocalVars() == 0 &&
"Local variables are not supported yet.");
PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
const PresburgerSpace &space = mafA.getDomain().getSpace();
PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
IntegerPolyhedron levelSet(1,
mafA.getNumOutputs(),
space.getNumVars() + 1, space);
for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
SmallVector<int64_t, 8> subExpr =
subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
if (lexMin) {
levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1);
} else {
levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1);
}
result.unionInPlace(levelSet);
levelSet.removeInequality(0);
levelSet.addEquality(subExpr);
}
result = result.intersect(PresburgerSet(mafA.getDomain()))
.intersect(PresburgerSet(mafB.getDomain()));
return result;
}
PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex<true>);
}
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex<false>);
}