//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting rules that are specific to sparse tensor
// primitives with memref operands.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Support/LLVM.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

//===---------------------------------------------------------------------===//
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//

static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;

static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
    "_sparse_binary_search_";
static constexpr const char kHybridQuickSortFuncNamePrefix[] =
    "_sparse_hybrid_qsort_";
static constexpr const char kSortStableFuncNamePrefix[] =
    "_sparse_sort_stable_";
static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";

using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
                                            AffineMap, uint64_t, uint32_t)>;

/// Constructs a function name with this format to facilitate quick sort:
///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
                                         StringRef namePrefix, AffineMap xPerm,
                                         uint64_t ny, ValueRange operands) {
  nameOstream << namePrefix;
  for (auto res : xPerm.getResults())
    nameOstream << cast<AffineDimExpr>(res).getPosition() << "_";

  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
  nameOstream << "_coo_" << ny;

  constexpr uint64_t yBufferOffset = 1;
  for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
    nameOstream << "_" << getMemRefType(v).getElementType();
}

/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
/// parameters `xPerm` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx.
//
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
// parameters for the sorting functions. Other parameters, such as the recursive
// call depth, are appended to the end of the parameter list as
// "trailing parameters".
static FlatSymbolRefAttr getMangledSortHelperFunc(
    OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
    StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
    FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
  SmallString<32> nameBuffer;
  llvm::raw_svector_ostream nameOstream(nameBuffer);
  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
                               operands.drop_back(nTrailingP));

  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
  MLIRContext *context = module.getContext();
  auto result = SymbolRefAttr::get(context, nameOstream.str());
  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());

  if (!func) {
    // Create the function.
    OpBuilder::InsertionGuard insertionGuard(builder);
    builder.setInsertionPoint(insertPoint);
    Location loc = insertPoint.getLoc();
    func = builder.create<func::FuncOp>(
        loc, nameOstream.str(),
        FunctionType::get(context, operands.getTypes(), resultTypes));
    func.setPrivate();
    createFunc(builder, module, func, xPerm, ny, nTrailingP);
  }

  return result;
}

/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
    unsigned actualK = cast<AffineDimExpr>(xPerm.getResult(k)).getPosition();
    Value ak = constantIndex(builder, loc, actualK);
    Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
    Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
    Value buffer = args[xStartIdx];

    bodyBuilder(k, i, j, buffer);
  }
}

/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {

  // Create code for the first (xPerm + ny) buffers.
  SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
                               xPerm.getResults().end());
  for (unsigned y = 0; y < ny; y++) {
    exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
  }
  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
  assert(xyPerm.isPermutation());

  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);

  constexpr uint64_t numHandledBuffers = 1;
  // Create code for the remaining buffers.
  Value i = args[0];
  Value j = args[1];
  for (const auto &arg :
       llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
    bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
  }
}

/// Creates a code block for swapping the values in index i and j for all the
/// buffers.
//
// The generated IR corresponds to this C like algorithm:
//     swap(x0[i], x0[j]);
//     swap(x1[i], x1[j]);
//     ...
//     swap(xn[i], xn[j]);
//     swap(y0[i], y0[j]);
//     ...
//     swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
                       AffineMap xPerm, uint64_t ny) {
  auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
    Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
    Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
    builder.create<memref::StoreOp>(loc, vj, buffer, i);
    builder.create<memref::StoreOp>(loc, vi, buffer, j);
  };

  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
}

/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
        compareBuilder) {
  Value result;
  auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
    bool isFirstDim = (k == 0);
    bool isLastDim = (k == xPerm.getNumResults() - 1);
    Value val =
        compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
    if (isFirstDim) {
      result = val;
    } else if (!isLastDim) {
      OpBuilder::InsertionGuard insertionGuard(builder);
      auto ifOp = cast<scf::IfOp>(val.getDefiningOp());
      builder.setInsertionPointAfter(ifOp);
      builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
    }
  };

  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);

  builder.setInsertionPointAfterValue(result);
  return result;
}

/// Generates code to compare whether x[i] is equal to x[j] and returns the
/// result of the comparison.
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
                             Value x, bool isFirstDim, bool isLastDim) {
  Value vi = builder.create<memref::LoadOp>(loc, x, i);
  Value vj = builder.create<memref::LoadOp>(loc, x, j);

  Value res;
  if (isLastDim) {
    res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
    // For 1D, we create a compare without any control flow. Otherwise, we
    // create YieldOp to return the result in the nested if-stmt.
    if (!isFirstDim)
      builder.create<scf::YieldOp>(loc, res);
  } else {
    Value ne =
        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
                                               ne, /*else=*/true);
    // If (x[i] != x[j]).
    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
    Value f = constantI1(builder, loc, false);
    builder.create<scf::YieldOp>(loc, f);

    // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
    // checks the remaining dimensions.
    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
    res = ifOp.getResult(0);
  }

  return res;
}

/// Creates code to compare whether xs[i] is equal to xs[j].
//
// The generate IR corresponds to this C like algorithm:
//   if (x0[i] != x0[j])
//     return false;
//   else
//     if (x1[i] != x1[j])
//       return false;
//     else if (x2[2] != x2[j]))
//       and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
                                    ValueRange args, AffineMap xPerm,
                                    uint64_t ny, uint32_t nTrailingP = 0) {
  // Compare functions don't use trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                            createEqCompare);
}

