#include "Utils/CodegenUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
static constexpr const char kHybridQuickSortFuncNamePrefix[] =
"_sparse_hybrid_qsort_";
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
AffineMap, uint64_t, uint32_t)>;
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
StringRef namePrefix, AffineMap xPerm,
uint64_t ny, ValueRange operands) {
nameOstream << namePrefix;
for (auto res : xPerm.getResults())
nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";
nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
nameOstream << "_coo_" << ny;
constexpr uint64_t yBufferOffset = 1;
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << getMemRefType(v).getElementType();
}
static FlatSymbolRefAttr getMangledSortHelperFunc(
OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, nameOstream.str());
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(insertPoint);
Location loc = insertPoint.getLoc();
func = builder.create<func::FuncOp>(
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
createFunc(builder, module, func, xPerm, ny, nTrailingP);
}
return result;
}
static void forEachIJPairInXs(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
Value ak = constantIndex(builder, loc, actualK);
Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
Value buffer = args[xStartIdx];
bodyBuilder(k, i, j, buffer);
}
}
static void forEachIJPairInAllBuffers(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
xPerm.getResults().end());
for (unsigned y = 0; y < ny; y++) {
exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
}
AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
assert(xyPerm.isPermutation());
forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
constexpr uint64_t numHandledBuffers = 1;
Value i = args[0];
Value j = args[1];
for (const auto &arg :
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
}
}
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
AffineMap xPerm, uint64_t ny) {
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
builder.create<memref::StoreOp>(loc, vj, buffer, i);
builder.create<memref::StoreOp>(loc, vi, buffer, j);
};
forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
}
static Value createInlinedCompareImplementation(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
uint64_t ny,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
Value result;
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
bool isFirstDim = (k == 0);
bool isLastDim = (k == xPerm.getNumResults() - 1);
Value val =
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
if (isFirstDim) {
result = val;
} else if (!isLastDim) {
OpBuilder::InsertionGuard insertionGuard(builder);
auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
}
};
forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
builder.setInsertionPointAfterValue(result);
return result;
}
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
Value x, bool isFirstDim, bool isLastDim) {
Value vi = builder.create<memref::LoadOp>(loc, x, i);
Value vj = builder.create<memref::LoadOp>(loc, x, j);
Value res;
if (isLastDim) {
res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
if (!isFirstDim)
builder.create<scf::YieldOp>(loc, res);
} else {
Value ne =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
ne, true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value f = constantI1(builder, loc, false);
builder.create<scf::YieldOp>(loc, f);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
res = ifOp.getResult(0);
}
return res;
}
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
ValueRange args, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createEqCompare);
}
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
Value j, Value x, bool isFirstDim,
bool isLastDim) {
Value vi = builder.create<memref::LoadOp>(loc, x, i);
Value vj = builder.create<memref::LoadOp>(loc, x, j);
Value res;
if (isLastDim) {
res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
if (!isFirstDim)
builder.create<scf::YieldOp>(loc, res);
} else {
Value ne =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
ne, true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value lt =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
builder.create<scf::YieldOp>(loc, lt);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
res = ifOp.getResult(0);
}
return res;
}
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
ValueRange args, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createLessThanCompare);
}
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP = 0) {
(void)nTrailingP;
assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value p = args[hiIdx];
SmallVector<Type, 2> types(2, p.getType());
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(before);
Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
before->getArgument(0),
before->getArgument(1));
builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(after);
Value lo = after->getArgument(0);
Value hi = after->getArgument(1);
Value c1 = constantIndex(builder, loc, 1);
Value mid = builder.create<arith::ShRUIOp>(
loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
SmallVector<Value> compareOperands{p, mid};
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
}
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
ModuleOp module,
func::FuncOp func, ValueRange xs,
Value i, Value p, AffineMap xPerm,
uint64_t ny, int step) {
Location loc = func.getLoc();
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
builder.setInsertionPointToEnd(before);
SmallVector<Value> compareOperands;
if (step > 0) {
compareOperands.push_back(before->getArgument(0));
compareOperands.push_back(p);
} else {
assert(step < 0);
compareOperands.push_back(p);
compareOperands.push_back(before->getArgument(0));
}
compareOperands.append(xs.begin(), xs.end());
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
builder.setInsertionPointToEnd(after);
Value cs = constantIndex(builder, loc, step);
i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
builder.create<scf::YieldOp>(loc, ValueRange{i});
i = whileOp.getResult(0);
builder.setInsertionPointAfter(whileOp);
compareOperands[0] = i;
compareOperands[1] = p;
Value compareEq =
createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
return std::make_pair(whileOp.getResult(0), compareEq);
}
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
AffineMap xPerm, uint64_t ny,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands,
Value a, Value b) {
compareOperands[0] = b;
compareOperands[1] = a;
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
swapOperands[0] = b;
swapOperands[1] = a;
createSwap(builder, loc, swapOperands, xPerm, ny);
return ifOp;
}
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v1, v2);
createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
v0, v1);
builder.setInsertionPointAfter(ifOp);
}
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v0, v1);
builder.setInsertionPointAfter(ifOp1);
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
v1, v2);
}
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2, Value v3, Value v4) {
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
v2);
auto insert4th = [&]() {
scf::IfOp ifOp = createCompareThenSwap(
builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
v1, v2);
builder.setInsertionPointAfter(ifOp);
};
insert4th();
scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
compareOperands, v3, v4);
insert4th();
builder.setInsertionPointAfter(ifOp);
}
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
Value lo, Value hi, Value mi, ValueRange args) {
SmallVector<Value> compareOperands{mi, lo};
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
SmallVector<Value> swapOperands{mi, lo};
swapOperands.append(args.begin() + xStartIdx, args.end());
Location loc = func.getLoc();
Value c1 = constantIndex(builder, loc, 1);
Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
Value lenThreshold = constantIndex(builder, loc, 1000);
Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
len, lenThreshold);
scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, true);
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
hi);
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
a = builder.create<arith::ShRUIOp>(loc, a, c1);
Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
b = builder.create<arith::ShRUIOp>(loc, b, c1);
createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
b, hi);
builder.setInsertionPointAfter(lenIf);
}
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP = 0) {
(void)nTrailingP;
assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
Value c1 = constantIndex(builder, loc, 1);
Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
Value trueVal = constantI1(builder, loc, true);
SmallVector<Value, 4> operands{i, j, p, trueVal};
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
trueVal.getType()};
scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
{loc, loc, loc, loc});
builder.setInsertionPointToEnd(before);
builder.create<scf::ConditionOp>(loc, before->getArgument(3),
before->getArguments());
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
builder.setInsertionPointToEnd(after);
i = after->getArgument(0);
j = after->getArgument(1);
p = after->getArgument(2);
constexpr uint64_t numXBuffers = 1;
auto [iresult, iCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
i, p, xPerm, ny, 1);
i = iresult;
auto [jresult, jCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
j, p, xPerm, ny, -1);
j = jresult;
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, xPerm, ny);
Value icond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
icond, true);
builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{j});
builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
Value jcond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
jcond, true);
builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{i});
builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{p});
builder.setInsertionPointAfter(ifOpJ);
builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
builder.setInsertionPointAfter(ifOpI);
Value compareEqIJ =
builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
scf::IfOp ifOp2 = builder.create<scf::IfOp>(
loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, true);
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{i, j});
builder.setInsertionPointAfter(ifOp2);
builder.create<scf::YieldOp>(
loc,
ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
constantI1(builder, loc, true)});
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
p = builder.create<arith::AddIOp>(loc, j,
constantOne(builder, loc, j.getType()));
builder.create<scf::YieldOp>(
loc, ValueRange{i, j, p, constantI1(builder, loc, false)});
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResults());
builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
}
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
Value n) {
Value i2 = constantIndex(builder, loc, 2);
Value res = builder.create<arith::SubIOp>(loc, n, i2);
Value i1 = constantIndex(builder, loc, 1);
return builder.create<arith::ShRUIOp>(loc, res, i1);
}
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
assert(nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
Value n = entryBlock->getArguments().back();
ValueRange args = entryBlock->getArguments().drop_back();
Value first = args[loIdx];
Value start = args[hiIdx];
Value c2 = constantIndex(builder, loc, 2);
Value condN =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, false);
builder.setInsertionPointToStart(&ifN.getThenRegion().front());
Value child = builder.create<arith::SubIOp>(loc, start, first);
Value t = createSubTwoDividedByTwo(builder, loc, n);
Value condNc =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, false);
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
Value c1 = constantIndex(builder, loc, 1);
SmallVector<Value> compareOperands{start, start};
constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
rChild, n);
SmallVector<Type, 2> ifTypes(2, r.getType());
scf::IfOp if1 =
builder.create<scf::IfOp>(loc, ifTypes, cond1, true);
builder.setInsertionPointToStart(&if1.getThenRegion().front());
Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
compareOperands[0] = lChildIdx;
compareOperands[1] = rChildIdx;
Value cond2 =
createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp if2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
builder.setInsertionPointToStart(&if2.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
builder.setInsertionPointAfter(if2);
builder.create<scf::YieldOp>(loc, if2.getResults());
builder.setInsertionPointToStart(&if1.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
builder.setInsertionPointAfter(if1);
return std::make_pair(if1.getResult(0), if1.getResult(1));
};
Value childIdx;
std::tie(child, childIdx) = getLargerChild(child);
SmallVector<Type, 3> types(3, child.getType());
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{start, child, childIdx});
SmallVector<Location, 3> locs(3, loc);
Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
builder.setInsertionPointToEnd(before);
start = before->getArgument(0);
childIdx = before->getArgument(2);
compareOperands[0] = start;
compareOperands[1] = childIdx;
Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
start = after->getArgument(0);
child = after->getArgument(1);
childIdx = after->getArgument(2);
SmallVector<Value> swapOperands{start, childIdx};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, xPerm, ny);
start = childIdx;
Value cond2 =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
scf::IfOp if2 = builder.create<scf::IfOp>(
loc, TypeRange{child.getType(), child.getType()}, cond2, true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
auto [newChild, newChildIdx] = getLargerChild(child);
builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
builder.setInsertionPointToStart(&if2.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
builder.setInsertionPointAfter(if2);
builder.create<scf::YieldOp>(
loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
builder.setInsertionPointAfter(ifN);
builder.create<func::ReturnOp>(loc);
}
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
(void)nTrailingP;
assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value n = builder.create<arith::SubIOp>(loc, hi, lo);
Value c0 = constantIndex(builder, loc, 0);
Value c1 = constantIndex(builder, loc, 1);
Value s = createSubTwoDividedByTwo(builder, loc, n);
Value up = builder.create<arith::AddIOp>(loc, s, c1);
scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
builder.setInsertionPointToStart(forI.getBody());
Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
SmallVector<Value> shiftDownOperands = {lo, lopi};
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
shiftDownOperands.push_back(n);
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
shiftDownOperands, createShiftDownFunc, 1);
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
shiftDownOperands);
builder.setInsertionPointAfter(forI);
up = builder.create<arith::SubIOp>(loc, n, c1);
scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
builder.setInsertionPointToStart(forL.getBody());
Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
SmallVector<Value> swapOperands{lo, loplm1};
swapOperands.append(args.begin() + xStartIdx, args.end());
createSwap(builder, loc, swapOperands, xPerm, ny);
shiftDownOperands[1] = lo;
shiftDownOperands[shiftDownOperands.size() - 1] =
builder.create<arith::SubIOp>(loc, l, c1);
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
shiftDownOperands);
builder.setInsertionPointAfter(forL);
builder.create<func::ReturnOp>(loc);
}
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
ValueRange args, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
Value lo = args[loIdx];
Value hi = args[hiIdx];
SmallVector<Type, 2> types(2, lo.getType());
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
ny, args.drop_back(nTrailingP), createPartitionFunc);
Value p = builder
.create<func::CallOp>(loc, partitionFunc,
TypeRange{IndexType::get(context)},
args.drop_back(nTrailingP))
.getResult(0);
Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
Value c2 = constantIndex(builder, loc, 2);
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
Value lenGtTwo =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
scf::IfOp ifLenGtTwo =
builder.create<scf::IfOp>(loc, types, lenGtTwo, true);
builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
lenLow, lenHigh);
Value c0 = constantIndex(builder, loc, 0);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, true);
auto mayRecursion = [&](Value low, Value high, Value len) {
Value cond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> operands{low, high};
operands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, operands);
builder.setInsertionPointAfter(ifOp);
};
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
mayRecursion(lo, p, lenLow);
builder.create<scf::YieldOp>(loc, ValueRange{p, hi});
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
mayRecursion(p, hi, lenHigh);
builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResults());
builder.setInsertionPointAfter(ifLenGtTwo);
return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
}
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm,
uint64_t ny, uint32_t nTrailingP) {
(void)nTrailingP;
assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value c1 = constantIndex(builder, loc, 1);
Value lo = args[loIdx];
Value hi = args[hiIdx];
Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
builder.setInsertionPointToStart(forOpI.getBody());
Value i = forOpI.getInductionVar();
SmallVector<Value> operands{lo, i};
operands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
xPerm, ny, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
.getResult(0);
operands[0] = operands[1] = i;
SmallVector<Value> d;
forEachIJPairInAllBuffers(
builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
});
Value imp = builder.create<arith::SubIOp>(loc, i, p);
Value c0 = constantIndex(builder, loc, 0);
scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
builder.setInsertionPointToStart(forOpJ.getBody());
Value j = forOpJ.getInductionVar();
Value imj = builder.create<arith::SubIOp>(loc, i, j);
operands[1] = imj;
operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
forEachIJPairInAllBuffers(
builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
builder.create<memref::StoreOp>(loc, t, buffer, imj);
});
builder.setInsertionPointAfter(forOpJ);
operands[0] = operands[1] = p;
forEachIJPairInAllBuffers(
builder, loc, operands, xPerm, ny,
[&](uint64_t k, Value p, Value usused, Value buffer) {
builder.create<memref::StoreOp>(loc, d[k], buffer, p);
});
builder.setInsertionPointAfter(forOpI);
builder.create<func::ReturnOp>(loc);
}
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
assert(nTrailingP == 1 || nTrailingP == 0);
bool isHybrid = (nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
SmallVector<Value> args;
args.append(entryBlock->getArguments().begin(),
entryBlock->getArguments().end());
Value lo = args[loIdx];
Value hi = args[hiIdx];
SmallVector<Type, 2> types(2, lo.getType());
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
Block *before =
builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(before);
lo = before->getArgument(0);
hi = before->getArgument(1);
Value loP1 =
builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
Value needSort =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
Block *after =
builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
builder.setInsertionPointToEnd(after);
lo = after->getArgument(0);
hi = after->getArgument(1);
args[0] = lo;
args[1] = hi;
if (isHybrid) {
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
Value lenLimit = constantIndex(builder, loc, 30);
Value lenCond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, len, lenLimit);
scf::IfOp lenIf =
builder.create<scf::IfOp>(loc, types, lenCond, true);
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
Value depthLimit = args.back();
depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
constantI64(builder, loc, 1));
Value depthCond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
depthLimit, constantI64(builder, loc, 0));
scf::IfOp depthIf =
builder.create<scf::IfOp>(loc, types, depthCond, true);
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
args.back() = depthLimit;
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(depthIf);
lo = depthIf.getResult(0);
hi = depthIf.getResult(1);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(lenIf);
lo = lenIf.getResult(0);
hi = lenIf.getResult(1);
} else {
std::tie(lo, hi) =
createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
}
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc);
}
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
uint64_t ny, PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
for (Value v : xys) {
auto mtp = getMemRefType(v);
if (!mtp.isDynamicDim(0)) {
auto newMtp =
MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
v = rewriter.create<memref::CastOp>(loc, newMtp, v);
}
operands.push_back(v);
}
auto insertPoint = op->template getParentOfType<func::FuncOp>();
if (!insertPoint)
return failure();
SmallString<32> funcName;
FuncGeneratorType funcGenerator;
uint32_t nTrailingP = 0;
switch (op.getAlgorithm()) {
case SparseTensorSortKind::HybridQuickSort: {
funcName = kHybridQuickSortFuncNamePrefix;
funcGenerator = createQuickSortFunc;
nTrailingP = 1;
Value lo = operands[loIdx];
Value hi = operands[hiIdx];
Value len = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI64Type(),
rewriter.create<arith::SubIOp>(loc, hi, lo));
Value depthLimit = rewriter.create<arith::SubIOp>(
loc, constantI64(rewriter, loc, 64),
rewriter.create<math::CountLeadingZerosOp>(loc, len));
operands.push_back(depthLimit);
break;
}
case SparseTensorSortKind::QuickSort:
funcName = kQuickSortFuncNamePrefix;
funcGenerator = createQuickSortFunc;
break;
case SparseTensorSortKind::InsertionSortStable:
funcName = kSortStableFuncNamePrefix;
funcGenerator = createSortStableFunc;
break;
case SparseTensorSortKind::HeapSort:
funcName = kHeapSortFuncNamePrefix;
funcGenerator = createHeapSortFunc;
break;
}
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
xPerm, ny, operands, funcGenerator, nTrailingP);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
namespace {
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
using OpRewritePattern<PushBackOp>::OpRewritePattern;
PushBackRewriter(MLIRContext *context, bool enableInit)
: OpRewritePattern(context), enableBufferInitialization(enableInit) {}
LogicalResult matchAndRewrite(PushBackOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value c0 = constantIndex(rewriter, loc, 0);
Value buffer = op.getInBuffer();
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
Value size = op.getCurSize();
Value value = op.getValue();
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
bool nIsOne = (nValue && nValue.value() == 1);
if (!op.getInbounds()) {
Value cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, newSize, capacity);
Value c2 = constantIndex(rewriter, loc, 2);
auto bufferType =
MemRefType::get({ShapedType::kDynamic}, value.getType());
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
if (nIsOne) {
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
} else {
scf::WhileOp whileOp =
rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);
Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
{capacity.getType()}, {loc});
rewriter.setInsertionPointToEnd(before);
capacity =
rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
newSize, capacity);
rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
{capacity.getType()}, {loc});
rewriter.setInsertionPointToEnd(after);
rewriter.create<scf::YieldOp>(loc, after->getArguments());
rewriter.setInsertionPointAfter(whileOp);
capacity = whileOp.getResult(0);
}
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
if (enableBufferInitialization) {
Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
Value fillValue = constantZero(rewriter, loc, value.getType());
Value subBuffer = rewriter.create<memref::SubViewOp>(
loc, newBuffer, ValueRange{newSize},
ValueRange{fillSize},
ValueRange{constantIndex(rewriter, loc, 1)});
rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
}
rewriter.create<scf::YieldOp>(loc, newBuffer);
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, buffer);
rewriter.setInsertionPointAfter(ifOp);
buffer = ifOp.getResult(0);
}
if (nIsOne) {
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
} else {
Value subBuffer = rewriter.create<memref::SubViewOp>(
loc, buffer, ValueRange{size}, ValueRange{n},
ValueRange{constantIndex(rewriter, loc, 1)});
rewriter.create<linalg::FillOp>(loc, value, subBuffer);
}
rewriter.replaceOp(op, {buffer, newSize});
return success();
}
private:
bool enableBufferInitialization;
};
struct SortRewriter : public OpRewritePattern<SortOp> {
public:
using OpRewritePattern<SortOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SortOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> xys;
xys.push_back(op.getXy());
xys.append(op.getYs().begin(), op.getYs().end());
auto xPerm = op.getPermMap();
uint64_t ny = 0;
if (auto nyAttr = op.getNyAttr())
ny = nyAttr.getInt();
return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
}
};
}
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
bool enableBufferInitialization) {
patterns.add<PushBackRewriter>(patterns.getContext(),
enableBufferInitialization);
patterns.add<SortRewriter>(patterns.getContext());
}