//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A pass that converts loops generated by the sparsifier into a form that
// can exploit SIMD instructions of the target architecture. Note that this pass
// ensures the sparsifier can generate efficient SIMD (including ArmSVE
// support) with proper separation of concerns as far as sparsification and
// vectorization is concerned. However, this pass is not the final abstraction
// level we want, and not the general vectorizer we want either. It forms a good
// stepping stone for incremental future improvements though.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Matchers.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

namespace {

/// Target SIMD properties:
///   vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
///   enableVLAVectorization: enables scalable vectors (viz. ARMSve)
///   enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
struct VL {
  unsigned vectorLength;
  bool enableVLAVectorization;
  bool enableSIMDIndex32;
};

/// Helper test for invariant value (defined outside given block).
static bool isInvariantValue(Value val, Block *block) {
  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
}

/// Helper test for invariant argument (defined outside given block).
static bool isInvariantArg(BlockArgument arg, Block *block) {
  return arg.getOwner() != block;
}

/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
}

/// Constructs vector type from a memref value.
static VectorType vectorType(VL vl, Value mem) {
  return vectorType(vl, getMemRefType(mem).getElementType());
}

/// Constructs vector iteration mask.
static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
                           Value iv, Value lo, Value hi, Value step) {
  VectorType mtp = vectorType(vl, rewriter.getI1Type());
  // Special case if the vector length evenly divides the trip count (for
  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
  // so that all subsequent masked memory operations are immediately folded
  // into unconditional memory operations.
  IntegerAttr loInt, hiInt, stepInt;
  if (matchPattern(lo, m_Constant(&loInt)) &&
      matchPattern(hi, m_Constant(&hiInt)) &&
      matchPattern(step, m_Constant(&stepInt))) {
    if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
      Value trueVal = constantI1(rewriter, loc, true);
      return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
    }
  }
  // Otherwise, generate a vector mask that avoids overrunning the upperbound
  // during vector execution. Here we rely on subsequent loop optimizations to
  // avoid executing the mask in all iterations, for example, by splitting the
  // loop into an unconditional vector loop and a scalar cleanup loop.
  auto min = AffineMap::get(
      /*dimCount=*/2, /*symbolCount=*/1,
      {rewriter.getAffineSymbolExpr(0),
       rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
      rewriter.getContext());
  Value end = rewriter.createOrFold<affine::AffineMinOp>(
      loc, min, ValueRange{hi, iv, step});
  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
}

/// Generates a vectorized invariant. Here we rely on subsequent loop
/// optimizations to hoist the invariant broadcast out of the vector loop.
static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
                                     Value val) {
  VectorType vtp = vectorType(vl, val.getType());
  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
}

/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
/// that the sparsifier can only generate indirect loads in
/// the last index, i.e. back().
static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
                           Value mem, ArrayRef<Value> idxs, Value vmask) {
  VectorType vtp = vectorType(vl, mem);
  Value pass = constantZero(rewriter, loc, vtp);
  if (llvm::isa<VectorType>(idxs.back().getType())) {
    SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
    Value indexVec = idxs.back();
    scalarArgs.back() = constantIndex(rewriter, loc, 0);
    return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
                                             indexVec, vmask, pass);
  }
  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
                                               pass);
}

/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
/// that the sparsifier can only generate indirect stores in
/// the last index, i.e. back().
static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
                           ArrayRef<Value> idxs, Value vmask, Value rhs) {
  if (llvm::isa<VectorType>(idxs.back().getType())) {
    SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
    Value indexVec = idxs.back();
    scalarArgs.back() = constantIndex(rewriter, loc, 0);
    rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
                                       rhs);
    return;
  }
  rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
}

/// Detects a vectorizable reduction operations and returns the
/// combining kind of reduction on success in `kind`.
static bool isVectorizableReduction(Value red, Value iter,
                                    vector::CombiningKind &kind) {
  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
    kind = vector::CombiningKind::ADD;
    return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
  }
  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
    kind = vector::CombiningKind::ADD;
    return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
  }
  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
    kind = vector::CombiningKind::ADD;
    return subf->getOperand(0) == iter;
  }
  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
    kind = vector::CombiningKind::ADD;
    return subi->getOperand(0) == iter;
  }
  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
    kind = vector::CombiningKind::MUL;
    return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
  }
  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
    kind = vector::CombiningKind::MUL;
    return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
  }
  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
    kind = vector::CombiningKind::AND;
    return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
  }
  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
    kind = vector::CombiningKind::OR;
    return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
  }
  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
    kind = vector::CombiningKind::XOR;
    return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
  }
  return false;
}

