#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir {
namespace triton {
namespace gpu {
#define GEN_PASS_DEF_TRITONGPUREORDERINSTRUCTIONS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static bool willIncreaseRegisterPressure(Operation *op) {
if (isa<triton::gpu::LocalLoadOp>(op))
return true;
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
return false;
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
cvt.getType().getEncoding()))
return true;
return false;
}
class TritonGPUReorderInstructionsPass
: public impl::TritonGPUReorderInstructionsBase<
TritonGPUReorderInstructionsPass> {
public:
TritonGPUReorderInstructionsPass() = default;
Operation *getFirstUse(Operation *op) {
std::vector<Operation *> users;
for (auto user : op->getUsers()) {
if (Operation *ancestor = op->getBlock()->findAncestorOpInBlock(*user))
users.push_back(ancestor);
}
auto minOpIt =
llvm::min_element(users, [](mlir::Operation *a, mlir::Operation *b) {
return a->isBeforeInBlock(b);
});
return minOpIt != users.end() ? *minOpIt : nullptr;
}
void runOnOperation() override {
ModuleOp m = getOperation();
mlir::DominanceInfo dom(m);
m.walk([&](triton::gpu::ConvertLayoutOp op) {
auto curr = mlir::Block::iterator(op);
auto end = op->getBlock()->end();
for (; curr != end && &*curr != getFirstUse(op); curr++)
if (isa<triton::gpu::LocalDeallocOp>(&*curr))
op->moveAfter(&*curr);
});
DenseMap<Operation *, Operation *> opToMove;
auto moveAfter = [](Operation *lhs, Operation *rhs) {
lhs->moveAfter(rhs);
};
m.walk([&](Operation *op) {
if (!willIncreaseRegisterPressure(op))
return;
auto user_begin = op->user_begin();
auto user_end = op->user_end();
if (std::distance(user_begin, user_end) != 1)
return;
if (user_begin->getParentOfType<scf::ForOp>() ==
op->getParentOfType<scf::ForOp>())
return;
opToMove.insert({op, *user_begin});
});
for (auto &kv : opToMove)
kv.first->moveBefore(kv.second);
m.walk([&](triton::gpu::LocalAllocOp op) {
if (!op.getSrc())
return;
Operation *argOp = op.getSrc().getDefiningOp();
if (!argOp)
return;
moveAfter(op, argOp);
});
opToMove.clear();
m.walk([&](triton::TransposeOpInterface op) {
Operation *argOp = op.getSrc().getDefiningOp();
if (!argOp)
return;
moveAfter(op, argOp);
});
m.walk([&](triton::gpu::LocalLoadOp op) {
auto dstEncoding = mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(
op.getType().getEncoding());
if (!dstEncoding)
return;
int opIdx = dstEncoding.getOpIdx();
if (opIdx != 1)
return;
if (!op->hasOneUse())
return;
auto dotUser = dyn_cast<triton::DotOp>(*op->user_begin());
if (!dotUser)
return;
auto AOp =
dotUser.getOperand(0).getDefiningOp<triton::gpu::LocalLoadOp>();
if (!AOp)
return;
if (!dom.dominates(op.getOperation(), AOp.getOperation()))
return;
moveAfter(op, AOp);
});
return;
}
};
}
}
}