#include <memory>
#include <numeric>
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir {
namespace triton {
namespace gpu {
#define GEN_PASS_DEF_TRITONGPUOPTIMIZETHREADLOCALITY
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
struct OptimizeReshapeLayoutPattern : public OpRewritePattern<ReshapeOp> {
OptimizeReshapeLayoutPattern(MLIRContext *context)
: OpRewritePattern<ReshapeOp>(context, 1) {}
LogicalResult matchAndRewrite(ReshapeOp viewOp,
PatternRewriter &rewriter) const override {
if (!viewOp.getAllowReorder())
return failure();
std::optional<int> reductionAxis;
for (Operation *user : viewOp.getResult().getUsers()) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(user)) {
if (reductionAxis) {
if (reductionAxis != reduceOp.getAxis())
return failure();
} else {
reductionAxis = reduceOp.getAxis();
}
}
}
if (!reductionAxis)
return failure();
RankedTensorType tensorType = viewOp.getType();
if (auto blocked =
mlir::dyn_cast<BlockedEncodingAttr>(tensorType.getEncoding())) {
if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 &&
blocked.getWarpsPerCTA()[*reductionAxis] == 1 &&
blocked.getCTAsPerCGA()[*reductionAxis] == 1)
return failure();
}
ArrayRef<int64_t> shape = tensorType.getShape();
SmallVector<unsigned> order;
for (int i : triton::gpu::getOrder(tensorType)) {
if (i != *reductionAxis)
order.push_back(i);
}
order.push_back(*reductionAxis);
SmallVector<unsigned> sizePerThread(shape.size(), 1);
auto mod = viewOp->getParentOfType<ModuleOp>();
int numWarps = lookupNumWarps(viewOp);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
int numCTAs = TritonGPUDialect::getNumCTAs(mod);
auto encoding =
BlockedEncodingAttr::get(viewOp.getContext(), shape, sizePerThread,
order, numWarps, threadsPerWarp, numCTAs);
if (encoding == tensorType.getEncoding())
return failure();
RankedTensorType newType =
RankedTensorType::get(shape, tensorType.getElementType(), encoding);
if (triton::gpu::isExpensiveView(viewOp.getSrc().getType(), newType))
return failure();
rewriter.setInsertionPointAfter(viewOp);
rewriter.modifyOpInPlace(viewOp, [&]() {
viewOp.getResult().setType(newType);
viewOp.setEfficientLayout(true);
});
auto cvt = rewriter.create<ConvertLayoutOp>(viewOp.getLoc(), tensorType,
viewOp.getResult());
rewriter.replaceAllUsesExcept(viewOp.getResult(), cvt.getResult(), cvt);
return success();
}
};
}
static LogicalResult setOptimizedGatherLayout(GatherOp op, RewriterBase &b) {
RankedTensorType srcType = op.getSrc().getType();
RankedTensorType idxType = op.getIndices().getType();
unsigned numThreadsPerWarp = lookupThreadsPerWarp(b);
unsigned numWarps = lookupNumWarps(op);
unsigned axis = op.getAxis();
unsigned rank = srcType.getRank();
if (rank == 1)
return failure();
SmallVector<unsigned> threadsPerWarp(rank);
SmallVector<unsigned> warpsPerCTA(rank);
SmallVector<unsigned> order;
order.push_back(axis);
unsigned maxThreadsInAxis =
std::min<unsigned>(srcType.getDimSize(axis), numThreadsPerWarp);
threadsPerWarp[axis] = maxThreadsInAxis;
unsigned threadsToAlloc = numThreadsPerWarp / maxThreadsInAxis;
for (unsigned dim : getThreadOrder(srcType)) {
if (dim == axis)
continue;
order.push_back(dim);
unsigned nextThreadAlloc =
std::min<unsigned>(srcType.getDimSize(dim), threadsToAlloc);
threadsPerWarp[dim] = nextThreadAlloc;
threadsToAlloc /= nextThreadAlloc;
}
assert(llvm::none_of(threadsPerWarp, [](unsigned c) { return c == 0; }));
warpsPerCTA[axis] = 1;
unsigned warpsToAlloc = numWarps;
for (unsigned dim : getWarpOrder(srcType)) {
if (dim == axis)
continue;
unsigned warpsCanFit = srcType.getDimSize(dim) / threadsPerWarp[dim];
assert(warpsCanFit != 0);
unsigned nextWarpAlloc = std::min<unsigned>(warpsCanFit, warpsToAlloc);
warpsPerCTA[dim] = nextWarpAlloc;
warpsToAlloc /= nextWarpAlloc;
}
assert(llvm::none_of(warpsPerCTA, [](unsigned c) { return c == 0; }));
SmallVector<unsigned> sizePerThread(rank, 1);
sizePerThread[axis] = srcType.getDimSize(axis) / threadsPerWarp[axis];
threadsPerWarp[axis] *= threadsToAlloc;
warpsPerCTA[axis] *= warpsToAlloc;
assert(product(threadsPerWarp) == numThreadsPerWarp);
assert(product(warpsPerCTA) == numWarps);
MLIRContext *ctx = srcType.getContext();
auto baseLayout = cast<LayoutEncodingTrait>(srcType.getEncoding());
auto ctaLayout =
CTALayoutAttr::get(ctx, baseLayout.getCTAsPerCGA(),
baseLayout.getCTASplitNum(), baseLayout.getCTAOrder());
auto newLayout = BlockedEncodingAttr::get(ctx, sizePerThread, threadsPerWarp,
warpsPerCTA, order, ctaLayout);
auto cvtSrc = b.create<ConvertLayoutOp>(
op.getLoc(), srcType.cloneWithEncoding(newLayout), op.getSrc());
auto cvtIdx = b.create<ConvertLayoutOp>(
op.getLoc(), idxType.cloneWithEncoding(newLayout), op.getIndices());
b.setInsertionPointAfter(op);
auto cvtOut =
b.create<ConvertLayoutOp>(op.getLoc(), op.getType(), op.getResult());
b.replaceAllUsesExcept(op.getResult(), cvtOut, cvtOut);
b.modifyOpInPlace(op, [&] {
op.getSrcMutable().set(cvtSrc);
op.getIndicesMutable().set(cvtIdx);
op.getResult().setType(op.getType().cloneWithEncoding(newLayout));
op.setEfficientLayout(true);
});
assert(GatherLoweringHelper(op).isWarpLocal());
return success();
}
namespace {
struct OptimizeGatherLayoutPattern : public mlir::OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
if (op.getEfficientLayout())
return failure();
return setOptimizedGatherLayout(op, rewriter);
}
};
}
namespace {
class TritonGPUOptimizeThreadLocalityPass
: public impl::TritonGPUOptimizeThreadLocalityBase<
TritonGPUOptimizeThreadLocalityPass> {
void runOnOperation() override {
ModuleOp mod = getOperation();
mlir::RewritePatternSet layoutPatterns(&getContext());
layoutPatterns.add<OptimizeReshapeLayoutPattern>(&getContext());
layoutPatterns.add<OptimizeGatherLayoutPattern>(&getContext());
if (mlir::applyPatternsGreedily(mod, std::move(layoutPatterns)).failed()) {
signalPassFailure();
}
DenseSet<triton::ReduceOp> reduceOps;
mod.walk([&](triton::ReduceOp reduce) -> void {
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
auto rank = srcType.getShape().size();
auto srcEncoding = srcType.getEncoding();
auto reductionOp = getReductionOp(reduce);
if (!reductionOp ||
!isa<arith::AddFOp, arith::MulFOp, arith::MaximumFOp,
arith::MaxNumFOp, arith::MinimumFOp, arith::MinNumFOp>(
reductionOp.value()))
return;
if (!(isa<triton::gpu::BlockedEncodingAttr>(srcEncoding) && rank > 1))
return;
if (reduce.getAxis() != rank - 1)
return;
for (auto operand : reduce->getOperands()) {
if (!operand.getDefiningOp<triton::LoadOp>())
return;
}
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
if (elemsPerThread == 1)
return;
if (!reduce->hasOneUse())
return;
Operation *user = *(reduce->getUsers().begin());
if (!user->hasOneUse())
return;
OpOperand &yieldOpOperand = *(user->getUses().begin());
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOpOperand.getOwner());
if (!yieldOp)
return;
auto operandNumber = yieldOpOperand.getOperandNumber();
Block *block = reduce->getBlock();
Operation *parentOp = block->getParentOp();
auto forOp = dyn_cast<scf::ForOp>(parentOp);
if (!forOp)
return;
auto argNum = yieldOpOperand.getOperandNumber();
auto oldAccum = forOp.getInitArgs()[argNum];
auto cstOp = oldAccum.getDefiningOp<arith::ConstantOp>();
if (!cstOp)
return;
reduceOps.insert(reduce);
});
IRRewriter builder(&getContext());
for (auto reduce : reduceOps) {
builder.setInsertionPoint(reduce);
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
auto srcShape = srcType.getShape();
auto srcEncoding = srcType.getEncoding();
assert(isa<triton::gpu::BlockedEncodingAttr>(srcEncoding) &&
"Thread locality optimization only supports blocked encoding");
auto blocked = dyn_cast<triton::gpu::BlockedEncodingAttr>(srcEncoding);
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
auto rank = srcShape.size();
auto blocked3d = getThreadLocalityOptimizedEncoding(reduce);
auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce);
auto viewOpTensorType = RankedTensorType::get(
viewOpTensorShape, srcType.getElementType(), blocked3d);
auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank,
blocked3d);
assert(reduce->hasOneUse());
OpOperand &use = *(reduce->getUses().begin());
auto operandNumber = use.getOperandNumber();
auto oldUpdate = use.getOwner();
assert(oldUpdate->getNumOperands() == 2);
auto accumOperandNumber = (operandNumber == 0) ? 1 : 0;
auto accumOperand = oldUpdate->getOperand(accumOperandNumber);
assert(isa<BlockArgument>(accumOperand));
auto blockArg = dyn_cast<BlockArgument>(accumOperand);
auto blockArgNum = blockArg.getArgNumber();
auto forOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
auto oldAccum =
forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()];
Value loopResult =
forOp.getResult(blockArgNum - forOp.getNumInductionVars());
assert(loopResult.hasOneUse());
OpOperand &loopUse = *(loopResult.getUses().begin());
Operation *loopUser = loopUse.getOwner();
auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto newAccum =
createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d);
auto newLoop = replaceForOpWithNewSignature(
builder, forOp, ValueRange{newAccum->getResult(0)});
auto newReduce = createReduce(builder, reduce, viewOpTensorType);
auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate);
auto newYield = createYield(builder, newLoop, oldYield,
newUpdate->getResult(0), blockArgNum);
auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce);
Type destType = loopResult.getType();
auto cvtLayout = createConvertLayout(builder, destType, newReduce2);
auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate,
cvtLayout, oldAccum);
loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0));
oldYield.erase();
forOp.erase();
}
};
private:
std::optional<Operation *> getReductionOp(triton::ReduceOp reduce) const {
auto numRegions = reduce->getNumRegions();
if (numRegions != 1)
return std::nullopt;
Region ®ion = reduce->getRegion(0);
auto numBlocks = region.getBlocks().size();
if (numBlocks != 1)
return std::nullopt;
Block &block = region.front();
auto blockWithoutTerminator = block.without_terminator();
auto blockSizeWithoutTerminator = std::distance(
blockWithoutTerminator.begin(), blockWithoutTerminator.end());
if (blockSizeWithoutTerminator != 1)
return std::nullopt;
Operation *op = &block.front();
return std::optional<Operation *>(op);
}
Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder,
Operation *oldUpdate,
Operation *cvtLayout,
Value oldAccum) const {
builder.setInsertionPointAfter(cvtLayout);
IRMapping mapping;
mapping.map(oldUpdate->getOperand(0), oldAccum);
mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0));
auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping);
return finalOp;
}
Operation *createConvertLayout(OpBuilder &builder, Type destType,
Operation *newReduce) const {
builder.setInsertionPointAfter(newReduce);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
newReduce->getLoc(), destType, newReduce->getResult(0));
return newCvt;
}
Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop,
triton::ReduceOp &reduce) const {
auto resultIndex =
loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars();
auto newLoopResult = loop.getResult(resultIndex);
builder.setInsertionPointAfter(loop);
IRMapping mapping;
mapping.map(*(reduce.getOperands().begin()), newLoopResult);
auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping);
return newReduce2;
}
Operation *createYield(OpBuilder &builder, scf::ForOp &loop,
scf::YieldOp &oldYield, Value newUpdate,
int oldAccumBlockArgNum) const {
builder.setInsertionPoint(oldYield);
SmallVector<Value> yieldValues = llvm::to_vector(oldYield.getOperands());
yieldValues[oldAccumBlockArgNum - 1] =
loop.getBody()->getArgument(oldAccumBlockArgNum);
yieldValues.push_back(newUpdate);
auto newYield =
builder.create<scf::YieldOp>(oldYield.getLoc(), yieldValues);
return newYield;
}
Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop,
Operation *newReduce, Operation *oldUpdate) const {
auto blockArgNum = loop.getBody()->getNumArguments() - 1;
auto newArg = loop.getBody()->getArgument(blockArgNum);
builder.setInsertionPointAfter(newReduce);
IRMapping mapping;
mapping.map(oldUpdate->getOperand(0), newArg);
mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0));
auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping);
return newUpdate;
}
Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce,
Type viewOpTensorType) const {
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
auto rank = srcType.getShape().size();
builder.setInsertionPointAfter(reduce);
IRMapping mapping;
for (auto operand : reduce.getOperands()) {
auto viewOp = builder.create<triton::ReshapeOp>(
reduce.getLoc(), viewOpTensorType, operand,
true, true);
mapping.map(operand, viewOp);
}
auto newReduce = cloneWithInferType(builder, &(*reduce), mapping);
newReduce->setAttr("axis", builder.getI32IntegerAttr(rank));
auto typeInfer = dyn_cast<InferTypeOpInterface>(newReduce);
if (typeInfer) {
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newReduce->getContext(), newReduce->getLoc(),
newReduce->getOperands(), newReduce->getAttrDictionary(),
newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes);
if (succeeded(success)) {
for (size_t i = 0; i < newTypes.size(); i++)
newReduce->getResult(i).setType(newTypes[i]);
}
}
return newReduce;
}
std::optional<TypedAttr> getNeutralElement(Operation *op) const {
if (isa<arith::MaxNumFOp, arith::MinNumFOp>(op)) {
OpBuilder builder(op->getContext());
Type resultType = op->getResult(0).getType();
const llvm::fltSemantics &semantic =
llvm::cast<FloatType>(resultType).getFloatSemantics();
if (isa<arith::MaxNumFOp>(op)) {
return builder.getFloatAttr(
resultType, APFloat::getInf(semantic, true));
}
if (isa<arith::MinNumFOp>(op)) {
return builder.getFloatAttr(
resultType, APFloat::getInf(semantic, false));
}
} else {
return mlir::arith::getNeutralElement(op);
}
llvm_unreachable("Unhandled reduction op");
return std::nullopt;
}
Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce,
Value &oldAccum, SmallVector<int64_t> &shape,
Attribute &slice2d) const {
SmallVector<int64_t> accumShape(shape.begin(), shape.end() - 1);
auto elemType = cast<RankedTensorType>(oldAccum.getType()).getElementType();
auto accumType = RankedTensorType::get(accumShape, elemType, slice2d);
builder.setInsertionPointAfter(oldAccum.getDefiningOp());
auto reductionOp = getReductionOp(reduce);
assert(reductionOp && "Processing a reduce that is not supported!");
auto neutralVal = getNeutralElement(reductionOp.value());
assert(neutralVal && "Could not find neutral value for reduction op!");
auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value());
auto newAccum = builder.create<arith::ConstantOp>(oldAccum.getLoc(),
accumType, denseAttr);
return newAccum;
}
SmallVector<int64_t>
getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const {
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
auto srcShape = srcType.getShape();
auto rank = srcShape.size();
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
auto viewOpTensorShape = insertValue(srcShape, rank, 1);
viewOpTensorShape[reduce.getAxis()] /= elemsPerThread;
viewOpTensorShape[rank] = elemsPerThread;
return viewOpTensorShape;
}
BlockedEncodingAttr
getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
auto srcType = cast<RankedTensorType>(reduce.getOperands()[0].getType());
auto rank = srcType.getShape().size();
auto srcEncoding = srcType.getEncoding();
auto blocked = dyn_cast<triton::gpu::BlockedEncodingAttr>(srcEncoding);
auto sizePerThread3d =
insertValue(blocked.getSizePerThread(), rank,
blocked.getSizePerThread()[reduce.getAxis()]);
sizePerThread3d[reduce.getAxis()] = 1;
auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1);
auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1);
auto order3d = insertValue(blocked.getOrder(), 0, rank);
auto ctasPerCGA3d =
insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1);
auto ctasSplitNum3d =
insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1);
auto ctaOrder3d =
insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank);
auto ctaLayout3d = triton::gpu::CTALayoutAttr::get(
reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d);
auto blocked3d = triton::gpu::BlockedEncodingAttr::get(
reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d,
order3d, ctaLayout3d);
return blocked3d;
}
template <typename T>
SmallVector<T> insertValue(ArrayRef<T> vec, unsigned index, int value) const {
SmallVector<T> res(vec.begin(), vec.end());
res.insert(res.begin() + index, static_cast<T>(value));
return res;
}
template <typename T>
SmallVector<T> insertValue(const SmallVector<T> &vec, unsigned index,
int value) const {
SmallVector<T> res(vec.begin(), vec.end());
res.insert(res.begin() + index, static_cast<T>(value));
return res;
}
};
}
}
}
}