/// Generates code to compare whether x[i] is less than x[j] and returns the
/// result of the comparison.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
                                   Value j, Value x, bool isFirstDim,
                                   bool isLastDim) {
  Value vi = builder.create<memref::LoadOp>(loc, x, i);
  Value vj = builder.create<memref::LoadOp>(loc, x, j);

  Value res;
  if (isLastDim) {
    res = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
    // For 1D, we create a compare without any control flow. Otherwise, we
    // create YieldOp to return the result in the nested if-stmt.
    if (!isFirstDim)
      builder.create<scf::YieldOp>(loc, res);
  } else {
    Value ne =
        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, vi, vj);
    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIntegerType(1),
                                               ne, /*else=*/true);
    // If (x[i] != x[j]).
    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
    Value lt =
        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, vi, vj);
    builder.create<scf::YieldOp>(loc, lt);

    // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that
    // checks the remaining dimensions.
    builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
    res = ifOp.getResult(0);
  }

  return res;
}

/// Creates code to compare whether xs[i] is less than xs[j].
//
// The generate IR corresponds to this C like algorithm:
//   if (x0[i] != x0[j])
//     return x0[i] < x0[j];
//   else if (x1[j] != x1[i])
//     return x1[i] < x1[j];
//   else
//       and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
                                   ValueRange args, AffineMap xPerm,
                                   uint64_t ny, uint32_t nTrailingP = 0) {
  // Compare functions don't use trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
                                            createLessThanCompare);
}

/// Creates a function to use a binary search to find the insertion point for
/// inserting xs[hi] to the sorted values xs[lo..hi).
//
// The generate IR corresponds to this C like algorithm:
//   p = hi
//   while (lo < hi)
//      mid = (lo + hi) >> 1
//      if (xs[p] < xs[mid])
//        hi = mid
//      else
//        lo = mid - 1
//   return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
                                   func::FuncOp func, AffineMap xPerm,
                                   uint64_t ny, uint32_t nTrailingP = 0) {
  // Binary search doesn't use trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  OpBuilder::InsertionGuard insertionGuard(builder);
  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  Location loc = func.getLoc();
  ValueRange args = entryBlock->getArguments();
  Value p = args[hiIdx];
  SmallVector<Type, 2> types(2, p.getType()); // Only two types.
  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
      loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});

  // The before-region of the WhileOp.
  Block *before =
      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
  builder.setInsertionPointToEnd(before);
  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                              before->getArgument(0),
                                              before->getArgument(1));
  builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());

  // The after-region of the WhileOp.
  Block *after =
      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
  builder.setInsertionPointToEnd(after);
  Value lo = after->getArgument(0);
  Value hi = after->getArgument(1);
  // Compute mid = (lo + hi) >> 1.
  Value c1 = constantIndex(builder, loc, 1);
  Value mid = builder.create<arith::ShRUIOp>(
      loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
  Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);

  // Compare xs[p] < xs[mid].
  SmallVector<Value> compareOperands{p, mid};
  constexpr uint64_t numXBuffers = 1;
  compareOperands.append(args.begin() + xStartIdx,
                         args.begin() + xStartIdx + numXBuffers);
  Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
  // Update lo and hi for the WhileOp as follows:
  //   if (xs[p] < xs[mid]))
  //     hi = mid;
  //   else
  //     lo = mid + 1;
  Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
  Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
  builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});

  builder.setInsertionPointAfter(whileOp);
  builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
}

/// Creates code to advance i in a loop based on xs[p] as follows:
///   while (xs[i] < xs[p]) i += step (step > 0)
/// or
///   while (xs[i] > xs[p]) i += step (step < 0)
/// The routine returns i as well as a boolean value to indicate whether
/// xs[i] == xs[p].
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
                                              ModuleOp module,
                                              func::FuncOp func, ValueRange xs,
                                              Value i, Value p, AffineMap xPerm,
                                              uint64_t ny, int step) {
  Location loc = func.getLoc();
  scf::WhileOp whileOp =
      builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});

  Block *before =
      builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
  builder.setInsertionPointToEnd(before);
  SmallVector<Value> compareOperands;
  if (step > 0) {
    compareOperands.push_back(before->getArgument(0));
    compareOperands.push_back(p);
  } else {
    assert(step < 0);
    compareOperands.push_back(p);
    compareOperands.push_back(before->getArgument(0));
  }
  compareOperands.append(xs.begin(), xs.end());
  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());

  Block *after =
      builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
  builder.setInsertionPointToEnd(after);
  Value cs = constantIndex(builder, loc, step);
  i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
  builder.create<scf::YieldOp>(loc, ValueRange{i});
  i = whileOp.getResult(0);

  builder.setInsertionPointAfter(whileOp);
  compareOperands[0] = i;
  compareOperands[1] = p;
  Value compareEq =
      createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);

  return std::make_pair(whileOp.getResult(0), compareEq);
}

