#include "triton/Dialect/TritonGPU/Transforms/Partition.h"
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/IR/Use.h"
using namespace mlir;
using namespace triton;
using namespace triton::gpu;
namespace {
struct PartitionNode {
PartitionNode(const Partition *partition) : partition(partition) {}
const Partition *partition;
SmallVector<std::pair<const PartitionNode *, OpOperand *>> consumers;
};
struct PartitionGraph {
PartitionGraph(scf::ForOp loop, const WarpSchedule &schedule);
PartitionNode root;
llvm::MapVector<const Partition *, PartitionNode> nodes;
};
}
PartitionGraph::PartitionGraph(scf::ForOp loop, const WarpSchedule &schedule)
: root(schedule.getRootPartition()) {
for (Partition &partition : schedule.getPartitions())
nodes.try_emplace(&partition, &partition);
for (PartitionNode &node : llvm::make_second_range(nodes))
root.consumers.emplace_back(&node, nullptr);
for (auto &[partition, node] : nodes) {
auto callback = [&, node = &node](Operation *owner, OpOperand &use) {
if (isa<scf::YieldOp>(owner))
return;
PartitionNode &consumer =
nodes.find(schedule.getPartition(owner))->second;
node->consumers.emplace_back(&consumer, &use);
};
schedule.iterateOutputs(loop, partition, callback);
}
}
namespace llvm {
template <> struct GraphTraits<PartitionGraph> {
using NodeRef = std::pair<const PartitionNode *, mlir::OpOperand *>;
static NodeRef getEntryNode(const PartitionGraph &graph) {
return {&graph.root, nullptr};
}
using ChildIteratorType = SmallVector<NodeRef>::const_iterator;
static ChildIteratorType child_begin(NodeRef node) {
return node.first->consumers.begin();
}
static ChildIteratorType child_end(NodeRef node) {
return node.first->consumers.end();
}
};
}
Partition *WarpSchedule::addPartition(unsigned stage) {
partitions.push_back(std::make_unique<Partition>(partitions.size(), stage));
return partitions.back().get();
}
Partition *WarpSchedule::getPartition(Operation *op) {
return opToPartition.lookup(op);
}
const Partition *WarpSchedule::getPartition(Operation *op) const {
return opToPartition.lookup(op);
}
Partition *WarpSchedule::getPartition(unsigned idx) {
return partitions[idx].get();
}
const Partition *WarpSchedule::getPartition(unsigned idx) const {
return partitions[idx].get();
}
void WarpSchedule::insert(Partition *partition, Operation *op) {
partition->ops.push_back(op);
opToPartition[op] = partition;
}
bool WarpSchedule::isScheduled(Operation *op) const {
const Partition *partition = getPartition(op);
return partition && partition != getRootPartition();
}
bool WarpSchedule::trySchedule(Partition *partition, Operation *op) {
if (isScheduled(op))
return false;
insert(partition, op);
return true;
}
FailureOr<WarpSchedule> WarpSchedule::deserialize(scf::ForOp loop) {
auto stages = loop->getAttrOfType<ArrayAttr>(kPartitionStagesAttrName);
if (!stages)
return failure();
auto tag = loop->getAttrOfType<IntegerAttr>(kWarpSpecializeTagAttrName);
if (!tag)
return failure();
WarpSchedule result;
result.tag = tag.getInt();
for (auto [idx, attr] : llvm::enumerate(stages)) {
auto stage = dyn_cast<IntegerAttr>(attr);
if (!stage || stage.getInt() < 0) {
return mlir::emitError(loop.getLoc(), "partition stages attribute '")
<< kPartitionStagesAttrName << "' has invalid element " << attr;
}
result.partitions.push_back(
std::make_unique<Partition>(idx, stage.getInt()));
}
for (Operation &op : loop.getBody()->without_terminator()) {
Partition *partition = result.getRootPartition();
if (auto attr = op.getAttrOfType<IntegerAttr>(kPartitionAttrName)) {
int64_t idx = attr.getInt();
if (idx < 0 || idx >= result.partitions.size())
return mlir::emitError(op.getLoc(), "invalid partition index ") << idx;
partition = result.partitions[idx].get();
}
result.insert(partition, &op);
}
return result;
}
void WarpSchedule::serialize(scf::ForOp loop) const {
SmallVector<Attribute> stages;
Builder b(loop.getContext());
for (Operation &op : loop.getBody()->without_terminator()) {
if (Partition *partition = opToPartition.lookup(&op)) {
if (partition == getRootPartition())
continue;
op.setAttr(kPartitionAttrName,
b.getI32IntegerAttr(partition->getIndex()));
}
}
for (Partition &partition : getPartitions())
stages.push_back(b.getI32IntegerAttr(partition.getStage()));
loop->setAttr(kPartitionStagesAttrName, b.getArrayAttr(stages));
}
LogicalResult WarpSchedule::verify(scf::ForOp loop) const {
bool failed = false;
iterateInputs(loop, getRootPartition(), [&](OpOperand &input) {
auto [def, distance] = getDefiningOpAndDistance(loop, input.get());
if (!def || def->getParentOp() != loop)
return;
const Partition *defPartition = opToPartition.at(def);
if (defPartition == getRootPartition())
return;
InFlightDiagnostic diag = mlir::emitWarning(input.getOwner()->getLoc());
diag << "operation in the root partition depends on a value that "
"originates from a non-root partition through operand #"
<< input.getOperandNumber();
diag.attachNote(def->getLoc())
<< "operand defined here in partition #" << defPartition->getIndex()
<< " at distance " << distance;
failed = true;
});
if (failed)
return failure();
return success();
}
void WarpSchedule::eraseFrom(scf::ForOp loop) {
loop.walk([&](Operation *op) { op->removeAttr(kPartitionAttrName); });
loop->removeAttr(kPartitionStagesAttrName);
}
void WarpSchedule::iterateInputs(
scf::ForOp loop, const Partition *partition,
function_ref<void(OpOperand &)> callback) const {
for (Operation *op : partition->getOps()) {
visitNestedOperands(op, [&](OpOperand &operand) {
Value value = operand.get();
if (value.getParentBlock() != loop.getBody())
return;
if (auto arg = dyn_cast<BlockArgument>(value)) {
assert(arg.getOwner() == loop.getBody());
if (arg == loop.getInductionVar())
return;
assert(llvm::is_contained(loop.getRegionIterArgs(), arg));
callback(operand);
} else if (getPartition(value.getDefiningOp()) != partition) {
assert(value.getDefiningOp()->getParentOp() == loop);
callback(operand);
}
});
}
}
void WarpSchedule::iterateOutputs(
scf::ForOp loop, const Partition *partition,
function_ref<void(Operation *, OpOperand &)> callback) const {
for (Operation *op : partition->getOps()) {
for (OpOperand &use : op->getUses()) {
Operation *owner = loop.getBody()->findAncestorOpInBlock(*use.getOwner());
if (isa<scf::YieldOp>(owner)) {
callback(owner, use);
} else if (getPartition(owner) != partition) {
callback(owner, use);
}
}
}
}
void WarpSchedule::iterateDefs(
scf::ForOp loop, const Partition *partition,
function_ref<void(OpResult, unsigned)> callback) const {
iterateInputs(loop, partition, [&](OpOperand &input) {
auto [def, distance] = getDefinitionAndDistance(loop, input.get());
if (def && def.getParentBlock() == loop.getBody())
callback(def, distance);
});
}
void WarpSchedule::iterateUses(
scf::ForOp loop, const Partition *partition,
function_ref<void(OpResult, OpOperand &, unsigned)> callback) const {
SmallVector<std::tuple<OpResult, OpOperand *, unsigned>> uses;
iterateOutputs(loop, partition, [&](Operation *owner, OpOperand &use) {
uses.emplace_back(cast<OpResult>(use.get()), &use, 0);
});
while (!uses.empty()) {
auto [output, use, distance] = uses.pop_back_val();
Operation *owner = loop.getBody()->findAncestorOpInBlock(*use->getOwner());
if (!isa<scf::YieldOp>(owner)) {
callback(output, *use, distance);
continue;
}
BlockArgument arg = loop.getRegionIterArg(use->getOperandNumber());
for (OpOperand &use : arg.getUses())
uses.emplace_back(output, &use, distance + 1);
}
}
void WarpSchedule::dump() const {
for (auto [i, partition] :
llvm::enumerate(llvm::make_pointee_range(partitions))) {
llvm::errs() << "=== PARTITION #" << i << " ===\n";
for (Operation *op : partition.getOps()) {
op->print(llvm::errs(), OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
}
llvm::errs() << "\n";
}
llvm::errs() << "=== ROOT PARTITION ===\n";
for (Operation *op : getRootPartition()->getOps()) {
op->print(llvm::errs(), OpPrintingFlags().skipRegions());
llvm::errs() << "\n";
}
llvm::errs() << "\n";
}