#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "llvm/ADT/AddressRanges.h"
namespace ttg = mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONNVIDIAGPUINTERLEAVETMEMPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
void addAllValuelessEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
}
bool collectEffects(Operation *op,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<MemoryEffects::EffectInstance> localEffects;
iface.getEffects(localEffects);
llvm::append_range(effects, localEffects);
return true;
}
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
for (auto &innerOp : block)
if (!collectEffects(&innerOp, effects))
return false;
}
}
return true;
}
addAllValuelessEffects(effects);
return false;
}
struct AccessRange {
SmallVector<std::optional<llvm::AddressRange>> ranges;
unsigned rankOffset = 0;
};
std::pair<Value, AccessRange> findBufferAccess(Value a);
std::pair<Value, AccessRange>
findBufferAccessMemdescSubview(Operation *subview) {
OpBuilder builder(subview);
Location loc = subview->getLoc();
TypedValue<ttg::MemDescType> src;
SmallVector<int64_t> shape;
SmallVector<Value> offsets;
if (auto indexOp = dyn_cast<ttg::MemDescIndexOp>(subview)) {
src = indexOp.getSrc();
shape = to_vector(indexOp.getType().getShape());
offsets = {indexOp.getIndex()};
for (auto i : llvm::seq(std::max<int>(0, shape.size() - 1)))
offsets.push_back(builder.create<arith::ConstantIntOp>(loc, 0, 32));
} else {
auto subsliceOp = cast<ttg::MemDescSubsliceOp>(subview);
src = subsliceOp.getSrc();
shape = to_vector(subsliceOp.getType().getShape());
for (auto offset : subsliceOp.getOffsets())
offsets.push_back(builder.create<arith::ConstantIntOp>(loc, offset, 32));
}
auto [alloc, parentAccess] = findBufferAccess(src);
if (!alloc)
return {};
AccessRange childAccess;
for (auto i : llvm::seq(parentAccess.rankOffset))
childAccess.ranges.push_back(parentAccess.ranges[i]);
childAccess.rankOffset = src.getType().getRank() - shape.size();
for (auto [i, offset] : llvm::enumerate(offsets)) {
auto parentRange = parentAccess.ranges[i + parentAccess.rankOffset];
if (!parentRange) {
childAccess.ranges.push_back({});
continue;
}
APInt value;
if (!matchPattern(offset, m_ConstantInt(&value))) {
childAccess.ranges.push_back({});
continue;
}
uint64_t accessStart = parentRange->start() + value.getSExtValue();
uint64_t accessSize = 1;
if (i >= childAccess.rankOffset)
accessSize = shape[i - childAccess.rankOffset];
childAccess.ranges.push_back({{accessStart, accessStart + accessSize}});
}
return {alloc, std::move(childAccess)};
}
std::pair<Value, AccessRange> findBufferAccess(Value a) {
if (auto arg = dyn_cast<BlockArgument>(a)) {
Operation *parentOp = arg.getOwner()->getParentOp();
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(parentOp)) {
return findBufferAccess(
wsOp.getParentOp().getExplicitCaptures()[arg.getArgNumber()]);
}
return {};
}
Operation *defOp = a.getDefiningOp();
if (auto alloc = dyn_cast<TMEMAllocOp>(defOp)) {
AccessRange access;
for (uint64_t dim : alloc.getType().getShape())
access.ranges.push_back({{0, dim}});
return {a, std::move(access)};
}
if (isa<ttg::MemDescTransOp, ttg::MemDescReshapeOp>(defOp)) {
return findBufferAccess(defOp->getOperand(0));
}
if (isa<ttg::MemDescIndexOp, ttg::MemDescSubsliceOp>(defOp)) {
return findBufferAccessMemdescSubview(defOp);
}
if (auto subslice = dyn_cast<TMEMSubSliceOp>(defOp)) {
auto [alloc, parentAccess] = findBufferAccess(subslice.getSrc());
if (!alloc)
return {};
if (!parentAccess.ranges[1])
return {alloc, parentAccess};
uint64_t mStart = parentAccess.ranges[1]->start() + subslice.getN();
uint64_t mSize = subslice.getType().getShape()[1];
AccessRange childAccess = parentAccess;
childAccess.ranges[1] = {{mStart, mStart + mSize}};
return {alloc, std::move(childAccess)};
}
return {};
}
bool tmemMayAlias(Value a, Value b) {
auto [aAlloc, aRanges] = findBufferAccess(a);
auto [bAlloc, bRanges] = findBufferAccess(b);
if (!aAlloc || !bAlloc)
return true;
if (aAlloc != bAlloc)
return false;
for (auto [aRange, bRange] : llvm::zip(aRanges.ranges, bRanges.ranges)) {
if (!aRange || !bRange)
continue;
if (!aRange->intersects(*bRange))
return false;
}
return true;
}
bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
Operation *insertBefore = nullptr;
Operation *next = useChain.back()->getNextNode();
while (next && !next->hasTrait<OpTrait::IsTerminator>()) {
insertBefore = next;
bool dep = false;
for (auto operand : getNestedOperands(next)) {
if (llvm::any_of(useChain, [&](Operation *op) {
return llvm::is_contained(op->getResults(), operand);
})) {
dep = true;
break;
}
}
if (isa<ArriveBarrierOp>(next))
break;
if (!isMemoryEffectFree(next)) {
SmallVector<MemoryEffects::EffectInstance> effects;
collectEffects(next, effects);
for (auto effect : effects) {
if (!isa<MemoryEffects::Write, MemoryEffects::Free>(effect.getEffect()))
continue;
if (isa<SideEffects::DefaultResource>(effect.getResource())) {
dep = true;
break;
}
if (isa<TensorMemory>(effect.getResource()) &&
(!effect.getValue() || tmemMayAlias(effect.getValue(), buffer))) {
dep = true;
break;
}
}
}
if (dep)
break;
next = next->getNextNode();
}
if (insertBefore && insertBefore != useChain.back()->getNextNode()) {
for (Operation *op : useChain)
op->moveBefore(insertBefore);
return true;
}
return false;
}
bool trySinkOp(Operation *op, Value buffer) {
SmallVector<Operation *> useChain{op};
while (useChain.back()->hasOneUse() &&
isPure(*useChain.back()->user_begin()) &&
useChain.back()->getNextNode() == *useChain.back()->user_begin()) {
useChain.push_back(*useChain.back()->user_begin());
}
return sinkOps(buffer, useChain);
}
}
struct TritonNvidiaGPUInterleaveTMemPass
: public impl::TritonNvidiaGPUInterleaveTMemPassBase<
TritonNvidiaGPUInterleaveTMemPass> {
using impl::TritonNvidiaGPUInterleaveTMemPassBase<
TritonNvidiaGPUInterleaveTMemPass>::TritonNvidiaGPUInterleaveTMemPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
SmallVector<std::pair<Operation *, Value>> opsToSink;
m.walk([&](Operation *op) {
if (auto load = dyn_cast<TMEMLoadOp>(op))
opsToSink.emplace_back(load, load.getSrc());
else if (auto alloc = dyn_cast<TMEMAllocOp>(op))
opsToSink.emplace_back(alloc, alloc.getResult());
});
for (auto [op, buffer] : opsToSink) {
while (trySinkOp(op, buffer)) {
}
}
}
};
}
}
}