#include <iterator>
#include <numeric>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tritongpu-coalesce"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir {
namespace triton {
namespace gpu {
#define GEN_PASS_DEF_TRITONGPUCOALESCE
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
void
setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op,
int numWarps, int threadsPerWarp,
llvm::MapVector<Operation *, Attribute> &layoutMap) {
Value ptr = getMemAccessPtr(op);
auto refTensorType = cast<RankedTensorType>(ptr.getType());
LDBG("Considering op: " << *op);
LLVM_DEBUG({
DBGS() << "axis info of pointer: ";
axisInfoAnalysis.getAxisInfo(ptr)->print(llvm::dbgs());
llvm::dbgs() << "\n";
});
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
LDBG("order=[" << triton::join(order, ", ") << "]");
auto matchesShape = [&refTensorType](const Value &val) {
auto rttType = dyn_cast<RankedTensorType>(val.getType());
return rttType && rttType.getShape() == refTensorType.getShape();
};
llvm::SmallSetVector<Operation *, 32> memAccessesSameOrder;
memAccessesSameOrder.insert(op);
if (ptr.getDefiningOp()) {
for (Operation *use : mlir::multiRootGetSlice(op)) {
Value val = getMemAccessPtr(use);
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
continue;
auto currOrder = getOrderFromContiguity(
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
if (order == currOrder) {
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
memAccessesSameOrder.insert(use);
}
}
}
auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType);
LDBG("shapePerCTA=[" << triton::join(shapePerCTA, ", ") << "]");
int numElems = product<int64_t>(shapePerCTA);
int numThreads = numWarps * threadsPerWarp;
unsigned perThread = getNumElementsPerThread(op, order, axisInfoAnalysis);
LDBG("perThread for op: " << perThread);
for (Operation *opSameOrder : memAccessesSameOrder) {
if (opSameOrder == op)
continue;
unsigned currPerThread =
getNumElementsPerThread(opSameOrder, order, axisInfoAnalysis);
LDBG("perThread for opSameOrder: " << currPerThread);
perThread = std::max(perThread, currPerThread);
}
perThread = std::min<int>(perThread, std::max(numElems / numThreads, 1));
LDBG("perThread: " << perThread);
if (!dyn_cast<triton::LoadOp>(op)) {
perThread = std::min<int>(
perThread, getNumElementsPerThread(op, order, axisInfoAnalysis));
}
SmallVector<unsigned> sizePerThread(refTensorType.getRank(), 1);
sizePerThread[order[0]] = perThread;
auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding());
layoutMap[op] = triton::gpu::BlockedEncodingAttr::get(
&getContext(), refTensorType.getShape(), sizePerThread, order, numWarps,
threadsPerWarp, CTALayout);
}
static Type getNewType(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return tensorType.cloneWithEncoding(encoding);
}
void coalesceOp(Attribute encoding, Operation *op) {
OpBuilder builder(op);
SmallVector<Value, 4> newArgs;
for (auto operand : op->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType &&
!isa<triton::gpu::SharedEncodingTrait>(tensorType.getEncoding())) {
Type newType = getNewType(tensorType, encoding);
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, operand));
} else {
newArgs.push_back(operand);
}
}
SmallVector<Type, 4> newTypes;
for (auto t : op->getResultTypes()) {
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
}
Operation *newOp =
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
newTypes, op->getAttrs());
for (size_t i = 0; i < op->getNumResults(); i++) {
Value newResult = newOp->getResult(i);
if (newTypes[i] != op->getResultTypes()[i]) {
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), op->getResult(i).getType(), newResult);
}
op->getResult(i).replaceAllUsesWith(newResult);
}
op->erase();
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
llvm::MapVector<Operation *, Attribute> layoutMap;
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(moduleOp);
moduleOp.walk([&](Operation *curr) {
Value ptr = getMemAccessPtr(curr);
if (!ptr)
return;
bool isPtrTensor = false;
if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType()))
isPtrTensor = isa<PointerType>(tensorType.getElementType());
if (!isPtrTensor)
return;
int numWarps = lookupNumWarps(curr);
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,
layoutMap);
});
for (auto &kv : layoutMap) {
coalesceOp(kv.second, kv.first);
}
}
};
}
}
}