/// Creates and returns an IfOp to compare two elements and swap the elements
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
/// right after the swap instructions.
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
                                       AffineMap xPerm, uint64_t ny,
                                       SmallVectorImpl<Value> &swapOperands,
                                       SmallVectorImpl<Value> &compareOperands,
                                       Value a, Value b) {
  // Compare(data[b], data[a]).
  compareOperands[0] = b;
  compareOperands[1] = a;
  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
  swapOperands[0] = b;
  swapOperands[1] = a;
  createSwap(builder, loc, swapOperands, xPerm, ny);
  return ifOp;
}

/// Creates code to insert the 3rd element to a list of two sorted elements.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
                            uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                            SmallVectorImpl<Value> &compareOperands, Value v0,
                            Value v1, Value v2) {
  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
                                         compareOperands, v1, v2);
  createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
                        v0, v1);
  builder.setInsertionPointAfter(ifOp);
}

/// Creates code to sort 3 elements.
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                        SmallVectorImpl<Value> &compareOperands, Value v0,
                        Value v1, Value v2) {
  // Sort the first 2 elements.
  scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
                                          compareOperands, v0, v1);
  builder.setInsertionPointAfter(ifOp1);

  // Insert the 3th element.
  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
                  v1, v2);
}

/// Creates code to sort 5 elements.
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                        SmallVectorImpl<Value> &compareOperands, Value v0,
                        Value v1, Value v2, Value v3, Value v4) {
  // Sort the first 3 elements.
  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
              v2);

  auto insert4th = [&]() {
    scf::IfOp ifOp = createCompareThenSwap(
        builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
    createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
                    v1, v2);
    builder.setInsertionPointAfter(ifOp);
  };

  // Insert the 4th element.
  insert4th();

  // Insert the 5th element.
  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
                                         compareOperands, v3, v4);
  insert4th();
  builder.setInsertionPointAfter(ifOp);
}

/// Creates a code block to swap the values in indices lo, mi, and hi so that
/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
/// the number of values in range [lo, hi) is more than a threshold, we also
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
                              func::FuncOp func, AffineMap xPerm, uint64_t ny,
                              Value lo, Value hi, Value mi, ValueRange args) {
  SmallVector<Value> compareOperands{mi, lo};
  constexpr uint64_t numXBuffers = 1;
  compareOperands.append(args.begin() + xStartIdx,
                         args.begin() + xStartIdx + numXBuffers);
  SmallVector<Value> swapOperands{mi, lo};
  swapOperands.append(args.begin() + xStartIdx, args.end());
  Location loc = func.getLoc();
  Value c1 = constantIndex(builder, loc, 1);
  Value hiP1 = builder.create<arith::AddIOp>(loc, hi, c1);
  Value len = builder.create<arith::SubIOp>(loc, hiP1, lo);
  Value lenThreshold = constantIndex(builder, loc, 1000);
  Value lenCond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                                len, lenThreshold);
  scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);

  // When len < 1000, choose pivot from median of 3 values.
  builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
              hi);

  // When len >= 1000, choose pivot from median of 5 values.
  builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
  Value miP1 = builder.create<arith::AddIOp>(loc, hi, c1);
  Value a = builder.create<arith::AddIOp>(loc, lo, miP1);
  // Value a is the middle between [loc, mi].
  a = builder.create<arith::ShRUIOp>(loc, a, c1);
  Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
  // Value b is the middle between [mi, hi].
  b = builder.create<arith::ShRUIOp>(loc, b, c1);
  createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
              b, hi);

  builder.setInsertionPointAfter(lenIf);
}

