#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
MemDescType memDescTy, SharedMemoryObject smemObj,
ArrayRef<Value> inVals,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
auto regTy = cast<RankedTensorType>(regVal.getType());
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
auto kReg = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kOffset = str_attr("offset");
auto regLayout = toLinearLayout(regTy);
auto paddedEnc =
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding());
LinearLayout cvt = LinearLayout::empty();
if (paddedEnc) {
cvt = getPaddedRegToSharedLayout(regLayout, paddedEnc);
} else {
auto sharedLayout = toLinearLayout(memDescTy);
cvt = regLayout.invertAndCompose(sharedLayout);
auto kBlock = str_attr("block");
if (!cvt.isTrivialOver({kBlock})) {
return failure();
}
}
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
lowerLocalLdSt(loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
rewriter, targetInfo);
return success();
}
struct GlobalScratchAllocOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
const TargetInfoBase *targetInfo;
GlobalScratchAllocOpConversion(LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit), targetInfo(&targetInfo) {}
LogicalResult
matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto opOffsetAttr = op->getAttrOfType<mlir::IntegerAttr>(
"ttg.global_scratch_memory_offset");
assert(opOffsetAttr);
auto opOffset = opOffsetAttr.getValue().getZExtValue();
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (!funcOp) {
return failure();
}
Value ptr = LLVM::getGlobalScratchPtr(loc, rewriter, *targetInfo, funcOp,
b.i32_val(opOffset));
rewriter.replaceOp(op, ptr);
return success();
}
};
struct LocalAllocOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalAllocOp> {
LocalAllocOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalAllocOp>(converter, benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.isSharedMemoryAlloc())
return failure();
Location loc = op->getLoc();
Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto memDescTy = cast<MemDescType>(op.getType());
auto typeConverter = getTypeConverter();
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, memDescTy.getRank(),
loc, rewriter);
if (op.getSrc()) {
auto *ctx = op.getContext();
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
if (failed(lowerLocalStore(loc, ctx, op.getSrc(), memDescTy, smemObj,
inVals, typeConverter, rewriter,
targetInfo))) {
return failure();
}
}
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
struct LocalDeallocOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalDeallocOp> {
using ConvertOpToLLVMPattern<
triton::gpu::LocalDeallocOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::gpu::LocalDeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
public:
LocalLoadOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}
LogicalResult
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto *ctx = op.getContext();
auto memDescVal = op.getSrc();
auto regVal = op.getResult();
auto memDescTy = cast<MemDescType>(memDescVal.getType());
auto regTy = cast<RankedTensorType>(regVal.getType());
auto typeConverter = getTypeConverter();
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
auto sharedEnc =
cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding());
auto kReg = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kOffset = str_attr("offset");
auto regLayout = toLinearLayout(regTy);
auto paddedEnc = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
LinearLayout cvt = LinearLayout::empty();
if (paddedEnc) {
cvt = getPaddedRegToSharedLayout(regLayout, paddedEnc);
} else {
auto sharedLayout = toLinearLayout(memDescTy);
cvt = regLayout.invertAndCompose(sharedLayout);
auto kBlock = str_attr("block");
if (!cvt.isTrivialOver({kBlock})) {
return failure();
}
}
cvt = cvt.sublayout({kReg, kLane, kWarp}, {kOffset});
auto outVals = lowerLocalLdSt(loc, ctx, cvt, {}, llvmElemTy, memDescTy,
smemObj, rewriter, targetInfo, op);
Value result = packLLElements(loc, typeConverter, outVals, rewriter, regTy);
rewriter.replaceOp(op, result);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
struct LocalStoreOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
LocalStoreOpConversion(const LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto *ctx = op.getContext();
Value regVal = op.getSrc();
Value memDescVal = op.getDst();
auto typeConverter = getTypeConverter();
auto memDescTy = cast<MemDescType>(memDescVal.getType());
auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
llvmElemTy, rewriter);
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
typeConverter, rewriter, targetInfo))) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populateMemoryOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, targetInfo,
benefit);
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
}