#ifndef TRITON_ANALYSIS_ALLOCATION_H
#define TRITON_ANALYSIS_ALLOCATION_H
#include "triton/Analysis/Utility.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
#include <limits>
namespace mlir {
namespace triton {
class AllocationAnalysis;
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
RankedTensorType dstTy);
}
template <typename T> class Interval {
public:
Interval() {}
Interval(T S, T E) : Start(S), End(E) { assert(Start <= End); }
T start() const { return Start; }
T end() const { return End; }
T size() const { return End - Start; }
bool contains(T Addr) const { return Start <= Addr && Addr < End; }
bool intersects(const Interval &R) const {
return Start < R.End && R.Start < End;
}
bool operator==(const Interval &R) const {
return Start == R.Start && End == R.End;
}
bool operator!=(const Interval &R) const { return !(*this == R); }
bool operator<(const Interval &R) const {
return std::make_pair(Start, End) < std::make_pair(R.Start, R.End);
}
private:
T Start = std::numeric_limits<T>::min();
T End = std::numeric_limits<T>::max();
};
template <class T> Interval(T, T) -> Interval<T>;
class Allocation {
public:
using BufferId = size_t;
using BufferIdSetT = DenseSet<BufferId>;
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
Allocation() = default;
explicit Allocation(Operation *operation) : operation(operation) {}
void run(FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
Operation *getOperation() const { return operation; }
size_t getOffset(BufferId bufferId) const {
return bufferSet.at(bufferId).offset;
}
size_t getAllocatedSize(BufferId bufferId) const {
return bufferSet.at(bufferId).size;
}
Interval<size_t> getAllocatedInterval(BufferId bufferId) const {
auto &buffer = bufferSet.at(bufferId);
return Interval<size_t>(buffer.offset, buffer.offset + buffer.size);
}
BufferId getBufferId(Value value) const {
if (valueBuffer.count(value)) {
return valueBuffer.lookup(value)->id;
} else {
return InvalidBufferId;
}
}
BufferIdSetT getBufferIds(Value value) const {
BufferIdSetT bufferIds;
auto allocBufferId = getBufferId(value);
if (allocBufferId != InvalidBufferId)
bufferIds.insert(allocBufferId);
for (auto *buffer : aliasBuffer.lookup(value)) {
if (buffer->id != InvalidBufferId)
bufferIds.insert(buffer->id);
}
return bufferIds;
}
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else if (opVirtual.count(operation)) {
return opVirtual.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
bool isVirtualBuffer(BufferId bufferId) const {
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
}
size_t getSharedMemorySize() const { return sharedMemorySize; }
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();
private:
struct BufferT {
enum class BufferKind { Explicit, Scratch, Virtual };
BufferKind kind;
BufferId id;
Operation *owner;
size_t size;
size_t alignment;
size_t offset;
bool operator==(const BufferT &other) const { return id == other.id; }
bool operator<(const BufferT &other) const { return id < other.id; }
BufferT(BufferKind kind, BufferId id, Operation *owner, size_t size,
size_t alignment = 4, size_t offset = 0)
: kind(kind), id(id), owner(owner), size(size), alignment(alignment),
offset(offset) {}
size_t setOffsetAligned(size_t newOffset) {
return offset = llvm::alignTo(newOffset, alignment);
}
};
using OpScratchMapT = llvm::MapVector<Operation *, BufferT *>;
using ValueBufferMapT = llvm::MapVector<Value, BufferT *>;
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
using BufferSetT = std::map<BufferId, BufferT>;
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
void addBuffer(KeyType &key, Args &&...args) {
BufferId nextId = bufferIdCounter++;
auto [it, inserted] = bufferSet.insert_or_assign(
nextId, BufferT(Kind, nextId, key, std::forward<Args>(args)...));
BufferT *buffer = &it->second;
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = buffer;
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
opVirtual[key] = buffer;
} else {
opScratch[key] = buffer;
}
}
void addAlias(Value value, Value alloc) {
aliasBuffer[value].insert(valueBuffer[alloc]);
}
private:
Operation *operation = nullptr;
OpScratchMapT opScratch;
OpScratchMapT opVirtual;
ValueBufferMapT valueBuffer;
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
size_t sharedMemorySize = 0;
size_t bufferIdCounter = 0;
friend class triton::AllocationAnalysis;
};
class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
ModuleAllocation(ModuleOp moduleOp,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
triton::defaultAllocationAnalysisScratchSizeFn)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap, scratchSizeGetter);
});
}
size_t getSharedMemorySize() {
size_t size = 0;
for (auto funcOp : getRoots()) {
auto *alloc = getFuncData(funcOp);
size = std::max(size, alloc->getSharedMemorySize());
}
return size;
}
size_t getSharedMemorySize(FunctionOpInterface funcOp) {
return getFuncData(funcOp)->getSharedMemorySize();
}
void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) {
sharedMemoryValue[funcOp] = value;
}
Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) {
return sharedMemoryValue[funcOp];
}
private:
FuncOffsetMapT sharedMemoryValue;
};
}
#endif