/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
// The generated IR corresponds to this C like algorithm:
// int partition(lo, hi, xs) {
//   p = (lo+hi)/2  // pivot index
//   i = lo
//   j = hi-1
//   while (true) do {
//     while (xs[i] < xs[p]) i ++;
//     i_eq = (xs[i] == xs[p]);
//     while (xs[j] > xs[p]) j --;
//     j_eq = (xs[j] == xs[p]);
//
//     if (i >= j) return j + 1;
//
//     if (i < j) {
//       swap(xs[i], xs[j])
//       if (i == p) {
//         p = j;
//       } else if (j == p) {
//         p = i;
//       }
//       if (i_eq && j_eq) {
//         ++i;
//         --j;
//       }
//     }
//   }
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP = 0) {
  // Quick sort partition doesn't use trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  OpBuilder::InsertionGuard insertionGuard(builder);

  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  Location loc = func.getLoc();
  ValueRange args = entryBlock->getArguments();
  Value lo = args[loIdx];
  Value hi = args[hiIdx];
  Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
  Value c1 = constantIndex(builder, loc, 1);
  Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);

  Value i = lo;
  Value j = builder.create<arith::SubIOp>(loc, hi, c1);
  createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
  Value trueVal = constantI1(builder, loc, true); // The value for while (true)
  SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
  SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
                             trueVal.getType()};
  scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);

  // The before-region of the WhileOp.
  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types,
                                      {loc, loc, loc, loc});
  builder.setInsertionPointToEnd(before);
  builder.create<scf::ConditionOp>(loc, before->getArgument(3),
                                   before->getArguments());

  // The after-region of the WhileOp.
  Block *after =
      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc, loc});
  builder.setInsertionPointToEnd(after);
  i = after->getArgument(0);
  j = after->getArgument(1);
  p = after->getArgument(2);

  constexpr uint64_t numXBuffers = 1;
  auto [iresult, iCompareEq] =
      createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
                     i, p, xPerm, ny, 1);
  i = iresult;
  auto [jresult, jCompareEq] =
      createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
                     j, p, xPerm, ny, -1);
  j = jresult;

  // If i < j:
  Value cond =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
  SmallVector<Value> swapOperands{i, j};
  swapOperands.append(args.begin() + xStartIdx, args.end());
  createSwap(builder, loc, swapOperands, xPerm, ny);
  // If the pivot is moved, update p with the new pivot.
  Value icond =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
  scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
                                              icond, /*else=*/true);
  builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
  builder.create<scf::YieldOp>(loc, ValueRange{j});
  builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
  Value jcond =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
  scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
                                              jcond, /*else=*/true);
  builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
  builder.create<scf::YieldOp>(loc, ValueRange{i});
  builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
  builder.create<scf::YieldOp>(loc, ValueRange{p});
  builder.setInsertionPointAfter(ifOpJ);
  builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
  builder.setInsertionPointAfter(ifOpI);
  Value compareEqIJ =
      builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
  scf::IfOp ifOp2 = builder.create<scf::IfOp>(
      loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
  Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
  Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
  builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
  builder.create<scf::YieldOp>(loc, ValueRange{i, j});
  builder.setInsertionPointAfter(ifOp2);
  builder.create<scf::YieldOp>(
      loc,
      ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0),
                 /*cont=*/constantI1(builder, loc, true)});

  // False branch for if i < j (i.e., i >= j):
  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
  p = builder.create<arith::AddIOp>(loc, j,
                                    constantOne(builder, loc, j.getType()));
  builder.create<scf::YieldOp>(
      loc, ValueRange{i, j, p, /*cont=*/constantI1(builder, loc, false)});

  // Return for the whileOp.
  builder.setInsertionPointAfter(ifOp);
  builder.create<scf::YieldOp>(loc, ifOp.getResults());

  // Return for the function.
  builder.setInsertionPointAfter(whileOp);
  builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
}

/// Computes (n-2)/n, assuming n has index type.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
                                      Value n) {
  Value i2 = constantIndex(builder, loc, 2);
  Value res = builder.create<arith::SubIOp>(loc, n, i2);
  Value i1 = constantIndex(builder, loc, 1);
  return builder.create<arith::ShRUIOp>(loc, res, i1);
}

/// Creates a function to heapify the subtree with root `start` within the full
/// binary tree in the range of index [first, first + n).
//
// The generated IR corresponds to this C like algorithm:
// void shiftDown(first, start, n, data) {
//   if (n >= 2) {
//     child = start - first
//     if ((n-2)/2 >= child) {
//       // Left child exists.
//       child = child * 2 + 1 // Initialize the bigger child to left child.
//       childIndex = child + first
//       if (child+1 < n && data[childIndex] < data[childIndex+1])
//         // Right child exits and is bigger.
//         childIndex++; child++;
//       // Shift data[start] down to where it belongs in the subtree.
//       while (data[start] < data[childIndex) {
//         swap(data[start], data[childIndex])
//         start = childIndex
//         if ((n - 2)/2 >= child) {
//           // Left child exists.
//           child = 2*child + 1
//           childIndex = child + 1
//           if (child + 1) < n && data[childIndex] < data[childIndex+1]
//             childIndex++; child++;
//         }
//       }
//     }
//   }
// }
//
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP) {
  // The value n is passed in as a trailing parameter.
  assert(nTrailingP == 1);
  OpBuilder::InsertionGuard insertionGuard(builder);
  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  Location loc = func.getLoc();
  Value n = entryBlock->getArguments().back();
  ValueRange args = entryBlock->getArguments().drop_back();
  Value first = args[loIdx];
  Value start = args[hiIdx];

  // If (n >= 2).
  Value c2 = constantIndex(builder, loc, 2);
  Value condN =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
  scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
  builder.setInsertionPointToStart(&ifN.getThenRegion().front());
  Value child = builder.create<arith::SubIOp>(loc, start, first);

  // If ((n-2)/2 >= child).
  Value t = createSubTwoDividedByTwo(builder, loc, n);
  Value condNc =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
  scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);

  builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
  Value c1 = constantIndex(builder, loc, 1);
  SmallVector<Value> compareOperands{start, start};
  constexpr uint64_t numXBuffers = 1;
  compareOperands.append(args.begin() + xStartIdx,
                         args.begin() + xStartIdx + numXBuffers);

  // Generate code to inspect the children of 'r' and return the larger child
  // as follows:
  //   child = r * 2 + 1 // Left child.
  //   childIndex = child + first
  //   if (child+1 < n && data[childIndex] < data[childIndex+1])
  //     childIndex ++; child ++ // Right child is bigger.
  auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
    Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
    lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
    Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
    Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
    Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
                                                rChild, n);
    SmallVector<Type, 2> ifTypes(2, r.getType());
    scf::IfOp if1 =
        builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
    builder.setInsertionPointToStart(&if1.getThenRegion().front());
    Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
    // Compare data[left] < data[right].
    compareOperands[0] = lChildIdx;
    compareOperands[1] = rChildIdx;
    Value cond2 =
        createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
    scf::IfOp if2 =
        builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
    builder.setInsertionPointToStart(&if2.getThenRegion().front());
    builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
    builder.setInsertionPointToStart(&if2.getElseRegion().front());
    builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
    builder.setInsertionPointAfter(if2);
    builder.create<scf::YieldOp>(loc, if2.getResults());
    builder.setInsertionPointToStart(&if1.getElseRegion().front());
    builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
    builder.setInsertionPointAfter(if1);
    return std::make_pair(if1.getResult(0), if1.getResult(1));
  };

  Value childIdx;
  std::tie(child, childIdx) = getLargerChild(child);

  // While (data[start] < data[childIndex]).
  SmallVector<Type, 3> types(3, child.getType());
  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
      loc, types, SmallVector<Value, 2>{start, child, childIdx});

  // The before-region of the WhileOp.
  SmallVector<Location, 3> locs(3, loc);
  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
  builder.setInsertionPointToEnd(before);
  start = before->getArgument(0);
  childIdx = before->getArgument(2);
  compareOperands[0] = start;
  compareOperands[1] = childIdx;
  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());

  // The after-region of the WhileOp.
  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
  start = after->getArgument(0);
  child = after->getArgument(1);
  childIdx = after->getArgument(2);
  SmallVector<Value> swapOperands{start, childIdx};
  swapOperands.append(args.begin() + xStartIdx, args.end());
  createSwap(builder, loc, swapOperands, xPerm, ny);
  start = childIdx;
  Value cond2 =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
  scf::IfOp if2 = builder.create<scf::IfOp>(
      loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
  builder.setInsertionPointToStart(&if2.getThenRegion().front());
  auto [newChild, newChildIdx] = getLargerChild(child);
  builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
  builder.setInsertionPointToStart(&if2.getElseRegion().front());
  builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
  builder.setInsertionPointAfter(if2);
  builder.create<scf::YieldOp>(
      loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});

  builder.setInsertionPointAfter(ifN);
  builder.create<func::ReturnOp>(loc);
}

