#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
LogicalResult sparse_tensor::detail::stageWithSortImpl(
StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) {
if (!op.needsExtraSort())
return failure();
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
Type srcCOOTp = dstStt.getCOOType(false);
Operation *cloned = rewriter.clone(*op.getOperation());
rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
cloned->getOpResult(0).setType(srcCOOTp);
});
Value srcCOO = cloned->getOpResult(0);
Type dstCOOTp = dstStt.getCOOType(true);
Value dstCOO = rewriter.create<ReorderCOOOp>(
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
if (dstCOO.getType() == finalTp) {
rewriter.replaceOp(op, dstCOO);
} else {
auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
rewriter.setInsertionPointAfter(c);
tmpBufs = dstCOO;
}
return success();
}