#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace nvgpu {
#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::nvgpu;
constexpr int64_t kSharedMemoryLineSizeBytes = 128;
constexpr int64_t kDefaultVectorSizeBits = 128;
static Value permuteVectorOffset(OpBuilder &b, Location loc,
ArrayRef<Value> indices, MemRefType memrefTy,
int64_t srcDim, int64_t tgtDim) {
Value src = indices[srcDim];
const int64_t permuteEveryN = std::max<int64_t>(
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
memrefTy.getElementTypeBitWidth()) /
8));
int64_t n =
llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
int64_t mask = (1LL << (m - n)) - 1;
if (permuteEveryN > 1)
mask = mask << llvm::Log2_64(permuteEveryN);
Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
if (permuteEveryN > 1) {
int64_t shlBits = n - llvm::Log2_64(permuteEveryN);
if (shlBits > 0) {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
} else if (shlBits < 0) {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
}
} else {
Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, n);
srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
}
Value permutedVectorIdx =
b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
return permutedVectorIdx;
}
static void transformIndices(OpBuilder &builder, Location loc,
SmallVector<Value, 4> &indices,
MemRefType memrefTy, int64_t srcDim,
int64_t tgtDim) {
indices[tgtDim] =
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
}
static LogicalResult
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
SmallVector<Operation *, 16> &readOps,
SmallVector<Operation *, 16> &writeOps) {
parentOp->walk([&](Operation *op) {
MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return;
std::optional<MemoryEffects::EffectInstance> effect =
iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
if (effect) {
readOps.push_back(op);
return;
}
effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
if (effect)
writeOps.push_back(op);
});
if (llvm::any_of(readOps, [](Operation *op) {
return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
getIndices(op).size() < 2;
}))
return failure();
if (llvm::any_of(writeOps, [](Operation *op) {
return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
op) ||
getIndices(op).size() < 2;
}))
return failure();
return success();
}
llvm::LogicalResult
mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
bool hasSubView = false;
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
if (hasSubView)
return failure();
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
const int64_t rowsPerLine =
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
rowSize;
const int64_t threadGroupSize =
1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
if (rowsPerLine >= threadGroupSize)
return failure();
SmallVector<Operation *, 16> shmReadOps;
SmallVector<Operation *, 16> shmWriteOps;
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
shmWriteOps)))
return failure();
if (shmReadOps.empty() || shmWriteOps.empty())
return failure();
OpBuilder builder(parentOp->getContext());
int64_t tgtDim = memRefType.getRank() - 1;
int64_t srcDim = memRefType.getRank() - 2;
while (!shmWriteOps.empty()) {
Operation *shmWriteOp = shmWriteOps.back();
shmWriteOps.pop_back();
builder.setInsertionPoint(shmWriteOp);
auto indices = getIndices(shmWriteOp);
SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
setIndices(shmWriteOp, transformedIndices);
}
while (!shmReadOps.empty()) {
Operation *shmReadOp = shmReadOps.back();
shmReadOps.pop_back();
builder.setInsertionPoint(shmReadOp);
auto indices = getIndices(shmReadOp);
SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
setIndices(shmReadOp, transformedIndices);
}
return success();
}
namespace {
class OptimizeSharedMemoryPass
: public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
OptimizeSharedMemoryPass() = default;
void runOnOperation() override {
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
op->walk([&](memref::AllocOp allocOp) {
if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
return;
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
allocOp.getMemref())))
return;
}
}
};
}
std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
return std::make_unique<OptimizeSharedMemoryPass>();
}