/// Creates a function to perform heap sort on the values in the range of index
/// [lo, hi) with the assumption hi - lo >= 2.
//
// The generate IR corresponds to this C like algorithm:
// void heapSort(lo, hi, data) {
//   n = hi - lo
//   for i = (n-2)/2 downto 0
//     shiftDown(lo, lo+i, n)
//
//   for l = n downto 2
//      swap(lo, lo+l-1)
//      shiftdown(lo, lo, l-1)
// }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
                               uint32_t nTrailingP) {
  // Heap sort function doesn't have trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  OpBuilder::InsertionGuard insertionGuard(builder);
  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  Location loc = func.getLoc();
  ValueRange args = entryBlock->getArguments();
  Value lo = args[loIdx];
  Value hi = args[hiIdx];
  Value n = builder.create<arith::SubIOp>(loc, hi, lo);

  // For i = (n-2)/2 downto 0.
  Value c0 = constantIndex(builder, loc, 0);
  Value c1 = constantIndex(builder, loc, 1);
  Value s = createSubTwoDividedByTwo(builder, loc, n);
  Value up = builder.create<arith::AddIOp>(loc, s, c1);
  scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
  builder.setInsertionPointToStart(forI.getBody());
  Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
  Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
  SmallVector<Value> shiftDownOperands = {lo, lopi};
  shiftDownOperands.append(args.begin() + xStartIdx, args.end());
  shiftDownOperands.push_back(n);
  FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
      builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
      shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
                               shiftDownOperands);

  builder.setInsertionPointAfter(forI);
  // For l = n downto 2.
  up = builder.create<arith::SubIOp>(loc, n, c1);
  scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
  builder.setInsertionPointToStart(forL.getBody());
  Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
  Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
  loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
  SmallVector<Value> swapOperands{lo, loplm1};
  swapOperands.append(args.begin() + xStartIdx, args.end());
  createSwap(builder, loc, swapOperands, xPerm, ny);
  shiftDownOperands[1] = lo;
  shiftDownOperands[shiftDownOperands.size() - 1] =
      builder.create<arith::SubIOp>(loc, l, c1);
  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
                               shiftDownOperands);

  builder.setInsertionPointAfter(forL);
  builder.create<func::ReturnOp>(loc);
}