/// Generates an initial value for a vector reduction, following the scheme
/// given in Chapter 5 of "The Software Vectorization Handbook", where the
/// initial scalar value is correctly embedded in the vector reduction value,
/// and a straightforward horizontal reduction will complete the operation.
/// Value 'r' denotes the initial value of the reduction outside the loop.
static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
                                Value red, Value iter, Value r,
                                VectorType vtp) {
  vector::CombiningKind kind;
  if (!isVectorizableReduction(red, iter, kind))
    llvm_unreachable("unknown reduction");
  switch (kind) {
  case vector::CombiningKind::ADD:
  case vector::CombiningKind::XOR:
    // Initialize reduction vector to: | 0 | .. | 0 | r |
    return rewriter.create<vector::InsertElementOp>(
        loc, r, constantZero(rewriter, loc, vtp),
        constantIndex(rewriter, loc, 0));
  case vector::CombiningKind::MUL:
    // Initialize reduction vector to: | 1 | .. | 1 | r |
    return rewriter.create<vector::InsertElementOp>(
        loc, r, constantOne(rewriter, loc, vtp),
        constantIndex(rewriter, loc, 0));
  case vector::CombiningKind::AND:
  case vector::CombiningKind::OR:
    // Initialize reduction vector to: | r | .. | r | r |
    return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
  default:
    break;
  }
  llvm_unreachable("unknown reduction kind");
}

/// This method is called twice to analyze and rewrite the given subscripts.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) yields the proper vector form in the output parameter
/// vector 'idxs'. This mechanism ensures that analysis and rewriting code
/// stay in sync. Note that the analyis part is simple because the sparsifier
/// only generates relatively simple subscript expressions.
///
/// See https://llvm.org/docs/GetElementPtr.html for some background on
/// the complications described below.
///
/// We need to generate a position/coordinate load from the sparse storage
/// scheme.  Narrower data types need to be zero extended before casting
/// the value into the `index` type used for looping and indexing.
///
/// For the scalar case, subscripts simply zero extend narrower indices
/// into 64-bit values before casting to an index type without a performance
/// penalty. Indices that already are 64-bit, in theory, cannot express the
/// full range since the LLVM backend defines addressing in terms of an
/// unsigned pointer/signed index pair.
static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
                                VL vl, ValueRange subs, bool codegen,
                                Value vmask, SmallVectorImpl<Value> &idxs) {
  unsigned d = 0;
  unsigned dim = subs.size();
  Block *block = &forOp.getRegion().front();
  for (auto sub : subs) {
    bool innermost = ++d == dim;
    // Invariant subscripts in outer dimensions simply pass through.
    // Note that we rely on LICM to hoist loads where all subscripts
    // are invariant in the innermost loop.
    // Example:
    //   a[inv][i] for inv
    if (isInvariantValue(sub, block)) {
      if (innermost)
        return false;
      if (codegen)
        idxs.push_back(sub);
      continue; // success so far
    }
    // Invariant block arguments (including outer loop indices) in outer
    // dimensions simply pass through. Direct loop indices in the
    // innermost loop simply pass through as well.
    // Example:
    //   a[i][j] for both i and j
    if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
      if (isInvariantArg(arg, block) == innermost)
        return false;
      if (codegen)
        idxs.push_back(sub);
      continue; // success so far
    }
    // Look under the hood of casting.
    auto cast = sub;
    while (true) {
      if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
        cast = icast->getOperand(0);
      else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
        cast = ecast->getOperand(0);
      else
        break;
    }
    // Since the index vector is used in a subsequent gather/scatter
    // operations, which effectively defines an unsigned pointer + signed
    // index, we must zero extend the vector to an index width. For 8-bit
    // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
    // zero extending the elements into 64-bit loses some performance since
    // the 32-bit indexed gather/scatter is more efficient than the 64-bit
    // index variant (if the negative 32-bit index space is unused, the
    // enableSIMDIndex32 flag can preserve this performance). For 64-bit
    // values, there is no good way to state that the indices are unsigned,
    // which creates the potential of incorrect address calculations in the
    // unlikely case we need such extremely large offsets.
    // Example:
    //    a[ ind[i] ]
    if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
      if (!innermost)
        return false;
      if (codegen) {
        SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
        Location loc = forOp.getLoc();
        Value vload =
            genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
        Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
        if (!llvm::isa<IndexType>(etp)) {
          if (etp.getIntOrFloatBitWidth() < 32)
            vload = rewriter.create<arith::ExtUIOp>(
                loc, vectorType(vl, rewriter.getI32Type()), vload);
          else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
            vload = rewriter.create<arith::ExtUIOp>(
                loc, vectorType(vl, rewriter.getI64Type()), vload);
        }
        idxs.push_back(vload);
      }
      continue; // success so far
    }
    // Address calculation 'i = add inv, idx' (after LICM).
    // Example:
    //    a[base + i]
    if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
      Value inv = load.getOperand(0);
      Value idx = load.getOperand(1);
      // Swap non-invariant.
      if (!isInvariantValue(inv, block)) {
        inv = idx;
        idx = load.getOperand(0);
      }
      // Inspect.
      if (isInvariantValue(inv, block)) {
        if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
          if (isInvariantArg(arg, block) || !innermost)
            return false;
          if (codegen)
            idxs.push_back(
                rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
          continue; // success so far
        }
      }
    }
    return false;
  }
  return true;
}

