#include "TritonAMDGPUToLLVM/TargetUtils.h"
#include "TritonAMDGPUTransforms/MfmaGroup.h"
#include "TritonAMDGPUTransforms/Passes.h"
#include "TritonAMDGPUTransforms/WmmaGroup.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h"
#include "triton/Tools/LayoutUtils.h"
#include "llvm/ADT/TypeSwitch.h"
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
using ::mlir::LLVM::AMD::hasTransInDefChain;
using ::mlir::LLVM::AMD::isChainDotHead;
using ::mlir::LLVM::AMD::isChainDotTail;
using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType;
using mlir::triton::gpu::chooseScaledMfmaScaleLayout;
namespace mlir {
namespace {
using triton::AMD::ISAFamily;
int getMfmaVersion(ISAFamily isaFamily) {
switch (isaFamily) {
case ISAFamily::CDNA1:
return 1;
case ISAFamily::CDNA2:
return 2;
case ISAFamily::CDNA3:
return 3;
case ISAFamily::CDNA4:
return 4;
default:
break;
}
return 0;
}
int getWmmaVersion(StringRef archGen) {
if (archGen.contains("gfx11"))
return 1;
if (archGen.contains("gfx12"))
return 2;
return 0;
}
FailureOr<ScaleDotElemType> mlirTypeToScaledElemType(Type type) {
return llvm::TypeSwitch<Type, FailureOr<ScaleDotElemType>>(type)
.Case<Float8E4M3FNType>([](Type) { return ScaleDotElemType::E4M3; })
.Case<Float8E5M2Type>([](Type) { return ScaleDotElemType::E5M2; })
.Case<Float6E3M2FNType>([](Type) { return ScaleDotElemType::E3M2; })
.Case<Float6E2M3FNType>([](Type) { return ScaleDotElemType::E2M3; })
.Case<Float4E2M1FNType>([](Type) { return ScaleDotElemType::E2M1; })
.Default([](Type) { return failure(); });
}
SmallVector<unsigned, 3>
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
std::pair<int64_t, int64_t> shapePerWarp) {
auto rank = shape.size();
if (rank == 3)
return {static_cast<unsigned>(numWarps), 1, 1};
auto ttDotOp = cast<tt::DotOpInterface>(dotOp);
bool isHeadDot = isChainDotHead(ttDotOp);
bool isTailDot = isChainDotTail(ttDotOp);
if (isHeadDot)
return {static_cast<unsigned>(numWarps), 1};
if (isTailDot) {
SmallVector<unsigned, 3> ret = {1, 1};
ret[0] = static_cast<unsigned>(std::min(
static_cast<int64_t>(numWarps),
static_cast<int64_t>(llvm::divideCeil(shape[0], shapePerWarp.first))));
ret[1] = numWarps / ret[0];
return ret;
}
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
SmallVector<unsigned, 3> ret = {1, 1};
do {
if (ret[0] * ret[1] >= numWarps)
break;
if (tensorShape[0] / (shapePerWarp.first * 2) / ret[0] >=
tensorShape[1] / shapePerWarp.second / ret[1]) {
if (ret[0] < tensorShape[0] / shapePerWarp.first) {
ret[0] *= 2;
} else {
ret[1] *= 2;
}
} else {
ret[1] *= 2;
}
} while (true);
if (ret[1] * shapePerWarp.second > tensorShape[1]) {
return {ret[1], ret[0]};
}
return ret;
}
SmallVector<unsigned, 3>
warpsPerTileMFMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
std::pair<int64_t, int64_t> shapePerWarp) {
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
}
SmallVector<unsigned, 3>
warpsPerTileWMMA(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
std::pair<int64_t, int64_t> shapePerWarp) {
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
}
FailureOr<MfmaIntrinsic>
chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
Type aElemType, Type bElemType, int inputKSize,
int enforcedNonKDim, bool withScale, bool allowXF32) {
unsigned kDim = 0;
auto resShape = cType.getShape();
auto rank = resShape.size();
auto M = resShape[rank - 2];
auto N = resShape[rank - 1];
unsigned mDim = 0;
unsigned nDim = 0;
if (enforcedNonKDim != 0) {
mDim = nDim = enforcedNonKDim;
} else {
int minSize = std::min(M, N);
if (minSize >= 32) {
mDim = nDim = 32;
if (aElemType.isF64() || bElemType.isF64()) {
mDim = nDim = 16;
}
} else if (minSize >= 16) {
mDim = nDim = 16;
} else if (minSize >= 4) {
if (M >= 64) {
mDim = 64;
nDim = 4;
} else if (N >= 64) {
mDim = 4;
nDim = 64;
}
}
}
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic =
MfmaIntrinsic::selectFor(loc, mfmaVersion, mDim, nDim, inputKSize,
aElemType, bElemType, withScale, allowXF32);
if (failed(maybeMfmaIntrinsic))
return failure();
kDim = maybeMfmaIntrinsic->kDim;
assert(kDim != 0);
assert(enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0));
if (enforcedNonKDim == 0 && inputKSize % kDim != 0)
return failure();
return maybeMfmaIntrinsic;
}
FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion,
int nonKDim,
bool withScale = false) {
RankedTensorType aType = dot.getA().getType();
bool allowXF32 =
dot.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
return chooseMfmaInstruction(
dot.getLoc(), mfmaVersion, dot.getC().getType(), aType.getElementType(),
dot.getB().getType().getElementType(), aType.getShape().back(), nonKDim,
withScale, allowXF32);
}
FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
int mfmaVersion, int nonKDim) {
auto ctx = dot.getContext();
int64_t inputKDim = dot.getA().getType().getShape().back();
if (dot.getAElemType() == ScaleDotElemType::E2M1 && dot.getLhsKPack()) {
inputKDim *= 2;
}
Type aElemType = scaleDotElemTypeToMLIRType(ctx, dot.getAElemType());
Type bElemType = scaleDotElemTypeToMLIRType(ctx, dot.getBElemType());
return chooseMfmaInstruction(dot.getLoc(), mfmaVersion, dot.getC().getType(),
aElemType, bElemType, inputKDim, nonKDim,
true, false);
}
FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot,
int mfmaVersion, int nonKDim,
bool useFp16) {
Builder b(dot.getContext());
Type elemType = useFp16 ? b.getF16Type() : b.getBF16Type();
return chooseMfmaInstruction(dot.getLoc(), mfmaVersion, dot.getC().getType(),
elemType, elemType,
dot.getA().getType().getShape().back(), nonKDim,
false, false);
}
using OperandTypesVector = SmallVector<Type, 4>;
OperandTypesVector
selectMatrixCoreOperandTypes(tt::DotOp dot,
ArrayRef<OperandTypesVector> applicableTypes) {
SmallVector<Value> dotOperands = {dot.getA(), dot.getB(), dot.getC(),
dot.getD()};
OperandTypesVector initElemTypes;
llvm::transform(dotOperands, std::back_inserter(initElemTypes), [](Value v) {
return cast<RankedTensorType>(v.getType()).getElementType();
});
int maxConvertCost =
std::numeric_limits<int32_t>::max() / applicableTypes.front().size();
auto calcConvertCost = [&](Type fromTy, Type toTy) -> int32_t {
if (fromTy == toTy)
return 0;
if (fromTy.isIntOrIndex() != toTy.isIntOrIndex())
return maxConvertCost;
if (fromTy.isIntOrIndex() && toTy.isIntOrIndex() &&
fromTy.isUnsignedInteger() != toTy.isUnsignedInteger())
return fromTy.isUnsignedInteger() && fromTy.getIntOrFloatBitWidth() <
toTy.getIntOrFloatBitWidth()
? 1
: maxConvertCost;
return fromTy.getIntOrFloatBitWidth() <= toTy.getIntOrFloatBitWidth()
? 1
: maxConvertCost;
};
auto minCost = maxConvertCost;
auto optTypes = OperandTypesVector();
for (auto types : applicableTypes) {
assert(types.size() == initElemTypes.size());
int accumulatedConvertCost = 0;
for (int i = 0; i < initElemTypes.size(); ++i) {
accumulatedConvertCost += calcConvertCost(initElemTypes[i], types[i]);
}
if (accumulatedConvertCost < minCost) {
minCost = accumulatedConvertCost;
optTypes = types;
}
}
return optTypes;
}
OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter,
tt::DotOp dot, int version) {
Type f16 = rewriter.getF16Type();
Type f32 = rewriter.getF32Type();
Type bf16 = rewriter.getBF16Type();
Type i8 = rewriter.getIntegerType(8);
Type i32 = rewriter.getIntegerType(32);
SmallVector<OperandTypesVector> applicableTypes = {
{f16, f16, f32, f32},
{bf16, bf16, f32, f32},
{i8, i8, i32, i32},
};
if (version == 2) {
Type fp8e4nv = rewriter.getType<Float8E4M3FNType>();
Type fp8e5 = rewriter.getType<Float8E5M2Type>();
applicableTypes.append({
{fp8e4nv, fp8e4nv, f32, f32},
{fp8e4nv, fp8e5, f32, f32},
{fp8e5, fp8e4nv, f32, f32},
{fp8e5, fp8e5, f32, f32},
});
}
return selectMatrixCoreOperandTypes(dot, applicableTypes);
}
Value convertAndCastTensor(PatternRewriter &rewriter, Value value,
Attribute newEncoding, Type newElemType) {
assert(newElemType.isIntOrFloat());
auto loc = value.getLoc();
auto oldType = cast<RankedTensorType>(value.getType());
auto oldElemType = oldType.getElementType();
assert(oldElemType.isIntOrFloat());
assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex());
auto convertedType =
RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding);
Value convertedTensor =
rewriter.create<ttg::ConvertLayoutOp>(loc, convertedType, value);
if (newElemType == oldElemType)
return convertedTensor;
Type castedType = convertedType.cloneWith(std::nullopt, newElemType);
Value castedTensor;
if (newElemType.isIntOrIndex()) {
unsigned oldWidth = oldElemType.getIntOrFloatBitWidth();
unsigned newWidth = newElemType.getIntOrFloatBitWidth();
if (oldWidth == newWidth)
castedTensor = rewriter.create<arith::BitcastOp>(loc, convertedType,
convertedTensor);
else if (oldWidth > newWidth)
castedTensor =
rewriter.create<arith::TruncIOp>(loc, castedType, convertedTensor);
else if (oldElemType.isSignedInteger())
castedTensor =
rewriter.create<arith::ExtSIOp>(loc, castedType, convertedTensor);
else
castedTensor =
rewriter.create<arith::ExtUIOp>(loc, castedType, convertedTensor);
} else {
if (oldElemType.isF16() && newElemType.isF32())
castedTensor =
rewriter.create<arith::ExtFOp>(loc, castedType, convertedTensor);
else if (oldElemType.isF32() && newElemType.isF16())
castedTensor =
rewriter.create<arith::TruncFOp>(loc, castedType, convertedTensor);
else
castedTensor =
rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
}
return castedTensor;
}
class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
int mfmaVersion;
int nonKDim;
int kPack;
public:
BlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, int kPack,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
nonKDim(nonKDim), kPack(kPack) {}
LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {
RankedTensorType oldRetType = dotOp.getType();
if (!oldRetType.getEncoding() ||
!isa<ttg::BlockedEncodingAttr>(oldRetType.getEncoding()))
return failure();
if (!isa_and_nonnull<BlockedEncodingAttr>(dotOp.getType().getEncoding()))
return rewriter.notifyMatchFailure(
dotOp, "expected blocked encoding result tensor");
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
auto retShape = oldRetType.getShape();
int numWarps = ttg::lookupNumWarps(dotOp);
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = cast<RankedTensorType>(a.getType());
auto oldBType = cast<RankedTensorType>(b.getType());
auto ctx = oldAType.getContext();
Type aElemType = oldAType.getElementType();
Type bElemType = oldBType.getElementType();
bool withScale =
mfmaVersion == 4 && isF8F6F4(aElemType) && isF8F6F4(bElemType);
FailureOr<MfmaIntrinsic> mfmaInstr =
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, withScale);
if (failed(mfmaInstr)) {
if (!withScale) {
return failure();
}
mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, false);
if (failed(mfmaInstr))
return failure();
withScale = false;
}
auto mDim = mfmaInstr->mDim;
auto nDim = mfmaInstr->nDim;
auto kDim = mfmaInstr->kDim;
auto kBase = mfmaInstr->kBase;
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});
Type mfmaAccType;
if (oldRetType.getElementType().isIntOrIndex())
mfmaAccType = rewriter.getIntegerType(32);
else if (oldRetType.getElementType().isF64())
mfmaAccType = rewriter.getF64Type();
else
mfmaAccType = rewriter.getF32Type();
bool isTransposed = !(mDim == 4 && nDim == 64);
auto aElemTy = mfmaInstr->aElementType;
ttg::AMDMfmaEncodingAttr mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
oldRetType.getContext(),
mfmaVersion, warpsPerTile,
mDim, nDim, isTransposed, CTALayout,
mfmaAccType);
auto oldAcc = dotOp.getC();
auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType);
auto kWidth = kBase;
auto isDotChainTail = isChainDotTail(dotOp);
if (!isDotChainTail)
kWidth *= kPack;
auto is16BitElemTy = (aElemTy.isF16() || aElemTy.isBF16());
if (is16BitElemTy && isDotChainTail) {
kWidth = 4;
}
if (is16BitElemTy && hasTransInDefChain(dotOp, 1u)) {
if (isChainDotHead(dotOp)) {
kWidth = 4;
} else if (isDotChainTail) {
kWidth = 8;
}
}
Value newDot;
if (withScale) {
auto aScaledElemTy = mlirTypeToScaledElemType(aElemType);
auto bScaledElemTy = mlirTypeToScaledElemType(bElemType);
if (failed(aScaledElemTy) || failed(bScaledElemTy))
return failure();
assert(kWidth == 32);
auto newAEncoding =
DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth / 2);
auto newBEncoding =
DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth / 2);
a = convertAndCastTensor(rewriter, a, newAEncoding,
mfmaInstr->aElementType);
b = convertAndCastTensor(rewriter, b, newBEncoding,
mfmaInstr->bElementType);
newDot = rewriter.create<triton::DotScaledOp>(
dotOp.getLoc(), newAcc.getType(), a, b, newAcc, Value(), Value(),
aScaledElemTy.value(), bScaledElemTy.value(), false);
} else {
auto newAEncoding =
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth);
auto newBEncoding =
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth);
a = convertAndCastTensor(rewriter, a, newAEncoding,
mfmaInstr->aElementType);
b = convertAndCastTensor(rewriter, b, newBEncoding,
mfmaInstr->bElementType);
newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newAcc.getType(), a,
b, newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());
}
Value dotOutput =
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
oldRetType.getElementType());
rewriter.replaceOp(dotOp, dotOutput);
return success();
}
};
class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
int mfmaVersion;
int nonKDim;
int kPack;
public:
ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim,
int kPack, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
nonKDim(nonKDim), kPack(kPack) {}
LogicalResult matchAndRewrite(triton::DotScaledOp dotOp,
PatternRewriter &rewriter) const override {
if (!dotOp.getLhsKPack() || !dotOp.getRhsKPack())
return failure();
using TensorValue = TypedValue<RankedTensorType>;
RankedTensorType oldRetType = dotOp.getType();
if (!isa_and_nonnull<BlockedEncodingAttr>(oldRetType.getEncoding()))
return rewriter.notifyMatchFailure(
dotOp, "expected blocked encoding result tensor");
unsigned rank = oldRetType.getRank();
if (rank == 3)
return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case");
TensorValue a = dotOp.getA();
TensorValue b = dotOp.getB();
TensorValue aScale = dotOp.getAScale();
TensorValue bScale = dotOp.getBScale();
if (aScale && bScale)
return rewriter.notifyMatchFailure(dotOp, "NYI: both LHS and RHS scale");
ScaleDotElemType aElemType = dotOp.getAElemType();
ScaleDotElemType bElemType = dotOp.getBElemType();
auto supportsTypes = [](ScaleDotElemType elemType) {
return elemType == ScaleDotElemType::E2M1 ||
elemType == ScaleDotElemType::E4M3 ||
elemType == ScaleDotElemType::E5M2 ||
elemType == ScaleDotElemType::BF16 ||
elemType == ScaleDotElemType::FP16;
};
if (!supportsTypes(aElemType) || !supportsTypes(bElemType))
return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6 operand");
MLIRContext *ctx = dotOp.getContext();
auto moduleOp = dotOp->getParentOfType<ModuleOp>();
int numWarps = ttg::lookupNumWarps(dotOp);
ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding());
int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp);
bool useFp16 = aElemType == ScaleDotElemType::FP16 ||
bElemType == ScaleDotElemType::FP16;
FailureOr<MfmaIntrinsic> mfmaInstr =
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, useFp16);
if (failed(mfmaInstr))
return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic");
if (useFp16) {
dotOp.emitRemark(
"Warning: detected one dot_scaled operand is fp16 tensor so "
"upcasting to fp16 for computation, which impacts precision; "
"experimental behavior and may change in future");
}
unsigned mDim = mfmaInstr->mDim;
unsigned nDim = mfmaInstr->nDim;
unsigned kDim = mfmaInstr->kDim;
unsigned kBase = mfmaInstr->kBase;
bool isAPacked = aElemType == ScaleDotElemType::E2M1;
bool isBPacked = bElemType == ScaleDotElemType::E2M1;
bool isPacked = isAPacked || isBPacked;
unsigned kWidths[] = {isPacked ? (isAPacked ? 4 : 8) : kBase * kPack,
isPacked ? (isAPacked ? 8 : 4) : kBase * kPack};
SmallVector<unsigned, 2> mfmaWarpsPerCTA(rank, 1);
mfmaWarpsPerCTA[aScale ? 0 : 1] = numWarps;
auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
ctx, mfmaVersion, mfmaWarpsPerCTA, mDim,
nDim, true, ctaLayout, oldRetType.getElementType());
auto newRetType = RankedTensorType::get(
oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc);
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(
dotOp.getC().getLoc(), newRetType, dotOp.getC());
auto upcastForMMA = [&](TensorValue v, int idx,
ScaleDotElemType type) -> TensorValue {
auto vType = v.getType();
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), kWidths[idx]);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
if (type == ScaleDotElemType::BF16 || type == ScaleDotElemType::FP16 ||
type == ScaleDotElemType::E2M1 ||
(mfmaVersion == 4 &&
(type == ScaleDotElemType::E4M3 || type == ScaleDotElemType::E5M2)))
return v;
auto upcastedType = RankedTensorType::get(
vType.getShape(),
useFp16 ? rewriter.getF16Type() : rewriter.getBF16Type(),
newVEncoding);
return cast<TensorValue>(
rewriter.create<FpToFpOp>(v.getLoc(), upcastedType, v).getResult());
};
a = upcastForMMA(a, 0, aElemType);
b = upcastForMMA(b, 1, bElemType);
assert(mDim == nDim);
SmallVector<unsigned, 2> threadsPerWarp = {mDim, numThreads / mDim};
SmallVector<unsigned, 2> blockWarpsPerCTA(rank, 1);
blockWarpsPerCTA[0] = numWarps;
auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout);
auto upcastMXFP = [&](TensorValue v, TensorValue scale,
ScaleDotElemType elemType, bool fastMath) -> Value {
if (!scale)
return v;
auto newScaleType = RankedTensorType::get(
scale.getType().getShape(), scale.getType().getElementType(),
newScaleEncoding);
auto convOp = rewriter.create<ttg::ConvertLayoutOp>(scale.getLoc(),
newScaleType, scale);
Builder b(v.getContext());
Type outputElemType = useFp16 ? b.getF16Type() : b.getBF16Type();
auto outputType =
amdgpu::UpcastMXFPOp::deduceOutputType(v, elemType, outputElemType);
return rewriter.create<amdgpu::UpcastMXFPOp>(
dotOp.getLoc(), outputType, v, convOp, elemType, fastMath);
};
Value scaledA =
upcastMXFP(a, aScale, dotOp.getAElemType(), dotOp.getFastMath());
Value scaledB =
upcastMXFP(b, bScale, dotOp.getBElemType(), dotOp.getFastMath());
auto newDot = rewriter.create<DotOp>(dotOp.getLoc(), newRetType, scaledA,
scaledB, newAcc);
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
newDot);
return success();
}
};
template <typename Op> Op getDefOpBeforeConvertLayout(Value op) {
while (auto cvtOp = op.getDefiningOp<ttg::ConvertLayoutOp>()) {
op = cvtOp.getSrc();
}
return op.getDefiningOp<Op>();
}
bool isScaleShuffled(Value scale) {
if (!scale) {
return false;
}
auto shape = cast<RankedTensorType>(scale.getType()).getShape();
int rank = shape.size();
int blockNonK = shape[rank - 2];
int blockK = shape[rank - 1] * 32;
auto reshapeOp2D = getDefOpBeforeConvertLayout<triton::ReshapeOp>(scale);
if (!reshapeOp2D || reshapeOp2D.getType().getShape() != shape) {
return false;
}
const std::array<int, 7> transposeOrder{0, 5, 3, 1, 4, 2, 6};
auto transOp =
getDefOpBeforeConvertLayout<triton::TransOp>(reshapeOp2D.getSrc());
if (!transOp || transOp.getOrder() != ArrayRef<int>(transposeOrder)) {
return false;
}
const std::array<int64_t, 7> reshape7DShape{
blockNonK / 32, blockK / 32 / 8, 4, 16, 2, 2, 1};
auto reshapeOp7D =
getDefOpBeforeConvertLayout<triton::ReshapeOp>(transOp.getSrc());
if (!reshapeOp7D ||
reshapeOp7D.getType().getShape() != ArrayRef<int64_t>(reshape7DShape)) {
return false;
}
return true;
}
SmallVector<unsigned, 2> getTilesPerWarp(Value aScale, Value bScale) {
if (isScaleShuffled(aScale) || isScaleShuffled(bScale)) {
return {2, 2};
}
return {1, 1};
}
class ScaledBlockedToScaledMFMAF8F6F4 final
: public OpRewritePattern<triton::DotScaledOp> {
int mfmaVersion;
int nonKDim;
public:
ScaledBlockedToScaledMFMAF8F6F4(MLIRContext *context, int mfmaVersion,
int nonKDim, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
nonKDim(nonKDim) {}
LogicalResult matchAndRewrite(triton::DotScaledOp dotOp,
PatternRewriter &rewriter) const override {
using TensorValue = TypedValue<RankedTensorType>;
if (mfmaVersion != 4) {
return rewriter.notifyMatchFailure(
dotOp, "F8F6F4 scaled dot is only natively supported on gfx950");
}
RankedTensorType oldRetType = dotOp.getType();
if (!isa_and_nonnull<BlockedEncodingAttr>(oldRetType.getEncoding()))
return rewriter.notifyMatchFailure(
dotOp, "expected blocked encoding result tensor");
unsigned rank = oldRetType.getRank();
if (rank == 3)
return rewriter.notifyMatchFailure(dotOp, "NYI: 3d case");
TensorValue a = dotOp.getA();
TensorValue b = dotOp.getB();
TensorValue aScale = dotOp.getAScale();
TensorValue bScale = dotOp.getBScale();
auto oldShape = oldRetType.getShape();
ScaleDotElemType aElemType = dotOp.getAElemType();
ScaleDotElemType bElemType = dotOp.getBElemType();
auto supportsTypes = [](ScaleDotElemType elemType) {
return elemType == ScaleDotElemType::E2M1 ||
elemType == ScaleDotElemType::E4M3 ||
elemType == ScaleDotElemType::E5M2;
};
if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) {
return rewriter.notifyMatchFailure(dotOp, "NYI: mxfp6");
}
bool bothScalesAbsent = !aScale && !bScale;
MLIRContext *ctx = dotOp.getContext();
ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding());
unsigned numWarps = ttg::lookupNumWarps(dotOp);
if (numWarps == 1)
return rewriter.notifyMatchFailure(dotOp,
"num_warps==1 is not supported");
FailureOr<MfmaIntrinsic> mfmaInstr =
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim);
if (failed(mfmaInstr))
return rewriter.notifyMatchFailure(dotOp,
"cannot choose scaled mfma intrinsic");
auto mDim = mfmaInstr->mDim;
auto nDim = mfmaInstr->nDim;
auto kDim = mfmaInstr->kDim;
auto kBase = mfmaInstr->kBase;
assert(mDim == nDim);
auto warpsPerTile =
warpsPerTileMFMA(dotOp, oldShape, numWarps, {mDim, nDim});
SmallVector<unsigned> tilesPerWarp = getTilesPerWarp(aScale, bScale);
if (rank == 3) {
tilesPerWarp.insert(tilesPerWarp.begin(), 1);
}
mlir::Attribute mfmaEnc;
if (llvm::any_of(tilesPerWarp, [](int x) { return x != 1; })) {
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
ctx, mfmaVersion, warpsPerTile, tilesPerWarp,
mDim, nDim, true, ctaLayout,
oldRetType.getElementType());
} else {
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
ctx, mfmaVersion, warpsPerTile,
mDim, nDim, true, ctaLayout,
oldRetType.getElementType());
}
auto newRetType =
RankedTensorType::get(oldShape, oldRetType.getElementType(), mfmaEnc);
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(
dotOp.getC().getLoc(), newRetType, dotOp.getC());
auto order = ttg::getMatrixOrder(rank, true);
auto standardOutDims = standardOutDimNames(ctx, rank);
const unsigned kWidth = kBase;
assert(kWidth == 32);
using basisT = std::vector<std::vector<int32_t>>;
auto aShape = a.getType().getShape();
auto bShape = b.getType().getShape();
auto aEncLL = LinearLayout::empty();
auto bEncLL = LinearLayout::empty();
auto convertInputLayout = [&](TensorValue v,
unsigned opIdx) -> TensorValue {
auto vType = v.getType();
auto newEnc =
DotOperandEncodingAttr::get(ctx, opIdx, mfmaEnc, kWidth / 2);
bool kPacked = opIdx == 0 ? dotOp.getLhsKPack() : dotOp.getRhsKPack();
if (kPacked == false) {
SmallVector<int64_t> newShape(vType.getShape());
newShape[opIdx == 0 ? 0 : 1] = newShape[opIdx == 0 ? 0 : 1] * 2;
newShape[opIdx == 0 ? 1 : 0] = newShape[opIdx == 0 ? 1 : 0] / 2;
auto newVType =
RankedTensorType::get(newShape, vType.getElementType(), newEnc);
OpBuilder builder(dotOp);
auto srcEncoding = vType.getEncoding();
auto originalOrder = triton::gpu::getOrderForMemory(vType);
SmallVector<unsigned> newOrder = originalOrder;
if (opIdx == 1) {
newOrder = {1, 0};
} else {
newOrder = {0, 1};
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(vType.getContext());
auto tmpType = triton::gpu::MemDescType::get(
vType.getShape(), vType.getElementType(),
triton::gpu::SwizzledSharedEncodingAttr::get(
v.getContext(), newEnc, vType.getShape(), newOrder,
triton::gpu::getCTALayout(srcEncoding), vType.getElementType()),
sharedMemorySpace);
auto tmp = builder.create<triton::gpu::LocalAllocOp>(dotOp.getLoc(),
tmpType, v);
auto newConvert =
builder.create<triton::amdgpu::LocalLoadPackedTransposedOp>(
dotOp.getLoc(), newVType, tmp);
if (opIdx == 0) {
aShape = newConvert.getType().getShape();
aEncLL *= newEnc.toLinearLayout(aShape);
} else {
bShape = newConvert.getType().getShape();
bEncLL *= newEnc.toLinearLayout(bShape);
}
return newConvert;
} else {
if (opIdx == 0)
aEncLL *= newEnc.toLinearLayout(aShape);
else
bEncLL *= newEnc.toLinearLayout(bShape);
auto newVType = RankedTensorType::get(vType.getShape(),
vType.getElementType(), newEnc);
return rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
}
};
a = convertInputLayout(a, 0);
b = convertInputLayout(b, 1);
StringAttr kWarp = StringAttr::get(ctx, "warp");
auto convertScaleLayout = [&](TensorValue scale,
llvm::ArrayRef<int64_t> valShape,
LinearLayout dotLL, int idx) -> Value {
if (bothScalesAbsent)
return Value();
SmallVector<int64_t> shape;
if (!scale) {
int64_t nonKDim = idx == 0 ? valShape[0] : valShape[1];
int64_t k = idx == 0 ? valShape[1] : valShape[0];
ScaleDotElemType &elemType = idx == 0 ? aElemType : bElemType;
int packSize = elemType == ScaleDotElemType::E2M1 ? 2 : 1;
shape = {nonKDim, k * packSize / 32};
} else {
shape = llvm::to_vector(scale.getType().getShape());
}
LinearLayout newLL = chooseScaledMfmaScaleLayout(
ctx, idx, shape, mDim, tilesPerWarp, warpsPerTile);
Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL);
auto newScaleType = RankedTensorType::get(shape, i8_ty, newScaleEncoding);
if (!scale) {
return rewriter.create<arith::ConstantOp>(
dotOp->getLoc(), newScaleType,
DenseElementsAttr::get(newScaleType, llvm::APInt(8, 0x7F)));
} else {
return rewriter.create<ttg::ConvertLayoutOp>(scale.getLoc(),
newScaleType, scale);
}
};
auto newAScale =
convertScaleLayout(aScale, aShape, aEncLL, 0);
auto newBScale =
convertScaleLayout(bScale, bShape, bEncLL, 1);
auto newDot = rewriter.create<triton::DotScaledOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, newAScale, newBScale,
aElemType, bElemType, dotOp.getFastMath());
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(dotOp, oldRetType,
newDot);
return success();
}
};
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
Type promotedType) {
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
.cloneWith(std::nullopt, promotedType);
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
}
static void decomposeMixedModeDotOp(ModuleOp mod) {
mod.walk([](triton::DotOp dotOp) -> void {
auto D = dotOp.getD();
OpBuilder builder(dotOp);
Type AElType = dotOp.getA().getType().getElementType();
Type promoteType;
if (isa<ttg::AMDMfmaEncodingAttr>(D.getType().getEncoding())) {
Type BElType = dotOp.getB().getType().getElementType();
auto maxBitWidth = std::max(AElType.getIntOrFloatBitWidth(),
BElType.getIntOrFloatBitWidth());
if (maxBitWidth == 8)
return;
if (AElType == BElType)
return;
if (maxBitWidth < 16)
promoteType = builder.getF16Type();
else if (maxBitWidth <= 32)
promoteType = builder.getF32Type();
} else if (isa<ttg::AMDWmmaEncodingAttr>(D.getType().getEncoding())) {
Type BElType = dotOp.getB().getType().getElementType();
if (AElType == BElType)
return;
promoteType =
AElType.getIntOrFloatBitWidth() > BElType.getIntOrFloatBitWidth()
? AElType
: BElType;
} else {
return;
}
Location loc = dotOp.getLoc();
Value promotedA = promoteOperand(builder, loc, dotOp.getA(), promoteType);
Value promotedB = promoteOperand(builder, loc, dotOp.getB(), promoteType);
dotOp.setOperand(0, promotedA);
dotOp.setOperand(1, promotedB);
});
}
FailureOr<WmmaIntrinsic> chooseWmmaInstruction(Location loc, int wmmaVersion,
RankedTensorType cType,
Type aElemType, Type bElemType,
Type cElemType, int inputKSize,
int enforcedNonKDim) {
unsigned kDim = 0;
auto resShape = cType.getShape();
auto rank = resShape.size();
auto M = resShape[rank - 2];
auto N = resShape[rank - 1];
unsigned mDim = 0;
unsigned nDim = 0;
if (enforcedNonKDim != 0) {
mDim = nDim = enforcedNonKDim;
} else {
int minSize = std::min(M, N);
if (minSize >= 16) {
mDim = 16;
nDim = 16;
}
}
if (mDim == 0 || nDim == 0)
return failure();
FailureOr<WmmaIntrinsic> maybeWmmaIntrinsic = WmmaIntrinsic::selectFor(
wmmaVersion, mDim, nDim, inputKSize, aElemType, bElemType, cElemType);
if (failed(maybeWmmaIntrinsic))
return emitError(loc, "no matching matrix core intrinsic due to "
"unsupported element type");
kDim = maybeWmmaIntrinsic->kDim;
assert(kDim != 0);
assert(enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0));
if (enforcedNonKDim == 0 && inputKSize % kDim != 0)
return failure();
return maybeWmmaIntrinsic;
}
FailureOr<WmmaIntrinsic> chooseWmmaInstruction(tt::DotOp dot,
OperandTypesVector operandTypes,
int wmmaVersion, int nonKDim) {
return chooseWmmaInstruction(dot.getLoc(), wmmaVersion, dot.getC().getType(),
operandTypes[0], operandTypes[1],
operandTypes[2],
dot.getA().getType().getShape().back(), nonKDim);
}
class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
int wmmaVersion;
int nonKDim;
public:
BlockedToWMMA(MLIRContext *context, int wmmaVersion, int nonKDim,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), wmmaVersion(wmmaVersion),
nonKDim(nonKDim) {}
LogicalResult matchAndRewrite(tt::DotOp dotOp,
PatternRewriter &rewriter) const override {
auto ctx = dotOp->getContext();
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldRetType = cast<RankedTensorType>(dotOp.getResult().getType());
auto oldRetEncoding = oldRetType.getEncoding();
if (!oldRetEncoding || !isa<ttg::BlockedEncodingAttr>(oldRetEncoding))
return failure();
auto oldAType = cast<RankedTensorType>(a.getType());
auto oldBType = cast<RankedTensorType>(b.getType());
auto retShape = oldRetType.getShape();
auto aShape = oldAType.getShape();
auto bShape = oldBType.getShape();
auto operandTypes = getOperandTypesForWmmaOp(rewriter, dotOp, wmmaVersion);
if (operandTypes.empty())
return failure();
FailureOr<WmmaIntrinsic> wmmaInstr =
chooseWmmaInstruction(dotOp, operandTypes, wmmaVersion, nonKDim);
if (failed(wmmaInstr)) {
return failure();
}
auto mDim = wmmaInstr->mDim;
auto nDim = wmmaInstr->nDim;
auto kDim = wmmaInstr->kDim;
auto kBase = wmmaInstr->kBase;
int numWarps = ttg::lookupNumWarps(dotOp);
ttg::AMDWmmaEncodingAttr wmmaEnc;
auto warpsPerTile =
warpsPerTileWMMA(dotOp, retShape, numWarps, {mDim, nDim});
auto CTALayout = ttg::getCTALayout(oldRetEncoding);
bool isTransposed = false;
wmmaEnc = ttg::AMDWmmaEncodingAttr::get(ctx, wmmaVersion, isTransposed,
warpsPerTile, CTALayout);
auto newRetType = RankedTensorType::get(retShape, operandTypes[3], wmmaEnc);
auto oldAcc = dotOp.getC();
auto newAcc =
convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]);
auto kWidth = kBase;
auto newAType = RankedTensorType::get(
aShape, operandTypes[0],
ttg::DotOperandEncodingAttr::get(ctx, 0, wmmaEnc, kWidth));
auto newBType = RankedTensorType::get(
bShape, operandTypes[1],
ttg::DotOperandEncodingAttr::get(ctx, 1, wmmaEnc, kWidth));
Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(),
operandTypes[0]);
Value castedB = convertAndCastTensor(rewriter, b, newBType.getEncoding(),
operandTypes[1]);
auto newDot = rewriter.create<tt::DotOp>(
dotOp.getLoc(), newRetType, castedA, castedB, newAcc,
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
Value dotOutput = convertAndCastTensor(rewriter, newDot, oldRetEncoding,
oldRetType.getElementType());
rewriter.replaceOp(dotOp, dotOutput);
return success();
}
};
class AccelerateBlocked : public OpRewritePattern<DotOp> {
StringRef arch;
public:
AccelerateBlocked(MLIRContext *context, StringRef arch,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), arch(arch) {}
bool isFloat(Type t) const { return t.isIntOrFloat() && !t.isIntOrIndex(); }
Value castToElTy(PatternRewriter &rewriter, Value v, Type elTy) const {
Location loc = v.getLoc();
auto srcTy = cast<RankedTensorType>(v.getType());
auto dstTy = srcTy.cloneWith(std::nullopt, elTy);
if (srcTy == dstTy)
return v;
auto srcElTy = srcTy.getElementType();
auto dstElTy = dstTy.getElementType();
if (isFloat(srcElTy) && isFloat(dstElTy)) {
auto rmode =
RoundingModeAttr::get(rewriter.getContext(), RoundingMode::RTNE);
return rewriter.create<FpToFpOp>(loc, dstTy, v, rmode);
}
if (!isFloat(srcElTy) && isFloat(dstElTy))
return rewriter.create<arith::SIToFPOp>(loc, dstTy, v);
if (isFloat(srcElTy) && !isFloat(dstElTy))
return rewriter.create<arith::FPToSIOp>(loc, dstTy, v);
assert(false && "int -> int cast is unexpected in FMA legalization");
return Value();
}
struct DotElTypes {
Type a, b, c, d;
};
bool isLegalFMAForm(DotOp dotOp, const DotElTypes &dotTypes) const {
if (AMD::supportsVDot(arch)) {
auto aOpType = dotOp.getA().getType();
int rank = aOpType.getRank();
int k = aOpType.getShape()[rank - 1];
if (dotTypes.a.isF16() && dotTypes.b.isF16() && dotTypes.c.isF32() &&
dotTypes.d.isF32() && k % 2 == 0) {
return true;
}
if (AMD::deduceISAFamily(arch) == ISAFamily::CDNA4 &&
dotTypes.a.isBF16() && dotTypes.b.isBF16() && dotTypes.c.isF32() &&
dotTypes.d.isF32() && k % 2 == 0) {
return true;
}
if (false && dotTypes.a.isF16() && dotTypes.b.isF16() &&
dotTypes.c.isF16() && dotTypes.d.isF16() && k % 2 == 0) {
return false;
}
if (dotTypes.a.isInteger(8) && dotTypes.b.isInteger(8) &&
dotTypes.c.isInteger(32) && dotTypes.d.isInteger(32) && k % 4 == 0) {
return true;
}
}
auto expectedElTy = dotTypes.a;
for (auto operand : dotOp.getOperands()) {
auto opTy = cast<RankedTensorType>(operand.getType());
auto elTy = opTy.getElementType();
if (elTy != expectedElTy)
return false;
if (!elTy.isF16() && !elTy.isF32() && !elTy.isF64())
return false;
}
return true;
}
LogicalResult tryAccelerateF16WithVDot(DotOp dotOp, PatternRewriter &rewriter,
const DotElTypes &dotTypes) const {
if (!AMD::supportsVDot(arch))
return failure();
auto aOpType = dotOp.getA().getType();
int rank = aOpType.getRank();
int k = aOpType.getShape()[rank - 1];
if (dotTypes.a.isF16() && dotTypes.b.isF16() && dotTypes.c.isF16() &&
dotTypes.d.isF16() && k % 2 == 0) {
auto newC = castToElTy(rewriter, dotOp.getC(), f32_ty);
auto newDot = rewriter.create<DotOp>(
dotOp.getLoc(), newC.getType(), dotOp.getA(), dotOp.getB(), newC,
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc());
auto newD = castToElTy(rewriter, newDot.getResult(), f16_ty);
rewriter.replaceOp(dotOp, newD);
return success();
}
return failure();
}
LogicalResult tryLegalizeFMA(DotOp dotOp, PatternRewriter &rewriter,
const DotElTypes &dotTypes) const {
SmallVector<Type> opElTy{dotTypes.a, dotTypes.b, dotTypes.c, dotTypes.d};
unsigned maxBitsize = 8;
for (auto elTy : opElTy)
maxBitsize = std::max(maxBitsize, elTy.getIntOrFloatBitWidth());
assert(maxBitsize <= 32);
Type commonTy =
maxBitsize <= 16 ? rewriter.getF16Type() : rewriter.getF32Type();
if (commonTy.isF16()) {
for (auto elTy : opElTy) {
if (elTy.isInteger() && elTy.getIntOrFloatBitWidth() > 8) {
commonTy = rewriter.getF32Type();
break;
}
if (elTy.isBF16()) {
commonTy = rewriter.getF32Type();
break;
}
}
}
auto newA = castToElTy(rewriter, dotOp.getA(), commonTy);
auto newB = castToElTy(rewriter, dotOp.getB(), commonTy);
auto newC = castToElTy(rewriter, dotOp.getC(), commonTy);
auto newDot = rewriter.create<DotOp>(dotOp.getLoc(), newC.getType(), newA,
newB, newC, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc());
auto newD = castToElTy(rewriter, newDot.getResult(), dotTypes.d);
rewriter.replaceOp(dotOp, newD);
return success();
}
LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {
if (!isa<BlockedEncodingAttr>(dotOp.getD().getType().getEncoding()))
return failure();
DotElTypes dotTypes;
dotTypes.a = dotOp.getA().getType().getElementType();
dotTypes.b = dotOp.getB().getType().getElementType();
dotTypes.c = dotOp.getC().getType().getElementType();
dotTypes.d = dotOp.getD().getType().getElementType();
if (isLegalFMAForm(dotOp, dotTypes)) {
return failure();
}
if (false &&
tryAccelerateF16WithVDot(dotOp, rewriter, dotTypes).succeeded()) {
return success();
}
return tryLegalizeFMA(dotOp, rewriter, dotTypes);
}
};
}
#define GEN_PASS_DEF_TRITONAMDGPUACCELERATEMATMUL
#include "TritonAMDGPUTransforms/Passes.h.inc"
struct TritonAMDGPUAccelerateMatmulPass
: impl::TritonAMDGPUAccelerateMatmulBase<TritonAMDGPUAccelerateMatmulPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
RewritePatternSet mfmaPatterns(context);
switch (auto isaFamily = triton::AMD::deduceISAFamily(archGenerationName)) {
case ISAFamily::CDNA4:
mfmaPatterns.add<::ScaledBlockedToScaledMFMAF8F6F4>(
context, getMfmaVersion(isaFamily), matrixInstructionSize,
10);
[[fallthrough]];
case ISAFamily::CDNA1:
case ISAFamily::CDNA2:
case ISAFamily::CDNA3:
mfmaPatterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>(
context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack,
2);
break;
case ISAFamily::RDNA3:
ttg::populateDecomposeScaledBlockedPatterns(mfmaPatterns,
3);
mfmaPatterns.add<::BlockedToWMMA>(
context, getWmmaVersion(archGenerationName), matrixInstructionSize,
2);
break;
default:
break;
}
if (applyPatternsGreedily(m, std::move(mfmaPatterns)).failed())
signalPassFailure();
RewritePatternSet patterns(context);
patterns.add<AccelerateBlocked>(context, archGenerationName, 1);
if (applyPatternsGreedily(m, std::move(patterns)).failed())
signalPassFailure();
decomposeMixedModeDotOp(m);
}
};
}