/// A helper for generating code to perform quick sort. It partitions [lo, hi),
/// recursively calls quick sort to process the smaller partition and returns
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
                ValueRange args, AffineMap xPerm, uint64_t ny,
                uint32_t nTrailingP) {
  MLIRContext *context = module.getContext();
  Location loc = func.getLoc();
  Value lo = args[loIdx];
  Value hi = args[hiIdx];
  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.

  FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
      ny, args.drop_back(nTrailingP), createPartitionFunc);
  Value p = builder
                .create<func::CallOp>(loc, partitionFunc,
                                      TypeRange{IndexType::get(context)},
                                      args.drop_back(nTrailingP))
                .getResult(0);

  Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
  Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
  // Partition already sorts array with len <= 2
  Value c2 = constantIndex(builder, loc, 2);
  Value len = builder.create<arith::SubIOp>(loc, hi, lo);
  Value lenGtTwo =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
  scf::IfOp ifLenGtTwo =
      builder.create<scf::IfOp>(loc, types, lenGtTwo, /*else=*/true);
  builder.setInsertionPointToStart(&ifLenGtTwo.getElseRegion().front());
  // Returns an empty range to mark the entire region is fully sorted.
  builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

  // Else len > 2, need recursion.
  builder.setInsertionPointToStart(&ifLenGtTwo.getThenRegion().front());
  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
                                             lenLow, lenHigh);

  Value c0 = constantIndex(builder, loc, 0);
  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);

  auto mayRecursion = [&](Value low, Value high, Value len) {
    Value cond =
        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
    scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
    SmallVector<Value> operands{low, high};
    operands.append(args.begin() + xStartIdx, args.end());
    builder.create<func::CallOp>(loc, func, operands);
    builder.setInsertionPointAfter(ifOp);
  };

  // Recursively call quickSort to process the smaller partition and return
  // the bigger partition to be processed by the enclosed while-loop.
  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
  mayRecursion(lo, p, lenLow);
  builder.create<scf::YieldOp>(loc, ValueRange{p, hi});

  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
  mayRecursion(p, hi, lenHigh);
  builder.create<scf::YieldOp>(loc, ValueRange{lo, p});

  builder.setInsertionPointAfter(ifOp);
  builder.create<scf::YieldOp>(loc, ifOp.getResults());

  builder.setInsertionPointAfter(ifLenGtTwo);
  return std::make_pair(ifLenGtTwo.getResult(0), ifLenGtTwo.getResult(1));
}

/// Creates a function to perform insertion sort on the values in the range of
/// index [lo, hi).
//
// The generate IR corresponds to this C like algorithm:
// void insertionSort(lo, hi, data) {
//   for (i = lo+1; i < hi; i++) {
//      d = data[i];
//      p = binarySearch(lo, i-1, data)
//      for (j = 0; j > i - p; j++)
//        data[i-j] = data[i-j-1]
//      data[p] = d
//   }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, AffineMap xPerm,
                                 uint64_t ny, uint32_t nTrailingP) {
  // Stable sort function doesn't use trailing parameters.
  (void)nTrailingP;
  assert(nTrailingP == 0);
  OpBuilder::InsertionGuard insertionGuard(builder);
  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  MLIRContext *context = module.getContext();
  Location loc = func.getLoc();
  ValueRange args = entryBlock->getArguments();
  Value c1 = constantIndex(builder, loc, 1);
  Value lo = args[loIdx];
  Value hi = args[hiIdx];
  Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);

  // Start the outer for-stmt with induction variable i.
  scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
  builder.setInsertionPointToStart(forOpI.getBody());
  Value i = forOpI.getInductionVar();

  // Binary search to find the insertion point p.
  SmallVector<Value> operands{lo, i};
  operands.append(args.begin() + xStartIdx, args.end());
  FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
      xPerm, ny, operands, createBinarySearchFunc);
  Value p = builder
                .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
                                      operands)
                .getResult(0);

  // Move the value at data[i] to a temporary location.
  operands[0] = operands[1] = i;
  SmallVector<Value> d;
  forEachIJPairInAllBuffers(
      builder, loc, operands, xPerm, ny,
      [&](uint64_t unused, Value i, Value unused2, Value buffer) {
        d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
      });

  // Start the inner for-stmt with induction variable j, for moving data[p..i)
  // to data[p+1..i+1).
  Value imp = builder.create<arith::SubIOp>(loc, i, p);
  Value c0 = constantIndex(builder, loc, 0);
  scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
  builder.setInsertionPointToStart(forOpJ.getBody());
  Value j = forOpJ.getInductionVar();
  Value imj = builder.create<arith::SubIOp>(loc, i, j);
  operands[1] = imj;
  operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
  forEachIJPairInAllBuffers(
      builder, loc, operands, xPerm, ny,
      [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
        Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
        builder.create<memref::StoreOp>(loc, t, buffer, imj);
      });

  // Store the value at data[i] to data[p].
  builder.setInsertionPointAfter(forOpJ);
  operands[0] = operands[1] = p;
  forEachIJPairInAllBuffers(
      builder, loc, operands, xPerm, ny,
      [&](uint64_t k, Value p, Value usused, Value buffer) {
        builder.create<memref::StoreOp>(loc, d[k], buffer, p);
      });

  builder.setInsertionPointAfter(forOpI);
  builder.create<func::ReturnOp>(loc);
}