#define UNAOP(xxx)                                                             \
  if (isa<xxx>(def)) {                                                         \
    if (codegen)                                                               \
      vexp = rewriter.create<xxx>(loc, vx);                                    \
    return true;                                                               \
  }

#define TYPEDUNAOP(xxx)                                                        \
  if (auto x = dyn_cast<xxx>(def)) {                                           \
    if (codegen) {                                                             \
      VectorType vtp = vectorType(vl, x.getType());                            \
      vexp = rewriter.create<xxx>(loc, vtp, vx);                               \
    }                                                                          \
    return true;                                                               \
  }

#define BINOP(xxx)                                                             \
  if (isa<xxx>(def)) {                                                         \
    if (codegen)                                                               \
      vexp = rewriter.create<xxx>(loc, vx, vy);                                \
    return true;                                                               \
  }

/// This method is called twice to analyze and rewrite the given expression.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) yields the proper vector form in the output parameter 'vexp'.
/// This mechanism ensures that analysis and rewriting code stay in sync. Note
/// that the analyis part is simple because the sparsifier only generates
/// relatively simple expressions inside the for-loops.
static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
                          Value exp, bool codegen, Value vmask, Value &vexp) {
  Location loc = forOp.getLoc();
  // Reject unsupported types.
  if (!VectorType::isValidElementType(exp.getType()))
    return false;
  // A block argument is invariant/reduction/index.
  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
    if (arg == forOp.getInductionVar()) {
      // We encountered a single, innermost index inside the computation,
      // such as a[i] = i, which must convert to [i, i+1, ...].
      if (codegen) {
        VectorType vtp = vectorType(vl, arg.getType());
        Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
        Value incr = rewriter.create<vector::StepOp>(loc, vtp);
        vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
      }
      return true;
    }
    // An invariant or reduction. In both cases, we treat this as an
    // invariant value, and rely on later replacing and folding to
    // construct a proper reduction chain for the latter case.
    if (codegen)
      vexp = genVectorInvariantValue(rewriter, vl, exp);
    return true;
  }
  // Something defined outside the loop-body is invariant.
  Operation *def = exp.getDefiningOp();
  Block *block = &forOp.getRegion().front();
  if (def->getBlock() != block) {
    if (codegen)
      vexp = genVectorInvariantValue(rewriter, vl, exp);
    return true;
  }
  // Proper load operations. These are either values involved in the
  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
  // or coordinate values inside the computation that are now fetched from
  // the sparse storage coordinates arrays, such as a[i] = i becomes
  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
  // and 'hi = lo + vl - 1'.
  if (auto load = dyn_cast<memref::LoadOp>(def)) {
    auto subs = load.getIndices();
    SmallVector<Value> idxs;
    if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
      if (codegen)
        vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
      return true;
    }
    return false;
  }
  // Inside loop-body unary and binary operations. Note that it would be
  // nicer if we could somehow test and build the operations in a more
  // concise manner than just listing them all (although this way we know
  // for certain that they can vectorize).
  //
  // TODO: avoid visiting CSEs multiple times
  //
  if (def->getNumOperands() == 1) {
    Value vx;
    if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
                      vx)) {
      UNAOP(math::AbsFOp)
      UNAOP(math::AbsIOp)
      UNAOP(math::CeilOp)
      UNAOP(math::FloorOp)
      UNAOP(math::SqrtOp)
      UNAOP(math::ExpM1Op)
      UNAOP(math::Log1pOp)
      UNAOP(math::SinOp)
      UNAOP(math::TanhOp)
      UNAOP(arith::NegFOp)
      TYPEDUNAOP(arith::TruncFOp)
      TYPEDUNAOP(arith::ExtFOp)
      TYPEDUNAOP(arith::FPToSIOp)
      TYPEDUNAOP(arith::FPToUIOp)
      TYPEDUNAOP(arith::SIToFPOp)
      TYPEDUNAOP(arith::UIToFPOp)
      TYPEDUNAOP(arith::ExtSIOp)
      TYPEDUNAOP(arith::ExtUIOp)
      TYPEDUNAOP(arith::IndexCastOp)
      TYPEDUNAOP(arith::TruncIOp)
      TYPEDUNAOP(arith::BitcastOp)
      // TODO: complex?
    }
  } else if (def->getNumOperands() == 2) {
    Value vx, vy;
    if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
                      vx) &&
        vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
                      vy)) {
      // We only accept shift-by-invariant (where the same shift factor applies
      // to all packed elements). In the vector dialect, this is still
      // represented with an expanded vector at the right-hand-side, however,
      // so that we do not have to special case the code generation.
      if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
          isa<arith::ShRSIOp>(def)) {
        Value shiftFactor = def->getOperand(1);
        if (!isInvariantValue(shiftFactor, block))
          return false;
      }
      // Generate code.
      BINOP(arith::MulFOp)
      BINOP(arith::MulIOp)
      BINOP(arith::DivFOp)
      BINOP(arith::DivSIOp)
      BINOP(arith::DivUIOp)
      BINOP(arith::AddFOp)
      BINOP(arith::AddIOp)
      BINOP(arith::SubFOp)
      BINOP(arith::SubIOp)
      BINOP(arith::AndIOp)
      BINOP(arith::OrIOp)
      BINOP(arith::XOrIOp)
      BINOP(arith::ShLIOp)
      BINOP(arith::ShRUIOp)
      BINOP(arith::ShRSIOp)
      // TODO: complex?
    }
  }
  return false;
}

