#include "triton/Analysis/Membar.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include <deque>
namespace mlir {
void MembarOrFenceAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
FunctionOpInterface funcOp =
dyn_cast<FunctionOpInterface>(allocation->getOperation());
OpBuilder builder(funcOp.getContext());
resolve(funcOp, &funcBlockInfoMap, &builder);
}
void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
DenseMap<VirtualBlock, BlockInfo> inputBlockInfoMap;
DenseMap<VirtualBlock, BlockInfo> outputBlockInfoMap;
std::deque<VirtualBlock> blockList;
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
if (block->isEntryBlock() &&
!isa<RegionBranchOpInterface>(block->getParentOp()))
blockList.emplace_back(block, Block::iterator());
});
while (!blockList.empty()) {
VirtualBlock block = blockList.front();
blockList.pop_front();
auto inputBlockInfo = inputBlockInfoMap[block];
SmallVector<VirtualBlock> successors;
Block::iterator startIt =
block.second.isValid() ? std::next(block.second) : block.first->begin();
for (Operation &op : llvm::make_range(startIt, block.first->end())) {
if (op.hasTrait<OpTrait::IsTerminator>() ||
isa<RegionBranchOpInterface>(op)) {
visitTerminator(&op, successors);
break;
}
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
}
if (outputBlockInfoMap.count(block) &&
inputBlockInfo == outputBlockInfoMap[block]) {
continue;
}
outputBlockInfoMap[block] = inputBlockInfo;
for (VirtualBlock successor : successors) {
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
blockList.emplace_back(successor);
}
}
BlockInfo &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
funcOp.walk<WalkOrder::PreOrder>([&](triton::ReturnOp returnOp) {
SmallVector<std::pair<VirtualBlock, BlockInfo>> virtualBlocks;
for (auto &[block, blockInfo] : outputBlockInfoMap) {
if (block.first == returnOp->getBlock())
virtualBlocks.emplace_back(block, blockInfo);
}
auto maxIt = llvm::max_element(virtualBlocks, [&](auto &lhs, auto &rhs) {
assert(lhs.first.first == rhs.first.first);
Block::iterator lhsIt = lhs.first.second, rhsIt = rhs.first.second;
return !lhsIt.isValid() ||
(rhsIt.isValid() && lhsIt->isBeforeInBlock(&*rhsIt));
});
funcBlockInfo.join(maxIt->second);
});
}
void MembarOrFenceAnalysis::visitTerminator(
Operation *op, SmallVector<VirtualBlock> &successors) {
if (isa<BranchOpInterface>(op)) {
for (Block *successor : op->getSuccessors())
successors.emplace_back(successor, Block::iterator());
return;
}
if (auto br = dyn_cast<RegionBranchOpInterface>(op)) {
SmallVector<RegionSuccessor> regions;
br.getSuccessorRegions(RegionBranchPoint::parent(), regions);
for (RegionSuccessor ®ion : regions) {
if (region.isParent()) {
successors.emplace_back(br->getBlock(), br->getIterator());
} else {
Block &block = region.getSuccessor()->front();
successors.emplace_back(&block, Block::iterator());
}
}
return;
}
auto br = dyn_cast<RegionBranchTerminatorOpInterface>(op);
if (br && isa<RegionBranchOpInterface>(br->getParentOp())) {
SmallVector<Attribute> operands(br->getNumOperands());
SmallVector<RegionSuccessor> regions;
br.getSuccessorRegions(operands, regions);
for (RegionSuccessor ®ion : regions) {
if (region.isParent()) {
Operation *parent = br->getParentOp();
successors.emplace_back(parent->getBlock(), parent->getIterator());
} else {
Block &block = region.getSuccessor()->front();
successors.emplace_back(&block, Block::iterator());
}
}
return;
}
if (op->hasTrait<OpTrait::ReturnLike>())
return;
llvm_unreachable("Unknown terminator encountered in membar analysis");
}
void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
OpBuilder::InsertionGuard g(*builder);
auto barrierOp = builder->create<gpu::BarrierOp>(op->getLoc());
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<gpu::BarrierOp>(op)) {
blockInfo->sync();
return;
}
if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
builder->setInsertionPointAfter(op);
insertBarrier(op, builder);
blockInfo->sync();
return;
}
BlockInfo curBlockInfo;
auto scratchBufferId = Allocation::InvalidBufferId;
if (isa<triton::CallOp>(op)) {
auto callOpInterface = dyn_cast<CallOpInterface>(op);
if (auto callee =
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
curBlockInfo = funcBlockInfoMap->lookup(callee);
} else {
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo
.syncReadIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
}
}
}
}
if (isa<triton::nvidia_gpu::ArriveBarrierOp>(op)) {
Interval<size_t> allIntervals(0, std::numeric_limits<size_t>::max());
curBlockInfo.syncWriteIntervals[allIntervals].insert(op);
curBlockInfo.syncReadIntervals[allIntervals].insert(op);
}
scratchBufferId = allocation->getBufferId(op);
}
if (scratchBufferId != Allocation::InvalidBufferId) {
bool isWarpSync = false;
if (auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cast<RankedTensorType>(cvt.getSrc().getType());
auto dstTy = cast<RankedTensorType>(cvt.getType());
auto srcLayout = triton::gpu::toLinearLayout(srcTy);
auto dstLayout = triton::gpu::toLinearLayout(dstTy);
isWarpSync = mlir::isCvtWarpSync(srcLayout, dstLayout);
}
if (!curBlockInfo.syncReadIntervals.empty() ||
!curBlockInfo.syncWriteIntervals.empty()) {
llvm::report_fatal_error(
"scratch buffer operations should not have any shared memory "
"dependencies");
}
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncWriteIntervals[interval].insert(op);
auto insertCTABarrier = blockInfo->isIntersected(curBlockInfo, filter);
if (insertCTABarrier) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
}
if (insertCTABarrier || !isWarpSync)
blockInfo->sync();
curBlockInfo.syncReadIntervals[interval].insert(op);
} else if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
blockInfo->sync();
}
blockInfo->join(curBlockInfo);
}
}