#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include <memory>
namespace mlir::triton::gpu {
namespace {
class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp,
PatternRewriter &rewriter) const override {
if (!cvtOp->hasOneUse() ||
!isa<triton::DotOp>(cvtOp->use_begin()->getOwner()))
return failure();
auto trans = cvtOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>{1, 0})
return failure();
RankedTensorType srcTy = trans.getSrc().getType();
if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
srcTy = srcCvt.getSrc().getType();
}
RankedTensorType sharedLoadTy = cvtOp.getType();
auto cvtEncoding =
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
if (!cvtEncoding)
return failure();
auto ctx = getContext();
auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding());
auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder());
auto newInnerCvtEnc =
SwizzledSharedEncodingAttr::get(ctx, cvtEncoding, srcTy.getShape(),
getOrderForMemory(srcTy),
newCTALayout, srcTy.getElementType(),
true);
if (newInnerCvtEnc == cvtEncoding)
return failure();
rewriter.setInsertionPoint(trans);
auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext());
auto alloc = rewriter.create<LocalAllocOp>(
trans.getLoc(),
MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
newInnerCvtEnc, sharedMemorySpace),
trans.getSrc());
auto newTrans = rewriter.create<MemDescTransOp>(trans.getLoc(), alloc,
ArrayRef<int32_t>({1, 0}));
auto localLoadOp =
rewriter.create<LocalLoadOp>(trans.getLoc(), sharedLoadTy, newTrans);
rewriter.modifyOpInPlace(cvtOp, [&]() {
cvtOp.getSrcMutable().assign(localLoadOp.getResult());
});
return success();
}
};
class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(LocalAllocOp allocOp,
PatternRewriter &rewriter) const override {
if (!allocOp.getSrc() || !allocOp->hasOneUse() ||
!isa<triton::nvidia_gpu::WarpGroupDotOp,
triton::nvidia_gpu::MMAv5OpInterface>(
*allocOp->getUsers().begin()))
return failure();
auto dot = *allocOp->getUsers().begin();
auto trans = allocOp.getSrc().getDefiningOp<TransOp>();
if (!trans || trans.getOrder() != ArrayRef<int32_t>({1, 0}))
return failure();
MemDescType allocType = allocOp.getType();
auto allocEncoding = cast<NVMMASharedEncodingAttr>(allocType.getEncoding());
RankedTensorType srcTy = trans.getSrc().getType();
auto newInnerCvtOrder = getOrderForMemory(srcTy);
if (auto cvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
newInnerCvtOrder = getOrderForMemory(cvt.getSrc().getType());
}
auto srcElemTy = allocType.getElementType();
if (!srcElemTy.isF16() && !srcElemTy.isBF16()) {
if (allocOp.getResult() == dot->getOperand(0)) {
newInnerCvtOrder = {0, 1};
} else if (allocOp.getResult() == dot->getOperand(1)) {
newInnerCvtOrder = {1, 0};
}
}
auto ctx = getContext();
auto newCTALayout =
permuteCTALayout(ctx, allocEncoding.getCTALayout(), {1, 0});
auto newInnerEnc = NVMMASharedEncodingAttr::get(
getContext(), srcTy.getShape(), newInnerCvtOrder, newCTALayout,
srcTy.getElementType(), allocEncoding.getFp4Padded());
MemDescType innerTy =
MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc,
allocType.getMemorySpace());
auto newAlloc = rewriter.create<LocalAllocOp>(allocOp.getLoc(), innerTy,
trans.getSrc());
rewriter.replaceOpWithNewOp<MemDescTransOp>(allocOp, newAlloc,
ArrayRef<int32_t>({1, 0}));
return success();
}
};
class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(LocalAllocOp allocOp,
PatternRewriter &rewriter) const override {
if (!allocOp.getSrc())
return failure();
auto reshapeOp = allocOp.getSrc().getDefiningOp<ReshapeOp>();
if (!reshapeOp)
return failure();
MemDescType allocType = allocOp.getType();
auto allocEncoding = allocType.getEncoding();
RankedTensorType srcTy = reshapeOp.getSrc().getType();
auto srcShape = srcTy.getShape();
auto dstShape = allocType.getShape();
MemDescType innerTy;
if (failed(MemDescReshapeOp::inferReturnTypes(
getContext(), allocOp.getLoc(), allocType, srcShape, innerTy)))
return failure();
auto newAlloc = rewriter.create<LocalAllocOp>(allocOp.getLoc(), innerTy,
reshapeOp.getSrc());
rewriter.replaceOpWithNewOp<MemDescReshapeOp>(allocOp, allocOp.getType(),
newAlloc);
return success();
}
};
class UseShmemForScales
: public OpRewritePattern<triton::nvidia_gpu::TCGen5MMAScaledOp> {
public:
using OpRewritePattern<
triton::nvidia_gpu::TCGen5MMAScaledOp>::OpRewritePattern;
LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAScaledOp mmaOp,
PatternRewriter &rewriter) const override {
auto aScale = mmaOp.getAScale();
auto bScale = mmaOp.getBScale();
LogicalResult ret = failure();
if (aScale && isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
aScale.getType().getEncoding())) {
if (rewriteOperand(mmaOp.getAScaleMutable(), rewriter).succeeded())
ret = success();
}
if (bScale && isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
bScale.getType().getEncoding())) {
if (rewriteOperand(mmaOp.getBScaleMutable(), rewriter).succeeded())
ret = success();
}
return ret;
}
private:
LogicalResult rewriteOperand(OpOperand &opOperand,
PatternRewriter &rewriter) const {
auto src = cast<TypedValue<MemDescType>>(opOperand.get());
auto tmemAlloc = src.getDefiningOp<triton::nvidia_gpu::TMEMAllocOp>();
if (!tmemAlloc) {
return failure();
}
auto dstType = tmemAlloc.getResult().getType();
if (!tmemAlloc.getSrc()) {
return failure();
}
auto scale2DShape = dstType.getShape();
auto blockMN = scale2DShape[0];
auto numScales = scale2DShape[1];
const SmallVector<int> transposeOrder{0, 3, 2, 1, 4};
const SmallVector<int64_t> reshape5DShape{blockMN / 128, numScales / 4, 32,
4, 4};
auto reshapeOp2D = getNextOp<triton::ReshapeOp>(tmemAlloc.getSrc());
if (!reshapeOp2D ||
reshapeOp2D.getResult().getType().getShape() != scale2DShape) {
return failure();
}
auto transOp = getNextOp<triton::TransOp>(reshapeOp2D.getSrc());
if (!transOp || transOp.getOrder() != ArrayRef<int>(transposeOrder)) {
return failure();
}
auto reshapeOp5D = getNextOp<triton::ReshapeOp>(transOp.getSrc());
if (!reshapeOp5D || reshapeOp5D.getResult().getType().getShape() !=
ArrayRef<int64_t>(reshape5DShape)) {
return failure();
}
auto localLoad = getNextOp<triton::gpu::LocalLoadOp>(reshapeOp5D.getSrc());
if (!localLoad) {
return failure();
}
auto localAlloc = getNextOp<LocalAllocOp>(localLoad.getSrc());
bool usesTMAload =
(localAlloc && localAlloc.getSrc() &&
(getNextOp<DescriptorLoadOp>(localAlloc.getSrc()) != nullptr));
if (!isTmemCopyCompatible(localLoad.getSrc().getType(), usesTMAload))
return failure();
opOperand.assign(localLoad.getSrc());
return success();
}
template <typename Op> Op getNextOp(Value op) const {
while (auto cvtOp = op.getDefiningOp<ConvertLayoutOp>()) {
op = cvtOp.getSrc();
}
return op.getDefiningOp<Op>();
}
bool isTmemCopyCompatible(triton::gpu::MemDescType scaleType,
bool usesTMAload) const {
if (!isInnermostContiguous(scaleType, 512))
return false;
if (usesTMAload) {
return true;
}
if (scaleType.getRank() != 2) {
return false;
}
auto elemBits = scaleType.getElementType().getIntOrFloatBitWidth();
auto innerMostBits =
scaleType.getDimSize(scaleType.getRank() - 1) * elemBits;
return innerMostBits % (32 * 128) == 0;
}
};
}
#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUOptimizeDotOperandsPass
: public impl::TritonGPUOptimizeDotOperandsBase<
TritonGPUOptimizeDotOperandsPass> {
public:
using impl::TritonGPUOptimizeDotOperandsBase<
TritonGPUOptimizeDotOperandsPass>::TritonGPUOptimizeDotOperandsBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
OpPassManager pm;
pm.addPass(mlir::createCanonicalizerPass());
if (failed(runPipeline(pm, m)))
return signalPassFailure();
mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
patterns.add<FuseTransMMAV3Plus, ReshapeMemDesc>(context);
patterns.add<UseShmemForScales>(context);
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
signalPassFailure();
}
};
}