#include "mlir/Analysis/Liveness.h"
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
using namespace triton;
using namespace triton::gpu;
namespace mlir::triton::gpu {
#define GEN_PASS_DEF_TRITONGPUGLOBALSCRATCHALLOCATIONPASS
#include "triton/Conversion/TritonGPUToLLVM/Passes.h.inc"
}
static int32_t roundUp(int32_t val, int32_t step) {
auto t = val + step - 1;
return t - (t % step);
}
static void allocateGMem(Operation *parentOp,
llvm::SetVector<Operation *> &callStack) {
parentOp->walk([&](triton::CallOp call) {
auto callable = call.resolveCallable();
if (!callable->hasAttr("ttg.global_scratch_memory_size")) {
auto inserted = callStack.insert(parentOp);
assert(inserted && "call cycle detected");
allocateGMem(callable, callStack);
callStack.remove(parentOp);
}
});
MLIRContext *ctx = parentOp->getContext();
OpBuilder builder(ctx);
int32_t offset = 0;
uint32_t largestAlignment = 1;
parentOp->walk<WalkOrder::PostOrder>([&](Operation *op) {
uint32_t nbytes = 0;
uint32_t align = 0;
if (auto alloc = dyn_cast<triton::gpu::GlobalScratchAllocOp>(op)) {
nbytes = alloc.getNbytes();
align = alloc.getAlignment();
} else if (auto callOp = dyn_cast<triton::CallOp>(op)) {
auto callable = callOp.resolveCallable();
auto nbytes_attr = callable->getAttrOfType<IntegerAttr>(
"ttg.global_scratch_memory_size");
auto align_attr = callable->getAttrOfType<IntegerAttr>(
"ttg.global_scratch_memory_alignment");
assert(nbytes_attr);
assert(align_attr);
nbytes = nbytes_attr.getValue().getZExtValue();
align = align_attr.getValue().getZExtValue();
}
if (nbytes > 0) {
offset = roundUp(offset, align);
op->setAttr("ttg.global_scratch_memory_offset",
builder.getI32IntegerAttr(offset));
offset += nbytes;
largestAlignment = std::max(largestAlignment, align);
}
});
int32_t totalMemorySize = roundUp(offset, largestAlignment);
parentOp->setAttr("ttg.global_scratch_memory_size",
builder.getI32IntegerAttr(totalMemorySize));
parentOp->setAttr("ttg.global_scratch_memory_alignment",
builder.getI32IntegerAttr(largestAlignment));
}
namespace {
class TritonGPUGlobalScratchAllocationPass
: public mlir::triton::gpu::impl::TritonGPUGlobalScratchAllocationPassBase<
TritonGPUGlobalScratchAllocationPass> {
public:
void runOnOperation() override {
ModuleOp mod = getOperation();
bool seenKernel = false;
SetVector<Operation *> callStack;
mod->walk([&](triton::FuncOp func) {
allocateGMem(func, callStack);
if (func.getVisibility() == SymbolTable::Visibility::Public) {
assert(!seenKernel);
seenKernel = true;
auto size =
func->getAttrOfType<IntegerAttr>("ttg.global_scratch_memory_size");
auto align = func->getAttrOfType<IntegerAttr>(
"ttg.global_scratch_memory_alignment");
assert(size);
assert(align);
mod->setAttr("ttg.global_scratch_memory_size", size);
mod->setAttr("ttg.global_scratch_memory_alignment", align);
}
});
assert(seenKernel);
}
};
}