#undef UNAOP
#undef TYPEDUNAOP
#undef BINOP

/// This method is called twice to analyze and rewrite the given for-loop.
/// The first call (!codegen) does the analysis. Then, on success, the second
/// call (codegen) rewriters the IR into vector form. This mechanism ensures
/// that analysis and rewriting code stay in sync.
static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
                          bool codegen) {
  Block &block = forOp.getRegion().front();
  // For loops with single yield statement (as below) could be generated
  // when custom reduce is used with unary operation.
  // for (...)
  //   yield c_0
  if (block.getOperations().size() <= 1)
    return false;

  Location loc = forOp.getLoc();
  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
  auto &last = *++block.rbegin();
  scf::ForOp forOpNew;

  // Perform initial set up during codegen (we know that the first analysis
  // pass was successful). For reductions, we need to construct a completely
  // new for-loop, since the incoming and outgoing reduction type
  // changes into SIMD form. For stores, we can simply adjust the stride
  // and insert in the existing for-loop. In both cases, we set up a vector
  // mask for all operations which takes care of confining vectors to
  // the original iteration space (later cleanup loops or other
  // optimizations can take care of those).
  Value vmask;
  if (codegen) {
    Value step = constantIndex(rewriter, loc, vl.vectorLength);
    if (vl.enableVLAVectorization) {
      Value vscale =
          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
      step = rewriter.create<arith::MulIOp>(loc, vscale, step);
    }
    if (!yield.getResults().empty()) {
      Value init = forOp.getInitArgs()[0];
      VectorType vtp = vectorType(vl, init.getType());
      Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
                                       forOp.getRegionIterArg(0), init, vtp);
      forOpNew = rewriter.create<scf::ForOp>(
          loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
      forOpNew->setAttr(
          LoopEmitter::getLoopEmitterLoopAttrName(),
          forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
      rewriter.setInsertionPointToStart(forOpNew.getBody());
    } else {
      rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
      rewriter.setInsertionPoint(yield);
    }
    vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
                          forOp.getLowerBound(), forOp.getUpperBound(), step);
  }

  // Sparse for-loops either are terminated by a non-empty yield operation
  // (reduction loop) or otherwise by a store operation (pararallel loop).
  if (!yield.getResults().empty()) {
    // Analyze/vectorize reduction.
    if (yield->getNumOperands() != 1)
      return false;
    Value red = yield->getOperand(0);
    Value iter = forOp.getRegionIterArg(0);
    vector::CombiningKind kind;
    Value vrhs;
    if (isVectorizableReduction(red, iter, kind) &&
        vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
      if (codegen) {
        Value partial = forOpNew.getResult(0);
        Value vpass = genVectorInvariantValue(rewriter, vl, iter);
        Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
        rewriter.create<scf::YieldOp>(loc, vred);
        rewriter.setInsertionPointAfter(forOpNew);
        Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
        // Now do some relinking (last one is not completely type safe
        // but all bad ones are removed right away). This also folds away
        // nop broadcast operations.
        rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
        rewriter.replaceAllUsesWith(forOp.getInductionVar(),
                                    forOpNew.getInductionVar());
        rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
                                    forOpNew.getRegionIterArg(0));
        rewriter.eraseOp(forOp);
      }
      return true;
    }
  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
    // Analyze/vectorize store operation.
    auto subs = store.getIndices();
    SmallVector<Value> idxs;
    Value rhs = store.getValue();
    Value vrhs;
    if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
        vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
      if (codegen) {
        genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
        rewriter.eraseOp(store);
      }
      return true;
    }
  }

  assert(!codegen && "cannot call codegen when analysis failed");
  return false;
}

