#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
namespace ttg = mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONGPUFENCEINSERTION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
struct FenceInsertionPass
: public impl::TritonGPUFenceInsertionBase<FenceInsertionPass> {
public:
using impl::TritonGPUFenceInsertionBase<
FenceInsertionPass>::TritonGPUFenceInsertionBase;
void runOnOperation() override {
if (computeCapability < 90)
return;
ModuleOp mod = getOperation();
mod.walk([&](DotOpInterface dotOp) {
Value a = dotOp.getA();
Value b = dotOp.getB();
SmallVector<Operation *> copyRegToSharedOpsA = findCopyRegToSharedOps(a);
SmallVector<Operation *> copyRegToSharedOpsB = findCopyRegToSharedOps(b);
if (copyRegToSharedOpsA.empty() && copyRegToSharedOpsB.empty())
return WalkResult::advance();
OpBuilder builder(dotOp);
auto fence = builder.create<FenceAsyncSharedOp>(dotOp.getLoc(),
false);
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
if (!copyRegToSharedOpsA.empty() &&
llvm::any_of(copyRegToSharedOpsA,
[&](Operation *op) { return loopOp->isAncestor(op); }))
break;
if (!copyRegToSharedOpsB.empty() &&
llvm::any_of(copyRegToSharedOpsB,
[&](Operation *op) { return loopOp->isAncestor(op); }))
break;
loopOp.moveOutOfLoop(fence);
}
if (auto lastFence =
dyn_cast_or_null<FenceAsyncSharedOp>(fence->getPrevNode())) {
if (lastFence.getBCluster() == fence.getBCluster())
fence.erase();
}
return WalkResult::advance();
});
}
private:
SmallVector<Operation *> findCopyRegToSharedOps(Value operand) {
DenseSet<Value> visited;
llvm::SetVector<Operation *> result;
findCopyRegToSharedOps(operand, visited, result);
return result.takeVector();
}
void findCopyRegToSharedOps(Value operand, DenseSet<Value> &visited,
llvm::SetVector<Operation *> &result) {
if (visited.count(operand))
return;
visited.insert(operand);
if (!isa<triton::gpu::MemDescType>(operand.getType()))
return;
auto op = operand.getDefiningOp();
if (op) {
if (auto localAlloc = dyn_cast<ttg::LocalAllocOp>(op)) {
if (localAlloc.getSrc()) {
result.insert(op);
}
for (auto user : localAlloc.getResult().getUsers()) {
while (user->hasOneUse() &&
user->hasTrait<OpTrait::MemDescViewTrait>()) {
user = *user->getUsers().begin();
}
if (isa<ttg::LocalStoreOp>(user)) {
result.insert(user);
return;
}
}
}
for (auto v : op->getOperands()) {
findCopyRegToSharedOps(v, visited, result);
}
return;
}
BlockArgument arg = cast<BlockArgument>(operand);
unsigned argNum = arg.getArgNumber();
Operation *argOwner = arg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
assert(argNum != 0 && "induction var cannot be memdesc type");
--argNum;
findCopyRegToSharedOps(forOp.getInitArgs()[argNum], visited, result);
auto yieldOp = forOp.getBody()->getTerminator();
Value v = yieldOp->getOperand(argNum);
findCopyRegToSharedOps(v, visited, result);
return;
}
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(argOwner)) {
findCopyRegToSharedOps(wsOp.getParentOp().getExplicitCaptures()[argNum],
visited, result);
return;
}
result.insert(argOwner);
}
};
}
}
}