#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
static SmallVector<Value> computeWarpLevelHistogram(
Location loc, RankedTensorType srcType, SmallVector<Value> &srcValues,
SmallVector<Value> &maskValues, int numBins, int numThreadPerWarp,
Value threadId, ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
assert(numBins % numThreadPerWarp == 0 &&
"numBins must be divisible by numThreadPerWarp");
Value zero = b.i32_val(0);
int numBits = llvm::Log2_64(numBins);
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
unsigned numElementsPerThreads = getTotalElemsPerThread(srcType);
SmallVector<Value> warpLevelHistogram(numBins / numThreadPerWarp, zero);
for (int i = 0; i < numElementsPerThreads; ++i) {
Value value = srcValues[i];
SmallVector<Value> ballotBits;
for (int j = 0; j < numBits; ++j) {
Value bitSet = b.and_(value, b.i32_val(1 << j));
Value cmp = b.icmp_ne(bitSet, zero);
Value bit =
targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp), cmp);
ballotBits.push_back(bit);
}
uint64_t fullMaskValue =
numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF;
Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue);
Value mask = fullMask;
for (int i = 0; i < numBitsLaneId; i++) {
Value updateMask =
b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero),
b.int_val(numThreadPerWarp, 0), fullMask);
mask = b.and_(
mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask));
}
Value inputMaskBit = fullMask;
if (maskValues.size() > 0) {
inputMaskBit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp),
maskValues[i]);
}
mask = b.and_(mask, inputMaskBit);
for (int k = 0; k < warpLevelHistogram.size(); k++) {
Value binMask = mask;
for (int j = 0; j < numBits - numBitsLaneId; j++) {
Value updateMask =
b.int_val(numThreadPerWarp, ((k & (1 << j)) ? 0 : fullMaskValue));
binMask = b.and_(binMask, b.xor_(ballotBits[j], updateMask));
}
Value bitCount = rewriter.create<LLVM::CtPopOp>(
loc, int_ty(numThreadPerWarp), binMask);
if (numThreadPerWarp > 32)
bitCount = b.trunc(i32_ty, bitCount);
warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount);
}
}
return warpLevelHistogram;
}
static void atomicAdd(Value ptr, Value val, Location loc,
ConversionPatternRewriter &rewriter) {
rewriter.create<LLVM::AtomicRMWOp>(loc, LLVM::AtomicBinOp::add, ptr, val,
LLVM::AtomicOrdering::monotonic);
}
static SmallVector<Value> computeCrossWarpHistogram(
Location loc, ConversionPatternRewriter &rewriter, RankedTensorType srcType,
Value baseSharedMemPtr, const SmallVector<Value> &warpLevelHistogram,
int numBins, int numThreadPerWarp, const SmallVector<Value> &indices,
Value threadId, int numWarps) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> histogramValues;
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
int64_t numElementPerThread =
ceil<int64_t>(numBins, numThreadPerWarp * numWarps);
for (int i = 0; i < numElementPerThread; ++i) {
Value offset =
b.add(threadId, b.i32_val((i * numWarps * numThreadPerWarp)));
offset = b.urem(offset, b.i32_val(numBins));
Value sharedMemPtr =
b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset);
b.store(b.i32_val(0), sharedMemPtr);
}
b.barrier();
Block *afterAtomics = nullptr;
for (int i = 0; i < warpLevelHistogram.size(); ++i) {
Value warpLevelHistogramValue = warpLevelHistogram[i];
Value offset = b.add(b.mul(laneId, b.i32_val(warpLevelHistogram.size())),
b.i32_val(i));
Value sharedMemPtr =
b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset);
atomicAdd(sharedMemPtr, warpLevelHistogramValue, loc, rewriter);
}
if (afterAtomics) {
rewriter.create<LLVM::BrOp>(loc, afterAtomics);
rewriter.setInsertionPointToStart(afterAtomics);
}
b.barrier();
for (Value index : indices) {
Value sharedMemPtr =
b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index);
Value val = b.load(i32_ty, sharedMemPtr);
histogramValues.push_back(val);
}
return histogramValues;
}
namespace {
struct HistogramOpConversion
: public ConvertOpToLLVMPattern<triton::HistogramOp> {
public:
using ConvertOpToLLVMPattern<triton::HistogramOp>::ConvertOpToLLVMPattern;
explicit HistogramOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}
LogicalResult
matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getSrc();
auto typeConverter = getTypeConverter();
SmallVector<Value> srcValues = unpackLLElements(loc, input, rewriter);
Value llMask = adaptor.getMask();
SmallVector<Value> maskValues;
if (llMask)
maskValues = unpackLLElements(loc, llMask, rewriter);
int numBins = op.getType().getDimSize(0);
auto mod = op->getParentOfType<ModuleOp>();
int numThreadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
assert(numThreadsPerWarp == 32 ||
numThreadsPerWarp == 64 &&
"Only supports 32 or 64 threads per warp");
int numWarps = triton::gpu::lookupNumWarps(op);
numBins = std::max(numBins, numThreadsPerWarp);
Value threadId = getThreadId(rewriter, loc);
auto srcType = op.getSrc().getType();
SmallVector<Value> warpLevelHistogram = computeWarpLevelHistogram(
loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp,
threadId, rewriter, targetInfo);
Value baseSharedMemPtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto dstType = op.getType();
Attribute dstEncoding = dstType.getEncoding();
auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding,
dstType, true);
SmallVector<Value> innerDimIndices;
for (int i = 0; i < indices.size(); ++i)
innerDimIndices.push_back(indices[i][0]);
SmallVector<Value> histogramValue = computeCrossWarpHistogram(
loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
auto replicationFactor = numWarps * numThreadsPerWarp;
auto threadsPerWarp = getThreadsPerWarp(srcType);
auto warpsPerCTA =
getWarpsPerCTA(srcType.getEncoding(), srcType.getShape());
replicationFactor /= std::accumulate(
threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>());
replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(),
1, std::multiplies<>());
auto b = TritonLLVMOpBuilder(loc, rewriter);
for (auto i = 0; i < histogramValue.size(); ++i) {
histogramValue[i] =
b.sdiv(histogramValue[i], b.i32_val(replicationFactor));
}
Value results = packLLElements(loc, typeConverter, histogramValue, rewriter,
op.getType());
rewriter.replaceOp(op, results);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populateHistogramOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<HistogramOpConversion>(typeConverter, targetInfo, benefit);
}