#include "mlir/IR/Dominance.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "triton-loop-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
namespace mlir::triton::gpu {
bool hasGpuBarriers(scf::ForOp forOp) {
WalkResult result = forOp.walk(
[&](mlir::gpu::BarrierOp barrier) { return WalkResult::interrupt(); });
return result.wasInterrupted();
}
bool isSafeToPipeline(scf::ForOp forOp) {
if (loopHasDistGreaterThanOne(forOp))
return false;
if (isOuterLoop(forOp))
return false;
if (hasGpuBarriers(forOp))
return false;
return true;
}
void scheduleDistanceOneDependencies(scf::ForOp forOp,
CoarseSchedule &schedule) {
int numStages = schedule.getNumStages();
DenseMap<CoarseSchedule::ClusterHash, CoarseSchedule::Cluster> dist1Cluster;
for (auto &op : forOp.getBody()->without_terminator()) {
if (schedule.count(&op) == 0)
continue;
auto [stage, cluster] = schedule[&op];
if (stage == numStages - 1)
continue;
for (Value operand : getNestedOperands(&op)) {
if (auto arg = dyn_cast<BlockArgument>(operand)) {
if (arg.getArgNumber() > 0 && arg.getOwner() == op.getBlock()) {
auto yieldOp = op.getBlock()->getTerminator();
Value v = yieldOp->getOperand(arg.getArgNumber() - 1);
Operation *defOp = v.getDefiningOp();
if (defOp && schedule.count(defOp) == 0) {
if (isa<tt::LoadOp>(defOp)) {
schedule.insertIfAbsent(defOp, stage, cluster);
schedule.insertDepsOfOp(defOp, stage, cluster,
true,
true);
} else {
CoarseSchedule::ClusterHash clusterHash =
CoarseSchedule::hashCluster(cluster);
if (dist1Cluster.count(clusterHash) == 0) {
dist1Cluster[clusterHash] =
schedule.clusters.newBefore(cluster);
}
schedule.insertIfAbsent(defOp, stage + 1,
dist1Cluster[clusterHash]);
schedule.insertDepsOfOp(defOp, stage + 1,
dist1Cluster[clusterHash],
true,
true);
}
}
}
}
}
}
}
void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule,
CoarseSchedule::Cluster afterPrologue) {
int numStages = schedule.getNumStages();
DenseMap<Operation *, CoarseSchedule::Cluster> opToCluster;
for (auto &op : forOp.getBody()->without_terminator()) {
if (schedule.count(&op) == 0) {
opToCluster[&op] = afterPrologue;
}
}
SmallVector<Operation *> queue;
for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) {
if (stage == numStages - 1) {
queue.push_back(op);
}
}
while (!queue.empty()) {
Operation *op = queue.pop_back_val();
for (auto user : op->getUsers()) {
if (opToCluster.count(user)) {
CoarseSchedule::Cluster userCluster = opToCluster[user];
CoarseSchedule::Cluster opCluster;
if (schedule.count(op))
opCluster = schedule[op].second;
else
opCluster = opToCluster[op];
if (*userCluster < *opCluster) {
opToCluster[user] = opCluster;
queue.push_back(user);
}
}
}
}
for (auto [op, cluster] : opToCluster) {
schedule.insert(op, numStages - 1, cluster);
}
}
namespace {
bool hasLatenciesAssigned(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
for (auto &op : forOp.getBody()->without_terminator()) {
if (opLatency.count(&op))
return true;
}
return false;
}
CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
llvm::MapVector<Operation *, int> opToStage;
auto terminator = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
SmallVector<Operation *> latOps;
for (auto &op : forOp.getBody()->without_terminator()) {
if (opLatency.count(&op))
latOps.push_back(&op);
}
if (latOps.empty())
return CoarseSchedule(0);
DominanceInfo domInfo(forOp);
DenseMap<Operation *, int> distance;
std::function<int(Operation *)> computeDistance = [&](Operation *op) -> int {
auto it = distance.find(op);
if (it != distance.end())
return it->second;
int maxDist = -1;
for (Operation *user : op->getUsers()) {
Operation *inBlockUser = forOp.getBody()->findAncestorOpInBlock(*user);
if (!inBlockUser || inBlockUser == terminator)
continue;
int distUser = computeDistance(inBlockUser);
if (distUser > maxDist)
maxDist = distUser;
}
int lat = 0;
if (opLatency.count(op))
lat = opLatency.lookup(op);
int d = lat + (maxDist < 0 ? 0 : maxDist);
distance[op] = d;
return d;
};
int maxDistance = 0;
for (Operation *latOp : latOps) {
int d = computeDistance(latOp);
if (d > maxDistance)
maxDistance = d;
}
for (auto [op, dist] : distance) {
if (dist >= 0)
opToStage[op] = maxDistance - dist;
}
auto stages = llvm::make_second_range(opToStage);
int maxStage = *llvm::max_element(stages);
CoarseSchedule schedule(maxStage + 1);
SmallVector<CoarseSchedule::Cluster> clusters(maxStage + 1);
for (int i = 0; i <= maxStage; i++) {
clusters[i] = schedule.clusters.newAtBack();
}
for (auto [op, stage] : opToStage)
schedule.insert(op, stage, clusters[maxStage - stage]);
CoarseSchedule::Cluster epilogue = schedule.clusters.newAtBack();
for (auto [op, stage] : opToStage) {
auto ifOp = dyn_cast<scf::IfOp>(op);
if (!ifOp)
continue;
if (opLatency.contains(ifOp))
continue;
SetVector<Operation *> slice;
getForwardSlice(ifOp, &slice);
if (llvm::any_of(slice, [&](Operation *op) { return opToStage.count(op); }))
continue;
schedule.insert(ifOp, stage, epilogue);
}
return schedule;
}
CoarseSchedule getInitialSchedule(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
if (!isSafeToPipeline(forOp))
return CoarseSchedule(0);
if (hasLatenciesAssigned(forOp, opLatency))
return scheduleKeyOps(forOp, opLatency);
CoarseSchedule schedule;
if (forOp->hasAttr(kWarpSpecializeAttrName) &&
succeeded(schedule.deSerialize(forOp))) {
auto isLatencyOp = [&](Operation &op) {
return opLatency.count(&op) ||
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp, LocalStoreOp,
LocalLoadOp, ttng::TMEMLoadOp, ttng::TMEMStoreOp,
AsyncCopyGlobalToLocalOp, ttng::AsyncTMACopyGlobalToLocalOp,
ttng::AsyncTMAGatherOp, ttng::MMAv5OpInterface,
ttng::WaitBarrierOp, ttng::ArriveBarrierOp>(op);
};
DenseSet<int> latencyStages;
auto ops = forOp.getBody()->without_terminator();
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp)) {
if (schedule.count(&op))
latencyStages.insert(schedule[&op].first);
}
if (latencyStages.size() <= 1) {
CoarseSchedule normalized(1);
auto cluster = normalized.clusters.newAtFront();
for (Operation &op : ops)
normalized.insert(&op, 0, cluster);
return normalized;
}
schedule.shrinkToFit();
return schedule;
}
return CoarseSchedule(0);
}
CoarseSchedule::Cluster schedulePrologueAndEpilogue(scf::ForOp forOp,
CoarseSchedule &schedule) {
int numStages = schedule.getNumStages();
CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin();
DenseMap<scf::IfOp, int> ifsToStage;
for (int stage = 0; stage < numStages; stage++) {
for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) {
if (stage_ != stage)
continue;
SetVector<Operation *> backwardSlice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.omitUsesFromAbove = false;
(void)getBackwardSlice((Operation *)op, &backwardSlice, opt);
for (auto op : backwardSlice) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
ifsToStage.insert({ifOp, stage});
}
}
}
}
if (!ifsToStage.empty()) {
CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront();
for (auto [ifOp, stage] : ifsToStage) {
schedule.insertIfAbsent(ifOp, stage, prologueCluster);
}
}
CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack();
for (auto &op : forOp.getBody()->without_terminator()) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
if (ifsToStage.count(ifOp) == 0) {
schedule.insertIfAbsent(ifOp, numStages - 1,
epilogueCluster);
}
}
}
return afterPrologue;
}
void scheduleLoop(scf::ForOp forOp,
const DenseMap<Operation *, int> &opLatency) {
CoarseSchedule schedule = getInitialSchedule(forOp, opLatency);
if (schedule.empty())
return;
LLVM_DEBUG({
schedule.serialize(forOp);
DBGS() << "Initial coarse schedule:\n" << forOp << "\n";
});
CoarseSchedule::Cluster afterPrologue =
schedulePrologueAndEpilogue(forOp, schedule);
LLVM_DEBUG({
schedule.serialize(forOp);
DBGS() << "Coarse schedule with prologue and epilogue:\n" << forOp << "\n";
});
scheduleDependencies(forOp, schedule);
LLVM_DEBUG({
schedule.serialize(forOp);
DBGS() << "Coarse schedule with dependencies:\n" << forOp << "\n";
});
scheduleDistanceOneDependencies(forOp, schedule);
LLVM_DEBUG({
schedule.serialize(forOp);
DBGS() << "Coarse schedule with dist 1:\n" << forOp << "\n";
});
scheduleRemainingToLastStage(forOp, schedule, afterPrologue);
LLVM_DEBUG({
schedule.serialize(forOp);
DBGS() << "Final coarse schedule:\n" << forOp << "\n";
});
schedule.serialize(forOp);
}
void scheduleLoops(ModuleOp moduleOp) {
DenseMap<Operation *, int> opLatency = deserializeLatencies(moduleOp);
SmallVector<scf::ForOp> loops;
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
if (loops.empty())
return;
for (auto forOp : loops) {
scheduleLoop(forOp, opLatency);
}
}
}
#define GEN_PASS_DEF_TRITONGPUSCHEDULELOOPS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct ScheduleLoops : public impl::TritonGPUScheduleLoopsBase<ScheduleLoops> {
using TritonGPUScheduleLoopsBase::TritonGPUScheduleLoopsBase;
void runOnOperation() override { scheduleLoops(getOperation()); }
};
}