#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
struct VL {
unsigned vectorLength;
bool enableVLAVectorization;
bool enableSIMDIndex32;
};
static bool isInvariantValue(Value val, Block *block) {
return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
}
static bool isInvariantArg(BlockArgument arg, Block *block) {
return arg.getOwner() != block;
}
static VectorType vectorType(VL vl, Type etp) {
return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
}
static VectorType vectorType(VL vl, Value mem) {
return vectorType(vl, getMemRefType(mem).getElementType());
}
static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
Value iv, Value lo, Value hi, Value step) {
VectorType mtp = vectorType(vl, rewriter.getI1Type());
IntegerAttr loInt, hiInt, stepInt;
if (matchPattern(lo, m_Constant(&loInt)) &&
matchPattern(hi, m_Constant(&hiInt)) &&
matchPattern(step, m_Constant(&stepInt))) {
if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
Value trueVal = constantI1(rewriter, loc, true);
return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
}
}
auto min = AffineMap::get(
2, 1,
{rewriter.getAffineSymbolExpr(0),
rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
rewriter.getContext());
Value end = rewriter.createOrFold<affine::AffineMinOp>(
loc, min, ValueRange{hi, iv, step});
return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
}
static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
Value val) {
VectorType vtp = vectorType(vl, val.getType());
return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
}
static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
Value mem, ArrayRef<Value> idxs, Value vmask) {
VectorType vtp = vectorType(vl, mem);
Value pass = constantZero(rewriter, loc, vtp);
if (llvm::isa<VectorType>(idxs.back().getType())) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
indexVec, vmask, pass);
}
return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
pass);
}
static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
ArrayRef<Value> idxs, Value vmask, Value rhs) {
if (llvm::isa<VectorType>(idxs.back().getType())) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
rhs);
return;
}
rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
}
static bool isVectorizableReduction(Value red, Value iter,
vector::CombiningKind &kind) {
if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
kind = vector::CombiningKind::ADD;
return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
}
if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
kind = vector::CombiningKind::ADD;
return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
}
if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
kind = vector::CombiningKind::ADD;
return subf->getOperand(0) == iter;
}
if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
kind = vector::CombiningKind::ADD;
return subi->getOperand(0) == iter;
}
if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
kind = vector::CombiningKind::MUL;
return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
}
if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
kind = vector::CombiningKind::MUL;
return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
}
if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
kind = vector::CombiningKind::AND;
return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
}
if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
kind = vector::CombiningKind::OR;
return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
}
if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
kind = vector::CombiningKind::XOR;
return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
}
return false;
}
static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
Value red, Value iter, Value r,
VectorType vtp) {
vector::CombiningKind kind;
if (!isVectorizableReduction(red, iter, kind))
llvm_unreachable("unknown reduction");
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::XOR:
return rewriter.create<vector::InsertElementOp>(
loc, r, constantZero(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::MUL:
return rewriter.create<vector::InsertElementOp>(
loc, r, constantOne(rewriter, loc, vtp),
constantIndex(rewriter, loc, 0));
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
default:
break;
}
llvm_unreachable("unknown reduction kind");
}
static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
VL vl, ValueRange subs, bool codegen,
Value vmask, SmallVectorImpl<Value> &idxs) {
unsigned d = 0;
unsigned dim = subs.size();
Block *block = &forOp.getRegion().front();
for (auto sub : subs) {
bool innermost = ++d == dim;
if (isInvariantValue(sub, block)) {
if (innermost)
return false;
if (codegen)
idxs.push_back(sub);
continue;
}
if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
if (isInvariantArg(arg, block) == innermost)
return false;
if (codegen)
idxs.push_back(sub);
continue;
}
auto cast = sub;
while (true) {
if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
cast = icast->getOperand(0);
else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
cast = ecast->getOperand(0);
else
break;
}
if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
if (!innermost)
return false;
if (codegen) {
SmallVector<Value> idxs2(load.getIndices());
Location loc = forOp.getLoc();
Value vload =
genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
if (!llvm::isa<IndexType>(etp)) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vectorType(vl, rewriter.getI32Type()), vload);
else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vectorType(vl, rewriter.getI64Type()), vload);
}
idxs.push_back(vload);
}
continue;
}
if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
Value inv = load.getOperand(0);
Value idx = load.getOperand(1);
if (!isInvariantValue(inv, block)) {
inv = idx;
idx = load.getOperand(0);
}
if (isInvariantValue(inv, block)) {
if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
if (isInvariantArg(arg, block) || !innermost)
return false;
if (codegen)
idxs.push_back(
rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
continue;
}
}
}
return false;
}
return true;
}
#define UNAOP(xxx) \
if (isa<xxx>(def)) { \
if (codegen) \
vexp = rewriter.create<xxx>(loc, vx); \
return true; \
}
#define TYPEDUNAOP(xxx) \
if (auto x = dyn_cast<xxx>(def)) { \
if (codegen) { \
VectorType vtp = vectorType(vl, x.getType()); \
vexp = rewriter.create<xxx>(loc, vtp, vx); \
} \
return true; \
}
#define BINOP(xxx) \
if (isa<xxx>(def)) { \
if (codegen) \
vexp = rewriter.create<xxx>(loc, vx, vy); \
return true; \
}
static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
Value exp, bool codegen, Value vmask, Value &vexp) {
Location loc = forOp.getLoc();
if (!VectorType::isValidElementType(exp.getType()))
return false;
if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
if (arg == forOp.getInductionVar()) {
if (codegen) {
VectorType vtp = vectorType(vl, arg.getType());
Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
Value incr = rewriter.create<vector::StepOp>(loc, vtp);
vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
}
return true;
}
if (codegen)
vexp = genVectorInvariantValue(rewriter, vl, exp);
return true;
}
Operation *def = exp.getDefiningOp();
Block *block = &forOp.getRegion().front();
if (def->getBlock() != block) {
if (codegen)
vexp = genVectorInvariantValue(rewriter, vl, exp);
return true;
}
if (auto load = dyn_cast<memref::LoadOp>(def)) {
auto subs = load.getIndices();
SmallVector<Value> idxs;
if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
if (codegen)
vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
return true;
}
return false;
}
if (def->getNumOperands() == 1) {
Value vx;
if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
vx)) {
UNAOP(math::AbsFOp)
UNAOP(math::AbsIOp)
UNAOP(math::CeilOp)
UNAOP(math::FloorOp)
UNAOP(math::SqrtOp)
UNAOP(math::ExpM1Op)
UNAOP(math::Log1pOp)
UNAOP(math::SinOp)
UNAOP(math::TanhOp)
UNAOP(arith::NegFOp)
TYPEDUNAOP(arith::TruncFOp)
TYPEDUNAOP(arith::ExtFOp)
TYPEDUNAOP(arith::FPToSIOp)
TYPEDUNAOP(arith::FPToUIOp)
TYPEDUNAOP(arith::SIToFPOp)
TYPEDUNAOP(arith::UIToFPOp)
TYPEDUNAOP(arith::ExtSIOp)
TYPEDUNAOP(arith::ExtUIOp)
TYPEDUNAOP(arith::IndexCastOp)
TYPEDUNAOP(arith::TruncIOp)
TYPEDUNAOP(arith::BitcastOp)
}
} else if (def->getNumOperands() == 2) {
Value vx, vy;
if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
vx) &&
vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
vy)) {
if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
isa<arith::ShRSIOp>(def)) {
Value shiftFactor = def->getOperand(1);
if (!isInvariantValue(shiftFactor, block))
return false;
}
BINOP(arith::MulFOp)
BINOP(arith::MulIOp)
BINOP(arith::DivFOp)
BINOP(arith::DivSIOp)
BINOP(arith::DivUIOp)
BINOP(arith::AddFOp)
BINOP(arith::AddIOp)
BINOP(arith::SubFOp)
BINOP(arith::SubIOp)
BINOP(arith::AndIOp)
BINOP(arith::OrIOp)
BINOP(arith::XOrIOp)
BINOP(arith::ShLIOp)
BINOP(arith::ShRUIOp)
BINOP(arith::ShRSIOp)
}
}
return false;
}
#undef UNAOP
#undef TYPEDUNAOP
#undef BINOP
static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
bool codegen) {
Block &block = forOp.getRegion().front();
if (block.getOperations().size() <= 1)
return false;
Location loc = forOp.getLoc();
scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
auto &last = *++block.rbegin();
scf::ForOp forOpNew;
Value vmask;
if (codegen) {
Value step = constantIndex(rewriter, loc, vl.vectorLength);
if (vl.enableVLAVectorization) {
Value vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
step = rewriter.create<arith::MulIOp>(loc, vscale, step);
}
if (!yield.getResults().empty()) {
Value init = forOp.getInitArgs()[0];
VectorType vtp = vectorType(vl, init.getType());
Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
forOp.getRegionIterArg(0), init, vtp);
forOpNew = rewriter.create<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
forOpNew->setAttr(
LoopEmitter::getLoopEmitterLoopAttrName(),
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
forOp.getLowerBound(), forOp.getUpperBound(), step);
}
if (!yield.getResults().empty()) {
if (yield->getNumOperands() != 1)
return false;
Value red = yield->getOperand(0);
Value iter = forOp.getRegionIterArg(0);
vector::CombiningKind kind;
Value vrhs;
if (isVectorizableReduction(red, iter, kind) &&
vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
if (codegen) {
Value partial = forOpNew.getResult(0);
Value vpass = genVectorInvariantValue(rewriter, vl, iter);
Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
rewriter.create<scf::YieldOp>(loc, vred);
rewriter.setInsertionPointAfter(forOpNew);
Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
rewriter.replaceAllUsesWith(forOp.getInductionVar(),
forOpNew.getInductionVar());
rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
forOpNew.getRegionIterArg(0));
rewriter.eraseOp(forOp);
}
return true;
}
} else if (auto store = dyn_cast<memref::StoreOp>(last)) {
auto subs = store.getIndices();
SmallVector<Value> idxs;
Value rhs = store.getValue();
Value vrhs;
if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
if (codegen) {
genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
rewriter.eraseOp(store);
}
return true;
}
}
assert(!codegen && "cannot call codegen when analysis failed");
return false;
}
struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
public:
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
: OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
if (vectorizeStmt(rewriter, op, vl, false) &&
vectorizeStmt(rewriter, op, vl, true))
return success();
return failure();
}
private:
const VL vl;
};
template <typename VectorOp>
struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
public:
using OpRewritePattern<VectorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(VectorOp op,
PatternRewriter &rewriter) const override {
Value inp = op.getSource();
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
rewriter.replaceOp(op, redOp.getVector());
return success();
}
}
}
return failure();
}
};
}
void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
assert(vectorLength > 0);
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
enableVLAVectorization, enableSIMDIndex32);
patterns.add<ReducChainRewriter<vector::InsertElementOp>,
ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
}