#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::linalg;
using namespace mlir::sparse_tensor;
static bool isZeroValue(Value val) {
return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
}
static bool isSparseTensor(Value v) {
auto enc = getSparseTensorEncoding(v.getType());
return enc && !llvm::all_of(enc.getLvlTypes(),
[](auto lt) { return lt == LevelFormat::Dense; });
}
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
static bool isMaterializing(OpOperand *op, bool isZero) {
Value val = op->get();
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Value copy = alloc.getCopy();
if (isZero)
return copy && isZeroValue(copy);
return !copy;
}
if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
return !isZero;
return isZero && isZeroValue(val);
}
static bool isSampling(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
Value s1 = op.getBlock()->getArgument(0);
Value s2 = op.getBlock()->getArgument(1);
return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
(def->getOperand(1) == s1 && def->getOperand(0) == s2);
}
}
return false;
}
static bool isMulChain(Value val, Value x) {
if (auto arg = dyn_cast<BlockArgument>(val))
return arg != x;
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
return isMulChain(def->getOperand(0), x) &&
isMulChain(def->getOperand(1), x);
}
return false;
}
static bool isSumOfMul(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
Value x = op.getBlock()->getArguments().back();
return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
(def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
}
}
return false;
}
static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
if (arg.getOwner()->getParentOp() == op) {
return isZeroValue(op->getOperand(arg.getArgNumber()));
}
}
return isZeroValue(yieldOp.getOperand(0));
}
static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, ShapedType stp, Value tensor) {
for (const auto &d : enumerate(stp.getShape())) {
Value dim;
if (d.value() == ShapedType::kDynamic)
dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
else
dim = constantIndex(builder, loc, d.value());
sizes.push_back(dim);
}
}
static RankedTensorType getBufferType(const SparseTensorType &stt,
bool needTmpCOO) {
return needTmpCOO ? stt.getCOOType(false)
: stt.getRankedTensorType();
}
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
SmallVectorImpl<Value> &dynSizes) {
for (const auto &d : enumerate(tp.getShape())) {
if (d.value() == ShapedType::kDynamic)
dynSizes.push_back(sizes[d.index()]);
}
}
static LogicalResult genForeachOnSparseConstant(ForeachOp op,
RewriterBase &rewriter,
SparseElementsAttr attr) {
auto loc = op.getLoc();
SmallVector<Value> reduc = op.getInitArgs();
foreachInSparseConstant(
rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
[&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
SmallVector<Value> args;
args.append(cvs.begin(), cvs.end());
args.push_back(v);
args.append(reduc);
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
assert(args.size() == cloned.getBody()->getNumArguments());
Operation *yield = cloned.getBody()->getTerminator();
rewriter.inlineBlockBefore(cloned.getBody(), op, args);
rewriter.eraseOp(cloned);
reduc = yield->getOperands();
rewriter.eraseOp(yield);
});
rewriter.replaceOp(op, reduc);
return success();
}
static void concatSizesFromInputs(OpBuilder &builder,
SmallVectorImpl<Value> &sizes, Location loc,
ShapedType dstTp, ValueRange srcs,
unsigned dim) {
auto dstShape = dstTp.getShape();
sizesFromSrc(builder, sizes, loc, srcs[0]);
if (dstShape[dim] != ShapedType::kDynamic) {
sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
} else {
for (const auto &src : srcs.drop_front()) {
Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
}
}
}
namespace {
struct FuseExtractSliceWithConcat
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
PatternRewriter &rewriter) const override {
auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
if (!concatOp)
return failure();
Location loc = extractOp.getLoc();
int64_t dim = concatOp.getDim();
int64_t rank = extractOp.getResultType().getRank();
SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
AffineExpr sum = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> partialSums = {sum};
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
for (auto [idx, input] :
llvm::enumerate(concatOp.getInputs().drop_back())) {
sum = sum + rewriter.getAffineDimExpr(idx + 1);
partialSums.push_back(sum);
offsetStrides.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
}
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
partialSums, rewriter.getContext());
SmallVector<OpFoldResult> dimOffsets =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, loc, partialSumMap, offsetStrides);
auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
for (auto [l, r] : llvm::zip(lhs, rhs)) {
std::optional<int64_t> staticVal = getConstantIntValue(l);
if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
return false;
}
return lhs.size() == rhs.size();
};
for (auto [i, input, offset] :
llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
SmallVector<OpFoldResult> srcSizes =
tensor::getMixedSizes(rewriter, loc, input);
srcOffsets[dim] = offset;
SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
allEqual(srcStrides, dstStrides)) {
Value operand = concatOp.getOperand(i);
if (operand.getType() == extractOp.getResultType())
rewriter.replaceOp(extractOp, operand);
break;
}
}
return success();
}
};
struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
auto producer = op.getSource().getDefiningOp<GenericOp>();
if (!producer || producer.getDpsInits().size() != 1 ||
!isMaterializing(producer.getDpsInitOperand(0), false) ||
!producer.getResult(0).hasOneUse()) {
return failure();
}
rewriter.setInsertionPoint(producer);
Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
Operation *cloned = rewriter.clone(*init);
cloned->getResult(0).setType(op.getResult().getType());
rewriter.modifyOpInPlace(producer, [&]() {
producer.getDpsInitsMutable().assign(cloned->getResults());
producer.getResult(0).setType(op.getResult().getType());
});
rewriter.replaceAllOpUsesWith(op, producer);
op->erase();
return success();
}
};
struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
!isMaterializing(op.getDpsInitOperand(0), false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
auto outputType = getRankedTensorType(op.getResult(0));
if (getSparseTensorEncoding(outputType)) {
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
if (!outputType.hasStaticShape())
return failure();
Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
rewriter.eraseOp(def);
return success();
}
};
struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
op.getNumResults() != 1 ||
op.getNumParallelLoops() != op.getNumLoops() ||
!op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
!op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
!op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
return failure();
unsigned other = 0;
if (isSparseTensor(op.getDpsInputOperand(0)))
other = 1;
else if (!isSparseTensor(op.getDpsInputOperand(1)))
return failure();
auto prod = dyn_cast_or_null<GenericOp>(
op.getDpsInputOperand(other)->get().getDefiningOp());
if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
!prod.getResult(0).hasOneUse())
return failure();
if (!isMaterializing(op.getDpsInitOperand(0), false) ||
!isMaterializing(prod.getDpsInitOperand(0), true) ||
!isSampling(op) || !isSumOfMul(prod))
return failure();
Location loc = prod.getLoc();
SmallVector<Value> inputOps = prod.getInputs();
SmallVector<Value> outputOps = op.getOutputs();
SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
fusedIndexMaps.push_back(fusedIndexMaps.back());
auto fusedOp = rewriter.create<GenericOp>(
loc, op.getResult(0).getType(), inputOps, outputOps,
rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
nullptr, nullptr);
Block &prodBlock = prod.getRegion().front();
Block &consBlock = op.getRegion().front();
IRMapping mapper;
Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
unsigned num = prodBlock.getNumArguments();
for (unsigned i = 0; i < num - 1; i++)
addArg(mapper, fusedBlock, prodBlock.getArgument(i));
addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
Value last;
for (auto &op : prodBlock.without_terminator())
if (&op != acc) {
last = op.getResult(0);
rewriter.clone(op, mapper);
}
mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
last = rewriter.clone(*acc, mapper)->getResult(0);
rewriter.create<linalg::YieldOp>(loc, last);
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
Value init = prod.getDpsInitOperand(0)
->get()
.getDefiningOp<AllocTensorOp>()
.getCopy();
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
}
rewriter.replaceOp(op, fusedOp->getResults());
return success();
}
private:
static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
}
};
struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
public:
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
Type srcType = op.getSource().getType();
Type dstType = op.getDest().getType();
if (srcType == dstType) {
rewriter.replaceOp(op, op->getResults());
return success();
}
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
if (Operation *def = op.getSource().getDefiningOp()) {
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
rewriter.modifyOpInPlace(def, [&]() {
def->getResult(0).setType(op->getResultTypes()[0]);
});
rewriter.replaceOp(op, def->getResult(0));
return success();
}
}
}
if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
return success();
}
return failure();
}
};
struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
return failure();
Location loc = op.getLoc();
SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
for (Operation &inst : *op.getBody()) {
auto matched = isRewritablePattern(op, &inst);
if (!matched.has_value())
continue;
rewriter.setInsertionPoint(&inst);
auto [c, t, f] = matched.value();
assert(t.getType() == f.getType());
auto selTp = t.getType();
auto c0 = constantZero(rewriter, loc, selTp);
auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
{t.getLoc(), f.getLoc()});
rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
for (auto *r : binOp.getRegions()) {
Block *b = &r->front();
rewriter.setInsertionPointToStart(b);
IRMapping irMap;
Value newC = c;
if (auto *def = c.getDefiningOp())
newC = rewriter.clone(*def, irMap)->getResult(0);
irMap.map(c, newC);
if (r == &binOp.getLeftRegion()) {
irMap.map(t, b->getArgument(0));
irMap.map(f, c0);
} else if (r == &binOp.getRightRegion()) {
irMap.map(t, c0);
irMap.map(f, b->getArgument(0));
} else {
irMap.map(t, b->getArgument(0));
irMap.map(f, b->getArgument(1));
}
auto y = rewriter.clone(inst, irMap)->getResult(0);
rewriter.create<sparse_tensor::YieldOp>(loc, y);
}
semiRings.emplace_back(&inst, binOp);
}
for (auto [sel, semi] : semiRings)
rewriter.replaceOp(sel, semi->getResults());
return success(!semiRings.empty());
}
private:
static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
isRewritablePattern(GenericOp op, Operation *v) {
auto sel = dyn_cast<arith::SelectOp>(v);
if (!sel)
return std::nullopt;
auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
if (!tVal || !fVal)
return std::nullopt;
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
if (auto bArg = dyn_cast<BlockArgument>(v);
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
return true;
return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
};
auto cond = sel.getCondition();
if (isValFromDenseInputOrInvariant(cond))
return std::make_tuple(cond, tVal, fVal);
Value cmpL, cmpR;
if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
matchers::m_Any(&cmpR))) ||
matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
matchers::m_Any(&cmpR)))) {
if (isValFromDenseInputOrInvariant(cmpL) ||
isValFromDenseInputOrInvariant(cmpR))
return std::make_tuple(cond, tVal, fVal);
}
return std::nullopt;
};
};
struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
public:
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
return failure();
auto *inp = op.getDpsInputOperand(0);
auto *init = op.getDpsInitOperand(0);
if (!isSparseTensor(inp))
return failure();
auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
.getOperand(0)
.getDefiningOp();
if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
arith::MaxUIOp>(red))
return failure();
Value s0 = op.getBlock()->getArgument(0);
Value s1 = op.getBlock()->getArgument(1);
if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
(red->getOperand(0) != s1 || red->getOperand(1) != s0))
return failure();
Location loc = op.getLoc();
Value identity =
rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
Type rtp = s0.getType();
rewriter.setInsertionPointToStart(&op.getRegion().front());
auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
Block *present =
rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
auto zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
rewriter.create<sparse_tensor::YieldOp>(loc, zero);
rewriter.setInsertionPointAfter(semiring);
auto custom = rewriter.create<sparse_tensor::ReduceOp>(
loc, rtp, semiring.getResult(), s1, identity);
Block *region =
rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
rewriter.setInsertionPointToStart(&custom.getRegion().front());
IRMapping irMap;
irMap.map(red->getOperand(0), region->getArgument(0));
irMap.map(red->getOperand(1), region->getArgument(1));
auto *cloned = rewriter.clone(*red, irMap);
rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
rewriter.setInsertionPointAfter(custom);
rewriter.replaceOp(red, custom.getResult());
return success();
}
};
struct PrintRewriter : public OpRewritePattern<PrintOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrintOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto tensor = op.getTensor();
auto stt = getSparseTensorType(tensor);
auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
rewriter.create<vector::PrintOp>(
loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
rewriter.create<vector::PrintOp>(loc, nse);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
printSizes(rewriter, loc, tensor, stt.getDimRank(), true);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
printSizes(rewriter, loc, tensor, stt.getLvlRank(), false);
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
&stt](Type, FieldIndex,
SparseTensorFieldKind kind,
Level l, LevelType) {
switch (kind) {
case SparseTensorFieldKind::StorageSpec: {
break;
}
case SparseTensorFieldKind::PosMemRef: {
auto lvl = constantIndex(rewriter, loc, l);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
printContents(rewriter, loc, pos);
break;
}
case SparseTensorFieldKind::CrdMemRef: {
auto lvl = constantIndex(rewriter, loc, l);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
Value crd = nullptr;
if (stt.getAoSCOOStart() == l)
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
else
crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
printContents(rewriter, loc, crd);
break;
}
case SparseTensorFieldKind::ValMemRef: {
rewriter.create<vector::PrintOp>(loc,
rewriter.getStringAttr("values : "));
auto val = rewriter.create<ToValuesOp>(loc, tensor);
printContents(rewriter, loc, val);
break;
}
}
return true;
});
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
rewriter.eraseOp(op);
return success();
}
private:
static void printContents(PatternRewriter &rewriter, Location loc,
Value vec) {
auto shape = cast<ShapedType>(vec.getType()).getShape();
SmallVector<Value> idxs;
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
Value vec, unsigned i, ArrayRef<int64_t> shape,
SmallVectorImpl<Value> &idxs) {
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
auto zero = constantIndex(rewriter, loc, 0);
auto index = constantIndex(rewriter, loc, i);
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
auto step = constantIndex(rewriter, loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
idxs.push_back(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
if (i < shape.size() - 1) {
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
} else {
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
if (llvm::isa<ComplexType>(val.getType())) {
Value real = rewriter.create<complex::ReOp>(loc, val);
Value imag = rewriter.create<complex::ImOp>(loc, val);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
rewriter.create<vector::PrintOp>(loc, real,
vector::PrintPunctuation::Comma);
rewriter.create<vector::PrintOp>(loc, imag,
vector::PrintPunctuation::Close);
} else {
rewriter.create<vector::PrintOp>(
loc, val, vector::PrintPunctuation::NoPunctuation);
}
auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
bound, size);
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
}
idxs.pop_back();
rewriter.setInsertionPointAfter(forOp);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
}
static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
unsigned size, bool isDim) {
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
for (unsigned i = 0; i < size; i++) {
auto idx = constantIndex(rewriter, loc, i);
Value val;
if (isDim)
val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
else
val = rewriter.create<LvlOp>(loc, tensor, idx);
rewriter.create<vector::PrintOp>(
loc, val,
i != size - 1 ? vector::PrintPunctuation::Comma
: vector::PrintPunctuation::NoPunctuation);
}
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
};
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
public:
using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSource();
const auto srcTp = getSparseTensorType(srcTensor);
const auto dstTp = getSparseTensorType(op.getResult());
if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
!dstTp.hasStaticDimShape())
return failure();
SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
SmallVector<Value> dstSizes;
for (Dimension d : dstTp.getDimShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
Type bufferTp = getBufferType(
dstTp.withoutDimToLvl(),
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
SmallVector<Value> dynSizes;
Value buffer = rewriter
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
nnz, Attribute())
.getResult();
const auto encSrc = srcTp.getEncoding();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
const Dimension srcRank = srcTp.getDimRank();
SmallVector<Value> srcDcvs;
srcDcvs.reserve(srcRank);
for (Dimension d = 0; d < srcRank; d++) {
Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
Value collapseSize = constantIndex(builder, loc, 1);
for (Dimension d = 0; d < srcRank; d++)
collapseSize =
builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
SmallVector<Value, 1> collapsedSizes = {collapseSize};
ReassociationIndices collapseIdx;
for (Dimension i = 0; i < srcRank; i++)
collapseIdx.push_back(i);
SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
SmallVector<Value, 1> collapsedDcvs;
reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
collapsedSizes, collapsedDcvs);
ReassociationIndices expandIdx;
for (Dimension i = 0; i < dstTp.getDimRank(); i++)
expandIdx.push_back(i);
SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
SmallVector<Value> dstDcvs;
reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
dstSizes, dstDcvs);
auto t =
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
if (bufferTp != dstTp) {
auto dstRTT = dstTp.getRankedTensorType();
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
t = converted;
}
rewriter.replaceOp(op, t);
return success();
}
};
template <typename ReshapeOp>
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
const auto srcTp = getSparseTensorType(srcTensor);
const auto dstTp = getSparseTensorType(op.getResult());
if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
return failure();
SmallVector<Value> srcSizes;
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
SmallVector<Value> dstSizes;
SmallVector<Value> dstDynSizes;
if (dstTp.hasStaticDimShape()) {
for (Dimension d : dstTp.getDimShape())
dstSizes.push_back(constantIndex(rewriter, loc, d));
} else {
ArrayRef<Size> dstShape = dstTp.getDimShape();
genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
op.getReassociationIndices());
for (auto [idx, shape] : llvm::enumerate(dstShape)) {
if (shape == ShapedType::kDynamic)
dstDynSizes.push_back(dstSizes[idx]);
}
}
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
Type bufferTp = getBufferType(
dstTp.withoutDimToLvl(),
!srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
Value buffer =
rewriter
.create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
nnz, Attribute())
.getResult();
const auto encSrc = srcTp.getEncoding();
ForeachOp foreachOp = rewriter.create<ForeachOp>(
loc, srcTensor, buffer,
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
ValueRange reduc) {
const Dimension dimRank = srcTp.getDimRank();
SmallVector<Value> srcDcvs;
srcDcvs.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++) {
Level lvl = toLvl(encSrc, d);
srcDcvs.push_back(srcLcvs[lvl]);
}
SmallVector<Value> dstDcvs;
reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
srcDcvs, dstSizes, dstDcvs);
auto t =
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
if (bufferTp != dstTp) {
auto dstRTT = dstTp.getRankedTensorType();
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
rewriter.create<DeallocTensorOp>(loc, t);
t = converted;
}
rewriter.replaceOp(op, t);
return success();
}
};
template <typename ReshapeOp>
struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
public:
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(op.getResult().getType());
auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
if (encDst && encSrc) {
return failure();
}
if (encSrc) {
auto rtp = getRankedTensorType(op.getSrc());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
return success();
}
if (encDst) {
auto rtp = getRankedTensorType(op.getResult());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
ReshapeOp reshape;
if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
reshape = rewriter.create<ReshapeOp>(
loc, denseTp, op.getSrc(), op.getReassociation(),
op.getOutputShape(), op.getStaticOutputShape());
} else {
reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
op.getReassociation());
}
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
}
return failure();
}
};
struct TensorLike {
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
ValueRange sizes) {
SmallVector<Value> dynSzs;
getDynamicSizes(rtt, sizes, dynSzs);
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
if (!isSparse()) {
Value c0 = constantZero(builder, loc, rtt.getElementType());
val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
}
}
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}
Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
if (isSparse())
return builder.create<LoadOp>(loc, val, true);
return val;
}
bool isSparse() const {
return getSparseTensorEncoding(val.getType()) != nullptr;
}
Value val;
};
struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::DimOp op,
PatternRewriter &rewriter) const override {
std::optional<int64_t> dim = op.getConstantIndex();
auto stt = getSparseTensorType(op.getSource());
if (!dim || !stt.hasEncoding())
return failure();
if (stt.isPermutation()) {
rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
toLvl(stt.getEncoding(), *dim));
return success();
}
Location loc = op.getLoc();
SmallVector<Value> maxLvlCrds;
for (Level l = 0; l < stt.getLvlRank(); l++) {
Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
Value maxLvlCrd = rewriter.create<arith::SubIOp>(
loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
maxLvlCrds.push_back(maxLvlCrd);
}
AffineExpr lvl2DimExp = stt.getLvlToDim().getResult(*dim);
Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), AffineMap::get(stt.getLvlRank(), 0, lvl2DimExp),
maxLvlCrds);
Value dimSz = rewriter.create<arith::AddIOp>(
loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
rewriter.replaceOp(op, dimSz);
return success();
}
};
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
if (op.needsExtraSort())
op.emitError("ConcatenateOp not staged");
const Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op);
const Dimension conDim = op.getDimension();
SmallVector<Value> sizes;
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
Value iterArg = dstBuf.val;
ForeachOp foreachOp;
for (Value input : op.getInputs()) {
foreachOp = rewriter.create<ForeachOp>(
loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> offDimCrd(dcvs);
offDimCrd[conDim] =
builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insert(builder, loc, v, offDimCrd);
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointAfter(ifOp);
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insert(builder, loc, v, offDimCrd);
}
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
assert(!ShapedType::isDynamic(sz));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, sz));
iterArg = foreachOp.getResult(0);
dstBuf.val = iterArg;
}
dstBuf.val = iterArg;
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
}
};
struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertOp op,
PatternRewriter &rewriter) const override {
if (op.needsExtraSort())
return op.emitError("ConvertOp not staged.");
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (encDst && encSrc && !encSrc.isSlice() &&
encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
return failure();
}
Location loc = op.getLoc();
Value src = op.getSource();
SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());
bool fromSparseConst = false;
if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
fromSparseConst = true;
const AffineMapAttr foreachOrder =
(!dstStt.isIdentity() && fromSparseConst)
? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
: nullptr;
bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
SmallVector<Value> sizes;
sizesFromSrc(rewriter, sizes, loc, src);
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, dstBuf.val, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
dstBuf.val = reduc.front();
if (!skipZeroCheck) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insert(builder, loc, v, dcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);
builder.setInsertionPointAfter(ifOp);
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insert(builder, loc, v, dcvs);
}
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
rewriter.setInsertionPointAfter(foreachOp);
dstBuf.val = foreachOp.getResult(0);
Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
}
};
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CrdTranslateOp op,
PatternRewriter &rewriter) const override {
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
? op.getEncoder().getDimToLvl()
: op.getEncoder().getLvlToDim();
SmallVector<Value> outCrds;
for (AffineExpr result : map.getResults()) {
Value trans = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
op.getInCrds());
outCrds.push_back(trans);
}
rewriter.replaceOp(op, outCrds);
return success();
}
};
struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ForeachOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
const auto stt = getSparseTensorType(input);
const Level lvlRank = stt.getLvlRank();
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
return genForeachOnSparseConstant(op, rewriter, attr);
}
}
const auto enc = stt.getEncoding();
LoopEmitter loopEmitter(
ValueRange{input},
StringAttr::get(getContext(), ForeachOp::getOperationName()));
loopEmitter.initializeLoopEmit(rewriter, loc);
for (Level l = 0; l < lvlRank; l++) {
const SmallVector<TensorLevel, 1> tidLvls{
loopEmitter.makeTensorLevel(0, l)};
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
reduc);
}
SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
if (op.getOrder()) {
llvm_unreachable(
"Level order not yet implemented on non-constant input tensors.");
}
Value vals = loopEmitter.getValBuffer()[0];
SmallVector<Value> pos = loopEmitter.getValPosits(0);
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
: rewriter.create<memref::LoadOp>(loc, vals, lcvs);
Block *srcBlock = op.getBody();
SmallVector<Value> args =
enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
args.push_back(val);
args.append(reduc);
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
rewriter.eraseOp(srcBlock->getTerminator());
Operation &last = rewriter.getBlock()->back();
if (llvm::isa<scf::YieldOp>(last)) {
rewriter.setInsertionPoint(&last);
}
rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
rewriter.getInsertionPoint(), args);
rewriter.setInsertionPointToEnd(rewriter.getBlock());
for (Level l = 0; l < lvlRank; l++) {
loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
loopEmitter.exitCurrentLoopSeq(rewriter, loc);
}
rewriter.replaceOp(op, reducValue);
return success();
}
};
struct NewRewriter : public OpRewritePattern<NewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
return failure();
RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = stt.getCOOType(true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = cooTensor;
auto enc = stt.getEncoding();
if (!stt.isPermutation()) {
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
}
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
if (!stt.isPermutation())
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
rewriter.replaceOp(op, convert);
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);
return success();
}
};
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value src = op.getTensor();
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
const auto srcTp = getSparseTensorType(src);
const Dimension dimRank = srcTp.getDimRank();
Type indexTp = rewriter.getIndexType();
Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
SmallVector<Value> dims;
sizesForTensor(rewriter, dims, loc, srcTp, src);
for (Dimension d = 0; d < dimRank; d++) {
rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
constantIndex(rewriter, loc, d));
}
Type opaqueTp = getOpaquePointerType(rewriter);
Value writer =
createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
{op.getDest()}, EmitCInterface::Off)
.getResult(0);
Value rankValue = constantIndex(rewriter, loc, dimRank);
createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
{writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
Value dimCoords = dimSizes;
Type eltTp = srcTp.getElementType();
SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
for (Dimension d = 0; d < dimRank; d++) {
rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
constantIndex(builder, loc, d));
}
rewriter.create<memref::StoreOp>(loc, v, value);
SmallVector<Value> operands{writer, rankValue, dimCoords, value};
FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
EmitCInterface::On);
builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
builder.create<sparse_tensor::YieldOp>(loc);
});
createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
EmitCInterface::Off);
rewriter.eraseOp(op);
return success();
}
};
}
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT,
bool enableConvert) {
patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
patterns.getContext());
if (enableConvert)
patterns.add<DirectConvertRewriter>(patterns.getContext());
if (!enableRT)
patterns.add<NewRewriter>(patterns.getContext());
}
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());
}