#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
FuncOpConversion(LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {}
static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs,
SmallVectorImpl<NamedAttribute> &result) {
for (const auto &attr : op->getAttrs()) {
if (attr.getName() == SymbolTable::getSymbolAttrName() ||
attr.getName() == op.getFunctionTypeAttrName() ||
attr.getName() == "std.varargs" ||
(filterArgAttrs && attr.getName() == op.getArgAttrsAttrName()))
continue;
result.push_back(attr);
}
}
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) const {
auto loc = funcOp.getLoc();
auto ctx = funcOp->getContext();
auto sharedPtrTy =
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1);
auto funcTy = funcOp.getFunctionType();
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
bool isKernel = triton::isKernel(funcOp);
if (isKernel) {
for (auto i : llvm::seq(amendedInputTy.size())) {
if (isa<TensorDescType>(amendedInputTy[i])) {
funcOp.setArgAttr(i, "tt.nv_tma_desc",
mlir::IntegerAttr::get(i32_ty, 1));
}
}
}
if (!isKernel) {
amendedInputTy.push_back(sharedPtrTy);
}
amendedInputTy.push_back(globalPtrTy);
amendedInputTy.push_back(profilePtrTy);
auto amendedFuncTy =
FunctionType::get(ctx, amendedInputTy, funcTy.getResults());
SmallVector<NamedAttribute> amendedAttrs;
filterFuncAttributes(funcOp, true, amendedAttrs);
if (auto argAttrs = funcOp.getAllArgAttrs()) {
llvm::SmallVector<mlir::Attribute> amendedArgAttrs(argAttrs.begin(),
argAttrs.end());
while (amendedArgAttrs.size() < amendedInputTy.size()) {
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
}
amendedAttrs.push_back(
rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(),
rewriter.getArrayAttr(amendedArgAttrs)));
}
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
auto ®ion = funcOp.getBody();
if (!isKernel) {
region.addArgument(sharedPtrTy, loc);
}
region.addArgument(globalPtrTy, loc);
region.addArgument(profilePtrTy, loc);
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
amendedFuncOp.end());
return amendedFuncOp;
}
static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) {
const bool isKernel = triton::isKernel(llvmFuncOp);
for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) {
const auto attrs = llvmFuncOp.getArgAttrDict(i);
if (!attrs) {
continue;
}
for (const auto &attr : attrs) {
if (attr.getName() == "tt.nv_tma_desc") {
const auto i32_type =
mlir::IntegerType::get(llvmFuncOp.getContext(), 32);
assert(attr.getValue() == mlir::IntegerAttr::get(i32_type, 1));
assert(isKernel &&
"tt.nv_tma_desc is not supported for device functions");
mlir::BlockArgument arg = llvmFuncOp.getArgument(i);
const auto byteType =
mlir::IntegerType::get(llvmFuncOp.getContext(), 8);
const auto arrayType = mlir::LLVM::LLVMArrayType::get(
llvmFuncOp.getContext(), byteType, 128);
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getByValAttrName(),
mlir::TypeAttr::get(arrayType));
llvmFuncOp.setArgAttr(i, NVVM::NVVMDialect::getGridConstantAttrName(),
mlir::UnitAttr::get(llvmFuncOp.getContext()));
llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(),
mlir::IntegerAttr::get(i32_type, 64));
}
}
}
}
LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo);
FailureOr<LLVM::LLVMFuncOp> maybeNewFuncOp =
mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter,
*getTypeConverter());
if (failed(maybeNewFuncOp)) {
return failure();
}
LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp;
auto ctx = funcOp->getContext();
if (triton::isKernel(funcOp)) {
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
newFuncOp.setLinkage(LLVM::Linkage::External);
} else {
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
newFuncOp.setLinkage(LLVM::Linkage::Internal);
}
int numWarps = triton::gpu::lookupNumWarps(funcOp);
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
"ttg.total-num-warps"))
numWarps = totalNumWarps.getInt();
if (Attribute maxnregAttr =
funcOp.getParentOp()->getAttr(triton::gpu::AttrMaxRegistersName))
newFuncOp->setAttr(NVVM::NVVMDialect::getMaxnregAttrName(), maxnregAttr);
newFuncOp->setAttr(NVVM::NVVMDialect::getReqntidAttrName(),
rewriter.getDenseI32ArrayAttr(32 * numWarps));
rewriter.eraseOp(funcOp);
rewriter.eraseOp(amendedFuncOp);
handleByvalTmaDescArgs(newFuncOp);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populateFuncOpConversionPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<FuncOpConversion>(typeConverter, targetInfo, benefit);
}