#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/PassManager.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/STLExtras.h"
namespace ttg = mlir::triton::gpu;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUREORDERINSTRUCTIONS
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
static bool isPureMatmulFunc(triton::FuncOp funcOp) {
bool isMatmul = true;
bool foundLoop = false;
funcOp.walk([&](scf::ForOp forOp) -> void {
int counter = 0;
forOp.walk([&counter](triton::DotOp dotOp) { ++counter; });
isMatmul = (isMatmul && (counter == 1));
foundLoop = true;
});
return foundLoop && isMatmul;
}
static bool isPureMatmulLoop(scf::ForOp forOp) {
int dotCounter = 0;
int loadCounter = 0;
forOp.walk([&](Operation *op) {
if (isa<triton::DotOp>(op))
++dotCounter;
else if (isa<triton::LoadOp>(op))
++loadCounter;
});
return dotCounter == 1 && loadCounter >= 2;
}
static llvm::ilist<Operation>::iterator
findEarlyInsertionPoint(Block *block, triton::LoadOp move) {
Value src = move.getPtr();
auto ipnt = block->end();
for (auto bi = block->begin(); bi != block->end(); ++bi) {
auto *op = &*bi;
if (op == move)
break;
for (auto opr : op->getResults()) {
if (opr == src) {
ipnt = bi;
break;
}
}
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp, gpu::BarrierOp,
scf::ForOp, scf::WhileOp>(op)) {
ipnt = bi;
}
}
return ipnt;
}
static Operation *getFirstUseInSameBlock(Operation *op) {
SmallVector<Operation *> usersInSameBlock;
for (auto user : op->getUsers()) {
if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user))
usersInSameBlock.push_back(ancestor);
}
auto minOpIt =
llvm::min_element(usersInSameBlock, [](Operation *a, Operation *b) {
return a->isBeforeInBlock(b);
});
return minOpIt != usersInSameBlock.end() ? *minOpIt : nullptr;
}
static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop,
mlir::Operation *opOutsideLoop) {
scf::ForOp parentForOp = opInsideLoop->getParentOfType<scf::ForOp>();
return parentForOp && !parentForOp->isAncestor(opOutsideLoop);
}
static void sinkDotConversion(triton::FuncOp funcOp) {
DenseMap<Operation *, Operation *> opToMove;
funcOp.walk([&](ttg::ConvertLayoutOp op) {
Attribute encoding = op.getType().getEncoding();
if (!isa_and_nonnull<ttg::DotOperandEncodingAttr>(encoding))
return;
if (!op->hasOneUse())
return;
Operation *user = *op->getUsers().begin();
if (user->getParentOfType<scf::ForOp>() ==
op->getParentOfType<scf::ForOp>())
return;
opToMove[op] = user;
});
for (auto &kv : opToMove)
kv.first->moveBefore(kv.second);
}
static void moveDownCoversion(triton::FuncOp funcOp) {
SmallVector<ttg::ConvertLayoutOp> convertOps;
funcOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); });
for (auto op : convertOps) {
Operation *user = getFirstUseInSameBlock(op);
for (auto it = Block::iterator(op), ie = op->getBlock()->end();
it != ie && &*it != user; ++it)
if (isa<ttg::LocalDeallocOp>(&*it))
op->moveAfter(&*it);
}
}
static void moveUpTranspose(triton::FuncOp funcOp) {
SmallVector<triton::TransposeOpInterface> transOps;
funcOp.walk([&](triton::TransposeOpInterface op) { transOps.push_back(op); });
for (auto op : transOps)
if (Operation *argOp = op.getSrc().getDefiningOp())
op->moveAfter(argOp);
}
static void moveUpGlobalLoadInPrologue(triton::FuncOp funcOp) {
auto globalLoadOps =
llvm::to_vector(funcOp.getBody().getOps<triton::LoadOp>());
llvm::erase_if(globalLoadOps, [](triton::LoadOp op) {
return !op->getAttr("amd.pipeliner_part");
});
for (auto op : llvm::reverse(globalLoadOps)) {
Block *block = op->getBlock();
SetVector<Operation *> backwardSet;
BackwardSliceOptions options;
options.omitBlockArguments = true;
options.inclusive = false;
options.omitUsesFromAbove = false;
options.filter = [&](Operation *defOp) -> bool {
Block *defBlock = defOp->getBlock();
if (!block->findAncestorOpInBlock(*defOp))
return false;
return defBlock == block;
};
(void)mlir::getBackwardSlice(op.getOperation(), &backwardSet, options);
backwardSet.insert(op);
auto ipoint = findEarlyInsertionPoint(block, op);
SmallVector<Operation *> dfg = backwardSet.takeVector();
if (ipoint != block->end()) {
llvm::erase_if(
dfg, [&](Operation *op) { return !ipoint->isBeforeInBlock(op); });
for (auto *dfgop : llvm::reverse(dfg))
dfgop->moveAfter(block, ipoint);
} else {
for (auto *dfgop : llvm::reverse(dfg))
dfgop->moveBefore(block, block->begin());
}
}
}
static void sinkSecondLoad(scf::ForOp forOp) {
SetVector<triton::LoadOp> loadOps;
triton::DotOp dotOp;
for (Operation &op : forOp) {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
loadOps.insert(loadOp);
if (auto curOp = dyn_cast<triton::DotOp>(&op))
dotOp = curOp;
}
if (loadOps.size() != 2)
return;
auto ldAOp = loadOps[0];
auto loadAType = dyn_cast<RankedTensorType>(ldAOp.getType());
auto ldBOp = loadOps[1];
auto loadBType = dyn_cast<RankedTensorType>(ldBOp.getType());
if (!loadAType || !loadBType)
return;
auto tileAShape = loadAType.getShape();
auto tileBShape = loadBType.getShape();
if (tileAShape.size() != 2 || tileBShape.size() != 2)
return;
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128))
return;
bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp);
auto firstUser = *ldBOp.getResult().getUsers().begin();
bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser);
if (isBeforeDotOp && firstUserAfterDotOp)
ldBOp->moveBefore(dotOp);
}
}
struct TritonAMDGPUReorderInstructionsPass
: public impl::TritonAMDGPUReorderInstructionsBase<
TritonAMDGPUReorderInstructionsPass> {
void runOnOperation() override {
ModuleOp m = getOperation();
for (auto funcOp : m.getOps<triton::FuncOp>()) {
sinkDotConversion(funcOp);
moveDownCoversion(funcOp);
moveUpTranspose(funcOp);
moveUpGlobalLoadInPrologue(funcOp);
if (isPureMatmulFunc(funcOp)) {
funcOp.walk([&](scf::ForOp forOp) -> void { sinkSecondLoad(forOp); });
} else {
SmallVector<scf::ForOp> leafForOps = triton::AMD::getLeafForOps(funcOp);
for (auto forOp : leafForOps) {
if (isPureMatmulLoop(forOp)) {
sinkSecondLoad(forOp);
}
}
}
}
}
};
}