#include "triton/Analysis/Membar.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include <deque>
namespace mlir {
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
FunctionOpInterface funcOp =
dyn_cast<FunctionOpInterface>(allocation->getOperation());
OpBuilder builder(funcOp.getContext());
resolve(funcOp, &funcBlockInfoMap, &builder);
}
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
std::deque<Block *> blockList;
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
for (auto &op : block->getOperations()) {
if (op.getDialect()->getNamespace() == "scf") {
llvm::report_fatal_error(
"scf dialect is not supported in membar. Please lower it "
"to cf dialect first.");
return;
}
}
if (block->isEntryBlock())
blockList.emplace_back(block);
});
while (!blockList.empty()) {
auto *block = blockList.front();
blockList.pop_front();
auto inputBlockInfo = inputBlockInfoMap[block];
SmallVector<Block *> successors;
for (auto &op : block->getOperations()) {
if (op.hasTrait<OpTrait::IsTerminator>()) {
visitTerminator(&op, successors);
} else {
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
}
}
if (outputBlockInfoMap.count(block) &&
inputBlockInfo == outputBlockInfoMap[block]) {
continue;
}
outputBlockInfoMap[block].join(inputBlockInfo);
for (auto *successor : successors) {
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);
blockList.emplace_back(successor);
}
}
auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
block->walk([&](triton::ReturnOp returnOp) {
funcBlockInfo.join(outputBlockInfoMap[block]);
});
});
}
void MembarAnalysis::visitTerminator(Operation *op,
SmallVector<Block *> &successors) {
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
Block *parentBlock = branchInterface->getBlock();
successors.append(std::begin(parentBlock->getSuccessors()),
std::end(parentBlock->getSuccessors()));
return;
}
if (op->hasTrait<OpTrait::ReturnLike>())
return;
llvm_unreachable("Unknown terminator encountered in membar analysis");
}
void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
OpBuilder::InsertionGuard g(*builder);
::insertBarrier(*builder, op);
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<gpu::BarrierOp>(op)) {
blockInfo->sync();
return;
}
if (isa<triton::gpu::AsyncWaitOp>(op) &&
!isa<gpu::BarrierOp>(op->getNextNode())) {
builder->setInsertionPointAfter(op);
insertBarrier(op, builder);
blockInfo->sync();
return;
}
BlockInfo curBlockInfo;
auto scratchBufferId = Allocation::InvalidBufferId;
if (isa<triton::CallOp>(op)) {
auto callOpInterface = dyn_cast<CallOpInterface>(op);
if (auto callee =
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable()))
curBlockInfo = funcBlockInfoMap->lookup(callee);
} else {
if (auto memoryEffectOpInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>>
effectInstances;
memoryEffectOpInterface.getEffects(effectInstances);
for (auto effectInstance : effectInstances) {
if (auto value = effectInstance.getValue()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<MemoryEffects::Write>(effectInstance.getEffect()))
curBlockInfo
.syncWriteIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
else if (isa<MemoryEffects::Read>(effectInstance.getEffect()))
curBlockInfo
.syncReadIntervals[allocation->getAllocatedInterval(
bufferId)]
.insert(op);
}
}
}
}
}
scratchBufferId = allocation->getBufferId(op);
}
if (scratchBufferId != Allocation::InvalidBufferId) {
if (!curBlockInfo.syncReadIntervals.empty() ||
!curBlockInfo.syncWriteIntervals.empty()) {
llvm::report_fatal_error(
"scratch buffer operations should not have any shared memory "
"dependencies");
}
auto interval = allocation->getAllocatedInterval(scratchBufferId);
curBlockInfo.syncWriteIntervals[interval].insert(op);
if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
}
blockInfo->sync();
curBlockInfo.syncReadIntervals[interval].insert(op);
} else if (blockInfo->isIntersected(curBlockInfo, filter)) {
builder->setInsertionPoint(op);
insertBarrier(op, builder);
blockInfo->sync();
}
blockInfo->join(curBlockInfo);
}
}