/// Creates a function to perform quick sort or a hybrid quick sort on the
/// values in the range of index [lo, hi).
//
//
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
// void quickSort(lo, hi, data) {
//   while (lo + 1 < hi) {
//        p = partition(low, high, data);
//        if (len(lo, p) < len(p+1, hi)) {
//          quickSort(lo, p, data);
//          lo = p+1;
//        } else {
//          quickSort(p + 1, hi, data);
//          hi = p;
//        }
//   }
// }
//
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
// void hybridQuickSort(lo, hi, data, depthLimit) {
//   while (lo + 1 < hi) {
//     len = hi - lo;
//     if (len <= limit) {
//       insertionSort(lo, hi, data);
//     } else {
//       depthLimit --;
//       if (depthLimit <= 0) {
//         heapSort(lo, hi, data);
//       } else {
//          p = partition(low, high, data);
//          if (len(lo, p) < len(p+1, hi)) {
//            quickSort(lo, p, data, depthLimit);
//            lo = p+1;
//          } else {
//            quickSort(p + 1, hi, data, depthLimit);
//            hi = p;
//          }
//       }
//     }
//   }
// }
//
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP) {
  assert(nTrailingP == 1 || nTrailingP == 0);
  bool isHybrid = (nTrailingP == 1);
  OpBuilder::InsertionGuard insertionGuard(builder);
  Block *entryBlock = func.addEntryBlock();
  builder.setInsertionPointToStart(entryBlock);

  Location loc = func.getLoc();
  SmallVector<Value> args;
  args.append(entryBlock->getArguments().begin(),
              entryBlock->getArguments().end());
  Value lo = args[loIdx];
  Value hi = args[hiIdx];
  SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
  scf::WhileOp whileOp =
      builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});

  // The before-region of the WhileOp.
  Block *before =
      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
  builder.setInsertionPointToEnd(before);
  lo = before->getArgument(0);
  hi = before->getArgument(1);
  Value loP1 =
      builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
  Value needSort =
      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
  builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());

  // The after-region of the WhileOp.
  Block *after =
      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
  builder.setInsertionPointToEnd(after);
  lo = after->getArgument(0);
  hi = after->getArgument(1);
  args[0] = lo;
  args[1] = hi;

  if (isHybrid) {
    Value len = builder.create<arith::SubIOp>(loc, hi, lo);
    Value lenLimit = constantIndex(builder, loc, 30);
    Value lenCond = builder.create<arith::CmpIOp>(
        loc, arith::CmpIPredicate::ule, len, lenLimit);
    scf::IfOp lenIf =
        builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);

    // When len <= limit.
    builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
    FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
        builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
        ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
    builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
                                 ValueRange(args).drop_back(nTrailingP));
    builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

    // When len > limit.
    builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
    Value depthLimit = args.back();
    depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
                                               constantI64(builder, loc, 1));
    Value depthCond =
        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
                                      depthLimit, constantI64(builder, loc, 0));
    scf::IfOp depthIf =
        builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);

    // When depth exceeds limit.
    builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
    FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
        builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
        ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
    builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
                                 ValueRange(args).drop_back(nTrailingP));
    builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});

    // When depth doesn't exceed limit.
    builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
    args.back() = depthLimit;
    std::tie(lo, hi) =
        createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
    builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

    builder.setInsertionPointAfter(depthIf);
    lo = depthIf.getResult(0);
    hi = depthIf.getResult(1);
    builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

    builder.setInsertionPointAfter(lenIf);
    lo = lenIf.getResult(0);
    hi = lenIf.getResult(1);
  } else {
    std::tie(lo, hi) =
        createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
  }

  // New [lo, hi) for the next while-loop iteration.
  builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});

  // After the while-loop.
  builder.setInsertionPointAfter(whileOp);
  builder.create<func::ReturnOp>(loc);
}

/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
                                    uint64_t ny, PatternRewriter &rewriter) {
  Location loc = op.getLoc();
  SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};

  // Convert `values` to have dynamic shape and append them to `operands`.
  for (Value v : xys) {
    auto mtp = getMemRefType(v);
    if (!mtp.isDynamicDim(0)) {
      auto newMtp =
          MemRefType::get({ShapedType::kDynamic}, mtp.getElementType());
      v = rewriter.create<memref::CastOp>(loc, newMtp, v);
    }
    operands.push_back(v);
  }

  auto insertPoint = op->template getParentOfType<func::FuncOp>();
  if (!insertPoint)
    return failure();

  SmallString<32> funcName;
  FuncGeneratorType funcGenerator;
  uint32_t nTrailingP = 0;
  switch (op.getAlgorithm()) {
  case SparseTensorSortKind::HybridQuickSort: {
    funcName = kHybridQuickSortFuncNamePrefix;
    funcGenerator = createQuickSortFunc;
    nTrailingP = 1;
    // As a heuristics, set depthLimit = 2 * log2(n).
    Value lo = operands[loIdx];
    Value hi = operands[hiIdx];
    Value len = rewriter.create<arith::IndexCastOp>(
        loc, rewriter.getI64Type(),
        rewriter.create<arith::SubIOp>(loc, hi, lo));
    Value depthLimit = rewriter.create<arith::SubIOp>(
        loc, constantI64(rewriter, loc, 64),
        rewriter.create<math::CountLeadingZerosOp>(loc, len));
    operands.push_back(depthLimit);
    break;
  }
  case SparseTensorSortKind::QuickSort:
    funcName = kQuickSortFuncNamePrefix;
    funcGenerator = createQuickSortFunc;
    break;
  case SparseTensorSortKind::InsertionSortStable:
    funcName = kSortStableFuncNamePrefix;
    funcGenerator = createSortStableFunc;
    break;
  case SparseTensorSortKind::HeapSort:
    funcName = kHeapSortFuncNamePrefix;
    funcGenerator = createHeapSortFunc;
    break;
  }

  FlatSymbolRefAttr func =
      getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
                               xPerm, ny, operands, funcGenerator, nTrailingP);
  rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
  return success();
}

