#include "TritonAMDGPUTransforms/Passes.h"
#include "amd/lib/TritonAMDGPUToLLVM/Utility.h"
#include "amd/lib/TritonAMDGPUTransforms/Utility.h"
namespace tt = triton;
namespace ttg = triton::gpu;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUUPDATEASYNCWAITCOUNT
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
int getNumberOfLoadInstructions(RankedTensorType srcTy,
ttg::MemDescType dstTy) {
LinearLayout srcLayout = tt::gpu::toLinearLayout(srcTy);
LinearLayout sharedLayout = tt::gpu::toLinearLayout(dstTy);
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
int contig = srcToSharedLayout.getNumConsecutiveInOut();
int numberOfRegisters = srcToSharedLayout.getInDimSize(
StringAttr::get(srcTy.getContext(), "register"));
int loadInstructionCount = std::max(1, numberOfRegisters / contig);
return loadInstructionCount;
}
int getNumberOfLoadInstructions(Operation *op) {
if (isa<ttg::AsyncCommitGroupOp>(op)) {
int count = 0;
for (auto token : op->getOperands()) {
auto defOp = token.getDefiningOp();
if (!defOp)
continue;
if (auto copyOp = llvm::dyn_cast<ttg::AsyncCopyGlobalToLocalOp>(defOp)) {
count += getNumberOfLoadInstructions(copyOp.getSrc().getType(),
copyOp.getResult().getType());
} else if (auto copyOp =
llvm::dyn_cast<amdgpu::BufferLoadToLocalOp>(defOp)) {
auto srcTy = cast<RankedTensorType>(LLVM::AMD::getPointerTypeWithShape(
copyOp.getPtr(), copyOp.getOffsets()));
count += getNumberOfLoadInstructions(srcTy, copyOp.getDest().getType());
}
}
return count;
}
if (isa<tt::LoadOp, tt::StoreOp, amdgpu::BufferLoadToLocalOp,
amdgpu::BufferStoreOp, tt::AtomicRMWOp, tt::AtomicCASOp,
amdgpu::BufferAtomicRMWOp>(op)) {
op->emitRemark("Global memory operation between async wait and "
"async_loads. This will hinder the interleaving of memory "
"operations and might impact performance.");
}
return 0;
}
void updateWaitCount(ttg::AsyncWaitOp waitOp, RewriterBase &rewriter) {
int waitCnt = std::numeric_limits<int>::max();
for (auto token : waitOp.getOperands()) {
auto tokenWaitCnt =
deduceMinCountOnDefChain(token, waitOp, [](Operation *op) {
return getNumberOfLoadInstructions(op);
});
waitCnt = std::min(waitCnt, tokenWaitCnt);
}
if (waitCnt == std::numeric_limits<int>::max() || waitOp.getNum() == waitCnt)
return;
rewriter.modifyOpInPlace(waitOp, [&]() { waitOp.setNum(waitCnt); });
}
}
struct TritonAMDGPUUpdateAsyncWaitCountPass
: impl::TritonAMDGPUUpdateAsyncWaitCountBase<
TritonAMDGPUUpdateAsyncWaitCountPass> {
using Base::Base;
void runOnOperation() override {
tt::AMD::TargetInfo targetInfo(archGenerationName);
if (!isCDNA(targetInfo.getISAFamily())) {
return;
}
ModuleOp m = getOperation();
SmallVector<ttg::AsyncWaitOp> waitOps;
getOperation()->walk(
[&](ttg::AsyncWaitOp waitOp) { waitOps.push_back(waitOp); });
for (auto waitOp : waitOps) {
IRRewriter builder(waitOp->getContext());
updateWaitCount(waitOp, builder);
}
}
};
}