#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::nvgpu;
using namespace mlir::NVVM;
using namespace mlir::transform;
#define DEBUG_TYPE "nvgpu-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
populateGpuMemorySpaceAttributeConversions(
llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
llvmTypeConverter.addConversion(
[&](nvgpu::DeviceAsyncTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 32));
});
llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
llvmTypeConverter.addConversion(
[&](nvgpu::WarpgroupAccumulatorType type) -> Type {
Type elemType = type.getFragmented().getElementType();
int64_t sizeM = type.getFragmented().getDimSize(0);
int64_t sizeN = type.getFragmented().getDimSize(1);
unsigned numMembers;
if (elemType.isF32() || elemType.isInteger(32))
numMembers = sizeN / 2;
else if (elemType.isF16())
numMembers = sizeN / 4;
else
llvm_unreachable("unsupported type for warpgroup accumulator");
SmallVector<Type> innerStructBody;
for (unsigned i = 0; i < numMembers; i++)
innerStructBody.push_back(elemType);
auto innerStructType = LLVM::LLVMStructType::getLiteral(
type.getContext(), innerStructBody);
SmallVector<Type> structBody;
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
structBody.push_back(innerStructType);
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return llvmTypeConverter.convertType(convertedType);
});
llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
});
llvmTypeConverter.addConversion(
[&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
llvmTypeConverter.addConversion(
[&](nvgpu::TensorMapDescriptorType type) -> Type {
return LLVM::LLVMPointerType::get(type.getContext());
});
populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
}
LogicalResult
transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
}
void transform::CreateAsyncGroupsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTargetMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, TransformState &state) {
nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
static bool hasDefaultMemorySpace(BaseMemRefType type) {
return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
}
static bool hasSharedMemorySpace(BaseMemRefType type) {
auto space =
dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
return space &&
space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
}
static Value getValueLoadedFromGlobal(Operation *op) {
auto load = dyn_cast<vector::TransferReadOp>(op);
if (!load)
return nullptr;
auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
if (!loadType || !hasDefaultMemorySpace(loadType))
return nullptr;
return load;
}
static bool isStoreToShared(Operation *op, Value v) {
auto store = dyn_cast<vector::TransferWriteOp>(op);
if (!store || store.getVector() != v)
return false;
auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
return storeType || hasSharedMemorySpace(storeType);
}
static bool isLoadFromGlobalStoredToShared(Operation *op) {
Value loaded = getValueLoadedFromGlobal(op);
if (!loaded || !loaded.hasOneUse())
return false;
return isStoreToShared(*loaded.getUsers().begin(), loaded);
}
static LogicalResult
collectStage0PipeliningOps(scf::ForOp forOp,
llvm::SmallPtrSet<Operation *, 16> &ops) {
llvm::SmallPtrSet<Operation *, 4> barriers;
for (Operation &op : *forOp.getBody()) {
if (op.getNumRegions() > 0)
return failure();
if (isa<gpu::BarrierOp>(op)) {
barriers.insert(&op);
continue;
}
if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
ops.insert(&op);
ops.insert(std::make_move_iterator(barriers.begin()),
std::make_move_iterator(barriers.end()));
assert(barriers.empty() &&
"expected to have moved the barriers into another set");
continue;
}
if (isLoadFromGlobalStoredToShared(&op)) {
ops.insert(&op);
continue;
}
}
return success();
}
static void
setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
scf::PipeliningOption::PipelinerPart part,
unsigned iteration, unsigned depth) {
auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
if (!waitOp || waitOp.getNumGroups())
return;
int numGroupInFlight = 0;
if (part == scf::PipeliningOption::PipelinerPart::Kernel ||
part == scf::PipeliningOption::PipelinerPart::Prologue) {
numGroupInFlight = depth - 1;
} else {
assert(part == scf::PipeliningOption::PipelinerPart::Epilogue);
numGroupInFlight = depth - 1 - iteration;
}
waitOp.setNumGroups(numGroupInFlight);
}
static void getPipelineStages(
scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
SetVector<Operation *> dependencies;
BackwardSliceOptions options([&](Operation *visited) {
return visited->getBlock() == forOp.getBody();
});
options.inclusive = true;
for (Operation &op : forOp.getBody()->getOperations()) {
if (stage0Ops.contains(&op))
getBackwardSlice(&op, &dependencies, options);
}
for (Operation &op : forOp.getBody()->getOperations()) {
if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
opsWithPipelineStages.emplace_back(&op, depth);
}
for (Operation &op : forOp.getBody()->getOperations()) {
if (dependencies.contains(&op))
opsWithPipelineStages.emplace_back(&op, 0);
}
}
static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
Operation *op, Value predicate) {
if (isMemoryEffectFree(op) ||
isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
nvgpu::DeviceAsyncWaitOp>(op)) {
return op;
}
auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
if (!asyncCopyOp)
return nullptr;
Location loc = asyncCopyOp->getLoc();
Value dstElements =
rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
Value originalSrcElement =
asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto srcElements = rewriter.create<arith::SelectOp>(
loc, predicate, originalSrcElement, c0Index);
auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
UnitAttr());
rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
return asyncCopyZeroFillOp;
}
static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
bool epiloguePeeling) {
llvm::SmallPtrSet<Operation *, 16> stage0Ops;
if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
return std::make_tuple(
emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
scf::ForOp());
}
if (stage0Ops.empty()) {
return std::make_tuple(
emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
}
scf::PipeliningOption options;
unsigned maxDepth = depth;
auto setAnnotation = [&](Operation *op,
scf::PipeliningOption::PipelinerPart part,
unsigned iteration) {
return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
};
options.getScheduleFn =
[&](scf::ForOp schedulingFor,
std::vector<std::pair<Operation *, unsigned>> &ops) {
if (schedulingFor != forOp)
return;
return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
};
options.annotateFn = setAnnotation;
if (!epiloguePeeling) {
options.peelEpilogue = false;
options.predicateFn = replaceOpWithPredicatedOp;
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forOp);
bool modifiedIR;
FailureOr<scf::ForOp> maybePipelined =
pipelineForLoop(rewriter, forOp, options, &modifiedIR);
if (succeeded(maybePipelined)) {
return std::make_tuple(DiagnosedSilenceableFailure::success(),
*maybePipelined);
}
return std::make_tuple(
modifiedIR
? DiagnosedSilenceableFailure::definiteFailure()
: emitSilenceableFailure(forOp, "pipelining preconditions failed"),
scf::ForOp());
}
DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
TransformRewriter &rewriter, scf::ForOp forOp,
ApplyToEachResultList &results, TransformState &state) {
auto [diag, pipelined] = pipelineForSharedCopies(
rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
if (diag.succeeded()) {
results.push_back(pipelined);
return DiagnosedSilenceableFailure::success();
}
if (diag.isDefiniteFailure()) {
auto diag = emitDefiniteFailure("irreversible pipelining failure");
if (!getPeelEpilogue()) {
diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
}
return diag;
}
return std::move(diag);
}
struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
RowColIndexing(AffineExpr row, AffineExpr col)
: std::pair<AffineExpr, AffineExpr>(row, col) {}
AffineExpr row() const { return first; };
AffineExpr col() const { return second; };
void print(llvm::raw_ostream &os) const {
os << "- indexing: " << first << ", " << second;
}
};
struct MmaSyncBuilder {
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
: b(b), loc(loc), laneId(laneId) {}
using IndexCalculator =
std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
private:
struct MmaSyncInfo {
std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
vectorShapes;
SmallVector<int64_t> mmaShape;
bool tf32Enabled;
};
FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes);
static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{groupID, threadIDInGroup},
RowColIndexing{groupID + 8, threadIDInGroup}};
}
static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{threadIDInGroup, groupID}};
}
static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
RowColIndexing{groupID, threadIDInGroup * 2 + 1},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
}
static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {
RowColIndexing{groupID, threadIDInGroup * 2 + 0},
RowColIndexing{groupID, threadIDInGroup * 2 + 1},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},
RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},
RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}
};
}
static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {
RowColIndexing{threadIDInGroup * 2 + 0, groupID},
RowColIndexing{threadIDInGroup * 2 + 1, groupID},
RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},
RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}
};
}
static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {
RowColIndexing{groupID, threadIDInGroup * 2 + 0},
RowColIndexing{groupID, threadIDInGroup * 2 + 1},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}
};
}
SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn);
Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
IndexCalculator indexFn,
ArrayRef<int64_t> vectorShape);
SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
ValueRange toStore,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn);
SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
OpBuilder &b;
Location loc;
OpFoldResult laneId;
};
template <typename ApplyFn, typename ReduceFn>
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
ReduceFn reduceFn) {
VectorType vectorType = cast<VectorType>(vector.getType());
auto vectorShape = vectorType.getShape();
auto strides = computeStrides(vectorShape);
for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
auto indices = delinearize(idx, strides);
reduceFn(applyFn(vector, idx, indices), idx, indices);
}
}
SmallVector<Value>
MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn) {
auto aff = [&](AffineExpr e) {
return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
};
SmallVector<Value> res;
SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
for (auto indexing : indexings) {
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
res.push_back(load);
}
return res;
}
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
foreachIndividualVectorElement(
res,
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
return loads[linearIdx];
},
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
res = b.create<vector::InsertOp>(loc, v, res, indices);
});
return res;
}
SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
Value memref, const IndexCalculator &indexFn) {
auto aff = [&](AffineExpr e) {
return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
};
SmallVector<Operation *> res;
for (auto [indexing, val] :
llvm::zip_equal(indexFn(b.getContext()), toStore)) {
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
Operation *store =
b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
res.push_back(store);
}
return res;
}
SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
SmallVector<Value> toStore;
toStore.reserve(32);
foreachIndividualVectorElement(
vectorToStore,
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
},
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
toStore.push_back(v);
});
return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
}
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
SmallVector<int64_t>>
makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
ArrayRef<int64_t> res) {
SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
SmallVector<int64_t> vres{res.begin(), res.end()};
return std::make_tuple(vlhs, vrhs, vres);
}
FailureOr<MmaSyncBuilder::MmaSyncInfo>
MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes) {
Type f16 = b.getF16Type();
Type f32 = b.getF32Type();
if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
elementalTypes == TypeRange{f32, f32, f32}) {
return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
&MmaSyncBuilder::m16n8k4tf32Rhs,
&MmaSyncBuilder::m16n8k4tf32Res),
makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
SmallVector<int64_t>{opShape.begin(), opShape.end()},
true};
}
if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
elementalTypes == TypeRange{f16, f16, f16}) {
return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
&MmaSyncBuilder::m16n8k16f16Rhs,
&MmaSyncBuilder::m16n8k16f16Res),
makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
SmallVector<int64_t>{opShape.begin(), opShape.end()},
false};
}
return failure();
}
FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
"expected lhs to be a 2D memref");
assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
"expected rhs to be a 2D memref");
assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
"expected res to be a 2D memref");
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
Type resType = getElementTypeOrSelf(resMemRef.getType());
FailureOr<MmaSyncInfo> maybeInfo =
getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
if (failed(maybeInfo))
return failure();
MmaSyncInfo info = *maybeInfo;
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
lhsIndexFn, lhsShape);
Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
rhsIndexFn, rhsShape);
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
info.tf32Enabled);
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
}
DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp linalgOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
bool fail = true;
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
Location loc = linalgOp.getLoc();
Value laneId = rewriter.create<gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), gpu::Dimension::x);
if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
fail = false;
}
if (fail) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "unsupported target op: " << linalgOp;
diag.attachNote(linalgOp->getLoc()) << "target op";
return diag;
}
rewriter.eraseOp(linalgOp);
return DiagnosedSilenceableFailure::success();
}
struct HopperBuilder {
HopperBuilder(RewriterBase &rewriter, Location loc)
: rewriter(rewriter), loc(loc) {}
TypedValue<nvgpu::MBarrierGroupType>
buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
TypedValue<nvgpu::TensorMapDescriptorType>
buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp);
OpFoldResult
buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
TypedValue<MemRefType> sharedMemref,
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps);
void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> sizes);
SmallVector<Operation *> buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier);
void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
RewriterBase &rewriter;
Location loc;
};
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier) {
SmallVector<Operation *> loadOps;
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value cond =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
rewriter.create<scf::IfOp>(
loc,
cond,
[&](OpBuilder &lb, Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(globalDescriptors.size());
for (auto [desc, shmem] : llvm::zip_equal(
globalDescriptors, sharedMemBuffers)) {
OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
sizes.push_back(sz);
}
buildBarrierArriveTx(barrier, sizes);
rewriter.create<scf::YieldOp>(loc);
},
[&](OpBuilder &lb, Location loc) {
buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
rewriter.create<scf::YieldOp>(loc);
});
return loadOps;
}
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
return gpu::AddressSpaceAttr::get(
b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
}
TypedValue<nvgpu::MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
loc,
nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierInitOp>(
loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
zero, Value());
rewriter.create<gpu::BarrierOp>(loc);
return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
}
TypedValue<nvgpu::TensorMapDescriptorType>
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(launchOp);
Value unrankedMemRef = rewriter.create<memref::CastOp>(
loc,
UnrankedMemRefType::get(memref.getType().getElementType(),
memref.getType().getMemorySpace()),
memref);
SmallVector<OpFoldResult> mixedSizes =
memref::getMixedSizes(rewriter, loc, memref);
SmallVector<Value> sizes =
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
loc,
nvgpu::TensorMapDescriptorType::get(
rewriter.getContext(),
MemRefType::Builder(memref.getType())
.setMemorySpace(sharedMemorySpace),
TensorMapSwizzleKind::SWIZZLE_NONE,
TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
TensorMapInterleaveKind::INTERLEAVE_NONE),
unrankedMemRef, sizes);
return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
}
OpFoldResult HopperBuilder::buildTmaAsyncLoad(
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
TypedValue<MemRefType> sharedMemref,
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps) {
MLIRContext *ctx = rewriter.getContext();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
AffineExpr prodExprInBytes =
computeProduct(ctx, symbols) *
(sharedMemref.getType().getElementTypeBitWidth() / 8);
auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
prodExprInBytes, mixedSizes);
return res;
}
void HopperBuilder::buildBarrierArriveTx(
TypedValue<nvgpu::MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> mixedSizes) {
assert(!mixedSizes.empty() && "expecte non-empty sizes");
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> symbols(mixedSizes.size());
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
AffineExpr sumExpr = computeSum(ctx, symbols);
OpFoldResult size =
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
Value());
}
void HopperBuilder::buildTryWaitParity(
TypedValue<nvgpu::MBarrierGroupType> barrier) {
Type i1 = rewriter.getI1Type();
Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
Value ticksBeforeRetry =
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
ticksBeforeRetry, zero);
}
struct CopyBuilder : public HopperBuilder {
CopyBuilder(RewriterBase &rewriter, Location loc)
: HopperBuilder(rewriter, loc) {}
SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
};
SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
MLIRContext *ctx = rewriter.getContext();
if (copyOps.empty())
return SmallVector<Operation *>();
auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
assert(launchOp && "expected launch op");
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(copyOps.front());
AffineExpr bx, by, bz;
bindSymbols(ctx, bx, by, bz);
AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
rewriter, loc, prod,
ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
launchOp.getBlockSizeZ()});
TypedValue<nvgpu::MBarrierGroupType> barrier =
buildAndInitBarrierInSharedMemory(numThreads);
SmallVector<TypedValue<MemRefType>> shmems;
SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
for (Operation *op : copyOps) {
auto copyOp = cast<linalg::CopyOp>(op);
auto inMemRef =
cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
assert(inMemRef.getType().getRank() == 2 &&
"expected in to be a 2D memref");
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
buildGlobalMemRefDescriptor(inMemRef, launchOp);
globalDescs.push_back(globalDesc);
auto shmem =
cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
shmems.push_back(shmem);
}
OpBuilder::InsertionGuard g2(rewriter);
rewriter.setInsertionPoint(copyOps.front());
SmallVector<Operation *> results =
buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
buildTryWaitParity(barrier);
for (Operation *op : copyOps)
rewriter.eraseOp(op);
return results;
}
DiagnosedSilenceableFailure
transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
gpu::LaunchOp commonLaunchOp;
Operation *firstOp, *failingOp;
if (llvm::any_of(payloadOps, [&](Operation *op) {
if (!commonLaunchOp) {
commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
firstOp = op;
}
auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
!isa<linalg::CopyOp>(op);
if (fail)
failingOp = op;
return fail;
})) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "target ops must be linalg::CopyOp nested under a common "
"gpu.LaunchOp to be rewritten because the tma descriptors need to "
"be created on the host.\nBut got: "
<< *firstOp << "\nand " << *failingOp;
return diag;
}
CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
return DiagnosedSilenceableFailure::success();
}
namespace {
class NVGPUTransformDialectExtension
: public transform::TransformDialectExtension<
NVGPUTransformDialectExtension> {
public:
NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
declareGeneratedDialect<NVVM::NVVMDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
>();
}
};
}
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) {
registry.addExtensions<NVGPUTransformDialectExtension>();
}