#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "triton-wgmma-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
#define int_attr(num) builder.getI64IntegerAttr(num)
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
static bool rsDotNeedsWait(Operation *dot, scf::ForOp forOp) {
auto dotOp = dyn_cast<ttng::WarpGroupDotOp>(dot);
if (!dotOp)
return false;
auto a = dotOp.getA();
if (!isa<RankedTensorType>(a.getType())) {
return false;
}
if (forOp.isDefinedOutsideOfLoop(a)) {
return false;
}
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(a.getDefiningOp())) {
return !isa<ttg::NvidiaMmaEncodingAttr>(
cvt.getSrc().getType().getEncoding());
}
return true;
}
static int minNumInterleavedCommitOps(Operation *waitOp) {
auto countCommitsBetween = [](Operation *op1, Operation *op2) {
int count = 0;
for (auto op = op1; op != op2; op = op->getNextNode()) {
if (isa<ttg::AsyncCommitGroupOp>(op))
count++;
}
return count;
};
int minCommitNumber = INT_MAX;
std::function<int(Value, Operation *, int)> minOverHistories =
[&](Value val, Operation *sinkOp, int thisHistorySum) -> int {
if (Operation *defOp = val.getDefiningOp()) {
thisHistorySum += countCommitsBetween(defOp->getNextNode(), sinkOp);
minCommitNumber = std::min(minCommitNumber, thisHistorySum);
return minCommitNumber;
}
if (auto arg = mlir::dyn_cast<BlockArgument>(val)) {
Block *block = arg.getOwner();
auto forOp = dyn_cast<scf::ForOp>(block->getParentOp());
if (!forOp)
return 0;
Operation *firstForInst = &*forOp.getBody()->begin();
int insertsBetween = countCommitsBetween(firstForInst, sinkOp);
thisHistorySum += insertsBetween;
if (thisHistorySum >= minCommitNumber)
return minCommitNumber;
Value incomingVal = forOp.getInitArgs()[arg.getArgNumber() - 1];
int min1 = minOverHistories(incomingVal, forOp, thisHistorySum);
Operation *yieldOp = block->getTerminator();
Value prevVal = yieldOp->getOperand(arg.getArgNumber() - 1);
int min2 = minOverHistories(prevVal, yieldOp, thisHistorySum);
return std::min(std::min(min1, min2), minCommitNumber);
}
return 0;
};
if (waitOp->getNumOperands() != 1)
return 0;
Value val = waitOp->getOperand(0);
while (waitOp->getParentRegion() != val.getParentRegion())
waitOp = waitOp->getParentOp();
int minCommits = minOverHistories(val, waitOp, 0);
return minCommits;
}
void mlir::triton::updateWaits(ModuleOp module) {
llvm::SmallSetVector<ttg::AsyncWaitOp, 8> waitOps;
module.walk([&](ttg::AsyncWaitOp waitOp) {
int minNumCommits = minNumInterleavedCommitOps(waitOp);
waitOp.setNum(minNumCommits);
waitOps.insert(waitOp);
});
tt::combineRedundantWaitOps(waitOps);
}
static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
MutableArrayRef<Value> values) {
IRRewriter builder(wait.getContext());
builder.setInsertionPoint(wait);
size_t origNumOperands = wait.getNumOperands();
SetVector<Value> newOperands(wait.getOperands().begin(),
wait.getOperands().end());
assert(newOperands.size() == origNumOperands &&
"Wait op has duplicate operands.");
newOperands.insert(values.begin(), values.end());
SmallVector<ttng::WarpGroupDotOp> asyncDots;
for (Value v : values) {
BackwardSliceOptions options;
options.omitBlockArguments = true;
options.filter = [&](Operation *op) {
if (auto dot = dyn_cast<ttng::WarpGroupDotOp>(op)) {
asyncDots.push_back(dot);
return false;
}
return op->getBlock() == wait->getBlock();
};
SetVector<Operation *> slice;
(void)getBackwardSlice(v, &slice, options);
}
for (ttng::WarpGroupDotOp dot : asyncDots) {
for (Value operand : dot.getOperands()) {
if (isa<ttg::MemDescType>(operand.getType())) {
newOperands.insert(operand);
}
}
}
auto newWait = builder.create<ttng::WarpGroupDotWaitOp>(
wait.getLoc(), llvm::to_vector(newOperands), wait.getPendings());
auto dominatedByNewWait = [&](OpOperand &operand) {
auto opInThisBlock =
newWait->getBlock()->findAncestorOpInBlock(*operand.getOwner());
return opInThisBlock && newWait->isBeforeInBlock(opInThisBlock);
};
for (int i = 0; i < origNumOperands; i++) {
Value operand = wait.getResult(i);
if (!isa<ttg::MemDescType>(operand.getType()))
operand.replaceAllUsesWith(newWait.getResult(i));
}
for (int i = origNumOperands; i < newOperands.size(); i++) {
Value operand = newWait.getOperand(i);
if (!isa<ttg::MemDescType>(operand.getType()))
operand.replaceUsesWithIf(newWait.getResult(i), dominatedByNewWait);
}
wait->erase();
}
SmallVector<Value> splitLhs(OpBuilder &builder,
TypedValue<RankedTensorType> lhs, int64_t newK) {
auto loc = lhs.getLoc();
auto type = lhs.getType();
auto rank = type.getRank();
auto shape = to_vector(type.getShape());
auto nSplits = shape.back() / newK;
assert(nSplits > 1);
shape.pop_back();
for (int i = 1; i < nSplits; i *= 2) {
shape.push_back(2);
}
shape.push_back(newK);
lhs = builder.create<tt::ReshapeOp>(loc, shape, lhs);
auto transOrder = to_vector(llvm::seq<int>(rank - 1));
transOrder.push_back(shape.size() - 1);
llvm::append_range(transOrder, llvm::reverse(llvm::seq(
rank - 1, (int64_t)shape.size() - 1)));
lhs = builder.create<tt::TransOp>(loc, lhs, transOrder);
SmallVector<Value> curr;
SmallVector<Value> ret = {lhs};
for (int i = 1; i < nSplits; i *= 2) {
curr = ret;
ret.clear();
for (auto v : curr) {
auto split = builder.create<tt::SplitOp>(loc, v);
ret.push_back(split.getResult(0));
ret.push_back(split.getResult(1));
}
}
auto mmav3Type =
type.clone(cast<RankedTensorType>(ret.front().getType()).getShape());
for (auto &v : ret) {
v = builder.create<ttg::ConvertLayoutOp>(loc, mmav3Type, v);
assert(isNoop(v.getDefiningOp()));
}
assert(ret.size() == nSplits);
return ret;
}
SmallVector<Value> splitRhs(OpBuilder &builder,
TypedValue<ttg::MemDescType> rhs, int64_t newK) {
auto loc = rhs.getLoc();
auto type = rhs.getType();
auto rank = type.getRank();
auto kDim = rank - 2;
auto nSplits = type.getShape()[kDim] / newK;
auto shape = llvm::to_vector(type.getShape());
shape[kDim] = newK;
SmallVector<int32_t> offsets(rank, 0);
auto newType = ttg::MemDescType::get(
shape, type.getElementType(), type.getEncoding(), type.getMemorySpace(),
false, type.getAllocShape());
SmallVector<Value> ret;
for (int i = 0; i < nSplits; i++) {
offsets[kDim] = i * newK;
Value newSmem =
builder.create<ttg::MemDescSubsliceOp>(loc, newType, rhs, offsets);
ret.push_back(newSmem);
}
return ret;
}
std::vector<ttng::WarpGroupDotOp> splitRSDot(ttng::WarpGroupDotOp dotOp) {
if (!isa<RankedTensorType>(dotOp.getA().getType())) {
return {dotOp};
}
auto a = cast<TypedValue<RankedTensorType>>(dotOp.getA());
auto b = cast<TypedValue<ttg::MemDescType>>(dotOp.getB());
auto origK = a.getType().getShape().back();
auto newK = cast<ttg::NvidiaMmaEncodingAttr>(dotOp.getType().getEncoding())
.getInstrShape()[2];
auto numSplits = origK / newK;
if (numSplits <= 1) {
return {dotOp};
}
assert(origK % newK == 0 && "origK must be divisible by newK");
auto builder = OpBuilder(dotOp);
auto loc = dotOp.getLoc();
auto lhss = splitLhs(builder, a, newK);
auto rhss = splitRhs(builder, b, newK);
assert(lhss.size() == numSplits && "lhs must have the same number of splits");
assert(rhss.size() == numSplits && "rhs must have the same number of splits");
Value useC = dotOp.getUseC();
Value C = dotOp.getC();
auto numImpreciseAccLeft = dotOp.getMaxNumImpreciseAcc();
std::vector<ttng::WarpGroupDotOp> dots;
for (int i = 0; i < numSplits; i++) {
auto take = std::min(numImpreciseAccLeft, newK);
uint32_t numImpreciseAcc = (take == newK) ? (1u << 30) : take;
numImpreciseAccLeft -= take;
auto dot = builder.create<ttng::WarpGroupDotOp>(
loc, dotOp.getType(), lhss[i], rhss[i], C, useC,
dotOp.getInputPrecision(), numImpreciseAcc, dotOp.getIsAsync());
dots.push_back(dot);
C = dot.getResult();
useC = builder.create<mlir::arith::ConstantIntOp>(loc, 1, 1);
}
dotOp.replaceAllUsesWith(dots.back().getResult());
dotOp.erase();
return dots;
}
llvm::MapVector<Operation *, int>
splitRSDots(const llvm::MapVector<Operation *, int> &dots) {
llvm::MapVector<Operation *, int> ret;
for (auto [dot, iterArgIdx] : dots) {
auto newDots = splitRSDot(cast<ttng::WarpGroupDotOp>(dot));
for (auto newDot : newDots) {
ret.insert({newDot, iterArgIdx});
}
}
return ret;
}
static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
scf::ForOp forOp) {
LDBG("Considering whether to make MMAv3 dot properly async: " << dotOp);
auto checkOperand = [&](Value operand) {
if (isa<RankedTensorType>(operand.getType())) {
return true;
}
Value transitiveOperand = operand;
while (isa_and_nonnull<ttg::ConvertLayoutOp, ttg::MemDescTransOp,
ttg::MemDescReshapeOp, ttg::MemDescSubsliceOp>(
transitiveOperand.getDefiningOp()) ||
isa<BlockArgument>(transitiveOperand)) {
auto blockArg = dyn_cast<BlockArgument>(transitiveOperand);
if (blockArg && blockArg.getOwner() == forOp.getBody()) {
transitiveOperand =
cast<scf::YieldOp>(blockArg.getOwner()->getTerminator())
.getOperand(blockArg.getArgNumber() - 1);
} else if (Operation *def = transitiveOperand.getDefiningOp()) {
transitiveOperand = def->getOperand(0);
}
}
return forOp.isDefinedOutsideOfLoop(transitiveOperand) ||
transitiveOperand.getDefiningOp<ttg::MemDescIndexOp>();
};
assert(isa<ttg::NvidiaMmaEncodingAttr>(dotOp.getC().getType().getEncoding()));
if (!checkOperand(dotOp.getA()) || !checkOperand(dotOp.getB())) {
LDBG("Can't make dot async because shmem operands aren't multi-buffered");
return std::nullopt;
}
int iterArgIdx = -1;
Value iterArg = nullptr;
SmallVector<std::pair<Operation *, int>> queue;
for (auto &use : dotOp->getUses()) {
queue.push_back({use.getOwner(), use.getOperandNumber()});
}
while (!queue.empty()) {
auto [user, argIdx] = queue.pop_back_val();
if (user->getParentOp() == forOp) {
if (isNoop(user)) {
for (auto &use : user->getResult(0).getUses()) {
queue.push_back({use.getOwner(), use.getOperandNumber()});
}
continue;
}
if (isa<scf::YieldOp>(user)) {
if (iterArg) {
LDBG("Can't make dot async because dot is used by multiple ops in "
"the loop.");
return std::nullopt;
}
iterArgIdx = argIdx;
iterArg = forOp.getRegionIterArg(argIdx);
continue;
}
LDBG("Can't make dot async because dot is unconditionally used in the "
"loop.");
return std::nullopt;
}
if (auto ifOp = dyn_cast<scf::IfOp>(user->getParentOp())) {
if (isa<scf::YieldOp>(user)) {
auto uses = ifOp.getResult(argIdx).getUses();
for (auto &use : uses) {
queue.push_back({use.getOwner(), use.getOperandNumber()});
}
}
} else {
return std::nullopt;
}
}
if (!dotOp.getC().getType().getElementType().isF32()) {
LDBG("Can't make dot async because the accumulator is not fp32");
return std::nullopt;
}
std::function<bool(OpOperand &)> isTransitivelyWarpGroupDot =
[&](OpOperand &use) -> bool {
Operation *user = use.getOwner();
if (isa<ttng::WarpGroupDotOp>(user))
return use.getOperandNumber() == 2;
if (isNoop(user))
return llvm::all_of(user->getResult(0).getUses(),
isTransitivelyWarpGroupDot);
return false;
};
if (llvm::all_of(iterArg.getUses(), isTransitivelyWarpGroupDot))
return iterArgIdx;
auto waitOps = forOp.getBody()->getOps<ttng::WarpGroupDotWaitOp>();
auto firstWaitOpIter = llvm::find_if(
waitOps, [&](auto waitOp) { return waitOp.getPendings() == 0; });
if (firstWaitOpIter != waitOps.end() &&
llvm::all_of(iterArg.getUsers(), [&](Operation *user) {
assert(forOp->isAncestor(user));
while (user->getParentOp() != forOp) {
user = user->getParentOp();
}
return (*firstWaitOpIter)->isBeforeInBlock(user);
})) {
LDBG("MMAv3 dot can be properly async because it follows a "
"warp_group_dot_wait "
"{pendings=0}.\n"
<< " wait: " << *firstWaitOpIter << "\n"
<< " dot: " << dotOp);
threadValuesThroughWait(*firstWaitOpIter, {iterArg});
return iterArgIdx;
}
LDBG("Can't make dot async because its result from i-1 is used by "
"something other than another MMAv3 dot as the `c` operand.");
return std::nullopt;
}
static void insertAsyncWarpGroupDotWaitInLoop(
scf::ForOp forOp,
const llvm::MapVector<Operation *, int > &properlyAsyncDots) {
if (properlyAsyncDots.empty())
return;
if (llvm::any_of(forOp.getBody()->getOps<ttng::WarpGroupDotWaitOp>(),
[](auto wait) { return wait.getPendings() == 0; })) {
return;
}
for (auto asyncDot : llvm::make_first_range(properlyAsyncDots)) {
if (rsDotNeedsWait(asyncDot, forOp)) {
OpBuilder builder(asyncDot);
builder.setInsertionPointAfter(asyncDot);
auto newWait = builder.create<ttng::WarpGroupDotWaitOp>(
asyncDot->getLoc(), ArrayRef<Value>{}, properlyAsyncDots.size() - 1);
SmallVector<Value> waitOperands = {asyncDot->getResult(0)};
threadValuesThroughWait(newWait, waitOperands);
continue;
}
SmallVector<OpOperand *> uses;
for (auto &use : asyncDot->getUses()) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(use.getOwner())) {
continue;
}
uses.push_back(&use);
}
DenseMap<Block *, SmallVector<Value>> blockToUsers;
for (auto use : uses) {
auto block = use->getOwner()->getBlock();
blockToUsers[block].push_back(use->get());
}
for (auto [block, users] : blockToUsers) {
OpBuilder builder(block, block->begin());
auto newWait = builder.create<ttng::WarpGroupDotWaitOp>(
asyncDot->getLoc(), ArrayRef<Value>{}, 0);
threadValuesThroughWait(newWait, users);
}
}
IRRewriter builder(forOp.getContext());
auto lastAsyncDot = properlyAsyncDots.back().first;
if (rsDotNeedsWait(lastAsyncDot, forOp)) {
return;
}
builder.setInsertionPointAfter(lastAsyncDot);
auto wait = builder.create<ttng::WarpGroupDotWaitOp>(
lastAsyncDot->getLoc(),
ArrayRef<Value>{}, properlyAsyncDots.size());
SmallVector<Value> addlWaitOperands;
for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
addlWaitOperands.push_back(asyncDot->getResult(0));
}
threadValuesThroughWait(wait, addlWaitOperands);
}
void triton::asyncLaunchDots(scf::ForOp forOp) {
LDBG("Original loop:\n" << *forOp);
IRRewriter builder(forOp.getContext());
llvm::MapVector<Operation *, int > properlyAsyncDots;
for (auto WarpGroupDotOp : forOp.getBody()->getOps<ttng::WarpGroupDotOp>()) {
WarpGroupDotOp.setIsAsync(true);
if (auto iterArgIdx = dotCanBeProperlyAsync(WarpGroupDotOp, forOp)) {
properlyAsyncDots[WarpGroupDotOp] = *iterArgIdx;
} else {
builder.setInsertionPointAfter(WarpGroupDotOp);
auto wait = builder.create<ttng::WarpGroupDotWaitOp>(
WarpGroupDotOp.getLoc(), ArrayRef<Value>{},
0);
SmallVector<Value> waitOperands = {WarpGroupDotOp.getResult()};
threadValuesThroughWait(wait, waitOperands);
}
}
if (properlyAsyncDots.empty()) {
LDBG("No properly async dots.");
return;
}
if (llvm::all_of(forOp.getBody()->getOps<ttng::WarpGroupDotWaitOp>(),
[](auto wait) { return wait.getPendings() != 0; })) {
properlyAsyncDots = splitRSDots(properlyAsyncDots);
}
insertAsyncWarpGroupDotWaitInLoop(forOp, properlyAsyncDots);
SmallVector<Value> waitOperands;
for (auto [asyncDot, iterArgIdx] : properlyAsyncDots) {
waitOperands.push_back(forOp.getResult(iterArgIdx));
}
builder.setInsertionPointAfter(forOp);
auto WarpGroupDotWaitAfterLoop = builder.create<ttng::WarpGroupDotWaitOp>(
forOp.getLoc(), ArrayRef<Value>{}, 0);
threadValuesThroughWait(WarpGroupDotWaitAfterLoop, waitOperands);
}