#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/MapVector.h"
namespace mlir {
namespace triton {
namespace nvidia_gpu {
namespace ttg = triton::gpu;
#define GEN_PASS_DEF_TRITONTENSORMEMORYALLOCATIONPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
static constexpr int allocGranularity = 64;
struct TMemChunk {
int startRow;
int startCol;
int numCols;
int numRows;
};
struct MemoryBitMap {
MemoryBitMap() : elements(512 * kNumRows, false) {}
void free(const TMemChunk &chunk) {
for (int i = 0; i < chunk.numCols; i++) {
for (int j = 0; j < chunk.numRows; j++) {
setUsed(chunk.startRow + j, chunk.startCol + i, false);
}
}
}
void alloc(const TMemChunk &chunk) {
while ((chunk.startCol + chunk.numCols) * kNumRows >= elements.size())
elements.resize(2 * elements.size(), false);
for (int i = 0; i < chunk.numCols; i++) {
for (int j = 0; j < chunk.numRows; j++) {
setUsed(chunk.startRow + j, chunk.startCol + i, true);
}
}
}
TMemChunk findFirstFit(TMemAllocation allocSize,
std::optional<int> rowIdConstraint,
int columnAlignment) const {
int numRows = allocSize.numRows / allocGranularity;
assert(kNumRows - numRows >= 0);
assert(allocSize.numRows % allocGranularity == 0);
int startCol = 0;
while (1) {
if (startCol % columnAlignment != 0) {
startCol = (startCol / columnAlignment + 1) * columnAlignment;
}
for (int startRow = 0; startRow <= kNumRows - numRows; ++startRow) {
if (rowIdConstraint && *rowIdConstraint != startRow)
continue;
bool fits = true;
for (int i = 0; i < allocSize.numCols && fits; ++i) {
for (int j = 0; j < numRows; ++j) {
if (isUsed(startRow + j, startCol + i)) {
fits = false;
break;
}
}
}
if (fits) {
TMemChunk chunk;
chunk.startRow = startRow;
chunk.startCol = startCol;
chunk.numRows = numRows;
chunk.numCols = allocSize.numCols;
return chunk;
}
}
startCol++;
}
return TMemChunk();
}
private:
bool isUsed(int row, int col) const {
if (row + col * kNumRows >= elements.size())
return false;
return elements[row + col * kNumRows];
}
void setUsed(int row, int col, bool used) {
assert(row + col * kNumRows < elements.size());
elements[row + col * kNumRows] = used;
}
static constexpr int kNumRows = 2;
std::vector<bool> elements;
};
static Interval<int> getLiveIntervals(Value value, Liveness &liveness,
DenseMap<Operation *, int> &operationId) {
auto liveOperations = liveness.resolveLiveness(value);
SmallVector<Operation *> users(value.getUsers());
while (!users.empty()) {
Operation *user = users.pop_back_val();
if (!isa<ttg::MemDescIndexOp, ttg::MemDescReinterpretOp>(user))
continue;
auto usersLivness = liveness.resolveLiveness(user->getResult(0));
liveOperations.insert(liveOperations.end(), usersLivness.begin(),
usersLivness.end());
users.append(user->getResult(0).getUsers().begin(),
user->getResult(0).getUsers().end());
}
auto minId = std::numeric_limits<int>::max();
auto maxId = std::numeric_limits<int>::min();
std::for_each(liveOperations.begin(), liveOperations.end(),
[&](Operation *liveOp) {
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
}
if ((operationId[liveOp] + 1) > maxId) {
maxId = operationId[liveOp] + 1;
}
});
return Interval(minId, maxId);
}
static void updateMap(MemoryBitMap &memoryMap, Interval<int> liveInterval,
std::multimap<int, TMemChunk> &intervalLiverangeEnd) {
int start = liveInterval.start();
for (auto it = intervalLiverangeEnd.begin();
it != intervalLiverangeEnd.end();) {
if (it->first > start)
break;
memoryMap.free(it->second);
it = intervalLiverangeEnd.erase(it);
}
}
static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
TMemAllocation allocSize,
std::optional<int> rowIdConstraint,
ArrayRef<TMemChunk> coexistingChunks,
int columnAlignment) {
MemoryBitMap mapForAlloc = memoryMap;
for (const TMemChunk &chunk : coexistingChunks)
mapForAlloc.alloc(chunk);
TMemChunk chunk =
mapForAlloc.findFirstFit(allocSize, rowIdConstraint, columnAlignment);
memoryMap.alloc(chunk);
return chunk;
}
static SmallVector<Operation *> getAlloc(Value value) {
SmallVector<Operation *> allocs;
DenseSet<Value> seen;
SmallVector<Value> worklist{value};
while (!worklist.empty()) {
Value v = worklist.pop_back_val();
if (!seen.insert(v).second)
continue;
if (auto arg = dyn_cast<BlockArgument>(v)) {
Block *block = arg.getOwner();
Operation *parentOp = block->getParentOp();
if (!block->isEntryBlock()) {
for (Block *pred : block->getPredecessors()) {
Operation *predOp = pred->getTerminator();
auto br = dyn_cast<BranchOpInterface>(predOp);
if (!br) {
llvm::report_fatal_error("unhandled branch op: " +
predOp->getName().getStringRef());
}
SmallVector<Attribute> operands(br->getNumOperands());
auto it = llvm::find(br->getSuccessors(), block);
unsigned idx = std::distance(br->getSuccessors().begin(), it);
SuccessorOperands args = br.getSuccessorOperands(idx);
Value operand =
args.getForwardedOperands()[arg.getArgNumber() -
args.getProducedOperandCount()];
worklist.push_back(operand);
}
continue;
}
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) {
worklist.push_back(
wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]);
} else if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
unsigned idx = arg.getArgNumber() - 1;
worklist.push_back(forOp.getYieldedValues()[idx]);
worklist.push_back(forOp.getInits()[idx]);
} else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp)) {
unsigned idx = arg.getArgNumber();
if (arg.getParentRegion() == &whileOp.getAfter()) {
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
} else {
worklist.push_back(whileOp.getYieldedValues()[idx]);
worklist.push_back(whileOp.getInits()[idx]);
}
} else {
llvm::report_fatal_error(
"unhandled parent op when looking for TMEM alloc: " +
parentOp->getName().getStringRef());
}
continue;
}
Operation *defOp = v.getDefiningOp();
unsigned idx = cast<OpResult>(v).getResultNumber();
if (isa<TMEMAllocOp>(defOp)) {
allocs.push_back(defOp);
} else if (defOp->hasTrait<OpTrait::MemDescViewTrait>()) {
worklist.push_back(defOp->getOperand(0));
} else if (auto sliceOp = dyn_cast<TMEMSubSliceOp>(defOp)) {
worklist.push_back(sliceOp.getSrc());
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
worklist.push_back(selectOp.getTrueValue());
worklist.push_back(selectOp.getFalseValue());
} else if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
worklist.push_back(ifOp.thenYield().getOperand(idx));
worklist.push_back(ifOp.elseYield().getOperand(idx));
} else if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
worklist.push_back(forOp.getYieldedValues()[idx]);
worklist.push_back(forOp.getInits()[idx]);
} else if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
worklist.push_back(whileOp.getConditionOp().getArgs()[idx]);
} else {
llvm::report_fatal_error("unhandled op when looking for TMEM alloc: " +
defOp->getName().getStringRef());
}
}
return allocs;
}
class RowIdConstraints {
llvm::EquivalenceClasses<Operation *> dependentAllocs;
llvm::SmallDenseMap<Operation *, int> rowIndex;
public:
void joinOps(Operation *op1, Operation *op2) {
dependentAllocs.unionSets(op1, op2);
}
std::optional<int> getRowIdConstraint(Operation *op) {
auto it = dependentAllocs.findLeader(op);
if (it == dependentAllocs.member_end())
return std::nullopt;
auto rowIt = rowIndex.find(*it);
if (rowIt == rowIndex.end())
return std::nullopt;
return rowIt->second;
}
void addConstraints(Operation *op, int rowId) {
auto it = dependentAllocs.findLeader(op);
if (it == dependentAllocs.member_end())
return;
rowIndex[*it] = rowId;
}
};
static int
allocateTMem(Operation *parentOp,
DenseMap<triton::nvidia_gpu::TMEMAllocOp, int> &offsets) {
SmallVector<triton::nvidia_gpu::TMEMAllocOp> allocs;
DenseMap<Operation *, int> operationId;
RowIdConstraints rowIdConstraints;
parentOp->walk<WalkOrder::PostOrder>([&](Operation *op) {
operationId[op] = operationId.size();
if (auto alloc = dyn_cast<triton::nvidia_gpu::TMEMAllocOp>(op)) {
allocs.push_back(alloc);
}
if (auto mmaOp = dyn_cast<MMAv5OpInterface>(op)) {
if (isa<TensorMemoryEncodingAttr>(mmaOp.getA().getType().getEncoding())) {
TMemAllocation allocSize = getTmemAllocSizes(mmaOp.getA().getType());
if (allocSize.numRows == 64) {
SmallVector<Operation *> lhsAllocs = getAlloc(mmaOp.getA());
SmallVector<Operation *> accAllocs = getAlloc(mmaOp.getAccumulator());
for (Operation *lhsAlloc : lhsAllocs)
for (Operation *accAlloc : accAllocs)
rowIdConstraints.joinOps(lhsAlloc, accAlloc);
} else {
assert((cast<TensorMemoryEncodingAttr>(
mmaOp.getA().getType().getEncoding())
.getBlockM() != 64 &&
cast<TensorMemoryEncodingAttr>(
mmaOp.getAccumulator().getType().getEncoding())
.getBlockM() != 64) &&
"interleaved layout with TMEM operand is not supported yet.");
}
}
}
});
int totalMemorySize = 0;
MemoryBitMap memoryMap;
Liveness liveness(parentOp);
std::multimap<int, TMemChunk> intervalLiverangeEnd;
DenseMap<TMEMAllocOp, TMemChunk> allocChunks;
for (auto it = allocs.begin(), e = allocs.end(); it != e; ++it) {
TMEMAllocOp alloc = *it;
SmallVector<TMemChunk> coexistingChunks;
if (auto ws = alloc->getParentOfType<triton::gpu::WarpSpecializeOp>()) {
for (auto prevIt = allocs.begin(); prevIt != it; ++prevIt) {
TMEMAllocOp prevAlloc = *prevIt;
auto prevWs =
prevAlloc->getParentOfType<triton::gpu::WarpSpecializeOp>();
if (prevWs && prevWs == ws &&
alloc->getParentRegion() != prevAlloc->getParentRegion())
coexistingChunks.push_back(allocChunks.at(prevAlloc));
}
}
Interval<int> liveInterval = getLiveIntervals(alloc, liveness, operationId);
auto memDescType = alloc.getType();
TMemAllocation allocSize = getTmemAllocSizes(memDescType);
updateMap(memoryMap, liveInterval, intervalLiverangeEnd);
std::optional<int> rowIdConstraint =
rowIdConstraints.getRowIdConstraint(alloc);
const int columnAlignment = 4;
TMemChunk chunkAllocated =
allocFirstFit(memoryMap, allocSize, rowIdConstraint, coexistingChunks,
columnAlignment);
allocChunks.insert({alloc, chunkAllocated});
rowIdConstraints.addConstraints(alloc, chunkAllocated.startRow);
intervalLiverangeEnd.insert({liveInterval.end(), chunkAllocated});
int colOffset = chunkAllocated.startCol;
int rowOffset = chunkAllocated.startRow * 16;
alloc->setAttr(
"tensor_memory_col_offset",
IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32),
colOffset));
alloc->setAttr(
"tensor_memory_row_offset",
IntegerAttr::get(IntegerType::get(parentOp->getContext(), 32),
rowOffset));
totalMemorySize = std::max(totalMemorySize, colOffset + allocSize.numCols);
}
return totalMemorySize;
}
}
class TritonTensorMemoryAllocationPass
: public impl::TritonTensorMemoryAllocationPassBase<
TritonTensorMemoryAllocationPass> {
public:
IntegerAttr getI32Attr(int32_t value) {
return Builder(&getContext()).getI32IntegerAttr(value);
}
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
DenseMap<triton::nvidia_gpu::TMEMAllocOp, int> offsets;
int totalMemorySize = allocateTMem(mod, offsets);
std::array<int, 6> possibleAllocations = {0, 32, 64, 128, 256, 512};
if (totalMemorySize <= 512) {
for (int size : possibleAllocations) {
if (totalMemorySize <= size) {
totalMemorySize = size;
break;
}
}
}
if (totalMemorySize > 0) {
int shared = 0;
if (auto sharedAttr = mod->getAttr("ttg.shared")) {
shared = cast<IntegerAttr>(sharedAttr).getInt();
}
if (shared < 4) {
mod->setAttr("ttg.shared", getI32Attr(4));
}
}
mod->setAttr("ttg.tensor_memory_size", getI32Attr(totalMemorySize));
}
};
}
}
}