#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/MMAv5PipelineUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
using namespace mlir;
using namespace triton;
using namespace triton::gpu;
namespace ttng = triton::nvidia_gpu;
static Operation *findDefOpInLoop(scf::ForOp loop, Value value,
int distance = 0) {
if (auto arg = dyn_cast<BlockArgument>(value)) {
if (arg.getParentBlock() != loop.getBody())
return {};
if (distance == 1)
return {};
return findDefOpInLoop(
loop, loop.getYieldedValues()[arg.getArgNumber() - 1], distance + 1);
}
Operation *defOp = value.getDefiningOp();
if (!loop.getBodyRegion().isAncestor(defOp->getParentRegion()))
return {};
return defOp;
}
static void iterateDefs(scf::ForOp loop, Operation *op,
function_ref<void(OpResult)> callback) {
visitNestedOperands(op, [&](OpOperand &operand) {
Value value = operand.get();
if (value.getParentBlock() != loop.getBody())
return;
auto arg = dyn_cast<BlockArgument>(value);
if (arg == loop.getInductionVar())
return;
auto [def, distance] = getDefinitionAndDistance(loop, operand.get());
if (def && def.getParentBlock() == loop.getBody())
callback(def);
});
}
static void iterateUsers(scf::ForOp loop, Operation *op,
function_ref<void(Operation *)> callback) {
SmallVector<OpOperand *> uses;
for (OpOperand &use : op->getUses())
uses.push_back(&use);
while (!uses.empty()) {
OpOperand *use = uses.pop_back_val();
Operation *owner = loop.getBody()->findAncestorOpInBlock(*use->getOwner());
if (!isa<scf::YieldOp>(owner)) {
callback(owner);
continue;
}
BlockArgument arg = loop.getRegionIterArg(use->getOperandNumber());
for (OpOperand &use : arg.getUses())
uses.emplace_back(&use);
}
}
static bool hasDefPartition(scf::ForOp loop, Operation *op,
WarpSchedule &schedule) {
SmallVector<Operation *> worklist{op};
DenseSet<Operation *> seen;
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (!seen.insert(op).second)
continue;
Partition *p = schedule.getPartition(op);
if (p && p != schedule.getRootPartition())
return true;
iterateDefs(loop, op,
[&](OpResult def) { worklist.push_back(def.getDefiningOp()); });
}
return false;
}
static void scheduleDependencies(scf::ForOp loop, WarpSchedule &schedule,
Partition *partition, Operation *op) {
SmallVector<Value> deps;
for (Value value : getNestedOperands(op)) {
if (isa<RankedTensorType, MemDescType>(value.getType()))
deps.push_back(value);
}
while (!deps.empty()) {
Value dep = deps.pop_back_val();
if (auto arg = dyn_cast<BlockArgument>(dep)) {
if (arg.getOwner() == loop.getBody() && arg != loop.getInductionVar())
deps.push_back(loop.getYieldedValues()[arg.getArgNumber() - 1]);
continue;
}
Operation *defOp =
loop.getBody()->findAncestorOpInBlock(*dep.getDefiningOp());
if (!defOp || !hasDefPartition(loop, defOp, schedule) ||
!schedule.trySchedule(partition, defOp))
continue;
llvm::append_range(deps, getNestedOperands(defOp));
}
}
static void scheduleUsers(scf::ForOp loop, WarpSchedule &schedule,
Partition *partition, Operation *op) {
SmallVector<OpOperand *> uses;
for (OpOperand &use : op->getUses())
uses.push_back(&use);
while (!uses.empty()) {
OpOperand *use = uses.pop_back_val();
Operation *user = loop.getBody()->findAncestorOpInBlock(*use->getOwner());
if (user == loop.getBody()->getTerminator()) {
for (OpOperand &use :
loop.getRegionIterArg(use->getOperandNumber()).getUses())
uses.push_back(&use);
continue;
}
if (!schedule.trySchedule(partition, user))
continue;
for (OpOperand &use : user->getUses())
uses.push_back(&use);
}
}
static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
if (FailureOr<WarpSchedule> scheduleOr = WarpSchedule::deserialize(loop);
succeeded(scheduleOr))
return {std::move(*scheduleOr)};
WarpSchedule schedule;
Partition *defaultPartition = schedule.addPartition(0);
Partition *mmaPartition = schedule.addPartition(1);
Partition *loadPartition = schedule.addPartition(0);
SmallVector<Operation *> loadsAndAllocs;
for (Operation &op : loop.getOps()) {
if (!isa<DescriptorLoadOp, DescriptorGatherOp>(op))
continue;
schedule.trySchedule(loadPartition, &op);
loadsAndAllocs.push_back(&op);
SharedEncodingTrait sharedEnc = getSharedEncoding(&op);
for (Operation *user : op.getUsers()) {
if (auto alloc = dyn_cast<LocalAllocOp>(user)) {
if (sharedEnc == alloc.getType().getEncoding()) {
schedule.trySchedule(loadPartition, alloc);
loadsAndAllocs.push_back(alloc);
}
} else if (isa<ttng::TMEMAllocOp>(user)) {
schedule.trySchedule(loadPartition, user);
loadsAndAllocs.push_back(user);
}
}
}
SmallVector<ttng::MMAv5OpInterface> mmas;
for (auto mmaOp : loop.getOps<ttng::MMAv5OpInterface>()) {
schedule.trySchedule(mmaPartition, mmaOp);
mmas.push_back(mmaOp);
auto storeOp = dyn_cast_or_null<ttng::TMEMStoreOp>(
findDefOpInLoop(loop, mmaOp.getAccDep()));
if (!ttng::hasAccReadModifyWrite(mmaOp, loop) && storeOp &&
loop.isDefinedOutsideOfLoop(storeOp.getSrc()))
schedule.trySchedule(mmaPartition, storeOp);
SmallVector<Operation *> operandViews;
for (Value operand : mmaOp->getOperands()) {
if (Operation *defOp = operand.getDefiningOp())
operandViews.push_back(defOp);
}
while (!operandViews.empty()) {
Operation *op = operandViews.pop_back_val();
if (!op->hasTrait<OpTrait::MemDescViewTrait>())
continue;
if (!llvm::all_of(op->getUsers(), [&](Operation *user) {
return schedule.getPartition(user) == mmaPartition;
})) {
Operation *newOp = OpBuilder(op).clone(*op);
op->replaceUsesWithIf(newOp->getResults(), [&](OpOperand &use) {
return schedule.getPartition(use.getOwner()) == mmaPartition;
});
op = newOp;
}
schedule.trySchedule(mmaPartition, op);
if (Operation *defOp = op->getOperand(0).getDefiningOp())
operandViews.push_back(defOp);
}
}
if (loadsAndAllocs.empty() && mmas.empty())
return std::nullopt;
for (Operation &op : loop.getOps()) {
if (!isa<math::Exp2Op, ElementwiseInlineAsmOp>(op))
continue;
int elementCount = 0;
for (Type type : op.getResultTypes()) {
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
elementCount += tensorTy.getNumElements();
}
if (elementCount > 256) {
schedule.trySchedule(defaultPartition, &op);
scheduleDependencies(loop, schedule, defaultPartition, &op);
}
}
for (Operation *loadOrAlloc : loadsAndAllocs)
scheduleUsers(loop, schedule, defaultPartition, loadOrAlloc);
SmallVector<Partition *> userPartitions{defaultPartition};
while (userPartitions.size() < mmas.size()) {
userPartitions.push_back(schedule.addPartition(userPartitions.size()));
}
for (auto [mmaOp, userPartition] :
llvm::reverse(llvm::zip(mmas, userPartitions))) {
scheduleUsers(loop, schedule, userPartition, mmaOp);
}
return schedule;
}
namespace {
struct OpCluster {
SetVector<Operation *> ops;
SetVector<Partition *> defPartitions;
SetVector<Partition *> sinkPartitions;
};
struct OpClusters : public llvm::MapVector<Operation *, OpCluster *> {
using MapVector::MapVector;
OpCluster *getOrCreate(Operation *op) {
OpCluster *&cluster = (*this)[op];
if (!cluster) {
cluster = clusters.emplace_back(new OpCluster).get();
cluster->ops.insert(op);
}
return cluster;
}
void merge(OpCluster *dst, OpCluster *src) {
dst->ops.insert_range(src->ops);
dst->defPartitions.insert_range(src->defPartitions);
dst->sinkPartitions.insert_range(src->sinkPartitions);
for (Operation *op : src->ops)
(*this)[op] = dst;
src->ops.clear();
src->defPartitions.clear();
src->sinkPartitions.clear();
}
SmallVector<std::unique_ptr<OpCluster>> clusters;
};
}
void propagatePartitions(scf::ForOp loop, WarpSchedule &schedule) {
OpClusters opClusters;
for (Partition &partition : schedule.getPartitions()) {
auto defCallback = [&](OpResult result, unsigned distance) {
Operation *defOp = result.getDefiningOp();
if (!schedule.isScheduled(defOp) &&
hasDefPartition(loop, defOp, schedule)) {
opClusters.getOrCreate(defOp)->sinkPartitions.insert(&partition);
}
};
schedule.iterateDefs(loop, &partition, defCallback);
auto useCallback = [&](OpResult result, OpOperand &use, unsigned distance) {
Operation *user = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
if (!schedule.isScheduled(user)) {
opClusters.getOrCreate(user)->defPartitions.insert(&partition);
}
};
schedule.iterateUses(loop, &partition, useCallback);
}
SmallVector<Operation *> worklist =
llvm::to_vector(llvm::make_first_range(opClusters));
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
OpCluster *cluster = opClusters.find(op)->second;
iterateDefs(loop, op, [&](OpResult def) {
Operation *defOp = def.getDefiningOp();
if (schedule.isScheduled(defOp)) {
cluster->defPartitions.insert(schedule.getPartition(defOp));
} else {
if (!hasDefPartition(loop, defOp, schedule))
return;
OpCluster *&defCluster = opClusters[defOp];
if (!defCluster) {
defCluster = cluster;
cluster->ops.insert(defOp);
worklist.push_back(defOp);
} else if (defCluster != cluster) {
opClusters.merge(cluster, defCluster);
}
}
});
iterateUsers(loop, op, [&](Operation *user) {
if (schedule.isScheduled(user)) {
Partition *userPartition = schedule.getPartition(user);
cluster->sinkPartitions.insert(userPartition);
return;
}
OpCluster *&userCluster = opClusters[user];
if (userCluster)
return;
userCluster = cluster;
cluster->ops.insert(user);
worklist.push_back(user);
});
}
for (OpCluster &cluster : llvm::make_pointee_range(opClusters.clusters)) {
if (cluster.ops.empty())
continue;
assert(!cluster.defPartitions.empty());
assert(llvm::all_of(
cluster.ops, [&](Operation *op) { return !schedule.isScheduled(op); }));
if (cluster.defPartitions.size() > 1 || cluster.sinkPartitions.size() > 1) {
Partition *newPartition = schedule.addPartition(0);
for (Operation *op : cluster.ops)
schedule.insert(newPartition, op);
continue;
}
Partition *defPartition = cluster.defPartitions.front();
if (cluster.sinkPartitions.empty()) {
for (Operation *op : cluster.ops)
schedule.insert(defPartition, op);
continue;
}
Partition *sinkPartition = cluster.sinkPartitions.front();
SetVector<Operation *> critPath;
DenseSet<Operation *> opsInCluster(cluster.ops.begin(), cluster.ops.end());
auto callback = [&](OpResult result, unsigned distance) {
Operation *defOp = result.getDefiningOp();
if (opsInCluster.contains(defOp))
critPath.insert(defOp);
};
schedule.iterateDefs(loop, sinkPartition, callback);
for (unsigned i = 0; i < critPath.size(); ++i) {
Operation *op = critPath[i];
iterateDefs(loop, op, [&](OpResult def) {
Operation *defOp = def.getDefiningOp();
if (opsInCluster.contains(defOp))
critPath.insert(defOp);
});
}
if (critPath.size() == cluster.ops.size()) {
for (Operation *op : cluster.ops)
schedule.insert(defPartition, op);
continue;
}
critPath = topologicalSort(critPath);
DenseSet<Operation *> sinkOps(sinkPartition->getOps().begin(),
sinkPartition->getOps().end());
for (Operation *op : llvm::reverse(critPath)) {
OpBuilder b(op);
Operation *clone = b.clone(*op);
op->replaceUsesWithIf(clone->getResults(), [&](OpOperand &use) {
return sinkOps.contains(use.getOwner());
});
sinkOps.insert(clone);
schedule.insert(sinkPartition, clone);
}
for (Operation *op : cluster.ops)
schedule.insert(defPartition, op);
}
}
void rematerializeBroadcasts(WarpSchedule &schedule, OpOperand *use) {
static_assert(
std::is_base_of_v<OpTrait::OneResult<BroadcastOp>, BroadcastOp> &&
std::is_base_of_v<OpTrait::OneResult<ExpandDimsOp>, ExpandDimsOp>);
Operation *defOp = use->get().getDefiningOp();
while (isa_and_nonnull<BroadcastOp, ExpandDimsOp>(defOp)) {
Operation *clone = OpBuilder(defOp).clone(*defOp);
Partition *userPartition = schedule.getPartition(use->getOwner());
assert(userPartition && "user not scheduled");
schedule.insert(userPartition, clone);
use->set(clone->getResult(0));
defOp = clone->getOperand(0).getDefiningOp();
use = &clone->getOpOperand(0);
}
}
void optimizeSchedule(scf::ForOp loop, WarpSchedule &schedule) {
for (Partition &partition : schedule.getPartitions()) {
SmallVector<OpOperand *> uses;
schedule.iterateOutputs(loop, &partition,
[&](Operation *defOp, OpOperand &use) {
if (!isa<scf::YieldOp>(use.getOwner()))
uses.push_back(&use);
});
for (OpOperand *use : uses)
rematerializeBroadcasts(schedule, use);
}
}
namespace mlir::triton::gpu {
#define GEN_PASS_DEF_TRITONGPUPARTITIONSCHEDULING
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
}
namespace {
struct PartitionScheduling
: public triton::gpu::impl::TritonGPUPartitionSchedulingBase<
PartitionScheduling> {
using TritonGPUPartitionSchedulingBase::TritonGPUPartitionSchedulingBase;
void runOnOperation() override;
};
}
void PartitionScheduling::runOnOperation() {
SmallVector<scf::ForOp> loops;
getOperation().walk([&](scf::ForOp loop) {
if (loop->hasAttr(kWarpSpecializeAttrName))
loops.push_back(loop);
});
for (auto [idx, loop] : llvm::enumerate(loops)) {
if (std::optional<WarpSchedule> schedule = getInitialSchedule(loop)) {
propagatePartitions(loop, *schedule);
optimizeSchedule(loop, *schedule);
schedule->serialize(loop);
loop->setAttr(
kWarpSpecializeTagAttrName,
IntegerAttr::get(IntegerType::get(loop.getContext(), 32), idx));
}
}
}