#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
template <typename OpTy>
static bool isContiguousXferOp(OpTy op) {
return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
op.hasPureBufferSemantics() &&
isLastMemrefDimUnitStride(
cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
}
static bool isContiguousStore(Operation *write) {
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
return isa<vector::StoreOp>(write);
}
static bool isContiguousRead(Operation *read) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
return isContiguousXferOp(transferRead);
return isa<vector::LoadOp>(read);
}
namespace {
struct TransferMask {
vector::CreateMaskOp createMaskOp;
SmallVector<int64_t> extractPosition;
};
}
static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
if (!transferRead || !transferRead.getMask())
return TransferMask{{}, {}};
assert(transferRead.getMask().getType().getRank() == 1 &&
"expected 1-D mask");
if (auto maskOp =
transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp, {}};
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};
return failure();
}
static Value buildNumReadElements(OpBuilder &b, Location loc,
Operation *readOp) {
FailureOr<TransferMask> transferMask = getMaskOp(readOp);
assert(succeeded(transferMask) && "invalid transfer mask");
if (!transferMask->createMaskOp)
return Value();
if (transferMask->extractPosition.empty()) {
assert(transferMask->createMaskOp.getNumOperands() == 1 &&
"expected single operand");
return transferMask->createMaskOp.getOperand(0);
}
assert(transferMask->createMaskOp.getVectorType().getRank() -
transferMask->extractPosition.size() ==
1 &&
"expected N-D -> (N-1)-D extract");
Value cond;
for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
transferMask->createMaskOp->getOperands())) {
Value cmp =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
b.create<arith::ConstantIndexOp>(loc, pos), sz);
if (!cond) {
cond = cmp;
continue;
}
cond = b.create<arith::AndIOp>(loc, cmp, cond);
}
return b.create<arith::SelectOp>(
loc, cond, transferMask->createMaskOp->getOperands().back(),
b.create<arith::ConstantIndexOp>(loc, 0));
}
static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
VectorType vecType) {
assert(vecType.getRank() == 1 && "expected 1-D vector");
constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
bool supportedCopySize = false;
int64_t numElements = vecType.getNumElements();
Type elementType = vecType.getElementType();
for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
if (alignmentInBytes * 8 ==
numElements * elementType.getIntOrFloatBitWidth()) {
supportedCopySize = true;
break;
}
}
if (!supportedCopySize)
return false;
return true;
}
void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
bool bypassL1) {
llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
op->walk([&](Operation *writeOp) {
if (!isContiguousStore(writeOp))
return;
Value vectorVal = nvgpu::getValueStored(writeOp);
if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
return;
Value storeBase = nvgpu::getMemrefOperand(writeOp);
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(storeBase.getType())))
return;
Operation *readOp = vectorVal.getDefiningOp();
if (readOp == nullptr || !isContiguousRead(readOp))
return;
Value loadBase = nvgpu::getMemrefOperand(readOp);
if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(loadBase.getType())))
return;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
if (Value mask = transferRead.getMask()) {
if (getConstantIntValue(transferRead.getPadding()) ==
static_cast<int64_t>(0))
return;
if (failed(getMaskOp(readOp)))
return;
}
}
VectorType vecType = cast<VectorType>(vectorVal.getType());
if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
vecType) ||
!resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
vecType))
return;
copyToSharedMem.insert(writeOp);
return;
});
while (!copyToSharedMem.empty()) {
SmallVector<Operation *> group;
Operation *writeOp = *copyToSharedMem.begin();
copyToSharedMem.remove(writeOp);
group.push_back(writeOp);
Operation *nextNode = writeOp;
while ((nextNode = nextNode->getNextNode())) {
auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
if (memInterface && memInterface.hasNoEffect() &&
!nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
continue;
if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
Operation *readOp = nextNode;
Value memrefOperand = nvgpu::getMemrefOperand(readOp);
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(memrefOperand.getType()))) {
continue;
}
}
if (copyToSharedMem.count(nextNode)) {
copyToSharedMem.remove(nextNode);
group.push_back(nextNode);
continue;
}
break;
}
SmallVector<Value> tokens;
for (Operation *writeOp : group) {
rewriter.setInsertionPoint(writeOp);
Value vectorVal = nvgpu::getValueStored(writeOp);
auto vectorType = cast<VectorType>(vectorVal.getType());
int64_t numElements = vectorType.getNumElements();
Operation *readOp = vectorVal.getDefiningOp();
Value storeBase = nvgpu::getMemrefOperand(writeOp);
Value loadBase = nvgpu::getMemrefOperand(readOp);
Value numReadElements =
buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
auto dstMemref = cast<MemRefType>(storeBase.getType());
int64_t sizeInBytes =
(dstMemref.getElementTypeBitWidth() * numElements) / 8;
Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
storeBase, nvgpu::getIndices(writeOp),
loadBase,
nvgpu::getIndices(readOp),
rewriter.getIndexAttr(numElements),
numReadElements,
bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
: UnitAttr());
tokens.push_back(token);
}
Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
tokens);
rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
nullptr);
for (Operation *writeOp : group)
rewriter.eraseOp(writeOp);
}
}