//===---------------------------------------------------------------------===//
// The actual sparse buffer rewriting rules.
//===---------------------------------------------------------------------===//

namespace {
/// Sparse rewriting rule for the push_back operator.
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
  using OpRewritePattern<PushBackOp>::OpRewritePattern;
  PushBackRewriter(MLIRContext *context, bool enableInit)
      : OpRewritePattern(context), enableBufferInitialization(enableInit) {}
  LogicalResult matchAndRewrite(PushBackOp op,
                                PatternRewriter &rewriter) const override {
    // Rewrite push_back(buffer, value, n) to:
    // new_size = size(buffer) + n
    // if (new_size > capacity(buffer))
    //    while new_size > new_capacity
    //      new_capacity = new_capacity*2
    //    new_buffer = realloc(buffer, new_capacity)
    // buffer = new_buffer
    // subBuffer = subviewof(buffer)
    // linalg.fill subBuffer value
    //
    // size(buffer) += n
    //
    // The capacity check is skipped when the attribute inbounds is presented.
    Location loc = op->getLoc();
    Value c0 = constantIndex(rewriter, loc, 0);
    Value buffer = op.getInBuffer();
    Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
    Value size = op.getCurSize();
    Value value = op.getValue();

    Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
    Value newSize = rewriter.create<arith::AddIOp>(loc, size, n);
    auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
    bool nIsOne = (nValue && nValue.value() == 1);

    if (!op.getInbounds()) {
      Value cond = rewriter.create<arith::CmpIOp>(
          loc, arith::CmpIPredicate::ugt, newSize, capacity);

      Value c2 = constantIndex(rewriter, loc, 2);
      auto bufferType =
          MemRefType::get({ShapedType::kDynamic}, value.getType());
      scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
                                                  /*else=*/true);
      // True branch.
      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
      if (nIsOne) {
        capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
      } else {
        // Use a do-while loop to calculate the new capacity as follows:
        //   do { new_capacity *= 2 } while (size > new_capacity)
        scf::WhileOp whileOp =
            rewriter.create<scf::WhileOp>(loc, capacity.getType(), capacity);

        // The before-region of the WhileOp.
        Block *before = rewriter.createBlock(&whileOp.getBefore(), {},
                                             {capacity.getType()}, {loc});
        rewriter.setInsertionPointToEnd(before);

        capacity =
            rewriter.create<arith::MulIOp>(loc, before->getArgument(0), c2);
        cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
                                              newSize, capacity);
        rewriter.create<scf::ConditionOp>(loc, cond, ValueRange{capacity});
        // The after-region of the WhileOp.
        Block *after = rewriter.createBlock(&whileOp.getAfter(), {},
                                            {capacity.getType()}, {loc});
        rewriter.setInsertionPointToEnd(after);
        rewriter.create<scf::YieldOp>(loc, after->getArguments());

        rewriter.setInsertionPointAfter(whileOp);
        capacity = whileOp.getResult(0);
      }

      Value newBuffer =
          rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
      if (enableBufferInitialization) {
        Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
        Value fillValue = constantZero(rewriter, loc, value.getType());
        Value subBuffer = rewriter.create<memref::SubViewOp>(
            loc, newBuffer, /*offset=*/ValueRange{newSize},
            /*size=*/ValueRange{fillSize},
            /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
        rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
      }
      rewriter.create<scf::YieldOp>(loc, newBuffer);

      // False branch.
      rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
      rewriter.create<scf::YieldOp>(loc, buffer);

      // Prepare for adding the value to the end of the buffer.
      rewriter.setInsertionPointAfter(ifOp);
      buffer = ifOp.getResult(0);
    }

    // Add the value to the end of the buffer.
    if (nIsOne) {
      rewriter.create<memref::StoreOp>(loc, value, buffer, size);
    } else {
      Value subBuffer = rewriter.create<memref::SubViewOp>(
          loc, buffer, /*offset=*/ValueRange{size}, /*size=*/ValueRange{n},
          /*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
      rewriter.create<linalg::FillOp>(loc, value, subBuffer);
    }

    // Update the buffer size.
    rewriter.replaceOp(op, {buffer, newSize});
    return success();
  }

private:
  bool enableBufferInitialization;
};

/// Sparse rewriting rule for the sort_coo operator.
struct SortRewriter : public OpRewritePattern<SortOp> {
public:
  using OpRewritePattern<SortOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(SortOp op,
                                PatternRewriter &rewriter) const override {
    SmallVector<Value> xys;
    xys.push_back(op.getXy());
    xys.append(op.getYs().begin(), op.getYs().end());

    auto xPerm = op.getPermMap();
    uint64_t ny = 0;
    if (auto nyAttr = op.getNyAttr())
      ny = nyAttr.getInt();

    return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
  }
};

} // namespace

//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//

void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
                                         bool enableBufferInitialization) {
  patterns.add<PushBackRewriter>(patterns.getContext(),
                                 enableBufferInitialization);
  patterns.add<SortRewriter>(patterns.getContext());
}