#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
namespace mlir {
namespace triton {
namespace gpu {
namespace {
static int getMMAVersionSafe(int computeCapability, DotOp op) {
SmallVector<int> versionsSupported;
if (computeCapability < 75) {
versionsSupported = {1};
} else if (computeCapability < 90) {
versionsSupported = {2};
} else if (computeCapability < 100) {
versionsSupported = {3, 2};
} else if (computeCapability < 110) {
versionsSupported = {5, 2};
} else if (computeCapability < 130) {
versionsSupported = {2};
} else {
assert(false && "computeCapability not supported");
}
for (int baseVersion : versionsSupported) {
if (supportMMA(op, baseVersion))
return baseVersion;
if (baseVersion == 3) {
auto remark = op.emitRemark()
<< "MMA version 3 acceleration not applied due to "
"unsupported shapes or data types.";
remark.attachNote() << "Target compute capability (" << computeCapability
<< ") supports MMA v3.";
}
if (baseVersion == 5) {
auto remark = op.emitRemark()
<< "MMA version 5 acceleration not applied due to "
"unsupported shapes or data types.";
remark.attachNote() << "Target compute capability (" << computeCapability
<< ") supports MMA v5.";
}
}
return 0;
}
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
int numWarps) {
auto rank = shape.size();
if (rank == 3)
return {(unsigned)numWarps, 1, 1};
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion() &&
!isa<TransOp>(op);
};
auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
bool hasChainedDot = false;
for (Operation *op : slices) {
if (isa<DotOp>(op) && (op != dotOp)) {
auto chainedDot = cast<DotOp>(op);
auto resTy = chainedDot.getResult().getType();
if (resTy.getRank() != rank) {
continue;
}
if (auto mmaEncoding =
dyn_cast<NvidiaMmaEncodingAttr>(resTy.getEncoding())) {
return to_vector(mmaEncoding.getWarpsPerCTA());
}
hasChainedDot = true;
}
}
if (hasChainedDot) {
if (shape[0] >= shape[1]) {
return {(unsigned)numWarps, 1};
} else {
return {1, (unsigned)numWarps};
}
}
assert(rank == 2);
SmallVector<int64_t> shapePerWarp = {16, 8};
SmallVector<int64_t> warps = {1, 1};
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
ceil(shape[1], shapePerWarp[1])};
while (product(warps) < numWarps) {
if (reps[0] >= reps[1]) {
warps[0] *= 2;
if (reps[0] != 1) {
reps[0] /= 2;
}
} else {
warps[1] *= 2;
reps[1] /= 2;
}
}
return {(unsigned)warps[0], (unsigned)warps[1]};
}
SmallVector<unsigned, 2>
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<mlir::triton::DotOpInterface>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {4, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, instrShape[1]};
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] > shapePerWarp[0] * ret[0]) {
ret[0] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
static Value
getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
bool allowTranspose, bool isMMAv5Fp4Padded = false,
bool forceTranspose = false,
Operation *op = nullptr ) {
OpBuilder::InsertionGuard g(rewriter);
Value arg = v;
if (auto cvtOp = v.getDefiningOp<ConvertLayoutOp>())
arg = cvtOp.getSrc();
auto argType = cast<RankedTensorType>(arg.getType());
assert(argType.getEncoding() && "unexpected tensor type");
auto order = getOrderForMemory(argType);
llvm::SmallVector<unsigned> newOrder = order;
if (!allowTranspose) {
if (opIdx == 1) {
newOrder = {0, 1};
} else {
newOrder = {1, 0};
}
if (forceTranspose)
std::swap(newOrder[0], newOrder[1]);
}
if (newOrder != order && op) {
op->emitWarning("Warning: Forcing a different order [")
<< newOrder[0] << ", " << newOrder[1]
<< "] on SMEM than the register order for the operand " << opIdx
<< ". Registers will be transposed before SMEM store and the pipelined "
"load for this operand will be disabled, so poor performance is "
"expected. Recommendation: consider transposing the operand in "
"global "
"memory to remove the need to transpose the tensor in registers.";
}
Attribute SharedMemorySpace =
SharedMemorySpaceAttr::get(argType.getContext());
auto CTALayout = getCTALayout(argType.getEncoding());
auto newLayout = NVMMASharedEncodingAttr::get(
argType.getContext(), argType.getShape(), newOrder, CTALayout,
argType.getElementType(), isMMAv5Fp4Padded);
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}
static LocalAllocOp
getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
OpBuilder::InsertionGuard g(rewriter);
auto argType = cast<RankedTensorType>(arg.getType());
assert(argType.getEncoding() && "unexpected tensor type");
auto newOrder = getOrderForMemory(argType);
Attribute SharedMemorySpace =
SharedMemorySpaceAttr::get(argType.getContext());
auto CTALayout = getCTALayout(argType.getEncoding());
auto newLayout = NVMMASharedEncodingAttr::get(
argType.getContext(), 0,
false,
argType.getElementType().getIntOrFloatBitWidth(),
false, CTALayout);
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);
return rewriter.create<LocalAllocOp>(loc, newType, arg);
}
SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
case 3:
return warpsPerTileV3(dotOp, shape, numWarps, instrShape);
default:
assert(false && "not supported version");
return {0, 0};
}
}
static bool bwdFilter(Operation *op) {
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isView(op) ||
isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, BroadcastOp, ConvertLayoutOp>(
op);
}
static int computeOrigBitWidth(Value x) {
SetVector<Operation *> slice;
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
(void)getBackwardSlice(x, &slice, opt);
if (llvm::any_of(slice, [](Operation *op) { return isa<Fp4ToFpOp>(op); }))
return 4;
int origBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
for (auto op : slice) {
if (isa<LoadOp, DescriptorLoadOp>(op)) {
if (auto tensorTy =
dyn_cast<RankedTensorType>(op->getResultTypes().front())) {
origBitWidth =
std::min<int>(origBitWidth, tensorTy.getElementTypeBitWidth());
}
}
}
if (llvm::any_of(slice, [](Operation *op) { return isa<JoinOp>(op); }))
origBitWidth /= 2;
return origBitWidth;
}
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
int computeCapability;
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability, int benefit)
: OpRewritePattern<DotOp>(context, benefit),
computeCapability(computeCapability) {}
mlir::LogicalResult
matchAndRewrite(triton::DotOp dotOp,
mlir::PatternRewriter &rewriter) const override {
if (computeCapability < 70)
return failure();
if (computeCapability < 80) {
dotOp.emitRemark()
<< "Dot op using MMA for compute capability " << computeCapability
<< " has been deprecated. It falls back to the FMA path.";
return failure();
}
if (!dotOp.getType().getEncoding() ||
mlir::isa<NvidiaMmaEncodingAttr>(dotOp.getType().getEncoding()))
return failure();
int numWarps = lookupNumWarps(dotOp);
int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
if (!(versionMajor >= 1 && versionMajor <= 3))
return failure();
bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA());
bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB());
auto origDotOp = dotOp;
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = cast<RankedTensorType>(a.getType());
auto oldBType = cast<RankedTensorType>(b.getType());
auto oldRetType = cast<RankedTensorType>(dotOp.getType());
if ((oldAType.getElementType().isF64() ||
oldBType.getElementType().isF64() ||
oldRetType.getElementType().isF64()) &&
!(computeCapability == 80 || computeCapability == 90)) {
return failure();
}
auto CTALayout = getCTALayout(oldRetType.getEncoding());
auto retShapePerCTA = getShapePerCTA(oldRetType);
auto instrShape = mmaVersionToInstrShape(
versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps);
assert(versionMajor == 2 || versionMajor == 3);
int versionMinor = computeCapability == 75 ? 1 : 0;
auto warpsPerTile = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor,
numWarps, instrShape);
auto mmaEnc = NvidiaMmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, versionMinor, warpsPerTile,
CTALayout, instrShape);
auto newRetType = oldRetType.cloneWithEncoding(mmaEnc);
auto oldAcc = dotOp.getOperand(2);
auto newAcc =
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);
auto getDotOperand = [&](Value v, int opIdx, int bitwidth) {
auto minType =
bitwidth > 0 ? rewriter.getIntegerType(bitwidth) : v.getType();
auto vType = cast<RankedTensorType>(v.getType());
auto newVEncoding = DotOperandEncodingAttr::get(
v.getContext(), opIdx, newRetType.getEncoding(), minType);
auto newVType = vType.cloneWithEncoding(newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
};
Operation *newDot = nullptr;
if (versionMajor == 3) {
auto eltType = dotOp.getA().getType().getElementType();
bool allowTranspose = eltType.isF16() || eltType.isBF16();
if (!aFromLoad) {
int bitwidth = getElementTypeOrSelf(a).getIntOrFloatBitWidth();
a = getDotOperand(a, 0, bitwidth);
} else {
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose,
false,
false, dotOp);
}
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose,
false,
false, dotOp);
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false);
} else {
int minBitwidth =
std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
a = getDotOperand(a, 0, minBitwidth);
b = getDotOperand(b, 1, minBitwidth);
newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, a, b, newAcc,
dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());
}
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(origDotOp, origDotOp.getType(),
newDot->getResult(0));
return success();
}
};
static Attribute getTmemScales(RankedTensorType type, unsigned numWarps) {
return triton::gpu::LinearEncodingAttr::get(
type.getContext(), getScaleTMEMStoreLinearLayout(type, numWarps));
}
static bool canUseTwoCTAs(triton::DotOp dotOp) {
RankedTensorType retType = dotOp.getType();
auto retShapePerCTA = getShapePerCTA(retType);
SmallVector<unsigned> splitNum = getCTASplitNum(retType.getEncoding());
if (splitNum.size() != 2 || splitNum[0] != 2 || splitNum[1] != 1)
return false;
int m = retShapePerCTA[0];
int n = retShapePerCTA[1];
if (m < 64 || n < 32)
return false;
Value b = dotOp.getB();
while (auto cvtOp = b.getDefiningOp<ConvertLayoutOp>())
b = cvtOp.getSrc();
return llvm::isa_and_nonnull<triton::LoadOp, triton::DescriptorLoadOp,
triton::DescriptorGatherOp>(b.getDefiningOp());
}
static DistributedEncodingTrait
replaceCTALayout(DistributedEncodingTrait layout,
const triton::gpu::CTALayoutAttr &newCTALayout) {
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(layout)) {
return BlockedEncodingAttr::get(
layout.getContext(), blockedLayout.getSizePerThread(),
blockedLayout.getThreadsPerWarp(), blockedLayout.getWarpsPerCTA(),
blockedLayout.getOrder(), newCTALayout);
} else if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return SliceEncodingAttr::get(
layout.getContext(), sliceLayout.getDim(),
replaceCTALayout(sliceLayout.getParent(), newCTALayout));
} else {
llvm::report_fatal_error("not implemented");
return layout;
}
}
static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
MLIRContext *ctx = b.getContext();
while (auto cvtOp = b.getDefiningOp<ConvertLayoutOp>())
b = cvtOp.getSrc();
auto loadOp = b.getDefiningOp();
assert((isa<triton::LoadOp, triton::DescriptorLoadOp,
triton::DescriptorGatherOp>(loadOp)) &&
"expected LoadOp");
RankedTensorType bType = cast<RankedTensorType>(b.getType());
auto currentLayout = cast<DistributedEncodingTrait>(bType.getEncoding());
auto newCTALayout =
CTALayoutAttr::get(ctx, {1, 2}, {1, 2}, getCTAOrder(currentLayout));
Attribute newLayout = replaceCTALayout(currentLayout, newCTALayout);
rewriter.setInsertionPoint(loadOp);
for (OpOperand &operand : loadOp->getOpOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.get().getType());
if (!tensorType)
continue;
Value newOperand = rewriter.create<ConvertLayoutOp>(
operand.get().getLoc(), tensorType.cloneWithEncoding(newLayout),
operand.get());
loadOp->setOperand(operand.getOperandNumber(), newOperand);
}
loadOp->getResult(0).setType(bType.cloneWithEncoding(newLayout));
Value newB = loadOp->getResult(0);
rewriter.setInsertionPointAfter(loadOp);
auto cvt = rewriter.create<ConvertLayoutOp>(b.getLoc(), bType, newB);
rewriter.replaceAllUsesExcept(newB, cvt.getResult(), cvt);
return newB;
}
class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
int computeCapability;
public:
BlockedToMMAv5(mlir::MLIRContext *context, int computeCapability, int benefit)
: OpRewritePattern<DotOp>(context, benefit),
computeCapability(computeCapability) {}
mlir::LogicalResult
matchAndRewrite(triton::DotOp dotOp,
mlir::PatternRewriter &rewriter) const override {
RankedTensorType oldRetType = dotOp.getType();
if (!oldRetType.getEncoding() ||
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
return failure();
auto retShapePerCTA = getShapePerCTA(oldRetType);
int numWarps = lookupNumWarps(dotOp);
auto CTALayout = getCTALayout(oldRetType.getEncoding());
int versionMajor = getMMAVersionSafe(computeCapability, dotOp);
if (versionMajor != 5)
return failure();
Location loc = dotOp.getLoc();
Value a = dotOp.getA();
Value b = dotOp.getB();
if (std::min(computeOrigBitWidth(a), computeOrigBitWidth(b)) >= 32 &&
dotOp.getInputPrecision() != InputPrecision::TF32)
return failure();
auto oldAType = dotOp.getA().getType();
auto oldBType = dotOp.getB().getType();
bool useTwoCTAs = canUseTwoCTAs(dotOp);
if (useTwoCTAs) {
b = splitBOperand(b, rewriter);
}
bool allowTranspose = !dotOp.getA().getType().getElementType().isF32();
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
MLIRContext *context = dotOp->getContext();
auto instrShape = mmaVersionToInstrShape(
versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps);
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
context, instrShape[0], instrShape[1], true,
CTASplitNum[0], CTASplitNum[1]);
Attribute tensorMemorySpace =
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
Type accMemDescType = triton::gpu::MemDescType::get(
oldRetType.getShape(), oldRetType.getElementType(), accEncoding,
tensorMemorySpace,
true);
Attribute newDistributedEncoding = nvidia_gpu::getTmemCompatibleLayout(
instrShape[0], instrShape[1], oldRetType, numWarps);
auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding);
Value cvtAcc =
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
auto tokType = rewriter.getType<AsyncTokenType>();
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
loc, accMemDescType, tokType, cvtAcc);
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
auto mma = rewriter.create<triton::nvidia_gpu::TCGen5MMAOp>(
loc, tokType, a, b, acc, acc.getToken(), vTrue,
vTrue);
mma.setTwoCtas(useTwoCTAs);
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
loc, newAccType, tokType, acc, mma.getToken());
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
return success();
}
};
Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
Rewrite load(scale) -> local_load(local_alloc(load(scale))).
This function does not add anything to the final IR when num_stages > 1,
but it makes it easy to apply TMEM copy rewriting later.
Since scales are stored in TMEM for MMAv5 scaled dot, loading of scales do
not needs to be put into SMEM. But in practice, the software pipeliner puts
loading of scales into multi-buffered SMEM. At that point, the SMEM
allocation created here is eliminated.
*/
OpBuilder::InsertionGuard g(rewriter);
auto op = scale.getDefiningOp();
Operation *loadConsumer = nullptr;
if (!op)
return scale;
while (!isa<LoadOp, DescriptorLoadOp>(op)) {
if (auto reshape = dyn_cast<ReshapeOp>(op)) {
op = reshape.getSrc().getDefiningOp();
loadConsumer = reshape;
} else if (auto trans = dyn_cast<TransOp>(op)) {
op = trans.getSrc().getDefiningOp();
loadConsumer = trans;
} else if (auto cvt = dyn_cast<ConvertLayoutOp>(op)) {
op = cvt.getSrc().getDefiningOp();
loadConsumer = cvt;
} else {
return scale;
}
}
auto scaleAfterLoad = op->getResult(0);
auto scaleSmemAlloc =
getSharedMemoryScale(scaleAfterLoad, rewriter, op->getLoc());
rewriter.setInsertionPointAfterValue(scaleSmemAlloc);
auto localLoad = rewriter.create<LocalLoadOp>(
op->getLoc(), scaleAfterLoad.getType(), scaleSmemAlloc);
rewriter.replaceAllUsesExcept(scaleAfterLoad, localLoad.getResult(),
scaleSmemAlloc);
if (loadConsumer) {
return scale;
} else {
return localLoad;
}
}
class ScaledBlockedToMMAv5
: public mlir::OpRewritePattern<triton::DotScaledOp> {
int computeCapability;
public:
ScaledBlockedToMMAv5(mlir::MLIRContext *context, int computeCapability,
int benefit)
: mlir::OpRewritePattern<triton::DotScaledOp>(context, benefit),
computeCapability(computeCapability) {}
mlir::LogicalResult
matchAndRewrite(triton::DotScaledOp dotOp,
mlir::PatternRewriter &rewriter) const override {
RankedTensorType oldRetType = dotOp.getType();
if (!oldRetType.getEncoding() ||
mlir::isa<NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
return failure();
if (dotOp.getAScale() == nullptr || dotOp.getBScale() == nullptr) {
return failure();
}
auto retShapePerCTA = getShapePerCTA(oldRetType);
int numWarps = lookupNumWarps(dotOp);
auto CTALayout = getCTALayout(oldRetType.getEncoding());
if ((computeCapability) / 10 != 10)
return failure();
if (numWarps != 4 && numWarps != 8)
return failure();
if (retShapePerCTA[0] < 128 || retShapePerCTA[1] < 8)
return failure();
Location loc = dotOp.getLoc();
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType();
auto oldBType = b.getType();
bool IsAMixedPrecFp4 = false;
bool IsBMixedPrecFp4 = false;
bool isAFP4 = dotOp.getAElemType() == ScaleDotElemType::E2M1;
bool isBFP4 = dotOp.getBElemType() == ScaleDotElemType::E2M1;
if (dotOp.getAElemType() != dotOp.getBElemType()) {
if (isAFP4)
IsAMixedPrecFp4 = true;
else if (isBFP4)
IsBMixedPrecFp4 = true;
}
bool isMMAv5Fp4PaddedLhs = IsAMixedPrecFp4 || !dotOp.getLhsKPack();
bool isMMAv5Fp4PaddedRhs = IsBMixedPrecFp4 || !dotOp.getRhsKPack();
a = getSharedMemoryMMAOperand(a, rewriter, 0,
!isAFP4,
isMMAv5Fp4PaddedLhs,
!dotOp.getLhsKPack(),
dotOp);
b = getSharedMemoryMMAOperand(b, rewriter, 1,
!isBFP4,
isMMAv5Fp4PaddedRhs,
!dotOp.getRhsKPack(),
dotOp);
MLIRContext *context = dotOp->getContext();
unsigned m = 128;
unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1];
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get(
context, m, n, true, CTASplitNum[0], CTASplitNum[1]);
Attribute tensorMemorySpace =
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
Type accMemDescType = triton::gpu::MemDescType::get(
oldRetType.getShape(), oldRetType.getElementType(), accEncoding,
tensorMemorySpace,
true);
Attribute newDistributedEncoding =
nvidia_gpu::getTmemCompatibleLayout(m, n, oldRetType, numWarps);
auto newAccType = oldRetType.cloneWithEncoding(newDistributedEncoding);
Value cvtAcc =
rewriter.create<ConvertLayoutOp>(loc, newAccType, dotOp.getOperand(2));
auto tokType = rewriter.getType<AsyncTokenType>();
auto acc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
loc, accMemDescType, tokType, cvtAcc);
RankedTensorType oldScaleAType = dotOp.getAScale().getType();
RankedTensorType oldScaleBType = dotOp.getBScale().getType();
Attribute scaleEncoding =
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr::get(
context, CTASplitNum[0], CTASplitNum[1]);
Type scaleAType = triton::gpu::MemDescType::get(
oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleEncoding,
tensorMemorySpace,
false);
Type scaleBType = triton::gpu::MemDescType::get(
oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleEncoding,
tensorMemorySpace,
false);
Attribute scaleALayout = getTmemScales(oldScaleAType, numWarps);
Attribute scaleBLayout = getTmemScales(oldScaleBType, numWarps);
RankedTensorType newScaleAType =
oldScaleAType.cloneWithEncoding(scaleALayout);
RankedTensorType newScaleBType =
oldScaleBType.cloneWithEncoding(scaleBLayout);
auto lhsScale = addSmemStageToScaleLoad(dotOp.getAScale(), rewriter);
auto rhsScale = addSmemStageToScaleLoad(dotOp.getBScale(), rewriter);
Value newScaleA =
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
Value newScaleB =
rewriter.create<ConvertLayoutOp>(loc, newScaleBType, rhsScale);
auto scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
loc, scaleAType, Type(), newScaleA);
auto scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
loc, scaleBType, Type(), newScaleB);
auto vTrue = rewriter.create<arith::ConstantIntOp>(dotOp.getLoc(), 1, 1);
auto mmaOp = rewriter.create<triton::nvidia_gpu::TCGen5MMAScaledOp>(
loc, tokType, a, b, acc.getResult(), acc.getToken(), scaleA.getResult(),
scaleB.getResult(), dotOp.getAElemType(), dotOp.getBElemType(),
vTrue, vTrue);
auto ld = rewriter.create<triton::nvidia_gpu::TMEMLoadOp>(
loc, newAccType, tokType, acc, mmaOp.getToken());
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(dotOp, oldRetType, ld);
return success();
}
};
}
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
.cloneWith(std::nullopt, promotedType);
Type operandElType =
cast<RankedTensorType>(operand.getType()).getElementType();
if (type::isFloat8(operandElType)) {
return builder.create<FpToFpOp>(loc, tensorPromotedType, operand);
}
return builder.create<arith::ExtFOp>(loc, tensorPromotedType, operand);
}
static bool mmav2SupportsFp8Operands(int computeCapability) {
return computeCapability == 89 || computeCapability == 120;
}
static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
mod.walk([=](DotOp dotOp) -> void {
auto D = dotOp.getD();
OpBuilder builder(dotOp);
Type AElType = dotOp.getA().getType().getElementType();
Type promoteType;
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
if (!isNativeFP8 ||
(isNativeFP8 && (mmav2SupportsFp8Operands(computeCapability) ||
mmaLayout.isHopper())))
return;
promoteType = builder.getF16Type();
} else {
Type AElType = dotOp.getA().getType().getElementType();
Type DElType = D.getType().getElementType();
if (AElType == DElType)
return;
promoteType = DElType;
}
Location loc = dotOp.getLoc();
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
dotOp.setOperand(0, promotedA);
dotOp.setOperand(1, promotedB);
});
}
static void transposeDotOp(DotScaledOp dotOp) {
OpBuilder builder(dotOp);
Value lhs = dotOp.getA();
std::array<int, 2> transOrder = {1, 0};
Value lhsTransposed = builder.create<TransOp>(lhs.getLoc(), lhs, transOrder);
Value rhs = dotOp.getB();
Value rhsTransposed = builder.create<TransOp>(rhs.getLoc(), rhs, transOrder);
Value c = dotOp.getC();
Value cTransposed = builder.create<TransOp>(c.getLoc(), c, transOrder);
Value result = builder.create<DotScaledOp>(
dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed,
cTransposed, dotOp.getBScale(), dotOp.getAScale(), dotOp.getBElemType(),
dotOp.getAElemType(), dotOp.getFastMath());
Operation *transposedResult =
builder.create<TransOp>(result.getLoc(), result, transOrder);
dotOp.replaceAllUsesWith(transposedResult);
dotOp.erase();
}
static void transposeDots(ModuleOp m) {
SmallVector<DotScaledOp> toTranspose;
m.walk([&](DotScaledOp dotOp) -> void {
if (dotOp.getAScale() == nullptr && dotOp.getBScale() != nullptr)
toTranspose.push_back(dotOp);
});
for (DotScaledOp dotOp : toTranspose) {
transposeDotOp(dotOp);
}
}
#define GEN_PASS_DEF_TRITONGPUACCELERATEMATMUL
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUAccelerateMatmulPass
: public impl::TritonGPUAccelerateMatmulBase<
TritonGPUAccelerateMatmulPass> {
public:
using impl::TritonGPUAccelerateMatmulBase<
TritonGPUAccelerateMatmulPass>::TritonGPUAccelerateMatmulBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
auto computeCapability = getNVIDIAComputeCapability(m);
transposeDots(m);
mlir::RewritePatternSet patterns(context);
constexpr int benefitDefault = 1;
constexpr int benefitMMAv5 = 10;
patterns.add<BlockedToMMA>(context, computeCapability, benefitDefault);
populateDecomposeScaledBlockedPatterns(patterns, benefitDefault);
patterns.add<BlockedToMMAv5, ScaledBlockedToMMAv5>(
context, computeCapability, benefitMMAv5);
if (applyPatternsGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
decomposeMixedModeDotOp(m, computeCapability);
}
};
}
}
}