#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace mlir::triton {
#define GEN_PASS_DEF_RELAYOUTTRITONGPU
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
}
namespace {
using namespace mlir;
using namespace triton;
using namespace triton::gpu;
namespace ttng = triton::nvidia_gpu;
RankedTensorType getTMEMTensorLayout(const TypeConverter *tc,
RankedTensorType type, MemDescType memdesc,
unsigned numWarps) {
Attribute encoding;
type = cast<RankedTensorType>(tc->convertType(type));
if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding())) {
encoding = LinearEncodingAttr::get(
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
} else {
auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding());
encoding = ttng::getTmemCompatibleLayout(
tmemEnc.getBlockM(), tmemEnc.getBlockN(), type, numWarps);
}
return type.cloneWithEncoding(encoding);
}
struct TMEMLoadOpPattern : public OpConversionPattern<ttng::TMEMLoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ttng::TMEMLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type = getTMEMTensorLayout(
typeConverter, op.getType(), op.getSrc().getType(), lookupNumWarps(op));
rewriter.modifyOpInPlace(op, [&] { op.getResult().setType(type); });
Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.setInsertionPointAfter(op);
auto cvt = rewriter.create<ConvertLayoutOp>(op.getLoc(), resultType,
op.getResult());
rewriter.replaceAllUsesExcept(op.getResult(), cvt, cvt);
return success();
}
};
struct TMEMStoreOpPattern : public OpConversionPattern<ttng::TMEMStoreOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ttng::TMEMStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
getTMEMTensorLayout(typeConverter, op.getSrc().getType(),
op.getDst().getType(), lookupNumWarps(op));
Value src =
rewriter.create<ConvertLayoutOp>(op.getLoc(), type, adaptor.getSrc());
rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); });
return success();
}
};
struct TMEMAllocOpPattern : public OpConversionPattern<ttng::TMEMAllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ttng::TMEMAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getSrc())
return success();
RankedTensorType type = getTMEMTensorLayout(
typeConverter, op.getSrc().getType(), op.getType(), lookupNumWarps(op));
Value src =
rewriter.create<ConvertLayoutOp>(op.getLoc(), type, adaptor.getSrc());
rewriter.modifyOpInPlace(op, [&] { op.getSrcMutable().assign(src); });
return success();
}
};
class RelayoutTritonGPU
: public triton::impl::RelayoutTritonGPUBase<RelayoutTritonGPU> {
public:
using RelayoutTritonGPUBase::RelayoutTritonGPUBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
int numWarps = lookupNumWarps(mod);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
int numCTAs = TritonGPUDialect::getNumCTAs(mod);
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
numCTAs, true);
TritonGPUConversionTarget target(*context, typeConverter);
target.addDynamicallyLegalDialect<ttng::TritonNvidiaGPUDialect>(
[&](Operation *op) {
return TritonGPUConversionTarget::isDynamicallyLegal(op,
typeConverter);
});
RewritePatternSet patterns(context);
patterns.insert<
GatherScatterOpPattern<ttng::AsyncTMAGatherOp>,
GatherScatterOpPattern<ttng::AsyncTMAScatterOp>,
TMEMLoadOpPattern,
TMEMStoreOpPattern,
TMEMAllocOpPattern
>(typeConverter, context);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
};
}