#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_CODEGENENV_H_
#include "CodegenUtils.h"
#include "LoopEmitter.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include <optional>
namespace mlir {
namespace sparse_tensor {
class CodegenEnv {
public:
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
unsigned numTensors, unsigned numLoops, unsigned maxRank);
LogicalResult initTensorExp();
ExprId getExprId() const { return tensorExp; }
linalg::GenericOp op() const { return linalgOp; }
const SparsificationOptions &options() const { return sparseOptions; }
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
void startEmit(SparseEmitStrategy emitStrategy);
std::optional<Operation *>
genLoopBoundary(function_ref<
std::optional<Operation *>(MutableArrayRef<Value> parameters)>
callback);
constexpr TensorId makeTensorId(unsigned t) const {
return latticeMerger.makeTensorId(t);
}
constexpr LoopId makeLoopId(unsigned i) const {
return latticeMerger.makeLoopId(i);
}
constexpr TensorLoopId makeTensorLoopId(unsigned t, unsigned i) const {
return latticeMerger.makeTensorLoopId(t, i);
}
const TensorExp &exp(ExprId e) const { return latticeMerger.exp(e); }
const LatPoint &lat(LatPointId l) const { return latticeMerger.lat(l); }
ArrayRef<LatPointId> set(LatSetId s) const { return latticeMerger.set(s); }
LevelType lt(TensorId t, LoopId i) const {
return latticeMerger.getLvlType(t, i);
}
LevelType lt(TensorLoopId b) const { return latticeMerger.getLvlType(b); }
unsigned getLoopNum() const { return latticeMerger.getNumLoops(); }
TensorLevel makeTensorLevel(TensorId t, Level l) const {
assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
return loopEmitter.makeTensorLevel(t, l);
}
TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
return makeTensorLevel(tlPair.first, tlPair.second);
}
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
return loopEmitter.unpackTensorLevel(tl);
}
template <class ContainerTy>
auto unpackTensorLevelRange(ContainerTy &&c) const {
return loopEmitter.unpackTensorLevelRange(std::forward<ContainerTy>(c));
}
unsigned getCurrentDepth() const { return loopEmitter.getCurrentDepth(); }
bool isAdmissibleTensorExp(ExprId e);
Value getLoopVar(LoopId i) const;
bool hasSparseOutput() const { return sparseOut != nullptr; }
bool isSparseOutput(OpOperand *o) const { return sparseOut == o; }
Value getInsertionChain() const { return insChain; }
void updateInsertionChain(Value chain);
bool atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const;
void startExpand(Value values, Value filled, Value added, Value count);
bool isExpand() const { return expValues != nullptr; }
void updateExpandCount(Value count);
Value getExpandValues() const { return expValues; }
Value getExpandFilled() const { return expFilled; }
Value getExpandAdded() const { return expAdded; }
Value getExpandCount() const { return expCount; }
void endExpand();
void startReduc(ExprId exp, Value val);
bool isReduc() const { return redExp != detail::kInvalidId; }
void updateReduc(Value val);
Value getReduc() const { return redVal; }
Value endReduc();
void startValidLexInsert(Value val);
bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
void updateValidLexInsert(Value val);
Value getValidLexInsert() const { return redValidLexInsert; }
void endValidLexInsert();
void startCustomReduc(ExprId exp);
bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
Value getCustomRedId() const;
void endCustomReduc();
private:
linalg::GenericOp linalgOp;
SparsificationOptions sparseOptions;
Merger latticeMerger;
LoopEmitter loopEmitter;
OpOperand *sparseOut;
LoopId outerParNest;
Value insChain;
Value expValues;
Value expFilled;
Value expAdded;
Value expCount;
Value redVal;
ExprId redExp;
ExprId redCustom;
Value redValidLexInsert;
ExprId tensorExp;
};
}
}
#endif