#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONGPUPROXYFENCEINSERTION
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
bool isAsyncProxyWrite(Operation *op) {
return isa<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
triton::nvidia_gpu::AsyncTMAGatherOp>(op);
}
Value getSmemDest(Operation *op) {
if (auto asyncTMACopyGlobalToLocalOp =
dyn_cast<triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp>(op)) {
return asyncTMACopyGlobalToLocalOp.getResult();
}
if (auto asyncTMAGatherOp =
dyn_cast<triton::nvidia_gpu::AsyncTMAGatherOp>(op)) {
return asyncTMAGatherOp.getResult();
}
return Value();
}
bool isAsyncProxyRead(Operation *op) {
return isa<triton::nvidia_gpu::WarpGroupDotOp,
triton::nvidia_gpu::TCGen5MMAOp,
triton::nvidia_gpu::TCGen5MMAScaledOp,
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp,
triton::nvidia_gpu::AsyncTMAScatterOp,
triton::nvidia_gpu::AsyncTMAReduceOp>(op);
}
bool ignoreOpForProxyFence(Operation *op) {
return isAsyncProxyRead(op) || isAsyncProxyWrite(op) ||
isa<triton::nvidia_gpu::ArriveBarrierOp,
triton::nvidia_gpu::TMEMCopyOp, triton::nvidia_gpu::WaitBarrierOp,
triton::nvidia_gpu::InitBarrierOp,
triton::nvidia_gpu::InvalBarrierOp>(op);
}
bool filterFn(Operation *op, Operation *other) {
return ignoreOpForProxyFence(other);
}
class ProxyFenceAnalysis : public MembarOrFenceAnalysis {
public:
ProxyFenceAnalysis() = default;
explicit ProxyFenceAnalysis(Allocation *allocation, MembarFilterFn filter)
: MembarOrFenceAnalysis(allocation, filter) {}
private:
virtual void update(Operation *operation, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) override;
void insertFence(Operation *operation, OpBuilder *builder);
};
void ProxyFenceAnalysis::insertFence(Operation *op, OpBuilder *builder) {
OpBuilder::InsertionGuard g(*builder);
builder->create<triton::nvidia_gpu::FenceAsyncSharedOp>(op->getLoc(), false);
}
void ProxyFenceAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<triton::nvidia_gpu::FenceAsyncSharedOp>(op)) {
blockInfo->sync();
return;
}
BlockInfo curBlockInfo;
BlockInfo proxyBlockInfo;
auto scratchBufferId = Allocation::InvalidBufferId;
if (isa<triton::CallOp>(op)) {
auto callOpInterface = dyn_cast<CallOpInterface>(op);
if (auto callee =
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
curBlockInfo = funcBlockInfoMap->lookup(callee);
} else {
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isAsyncProxyWrite(op)) {
if (value == getSmemDest(op)) {
proxyBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
} else if (isa<MemoryEffects::Write>(
effectInstance.getEffect())) {
curBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
} else if (isa<MemoryEffects::Read>(effectInstance.getEffect())) {
curBlockInfo
.syncReadIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
}
}
}
}
}
scratchBufferId = allocation->getBufferId(op);
}
if (scratchBufferId != Allocation::InvalidBufferId) {
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncReadIntervals[interval].insert(op);
}
if (isAsyncProxyWrite(op) || isAsyncProxyRead(op)) {
if (proxyBlockInfo.isIntersected(*blockInfo, filter)) {
builder->setInsertionPoint(op);
insertFence(op, builder);
blockInfo->sync();
}
}
blockInfo->join(curBlockInfo);
}
}
struct ProxyFenceInsertionPass
: public impl::TritonGPUProxyFenceInsertionBase<ProxyFenceInsertionPass> {
public:
using impl::TritonGPUProxyFenceInsertionBase<
ProxyFenceInsertionPass>::TritonGPUProxyFenceInsertionBase;
void runOnOperation() override {
if (computeCapability < 90)
return;
ModuleOp mod = getOperation();
ModuleAllocation allocation(mod);
ModuleMembarOrFenceAnalysis<ProxyFenceAnalysis> analysis(&allocation,
filterFn);
analysis.run();
}
};
}
}
}