#include "Utils/CodegenEnv.h"
#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/SmallBitVector.h"
#include <optional>
using namespace mlir;
using namespace mlir::sparse_tensor;
static bool isInvariantAffine(AffineExpr a, LoopId curr, bool &isCurrentLoop) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId i = cast<AffineDimExpr>(a).getPosition();
if (i + 1 == curr) {
isCurrentLoop = true;
return true;
}
return i < curr;
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
auto binOp = cast<AffineBinaryOpExpr>(a);
return isInvariantAffine(binOp.getLHS(), curr, isCurrentLoop) &&
isInvariantAffine(binOp.getRHS(), curr, isCurrentLoop);
}
default: {
assert(isa<AffineConstantExpr>(a));
return true;
}
}
}
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
LevelType lt, bool setLvlFormat = true) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefLT(merger.getLvlType(tid, idx)))
return false;
if (setLvlFormat)
merger.setLevelAndType(tid, idx, lvl, lt);
return true;
}
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
assert(lt.hasDenseSemantic());
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
return findAffine(merger, tid, lvl, binOp.getLHS(), lt, false) &&
findAffine(merger, tid, lvl, binOp.getRHS(), lt, false);
}
return true;
}
default:
return false;
}
}
static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
AffineExpr a, LevelType lt, bool isSubExp = false,
int64_t coefficient = 1) {
switch (a.getKind()) {
case AffineExprKind::DimId: {
if (coefficient <= 0)
return false;
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefLT(merger.getLvlType(tensor, idx)))
return false;
if (!isSubExp) {
assert(coefficient == 1);
merger.setLevelAndType(tensor, idx, lvl, lt);
}
if (isSubExp) {
if (merger.hasDependentLvl(idx, tensor)) {
return false;
}
merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
}
return true;
}
case AffineExprKind::Constant:
case AffineExprKind::Mul: {
if (!isSubExp)
return false;
if (isa<AffineConstantExpr>(a))
llvm_unreachable("Not yet implemented");
auto binOp = cast<AffineBinaryOpExpr>(a);
auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
if (isa<AffineConstantExpr>(rhs))
std::swap(lhs, rhs);
assert(isa<AffineConstantExpr>(lhs) && isa<AffineDimExpr>(rhs));
int64_t coefficient = cast<AffineConstantExpr>(lhs).getValue();
return findDepIdxSet(merger, tensor, lvl, rhs, lt, isSubExp, coefficient);
}
case AffineExprKind::Add: {
auto binOp = cast<AffineBinaryOpExpr>(a);
return findDepIdxSet(merger, tensor, lvl, binOp.getLHS(), lt, true) &&
findDepIdxSet(merger, tensor, lvl, binOp.getRHS(), lt, true);
}
default:
return false;
}
}
static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
Value tensor) {
const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
return 0;
const SparseTensorType stt(rtp);
const Level lvlRank = stt.getLvlRank();
const auto exprs = map.getResults();
assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
"AffineMap does not have dimension-rank many results");
unsigned num = 0;
for (Level l = 0; l < lvlRank; l++) {
if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
num++;
}
return num;
}
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
for (OpOperand &t : op->getOpOperands())
num += getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(&t),
t.get());
return num;
}
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
return false;
return getNumNonTrivialIdxExpOnSparseLvls(op.getMatchingIndexingMap(out),
out->get());
}
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
bool annotated = false;
for (OpOperand &t : env.op()->getOpOperands()) {
const TensorId tid = env.makeTensorId(t.getOperandNumber());
const auto map = env.op().getMatchingIndexingMap(&t);
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
const Level lvlRank = map.getNumResults();
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
bool needIdxReduc =
enc && getNumNonTrivialIdxExpOnSparseLvls(map, t.get()) != 0;
for (Level l = 0; l < lvlRank; l++) {
const AffineExpr a = map.getResult(l);
const LevelType lt = enc.getLvlType(l);
if (idxReducBased && needIdxReduc) {
if (!findDepIdxSet(env.merger(), tid, l, a, lt))
return false;
} else {
if (!findAffine(env.merger(), tid, l, a, lt))
return false;
}
}
}
return annotated;
}
static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
SmallVector<Range, 4> loopRange =
llvm::cast<linalg::LinalgOp>(op.getOperation())
.createLoopRanges(builder, loc);
env.emitter().initializeLoopEmit(
builder, loc,
[&op](OpBuilder &builder, Location loc, Value memref,
Value tensor) -> Value {
assert(!getSparseTensorEncoding(tensor.getType()));
OpOperand *lhs = op.getDpsInitOperand(0);
assert(lhs->get() == tensor);
bool isInit = op.isInitTensor(lhs);
Value init = memref;
if (!isInit) {
Value zero = constantZero(builder, loc,
getElementTypeOrSelf(tensor.getType()));
builder.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{init});
}
return init;
},
[&loopRange](OpBuilder &b, Location loc, Level l) {
assert(l < loopRange.size());
return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
});
}
static Value genIndex(CodegenEnv &env, OpOperand *t) {
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
const AffineExpr a = map.getResult(lvlRank - 1);
assert(a.getKind() == AffineExprKind::DimId);
const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
return env.getLoopVar(idx);
}
static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
SmallVectorImpl<Value> &args) {
const Location loc = env.op().getLoc();
const TensorId tid = env.makeTensorId(t->getOperandNumber());
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
if (stt.hasEncoding()) {
const auto pos = env.emitter().getValPosits(tid);
assert(!pos.empty());
args.append(pos);
} else {
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
for (Level l = 0; l < lvlRank; l++) {
const auto lvlExpr = map.getResult(l);
const auto lvlCrd = env.emitter().genAffine(builder, loc, lvlExpr);
args.push_back(lvlCrd);
}
}
return env.emitter().getValBuffer()[tid];
}
static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder,
OpOperand *t) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
if (!env.isExpand()) {
Type tp = getElementTypeOrSelf(t->get().getType());
return constantZero(builder, loc, tp);
}
Value index = genIndex(env, t);
return builder.create<memref::LoadOp>(loc, env.getExpandValues(), index);
}
static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder,
OpOperand *t) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
Value identity = env.getCustomRedId();
if (!env.isExpand())
return identity;
Value values = env.getExpandValues();
Value filled = env.getExpandFilled();
Value index = genIndex(env, t);
Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
Value valAtIndex = builder.create<memref::LoadOp>(loc, values, index);
return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity);
}
static Value genConditionalInsert(Location loc, OpBuilder &builder, Value cond,
Value sparseOut, ValueRange ivs, Value v) {
scf::IfOp condInsert =
builder.create<scf::IfOp>(loc, sparseOut.getType(), cond, true);
builder.setInsertionPointToStart(condInsert.thenBlock());
Value res = builder.create<tensor::InsertOp>(loc, v, sparseOut, ivs);
builder.create<scf::YieldOp>(loc, res);
builder.setInsertionPointToStart(condInsert.elseBlock());
builder.create<scf::YieldOp>(loc, sparseOut);
builder.setInsertionPointAfter(condInsert);
return condInsert.getResult(0);
}
static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
Value rhs) {
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
if (!env.isExpand()) {
const LoopId numLoops = op.getRank(t);
SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
Value chain = env.getInsertionChain();
if (env.isValidLexInsert()) {
Value out = genConditionalInsert(loc, builder, env.getValidLexInsert(),
chain, ivs, rhs);
env.updateInsertionChain(out);
} else {
Value sparseOut;
if (!hasAnySparseType(env.op().getInputs().getTypes())) {
Value nz = genIsNonzero(builder, loc, rhs);
sparseOut = genConditionalInsert(loc, builder, nz, chain, ivs, rhs);
} else {
sparseOut = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
}
env.updateInsertionChain(sparseOut);
}
return;
}
Value values = env.getExpandValues();
Value filled = env.getExpandFilled();
Value added = env.getExpandAdded();
Value count = env.getExpandCount();
Value index = genIndex(env, t);
Value fval = constantI1(builder, loc, false);
Value tval = constantI1(builder, loc, true);
Value isFilled = builder.create<memref::LoadOp>(loc, filled, index);
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
isFilled, fval);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond,
true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<memref::StoreOp>(loc, tval, filled, index);
builder.create<memref::StoreOp>(loc, index, added, count);
Value one = constantIndex(builder, loc, 1);
Value add = builder.create<arith::AddIOp>(loc, count, one);
builder.create<scf::YieldOp>(loc, add);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, count);
builder.setInsertionPointAfter(ifOp);
env.updateExpandCount(ifOp.getResult(0));
builder.create<memref::StoreOp>(loc, rhs, values, index);
}
static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value val = env.exp(exp).val;
if (val)
return val;
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
const auto stt = getSparseTensorType(t->get());
if (auto explVal = stt.getExplicitVal())
return genValFromAttr(builder, loc, explVal);
if (env.isSparseOutput(t)) {
if (env.isCustomReduc())
return genInsertionLoadReduce(env, builder, t);
return genInsertionLoad(env, builder, t);
}
SmallVector<Value> args;
Value ptr = genSubscript(env, builder, t, args);
return builder.create<memref::LoadOp>(loc, ptr, args);
}
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
Value rhs) {
if (!rhs) {
assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
env.exp(exp).kind == TensorExp::Kind::kBinary ||
env.exp(exp).kind == TensorExp::Kind::kReduce);
return;
}
if (env.isReduc()) {
env.updateReduc(rhs);
return;
}
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
OpOperand *t = op.getDpsInitOperand(0);
if (!env.isSparseOutput(t)) {
SmallVector<Value> args;
Value ptr = genSubscript(env, builder, t, args);
builder.create<memref::StoreOp>(loc, rhs, ptr, args);
return;
}
if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
genInsertionStore(env, builder, t, rhs);
return;
}
Value chain = env.getInsertionChain();
scf::IfOp ifOp =
builder.create<scf::IfOp>(loc, chain.getType(), rhs, true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
assert(env.exp(exp).val);
Value v0 = env.exp(exp).val;
genInsertionStore(env, builder, t, v0);
env.merger().clearExprValue(exp);
Value mchain = env.getInsertionChain();
builder.create<scf::YieldOp>(op.getLoc(), mchain);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, chain);
env.updateInsertionChain(ifOp->getResult(0));
builder.setInsertionPointAfter(ifOp);
}
inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
return env.exp(exp).val;
}
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
Value e) {
if (auto arg = dyn_cast<BlockArgument>(e)) {
linalg::GenericOp op = env.op();
if (arg.getOwner()->getParentOp() == op) {
const TensorId tid = env.makeTensorId(arg.getArgNumber());
OpOperand *t = &op->getOpOperand(tid);
assert(!getSparseTensorType(t->get()).hasEncoding());
SmallVector<Value> args;
Value ptr = genSubscript(env, rewriter, t, args);
return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
}
} else if (Operation *def = e.getDefiningOp()) {
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
if (def->getBlock() == block) {
rewriter.setInsertionPoint(def);
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.modifyOpInPlace(def, [&]() {
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
}
}
}
return e;
}
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
if (e == ::mlir::sparse_tensor::detail::kInvalidId)
return Value();
linalg::GenericOp op = env.op();
Location loc = op.getLoc();
const TensorExp &exp = env.exp(e);
const auto kind = exp.kind;
if (kind == TensorExp::Kind::kTensor)
return genTensorLoad(env, rewriter, e);
if (kind == TensorExp::Kind::kInvariant)
return genInvariantValue(env, e);
if (kind == TensorExp::Kind::kLoopVar)
return env.getLoopVar(exp.loop);
if (kind == TensorExp::Kind::kReduce)
env.startCustomReduc(e);
Value v0, v1;
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
v1 = genExp(env, rewriter, exp.children.e1);
v0 = constantZero(rewriter, loc, v1.getType());
} else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
v0 = genExp(env, rewriter, exp.children.e0);
v1 = constantZero(rewriter, loc, v0.getType());
} else {
v0 = genExp(env, rewriter, exp.children.e0);
v1 = genExp(env, rewriter, exp.children.e1);
}
Value ee;
if (kind == TensorExp::Kind::kReduce && (!v0 || !v1)) {
} else {
ee = env.merger().buildExp(rewriter, loc, e, v0, v1);
if (ee &&
(kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
kind == TensorExp::Kind::kBinaryBranch ||
kind == TensorExp::Kind::kReduce ||
kind == TensorExp::Kind::kSelect)) {
OpBuilder::InsertionGuard guard(rewriter);
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
}
}
if (kind == TensorExp::Kind::kReduce)
env.endCustomReduc();
if (kind == TensorExp::Kind::kSelect)
env.merger().setExprValue(e, v0);
return ee;
}
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId curr, bool isStart) {
if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
return;
if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
linalg::GenericOp op = env.op();
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
const auto map = op.getMatchingIndexingMap(&t);
const auto stt = getSparseTensorType(t.get());
const Level lvlRank = stt.getLvlRank();
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
bool isCurrentLoop = curr == 0;
for (Level l = 0; l < lvlRank; l++) {
const AffineExpr a = map.getResult(l);
if (!isInvariantAffine(a, curr, isCurrentLoop))
return;
}
if (!isCurrentLoop)
return;
OpOperand *lhs = op.getDpsInitOperand(0);
if (lhs == &t) {
if (isStart) {
if (env.isCustomReduc()) {
if (!env.isReduc())
env.startReduc(exp, env.getCustomRedId());
} else {
env.startReduc(exp, genTensorLoad(env, builder, exp));
}
if (env.hasSparseOutput())
env.startValidLexInsert(
constantI1(builder, env.op().getLoc(), false));
} else {
if (!env.isCustomReduc() || env.isReduc())
genTensorStore(env, builder, exp, env.endReduc());
if (env.hasSparseOutput())
env.endValidLexInsert();
}
} else {
if (isStart) {
env.merger().setExprValue(exp, genTensorLoad(env, builder, exp));
} else {
env.merger().clearExprValue(exp);
}
}
} else if (env.exp(exp).kind != TensorExp::Kind::kInvariant &&
env.exp(exp).kind != TensorExp::Kind::kLoopVar &&
env.exp(exp).kind != TensorExp::Kind::kSynZero) {
if (env.exp(exp).kind == TensorExp::Kind::kReduce)
env.startCustomReduc(exp);
const ExprId e0 = env.exp(exp).children.e0;
const ExprId e1 = env.exp(exp).children.e1;
genInvariants(env, builder, e0, curr, isStart);
genInvariants(env, builder, e1, curr, isStart);
if (env.exp(exp).kind == TensorExp::Kind::kReduce)
env.endCustomReduc();
}
}
static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopId curr,
bool isStart) {
linalg::GenericOp op = env.op();
OpOperand *lhs = op.getDpsInitOperand(0);
if (!env.atExpandLevel(lhs, op.getRank(lhs), curr))
return;
assert(!env.isReduc());
Value tensor = lhs->get();
Location loc = op.getLoc();
if (isStart) {
auto dynShape = {ShapedType::kDynamic};
Type etp = cast<ShapedType>(tensor.getType()).getElementType();
Type t1 = MemRefType::get(dynShape, etp);
Type t2 = MemRefType::get(dynShape, builder.getI1Type());
Type t3 = MemRefType::get(dynShape, builder.getIndexType());
Type t4 = builder.getIndexType();
auto r = builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor);
assert(r.getNumResults() == 4);
env.startExpand(r.getResult(0), r.getResult(1), r.getResult(2),
r.getResult(3));
} else {
SmallVector<Value> indices;
for (LoopId i = 0; i < curr; i++)
indices.push_back(env.emitter().getLoopIV(i));
Value values = env.getExpandValues();
Value filled = env.getExpandFilled();
Value added = env.getExpandAdded();
Value count = env.getExpandCount();
Value chain = env.getInsertionChain();
Value compress = builder.create<CompressOp>(loc, values, filled, added,
count, chain, indices);
env.updateInsertionChain(compress);
env.endExpand();
}
}
static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
if (env.hasSparseOutput())
return false;
if (env.isExpand())
return false;
switch (env.options().parallelizationStrategy) {
case SparseParallelizationStrategy::kNone:
return false;
case SparseParallelizationStrategy::kDenseOuterLoop:
return isOuter && !isSparse;
case SparseParallelizationStrategy::kAnyStorageOuterLoop:
return isOuter;
case SparseParallelizationStrategy::kDenseAnyLoop:
return !isSparse;
case SparseParallelizationStrategy::kAnyStorageAnyLoop:
return true;
}
llvm_unreachable("unexpected parallelization strategy");
}
static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
ArrayRef<TensorLevel> tidLvls) {
linalg::GenericOp op = env.op();
auto iteratorTypes = op.getIteratorTypesArray();
bool isSparse = llvm::any_of(tidLvls, [curr, &env](TensorLevel tidLvl) {
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
return lt.hasSparseSemantic();
});
return isParallelFor(env, curr == 0, isSparse);
}
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
ArrayRef<TensorLevel> tidLvls,
bool tryParallel, bool needsUniv) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
return env.emitter().enterCoIterationOverTensorsAtLvls(
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
});
assert(loop);
return loop;
}
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
bool tryParallel = shouldTryParallize(env, curr, tidLvls);
return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
}
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
bool needsUniv) {
Location loc = env.op().getLoc();
if (env.isReduc() || env.isExpand() || env.getInsertionChain()) {
while (auto ifOp = dyn_cast_or_null<scf::IfOp>(
builder.getInsertionBlock()->getParentOp())) {
if (ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ==
StringAttr::get(ifOp->getContext(), "slice"))
break;
unsigned y = 0;
SmallVector<Value> yields;
if (env.isReduc()) {
yields.push_back(env.getReduc());
env.updateReduc(ifOp.getResult(y++));
if (env.isValidLexInsert()) {
yields.push_back(env.getValidLexInsert());
env.updateValidLexInsert(ifOp.getResult(y++));
}
}
if (env.isExpand()) {
yields.push_back(env.getExpandCount());
env.updateExpandCount(ifOp->getResult(y++));
}
if (env.getInsertionChain()) {
yields.push_back(env.getInsertionChain());
env.updateInsertionChain(ifOp->getResult(y++));
}
assert(y == yields.size());
builder.create<scf::YieldOp>(loc, yields);
builder.setInsertionPointAfter(ifOp);
}
}
}
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
LatPointId p) {
Location loc = env.op().getLoc();
SmallVector<Type> types;
Value cond;
env.merger().foreachTensorLoopId(
p, true,
[&](TensorLoopId b, TensorId tid, std::optional<Level> lvl, LevelType lt,
bool isIdxRed) {
if (isIdxRed) {
assert(lvl.has_value() && isUndefLT(lt));
auto stt = getSparseTensorType(env.op().getInputs()[tid]);
lt = stt.getLvlType(*lvl);
}
assert(curr == env.merger().loop(b));
Value clause;
if (lt.hasSparseSemantic()) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
} else {
assert(lt.hasDenseSemantic() || isUndefLT(lt));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
});
if (env.isReduc()) {
types.push_back(env.getReduc().getType());
if (env.isValidLexInsert())
types.push_back(env.getValidLexInsert().getType());
}
if (env.isExpand())
types.push_back(builder.getIndexType());
if (env.getInsertionChain())
types.push_back(env.getInsertionChain().getType());
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
return ifOp;
}
static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
Value redInput, Value cntInput, Value insInput,
Value validIns) {
SmallVector<Value> operands;
if (env.isReduc()) {
operands.push_back(env.getReduc());
env.updateReduc(redInput);
if (env.isValidLexInsert()) {
operands.push_back(constantI1(builder, env.op().getLoc(), true));
env.updateValidLexInsert(validIns);
}
}
if (env.isExpand()) {
operands.push_back(env.getExpandCount());
env.updateExpandCount(cntInput);
}
if (env.getInsertionChain()) {
operands.push_back(env.getInsertionChain());
env.updateInsertionChain(insInput);
}
if (!operands.empty())
builder.create<scf::YieldOp>(env.op().getLoc(), operands);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
}
static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
unsigned numloopCond = 0;
bool hasNonUnique = false;
env.merger().foreachTensorLoopId(
li, [&, curr](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
if (isUndefLT(lt)) {
if (env.merger().getSynTensorID() == tid) {
assert(curr == env.getCurrentDepth());
lvl = curr;
} else if (!lvl) {
return;
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (lt.hasDenseSemantic() || isIdxReduc) {
callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
if (tid >= op.getNumDpsInputs())
return;
OpOperand *operand = &op->getOpOperand(tid);
const auto stt = getSparseTensorType(operand->get());
if (!stt.hasEncoding())
return;
ArrayRef<AffineExpr> affines =
op.getMatchingIndexingMap(operand).getResults();
const Level lvlRank = stt.getLvlRank();
assert(affines.size() == static_cast<size_t>(lvlRank));
for (Level l = 0; l < lvlRank; l++) {
AffineExpr exp = affines[l];
LevelType lt = stt.getLvlType(l);
if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
continue;
if (!isa<AffineConstantExpr>(exp)) {
bool isCurrentLoop = false;
assert(curr == env.getCurrentDepth());
if (isInvariantAffine(exp, curr + 1, isCurrentLoop) &&
isCurrentLoop) {
callback(env.makeTensorLevel(tid, l), exp);
}
}
}
}
});
if (isDenseLT(env.lt(outTid, curr))) {
auto stt = getSparseTensorType(env.op().getOutputs().front());
if (stt.hasEncoding() && stt.isAllDense())
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
return numloopCond == 1 &&
(!hasNonUnique || env.options().sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator);
}
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId curr, LatSetId lts) {
assert(!env.getLoopVar(curr));
genInvariants(env, builder, exp, curr, true);
genExpand(env, builder, curr, true);
const LatPointId l0 = env.set(lts)[0];
SmallVector<TensorLevel> tidLvls;
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
if (llvm::find(tidLvls, tl) != tidLvls.end())
return;
tidLvls.emplace_back(tl);
});
env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
for (const LatPointId li : env.set(lts).drop_front())
if (!env.merger().hasAnySparse(env.lat(li).simple))
return true;
return false;
}
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
OpBuilder &builder, TensorId tid,
Level startLvl) {
linalg::GenericOp op = env.op();
assert(tid < op.getNumDpsInputs());
OpOperand *input = op.getDpsInputOperands()[tid];
const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
const auto enc = getSparseTensorEncoding(input->get().getType());
if (enc) {
const Location loc = op.getLoc();
const TensorId tid = env.makeTensorId(input->getOperandNumber());
const Level lvlRank = enc.getLvlRank();
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
if (enc.getLvlType(l).hasDenseSemantic() &&
isa<AffineConstantExpr>(lvlExpr))
env.emitter().locateLvlAtAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
return;
}
}
}
static void genInitConstantDenseAddress(CodegenEnv &env,
RewriterBase &rewriter) {
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}
static bool translateBitsToTidLvlPairs(
CodegenEnv &env, LatPointId li, LoopId curr,
SmallVectorImpl<TensorLevel> &tidLvls,
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
return getAllTidLvlsInLatPoints(env, li, curr,
[&](TensorLevel tl, AffineExpr exp) {
if (exp)
affineTidLvls.emplace_back(tl, exp);
else
tidLvls.emplace_back(tl);
});
}
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
LatPointId li, bool needsUniv) {
SmallVector<TensorLevel> tidLvls;
SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
bool isSingleCond =
translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
}
auto allTidLvls =
llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
if (tid != env.merger().getOutTensorID() &&
tid != env.merger().getSynTensorID())
genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
}
return std::make_pair(loop, isSingleCond);
}
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
LatPointId li, bool needsUniv, bool isSingleCond) {
if (isSingleCond) {
if (env.isReduc() && env.isValidLexInsert())
env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
finalizeWhileOp(env, rewriter, needsUniv);
} else {
needsUniv = false;
}
env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
return std::nullopt;
});
return needsUniv;
}
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
unsigned at) {
assert(!env.getLoopVar(at));
env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
genInvariants(env, builder, exp, at, false);
genExpand(env, builder, at, false);
}
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopId curr) {
assert(curr == env.getCurrentDepth());
if (curr == env.getLoopNum()) {
Value rhs = genExp(env, rewriter, exp);
genTensorStore(env, rewriter, exp, rhs);
return;
}
const LatSetId lts =
env.merger().optimizeSet(env.merger().buildLattices(exp, curr));
bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
const unsigned lsize = env.set(lts).size();
for (unsigned i = 0; i < lsize; i++) {
const LatPointId li = env.set(lts)[i];
auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
if (li == lj || env.merger().latGT(li, lj)) {
if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
genStmt(env, rewriter, ej, curr + 1);
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
genStmt(env, rewriter, ej, curr + 1);
}
}
}
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
}
endLoopSeq(env, rewriter, exp, curr);
assert(curr == env.getCurrentDepth());
}
static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
linalg::GenericOp op = env.op();
OpOperand *lhs = op.getDpsInitOperand(0);
Value tensor = lhs->get();
Type resType = tensor.getType();
if (getSparseTensorEncoding(resType)) {
bool hasInserts = false;
if (Value chain = env.getInsertionChain()) {
hasInserts = true;
tensor = chain;
}
rewriter.replaceOpWithNewOp<LoadOp>(op, resType, tensor, hasInserts);
} else {
Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
}
}
namespace {
struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
public:
GenericOpSparsifier(MLIRContext *context, SparsificationOptions o)
: OpRewritePattern<linalg::GenericOp>(context), options(o) {}
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
if (op.getNumDpsInits() != 1 || !op.hasPureTensorSemantics())
return failure();
if (hasNonTrivialAffineOnSparseOut(op))
return failure();
if (!op->hasAttr("sorted")) {
return rewriter.notifyMatchFailure(
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
"before sparsification.");
}
assert(!hasAnyNonIdentityOperandsOrResults(op));
const unsigned numTensors = op->getNumOperands();
const unsigned numLoops = op.getNumLoops();
bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
Level maxLvlRank = 0;
for (auto operand : op.getOperands()) {
if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
if (!findSparseAnnotations(env, needIdxRed))
return failure();
if (op.getNumReductionLoops() > 0) {
Operation *yield = op.getRegion().front().getTerminator();
assert(isa<linalg::YieldOp>(yield));
Operation *redop = yield->getOperand(0).getDefiningOp();
if (!isa<arith::AddFOp>(redop) && !isa<complex::AddOp>(redop) &&
!isa<arith::AddIOp>(redop) && !isa<arith::SubFOp>(redop) &&
!isa<complex::SubOp>(redop) && !isa<arith::SubIOp>(redop) &&
!isa<arith::OrIOp>(redop) && !isa<arith::XOrIOp>(redop) &&
!isa<ReduceOp>(redop)) {
return failure();
}
}
if (failed(env.initTensorExp()))
return failure();
env.startEmit(options.sparseEmitStrategy);
genBuffers(env, rewriter);
genInitConstantDenseAddress(env, rewriter);
genStmt(env, rewriter, env.getExprId(), 0);
genResult(env, rewriter);
return success();
}
private:
SparsificationOptions options;
};
}
void mlir::populateSparsificationPatterns(
RewritePatternSet &patterns, const SparsificationOptions &options) {
patterns.add<GenericOpSparsifier>(patterns.getContext(), options);
}