#include "CodegenEnv.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <optional>
using namespace mlir;
using namespace mlir::sparse_tensor;
static bool isMaterializing(Value val) {
return val.getDefiningOp<tensor::EmptyOp>() ||
val.getDefiningOp<bufferization::AllocTensorOp>();
}
static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
std::sort(target.begin(), target.end(),
[](const LoopCoeffPair &l, const LoopCoeffPair &r) {
assert(std::addressof(l) == std::addressof(r) || l != r);
return l.first < r.first;
});
}
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops, unsigned maxRank)
: linalgOp(linop), sparseOptions(opts),
latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
redCustom(detail::kInvalidId), redValidLexInsert() {}
LogicalResult CodegenEnv::initTensorExp() {
std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
if (!optExp || !isAdmissibleTensorExp(*optExp))
return failure();
tensorExp = *optExp;
return success();
}
void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
assert(insChain == nullptr && "must only start emitting once");
if (sparseOut) {
insChain = sparseOut->get();
latticeMerger.setHasSparseOut(true);
}
SmallVector<Value> tensors;
for (OpOperand &t : linalgOp->getOpOperands()) {
tensors.push_back(t.get());
const TensorId tid = makeTensorId(t.getOperandNumber());
const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
const auto enc = getSparseTensorEncoding(t.get().getType());
(void)enc;
assert(!enc || lvlRank == enc.getLvlRank());
for (Level lvl = 0; lvl < lvlRank; lvl++)
sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
}
loopEmitter.initialize(
tensors,
StringAttr::get(linalgOp.getContext(),
linalg::GenericOp::getOperationName()),
true,
sparseOut != nullptr, getLoopNum(),
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
return merger().getDependentLoops(t, lvl);
},
emitStrategy);
}
std::optional<Operation *> CodegenEnv::genLoopBoundary(
function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
callback) {
SmallVector<Value> params;
if (isReduc()) {
params.push_back(redVal);
if (isValidLexInsert())
params.push_back(redValidLexInsert);
} else {
assert(!isValidLexInsert());
}
if (isExpand())
params.push_back(expCount);
if (insChain != nullptr)
params.push_back(insChain);
auto r = callback(params);
unsigned i = 0;
if (isReduc()) {
updateReduc(params[i++]);
if (isValidLexInsert())
updateValidLexInsert(params[i++]);
}
if (isExpand())
updateExpandCount(params[i++]);
if (insChain != nullptr)
updateInsertionChain(params[i]);
return r;
}
bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
if (it == utils::IteratorType::reduction) {
if (latticeMerger.hasNegateOnOut(exp))
return false;
break;
}
}
OpOperand *lhs = linalgOp.getDpsInitOperand(0);
const TensorId tensor = makeTensorId(lhs->getOperandNumber());
if (getSparseTensorType(lhs->get()).isAllDense())
return true;
if (latticeMerger.isSingleCondition(tensor, exp))
return true;
sparseOut = lhs;
outerParNest = 0;
const auto iteratorTypes = linalgOp.getIteratorTypesArray();
for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
if (linalg::isReductionIterator(iteratorTypes[i]))
break;
outerParNest++;
}
assert(static_cast<int64_t>(outerParNest) >=
linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
return isMaterializing(lhs->get());
}
Value CodegenEnv::getLoopVar(LoopId i) const {
return loopEmitter.getLoopIV(i);
}
void CodegenEnv::updateInsertionChain(Value chain) {
assert(sparseOut != nullptr && insChain != nullptr);
insChain = chain;
}
bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
outerParNest == n;
}
void CodegenEnv::startExpand(Value values, Value filled, Value added,
Value count) {
assert(sparseOut != nullptr && expValues == nullptr);
expValues = values;
expFilled = filled;
expAdded = added;
expCount = count;
}
void CodegenEnv::updateExpandCount(Value count) {
assert(sparseOut != nullptr && expValues != nullptr);
expCount = count;
}
void CodegenEnv::endExpand() {
assert(sparseOut != nullptr && expValues != nullptr);
expValues = expFilled = expAdded = expCount = Value();
}
void CodegenEnv::startReduc(ExprId exp, Value val) {
assert(!isReduc() && exp != detail::kInvalidId && val);
redExp = exp;
redVal = val;
latticeMerger.setExprValue(exp, val);
}
void CodegenEnv::updateReduc(Value val) {
assert(isReduc() && val);
redVal = val;
latticeMerger.clearExprValue(redExp);
latticeMerger.setExprValue(redExp, val);
}
Value CodegenEnv::endReduc() {
assert(isReduc());
Value val = redVal;
redVal = val;
latticeMerger.clearExprValue(redExp);
redExp = detail::kInvalidId;
return val;
}
void CodegenEnv::startValidLexInsert(Value val) {
assert(!isValidLexInsert() && isReduc() && val);
redValidLexInsert = val;
}
void CodegenEnv::updateValidLexInsert(Value val) {
assert(redValidLexInsert && isReduc() && val);
redValidLexInsert = val;
}
void CodegenEnv::endValidLexInsert() {
assert(isValidLexInsert() && !isReduc());
redValidLexInsert = Value();
}
void CodegenEnv::startCustomReduc(ExprId exp) {
assert(!isCustomReduc() && exp != detail::kInvalidId);
redCustom = exp;
}
Value CodegenEnv::getCustomRedId() const {
assert(isCustomReduc());
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
}
void CodegenEnv::endCustomReduc() {
assert(isCustomReduc());
redCustom = detail::kInvalidId;
}