#include "triton/Analysis/Allocation.h"
#include <algorithm>
#include <limits>
#include "mlir/Analysis/Liveness.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/GenericSwizzling.h"
#include "triton/Tools/LayoutUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "allocation-shared-memory"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace ttng = mlir::triton::nvidia_gpu;
namespace mlir {
namespace triton {
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto *ctx = srcTy.getContext();
auto srcLayout = gpu::toLinearLayout(srcTy);
auto dstLayout = gpu::toLinearLayout(dstTy);
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
auto bitwidth = getBitwidth(srcTy);
auto smem = gpu::optimalSwizzlingLdSt(srcLayout, dstLayout, bitwidth);
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
return smem.getTotalOutDimSize() / reps;
}
static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
SmallVector<unsigned> smemShape;
if (!result.use_empty()) {
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
auto freeVariableMasks =
gpu::toLinearLayout(tensorTy).getFreeVariableMasks();
if (llvm::any_of(freeVariableMasks, [](auto variableMask) {
return variableMask.second != 0;
})) {
smemShape = convertType<unsigned>(gpu::getShapePerCTA(tensorTy));
}
} else {
smemShape.push_back(1);
}
}
return smemShape;
}
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
return helper.getScratchSizeInBytes();
}
if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
return helper.getScratchSizeInBytes();
}
if (auto gatherOp = dyn_cast<GatherOp>(op)) {
GatherLoweringHelper helper(gatherOp);
return helper.getScratchSizeInBytes();
}
if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
getBitwidth(dstTy) / 8;
}
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
if (!cvtNeedsSharedMemory(srcTy, dstTy))
return 0;
auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy);
return elems * getBitwidth(srcTy) / 8;
}
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
if (elems == 0)
return 0;
auto elemTy = getElementTypeOrSelf(getPointeeType(value.getType()));
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
}
if (isa<ttng::TensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
return kTMASize;
}
return 0;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation,
AllocationAnalysisScratchSizeFn scratchSizeGetter)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
run();
}
private:
using BufferT = Allocation::BufferT;
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
void run() {
getValuesAndSizes();
resolveLiveness();
computeOffsets();
}
void getExplicitValueSize(Operation *op) {
auto alloc = dyn_cast<gpu::LocalAllocOp>(op);
if (!alloc || !alloc.isSharedMemoryAlloc())
return;
auto allocType = alloc.getType();
int64_t numElems = 0;
if (auto paddedEnc =
dyn_cast<gpu::PaddedSharedEncodingAttr>(allocType.getEncoding())) {
SmallVector<int64_t> unpaddedShape = gpu::getShapePerCTA(allocType);
numElems = paddedEnc.getPaddedSize(unpaddedShape);
} else {
auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType);
numElems = product<int64_t>(shapePerCTA);
}
int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8;
auto alignment = alloc.getAlignmentOrDefault();
allocation->addBuffer<BufferT::BufferKind::Explicit>(alloc, bytes,
alignment);
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
unsigned alignment) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes, alignment);
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes);
}
void getScratchValueSize(Operation *op) {
constexpr size_t scratchAlignment = 128;
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
return;
}
if (auto ws = dyn_cast<gpu::WarpSpecializeOp>(op)) {
auto [captureSize, captureAlign] = ws.getCaptureSizeAlign();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, captureSize,
captureAlign);
return;
}
if (auto func = dyn_cast<FunctionOpInterface>(op)) {
unsigned numWarpIndices = 0;
func.walk([&](gpu::WarpSpecializeOp op) {
numWarpIndices = std::max(numWarpIndices, op.getTotalPartitionWarps());
});
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, numWarpIndices);
return;
}
unsigned bytes = scratchSizeGetter(op);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
dataflow::Lattice<AliasInfo> *latticeElement =
analysis.getLatticeElement(value);
if (latticeElement) {
AliasInfo &info = latticeElement->getValue();
if (!info.getAllocs().empty()) {
for (auto alloc : info.getAllocs()) {
allocation->addAlias(value, alloc);
}
}
}
}
void getValuesAndSizes() {
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
getExplicitValueSize(op);
getScratchValueSize(op);
});
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
SharedMemoryAliasAnalysis *aliasAnalysis =
solver->load<SharedMemoryAliasAnalysis>();
operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
failed(solver->initializeAndRun(op))) {
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
}
});
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (auto operand : op->getOperands()) {
getValueAlias(operand, *aliasAnalysis);
}
for (auto value : op->getResults()) {
getValueAlias(value, *aliasAnalysis);
}
});
}
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value);
LLVM_DEBUG({
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
value.dump();
});
}
}
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (const auto &[value, buffers] : allocation->aliasBuffer) {
auto range = getLiveness(value);
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
if (bufferRange.count(buffer)) {
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
}
bufferRange[buffer] = Interval(minId, maxId);
}
}
}
void resolveScratchBufferLiveness(
const DenseMap<Operation *, size_t> &operationId) {
auto processScratchMemory = [&](const auto &container) {
for (auto [op, buffer] : container) {
if (op == operation) {
bufferRange.insert(
{buffer, Interval(size_t(), std::numeric_limits<size_t>::max())});
continue;
}
bufferRange.insert(
{buffer, Interval(operationId.at(op), operationId.at(op) + 1)});
LLVM_DEBUG({
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
op->dump();
});
}
};
processScratchMemory(allocation->opScratch);
processScratchMemory(allocation->opVirtual);
}
void resolveLiveness() {
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PostOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value) {
auto liveOperations = liveness.resolveLiveness(value);
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
llvm::for_each(liveOperations, [&](Operation *liveOp) {
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
});
return Interval(minId, maxId);
};
resolveExplicitBufferLiveness(getValueLivenessRange);
resolveAliasBufferLiveness(getValueLivenessRange);
resolveScratchBufferLiveness(operationId);
}
void dumpBuffers() const {
LDBG("Dump bufferRange: id size offset ---------");
for (auto bufferIter : bufferRange) {
llvm::dbgs() << "-- " << bufferIter.first->id << " "
<< bufferIter.first->size << " " << bufferIter.first->offset;
llvm::dbgs() << " interval " << bufferIter.second.start() << " "
<< bufferIter.second.end() << "\n";
}
}
void dumpAllocationSize() const {
LDBG("Dump shared memory allocation size -----------");
auto liveBuffers = allocation->getLiveBuffers();
auto analyzedSize = 0;
for (auto [op, bufferIds] : liveBuffers) {
auto size = 0;
for (auto bufferId : bufferIds) {
auto bufferSize = allocation->getAllocatedSize(bufferId);
size += bufferSize;
}
analyzedSize = std::max(analyzedSize, size);
}
llvm::dbgs() << "Allocated: " << allocation->sharedMemorySize
<< ", analyzed: " << analyzedSize << "\n";
}
void dumpInterferenceGraph(const GraphT &interference) const {
LDBG("\n");
LDBG("Dump interference graph: \n");
for (auto edges : interference) {
llvm::dbgs() << "-- from " << edges.first->id << " to ";
for (auto node : edges.second) {
llvm::dbgs() << node->id << "; ";
}
llvm::dbgs() << "\n";
}
}
void computeOffsets() {
SmallVector<BufferT *> buffers;
for (auto bufferIter : bufferRange) {
buffers.emplace_back(bufferIter.first);
}
llvm::stable_sort(
buffers, [&](BufferT *A, BufferT *B) { return A->size > B->size; });
calculateStarts(buffers);
GraphT interference;
buildInterferenceGraph(buffers, interference);
do {
allocate(buffers, interference);
buildInterferenceGraph(buffers, interference);
} while (!interference.empty());
LLVM_DEBUG(dumpAllocationSize());
}
void calculateStarts(const SmallVector<BufferT *> &buffers) {
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
auto offset = tripleIt->first;
auto range = tripleIt->second;
tripleMap.erase(tripleIt);
auto bufferIt =
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (const auto &val : tripleMap)
res = res &&
!val.second.intersects(xRange);
return res;
});
if (bufferIt != xBuffers.end()) {
auto buffer = *bufferIt;
auto xSize = buffer->size;
auto xRange = bufferRange.lookup(buffer);
size_t alignOffset = buffer->setOffsetAligned(offset);
tripleMap.insert({alignOffset + xSize,
Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
if (range.start() < xRange.start())
tripleMap.insert({offset, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({offset, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
LLVM_DEBUG(dumpBuffers());
}
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
GraphT &interference) {
interference.clear();
for (auto x : buffers) {
for (auto y : buffers) {
if (x == y)
continue;
auto xStart = x->offset;
auto yStart = y->offset;
auto xSize = x->size;
auto ySize = y->size;
Interval xSizeRange = {xStart, xStart + xSize};
Interval ySizeRange = {yStart, yStart + ySize};
auto xOpRange = bufferRange.lookup(x);
auto yOpRange = bufferRange.lookup(y);
if (xOpRange.intersects(yOpRange) &&
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
auto wsx = x->owner->getParentWithTrait<OpTrait::AsyncRegions>();
auto wsy = y->owner->getParentWithTrait<OpTrait::AsyncRegions>();
if (wsx && wsy && wsx == wsy &&
x->owner->getParentRegion() != y->owner->getParentRegion() &&
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
}
}
LLVM_DEBUG(dumpInterferenceGraph(interference));
}
void allocate(const SmallVector<BufferT *> &buffers,
const GraphT &interference) {
allocation->sharedMemorySize = 0;
DenseMap<BufferT *, int> colors;
for (auto value : buffers) {
colors[value] = (value == buffers[0]) ? 0 : -1;
}
SmallVector<bool> available(buffers.size());
for (auto x : buffers) {
std::fill(available.begin(), available.end(), true);
for (auto y : interference.lookup(x)) {
int color = colors[y];
if (color >= 0) {
available[color] = false;
}
}
auto it = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), it);
LLVM_DEBUG({
llvm::dbgs() << "-- color " << x->id << " " << colors[x] << "\n";
});
}
for (auto x : buffers) {
size_t newOffset = 0;
for (auto y : interference.lookup(x)) {
newOffset = std::max(newOffset, y->offset + y->size);
}
if (colors.lookup(x) != 0)
x->setOffsetAligned(newOffset);
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
LLVM_DEBUG(dumpBuffers());
}
private:
Operation *operation;
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
AllocationAnalysisScratchSizeFn scratchSizeGetter;
};
}
void Allocation::run(
FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
scratchSizeGetter);
}
std::map<Operation *, SmallVector<Allocation::BufferId>>
Allocation::getLiveBuffers() {
std::map<Operation *, SmallVector<BufferId>> liveBuffers;
Operation *rootOperation = getOperation();
Liveness liveness(rootOperation);
auto analyzeOperation = [&](Operation *op) -> void {
auto scratchBuffer = getBufferId(op);
if (scratchBuffer != InvalidBufferId)
liveBuffers[op].push_back(scratchBuffer);
for (auto result : op->getOpResults()) {
auto bufferId = getBufferId(result);
if (bufferId == Allocation::InvalidBufferId)
continue;
auto liveOperations = liveness.resolveLiveness(result);
for (auto depOp : liveOperations)
liveBuffers[depOp].push_back(bufferId);
}
};
rootOperation->walk(analyzeOperation);
return liveBuffers;
}
}