#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_REDUCESCANCOMMON_H
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include <iterator>
#include <type_traits>
#define DEBUG_TYPE "ttgpu_to_llvm"
using namespace mlir;
using namespace mlir::triton;
namespace mlir::triton {
class ReduceOp;
class ScanOp;
inline SmallVector<Value>
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
Block *insertionBlock, Block::iterator insertionPoint,
ValueRange combineArgs) {
auto returnOp = combineBlock.getTerminator();
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
combineArgs);
auto results = SmallVector<Value>(returnOp->getOperands());
rewriter.eraseOp(returnOp);
return results;
}
inline SmallVector<Value> applyCombineOp(Location loc,
ConversionPatternRewriter &rewriter,
Region &combineOp, ValueRange acc,
ValueRange cur, Value pred = {}) {
if (acc.size() == 0) {
return cur;
}
assert(cur.size() == acc.size());
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, parent,
std::next(currentBlock->getIterator()));
Block &newCombine = *currentBlock->getNextNode();
llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}
auto isRegionSpeculatable =
std::all_of(newCombine.begin(), newCombine.end(),
[](auto &op) { return isSpeculatable(&op); });
if (!pred || isRegionSpeculatable) {
return inlineCombineBlock(rewriter, newCombine, currentBlock,
rewriter.getInsertionPoint(), combineArgs);
}
Block *thenBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
auto returnOp = newCombine.getTerminator();
auto results = SmallVector<Value>(returnOp->getOperands());
rewriter.setInsertionPointToEnd(currentBlock);
SmallVector<Value> thenBlockArgs;
thenBlockArgs.reserve(results.size());
for (auto result : results) {
auto ty = result.getType();
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
thenBlockArgs.push_back(undef);
thenBlock->addArgument(ty, loc);
}
rewriter.create<LLVM::CondBrOp>(loc, pred, &newCombine, combineArgs,
thenBlock, thenBlockArgs);
rewriter.setInsertionPointToEnd(&newCombine);
rewriter.replaceOpWithNewOp<LLVM::BrOp>(returnOp, results, thenBlock);
rewriter.setInsertionPointToStart(thenBlock);
return SmallVector<Value>(thenBlock->getArguments());
}
}
template <typename SourceOp>
class ConvertTritonGPUReduceScanToLLVMPattern
: public ConvertOpToLLVMPattern<SourceOp> {
public:
static_assert(std::is_same_v<SourceOp, ReduceOp> ||
std::is_same_v<SourceOp, ScanOp>);
using ConvertOpToLLVMPattern<SourceOp>::getTypeConverter;
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
Type getElementType(SourceOp op, int i) const {
auto ty = op.getInputTypes()[i].getElementType();
return getTypeConverter()->convertType(ty);
}
SmallVector<Value> getSmemBases(SourceOp op, unsigned elems,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) const {
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
std::vector<unsigned> indices(op.getNumOperands());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) {
return op.getElementTypes()[i].getIntOrFloatBitWidth() >
op.getElementTypes()[j].getIntOrFloatBitWidth();
});
std::map<unsigned, Value> indexToBase;
auto basePtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
indexToBase[indices[0]] = basePtr;
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
indexToBase[indices[i]] =
b.gep(basePtr.getType(), getElementType(op, indices[i - 1]),
indexToBase[indices[i - 1]], b.i32_val(elems));
}
SmallVector<Value> smemBases(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
smemBases[i] = indexToBase[i];
}
return smemBases;
}
};
#endif