#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
#include "triton/Tools/StrUtil.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "triton-loop-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
namespace mlir {
namespace triton {
namespace gpu {
namespace {
int getSelfLatencyFromAttr(Operation *op) {
auto module = op->getParentOfType<ModuleOp>();
auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper();
if (!helper.isAttrPresent(op))
return 0;
int val = helper.getAttr(op).getInt();
helper.removeAttr(op);
return val;
}
bool mustLoadToRegisters(Operation *op) {
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
if (loadOp.getOther() && !isZeroConst(loadOp.getOther()))
return true;
}
if (!op->hasOneUse())
return true;
Operation *user = *op->getUsers().begin();
auto alloc = dyn_cast<ttg::LocalAllocOp>(user);
if (!alloc)
return true;
Attribute loadEncoding;
if (auto descLoad = dyn_cast<DescriptorLoadOp>(op)) {
loadEncoding = nvidia_gpu::getEncodingFromDescriptor(op, descLoad.getType(),
descLoad.getDesc());
} else if (auto descGather = dyn_cast<DescriptorGatherOp>(op)) {
loadEncoding = nvidia_gpu::getEncodingFromDescriptor(
op, descGather.getType(), descGather.getDesc());
}
return loadEncoding && (loadEncoding != alloc.getType().getEncoding());
}
int getDefUseStageDiff(Operation *op, scf::ForOp forOp,
CoarseSchedule &schedule) {
assert(schedule.count(op) && "Op not found in the schedule");
int defStage = schedule[op].first;
std::optional<int> useStage;
DenseSet<Operation *> topLevelUsers =
triton::getTopLevelUsersInLoop(op, forOp);
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
DenseSet<Operation *> allocUsers;
for (Operation *topLevelUser : topLevelUsers) {
if (auto localAlloc = dyn_cast<ttg::LocalAllocOp>(topLevelUser)) {
DenseSet<Operation *> users =
triton::getTopLevelUsersInLoop(localAlloc, forOp);
allocUsers.insert(users.begin(), users.end());
}
}
topLevelUsers.insert(allocUsers.begin(), allocUsers.end());
}
DenseSet<Operation *> topLevelWaitUsers;
for (Operation *topLevelUser : topLevelUsers) {
if (isa<ttng::WaitBarrierOp>(topLevelUser)) {
topLevelWaitUsers.insert(topLevelUser);
}
}
for (Operation *topLevelUser : topLevelUsers) {
int _useStage = schedule[topLevelUser].first;
useStage = std::min(_useStage, useStage.value_or(_useStage));
}
for (Operation *topLevelUser : topLevelWaitUsers) {
int _useStage = schedule[topLevelUser].first;
useStage = std::max(_useStage, useStage.value_or(_useStage));
}
if (!useStage)
return 0;
assert(useStage >= defStage && "Op used before defined");
return useStage.value() - defStage;
}
void replaceAllUsesDominatedBy(Operation *domOp, Value newValue, Value oldValue,
DominanceInfo &domInfo) {
if (newValue == oldValue)
return;
oldValue.replaceUsesWithIf(newValue, [&](OpOperand &use) {
return domInfo.properlyDominates(domOp, use.getOwner());
});
}
static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
ttg::SharedEncodingTrait sharedEnc,
unsigned distance) {
return triton::createAlloc(
forOp, cast<RankedTensorType>(loadOp->getResultTypes().front()),
loadOp->getLoc(), sharedEnc, distance);
}
void createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx,
CoarseSchedule &schedule) {
OpBuilderForStage builder(loadOp.getLoc(), forOp, schedule);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule);
assert(firstUse && "LoadOp has no users");
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(loadOp);
builder.setStageCluster(schedule[loadOp]);
Value src = loadOp.getPtr();
Value mask = loadOp.getMask();
Value other = loadOp.getOther();
ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType());
Value view = createSingleBufferView(builder, alloc, insertIdx);
Operation *copy = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
src, view, mask, other, loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile());
Operation *commit =
builder.create<ttg::AsyncCommitGroupOp>(copy->getResult(0));
builder.setStageCluster(schedule[firstUse]);
auto wait = builder.create<ttg::AsyncWaitOp>(commit->getResult(0), 0);
auto viewLoad = createSingleBufferView(builder, alloc, extractIdx);
if (!loadOp.getOther() || isZeroConst(loadOp.getOther())) {
replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad,
wait.getResult());
} else if (loadOp->use_begin() != loadOp->use_end()) {
auto sharedLoad = builder.create<ttg::LocalLoadOp>(
loadOp.getType(), viewLoad, wait.getResult());
auto select = builder.create<arith::SelectOp>(
loadOp.getType(),
loadOp.getMask(), sharedLoad.getResult(), other);
loadOp->replaceAllUsesWith(select->getResults());
}
schedule.erase(loadOp);
loadOp->erase();
}
void createTMAAsyncCopy(
scf::ForOp forOp, Operation *loadOp, Value desc, Value alloc,
Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp,
CoarseSchedule &schedule,
function_ref<void(OpBuilderForStage &, Value, Value, Value, Value)>
createCopy) {
OpBuilderForStage builder(loadOp->getLoc(), forOp, schedule);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule);
assert(firstUse && "LoadOp has no users");
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(forOp.getContext());
builder.setInsertionPoint(loadOp);
builder.setStageCluster(schedule[loadOp]);
ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType());
Value view = createSingleBufferView(builder, alloc, insertIdx);
Value pred = builder.create<arith::ConstantIntOp>(1, 1);
createCopy(builder, desc, barrier, view, pred);
builder.setInsertionPointAfter(waitOp);
builder.setStageCluster(schedule[firstUse]);
auto viewLoad = createSingleBufferView(builder, alloc, extractIdx);
replaceUsesWithLocalLoad(builder, loadOp->getResult(0), viewLoad);
schedule.erase(loadOp);
loadOp->erase();
}
void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp,
Value alloc, Value insertIdx, Value extractIdx,
Value barrier, Operation *waitOp,
CoarseSchedule &schedule) {
return createTMAAsyncCopy(
forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier,
waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view,
Value pred) {
auto indices = ttng::translateTMAIndices(
builder, loadOp.getLoc(),
loadOp.getDesc().getType().getBlockType().getEncoding(),
loadOp.getIndices());
builder.create<ttng::AsyncTMACopyGlobalToLocalOp>(
loadOp.getLoc(), tmaPtr, indices, barrier, view, pred);
});
}
void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
Value alloc, Value insertIdx, Value extractIdx,
Value barrier, Operation *waitOp,
CoarseSchedule &schedule) {
return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc,
insertIdx, extractIdx, barrier, waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr,
Value barrier, Value view, Value pred) {
builder.create<ttng::AsyncTMAGatherOp>(
gatherOp.getLoc(), tmaPtr,
gatherOp.getXOffsets(), gatherOp.getYOffset(),
barrier, view, pred);
});
}
struct AsyncLoad {
int stageDiff;
Value alloc;
Value barrier;
Operation *waitOp;
SharedEncodingTrait sharedEncoding;
};
struct LoadGroupInfo {
Value insertIdx;
Value extractIdx;
Value phase;
bool hasTMALoad = false;
};
void convertScalarToTensorLoad(Operation *op, CoarseSchedule &schedule,
scf::ForOp forOp) {
auto scalarLoad = cast<tt::LoadOp>(op);
Type scalarTy = scalarLoad.getType();
OpBuilderForStage builder(op->getLoc(), op, schedule);
builder.setInsertionPoint(op);
MLIRContext *ctx = op->getContext();
auto nWarps = lookupNumWarps(op);
ModuleOp mod = forOp->getParentOfType<ModuleOp>();
auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto numCTAs = TritonGPUDialect::getNumCTAs(mod);
auto blockedEnc =
getDefaultBlockedEncoding(ctx, {1}, nWarps, threadsPerWarp, numCTAs);
auto newPtrTy =
RankedTensorType::get({1}, scalarLoad.getPtr().getType(), blockedEnc);
auto newPtr =
builder.create<tt::SplatOp>(op->getLoc(), newPtrTy, scalarLoad.getPtr());
scalarLoad.getPtrMutable().assign(newPtr);
if (scalarLoad.getMask()) {
auto newMaskTy =
RankedTensorType::get({1}, scalarLoad.getMask().getType(), blockedEnc);
auto newMask = builder.create<tt::SplatOp>(op->getLoc(), newMaskTy,
scalarLoad.getMask());
scalarLoad.getMaskMutable().assign(newMask);
}
if (scalarLoad.getOther()) {
auto newOtherTy =
RankedTensorType::get({1}, scalarLoad.getOther().getType(), blockedEnc);
auto newOther = builder.create<tt::SplatOp>(op->getLoc(), newOtherTy,
scalarLoad.getOther());
scalarLoad.getOtherMutable().assign(newOther);
}
auto newDstTy = RankedTensorType::get({1}, scalarLoad.getType(), blockedEnc);
scalarLoad.getResult().setType(newDstTy);
builder.setInsertionPointAfter(op);
Operation *firstUse = getFirstUseOfPipelinedOp({op}, forOp, schedule);
builder.setStageCluster(schedule[firstUse]);
Operation *unsplat = builder.create<tt::UnsplatOp>(op->getLoc(), scalarTy,
scalarLoad.getResult());
scalarLoad.getResult().replaceAllUsesExcept(unsplat->getResult(0), unsplat);
}
void createTMABarrierAndWait(
scf::ForOp forOp, llvm::MapVector<Operation *, AsyncLoad> &asyncLoads,
const llvm::MapVector<int, LoadGroupInfo> &loadGroups,
CoarseSchedule &schedule) {
SmallVector<SmallVector<Operation *>> commonWaitGroups;
llvm::SmallDenseSet<Operation *> visited;
for (auto &[loadOp, asyncLoad] : asyncLoads) {
if (!isTMALoad(loadOp) || visited.count(loadOp))
continue;
llvm::SmallDenseSet<Operation *> users;
SmallVector<Operation *> group;
Block *loadBlock = loadOp->getBlock();
auto addToGroup = [&](Operation *loadOp) {
group.push_back(loadOp);
visited.insert(loadOp);
for (Operation *user : loadOp->getUsers()) {
if (!mustLoadToRegisters(loadOp)) {
assert(loadOp->hasOneUse());
auto alloc = cast<ttg::LocalAllocOp>(*loadOp->getUsers().begin());
if (alloc->getBlock() == loadBlock) {
users.insert(alloc->getUsers().begin(), alloc->getUsers().end());
continue;
}
}
Operation *userInBlock = loadBlock->findAncestorOpInBlock(*user);
if (userInBlock)
users.insert(userInBlock);
}
};
addToGroup(loadOp);
Operation *nextOp = loadOp->getNextNode();
int numBuffers = asyncLoad.stageDiff;
while (nextOp) {
if (users.count(nextOp) || visited.count(nextOp))
break;
if (isTMALoad(nextOp) && asyncLoads.count(nextOp)) {
if (asyncLoads[nextOp].stageDiff != numBuffers)
break;
if (group.size() > 0 && schedule[group[0]] == schedule[nextOp]) {
addToGroup(nextOp);
}
}
nextOp = nextOp->getNextNode();
}
commonWaitGroups.push_back(group);
}
for (SmallVector<Operation *> &group : commonWaitGroups) {
int sizeInBytes = 0;
int numBuffers = asyncLoads[group[0]].stageDiff;
const LoadGroupInfo loadGroup = loadGroups.find(numBuffers)->second;
for (Operation *op : group) {
auto tensorTy = cast<RankedTensorType>(op->getResultTypes()[0]);
int loadSize = product(getShapePerCTA(tensorTy));
sizeInBytes += loadSize * tensorTy.getElementTypeBitWidth() / 8;
}
Value barrierAlloc = triton::createBarrierAlloc(forOp, numBuffers);
OpBuilderForStage builder(forOp.getLoc(), group[0], schedule);
Value barrier = triton::createSingleBufferView(builder, barrierAlloc,
loadGroup.insertIdx);
Value pred = builder.create<arith::ConstantIntOp>(1, 1);
builder.create<ttng::BarrierExpectOp>(barrier, sizeInBytes, pred);
builder.setInsertionPointAfter(group.back());
Operation *firstUse = getFirstUseOfPipelinedOp(group, forOp, schedule);
builder.setStageCluster(schedule[firstUse]);
Value barrierViewWait = triton::createSingleBufferView(
builder, barrierAlloc, loadGroup.extractIdx);
auto wait =
builder.create<ttng::WaitBarrierOp>(barrierViewWait, loadGroup.phase);
for (Operation *op : group) {
asyncLoads[op].barrier = barrier;
asyncLoads[op].waitOp = wait;
}
}
}
bool loadRequiresAdditionalBuffer(Operation *loadOp) {
auto skipViewOps = [](Operation *op) -> Operation * {
while (op->hasOneUse() && op->hasTrait<OpTrait::MemDescViewTrait>()) {
op = *op->getUsers().begin();
}
return op;
};
if (!mustLoadToRegisters(loadOp)) {
assert(loadOp->hasOneUse());
ttg::LocalAllocOp alloc =
dyn_cast<ttg::LocalAllocOp>(*loadOp->getUsers().begin());
if (alloc) {
return llvm::any_of(alloc->getUsers(), [&](Operation *op) {
return isa<ttng::WarpGroupDotOp>(skipViewOps(op));
});
}
}
return false;
}
scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
llvm::MapVector<Operation *, AsyncLoad> asyncLoads;
llvm::MapVector<int, LoadGroupInfo> loadGroups;
llvm::SmallVector<Operation *> scalarLoads;
for (auto &op : forOp.getBody()->without_terminator()) {
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
int stageDiff = getDefUseStageDiff(&op, forOp, schedule);
if (stageDiff == 0) {
continue;
}
SharedEncodingTrait sharedEncoding;
bool canUseAsyncCp = false;
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
forOp.getContext(), 1, 1, 1, {0},
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
if (canUseAsyncCp) {
scalarLoads.push_back(&op);
}
} else {
sharedEncoding = getSharedEncoding(&op);
canUseAsyncCp =
isa<tt::LoadOp>(op) &&
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
int copyVecBytes = getCopyVecBytes(
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
canUseAsyncCp &= copyVecBytes >= 4;
}
if (canUseAsyncCp || isTMALoad(&op)) {
if (loadRequiresAdditionalBuffer(&op)) {
stageDiff += 1;
}
auto &asyncLoad = asyncLoads[&op];
asyncLoad.stageDiff = stageDiff;
asyncLoad.sharedEncoding = sharedEncoding;
} else if (stageDiff > 1) {
op.emitRemark() << "Pipelining load that cannot use vectorized "
"copy. This will likely "
"lead to pipelining in registers and severe "
"performance degradation.";
}
}
}
for (auto op : scalarLoads) {
convertScalarToTensorLoad(op, schedule, forOp);
}
if (asyncLoads.empty())
return forOp;
for (auto &[loadOp, asyncLoad] : asyncLoads) {
Value alloc = createAlloc(forOp, loadOp, asyncLoad.sharedEncoding,
asyncLoad.stageDiff);
asyncLoad.alloc = alloc;
loadGroups.insert({asyncLoad.stageDiff, {}});
if (isTMALoad(loadOp)) {
loadGroups[asyncLoad.stageDiff].hasTMALoad = true;
}
}
IRRewriter builder(forOp);
builder.setInsertionPoint(forOp);
Location loc = forOp.getLoc();
Value minusOne = builder.create<arith::ConstantIntOp>(loc, -1, 32);
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
SmallVector<Value> newOperands;
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
for (auto [_, loadGroup] : loadGroups) {
newOperands.push_back(minusOne);
newOperands.push_back(minusOne);
if (loadGroup.hasTMALoad) {
newOperands.push_back(zero);
}
}
forOp = addIterArgsToLoop(builder, forOp, newOperands);
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned i = 0; i < newOperands.size(); ++i) {
forYield.getResultsMutable().append(newOperands[i]);
}
builder.setInsertionPoint(forOp);
loc = forOp.getLoc();
int argIdx = newOperandIndex;
for (auto &[numBuffers, loadGroup] : loadGroups) {
Value insertIdx = forOp.getBody()->getArgument(argIdx);
argIdx++;
Value extractIdx = forOp.getBody()->getArgument(argIdx);
argIdx++;
Value phase = nullptr;
if (loadGroup.hasTMALoad) {
phase = forOp.getBody()->getArgument(argIdx);
argIdx++;
}
builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin());
Value numBuffersVal =
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
loadGroup.insertIdx = createIncrementModulo(builder, loc, insertIdx,
numBuffersVal, zero, one);
Value cndExt = nullptr;
loadGroup.extractIdx = createIncrementModulo(
builder, loc, extractIdx, numBuffersVal, zero, one, &cndExt);
if (phase) {
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, one);
phase = builder.create<arith::SelectOp>(loc, cndExt, nextPhase, phase);
loadGroup.phase = phase;
}
}
createTMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule);
bool hasAsyncLoads = false;
for (auto [op, asyncLoad] : asyncLoads) {
auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff];
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
schedule);
hasAsyncLoads = true;
} else if (auto loadOp = dyn_cast<tt::DescriptorLoadOp>(op)) {
createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
asyncLoad.barrier, asyncLoad.waitOp, schedule);
} else if (auto loadOp = dyn_cast<tt::DescriptorGatherOp>(op)) {
createTMAAsyncGather(forOp, loadOp, asyncLoad.alloc, insertIdx,
extractIdx, asyncLoad.barrier, asyncLoad.waitOp,
schedule);
}
}
argIdx = newOperandIndex - 1;
for (auto &[numBuffers, loadGroup] : loadGroups) {
forYield.setOperand(argIdx++, loadGroup.insertIdx);
forYield.setOperand(argIdx++, loadGroup.extractIdx);
if (loadGroup.phase)
forYield.setOperand(argIdx++, loadGroup.phase);
}
scheduleDependencies(forOp, schedule);
if (hasAsyncLoads) {
builder.setInsertionPointAfter(forOp);
builder.create<ttg::AsyncWaitOp>(loc, ValueRange({}), 0);
}
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!schedule.count(&op)) {
op.emitError() << "op not found in the schedule";
}
assert(schedule.count(&op) && "op not found in the schedule");
}
return forOp;
}
std::pair<Operation *, Operation *>
getTmemUseStageBoundOps(ttng::TMEMAllocOp alloc, scf::ForOp forOp,
CoarseSchedule &schedule) {
std::pair<Operation *, Operation *> bounds = {nullptr, nullptr};
for (auto user : alloc->getUsers()) {
if (!forOp->isAncestor(user->getParentOp())) {
continue;
}
auto topLevelUser = forOp.getBody()->findAncestorOpInBlock(*user);
if (!bounds.first) {
bounds.first = topLevelUser;
}
if (!bounds.second) {
bounds.second = topLevelUser;
}
if (schedule.isOpBefore(topLevelUser, bounds.first)) {
bounds.first = topLevelUser;
}
if (schedule.isOpBefore(bounds.second, topLevelUser)) {
bounds.second = topLevelUser;
}
}
return bounds;
}
Operation *hoistBufferOutOfLoop(scf::ForOp forOp, Operation *op,
CoarseSchedule &schedule) {
Operation *newStore = nullptr;
if (!isa<ttng::TMEMAllocOp, ttg::LocalAllocOp>(op))
return nullptr;
if (!forOp->isAncestor(op))
return nullptr;
OpBuilderForStage builder(op->getLoc(), forOp, schedule);
auto allocType = dyn_cast<MemDescType>(op->getResult(0).getType());
auto newType = triton::gpu::MemDescType::get(
allocType.getShape(), allocType.getElementType(), allocType.getEncoding(),
allocType.getMemorySpace(),
true);
auto newAlloc = builder.clone(*op);
newAlloc->getResult(0).setType(newType);
builder.setStageCluster(schedule[op]);
if (auto tmemAlloc = dyn_cast<ttng::TMEMAllocOp>(newAlloc)) {
tmemAlloc.getSrcMutable().clear();
builder.setInsertionPointAfter(op);
Value trueVal = builder.create<arith::ConstantIntOp>(1, 1);
newStore = builder.create<ttng::TMEMStoreOp>(tmemAlloc.getResult(),
op->getOperand(0), trueVal);
} else {
auto localAlloc = cast<ttg::LocalAllocOp>(newAlloc);
localAlloc.getSrcMutable().clear();
builder.setInsertionPointAfter(op);
newStore = builder.create<ttg::LocalStoreOp>(op->getOperand(0),
localAlloc.getResult());
}
replaceUsesAndPropagateType(builder, op, newAlloc->getResult(0));
op->erase();
return newStore;
}
void createBarrierAndWaitOps(scf::ForOp forOp, CoarseSchedule &schedule,
ttng::MMAv5OpInterface mma, int mmaSelfLatency,
ttng::TMEMAllocOp alloc, int phaseArgIdx,
int barrierIdxArgIdx) {
auto isLoadToBePipelined = [&](Operation *op) {
return schedule[mma].first > schedule[op].first;
};
std::optional<Operation *> latestSyncPoint;
for (auto user : alloc->getUsers()) {
if (auto load = dyn_cast<ttng::TMEMLoadOp>(user)) {
if (load->getBlock() != mma->getBlock()) {
continue;
}
if (!latestSyncPoint || schedule.isOpBefore(load, *latestSyncPoint)) {
latestSyncPoint = load;
}
}
}
ttng::MMAv5PipelineableOperandsHelper mmaPipeHelper(mma, forOp,
isLoadToBePipelined);
SmallVector<Operation *> updatedDefs;
for (auto def : mmaPipeHelper.unpipelineableOperandDefs) {
auto newStore = hoistBufferOutOfLoop(forOp, def, schedule);
if (newStore) {
updatedDefs.push_back(newStore);
} else {
updatedDefs.push_back(def);
}
}
if (!mmaPipeHelper.isPipelineable &&
mmaPipeHelper.isOperandsStateDetermined) {
for (auto def : updatedDefs) {
if (!latestSyncPoint || schedule.isOpBefore(def, *latestSyncPoint)) {
latestSyncPoint = def;
}
}
}
int mainWaitStage = schedule[mma].first + mmaSelfLatency;
CoarseSchedule::Cluster mainWaitCluster = schedule[mma].second;
if (latestSyncPoint && mmaPipeHelper.isOperandsStateDetermined) {
if (schedule.isOpBefore(*latestSyncPoint, mma)) {
mainWaitStage = schedule[mma].first + 1;
mainWaitCluster = schedule.clusters.newBefore(
schedule.splitClusterBefore(*latestSyncPoint, forOp));
} else {
mainWaitStage = schedule[*latestSyncPoint].first;
mainWaitCluster = schedule.clusters.newBefore(
schedule.splitClusterBefore(*latestSyncPoint, forOp));
}
}
int numStages = mainWaitStage - schedule[mma].first + 1;
OpBuilderForStage builder(mma.getLoc(), mma, schedule);
Value barrierAlloc = createBarrierAlloc(forOp, numStages);
Value vTrue = builder.create<arith::ConstantIntOp>(1, 1);
Value phase = forOp.getRegionIterArg(phaseArgIdx);
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
Value one = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32);
Value numStagesVal =
builder.create<arith::ConstantIntOp>(forOp.getLoc(), numStages, 32);
Value barrierSlice = barrierAlloc;
if (numStages > 1) {
barrierSlice =
triton::createSingleBufferView(builder, barrierAlloc, barrierIdx);
}
mma.addCompletionBarrier(barrierSlice, vTrue);
mma.setIsAsync(true);
SmallVector<Value> waitBuffers;
auto mmaAsDotOp = cast<DotOpInterface>(mma.getOperation());
waitBuffers.push_back(mmaAsDotOp.getA());
waitBuffers.push_back(mmaAsDotOp.getB());
if (auto mmaAsScaledDotOp =
dyn_cast<ttng::TCGen5MMAScaledOp>(mma.getOperation())) {
waitBuffers.push_back(mmaAsScaledDotOp.getAScale());
waitBuffers.push_back(mmaAsScaledDotOp.getBScale());
}
builder.setInsertionPointAfter(mma);
builder.setStageCluster({mainWaitStage, mainWaitCluster});
builder.create<ttng::WaitBarrierOp>(barrierSlice, phase, waitBuffers);
for (auto user : alloc->getUsers()) {
if (auto load = dyn_cast<ttng::TMEMLoadOp>(user)) {
if (load->getBlock() == mma->getBlock()) {
continue;
}
auto topLevelUser = forOp.getBody()->findAncestorOpInBlock(*load);
if (!topLevelUser) {
continue;
}
auto [loadStage, loadCluster] = schedule[topLevelUser];
if (loadStage < mainWaitStage) {
builder.setStageCluster({loadStage, loadCluster});
builder.setInsertionPoint(load);
builder.create<ttng::WaitBarrierOp>(barrierSlice, phase, waitBuffers);
}
}
}
builder.setStageCluster(schedule[mma]);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
builder.setInsertionPoint(yieldOp);
Value newPhase = builder.create<arith::XOrIOp>(phase, one);
Value newBarrierIdx = barrierIdx;
if (numStages > 1) {
Value barWrap;
newBarrierIdx = createIncrementModulo(builder, builder.getLoc(), barrierIdx,
numStagesVal, zero, one, &barWrap);
newPhase = builder.create<arith::SelectOp>(phase.getType(), barWrap,
newPhase, phase);
}
yieldOp->replaceUsesOfWith(phase, newPhase);
yieldOp->replaceUsesOfWith(barrierIdx, newBarrierIdx);
}
void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule,
ttng::TMEMAllocOp alloc, int bufIdxArgIdx,
int tmemUseNumStages) {
DominanceInfo domInfo(forOp);
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
SmallVector<std::pair<Operation *, Value>> bufIdxDefs;
auto getCurrBufIdx = [&](Operation *op) {
for (auto [_op, _val] : llvm::reverse(bufIdxDefs)) {
if (domInfo.properlyDominates(_op, op)) {
return _val;
}
}
return Value();
};
bufIdxDefs.push_back({&forOp.getBody()->front(), bufIdx});
OpBuilderForStage builder(alloc.getLoc(), alloc, schedule);
auto newAlloc = createTMemAlloc(builder, alloc, true, tmemUseNumStages);
Value numStagesVal =
builder.create<arith::ConstantIntOp>(tmemUseNumStages, 32);
Value zero = builder.create<arith::ConstantIntOp>(0, 32);
Value one = builder.create<arith::ConstantIntOp>(1, 32);
bool multibufferingIsValid = false;
SmallVector<Operation *> allocUsers =
llvm::to_vector(alloc.getResult().getUsers());
Value replTok = OpBuilder(forOp).create<ub::PoisonOp>(
forOp.getLoc(), builder.getType<AsyncTokenType>());
if (newAlloc.getToken()) {
newAlloc.getToken().replaceAllUsesWith(replTok);
}
for (auto user : allocUsers) {
if (auto store = dyn_cast<ttng::TMEMStoreOp>(user)) {
store.getDepMutable().clear();
store.getToken().replaceAllUsesWith(replTok);
if (forOp->isAncestor(store)) {
multibufferingIsValid = true;
builder.setStageCluster(schedule[store]);
builder.setInsertionPoint(store);
Value curBufIdx = getCurrBufIdx(store);
Value newBufIdx = createIncrementModulo(
builder, forOp.getLoc(), curBufIdx, numStagesVal, zero, one);
if (Value pred = store.getPred()) {
newBufIdx = builder.create<arith::SelectOp>(newBufIdx.getType(), pred,
newBufIdx, curBufIdx);
}
replaceAllUsesDominatedBy(store, newBufIdx, curBufIdx, domInfo);
bufIdxDefs.push_back({store, newBufIdx});
auto tmemSlice =
triton::createSingleBufferView(builder, newAlloc, newBufIdx);
store.getDstMutable().assign(tmemSlice);
} else {
assert(store->isBeforeInBlock(forOp) && "Store is not before the loop");
builder.setInsertionPoint(store);
auto tmemSlice =
triton::createSingleBufferView(builder, newAlloc, zero);
store.getDstMutable().assign(tmemSlice);
}
} else if (auto load = dyn_cast<ttng::TMEMLoadOp>(user)) {
load.getDepMutable().clear();
load.getToken().replaceAllUsesWith(replTok);
if (forOp->isAncestor(load)) {
builder.setStageCluster(schedule[load]);
builder.setInsertionPoint(load);
Value curBufIdx = getCurrBufIdx(load);
auto tmemSlice =
triton::createSingleBufferView(builder, newAlloc, curBufIdx);
load.getSrcMutable().assign(tmemSlice);
} else {
assert(forOp->isBeforeInBlock(load) && "Load is not after the loop");
builder.setInsertionPoint(load);
auto tmemSlice = triton::createSingleBufferView(
builder, newAlloc, forOp->getResult(bufIdxArgIdx));
load.getSrcMutable().assign(tmemSlice);
}
} else if (auto mma = dyn_cast<ttng::MMAv5OpInterface>(user)) {
mma.getAccDepMutable().clear();
mma.getToken().replaceAllUsesWith(replTok);
builder.setStageCluster(schedule[mma]);
builder.setInsertionPoint(mma);
auto isConstTrue = [](Value v) {
if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<BoolAttr>(constOp.getValueAttr())) {
return attr.getValue();
}
}
return false;
};
multibufferingIsValid = !isConstTrue(mma.useAccumulator());
Value curBufIdx = getCurrBufIdx(mma.getOperation());
Value newBufIdx = createIncrementModulo(
builder, forOp.getLoc(), curBufIdx, numStagesVal, zero, one);
newBufIdx = builder.create<arith::SelectOp>(
newBufIdx.getType(), mma.useAccumulator(), curBufIdx, newBufIdx);
replaceAllUsesDominatedBy(mma.getOperation(), newBufIdx, curBufIdx,
domInfo);
bufIdxDefs.push_back({mma.getOperation(), newBufIdx});
auto tmemSlice =
triton::createSingleBufferView(builder, newAlloc, newBufIdx);
mma.setAccumulator(tmemSlice);
} else {
llvm::errs() << "Unsupported user of the accumulator: " << *user << "\n";
llvm::report_fatal_error("Unsupported user of the accumulator");
}
}
if (!multibufferingIsValid) {
llvm::report_fatal_error(
"Trying to multibuffer TMEM while there is no store to the "
"accumulator, and the mma uses the accumulator all the time.");
}
alloc.getToken().replaceAllUsesWith(newAlloc.getToken());
alloc->erase();
Value newBufIdx = bufIdxDefs.back().second;
replaceAllUsesDominatedBy(newBufIdx.getDefiningOp(), newBufIdx, bufIdx,
domInfo);
}
scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp,
CoarseSchedule &schedule) {
auto isLoadToBePipelined = [&](Operation *op) {
return schedule[mma].first > schedule[op].first;
};
auto alloc = mma.getAccumulator().getDefiningOp<ttng::TMEMAllocOp>();
if (!alloc) {
return forOp;
}
int mmaSelfLatency = getSelfLatencyFromAttr(mma.getOperation());
if (mmaSelfLatency == 0) {
return forOp;
}
std::pair<Operation *, Operation *> tmemUseStageBoundOps =
getTmemUseStageBoundOps(alloc, forOp, schedule);
int tmemUseNumStages = schedule[tmemUseStageBoundOps.second].first -
schedule[tmemUseStageBoundOps.first].first;
if (schedule.isOpInEarlierCluster(tmemUseStageBoundOps.first,
tmemUseStageBoundOps.second) ||
(schedule.isOpInSameCluster(tmemUseStageBoundOps.first,
tmemUseStageBoundOps.second) &&
tmemUseStageBoundOps.first->isBeforeInBlock(
tmemUseStageBoundOps.second))) {
tmemUseNumStages += 1;
}
OpBuilder builder(forOp);
Value minusOne = builder.create<arith::ConstantIntOp>(forOp.getLoc(), -1, 32);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
unsigned newOperandIndex = forOp.getInitArgs().size();
SmallVector<Value> newOperands = {
zero,
zero,
};
if (tmemUseNumStages > 1) {
newOperands.push_back(minusOne);
}
scf::ForOp newForOp =
replaceForOpWithNewSignature(builder, forOp, newOperands);
forOp.erase();
forOp = newForOp;
int phaseArgIdx = newOperandIndex + 0;
int barrierIdxArgIdx = newOperandIndex + 1;
int bufIdxArgIdx = newOperandIndex + 2;
Value phase = forOp.getRegionIterArg(phaseArgIdx);
Value barrierIdx = forOp.getRegionIterArg(barrierIdxArgIdx);
SmallVector<Value> newYieldOperands = {phase, barrierIdx};
if (tmemUseNumStages > 1) {
Value bufIdx = forOp.getRegionIterArg(bufIdxArgIdx);
newYieldOperands.push_back(bufIdx);
}
cast<scf::YieldOp>(forOp.getBody()->getTerminator())
.getResultsMutable()
.append(newYieldOperands);
createBarrierAndWaitOps(forOp, schedule, mma, mmaSelfLatency, alloc,
phaseArgIdx, barrierIdxArgIdx);
if (tmemUseNumStages > 1) {
multibufferTensorMemory(forOp, schedule, alloc, bufIdxArgIdx,
tmemUseNumStages);
}
return forOp;
}
scf::ForOp lowerMMAs(scf::ForOp forOp, CoarseSchedule &schedule) {
SmallVector<ttng::MMAv5OpInterface> mmas;
forOp.walk([&](ttng::MMAv5OpInterface mma) { mmas.push_back(mma); });
for (auto mma : mmas) {
forOp = lowerMMA(mma, forOp, schedule);
}
return forOp;
}
void lowerLoop(scf::ForOp forOp,
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
CoarseSchedule schedule;
if (failed(schedule.deSerialize(forOp))) {
return;
}
scf::ForOp newForOp = lowerMMAs(forOp, schedule);
newForOp = lowerLoads(newForOp, schedule, axisInfoAnalysis);
newForOp = lowerTMADescriptors(newForOp, schedule);
schedule.serialize(newForOp);
}
}
void lowerLoops(ModuleOp moduleOp) {
triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
SmallVector<scf::ForOp> loops;
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
if (loops.empty())
return;
for (auto forOp : loops) {
lowerLoop(forOp, axisInfoAnalysis);
}
}
}
}
}