#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "tritonamdgpu-block-pingpong"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace ttg = mlir::triton::gpu;
namespace tt = mlir::triton;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUBLOCKPINGPONG
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
class Pingponger {
scf::ForOp forOp;
SmallVector<tt::LoadOp> gLoadOps;
SmallVector<ttg::LocalLoadOp> lLoadOps;
SmallVector<ttg::LocalStoreOp> lStoreOps;
SmallVector<ttg::AsyncCopyGlobalToLocalOp> asyncCopyOps;
SmallVector<ttg::AsyncWaitOp> asyncWaitOps;
SmallVector<ttg::AsyncCommitGroupOp> asyncCommitOps;
SmallVector<tt::DotOp> dotOps;
SmallVector<tt::DotScaledOp> scaledDotOps;
SmallVector<SmallVector<Operation *>> subViewOps;
SmallVector<SmallVector<Operation *>> loadSliceOps;
SmallVector<Operation *> dotSliceOps;
SmallVector<int64_t> constOffsets;
Operation *lastInsertedOp;
int lowPriority = 0;
int highPriority = 1;
int32_t kWidth;
int32_t numWarps;
int32_t numStages;
bool useAsyncCopy;
public:
Pingponger(scf::ForOp forOp, int32_t numWarps, int32_t numStages)
: forOp(forOp), numWarps(numWarps), numStages(numStages) {}
void getDotPingponged();
private:
void genOffsetConstants(Location loc, OpBuilder &builder, unsigned numSlices,
int64_t sliceWidth);
LogicalResult genLocalSlice(OpBuilder &builder, Value v,
Attribute dotEncoding, unsigned opIdx,
unsigned numSlices, int64_t sliceWidth);
LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op,
unsigned numSlices);
void transformOnePPClusters(OpBuilder &builder, Location loc);
LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc);
LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc);
LogicalResult transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
Location loc);
LogicalResult transformTwoClusterWithAsyncAndAll(OpBuilder &builder,
Location loc);
LogicalResult transformChainedDotSchedule(OpBuilder &builder, Location loc);
void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc);
void updateOpInsertion(Operation *Op);
void appendOp(Operation *Op);
void prependOp(Operation *Op, bool moveBackwards);
void moveOpAndPredecessorsUpSameBlock(Operation *Op);
void appendSlicedLoadAB(int slice);
SmallVector<Operation *> genClusterBarrier(OpBuilder &builder, Location loc);
void appendClusterBarrier(OpBuilder &builder, Location loc);
void prependClusterBarrier(OpBuilder &builder, Location loc);
void appendOpWithPrio(OpBuilder &builder, Operation *Op, Location loc);
bool isPersistentGemm(size_t num_dots);
template <typename T>
size_t countIfMemoryOps(scf::IfOp ifOp, bool assumeNotTaken);
template <typename T>
size_t estimateNonDotMemoryImpact(T *start, T *end, bool assumeNotTaken);
void determineDotMemoryOps(tt::DotOp dotOp,
DenseSet<tt::LoadOp> &dotGlobalLoads,
DenseSet<ttg::LocalLoadOp> &dotLocalLoads,
DenseSet<ttg::LocalStoreOp> &dotLocalStores);
template <typename T>
void findClosestPredOps(Value v, DenseSet<T> &matchingOps);
};
void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; }
void Pingponger::appendOp(Operation *op) {
assert(lastInsertedOp != nullptr);
op->moveAfter(lastInsertedOp);
lastInsertedOp = op;
}
void Pingponger::prependOp(Operation *op, bool moveBackwards) {
assert(lastInsertedOp != nullptr);
op->moveBefore(lastInsertedOp);
if (moveBackwards)
lastInsertedOp = op;
}
void Pingponger::moveOpAndPredecessorsUpSameBlock(Operation *op) {
assert(lastInsertedOp != nullptr);
assert(op->getBlock() == lastInsertedOp->getBlock());
Operation *checkedOp = lastInsertedOp;
if (lastInsertedOp->isBeforeInBlock(op)) {
SetVector<Operation *> backwardSlice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&checkedOp](Operation *op) {
return op->getBlock() == checkedOp->getBlock() &&
checkedOp->isBeforeInBlock(op);
};
(void)getBackwardSlice(op, &backwardSlice, opt);
for (auto predOp : backwardSlice)
appendOp(predOp);
appendOp(op);
} else {
auto hasUnsafeUser = [&checkedOp](auto &&user) {
return user != checkedOp && user->getBlock() == checkedOp->getBlock() &&
user->isBeforeInBlock(checkedOp);
};
if (std::any_of(op->user_begin(), op->user_end(), hasUnsafeUser))
LDBG("Unable to move operation "
<< op << " due to use before intended move location");
else
appendOp(op);
}
}
void Pingponger::appendSlicedLoadAB(int slice) {
appendOp(subViewOps[0][slice]);
appendOp(loadSliceOps[0][slice]);
appendOp(subViewOps[1][slice]);
appendOp(loadSliceOps[1][slice]);
}
SmallVector<Operation *> Pingponger::genClusterBarrier(OpBuilder &builder,
Location loc) {
auto barrierOp = builder.create<gpu::BarrierOp>(loc);
auto schedBarrierOp = builder.create<ROCDL::SchedBarrier>(loc, 0);
return {barrierOp, schedBarrierOp};
}
void Pingponger::appendClusterBarrier(OpBuilder &builder, Location loc) {
for (auto &&op : genClusterBarrier(builder, loc))
appendOp(op);
}
void Pingponger::prependClusterBarrier(OpBuilder &builder, Location loc) {
for (auto &&op : genClusterBarrier(builder, loc))
prependOp(op, false);
}
void Pingponger::appendOpWithPrio(OpBuilder &builder, Operation *op,
Location loc) {
appendOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority));
appendOp(op);
appendOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority));
}
bool Pingponger::isPersistentGemm(size_t num_dots) {
if (num_dots != 1)
return false;
bool seenIfSection = false;
bool seenDot = false;
for (auto &op : *forOp.getBody()) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
if (seenIfSection) {
return false;
}
auto cond = ifOp.getCondition().getDefiningOp();
if (!cond) {
return false;
}
bool matchesPattern = false;
if (auto cmpIOp = dyn_cast<arith::CmpIOp>(cond)) {
matchesPattern =
cmpIOp.getPredicate() == mlir::arith::CmpIPredicate::eq;
}
if (!matchesPattern) {
return false;
}
seenIfSection = true;
} else if (auto dotOp = dyn_cast<tt::DotOp>(op)) {
if (seenDot || !seenIfSection) {
return false;
}
seenDot = true;
seenIfSection = false;
}
}
return seenIfSection && seenDot;
}
template <typename T>
void Pingponger::findClosestPredOps(Value v, DenseSet<T> &matchingOps) {
DenseSet<Operation *> visitedOps;
std::function<void(Value)> impl;
impl = [&matchingOps, &visitedOps, &impl](Value v) {
if (auto blockArg = dyn_cast<BlockArgument>(v)) {
auto operandNumber = blockArg.getArgNumber();
auto block = blockArg.getOwner();
if (auto yield = dyn_cast<scf::YieldOp>(block->getTerminator())) {
auto parentOp = block->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
if (operandNumber < forOp.getNumInductionVars())
return;
operandNumber -= forOp.getNumInductionVars();
}
impl(yield->getOperand(operandNumber));
}
} else {
auto definingOp = v.getDefiningOp();
if (!definingOp)
return;
else if (visitedOps.contains(definingOp))
return;
visitedOps.insert(definingOp);
if (auto matchOp = dyn_cast<T>(definingOp))
matchingOps.insert(matchOp);
else
for (auto predValue : definingOp->getOperands())
impl(predValue);
}
};
impl(v);
}
template <typename T>
size_t Pingponger::countIfMemoryOps(scf::IfOp ifOp, bool assumeNotTaken) {
auto thenOps = ifOp.thenBlock()->getOps<T>();
size_t thenCount = std::distance(thenOps.begin(), thenOps.end());
size_t elseCount = 0;
if (ifOp.elseBlock()) {
auto elseOps = ifOp.elseBlock()->getOps<T>();
elseCount = std::distance(elseOps.begin(), elseOps.end());
}
return assumeNotTaken ? elseCount : std::max(thenCount, elseCount);
}
template <typename T>
size_t Pingponger::estimateNonDotMemoryImpact(T *start, T *end,
bool assumeNotTaken) {
DenseSet<Operation *> visitedParents;
size_t count = 0;
for (auto it = start; it != end; it++) {
auto parent = (*it)->getParentOp();
if (parent == nullptr)
continue;
if (parent == forOp)
count += 1;
else {
if (visitedParents.contains(parent))
continue;
visitedParents.insert(parent);
if (auto ifOp = dyn_cast<scf::IfOp>(parent))
count += countIfMemoryOps<T>(ifOp, assumeNotTaken);
else {
count += 1;
}
}
}
return count;
}
void Pingponger::determineDotMemoryOps(
tt::DotOp dotOp, DenseSet<tt::LoadOp> &dotGlobalLoads,
DenseSet<ttg::LocalLoadOp> &dotLocalLoads,
DenseSet<ttg::LocalStoreOp> &dotLocalStores) {
findClosestPredOps<ttg::LocalLoadOp>(dotOp.getA(), dotLocalLoads);
findClosestPredOps<ttg::LocalLoadOp>(dotOp.getB(), dotLocalLoads);
DenseSet<ttg::MemDescIndexOp> subviews;
for (auto &&localLoad : dotLocalLoads)
findClosestPredOps<ttg::MemDescIndexOp>(localLoad.getSrc(), subviews);
for (auto &&subview : subviews)
for (auto &&user : subview->getUsers())
if (auto localStore = dyn_cast<ttg::LocalStoreOp>(user))
dotLocalStores.insert(localStore);
for (auto &&localStore : dotLocalStores)
findClosestPredOps<tt::LoadOp>(localStore.getSrc(), dotGlobalLoads);
}
void Pingponger::transformOnePPClusters(OpBuilder &builder, Location loc) {
auto dotLoc = dotOps[0]->getPrevNode();
auto preDotBar = builder.create<ROCDL::SchedBarrier>(loc, 1);
updateOpInsertion(dotLoc);
appendOp(preDotBar);
updateOpInsertion(lLoadOps[0]);
appendOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority));
moveOpAndPredecessorsUpSameBlock(gLoadOps[0]);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
moveOpAndPredecessorsUpSameBlock(lLoadOps[1]);
appendOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority));
moveOpAndPredecessorsUpSameBlock(gLoadOps[1]);
updateOpInsertion(preDotBar);
appendOpWithPrio(builder, dotOps[0], loc);
dotOps[0]->emitRemark() << "Performed one ping pong cluster transformation\n";
}
void Pingponger::genOffsetConstants(Location loc, OpBuilder &builder,
unsigned numSlices, int64_t sliceWidth) {
for (int i = 0; i < numSlices; i++) {
int64_t offset = sliceWidth * i;
constOffsets.push_back(offset);
}
}
LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v,
Attribute dotEncoding, unsigned opIdx,
unsigned numSlices,
int64_t sliceWidth) {
SmallVector<Operation *> slices;
SmallVector<Operation *> subviews;
auto localLoad = v.getDefiningOp<ttg::LocalLoadOp>();
if (!localLoad)
return failure();
auto memDesc = localLoad.getSrc();
auto type = cast<ttg::MemDescType>(memDesc.getType());
SmallVector<int64_t> shape = llvm::to_vector(type.getShape());
Type elementType = type.getElementType();
int64_t kIdx = opIdx == 0 ? 1 : 0;
shape[kIdx] = sliceWidth;
if (sliceWidth < 16)
return failure();
auto dotOperandEnc = ttg::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, kWidth);
auto subviewDescType = ttg::MemDescType::get(
shape, elementType, type.getEncoding(), type.getMemorySpace(),
type.getMutableMemory(), type.getAllocShape());
for (int i = 0; i < numSlices; i++) {
SmallVector<int32_t> logicalOffsets;
SmallVector<int64_t> offsets = {0, 0};
offsets[kIdx] = i;
for (int64_t off : offsets) {
logicalOffsets.push_back(constOffsets[off]);
}
Value newSmem = builder.create<ttg::MemDescSubsliceOp>(
v.getLoc(), subviewDescType, memDesc, logicalOffsets);
Value prefetchSlice = builder.create<ttg::LocalLoadOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
subviews.push_back(newSmem.getDefiningOp());
slices.push_back(prefetchSlice.getDefiningOp());
}
subViewOps.push_back(subviews);
loadSliceOps.push_back(slices);
return success();
}
LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc,
tt::DotOp op, unsigned numSlices) {
builder.setInsertionPointToStart(forOp.getBody());
auto typeB = op.getB().getType();
auto shapeB = typeB.getShape();
int64_t sliceWidth = shapeB[0] / numSlices;
if (shapeB[0] % numSlices != 0)
return failure();
genOffsetConstants(loc, builder, numSlices, sliceWidth);
builder.setInsertionPointAfter(gLoadOps[0]);
auto dotEncoding = op.getType().getEncoding();
if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth)
.failed() ||
genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth)
.failed())
return failure();
Operation *prevDot = op;
for (int i = 0; i < numSlices; i++) {
IRMapping mapping;
mapping.map(op.getA(), loadSliceOps[0][i]->getResult(0));
mapping.map(op.getB(), loadSliceOps[1][i]->getResult(0));
if (i > 0)
mapping.map(op.getC(), prevDot->getResult(0));
auto newOp = builder.clone(*op, mapping);
prevDot = newOp;
dotSliceOps.push_back(newOp);
}
op->replaceAllUsesWith(prevDot);
op->erase();
for (auto loads : lLoadOps)
loads->erase();
return success();
}
LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder,
Location loc) {
if (sliceDot(builder, loc, dotOps[0], 4).failed())
return failure();
builder.setInsertionPointAfter(gLoadOps[1]);
updateOpInsertion(gLoadOps[1]);
appendSlicedLoadAB(0);
appendClusterBarrier(builder, loc);
appendOpWithPrio(builder, dotSliceOps[0], loc);
appendClusterBarrier(builder, loc);
appendOp(gLoadOps[1]);
appendSlicedLoadAB(1);
appendClusterBarrier(builder, loc);
appendOpWithPrio(builder, dotSliceOps[1], loc);
appendClusterBarrier(builder, loc);
appendSlicedLoadAB(2);
appendSlicedLoadAB(3);
appendClusterBarrier(builder, loc);
appendOpWithPrio(builder, dotSliceOps[2], loc);
appendClusterBarrier(builder, loc);
moveOpAndPredecessorsUpSameBlock(lStoreOps[0]);
moveOpAndPredecessorsUpSameBlock(lStoreOps[1]);
appendClusterBarrier(builder, loc);
appendOpWithPrio(builder, dotSliceOps[3], loc);
updateOpInsertion(lastInsertedOp->getBlock()->getTerminator());
prependClusterBarrier(builder, loc);
dotSliceOps[0]->emitRemark()
<< "Performed four ping pong cluster transformation\n";
return success();
}
LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder,
Location loc) {
if (sliceDot(builder, loc, dotOps[0], 2).failed())
return failure();
builder.setInsertionPointAfter(gLoadOps[1]);
updateOpInsertion(gLoadOps[1]);
appendSlicedLoadAB(0);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(gLoadOps[0]);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendSlicedLoadAB(1);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(gLoadOps[1]);
appendOp(builder.create<ROCDL::SBarrierOp>(loc));
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOpWithPrio(builder, dotSliceOps[0], loc);
appendClusterBarrier(builder, loc);
moveOpAndPredecessorsUpSameBlock(lStoreOps[0]);
moveOpAndPredecessorsUpSameBlock(lStoreOps[1]);
appendClusterBarrier(builder, loc);
appendOpWithPrio(builder, dotSliceOps[1], loc);
updateOpInsertion(lastInsertedOp->getBlock()->getTerminator());
prependClusterBarrier(builder, loc);
dotSliceOps[0]->emitRemark()
<< "Performed two ping pong cluster transformation\n";
return success();
}
LogicalResult Pingponger::transformTwoClusterWithAsyncAndAll(OpBuilder &builder,
Location loc) {
if (asyncCopyOps.size() == 0)
return failure();
builder.setInsertionPointAfter(asyncWaitOps[0]);
updateOpInsertion(asyncWaitOps[0]);
for (auto cop : asyncCommitOps)
moveOpAndPredecessorsUpSameBlock(cop);
for (auto glop : gLoadOps)
moveOpAndPredecessorsUpSameBlock(glop);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(builder.create<ROCDL::SBarrierOp>(loc));
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
scaledDotOps[0]->setAttr("pingpong_2step", builder.getUnitAttr());
return success();
}
LogicalResult Pingponger::transformChainedDotSchedule(OpBuilder &builder,
Location loc) {
assert(dotOps.size() == 2);
auto findNextMemoryCluster = [](Operation *op) {
while (op && !llvm::isa<ttg::AsyncWaitOp, ttg::LocalStoreOp>(op)) {
op = op->getNextNode();
}
return op;
};
std::array memoryClusterStartOps = {findNextMemoryCluster(dotOps[0]),
findNextMemoryCluster(dotOps[1])};
if (llvm::is_contained(memoryClusterStartOps, nullptr) ||
memoryClusterStartOps[0] == memoryClusterStartOps[1]) {
LDBG("ChainedDot pingpong requires memory operations in both memory "
"clusters");
return failure();
}
builder.setInsertionPointToStart(forOp.getBody());
updateOpInsertion(dotOps[0]);
prependOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority), false);
updateOpInsertion(memoryClusterStartOps[0]);
prependOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority), false);
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[0])) {
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
} else {
prependOp(builder.create<gpu::BarrierOp>(loc), false);
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
}
updateOpInsertion(dotOps[1]);
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
prependOp(builder.create<ROCDL::SBarrierOp>(loc), false);
prependOp(builder.create<ROCDL::SetPrioOp>(loc, lowPriority), false);
updateOpInsertion(memoryClusterStartOps[1]);
prependOp(builder.create<ROCDL::SetPrioOp>(loc, highPriority), false);
if (llvm::isa<ttg::AsyncWaitOp>(memoryClusterStartOps[1])) {
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
} else {
prependOp(builder.create<gpu::BarrierOp>(loc), false);
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
}
updateOpInsertion(lastInsertedOp->getBlock()->getTerminator());
prependOp(builder.create<ROCDL::SchedBarrier>(loc, 0), false);
prependOp(builder.create<ROCDL::SBarrierOp>(loc), false);
return success();
}
LogicalResult
Pingponger::transformTwoClusterWithLocalLoadAndAll(OpBuilder &builder,
Location loc) {
Operation *gLoadRhs = useAsyncCopy ? asyncCopyOps[1] : gLoadOps[1];
builder.setInsertionPointAfter(gLoadRhs);
updateOpInsertion(gLoadRhs);
auto newAsyncWaitOp = asyncWaitOps[0];
if (asyncWaitOps.size() > 1) {
SmallVector<Value> tokens;
for (auto asyncWaitOp : asyncWaitOps) {
for (auto token : asyncWaitOp.getAsyncToken()) {
tokens.push_back(token);
}
}
newAsyncWaitOp = builder.create<ttg::AsyncWaitOp>(loc, tokens, 0);
for (auto asyncWaitOp : asyncWaitOps) {
asyncWaitOp.getResult().replaceAllUsesWith(newAsyncWaitOp.getResult());
asyncWaitOp->erase();
}
}
assert(newAsyncWaitOp != nullptr);
moveOpAndPredecessorsUpSameBlock(lLoadOps[0]);
moveOpAndPredecessorsUpSameBlock(lLoadOps[1]);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(asyncCopyOps[0]);
appendOp(asyncCommitOps[0]);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(newAsyncWaitOp);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(builder.create<ROCDL::SchedGroupBarrier>(loc, 8, 1, 0));
appendOp(builder.create<ROCDL::SchedGroupBarrier>(loc, 4, 3, 0));
appendOp(builder.create<ROCDL::SchedGroupBarrier>(loc, 8, 1, 0));
appendOp(builder.create<ROCDL::SchedGroupBarrier>(loc, 4, 3, 0));
appendOp(builder.create<ROCDL::SchedGroupBarrier>(loc, 8, 1, 0));
appendOp(asyncCopyOps[1]);
appendOp(asyncCommitOps[1]);
appendOp(dotOps[0]);
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
appendOp(builder.create<ROCDL::SBarrierOp>(loc));
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
return success();
}
void Pingponger::addAsymmetricSyncToLoop(OpBuilder &builder, Location loc) {
builder.setInsertionPointAfter(forOp);
auto preBarrier = builder.create<gpu::BarrierOp>(loc);
preBarrier->moveBefore(forOp);
builder.setInsertionPointAfter(preBarrier);
auto i32ty = builder.getIntegerType(32);
auto workIDX = builder.create<ROCDL::ThreadIdXOp>(loc, i32ty);
auto constZero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
auto constWarpSize = builder.create<arith::ConstantIntOp>(loc, 256, 32);
auto warpIDX = builder.create<arith::DivSIOp>(loc, workIDX, constWarpSize);
auto warpLow = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
warpIDX, constZero);
auto warpHigh = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
warpIDX, constZero);
auto condBarrierHigh =
builder.create<tt::amdgpu::CondBarrierOp>(loc, warpHigh);
builder.setInsertionPointAfter(forOp);
auto condBarrierLow = builder.create<tt::amdgpu::CondBarrierOp>(loc, warpLow);
}
void Pingponger::getDotPingponged() {
if (numStages <= 1) {
LDBG("Pingpong pass expect a loop transformed by pipeliner that prefetches "
"memory and reduces dependencies among the operations in the same "
"iteration.");
return;
}
OpBuilder builder(forOp);
MLIRContext *ctx = forOp.getContext();
Location loc = forOp.getLoc();
forOp->walk([&](Operation *op) {
llvm::TypeSwitch<Operation *>(op)
.Case<tt::LoadOp>([&](auto gLoadOp) { gLoadOps.push_back(gLoadOp); })
.Case<ttg::AsyncCopyGlobalToLocalOp>(
[&](auto asyncCopyOp) { asyncCopyOps.push_back(asyncCopyOp); })
.Case<ttg::LocalLoadOp>([&](auto lLoad) {
auto src = lLoad.getSrc();
if (auto arg = mlir::dyn_cast<BlockArgument>(src))
if (auto tiedLoopInit = forOp.getTiedLoopInit(arg))
if (tiedLoopInit->get())
lLoadOps.push_back(lLoad);
})
.Case<ttg::LocalStoreOp>(
[&](auto lStore) { lStoreOps.push_back(lStore); })
.Case<tt::DotOp>(
[&](auto pingpongDot) { dotOps.push_back(pingpongDot); })
.Case<tt::DotScaledOp>(
[&](auto pingpongDot) { scaledDotOps.push_back(pingpongDot); })
.Case<ttg::AsyncCopyGlobalToLocalOp>(
[&](auto asyncOp) { asyncCopyOps.push_back(asyncOp); })
.Case<ttg::AsyncCommitGroupOp>([&](auto asyncCommitGroupOp) {
asyncCommitOps.push_back(asyncCommitGroupOp);
})
.Case<ttg::AsyncWaitOp>(
[&](auto asyncOp) { asyncWaitOps.push_back(asyncOp); });
});
int64_t numOfDotLikeOps = scaledDotOps.size() + dotOps.size();
if (numOfDotLikeOps < 1 || numOfDotLikeOps > 2) {
LDBG("Only handle one or two dotlike ops");
return;
}
if (numOfDotLikeOps == 2) {
if (numStages != 4)
return;
if (transformChainedDotSchedule(builder, loc).failed()) {
LDBG("Encountered failure when trying the ChainedDot ping pong "
"cluster transformation");
return;
}
addAsymmetricSyncToLoop(builder, loc);
}
useAsyncCopy = (asyncCopyOps.size() > 0);
int64_t gloadSize = useAsyncCopy ? asyncCopyOps.size() : gLoadOps.size();
int64_t dotSize =
scaledDotOps.size() > 0 ? scaledDotOps.size() : dotOps.size();
if ((gloadSize < 2 || lLoadOps.size() < 2 || dotSize != 1)) {
std::stringstream message;
message << "Unable to match ping pong scheduling pattern. Details: "
<< gloadSize << " global loads, " << lLoadOps.size()
<< " local loads, " << dotSize << " dot products";
LDBG(message.str());
return;
}
if (scaledDotOps.size() == 1 && numWarps == 8 && numStages == 2 &&
asyncCopyOps.size() > 0) {
auto scaledDotType = scaledDotOps[0].getType();
auto scaledDotShape = scaledDotType.getShape();
auto aType = scaledDotOps[0].getA().getType();
auto aShape = aType.getShape();
auto elemWidth = aType.getElementTypeBitWidth();
if (scaledDotShape[0] == 256 && scaledDotShape[1] == 256 &&
elemWidth == 8) {
if (transformTwoClusterWithAsyncAndAll(builder, scaledDotOps[0]->getLoc())
.failed()) {
LDBG("Encountered failure when trying to execute the"
"TwoClusterWithAsyncAndAll transformation");
return;
}
addAsymmetricSyncToLoop(builder, loc);
}
return;
} else if (scaledDotOps.size() == 1)
return;
auto assumeNotTaken = isPersistentGemm(dotOps.size());
auto dotType = dotOps[0].getType();
auto dotShape = dotType.getShape();
auto aType = dotOps[0].getA().getType();
auto aShape = aType.getShape();
auto elemWidth = aType.getElementTypeBitWidth();
int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth;
const int64_t minTile = 262144;
const int64_t smallTile = 16777216;
const int64_t mediumTile = 33554432;
const int64_t largeTile = 67108864;
auto encoding = cast<RankedTensorType>(aType).getEncoding();
auto srcEncoding = cast<ttg::DotOperandEncodingAttr>(encoding);
kWidth = srcEncoding.getKWidth();
auto mfmaEncoding = cast<ttg::AMDMfmaEncodingAttr>(srcEncoding.getParent());
SmallVector<int64_t> intShape;
intShape.push_back(mfmaEncoding.getMDim());
intShape.push_back(mfmaEncoding.getNDim());
if (dotOps.size() == 1 && useAsyncCopy) {
if (numWarps != 8) {
LDBG("Currently only support num_warp=8 for async PP");
return;
}
if (numStages > 2 && dotOps.size() == 1 && dotShape[0] > 64 &&
dotShape[1] > 64 && (elemWidth == 16 || elemWidth == 8)) {
if (transformTwoClusterWithLocalLoadAndAll(builder, loc).failed()) {
LDBG("Encountered failure when trying to execute the "
"TwoClusterWithLocalLoadAndAll transformation");
return;
}
addAsymmetricSyncToLoop(builder, loc);
return;
}
}
DenseSet<tt::LoadOp> dotGlobalLoads;
DenseSet<ttg::LocalLoadOp> dotLocalLoads;
DenseSet<ttg::LocalStoreOp> dotLocalStores;
determineDotMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads,
dotLocalStores);
auto gLoadIt = std::stable_partition(
gLoadOps.begin(), gLoadOps.end(),
[&dotGlobalLoads](tt::LoadOp op) { return dotGlobalLoads.contains(op); });
auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(),
[&dotLocalLoads](ttg::LocalLoadOp op) {
return dotLocalLoads.contains(op);
});
auto lStoreIt =
std::stable_partition(lStoreOps.begin(), lStoreOps.end(),
[&dotLocalStores](ttg::LocalStoreOp op) {
return dotLocalStores.contains(op);
});
if (estimateNonDotMemoryImpact<tt::LoadOp>(gLoadIt, gLoadOps.end(),
assumeNotTaken) != 0) {
std::stringstream message;
message << "Unable to match ping pong scheduling pattern. Details: "
<< "Non-dot global loads found in non-persistent GEMM";
LDBG(message.str());
return;
}
if (estimateNonDotMemoryImpact<ttg::LocalLoadOp>(lLoadIt, lLoadOps.end(),
assumeNotTaken) != 0) {
std::stringstream message;
message << "Unable to match ping pong scheduling pattern. Details: "
<< "Non-dot local loads found in non-persistent GEMM";
LDBG(message.str());
return;
}
if (estimateNonDotMemoryImpact<ttg::LocalStoreOp>(lStoreIt, lStoreOps.end(),
assumeNotTaken) != 0) {
std::stringstream message;
message << "Unable to match ping pong scheduling pattern. Details: "
<< "Non-dot local stores found in non-persistent GEMM";
LDBG(message.str());
return;
}
gLoadOps.erase(gLoadIt, gLoadOps.end());
lLoadOps.erase(lLoadIt, lLoadOps.end());
lStoreOps.erase(lStoreIt, lStoreOps.end());
if (gLoadOps.size() != 2 || lLoadOps.size() != 2) {
std::stringstream message;
message << "Unable to match ping pong slicing pattern. Details: "
<< gLoadOps.size() << " global loads in dot computation, "
<< lLoadOps.size() << " local loads in dot computation";
LDBG(message.str());
return;
}
if (numWarps == 4) {
if (tileSize <= smallTile && tileSize >= minTile)
transformOnePPClusters(builder, loc);
return;
} else if (numWarps == 8 && numStages == 2) {
if (lStoreOps.size() != 2) {
std::stringstream message;
message << "Unable to match ping pong slicing pattern. Details: "
<< lStoreOps.size() << " local stores in dot computation ";
LDBG(message.str());
return;
}
if (tileSize == mediumTile) {
if (transformTwoPPClusters(builder, dotOps[0]->getLoc()).failed()) {
LDBG("Encountered failure when trying to execute the two ping pong "
"cluster transformation");
return;
}
} else if (tileSize >= largeTile) {
if (intShape[0] == 16 && intShape[1] == 16 && kWidth == 8) {
LDBG("Reached known register spilling case, skip pingpong scheduling");
return;
}
if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) {
LDBG("Encountered failure when trying to execute the four ping pong "
"cluster transformation");
return;
}
} else
return;
addAsymmetricSyncToLoop(builder, loc);
}
}
}
struct TritonAMDGPUBlockPingpongPass
: impl::TritonAMDGPUBlockPingpongBase<TritonAMDGPUBlockPingpongPass> {
using Base::Base;
void runOnOperation() override {
ModuleOp m = getOperation();
for (auto funcOp : m.getOps<tt::FuncOp>()) {
funcOp.walk([&](scf::ForOp forOp) {
Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp), numStages);
pingponger.getDotPingponged();
});
}
}
};
}