#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
using ConvertOpToLLVMPattern<triton::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (funcOp->hasAttr("nvvm.kernel")) {
if (op.getNumOperands() > 0) {
return rewriter.notifyMatchFailure(
op, "Kernel functions do not support return with operands");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
} else {
LLVM::ReturnOp newOp;
if (adaptor.getOperands().size() < 2) {
newOp =
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
} else {
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
funcOp.getResultTypes());
Value packedResults =
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
for (auto it : llvm::enumerate(adaptor.getOperands())) {
packedResults = b.insert_val(packedResultsTy, packedResults,
it.value(), it.index());
}
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
}
newOp->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newOp->getResults());
}
return success();
}
};
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
CallOpConversion(LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
auto newCallOp =
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
if (!newCallOp)
return failure();
auto results = getCallOpResults(callOp, newCallOp, rewriter);
rewriter.replaceOp(callOp, results);
return success();
}
private:
SmallVector<Value, 4>
promoteOperands(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = callOp.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), callOp->getOperands(),
adaptor.getOperands(), rewriter);
if (!caller->hasAttr("allocation.offset") ||
!callOp->hasAttr("allocation.offset")) {
auto base = LLVM::getStackPointer(rewriter, caller);
promotedOperands.push_back(base);
} else {
auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp);
promotedOperands.push_back(base);
}
auto opOffsetAttr = callOp->getAttrOfType<mlir::IntegerAttr>(
"ttg.global_scratch_memory_offset");
Value opOffsetVal;
if (opOffsetAttr) {
auto opOffset = opOffsetAttr.getValue().getZExtValue();
opOffsetVal = b.i32_val(opOffset);
}
promotedOperands.push_back(LLVM::getGlobalScratchPtr(
loc, rewriter, targetInfo, caller, opOffsetVal));
promotedOperands.push_back(
LLVM::getProfileScratchPtr(loc, rewriter, caller));
return promotedOperands;
}
LLVM::CallOp
convertCallOpToLLVMCallOp(triton::CallOp callOp,
ArrayRef<Value> promotedOperands,
ConversionPatternRewriter &rewriter) const {
Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
if (numResults != 0) {
if (!(packedResult =
this->getTypeConverter()->packFunctionResults(resultTypes)))
return nullptr;
}
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
newCallOp.getProperties().setOpBundleSizes(
rewriter.getDenseI32ArrayAttr({}));
newCallOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(promotedOperands.size()), 0});
return newCallOp;
}
SmallVector<Value>
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
ConversionPatternRewriter &rewriter) const {
auto numResults = callOp.getNumResults();
SmallVector<Value> results;
if (numResults < 2) {
results.append(newCallOp.result_begin(), newCallOp.result_end());
} else {
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
callOp.getLoc(), newCallOp->getResult(0), i));
}
}
return results;
}
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populateControlFlowOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<CallOpConversion>(typeConverter, targetInfo, benefit);
}