#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <queue>
#define DEBUG_TYPE "triton-loop-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
bool triton::isPureScalarOp(Operation *op) {
auto isScalar = [](Type type) { return type.isIntOrIndexOrFloat(); };
return isPure(op) && llvm::all_of(op->getOperandTypes(), isScalar) &&
llvm::all_of(op->getResultTypes(), isScalar);
}
bool triton::getDominatingValueSetOpsToHoist(
DominanceInfo &domInfo, Operation *refOp, ArrayRef<Value> valueSet,
llvm::SetVector<Operation *> &toHoist,
function_ref<bool(Operation *)> canHoist) {
llvm::SetVector<Operation *> visited;
std::queue<Value> queue;
for (Value value : valueSet)
queue.push(value);
while (!queue.empty()) {
Value value = queue.front();
queue.pop();
if (domInfo.properlyDominates(value, refOp))
continue;
if (auto arg = dyn_cast<BlockArgument>(value))
return false;
Operation *op = value.getDefiningOp();
if (visited.contains(op))
continue;
if (!canHoist(op))
return false;
visited.insert(op);
for (Value operand : op->getOperands())
queue.push(operand);
}
toHoist.insert(visited.begin(), visited.end());
return true;
}
void triton::hoistOpsBefore(Operation *refOp,
const llvm::SetVector<Operation *> &toHoist) {
return hoistOpsBefore(refOp->getBlock(), refOp->getIterator(), toHoist);
}
void triton::hoistOpsBefore(Block *block, Block::iterator it,
const llvm::SetVector<Operation *> &toHoist) {
for (Operation *op : topologicalSort(toHoist)) {
op->moveBefore(block, it);
}
}
Value triton::sinkValueRedefinition(RewriterBase &rewriter, Value in, Value out,
Block *block) {
OpBuilder::InsertionGuard guard(rewriter);
for (; block != in.getParentBlock();
block = block->getParentOp()->getBlock()) {
Operation *op = block->getParentOp();
rewriter.setInsertionPoint(op);
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
forOp = addIterArgsToLoop(rewriter, forOp, in);
appendToForOpYield(forOp, out);
out = forOp.getResults().back();
continue;
}
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
scf::IfOp newIfOp =
replaceIfOpWithNewSignature(rewriter, ifOp, out.getType());
scf::YieldOp taken = newIfOp.thenYield();
scf::YieldOp other = newIfOp.elseYield();
if (block == newIfOp.elseBlock())
std::swap(taken, other);
taken->insertOperands(taken.getNumOperands(), out);
other->insertOperands(other.getNumOperands(), in);
out = newIfOp.getResults().back();
rewriter.eraseOp(ifOp);
continue;
}
llvm::report_fatal_error("FIXME: sinking into unhandled control flow op: " +
op->getName().getStringRef());
}
return out;
}
bool mlir::triton::loopHasDistGreaterThanOne(scf::ForOp forOp) {
return llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
[](Value operand) {
Operation *def = operand.getDefiningOp();
return !def;
});
}
bool mlir::triton::isOuterLoop(scf::ForOp forOp) {
return llvm::any_of(forOp.getBody()->getOperations(), [](Operation &op) {
return isa<scf::ForOp, scf::WhileOp>(op);
});
}
Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
Value pred) {
OpBuilder::InsertionGuard guard(rewriter);
if (mlir::isMemoryEffectFree(op))
return op;
if (isConstantIntValue(pred, 1))
return op;
if (isa<LLVM::AssumeOp, ttng::FenceAsyncSharedOp>(op))
return op;
if (isa<ttg::AsyncCommitGroupOp, ttg::AsyncWaitOp>(op))
return op;
if (op->hasTrait<OpTrait::LocalLoadTrait>())
return op;
if (isa<ttg::LocalStoreOp>(op))
return op;
if (isa<ttng::TMEMAllocOp, ttng::TMEMLoadOp>(op))
return op;
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
rewriter.setInsertionPoint(op);
Value cnd = getPredMask(rewriter, ifOp.getCondition().getType(),
ifOp.getCondition(), pred);
ifOp.getConditionMutable().assign(cnd);
return op;
}
if (auto asyncCopyOp = dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(op)) {
rewriter.setInsertionPoint(asyncCopyOp);
Value mask = getPredMask(rewriter, asyncCopyOp.getSrc().getType(),
asyncCopyOp.getMask(), pred);
asyncCopyOp.getMaskMutable().assign(mask);
return op;
}
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
rewriter.setInsertionPoint(loadOp);
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
loadOp.getMask(), pred);
loadOp.getMaskMutable().assign(mask);
return op;
}
if (auto copyOp = dyn_cast<ttng::AsyncTMACopyGlobalToLocalOp>(op)) {
rewriter.setInsertionPoint(copyOp);
Value mask = getPredMask(rewriter, copyOp.getPred().getType(),
copyOp.getPred(), pred);
copyOp.getPredMutable().assign(mask);
return op;
}
if (auto gatherOp = dyn_cast<ttng::AsyncTMAGatherOp>(op)) {
rewriter.setInsertionPoint(gatherOp);
Value mask = getPredMask(rewriter, gatherOp.getPred().getType(),
gatherOp.getPred(), pred);
gatherOp.getPredMutable().assign(mask);
return op;
}
if (auto expectOp = dyn_cast<ttng::BarrierExpectOp>(op)) {
rewriter.setInsertionPoint(expectOp);
Value mask = getPredMask(rewriter, expectOp.getPred().getType(),
expectOp.getPred(), pred);
expectOp.getPredMutable().assign(mask);
return op;
}
if (auto mmav5Op = dyn_cast<ttng::MMAv5OpInterface>(op)) {
rewriter.setInsertionPoint(mmav5Op);
auto currPred = mmav5Op.getPredicate();
Value mask = getPredMask(rewriter, currPred.getType(), currPred, pred);
mmav5Op.setPredicate(mask);
return op;
}
if (auto tmemStoreOp = dyn_cast<ttng::TMEMStoreOp>(op)) {
rewriter.setInsertionPoint(tmemStoreOp);
Value mask = getPredMask(rewriter, tmemStoreOp.getPred().getType(),
tmemStoreOp.getPred(), pred);
tmemStoreOp.getPredMutable().assign(mask);
return op;
}
if (auto waitBarrier = dyn_cast<ttng::WaitBarrierOp>(op)) {
rewriter.setInsertionPoint(waitBarrier);
Value mask = pred;
Value currentPred = waitBarrier.getPred();
if (currentPred) {
mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred);
}
waitBarrier.getPredMutable().assign(mask);
return op;
}
if (auto arriveBarrier = dyn_cast<ttng::ArriveBarrierOp>(op)) {
rewriter.setInsertionPoint(arriveBarrier);
Value mask = pred;
Value currentPred = arriveBarrier.getPred();
if (currentPred) {
mask = getPredMask(rewriter, currentPred.getType(), currentPred, pred);
}
arriveBarrier.getPredMutable().assign(mask);
return op;
}
if (auto storeOp = dyn_cast<tt::StoreOp>(op)) {
rewriter.setInsertionPoint(storeOp);
Value mask = getPredMask(rewriter, storeOp.getPtr().getType(),
storeOp.getMask(), pred);
storeOp.getMaskMutable().assign(mask);
return op;
}
if (auto atomicRMWOp = dyn_cast<tt::AtomicRMWOp>(op)) {
rewriter.setInsertionPoint(atomicRMWOp);
Value mask = getPredMask(rewriter, atomicRMWOp.getPtr().getType(),
atomicRMWOp.getMask(), pred);
atomicRMWOp.getMaskMutable().assign(mask);
return op;
}
if (!op->isRegistered()) {
return op;
}
op->emitOpError("pipeliner doesn't know how to predicate this op.");
llvm::report_fatal_error("Fatal pipeliner error");
return op;
}
Operation *mlir::triton::wrapInMaskOp(RewriterBase &rewriter, Operation *op,
Value pred) {
auto mask =
rewriter.create<ttg::MaskOp>(op->getLoc(), op->getResultTypes(), pred);
rewriter.createBlock(&mask->getRegion(0));
rewriter.setInsertionPointToStart(&mask->getRegion(0).front());
auto newOp = rewriter.clone(*op);
rewriter.create<ttg::MaskReturnOp>(op->getLoc(), newOp->getResults());
op->replaceAllUsesWith(mask->getResults());
rewriter.eraseOp(op);
return mask;
}
void mlir::triton::resolveMaskOp(ModuleOp moduleOp,
DenseSet<ttg::MaskOp> &peeledMaskOps) {
IRRewriter rewriter(moduleOp);
auto arithDialect =
moduleOp.getContext()->getLoadedDialect<arith::ArithDialect>();
RewritePatternSet patterns(moduleOp.getContext());
arithDialect->getCanonicalizationPatterns(patterns);
if (mlir::applyPatternsGreedily(moduleOp, std::move(patterns)).failed())
return llvm::report_fatal_error("Failed to canonicalize the IR");
for (auto maskOp : peeledMaskOps) {
rewriter.setInsertionPoint(maskOp);
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
Operation *op = &maskOp.getBody()->front();
if (isConstantIntValue(maskOp.getPred(), 0)) {
if (op->getNumResults() > 0) {
SmallVector<Value> results;
for (auto result : op->getResults()) {
auto poisonOp = rewriter.create<mlir::ub::PoisonOp>(
op->getLoc(), result.getType());
results.push_back(poisonOp);
}
op->replaceAllUsesWith(results);
}
op->erase();
}
}
}
SmallVector<ttg::MaskOp> maskOps;
moduleOp->walk([&](ttg::MaskOp maskOp) { maskOps.push_back(maskOp); });
for (auto maskOp : maskOps) {
rewriter.setInsertionPoint(maskOp);
while (&maskOp.getBody()->front() != maskOp.getBody()->getTerminator()) {
Operation *op = &maskOp.getBody()->front();
rewriter.moveOpBefore(op, maskOp);
op = triton::predicateOp(rewriter, op, maskOp.getPred());
}
maskOp->replaceAllUsesWith(
maskOp.getBody()->getTerminator()->getOperands());
maskOp->erase();
}
}
bool mlir::triton::getDisallowAccMultiBuffer(scf::ForOp forOp) {
return forOp->hasAttr(mlir::triton::kDisallowAccMultiBufferAttrName);
}
std::pair<OpResult, int64_t>
mlir::triton::getDefinitionAndDistance(scf::ForOp forOp, Value value) {
int64_t distance = 0;
DenseSet<Value> seen;
while (auto arg = dyn_cast<BlockArgument>(value)) {
if (arg.getOwner() != forOp.getBody())
return {nullptr, 0};
if (arg.getArgNumber() == 0)
return {nullptr, 0};
++distance;
value = forOp.getYieldedValues()[arg.getArgNumber() - 1];
if (!seen.insert(value).second)
return {nullptr, 0};
}
return {cast<OpResult>(value), distance};
}
std::pair<Operation *, int64_t>
mlir::triton::getDefiningOpAndDistance(scf::ForOp forOp, Value value) {
auto [definition, distance] = getDefinitionAndDistance(forOp, value);
return {definition ? definition.getDefiningOp() : nullptr, distance};
}
int mlir::triton::getCopyVecBytes(RankedTensorType registerTy,
ttg::SharedEncodingTrait sharedEnc) {
auto shape = registerTy.getShape();
auto regLayout = triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
auto sharedLayout = triton::gpu::toLinearLayout(shape, sharedEnc);
auto regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
const int vecElems = regToSharedLayout.getNumConsecutiveInOut();
return vecElems * registerTy.getElementTypeBitWidth() / 8;
}
bool mlir::triton::canBeConvertedToAsyncLoad(
tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
assert(!isLoadFromTensorPtr(loadOp) &&
"Block ptr should have been lowered before this pass.");
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
unsigned width = 0;
if (tensorTy) {
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
width = vec * ty.getIntOrFloatBitWidth();
} else {
width = cast<tt::PointerType>(ptr.getType())
.getPointeeType()
.getIntOrFloatBitWidth();
}
LDBG("Load " << *loadOp << " has width " << width);
return width >= 32;
}
void mlir::triton::serializeLatencies(ModuleOp module,
DenseMap<Operation *, int> &opLatency) {
auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper();
auto builder = Builder(module);
for (auto &[op, latency] : opLatency) {
helper.setAttr(op, builder.getI32IntegerAttr(latency));
}
}
void mlir::triton::serializeSelfLatencies(
ModuleOp module, DenseMap<Operation *, int> &opSelfLatency) {
auto helper = TritonDialect::getLoaded(module)->getSelfLatencyAttrHelper();
auto builder = Builder(module);
for (auto &[op, latency] : opSelfLatency) {
helper.setAttr(op, builder.getI32IntegerAttr(latency));
}
}
DenseMap<Operation *, int> mlir::triton::deserializeLatencies(Operation *op) {
DenseMap<Operation *, int> opLatency;
auto latencyHelper = TritonDialect::getLoaded(op)->getLatencyAttrHelper();
op->walk([&](Operation *op) {
if (auto attr = latencyHelper.getAttr(op)) {
opLatency[op] = attr.getInt();
latencyHelper.removeAttr(op);
}
});
return opLatency;
}
Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type,
unsigned numBuffers) {
MLIRContext *ctx = rewriter.getContext();
unsigned numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(
rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>());
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(rewriter.getContext());
auto barrierCTALayout =
ttg::CTALayoutAttr::get(ctx, {numCTAs},
{1}, {0});
auto barrierEncoding =
ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout);
ttg::MemDescType memDescType = ttg::MemDescType::get(
{numBuffers, 1}, type, barrierEncoding, sharedMemorySpace,
true);
return rewriter.create<ttg::LocalAllocOp>(memDescType, Value());
}
Value mlir::triton::createBarrierAlloc(Operation *op, int numBarriers,
int arriveCount) {
ImplicitLocOpBuilder rewriter(op->getLoc(), op);
Value barrierAlloc =
createScalarAlloc(rewriter, rewriter.getI64Type(), numBarriers);
for (unsigned i = 0; i < numBarriers; i++) {
Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i);
rewriter.create<ttng::InitBarrierOp>(barrierView, arriveCount);
}
rewriter.setInsertionPointAfter(op);
for (unsigned i = 0; i < numBarriers; i++) {
Value barrierView = createSingleBufferView(rewriter, barrierAlloc, i);
rewriter.create<ttng::InvalBarrierOp>(barrierView);
}
rewriter.create<ttg::LocalDeallocOp>(barrierAlloc);
return barrierAlloc;
}
Value mlir::triton::createAlloc(Operation *insertBefore, RankedTensorType ty,
Location loc,
gpu::SharedEncodingTrait sharedEnc,
unsigned distance) {
OpBuilder builder(insertBefore);
Attribute sharedMemorySpace =
ttg::SharedMemorySpaceAttr::get(insertBefore->getContext());
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), distance);
Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(),
sharedEnc, sharedMemorySpace,
true);
Value alloc = builder.create<ttg::LocalAllocOp>(loc, memdescType);
builder.setInsertionPointAfter(insertBefore);
builder.create<ttg::LocalDeallocOp>(insertBefore->getLoc(), alloc);
return alloc;
}
bool mlir::triton::isTMALoad(Operation *op) {
return isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op);
}
bool mlir::triton::canBeAsyncLoad(Operation *op) {
if (mlir::triton::isTMALoad(op)) {
return true;
}
assert(isa<tt::LoadOp>(op));
ttg::SharedEncodingTrait sharedEncoding = mlir::triton::getSharedEncoding(op);
int copyVecBytes = mlir::triton::getCopyVecBytes(
cast<RankedTensorType>(op->getResultTypes()[0]), sharedEncoding);
if (copyVecBytes >= 4) {
return true;
}
return false;
}
void mlir::triton::combineRedundantWaitOps(
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> &waitOps) {
llvm::MapVector<ttg::AsyncWaitOp, ttg::AsyncWaitOp> toDelete;
for (auto waitOp : waitOps) {
if (toDelete.count(waitOp))
continue;
SmallVector<ttg::AsyncWaitOp> waitGroup = {waitOp};
SmallVector<Value> depTokens = waitOp.getOperands();
unsigned minWaitNumber = waitOp.getNum();
Operation *next = waitOp->getNextNode();
while (next &&
!isa<ttg::AsyncCommitGroupOp, RegionBranchOpInterface>(next)) {
if (auto nextWait = dyn_cast<ttg::AsyncWaitOp>(next)) {
waitGroup.push_back(nextWait);
minWaitNumber = std::min(minWaitNumber, nextWait.getNum());
depTokens.append(nextWait.getOperands().begin(),
nextWait.getOperands().end());
}
next = next->getNextNode();
}
if (waitGroup.size() == 1)
continue;
OpBuilder builder(waitGroup.front());
auto newWaitOp = builder.create<ttg::AsyncWaitOp>(waitOp.getLoc(),
depTokens, minWaitNumber);
for (auto waitOp : waitGroup) {
toDelete[waitOp] = newWaitOp;
}
}
for (auto waitOp : toDelete) {
waitOp.first->replaceAllUsesWith(waitOp.second);
waitOp.first->erase();
}
}
ttg::MemDescType mlir::triton::getBufferViewType(ttg::MemDescType allocTy,
bool mutableMemory) {
return ttg::MemDescType::get(allocTy.getShape().drop_front(),
allocTy.getElementType(), allocTy.getEncoding(),
allocTy.getMemorySpace(), mutableMemory,
allocTy.getAllocShape());
}
ttg::MemDescType
mlir::triton::getMultiBufferedType(ttg::MemDescType memDescType,
int32_t depth) {
auto shape = memDescType.getShape();
SmallVector<int64_t> bufferShape(shape.begin(), shape.end());
bufferShape.insert(bufferShape.begin(), depth);
return ttg::MemDescType::get(
bufferShape, memDescType.getElementType(), memDescType.getEncoding(),
memDescType.getMemorySpace(), true);
}
ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) {
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
auto order = ttg::getOrder(ty);
return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order,
ctaLayout);
}
ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) {
ttg::SharedEncodingTrait localAllocEnc;
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<ttg::LocalAllocOp>(user);
})) {
for (auto user : op->getUsers()) {
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(user);
if (!localAlloc)
continue;
auto enc = mlir::cast<ttg::SharedEncodingTrait>(
localAlloc.getType().getEncoding());
if (!localAllocEnc) {
localAllocEnc = enc;
}
if (enc != localAllocEnc) {
op->emitRemark()
<< "Pipelining load with different use encodings. This will lead "
"to layout conversions and performance degradation.";
continue;
}
}
}
auto ty = cast<RankedTensorType>(op->getResultTypes()[0]);
auto ctaLayout = ttg::getCTALayout(ty.getEncoding());
auto order = ttg::getOrder(ty);
if (isTMALoad(op)) {
TypedValue<tt::TensorDescType> desc;
if (auto load = dyn_cast<tt::DescriptorLoadOp>(op)) {
desc = load.getDesc();
} else if (auto gather = dyn_cast<tt::DescriptorGatherOp>(op)) {
desc = gather.getDesc();
} else {
op->emitError() << "unrecognized tma load type";
llvm::report_fatal_error("unrecognized tma load type");
}
return ttng::getEncodingFromDescriptor(op, ty, desc);
}
if (localAllocEnc)
return localAllocEnc;
bool incompatible = false;
localAllocEnc =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
.value_or(nullptr);
if (localAllocEnc)
return localAllocEnc;
return ttg::SwizzledSharedEncodingAttr::get(ty.getContext(), 1, 1, 1, order,
ctaLayout);
}
int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp,
int defaultNumStages) {
auto helper = TritonDialect::getLoaded(forOp)->getNumStagesAttrHelper();
if (auto attr = helper.getAttr(forOp))
return attr.getInt();
return defaultNumStages;
}
TypedValue<ttg::MemDescType>
triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) {
assert(isa<ttg::MemDescType>(alloc.getType()) && "Expected MemDescType");
auto allocDescType = cast<ttg::MemDescType>(alloc.getType());
SmallVector<int64_t> shape;
assert(allocDescType.getShape().size() > 1 &&
"Expected multi-dimensional memdesc (e.g., Nx...) for subview");
shape.insert(shape.end(), allocDescType.getShape().begin() + 1,
allocDescType.getShape().end());
auto viewDescType = ttg::MemDescType::get(
shape, allocDescType.getElementType(), allocDescType.getEncoding(),
allocDescType.getMemorySpace(), allocDescType.getMutableMemory(),
allocDescType.getAllocShape());
return builder.create<ttg::MemDescIndexOp>(alloc.getLoc(), viewDescType,
alloc, idx);
}
TypedValue<ttg::MemDescType>
triton::createSingleBufferView(OpBuilder &builder, Value alloc, int idx) {
Value idxVal = builder.create<arith::ConstantIntOp>(alloc.getLoc(), idx, 32);
return createSingleBufferView(builder, alloc, idxVal);
}
Value triton::createIncrementModulo(OpBuilder &builder, Location loc,
Value counter, Value modulus, Value zero,
Value one, Value *outWrapCond) {
Value addOne = builder.create<arith::AddIOp>(loc, counter, one);
Value outOfRangeCond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, addOne, modulus);
if (outWrapCond)
*outWrapCond = outOfRangeCond;
return builder.create<arith::SelectOp>(loc, outOfRangeCond, zero, addOne);
}
static void
allocTMABuffers(scf::ForOp forOp,
llvm::MapVector<Operation *, Value> &tmaBufferMapping,
int maxStage) {
IRRewriter rewriter(forOp);
forOp.walk([&](tt::MakeTensorDescOp op) {
auto loc = op.getLoc();
Value alloc = rewriter.create<triton::gpu::GlobalScratchAllocOp>(
loc, triton::getPointerType(rewriter.getI8Type()),
maxStage * ttng::TMA_SIZE_BYTES, ttng::TMA_ALIGN);
tmaBufferMapping[op.getOperation()] = alloc;
});
}
static Value subviewTMADescriptor(OpBuilder &builder, Location loc, Value alloc,
Value counter) {
Value tmaSizeVal =
builder.create<arith::ConstantIntOp>(loc, ttng::TMA_SIZE_BYTES, 32);
Value offset = builder.create<arith::MulIOp>(loc, tmaSizeVal, counter);
return builder.create<triton::AddPtrOp>(loc, alloc.getType(), alloc, offset);
}
static LogicalResult rewriteTMABufferUpdates(
scf::ForOp forOp,
const llvm::MapVector<Operation *, Value> &tmaBufferMapping,
ArrayRef<BlockArgument> tmaCounters, int numBuffers, Value one, Value zero,
triton::CoarseSchedule &schedule) {
assert(tmaBufferMapping.size() == tmaCounters.size());
Value numBuffersVal = mlir::OpBuilder(forOp).create<arith::ConstantIntOp>(
forOp.getLoc(), numBuffers, 32);
for (auto [iOp, pair] : llvm::enumerate(tmaBufferMapping)) {
auto &[op, alloc] = pair;
auto makeDescOp = cast<tt::MakeTensorDescOp>(op);
triton::OpBuilderForStage builder(makeDescOp.getLoc(), makeDescOp,
schedule);
BlockArgument counter = tmaCounters[iOp];
Value nextBuf =
subviewTMADescriptor(builder, builder.getLoc(), alloc, counter);
if (failed(ttng::createTMADesc(nextBuf, makeDescOp, builder))) {
return failure();
}
builder.create<ttng::TensormapFenceproxyAcquireOp>(nextBuf);
Value nextDesc = builder.create<ttng::ReinterpretTensorDescOp>(
makeDescOp.getType(), nextBuf);
makeDescOp.getResult().replaceAllUsesWith(nextDesc);
Value nextCounter = createIncrementModulo(
builder, builder.getLoc(), counter, numBuffersVal, zero, one);
IRRewriter rewriter(forOp);
nextCounter = triton::sinkValueRedefinition(rewriter, counter, nextCounter,
op->getBlock());
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
forYield.setOperand(counter.getArgNumber() - 1, nextCounter);
}
return success();
}
scf::ForOp triton::lowerTMADescriptors(scf::ForOp forOp,
CoarseSchedule &schedule) {
llvm::MapVector<Operation *, Value> tmaBufferMapping;
int maxStage = schedule.getNumStages() - 1;
for (auto &op : forOp.getBody()->without_terminator()) {
if (auto wgMmaOp = dyn_cast<ttng::WarpGroupDotOp>(&op)) {
maxStage += 1;
break;
}
}
allocTMABuffers(forOp, tmaBufferMapping, maxStage);
if (tmaBufferMapping.empty())
return forOp;
IRRewriter builder(forOp);
Location loc = forOp.getLoc();
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
SmallVector<Value> newOperands;
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
unsigned tmaCounterArgsStartIdx = newOperandIndex + newOperands.size();
for (int i = 0; i < tmaBufferMapping.size(); ++i) {
newOperands.push_back(zero);
}
forOp = addIterArgsToLoop(builder, forOp, newOperands);
auto tmaCounters = ArrayRef<BlockArgument>(forOp.getBody()->getArguments())
.slice(tmaCounterArgsStartIdx);
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (unsigned i = 0; i < newOperands.size(); ++i) {
forYield.getResultsMutable().append(newOperands[i]);
}
if (failed(rewriteTMABufferUpdates(forOp, tmaBufferMapping, tmaCounters,
maxStage, one, zero, schedule))) {
llvm_unreachable("Failed to rewrite TMA ops");
}
return forOp;
}
DenseSet<Operation *>
triton::getTopLevelUsersInLoop(Operation *op, scf::ForOp forOp,
std::function<bool(Operation *)> filter) {
DenseSet<Operation *> topLevelUsers;
SmallVector<OpOperand *> q;
for (auto &use : op->getUses())
q.push_back(&use);
while (!q.empty()) {
auto use = q.pop_back_val();
auto yieldOp = dyn_cast<scf::YieldOp>(use->getOwner());
if (yieldOp && yieldOp->getParentOp() == forOp) {
for (auto &use :
forOp.getRegionIterArgs()[use->getOperandNumber()].getUses())
q.push_back(&use);
continue;
}
if (use->getOwner()->hasTrait<OpTrait::MemDescViewTrait>()) {
for (auto &use : use->getOwner()->getUses())
q.push_back(&use);
continue;
}
if (filter && !filter(use->getOwner()))
continue;
Operation *topLevelUser =
forOp.getBody()->findAncestorOpInBlock(*use->getOwner());
topLevelUsers.insert(topLevelUser);
}
return topLevelUsers;
}
static Operation *getUseOfPipelinedOp(
ArrayRef<Operation *> ops, scf::ForOp forOp,
triton::CoarseSchedule &schedule,
std::function<bool(Operation *)> filterUse,
std::function<bool(Operation *, Operation *)> shouldPrefer) {
DenseSet<Operation *> topLevelUsers;
Operation *selectedUser = nullptr;
for (Operation *op : ops) {
auto users = triton::getTopLevelUsersInLoop(op, forOp, filterUse);
topLevelUsers.insert(users.begin(), users.end());
}
for (Operation *topLevelUser : topLevelUsers) {
assert(schedule.count(topLevelUser) && "op user not found in the schedule");
if (!selectedUser || shouldPrefer(topLevelUser, selectedUser)) {
selectedUser = topLevelUser;
}
}
return selectedUser;
}
Operation *
triton::getFirstUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
triton::CoarseSchedule &schedule,
std::function<bool(Operation *)> filterUse) {
return getUseOfPipelinedOp(
ops, forOp, schedule, filterUse,
[&](Operation *candidate, Operation *current) {
auto [candidateStage, candidateCluster] = schedule[candidate];
auto [currentStage, currentCluster] = schedule[current];
return candidateStage < currentStage ||
(candidateStage == currentStage &&
schedule.clusters.isBefore(candidateCluster, currentCluster)) ||
(candidateStage == currentStage &&
candidateCluster == currentCluster &&
candidate->isBeforeInBlock(current));
});
}
Operation *
triton::getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
triton::CoarseSchedule &schedule,
std::function<bool(Operation *)> filterUse) {
return getUseOfPipelinedOp(
ops, forOp, schedule, filterUse,
[&](Operation *candidate, Operation *current) {
auto [candidateStage, candidateCluster] = schedule[candidate];
auto [currentStage, currentCluster] = schedule[current];
return candidateStage > currentStage ||
(candidateStage == currentStage &&
schedule.clusters.isBefore(currentCluster, candidateCluster)) ||
(candidateStage == currentStage &&
candidateCluster == currentCluster &&
current->isBeforeInBlock(candidate));
});
}