#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "triton-loop-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
namespace mlir::triton::gpu {
namespace {
bool preCondition(scf::ForOp forOp) {
if (loopHasDistGreaterThanOne(forOp))
return false;
if (isOuterLoop(forOp))
return false;
return true;
}
bool hasLatenciesAssigned(scf::ForOp forOp) {
auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper();
for (auto &op : forOp.getBody()->without_terminator()) {
if (helper.getAttr(&op))
return true;
}
return false;
}
void assignUserProvidedLatencies(scf::ForOp forOp,
DenseMap<Operation *, int> &opLatency) {
auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper();
for (auto &op : forOp.getBody()->without_terminator()) {
if (auto latencyAttr = helper.getAttr(&op)) {
opLatency[&op] = latencyAttr.getInt();
}
}
}
class AssignLoadLatencies {
public:
AssignLoadLatencies(scf::ForOp forOp, int numStages,
DenseMap<Operation *, int> &opLatency)
: forOp(forOp), numStages(numStages), opLatency(opLatency) {};
void run() {
bool pipelineWithoutDot = forOp->hasAttr(mlir::triton::kNumStagesAttrName);
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
llvm::MapVector<Operation *, std::pair<int, Operation *>> loadOpToIndLevel =
loadOpsToIndirectionLevel(forOp, pipelineWithoutDot, axisInfoAnalysis,
numStages);
if (loadOpToIndLevel.empty())
return;
int maxIndirectionLevel = 0;
for (auto &[loadOp, info] : loadOpToIndLevel)
maxIndirectionLevel = std::max(maxIndirectionLevel, info.first);
unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1);
for (auto [loadOp, dist] : loadOpToIndLevel) {
opLatency[loadOp] = loadLatency;
}
}
private:
scf::ForOp forOp;
int numStages;
DenseMap<Operation *, int> &opLatency;
public:
static bool canHaveSharedEncoding(tt::LoadOp op) {
bool incompatible = false;
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
return !incompatible;
}
static bool
isPipeliningBeneficial(Operation *op, Operation *finalUser,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
bool filterSmall) {
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
LDBG("Load " << *loadOp << " is too small for pipelining");
return false;
}
}
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
return true;
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
LDBG("Load " << *op << " cannot have shared encoding");
return false;
}
ttg::SharedEncodingTrait localAllocEnc;
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return isa<ttg::LocalAllocOp>(user);
})) {
for (auto user : op->getUsers()) {
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(user);
if (!localAlloc)
continue;
auto enc = mlir::cast<ttg::SharedEncodingTrait>(
localAlloc.getType().getEncoding());
if (!localAllocEnc) {
localAllocEnc = enc;
}
if (enc != localAllocEnc) {
return false;
}
}
}
if (localAllocEnc) {
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc);
if (filterSmall && vecBytes < 4) {
return false;
}
}
return true;
}
};
class AssignMMALatencies {
public:
AssignMMALatencies(scf::ForOp forOp, DenseMap<Operation *, int> &opLatency)
: forOp(forOp), opLatency(opLatency) {};
void run() {
DenseMap<Operation *, int> mmaSelfLatency;
auto isLoadToBePipelined = [&](Operation *op) {
return opLatency.count(op) && opLatency[op] > 0;
};
for (auto &op : forOp.getBody()->without_terminator()) {
if (auto mma = dyn_cast<ttng::MMAv5OpInterface>(&op)) {
if (hasSyncDots(forOp)) {
continue;
}
auto pipeHelper = ttng::MMAv5PipelineableOperandsHelper(
mma, forOp, isLoadToBePipelined);
if (pipeHelper.isPipelineable ||
(pipeHelper.isOperandsStateDetermined &&
!ttng::hasLoadsAfterMMA(mma, forOp))) {
mmaSelfLatency[mma] = 1;
if (!ttng::requiresAccMultiBuffering(mma, forOp) ||
(ttng::isAccMultibufferingPossible(mma, forOp) &&
!getDisallowAccMultiBuffer(forOp))) {
opLatency[&op] = 1;
}
if (forOp->hasAttr(kWarpSpecializeAttrName)) {
if (ttng::hasAccReadModifyWrite(mma, forOp))
opLatency.erase(&op);
else
opLatency[&op] += 1;
}
}
}
}
serializeSelfLatencies(forOp->getParentOfType<ModuleOp>(), mmaSelfLatency);
}
private:
scf::ForOp forOp;
DenseMap<Operation *, int> &opLatency;
bool hasSyncDots(scf::ForOp forOp) {
for (auto &op : forOp.getBody()->without_terminator()) {
if (isa<mlir::triton::DotOp>(op))
return true;
}
return false;
}
};
void assignLatencies(ModuleOp moduleOp, int defaultNumStages) {
SmallVector<scf::ForOp> loops;
moduleOp->walk([&](scf::ForOp forOp) {
if (preCondition(forOp) &&
getNumStagesOrDefault(forOp, defaultNumStages) > 1)
loops.push_back(forOp);
});
if (loops.empty())
return;
DenseMap<Operation *, int> opLatency;
for (auto forOp : loops) {
if (hasLatenciesAssigned(forOp)) {
assignUserProvidedLatencies(forOp, opLatency);
continue;
}
int numStages = getNumStagesOrDefault(forOp, defaultNumStages);
AssignLoadLatencies(forOp, numStages, opLatency).run();
AssignMMALatencies(forOp, opLatency).run();
}
serializeLatencies(moduleOp, opLatency);
}
}
llvm::MapVector<Operation *, std::pair<int, Operation *>>
loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
int numStages, bool filterSmall) {
llvm::MapVector<Operation *, std::pair<int, Operation *>> loadOpToIndLevel;
DenseSet<Operation *> seen;
DenseSet<Operation *> excluded;
std::function<void(Operation *, Operation *, int)> dfs =
[&](Operation *op, Operation *finalUser, int distance) {
if (!seen.insert(op).second || excluded.count(op))
return;
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
if (!AssignLoadLatencies::isPipeliningBeneficial(
op, finalUser, axisInfoAnalysis, filterSmall))
return;
if (loadOpToIndLevel.count(op)) {
int level = loadOpToIndLevel[op].first;
if (level != distance) {
LDBG("Load " << *op
<< " has multiple uses at different distances:"
<< level << " and " << distance);
loadOpToIndLevel.erase(op);
excluded.insert(op);
return;
}
} else {
LDBG("Load " << *op << " considered for pipelining with distance "
<< distance);
loadOpToIndLevel[op] = {distance, finalUser};
}
finalUser = op;
distance++;
}
for (Value operand : getNestedOperands(op)) {
if (isa<mlir::triton::DotOpInterface>(op)) {
if (operand == op->getOperand(2))
continue;
}
Value v = operand;
Operation *defOp = v.getDefiningOp();
if (defOp && defOp->getBlock() == op->getBlock()) {
dfs(defOp, finalUser, distance);
}
}
};
bool seenDot = false;
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!isa<mlir::triton::DotOpInterface, ttng::TMEMStoreOp>(op))
continue;
seenDot = true;
seen.clear();
dfs(&op, &op, 0);
}
if (pipelineWithoutDot) {
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
dfs(&op, &op, 0);
}
}
for (auto iter = loadOpToIndLevel.begin(); iter != loadOpToIndLevel.end();) {
if (iter->second.first >= numStages - 1)
iter = loadOpToIndLevel.erase(iter);
else
++iter;
}
return loadOpToIndLevel;
}
#define GEN_PASS_DEF_TRITONGPUASSIGNLATENCIES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct AssignLatencies
: public impl::TritonGPUAssignLatenciesBase<AssignLatencies> {
using TritonGPUAssignLatenciesBase::TritonGPUAssignLatenciesBase;
void runOnOperation() override { assignLatencies(getOperation(), numStages); }
};
}