#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Tools/Sys/GetEnv.hpp"
namespace ttg = mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONNVIDIAGPUPROMOTELHSTOTMEMPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
template <class MMAOpTy>
Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, RankedTensorType srcType) {
int numWarps = ttg::lookupNumWarps(tcGen5MMAOp);
auto accTmemEncoding = dyn_cast<TensorMemoryEncodingAttr>(
tcGen5MMAOp.getD().getType().getEncoding());
auto lhs = tcGen5MMAOp.getA();
auto lhsShape = lhs.getType().getShape();
int M = accTmemEncoding.getBlockM();
int N = lhsShape[1];
Attribute resLayout = getTmemCompatibleLayout(M, N, srcType, numWarps);
return resLayout;
}
template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
public:
using OpRewritePattern<MMAOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(MMAOpTy tcGen5MMAOp,
PatternRewriter &rewriter) const override {
MLIRContext *context = tcGen5MMAOp->getContext();
Location loc = tcGen5MMAOp.getLoc();
auto lhs = tcGen5MMAOp.getA();
auto localAllocOp = lhs.template getDefiningOp<ttg::LocalAllocOp>();
if (!localAllocOp)
return failure();
if (localAllocOp->getParentRegion() != tcGen5MMAOp->getParentRegion())
return failure();
Value src = localAllocOp.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
auto srcLayout = srcType.getEncoding();
auto accTMemEncoding = dyn_cast<TensorMemoryEncodingAttr>(
tcGen5MMAOp.getD().getType().getEncoding());
ArrayRef<unsigned> CTASplitNum =
triton::gpu::getCTALayout(srcLayout).getCTASplitNum();
unsigned elemBitWidth =
lhs.getType().getElementType().getIntOrFloatBitWidth();
if (elemBitWidth != 16 && elemBitWidth != 32) {
return failure();
}
bool unpacked = elemBitWidth != 16;
auto aTMemEncoding = TensorMemoryEncodingAttr::get(
context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1],
unpacked, CTASplitNum[0], CTASplitNum[1]);
Attribute tensorMemorySpace =
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
ttg::MemDescType lhsMemDescType = ttg::MemDescType::get(
lhs.getType().getShape(), lhs.getType().getElementType(), aTMemEncoding,
tensorMemorySpace,
false);
bool layoutTmemCompatible =
isDistributedLayoutTMemCompatible(tcGen5MMAOp, srcType, lhsMemDescType);
Attribute newLayout = srcLayout;
if (!layoutTmemCompatible) {
if (!comesFromLoadOrBlockArg(src) ||
triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) {
newLayout = getLHSTMemLayout(tcGen5MMAOp, srcType);
} else {
return failure();
}
}
rewriter.setInsertionPointAfter(localAllocOp);
if (newLayout != srcLayout) {
auto ty = cast<RankedTensorType>(src.getType());
auto newTy = ty.cloneWithEncoding(newLayout);
src = rewriter.create<ttg::ConvertLayoutOp>(loc, newTy, src);
}
Value tMemAlloc = rewriter.create<TMEMAllocOp>(loc, lhsMemDescType, src);
tcGen5MMAOp.getAMutable().assign(tMemAlloc);
return success();
}
};
}
class TritonNvidiaGPUPromoteLHSToTMemPass
: public impl::TritonNvidiaGPUPromoteLHSToTMemPassBase<
TritonNvidiaGPUPromoteLHSToTMemPass> {
public:
using TritonNvidiaGPUPromoteLHSToTMemPassBase<
TritonNvidiaGPUPromoteLHSToTMemPass>::
TritonNvidiaGPUPromoteLHSToTMemPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
RewritePatternSet patterns(context);
patterns.add<LHSToTMem<TCGen5MMAOp>>(context);
patterns.add<LHSToTMem<TCGen5MMAScaledOp>>(context);
if (applyPatternsGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};
}
}
}