#include "triton/Analysis/Allocation.h"
#include <algorithm>
#include <limits>
#include <numeric>
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Alias.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getContigPerThread;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getUniqueContigPerThread;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
#define DEBUG_TYPE "allocation-analysis"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir {
namespace triton {
constexpr int kPtrBitWidth = 64;
static std::pair<SmallVector<unsigned>, SmallVector<unsigned>>
getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout);
auto srcDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(srcLayout);
auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);
assert(!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere() &&
!srcMmaLayout.isHopper()) &&
"mma -> mma layout conversion is only supported on Ampere");
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
: getOrder(srcLayout);
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
: getOrder(dstLayout);
return {inOrd, outOrd};
}
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (!cvtNeedsSharedMemory(srcTy, dstTy)) {
return {};
}
if (shouldUseDistSmem(srcLayout, dstLayout)) {
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
assert(srcLayout && dstLayout && "Unexpected layout in getRepShapeForCvt()");
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
unsigned rank = dstTy.getRank();
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
repShape[d] =
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
}
return repShape;
}
static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
SmallVector<unsigned> smemShape;
if (atomicNeedsSharedMemory(result)) {
smemShape.push_back(1);
}
return smemShape;
}
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto repShape = getRepShapeForCvt(srcTy, dstTy);
if (repShape.empty())
return ScratchConfig({}, {});
ScratchConfig scratchConfig(repShape, repShape);
auto rank = repShape.size();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
assert(!isMfmaToDotShortcut(srcTy, dstTy));
auto inOrd = gpu::getThreadOrder(srcLayout);
auto outOrd = gpu::getThreadOrder(dstLayout);
scratchConfig.order = outOrd;
unsigned srcContigPerThread =
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
unsigned innerDim = rank - 1;
scratchConfig.inVec = outOrd[0] != innerDim ? 1
: inOrd[0] != innerDim ? 1
: srcContigPerThread;
scratchConfig.outVec = outOrd[0] != innerDim ? 1 : dstContigPerThread;
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
if (mma.getVersionMajor() == 1) {
scratchConfig.inVec = srcContigPerThread;
} else if (mlir::isa<BlockedEncodingAttr>(dstLayout)) {
scratchConfig.outVec = dstContigPerThread;
}
}
if (rank <= 1 || product(repShape) == repShape[outOrd[0]])
return scratchConfig;
auto paddedSize = std::max(scratchConfig.inVec, scratchConfig.outVec);
scratchConfig.paddedRepShape[outOrd[0]] += paddedSize;
return scratchConfig;
}
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation) {
run();
}
private:
using BufferT = Allocation::BufferT;
using BufferRangeMapT = llvm::MapVector<BufferT *, Interval<size_t>>;
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
void run() {
getValuesAndSizes();
resolveLiveness();
computeOffsets();
}
void getExplicitValueSize(Operation *op) {
for (Value result : op->getResults()) {
auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>();
if (alloc && alloc.isSharedMemoryAlloc()) {
auto allocType = alloc.getType();
auto shapePerCTA = triton::gpu::getShapePerCTA(allocType);
auto bytes = product<int64_t>(shapePerCTA) *
allocType.getElementTypeBitWidth() / 8;
auto alignment = alloc.getAlignmentOrDefault();
LLVM_DEBUG({
llvm::dbgs() << "check localAlloc in getExplicitValueSize: ";
alloc.dump();
});
int sharingGroup = -1;
if (alloc->hasAttr("allocation.shareGroup")) {
sharingGroup =
mlir::cast<IntegerAttr>(alloc->getAttr("allocation.shareGroup"))
.getInt();
LDBG("with shareGroup of " << sharingGroup);
}
allocation->addBuffer<BufferT::BufferKind::Explicit>(
result, bytes, alignment, 0, sharingGroup);
}
}
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes,
unsigned alignment) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes, alignment);
}
template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes);
}
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<triton::ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto histogram = dyn_cast<triton::HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
mlir::isa<SharedEncodingAttr>(dstEncoding)) {
return;
}
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
auto bytes =
isa<triton::PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
if (dyn_cast<RankedTensorType>(value.getType())) {
} else {
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
auto bytes =
isa<triton::PointerType>(elemTy)
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
} else if (auto createTensormap =
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
constexpr int32_t kTMAAlign = 128;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
kTMAAlign);
}
}
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
dataflow::Lattice<AliasInfo> *latticeElement =
analysis.getLatticeElement(value);
if (latticeElement) {
AliasInfo &info = latticeElement->getValue();
if (!info.getAllocs().empty()) {
for (auto alloc : info.getAllocs()) {
allocation->addAlias(value, alloc);
}
}
}
}
void getValuesAndSizes() {
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
getExplicitValueSize(op);
getScratchValueSize(op);
});
LDBG("getValuesAndSizes --");
for (auto valueBufferIter : allocation->valueBuffer) {
auto *buffer = valueBufferIter.second;
LLVM_DEBUG(llvm::dbgs()
<< "-- buffer " << buffer->id << " " << buffer->size << " "
<< buffer->offset << " " << buffer->sharingGroup << "\n");
}
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
SharedMemoryAliasAnalysis *aliasAnalysis =
solver->load<SharedMemoryAliasAnalysis>();
if (failed(solver->initializeAndRun(operation))) {
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
}
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (auto operand : op->getOperands()) {
getValueAlias(operand, *aliasAnalysis);
}
for (auto value : op->getResults()) {
getValueAlias(value, *aliasAnalysis);
}
});
}
void resolveExplicitBufferLiveness(
function_ref<Interval<size_t>(Value value, BufferT *buffer)>
getLiveness) {
for (auto valueBufferIter : allocation->valueBuffer) {
auto value = valueBufferIter.first;
auto *buffer = valueBufferIter.second;
bufferRange[buffer] = getLiveness(value, buffer);
}
}
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value, BufferT *buffer)>
getLiveness) {
for (auto aliasBufferIter : allocation->aliasBuffer) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value, buffers.front());
for (auto *buffer : buffers) {
auto minId = range.start();
auto maxId = range.end();
if (bufferRange.count(buffer)) {
minId = std::min(minId, bufferRange[buffer].start());
maxId = std::max(maxId, bufferRange[buffer].end());
}
bufferRange[buffer] = Interval(minId, maxId);
}
}
}
void resolveScratchBufferLiveness(
const DenseMap<Operation *, size_t> &operationId) {
auto processScratchMemory = [&](const auto &container) {
for (auto opScratchIter : container) {
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
if (getAsyncTaskIds(op).empty()) {
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
} else {
for (auto tId : getAsyncTaskIds(op))
buffer->regionIds.insert(tId);
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
}
};
processScratchMemory(allocation->opScratch);
processScratchMemory(allocation->opVirtual);
}
void resolveLiveness() {
DenseMap<Operation *, size_t> operationId;
operation->walk<WalkOrder::PostOrder>(
[&](Operation *op) { operationId[op] = operationId.size(); });
Liveness liveness(operation);
auto getValueLivenessRange = [&](Value value, BufferT *buffer) {
auto liveOperations = liveness.resolveLiveness(value);
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
for (auto rId : getAsyncTaskIds(liveOp)) {
buffer->regionIds.insert(rId);
}
});
auto minId = std::numeric_limits<size_t>::max();
auto maxId = std::numeric_limits<size_t>::min();
std::for_each(
liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) {
if (buffer->regionIds.size() > 1 || buffer->sharingGroup >= 0) {
minId = 0;
maxId = operationId.size();
return;
}
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
});
return Interval(minId, maxId);
};
resolveExplicitBufferLiveness(getValueLivenessRange);
resolveAliasBufferLiveness(getValueLivenessRange);
resolveScratchBufferLiveness(operationId);
}
void dumpBuffers() {
LDBG("Dump bufferRange: id size offset sharingGroup ---------");
for (auto bufferIter : bufferRange) {
LLVM_DEBUG({
llvm::dbgs() << "-- " << bufferIter.first->id << " "
<< bufferIter.first->size << " "
<< bufferIter.first->offset << " "
<< bufferIter.first->sharingGroup << " regions [";
for (auto tId : bufferIter.first->regionIds) {
llvm::dbgs() << tId << " ";
}
llvm::dbgs() << "] interval " << bufferIter.second.start() << " "
<< bufferIter.second.end() << "\n";
});
}
}
void computeOffsets() {
SmallVector<BufferT *> buffers;
DenseMap<int, SmallVector<BufferT *>> toGroup;
for (auto bufferIter : bufferRange) {
if (bufferIter.first->sharingGroup >= 0)
toGroup[bufferIter.first->sharingGroup].push_back(bufferIter.first);
}
DenseMap<int, BufferT *> sharingIdToRep;
for (auto &kv : toGroup) {
size_t bigSize = 0;
BufferT *rep = nullptr;
for (auto *buf : kv.second) {
if (buf->size > bigSize) {
rep = buf;
bigSize = buf->size;
}
}
sharingIdToRep[kv.first] = rep;
}
for (auto bufferIter : bufferRange) {
if (sharingIdToRep.find(bufferIter.first->sharingGroup) !=
sharingIdToRep.end()) {
if (bufferIter.first !=
sharingIdToRep[bufferIter.first->sharingGroup]) {
LDBG("-- ignore shared buffer " << bufferIter.first->size << " "
<< bufferIter.first->offset << " "
<< bufferIter.first->sharingGroup);
continue;
}
}
buffers.emplace_back(bufferIter.first);
}
calculateStarts(buffers);
dumpBuffers();
GraphT interference;
buildInterferenceGraph(buffers, interference);
do {
allocate(buffers, interference);
buildInterferenceGraph(buffers, interference);
} while (!interference.empty());
for (auto &kv : toGroup) {
auto *rep = sharingIdToRep[kv.first];
for (auto *buf : kv.second) {
if (buf != rep) {
buf->setOffsetAligned(rep->offset);
LDBG("-- set sharing buffer's offset "
<< buf->size << " " << buf->offset << " " << buf->sharingGroup);
}
}
}
dumpBuffers();
}
void calculateStarts(const SmallVector<BufferT *> &buffers) {
using TripleMapT = std::multimap<size_t, Interval<size_t>>;
TripleMapT tripleMap;
tripleMap.insert(std::make_pair(0, Interval<size_t>()));
SmallVector<BufferT *> xBuffers = buffers;
while (!xBuffers.empty()) {
auto tripleIt = tripleMap.begin();
auto offset = tripleIt->first;
auto range = tripleIt->second;
tripleMap.erase(tripleIt);
auto bufferIt =
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (auto val : tripleMap)
res = res &&
!val.second.intersects(xRange);
return res;
});
if (bufferIt != xBuffers.end()) {
auto buffer = *bufferIt;
auto xSize = buffer->size;
auto xRange = bufferRange.lookup(buffer);
size_t alignOffset = buffer->setOffsetAligned(offset);
tripleMap.insert({alignOffset + xSize,
Interval{std::max(range.start(), xRange.start()),
std::min(range.end(), xRange.end())}});
if (range.start() < xRange.start())
tripleMap.insert({offset, Interval{range.start(), xRange.end()}});
if (xRange.end() < range.end())
tripleMap.insert({offset, Interval{xRange.start(), range.end()}});
xBuffers.erase(bufferIt);
}
}
}
void buildInterferenceGraph(const SmallVector<BufferT *> &buffers,
GraphT &interference) {
auto inDifferentRegion = [&](BufferT *A, BufferT *B) {
auto tA = A->regionIds;
auto tB = B->regionIds;
if (tA.empty() && tB.empty())
return false;
if (tA.empty() || tB.empty())
return true;
for (auto t1 : tA) {
for (auto t2 : tB) {
if (t1 != t2)
return true;
}
}
return false;
};
interference.clear();
for (auto x : buffers) {
for (auto y : buffers) {
if (x == y)
continue;
auto xStart = x->offset;
auto yStart = y->offset;
auto xSize = x->size;
auto ySize = y->size;
Interval xSizeRange = {xStart, xStart + xSize};
Interval ySizeRange = {yStart, yStart + ySize};
auto xOpRange = bufferRange.lookup(x);
auto yOpRange = bufferRange.lookup(y);
if (xOpRange.intersects(yOpRange) &&
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
if (inDifferentRegion(x, y) && xSizeRange.intersects(ySizeRange))
interference[x].insert(y);
}
}
}
void allocate(const SmallVector<BufferT *> &buffers,
const GraphT &interference) {
allocation->sharedMemorySize = 0;
DenseMap<BufferT *, int> colors;
for (auto value : buffers) {
colors[value] = (value == buffers[0]) ? 0 : -1;
}
SmallVector<bool> available(buffers.size());
for (auto x : buffers) {
std::fill(available.begin(), available.end(), true);
for (auto y : interference.lookup(x)) {
int color = colors[y];
if (color >= 0) {
available[color] = false;
}
}
auto it = std::find(available.begin(), available.end(), true);
colors[x] = std::distance(available.begin(), it);
}
for (auto x : buffers) {
size_t newOffset = 0;
for (auto y : interference.lookup(x)) {
newOffset = std::max(newOffset, y->offset + y->size);
}
if (colors.lookup(x) != 0)
x->setOffsetAligned(newOffset);
allocation->sharedMemorySize =
std::max(allocation->sharedMemorySize, x->offset + x->size);
}
}
private:
Operation *operation;
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
};
}
void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}
std::map<Operation *, SmallVector<Allocation::BufferId>>
Allocation::getLiveBuffers() {
std::map<Operation *, SmallVector<BufferId>> liveBuffers;
Operation *rootOperation = getOperation();
mlir::Liveness liveness(rootOperation);
auto analyzeOperation = [&](Operation *op) -> void {
auto scratchBuffer = getBufferId(op);
if (scratchBuffer != InvalidBufferId)
liveBuffers[op].push_back(scratchBuffer);
for (auto result : op->getOpResults()) {
auto bufferId = getBufferId(result);
if (bufferId == Allocation::InvalidBufferId)
continue;
auto liveOperations = liveness.resolveLiveness(result);
for (auto depOp : liveOperations)
liveBuffers[depOp].push_back(bufferId);
}
};
rootOperation->walk(analyzeOperation);
return liveBuffers;
}
}