#include "ReduceScanCommon.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
using namespace mlir;
using namespace mlir::triton;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::DistributedEncodingTrait;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;
namespace {
struct ReduceOpConversion
: public ConvertTritonGPUReduceScanToLLVMPattern<triton::ReduceOp> {
public:
ReduceOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertTritonGPUReduceScanToLLVMPattern<triton::ReduceOp>(typeConverter,
benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ReduceOpHelper helper(op);
assert(helper.isReduceWithinCTA() &&
"Unexpected srcLayout in ReduceOpConversion");
Location loc = op->getLoc();
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
std::map<SmallVector<unsigned>, SmallVector<Value>> accs;
std::map<SmallVector<unsigned>, SmallVector<Value>> indices;
reduceWithinThreads(helper, srcValues, accs, indices, rewriter);
reduceWithinWarps(helper, accs, rewriter);
if (helper.isWarpSynchronous()) {
packResults(helper, accs, rewriter);
return success();
}
auto smemShape = helper.getScratchRepShape();
SmallVector<Value> smemBases =
getSmemBases(op, product<unsigned>(smemShape), rewriter, targetInfo);
storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter);
sync(rewriter, loc, op);
accumulatePartialReductions(helper, smemBases, rewriter);
sync(rewriter, loc, op);
loadReductionAndPackResult(helper, smemShape, smemBases, rewriter);
return success();
}
private:
const TargetInfoBase &targetInfo;
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
Value pred = {}) const {
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
if (acc.size() < results.size()) {
acc.resize(results.size());
}
for (unsigned i = 0; i < acc.size(); ++i) {
acc[i] = results[i];
}
}
SmallVector<SmallVector<Value>>
unpackInputs(Location loc, triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto types = op.getInputTypes();
auto operands = adaptor.getOperands();
unsigned srcElems = getTotalElemsPerThread(types[0]);
SmallVector<SmallVector<Value>> srcValues(srcElems);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto values = unpackLLElements(loc, operands[i], rewriter);
assert(values.size() == srcValues.size());
for (unsigned j = 0; j < srcValues.size(); ++j) {
srcValues[j].push_back(values[j]);
}
}
return srcValues;
}
void sync(ConversionPatternRewriter &rewriter, Location loc,
triton::ReduceOp op) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
b.barrier();
}
void reduceWithinThreads(
ReduceOpHelper &helper, SmallVector<SmallVector<Value>> &srcValues,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
RankedTensorType operandType = op.getInputTypes()[0];
SmallVector<SmallVector<unsigned>> offsets =
emitOffsetForLayout(helper.getSrcLayout(), operandType);
llvm::MapVector<ArrayRef<unsigned>, int> uniqueOffsets;
for (int i = 0; i < offsets.size(); ++i) {
uniqueOffsets.insert({offsets[i], i});
}
auto *combineOp = &op.getCombineOp();
auto srcIndices = emitIndices(op.getLoc(), rewriter, targetInfo,
helper.getSrcLayout(), operandType, true);
for (const auto &[_, i] : uniqueOffsets) {
SmallVector<unsigned> key = offsets[i];
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
if (isFirst)
indices[key] = srcIndices[i];
}
}
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave,
Value pred = {}) const {
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
numLaneToReduce, interleave);
if (success)
return;
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
}
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
}
}
void
reduceWithinWarps(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned threadOffsetOnReductionAxis =
helper.getThreadOffsetOnReductionAxis();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = accs[key];
warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps,
threadOffsetOnReductionAxis);
}
}
void packResults(ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
unsigned axis = op.getAxis();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
dyn_cast<RankedTensorType>(op.getResult()[i].getType())) {
auto resultLayout = cast<SliceEncodingAttr>(resultTy.getEncoding());
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(accs[key][i]);
}
results[i] = packLLElements(loc, getTypeConverter(), resultVals,
rewriter, resultTy);
} else
results[i] = accs.begin()->second[i];
}
rewriter.replaceOp(op, results);
}
void storeWarpReduceToSharedMemory(
ReduceOpHelper &helper,
std::map<SmallVector<unsigned>, SmallVector<Value>> &accs,
std::map<SmallVector<unsigned>, SmallVector<Value>> &indices,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto srcLayout =
mlir::cast<DistributedEncodingTrait>(helper.getSrcLayout());
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
unsigned axis = op.getAxis();
auto smemShape = helper.getScratchRepShape();
auto srcShape = helper.getSrcShape();
auto kLane = rewriter.getStringAttr("lane");
auto [multiDimLaneId, isRepresentativeLane] =
delinearize(rewriter, loc, srcLayout, srcShape, kLane, laneId);
auto kWarp = rewriter.getStringAttr("warp");
auto [multiDimWarpId, isRepresentativeWarp] =
delinearize(rewriter, loc, srcLayout, srcShape, kWarp, warpId);
Value laneIdAxis = multiDimLaneId[axis];
Value laneZero = b.icmp_eq(laneIdAxis, b.i32_val(0));
Value write =
b.and_(b.and_(isRepresentativeLane, isRepresentativeWarp), laneZero);
Value warpIdAxis = multiDimWarpId[axis];
auto smemOrder = helper.getOrderWithAxisAtBeginning();
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> &acc = it.second;
SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShape, smemOrder);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Value writePtr =
b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset);
targetInfo.storeShared(rewriter, loc, writePtr, acc[i], write);
}
}
}
void accumulatePartialReductions(ReduceOpHelper &helper,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
auto smemShape = helper.getScratchRepShape();
unsigned elems = product<unsigned>(smemShape);
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto mod = op->getParentOfType<ModuleOp>();
int numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
int numWarps = triton::gpu::lookupNumWarps(op);
int numThreads = numLanes * numWarps;
Value threadId = getThreadId(rewriter, loc);
Value warpSize = b.i32_val(numLanes);
Value laneId = b.urem(threadId, warpSize);
Value zero = b.i32_val(0);
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value threadIsNeeded = b.icmp_slt(threadId, b.i32_val(elems));
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
SmallVector<Value> acc(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
Value readPtr =
b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset);
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
threadIsNeeded);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 ,
threadIsNeeded);
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
writePtrs[i] =
b.gep(smemBases[i].getType(), elemTy, smemBases[i], writeOffset);
}
Value laneIdModSizeInterWarps = b.urem(laneId, b.i32_val(sizeInterWarps));
Value laneIdModSizeInterWarpsIsZero =
b.icmp_eq(laneIdModSizeInterWarps, zero);
Value pred = b.and_(threadIsNeeded, laneIdModSizeInterWarpsIsZero);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
targetInfo.storeShared(rewriter, loc, writePtrs[i], acc[i], pred);
}
if (round != elemsPerThread - 1) {
readOffset = b.add(readOffset, b.i32_val(numThreads));
}
}
}
void loadReductionAndPackResult(ReduceOpHelper &helper,
SmallVector<unsigned> smemShape,
SmallVector<Value> &smemBases,
ConversionPatternRewriter &rewriter) const {
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto srcLayout = helper.getSrcLayout();
auto axis = op.getAxis();
auto smemOrder = helper.getOrderWithAxisAtBeginning();
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
auto elemTy = getElementType(op, i);
if (auto resultTy =
dyn_cast<RankedTensorType>(op.getResult()[i].getType())) {
auto resultLayout = cast<SliceEncodingAttr>(resultTy.getEncoding());
unsigned resultElems = getTotalElemsPerThread(resultTy);
auto resultIndices = emitIndices(loc, rewriter, targetInfo,
resultLayout, resultTy, true);
auto resultShape = resultTy.getShape();
assert(resultIndices.size() == resultElems);
SmallVector<Value> resultVals(resultElems);
for (size_t j = 0; j < resultElems; ++j) {
SmallVector<Value> readIdx = resultIndices[j];
readIdx.insert(readIdx.begin() + op.getAxis(), b.i32_val(0));
for (size_t resultIdx = 0, resultDim = resultShape.size();
resultIdx < resultDim; ++resultIdx) {
auto smemIdx = resultIdx < op.getAxis() ? resultIdx : resultIdx + 1;
if (resultShape[resultIdx] > smemShape[smemIdx]) {
readIdx[smemIdx] =
b.urem(readIdx[smemIdx], b.i32_val(smemShape[smemIdx]));
}
}
Value readOffset =
linearize(rewriter, loc, readIdx, smemShape, smemOrder);
Value readPtr =
b.gep(smemBases[i].getType(), elemTy, smemBases[i], readOffset);
resultVals[j] = b.load(elemTy, readPtr);
}
results[i] = packLLElements(loc, getTypeConverter(), resultVals,
rewriter, resultTy);
} else {
results[i] = b.load(elemTy, smemBases[i]);
}
}
rewriter.replaceOp(op, results);
}
};
}
void mlir::triton::populateReduceOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, targetInfo, benefit);
}