#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <cstdint>
using namespace mlir;
namespace {
struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() < 2)
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
assert(vecTy.getRank() == 1 && "Unexpected vector type");
assert(!vecTy.isScalable() && "Unexpected vector type");
Type elemTy = vecTy.getElementType();
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth >= maxShuffleBitwidth)
return rewriter.notifyMatchFailure(
op, llvm::formatv("element type too large ({0}), cannot break down "
"into vectors of bitwidth {1} or less",
elemBitwidth, maxShuffleBitwidth));
unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
assert(elementsPerShuffle >= 1);
unsigned numNewReductions =
llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
assert(numNewReductions >= 1);
if (numNewReductions == 1)
return rewriter.notifyMatchFailure(op, "nothing to break down");
Location loc = op.getLoc();
Value res =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
for (unsigned i = 0; i != numNewReductions; ++i) {
int64_t startIdx = i * elementsPerShuffle;
int64_t endIdx =
std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
int64_t numElems = endIdx - startIdx;
Value extracted;
if (numElems == 1) {
extracted =
rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
} else {
extracted = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getValue(), startIdx, numElems,
1);
}
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
if (numElems == 1) {
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
continue;
}
res = rewriter.create<vector::InsertStridedSliceOp>(
loc, reduce, res, startIdx, 1);
}
rewriter.replaceOp(op, res);
return success();
}
private:
unsigned maxShuffleBitwidth = 0;
};
struct ScalarizeSingleElementReduce final
: OpRewritePattern<gpu::SubgroupReduceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
assert(vecTy.getRank() == 1 && "Unexpected vector type");
assert(!vecTy.isScalable() && "Unexpected vector type");
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
};
static Value createSubgroupShuffleReduction(
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
unsigned subgroupSize, function_ref<Value(Value)> packFn,
function_ref<Value(Value)> unpackFn) {
assert(llvm::isPowerOf2_32(subgroupSize));
Value laneVal = input;
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
subgroupSize,
gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = vector::makeArithReduction(builder, loc,
gpu::convertReductionKind(mode),
laneVal, unpackFn(shuffled));
assert(laneVal.getType() == input.getType());
}
return laneVal;
}
struct ScalarSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
Type valueTy = op.getType();
unsigned elemBitwidth =
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
return rewriter.notifyMatchFailure(
op, "value type is not a compatible scalar");
Location loc = op.getLoc();
if (elemBitwidth == shuffleBitwidth) {
auto identityFn = [](Value v) { return v; };
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, identityFn, identityFn));
return success();
}
auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
auto equivIntType = rewriter.getIntegerType(elemBitwidth);
auto packFn = [loc, &rewriter, equivIntType,
shuffleIntType](Value unpackedVal) -> Value {
auto asInt =
rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
};
auto unpackFn = [loc, &rewriter, equivIntType,
valueTy](Value packedVal) -> Value {
auto asInt =
rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
};
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, packFn, unpackFn));
return success();
}
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
};
struct VectorSubgroupReduceToShuffles final
: OpRewritePattern<gpu::SubgroupReduceOp> {
VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
unsigned shuffleBitwidth,
PatternBenefit benefit)
: OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
shuffleBitwidth(shuffleBitwidth) {}
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return rewriter.notifyMatchFailure(op, "value type is not a vector");
unsigned vecBitwidth =
vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
if (vecBitwidth > shuffleBitwidth)
return rewriter.notifyMatchFailure(
op,
llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
"to shuffles of size {1}",
vecBitwidth, shuffleBitwidth));
unsigned elementsPerShuffle =
shuffleBitwidth / vecTy.getElementTypeBitWidth();
if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
return rewriter.notifyMatchFailure(
op, "shuffle bitwidth is not a multiple of the element bitwidth");
Location loc = op.getLoc();
auto extendedVecTy = VectorType::get(
static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
Value extendedInput = op.getValue();
if (vecBitwidth < shuffleBitwidth) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(extendedVecTy));
extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
loc, extendedInput, zero, 0, 1);
}
auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
auto shuffleVecType = VectorType::get(1, shuffleIntType);
auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
auto asIntVec =
rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
};
auto unpackFn = [loc, &rewriter, shuffleVecType,
extendedVecTy](Value packedVal) -> Value {
auto asIntVec =
rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};
Value res =
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
subgroupSize, packFn, unpackFn);
if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(
loc, res, 0, vecTy.getNumElements(),
1);
}
rewriter.replaceOp(op, res);
return success();
}
private:
unsigned subgroupSize = 0;
unsigned shuffleBitwidth = 0;
};
}
void mlir::populateGpuBreakDownSubgrupReducePatterns(
RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
PatternBenefit benefit) {
patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
maxShuffleBitwidth, benefit);
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
}
void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
RewritePatternSet &patterns, unsigned subgroupSize,
unsigned shuffleBitwidth, PatternBenefit benefit) {
patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
}