#include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
namespace {
SmallVector<int, 2> getTransposeOrder(int rank) {
assert(rank >= 2);
auto transOrder = llvm::to_vector<2>(llvm::seq<int>(rank - 2));
transOrder.push_back(rank - 1);
transOrder.push_back(rank - 2);
return transOrder;
}
class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
public:
DecomposeScaledBlocked(MLIRContext *context, int benefit)
: OpRewritePattern<DotScaledOp>(context, benefit) {}
LogicalResult matchAndRewrite(DotScaledOp scaledDotOp,
PatternRewriter &rewriter) const override {
if (isa_and_nonnull<MmaEncodingTrait>(
scaledDotOp.getResult().getType().getEncoding()))
return failure();
if (!scaledDotOp.getLhsKPack() || !scaledDotOp.getRhsKPack())
return failure();
auto computeType = getComputeType(scaledDotOp.getAElemType(),
scaledDotOp.getBElemType(), rewriter);
auto loc = scaledDotOp.getLoc();
auto cvtDotOperand = [&](TypedValue<RankedTensorType> v,
int opIdx) -> TypedValue<RankedTensorType> {
auto *ctx = rewriter.getContext();
auto retEnc = scaledDotOp.getType().getEncoding();
auto vType = v.getType();
auto encoding = DotOperandEncodingAttr::get(ctx, opIdx, retEnc,
vType.getElementType());
auto retTy = vType.cloneWithEncoding(encoding);
return rewriter.create<ConvertLayoutOp>(loc, retTy, v);
};
auto scaledA = scaleArg(rewriter, scaledDotOp, 0, computeType);
scaledA = cvtDotOperand(scaledA, 0);
auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType);
scaledB = cvtDotOperand(scaledB, 1);
auto newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), scaledA, scaledB,
scaledDotOp.getC());
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(scaledDotOp,
scaledDotOp.getType(), newDot);
return success();
}
private:
FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType,
PatternRewriter &rewriter) const {
if (aType == ScaleDotElemType::FP16 || bType == ScaleDotElemType::FP16)
return rewriter.getF16Type();
return rewriter.getBF16Type();
}
TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter,
TypedValue<RankedTensorType> scale,
FloatType computeType) const {
auto loc = scale.getLoc();
auto scaleTy = scale.getType();
assert(computeType == rewriter.getBF16Type() ||
computeType == rewriter.getF16Type());
FloatType largeFpType = computeType == rewriter.getF16Type()
? rewriter.getF32Type()
: computeType;
int intWidth = largeFpType.getIntOrFloatBitWidth();
auto intType = rewriter.getIntegerType(intWidth);
auto zexted =
rewriter.create<arith::ExtUIOp>(loc, scaleTy.clone(intType), scale);
int shiftValue = largeFpType.getFPMantissaWidth() - 1;
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shiftValue, intWidth);
auto shift =
rewriter.create<SplatOp>(loc, scaleTy.clone(intType), shiftConst);
auto shlRes = rewriter.create<arith::ShLIOp>(loc, zexted, shift);
Value scaleFP =
rewriter.create<BitcastOp>(loc, scaleTy.clone(largeFpType), shlRes);
if (largeFpType != computeType) {
scaleFP = rewriter.create<arith::TruncFOp>(
loc, scaleTy.clone(computeType), scaleFP);
}
return cast<TypedValue<RankedTensorType>>(scaleFP);
}
TypedValue<RankedTensorType>
broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
ModuleOp mod, TypedValue<RankedTensorType> scale,
int dim) const {
auto *ctx = rewriter.getContext();
auto loc = scale.getLoc();
auto scaleTy = scale.getType();
auto rank = scaleTy.getRank();
{
auto shape = to_vector(scaleTy.getShape());
shape.insert(shape.end(), 1);
auto nWarps = lookupNumWarps(scaledDotOp);
auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto numCTAs = TritonGPUDialect::getNumCTAs(mod);
auto blockedEnc = getDefaultBlockedEncoding(ctx, shape, nWarps,
threadsPerWarp, numCTAs);
auto sliceEnc = SliceEncodingAttr::get(ctx, rank, blockedEnc);
auto sliceType = scaleTy.cloneWithEncoding(sliceEnc);
scale = rewriter.create<ConvertLayoutOp>(loc, sliceType, scale);
}
auto expandScale = rewriter.create<ExpandDimsOp>(loc, scale, rank);
auto scaleShape = to_vector(scaleTy.getShape());
scaleShape.push_back(32);
auto broadcastScale = rewriter.create<BroadcastOp>(
loc, expandScale.getType().clone(scaleShape), expandScale);
auto transposeOrder = llvm::to_vector(llvm::seq<int32_t>(rank));
transposeOrder.insert(transposeOrder.begin() + dim + 1, rank);
auto transposedScale =
rewriter.create<TransOp>(loc, broadcastScale, transposeOrder);
scaleShape.pop_back();
scaleShape[dim] *= 32;
auto reshapeScale =
rewriter.create<ReshapeOp>(loc, scaleShape, transposedScale);
return reshapeScale;
}
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
DotScaledOp scaledDotOp, ModuleOp mod,
TypedValue<RankedTensorType> mxfp,
TypedValue<RankedTensorType> scale,
int dim) const {
auto loc = scale.getLoc();
auto scaleTy = scale.getType();
auto constFF = rewriter.create<arith::ConstantOp>(
loc, scaleTy,
DenseElementsAttr::get(scaleTy,
APInt(scaleTy.getElementTypeBitWidth(), 0xff)));
auto scaleIsNan = cast<TypedValue<RankedTensorType>>(
rewriter
.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, scale,
constFF)
.getResult());
auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim);
auto condTy = cond.getType();
condTy = condTy.cloneWithEncoding(mxfp.getType().getEncoding());
cond = rewriter.create<ConvertLayoutOp>(loc, condTy, cond);
auto mxfpTy = mxfp.getType();
auto nan = APFloat::getNaN(
cast<FloatType>(mxfpTy.getElementType()).getFloatSemantics());
auto constNan = rewriter.create<arith::ConstantOp>(
loc, mxfpTy, DenseElementsAttr::get(mxfpTy, nan));
auto result = rewriter.create<arith::SelectOp>(loc, cond, constNan, mxfp);
return cast<TypedValue<RankedTensorType>>(result.getResult());
}
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
DotScaledOp scaledDotOp, int opIdx,
FloatType computeType) const {
auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB();
auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale();
auto isFp4 =
ScaleDotElemType::E2M1 ==
(opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType());
auto fastMath = scaledDotOp.getFastMath();
auto *ctx = rewriter.getContext();
auto loc = v.getLoc();
auto mod = scaledDotOp->getParentOfType<ModuleOp>();
auto rank = v.getType().getRank();
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
if (isFp4) {
v = rewriter.create<Fp4ToFpOp>(loc, v, computeType, kDim);
} else {
auto vType16 = v.getType().clone(computeType);
v = cast<TypedValue<RankedTensorType>>(
rewriter.create<FpToFpOp>(loc, vType16, v).getResult());
}
if (!scale)
return v;
if (opIdx == 1) {
auto order = getTransposeOrder(rank);
scale = rewriter.create<TransOp>(loc, scale, order);
}
auto scale16 = scaleTo16(rewriter, scale, computeType);
auto reshapeScale =
broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim);
reshapeScale =
rewriter.create<ConvertLayoutOp>(loc, v.getType(), reshapeScale);
auto mxfp = cast<TypedValue<RankedTensorType>>(
rewriter.create<arith::MulFOp>(loc, v, reshapeScale).getResult());
if (fastMath)
return mxfp;
return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, kDim);
}
};
}
namespace mlir::triton::gpu {
void populateDecomposeScaledBlockedPatterns(RewritePatternSet &patterns,
int benefit) {
patterns.add<DecomposeScaledBlocked>(patterns.getContext(), benefit);
}
}