/// Basic for-loop vectorizer.
struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
public:
  using OpRewritePattern<scf::ForOp>::OpRewritePattern;

  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
                bool enableVLAVectorization, bool enableSIMDIndex32)
      : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
                                      enableSIMDIndex32} {}

  LogicalResult matchAndRewrite(scf::ForOp op,
                                PatternRewriter &rewriter) const override {
    // Check for single block, unit-stride for-loop that is generated by
    // sparsifier, which means no data dependence analysis is required,
    // and its loop-body is very restricted in form.
    if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
        !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
      return failure();
    // Analyze (!codegen) and rewrite (codegen) loop-body.
    if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
        vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
      return success();
    return failure();
  }

private:
  const VL vl;
};

/// Reduction chain cleanup.
///   v = for { }
///   s = vsum(v)               v = for { }
///   u = expand(s)       ->    for (v) { }
///   for (u) { }
template <typename VectorOp>
struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
public:
  using OpRewritePattern<VectorOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(VectorOp op,
                                PatternRewriter &rewriter) const override {
    Value inp = op.getSource();
    if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
      if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
        if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
          rewriter.replaceOp(op, redOp.getVector());
          return success();
        }
      }
    }
    return failure();
  }
};

} // namespace

//===----------------------------------------------------------------------===//
// Public method for populating vectorization rules.
//===----------------------------------------------------------------------===//

/// Populates the given patterns list with vectorization rules.
void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
                                               unsigned vectorLength,
                                               bool enableVLAVectorization,
                                               bool enableSIMDIndex32) {
  assert(vectorLength > 0);
  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
                              enableVLAVectorization, enableSIMDIndex32);
  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
               ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
}