#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
namespace ttg = mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONNVIDIAGPUOPTIMIZETMEMLAYOUTSPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
static Value stripConvertLayout(Value v) {
while (auto cvt = v.getDefiningOp<ttg::ConvertLayoutOp>())
v = cvt.getSrc();
return v;
}
class TMemSplitLoadPattern : public OpRewritePattern<SplitOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(SplitOp splitOp,
PatternRewriter &rewriter) const override {
Value src = stripConvertLayout(splitOp.getSrc());
auto transOp = src.getDefiningOp<TransOp>();
if (!transOp || transOp.getOrder() != ArrayRef<int>({0, 2, 1}))
return failure();
auto reshapeOp = transOp.getSrc().getDefiningOp<ReshapeOp>();
if (!reshapeOp)
return failure();
Value reshapeSrc = stripConvertLayout(reshapeOp.getSrc());
auto tmemLoad = reshapeSrc.getDefiningOp<TMEMLoadOp>();
if (!tmemLoad)
return failure();
auto shape = reshapeOp.getResult().getType().getShape();
if (shape[0] != cast<RankedTensorType>(reshapeSrc.getType()).getShape()[0])
return failure();
int mDim = getShapePerCTA(tmemLoad.getSrc().getType())[0];
if (mDim != 128)
return failure();
int splitNSize = shape[2];
if (splitNSize < 8)
return failure();
Value tmem = tmemLoad.getSrc();
int numWarps = ttg::lookupNumWarps(tmemLoad);
rewriter.setInsertionPoint(tmemLoad);
auto createSliceLoad =
[&](int64_t nOffset) -> std::pair<TMEMLoadOp, ttg::ConvertLayoutOp> {
Value subSlice = rewriter.create<TMEMSubSliceOp>(tmemLoad.getLoc(), tmem,
nOffset, splitNSize);
Attribute distLayout = getTmemCompatibleLayout(
mDim, splitNSize, splitOp.getOutLHS().getType(), numWarps);
RankedTensorType newLoadType =
splitOp.getOutLHS().getType().cloneWithEncoding(distLayout);
auto load =
rewriter.create<TMEMLoadOp>(tmemLoad.getLoc(), newLoadType, subSlice);
auto cvt = rewriter.create<ttg::ConvertLayoutOp>(
tmemLoad.getLoc(), splitOp.getOutLHS().getType(), load);
return {load, cvt};
};
auto [load0, cvt0] = createSliceLoad(0);
auto [load1, cvt1] = createSliceLoad(splitNSize);
rewriter.replaceOp(splitOp, {cvt0, cvt1});
return success();
}
};
class TMemStoreJoinPattern : public OpRewritePattern<TMEMStoreOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TMEMStoreOp storeOp,
PatternRewriter &b) const override {
Value src = storeOp.getSrc();
while (auto cvt = src.getDefiningOp<ttg::ConvertLayoutOp>()) {
src = cvt.getSrc();
}
auto reshapeOp = src.getDefiningOp<ReshapeOp>();
if (!reshapeOp)
return failure();
auto shape = reshapeOp.getSrc().getType().getShape();
if (reshapeOp.getType().getShape().front() != shape[0])
return failure();
auto transOp = reshapeOp.getSrc().getDefiningOp<TransOp>();
if (!transOp || transOp.getOrder() != ArrayRef<int>({0, 2, 1}))
return failure();
auto joinOp = transOp.getSrc().getDefiningOp<JoinOp>();
if (!joinOp)
return failure();
int mDim = getShapePerCTA(storeOp.getDst().getType())[0];
if (mDim != 128)
return failure();
int splitNSize = shape[2];
if (splitNSize < 8)
return failure();
Location loc = storeOp.getLoc();
Value tmem = storeOp.getDst();
int numWarps = ttg::lookupNumWarps(storeOp);
Value truePred = b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
Attribute distLayout = getTmemCompatibleLayout(
mDim, splitNSize, joinOp.getLhs().getType(), numWarps);
auto newStoreType = joinOp.getLhs().getType().cloneWithEncoding(distLayout);
auto subSlice0 = b.create<TMEMSubSliceOp>(loc, tmem, 0, splitNSize);
auto cvt0 =
b.create<ttg::ConvertLayoutOp>(loc, newStoreType, joinOp.getLhs());
auto store0 =
b.create<TMEMStoreOp>(loc, subSlice0, cvt0.getResult(), truePred);
auto subSlice1 =
b.create<TMEMSubSliceOp>(loc, tmem, splitNSize, splitNSize);
auto cvt1 =
b.create<ttg::ConvertLayoutOp>(loc, newStoreType, joinOp.getRhs());
auto store1 =
b.create<TMEMStoreOp>(loc, subSlice1, cvt1.getResult(), truePred);
b.eraseOp(storeOp);
return success();
}
};
class TMemLoadReducePattern : public OpRewritePattern<TMEMLoadOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TMEMLoadOp tmemLoadOp,
PatternRewriter &rewriter) const override {
int numWarps = ttg::lookupNumWarps(tmemLoadOp);
if (numWarps != 8)
return failure();
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
tmemLoadOp.getSrc().getType().getEncoding());
if (!tmemEnc)
return failure();
int M = tmemEnc.getBlockM();
int N = tmemEnc.getBlockN();
if (M != 128)
return failure();
bool foundReductionAlongN = false;
auto filter = [&](Operation *op) {
if (isa<ttg::ConvertLayoutOp>(op) || op->hasTrait<OpTrait::Elementwise>())
return true;
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
foundReductionAlongN = reduce.getAxis() == 1;
}
return false;
};
ForwardSliceOptions fwdOpt;
fwdOpt.filter = filter;
SetVector<mlir::Operation *> fwdSlices;
getForwardSlice(tmemLoadOp.getResult(), &fwdSlices, fwdOpt);
if (!foundReductionAlongN)
return failure();
RankedTensorType oldType = tmemLoadOp.getType();
Attribute newLayout = ttg::LinearEncodingAttr::get(
tmemLoadOp.getContext(),
ttg::getTmemLoadLayoutSplitLongM(M, N, oldType, numWarps));
if (newLayout == oldType.getEncoding())
return failure();
auto newType = oldType.cloneWithEncoding(newLayout);
tmemLoadOp.getResult().setType(newType);
OpBuilder builder(tmemLoadOp);
builder.setInsertionPointAfter(tmemLoadOp);
auto cvt = builder.create<ttg::ConvertLayoutOp>(
tmemLoadOp.getLoc(), oldType, tmemLoadOp.getResult());
tmemLoadOp.getResult().replaceAllUsesExcept(cvt.getResult(), cvt);
return success();
}
};
class TMemFromSharedMemPattern : public OpRewritePattern<TMEMStoreOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TMEMStoreOp tmemStoreOp,
PatternRewriter &rewriter) const override {
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
tmemStoreOp.getDst().getType().getEncoding());
if (!tmemEnc)
return failure();
int M = tmemEnc.getBlockM();
int N = tmemEnc.getBlockN();
int numWarps = ttg::lookupNumWarps(tmemStoreOp);
std::optional<LinearLayout> ll = gpu::getTmemLoadStoreLayout16x256(
M, N, tmemStoreOp.getSrc().getType(), numWarps);
if (!ll)
return failure();
Attribute newEncoding =
gpu::LinearEncodingAttr::get(tmemStoreOp.getContext(), *ll);
auto newType = RankedTensorType::get(
tmemStoreOp.getSrc().getType().getShape(),
tmemStoreOp.getSrc().getType().getElementType(), newEncoding);
if (newType == tmemStoreOp.getSrc().getType())
return failure();
SetVector<Value> slice;
DenseMap<Value, Attribute> layoutMap;
LogicalResult result = getConvertBackwardSlice(
tmemStoreOp.getSrcMutable(), slice, newEncoding, layoutMap);
if (result.failed())
return failure();
bool foundImprovedLoad = false;
for (Value v : slice) {
auto localLoad = v.getDefiningOp<gpu::LocalLoadOp>();
if (!localLoad)
continue;
if (localLoad.getType().getElementType().getIntOrFloatBitWidth() != 16)
return failure();
LinearLayout regLayout = gpu::toLinearLayout(localLoad.getType());
LinearLayout smemLayout =
gpu::toLinearLayout(localLoad.getSrc().getType());
int vecDim =
regLayout.invertAndCompose(smemLayout).getNumConsecutiveInOut();
if (vecDim != 1)
return failure();
foundImprovedLoad = true;
}
if (!foundImprovedLoad)
return failure();
auto cvt = rewriter.create<ttg::ConvertLayoutOp>(
tmemStoreOp.getLoc(), newType, tmemStoreOp.getSrc());
rewriter.modifyOpInPlace(tmemStoreOp, [&]() {
tmemStoreOp.getSrcMutable().assign(cvt.getResult());
});
return success();
}
};
}
class TritonNvidiaGPUOptimizeTMemLayoutsPass
: public impl::TritonNvidiaGPUOptimizeTMemLayoutsPassBase<
TritonNvidiaGPUOptimizeTMemLayoutsPass> {
public:
using BaseT = TritonNvidiaGPUOptimizeTMemLayoutsPassBase<
TritonNvidiaGPUOptimizeTMemLayoutsPass>;
using BaseT::BaseT;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<TMemSplitLoadPattern, TMemStoreJoinPattern,
TMemLoadReducePattern, TMemFromSharedMemPattern>(context);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
signalPassFailure();
}
};
}
}
}