#include "NVGPUToLLVM/NVGPUToLLVMPass.h"
#include "NVGPUToLLVM/Passes.h"
#include "Dialect/NVGPU/IR/Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h"
#include "llvm/Support/ErrorHandling.h"
namespace ttn = mlir::triton::nvgpu;
using ttn::Constraints;
using ttn::OperandsAndConstraints;
namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTNVGPUTOLLVM
#include "NVGPUToLLVM/Passes.h.inc"
namespace {
bool isNumber(const std::string &s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) {
return !std::isdigit(c);
}) == s.end();
}
Type getTypeFromConstraint(char constraint, PatternRewriter &rewriter) {
Type ty;
if (constraint == 'b')
ty = IntegerType::get(rewriter.getContext(), 1);
else if (constraint == 'h')
ty = IntegerType::get(rewriter.getContext(), 16);
else if (constraint == 'r')
ty = IntegerType::get(rewriter.getContext(), 32);
else if (constraint == 'l')
ty = IntegerType::get(rewriter.getContext(), 64);
else if (constraint == 'f')
ty = Float32Type::get(rewriter.getContext());
else if (constraint == 'd')
ty = Float64Type::get(rewriter.getContext());
else {
assert(false && "Unsupported constraint");
}
return ty;
}
Value convertToType(Value val, std::string constraint, Location loc,
PatternRewriter &rewriter) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto isConstraintNumber = isNumber(constraint);
if (!isConstraintNumber) {
auto ty = getTypeFromConstraint(constraint[0], rewriter);
if (isa<LLVM::LLVMPointerType>(val.getType())) {
return b.ptrtoint(ty, val);
} else {
assert(val.getType().getIntOrFloatBitWidth() <=
ty.getIntOrFloatBitWidth() &&
"Cannot convert to a smaller type");
if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth())
return b.zext(ty, val);
}
}
return val;
}
SmallVector<PTXBuilder::Operand *>
getPtxOutputs(const nvgpu::Constraints &outputConstraints,
PTXBuilder &ptxBuilder) {
SmallVector<PTXBuilder::Operand *> ptxOutputs;
for (unsigned i = 0; i < outputConstraints.size(); i++) {
auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]);
ptxOutputs.push_back(ptxOutput);
}
return ptxOutputs;
}
OperandsAndConstraints
unpackOperands(const OperandsAndConstraints &operandsAndConstraints,
PTXBuilder &ptxBuilder, Location loc,
PatternRewriter &rewriter) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
OperandsAndConstraints unpackedOperands;
for (const auto &[operand, constraint] : operandsAndConstraints) {
auto llvmStruct = llvm::dyn_cast<LLVM::LLVMStructType>(operand.getType());
auto isConstraintNumber = isNumber(constraint);
if (llvmStruct) {
for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) {
if (isConstraintNumber) {
auto constraintInt = std::stoi(constraint) + i;
unpackedOperands.push_back(
{b.extract_val(llvmStruct.getBody()[i], operand, i),
std::to_string(constraintInt)});
} else {
unpackedOperands.push_back(
{b.extract_val(llvmStruct.getBody()[i], operand, i), constraint});
}
}
} else {
unpackedOperands.push_back({operand, constraint});
}
}
return unpackedOperands;
}
SmallVector<PTXBuilder::Operand *>
getPtxOperands(const OperandsAndConstraints &operandsAndConstraints,
PTXBuilder &ptxBuilder, Location loc,
PatternRewriter &rewriter) {
SmallVector<PTXBuilder::Operand *> ptxOperands;
auto unpackedOperandsAndConstraints =
unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter);
for (auto &[operand, constraint] : unpackedOperandsAndConstraints) {
auto convertedOperand = convertToType(operand, constraint, loc, rewriter);
auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint);
ptxOperands.push_back(ptxOperand);
}
return ptxOperands;
}
std::string patchPtxAsm(Operation *op, std::string ptxAsm) {
std::vector<std::pair<int, int>> patchLocations;
std::vector<std::string> patchValues;
auto start = ptxAsm.find("#", 0);
while (start != std::string::npos) {
auto endIterator =
std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(),
[](unsigned char c) { return !std::isalnum(c); });
assert(endIterator != ptxAsm.end() && "unexpected asm format");
auto end = std::distance(ptxAsm.begin(), endIterator);
auto patchLocation = std::make_pair(start, end);
patchLocations.push_back(patchLocation);
auto patchValue = ptxAsm.substr(start + 1, end - start - 1);
patchValues.push_back(patchValue);
start = ptxAsm.find("#", end);
}
assert(patchLocations.size() == patchValues.size() &&
"patchLocations and patchValues should have the same size");
if (patchLocations.size() == 0) {
return ptxAsm;
}
std::string res = "";
size_t prevStart = 0;
unsigned i = 0;
for (auto &[start, end] : patchLocations) {
res += ptxAsm.substr(prevStart, start - prevStart);
auto integerAttr = op->getAttrOfType<IntegerAttr>(patchValues[i]);
auto attr = integerAttr.getInt();
res += std::to_string(attr);
prevStart = end;
i++;
}
if (prevStart < ptxAsm.size())
res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart);
return res;
}
template <typename SourceOp>
class NVGPUOpGenericPattern : public OpRewritePattern<SourceOp> {
public:
explicit NVGPUOpGenericPattern(MLIRContext *context, std::string ptxAsm,
Constraints outputConstraints,
Constraints inputConstraints)
: OpRewritePattern<SourceOp>(context), ptxAsm(std::move(ptxAsm)),
outputConstraints(outputConstraints),
inputConstraints(inputConstraints) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
OperandsAndConstraints operandsAndConstraints;
for (unsigned i = 0; i < inputConstraints.size(); i++) {
operandsAndConstraints.push_back(
{op->getOperand(i), inputConstraints[i]});
}
return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandsAndConstraints,
outputConstraints);
}
private:
std::string ptxAsm;
Constraints outputConstraints;
Constraints inputConstraints;
};
class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {
public:
using OpRewritePattern<ttn::WarpIdOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ttn::WarpIdOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (triton::gpu::lookupNumWarps(op) == 1) {
rewriter.replaceOp(op, b.i32_val(0));
return success();
}
Value tid = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);
if (std::optional<int> startId =
getWarpGroupStartThreadId(rewriter.getInsertionBlock()))
tid = rewriter.create<LLVM::SubOp>(loc, tid, b.i32_val(*startId));
Value warpId = b.udiv(tid, b.i32_val(32));
warpId = LLVM::NVIDIA::shuffleIdx(loc, rewriter, warpId, 0);
rewriter.replaceOp(op, warpId);
return success();
}
};
class ClusterCTAIdOpPattern : public OpRewritePattern<ttn::ClusterCTAIdOp> {
using OpRewritePattern<ttn::ClusterCTAIdOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ttn::ClusterCTAIdOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto a0 = rewriter.create<NVVM::BlockInClusterIdXOp>(loc, i32_ty);
auto a1 = rewriter.create<NVVM::BlockInClusterIdYOp>(loc, i32_ty);
auto a2 = rewriter.create<NVVM::BlockInClusterIdZOp>(loc, i32_ty);
auto a3 = rewriter.create<NVVM::ClusterDimBlocksXOp>(loc, i32_ty);
auto a4 = rewriter.create<NVVM::ClusterDimBlocksYOp>(loc, i32_ty);
auto p1 = rewriter.create<LLVM::MulOp>(loc, a2, a4);
auto s1 = rewriter.create<LLVM::AddOp>(loc, a1, p1);
auto p2 = rewriter.create<LLVM::MulOp>(loc, s1, a3);
auto res = rewriter.create<LLVM::AddOp>(loc, a0, p2);
rewriter.replaceOp(op, res);
return success();
}
};
template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
class MatrixOpPattern : public OpRewritePattern<MatrixOpType> {
public:
using OpRewritePattern<MatrixOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(MatrixOpType op,
PatternRewriter &rewriter) const override {
unsigned vecSize = getVectorSize(op);
bool trans = op.getTrans();
std::string ptxAsm =
(llvm::Twine(ConcreteMatrixOpPattern::kOpCode) +
getPtxModifiers(vecSize, trans, op.getShape(), op.getBitWidth()) +
" " + getOperands(op, vecSize) + ";")
.str();
OperandsAndConstraints operandAndConstraints =
getOperandsAndConstraints(op, vecSize);
Constraints outputConstraints = getOutputConstraints(op, vecSize);
return rewriteAsPtxAsm(op, rewriter, ptxAsm, operandAndConstraints,
outputConstraints);
}
protected:
std::string getPtxModifiers(unsigned vecSize, bool trans,
triton::nvgpu::LoadMatrixShape shape,
int bitWidth) const {
std::string ptxAsmBase = std::string(".sync.aligned");
switch (shape) {
case triton::nvgpu::LoadMatrixShape::m8n8:
ptxAsmBase += ".m8n8";
break;
case triton::nvgpu::LoadMatrixShape::m16n16:
ptxAsmBase += ".m16n16";
break;
default:
llvm_unreachable("Invalid load matrix shape");
}
std::string suffix = trans ? ".trans.shared" : ".shared";
suffix += ".b" + std::to_string(bitWidth);
switch (vecSize) {
case 1:
return ptxAsmBase + ".x1" + suffix;
case 2:
return ptxAsmBase + ".x2" + suffix;
case 4:
return ptxAsmBase + ".x4" + suffix;
default:
llvm_unreachable("Invalid vector size");
}
}
std::string getPtxRegOperands(unsigned startIdx, unsigned count) const {
llvm::SmallString<20> regOperands;
llvm::raw_svector_ostream stream(regOperands);
stream << "{";
for (unsigned i = 0; i < count; i++) {
stream << "$" + llvm::utostr(startIdx + i);
if (i != count - 1)
stream << ", ";
}
stream << "}";
return std::string(regOperands.str());
}
std::string getPtxAddrOperand(unsigned idx) const {
return (llvm::Twine("[$") + llvm::utostr(idx) + "]").str();
}
virtual std::string getOperands(MatrixOpType op, unsigned vecSize) const = 0;
virtual OperandsAndConstraints
getOperandsAndConstraints(MatrixOpType op, unsigned vecSize) const = 0;
virtual Constraints getOutputConstraints(MatrixOpType op,
unsigned vecSize) const = 0;
virtual unsigned getVectorSize(MatrixOpType op) const = 0;
};
class LoadMatrixOpPattern
: public MatrixOpPattern<ttn::LoadMatrixOp, LoadMatrixOpPattern> {
public:
using MatrixOpPattern<ttn::LoadMatrixOp,
LoadMatrixOpPattern>::MatrixOpPattern;
static constexpr const char *kOpCode = "ldmatrix";
protected:
unsigned getVectorSize(ttn::LoadMatrixOp op) const override {
auto resultType = op.getType();
if (auto structTy = dyn_cast<LLVM::LLVMStructType>(resultType)) {
return structTy.getBody().size();
}
return 1;
}
std::string getOperands(ttn::LoadMatrixOp op,
unsigned vecSize) const override {
return (llvm::Twine(getPtxRegOperands(0, vecSize)) + ", " +
getPtxAddrOperand(vecSize))
.str();
}
OperandsAndConstraints
getOperandsAndConstraints(ttn::LoadMatrixOp op,
unsigned vecSize) const override {
return {{op.getAddr(), "r"}};
}
Constraints getOutputConstraints(ttn::LoadMatrixOp op,
unsigned vecSize) const override {
return Constraints(vecSize, "=r");
}
};
class LoadAcquireOpPattern : public OpRewritePattern<ttn::LoadAcquireOp> {
public:
using OpRewritePattern<ttn::LoadAcquireOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ttn::LoadAcquireOp op,
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
Type valueTy = op.getType();
const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth());
const size_t maxWordWidth = std::max<size_t>(32, valueNBits);
const size_t width = std::min((size_t)valueNBits, maxWordWidth);
const std::string writeConstraint =
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
PTXBuilder ptxBuilder;
bool init = true;
auto *dstOpr = ptxBuilder.newOperand(writeConstraint, init);
auto *addrOpr =
ptxBuilder.newAddrOperand(op.getAddr(), "l", 0 );
auto &ld =
ptxBuilder.create<>("ld")
->global()
.o("cta", op.getScope() == triton::nvgpu::MemSyncScope::CTA)
.o("gpu", op.getScope() == triton::nvgpu::MemSyncScope::GPU)
.o("sys", op.getScope() == triton::nvgpu::MemSyncScope::SYSTEM)
.o("acquire", op.getSem() == triton::nvgpu::MemSemantic::ACQUIRE)
.o("relaxed", op.getSem() == triton::nvgpu::MemSemantic::RELAXED)
.b(width);
ld(dstOpr, addrOpr).maybePredicate(op.getMask(), "b");
Type retTy = IntegerType::get(getContext(), width);
Value ret = ptxBuilder.launch(rewriter, loc, retTy);
ret = b.bitcast(ret, op.getType());
rewriter.replaceOp(op, {ret});
return success();
}
};
class WGMMAWaitGroupOpPattern : public OpRewritePattern<ttn::WGMMAWaitGroupOp> {
public:
using OpRewritePattern<ttn::WGMMAWaitGroupOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ttn::WGMMAWaitGroupOp op,
PatternRewriter &rewriter) const override {
return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op),
getOperandsAndConstraints(op),
getOutputConstraints(op));
}
Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return Constraints(numOutputRegs, output);
}
OperandsAndConstraints
getOperandsAndConstraints(ttn::WGMMAWaitGroupOp op) const {
OperandsAndConstraints operandsAndConstraints;
auto input = op.getInput();
operandsAndConstraints.push_back({input, "0"});
return operandsAndConstraints;
}
std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = dyn_cast<LLVM::LLVMStructType>(op.getType());
uint32_t numCRegs = outputStructType.getBody().size();
std::string args = "";
uint32_t asmOpIdx = 0;
for (uint32_t i = 0; i < numCRegs; ++i) {
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
}
auto ptxAsm = "// wait for regs: " + args + "\n\t" +
"wgmma.wait_group.sync.aligned #pendings;";
return ptxAsm;
}
};
class WGMMAOpPattern : public OpRewritePattern<ttn::WGMMAOp> {
public:
using OpRewritePattern<ttn::WGMMAOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ttn::WGMMAOp op,
PatternRewriter &rewriter) const override {
return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op),
getOperandsAndConstraints(op),
getOutputConstraints(op));
}
std::vector<std::string> getOutputConstraints(ttn::WGMMAOp op) const {
auto resultType = op.getType();
auto outputStructType = dyn_cast<LLVM::LLVMStructType>(resultType);
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return std::vector<std::string>(numOutputRegs, output);
}
OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const {
OperandsAndConstraints operandsAndConstraints;
auto opA = op.getOpA();
auto opB = op.getOpB();
auto opC = op.getOpC();
auto opScaleD = op.getUseC();
auto typeA = opA.getType();
auto structTypeA = dyn_cast<LLVM::LLVMStructType>(typeA);
if (opC)
operandsAndConstraints.push_back({opC, "0"});
if (structTypeA) {
operandsAndConstraints.push_back({opA, "r"});
} else {
operandsAndConstraints.push_back({opA, "l"});
}
operandsAndConstraints.push_back({opB, "l"});
if (op.getOpC())
operandsAndConstraints.push_back({opScaleD, "b"});
return operandsAndConstraints;
}
std::string getPtxAsm(ttn::WGMMAOp op) const {
using namespace ttn;
auto opA = op.getOpA();
auto opB = op.getOpB();
auto m = op.getM();
auto n = op.getN();
auto k = op.getK();
auto eltTypeC = op.getEltTypeC();
auto eltTypeA = op.getEltTypeA();
auto eltTypeB = op.getEltTypeB();
auto layoutA = op.getLayoutA();
auto layoutB = op.getLayoutB();
auto typeA = opA.getType();
auto typeB = opB.getType();
auto typeOutput = op.getType();
auto structTypeA = dyn_cast<LLVM::LLVMStructType>(typeA);
auto structTypeB = dyn_cast<LLVM::LLVMStructType>(typeB);
auto structTypeOutput = dyn_cast<LLVM::LLVMStructType>(typeOutput);
assert(!structTypeB && "Operand B can not be registers");
assert(structTypeOutput && "Output and C operand must be registers");
bool transA = layoutA == WGMMALayout::col;
bool transB = layoutB == WGMMALayout::row;
bool supported = false, needTransArgs = false, floatTypeWGMMA = false;
assert(m % 8 == 0 && n % 8 == 0 && k % 8 == 0);
supported |=
(eltTypeA == WGMMAEltType::f16) && (eltTypeB == WGMMAEltType::f16) &&
(eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) &&
(m == 64 && 8 <= n && n <= 256 && k == 16);
supported |= (eltTypeA == WGMMAEltType::bf16) &&
(eltTypeB == WGMMAEltType::bf16) &&
(eltTypeC == WGMMAEltType::f32) &&
(m == 64 && 8 <= n && n <= 256 && k == 16);
needTransArgs = supported;
floatTypeWGMMA = supported;
if (!supported && !transA && !transB) {
supported |= (eltTypeA == WGMMAEltType::tf32) &&
(eltTypeB == WGMMAEltType::tf32) &&
(eltTypeC == WGMMAEltType::f32) &&
(m == 64 && 8 <= n && n <= 256 && k == 8);
supported |=
(eltTypeA == WGMMAEltType::e4m3 || eltTypeA == WGMMAEltType::e5m2) &&
(eltTypeB == WGMMAEltType::e4m3 || eltTypeB == WGMMAEltType::e5m2) &&
(eltTypeC == WGMMAEltType::f16 || eltTypeC == WGMMAEltType::f32) &&
(m == 64 && 8 <= n && n <= 256 && k == 32);
floatTypeWGMMA = supported;
supported |= (eltTypeA == WGMMAEltType::s8) &&
(eltTypeB == WGMMAEltType::s8) &&
(eltTypeC == WGMMAEltType::s32) &&
(m == 64 && 8 <= n && n <= 224 && k == 32);
}
assert(supported && "WGMMA type or shape is not supported");
uint32_t asmOpIdx = 0;
std::string args = "";
uint32_t numCRegs = structTypeOutput.getBody().size();
args += "{";
for (uint32_t i = 0; i < numCRegs; ++i) {
args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ",");
}
args += "}, ";
if (op.getOpC())
asmOpIdx += numCRegs;
if (structTypeA) {
uint32_t numARegs = structTypeA.getBody().size();
args += "{";
for (uint32_t i = 0; i < numARegs; ++i) {
args +=
"$" + std::to_string(asmOpIdx++) + (i == numARegs - 1 ? "" : ",");
}
args += "}, ";
} else {
args += "$" + std::to_string(asmOpIdx++) + ", ";
}
args += "$" + std::to_string(asmOpIdx++) + ", ";
if (op.getOpC())
args += "$" + std::to_string(asmOpIdx++);
else
args += "0";
if (floatTypeWGMMA)
args += ", 1, 1";
if (needTransArgs) {
if (!structTypeA)
args += ", " + std::to_string(transA);
args += ", " + std::to_string(transB);
}
auto ptxAsm = "wgmma.mma_async.sync.aligned"
".m" +
std::to_string(m) + "n" + std::to_string(n) + "k" +
std::to_string(k) + "." + stringifyEnum(eltTypeC).str() +
"." + stringifyEnum(eltTypeA).str() + "." +
stringifyEnum(eltTypeB).str() + " " + args + ";";
return ptxAsm;
}
};
static Value createTMAlloc(IRRewriter &rewriter, LLVM::LLVMFuncOp func,
size_t size, Value pred, bool twoCTAs) {
PTXBuilder ptxBuilder;
Location loc = func.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value sharedMem = mlir::LLVM::getStackPointer(rewriter, func);
std::string ptxString =
"@$0 tcgen05.alloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) +
".sync.aligned.shared::cta.b32 [$1], " + std::to_string(size) + ";";
auto &allocOp = *ptxBuilder.create<>(ptxString);
allocOp(
{ptxBuilder.newOperand(pred, "b"), ptxBuilder.newOperand(sharedMem, "r")},
true);
auto voidTy = void_ty(func->getContext());
ptxBuilder.launch(rewriter, loc, void_ty(func->getContext()));
rewriter.create<NVVM::Barrier0Op>(loc);
Value address = b.load(i32_ty, sharedMem);
rewriter.create<NVVM::Barrier0Op>(loc);
address = b.inttoptr(ptr_ty(func.getContext(), 6), address);
return address;
}
static void createRelinquishAlloc(IRRewriter &rewriter, Location loc,
Value pred, bool twoCTAs) {
PTXBuilder ptxBuilder;
std::string ptxString = "@$0 tcgen05.relinquish_alloc_permit.cta_group::" +
std::to_string(twoCTAs ? 2 : 1) + ".sync.aligned;";
auto &f = *ptxBuilder.create<>(ptxString);
f({ptxBuilder.newOperand(pred, "b")}, true);
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
}
void freeTMAlloc(LLVM::LLVMFuncOp func, Value alloc, size_t size, Value pred,
bool twoCTAs) {
func.walk([&](LLVM::ReturnOp ret) {
OpBuilder b(ret);
auto ctx = ret->getContext();
auto loc = ret.getLoc();
auto voidTy = void_ty(ctx);
b.create<NVVM::Barrier0Op>(loc);
PTXBuilder ptxBuilder;
std::string ptxString =
"@$0 tcgen05.dealloc.cta_group::" + std::to_string(twoCTAs ? 2 : 1) +
".sync.aligned.b32 $1, " + std::to_string(size) + ";";
auto &dealloc = *ptxBuilder.create<>(ptxString);
dealloc(
{ptxBuilder.newOperand(pred, "b"), ptxBuilder.newOperand(alloc, "r")},
true);
ptxBuilder.launch(b, loc, void_ty(ctx));
});
}
static Value initTensorMemory(LLVM::LLVMFuncOp func) {
auto mod = func->getParentOfType<ModuleOp>();
assert(mod->hasAttr("ttg.tensor_memory_size"));
size_t size = cast<IntegerAttr>(mod->getAttr("ttg.tensor_memory_size"))
.getValue()
.getZExtValue();
if (size == 0)
return Value();
IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToStart(&func.front());
auto ctx = mod.getContext();
auto loc = func.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
if (size > 512) {
rewriter.create<LLVM::Trap>(loc);
return rewriter.create<LLVM::UndefOp>(loc, ptr_ty(ctx, 6));
}
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
bool useTwoCTAs = numCTAs == 2;
Value threadId = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);
Value pred = b.icmp_ult(threadId, b.i32_val(32));
Value alloc = createTMAlloc(rewriter, func, size, pred, useTwoCTAs);
createRelinquishAlloc(rewriter, loc, pred, useTwoCTAs);
freeTMAlloc(func, alloc, size, pred, useTwoCTAs);
return alloc;
}
static void lowerTensorMemoryAlloc(ModuleOp mod) {
SmallVector<Operation *> baseOps;
LLVM::LLVMFuncOp kernel = nullptr;
mod.walk([&](ttn::TensorMemoryBaseAddress baseOp) {
baseOps.push_back(baseOp);
if (!kernel)
kernel = baseOp->getParentOfType<LLVM::LLVMFuncOp>();
assert(kernel == baseOp->getParentOfType<LLVM::LLVMFuncOp>() &&
"TODO: add support for function calls using tmem.");
});
if (baseOps.empty())
return;
assert(triton::isKernel(kernel));
Value newBase = initTensorMemory(kernel);
if (!newBase)
return;
for (auto baseOp : baseOps) {
baseOp->getResult(0).replaceAllUsesWith(newBase);
baseOp->erase();
}
}
}
class ConvertNVGPUToLLVM
: public impl::ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
public:
using impl::ConvertNVGPUToLLVMBase<
ConvertNVGPUToLLVM>::ConvertNVGPUToLLVMBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
RewritePatternSet patterns(context);
patterns
.add<ClusterCTAIdOpPattern, LoadMatrixOpPattern, WGMMAOpPattern,
LoadAcquireOpPattern, WGMMAWaitGroupOpPattern, WarpIdOpPattern>(
context);
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
signalPassFailure();
lowerTensorMemoryAlloc(mod);
makeAllWarpGroupsIsolatedFromAbove(mod);
}
};
LogicalResult
nvgpu::rewriteAsPtxAsm(Operation *op, PatternRewriter &rewriter,
std::string ptxAsm,
const OperandsAndConstraints &operandsAndConstraints,
const Constraints &outputConstraints) {
auto ctx = rewriter.getContext();
auto loc = op->getLoc();
ptxAsm = patchPtxAsm(op, std::move(ptxAsm));
auto hasSideEffects = !isMemoryEffectFree(op);
PTXBuilder ptxBuilder;
auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder);
auto ptxOperands =
getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter);
SmallVector<PTXBuilder::Operand *> outputsAndOperands = ptxOutputs;
outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end());
auto &ptxInstr = *ptxBuilder.create<PTXInstr>(ptxAsm);
ptxInstr(outputsAndOperands, true);
auto retTy =
op->getNumResults() == 0 ? void_ty(ctx) : op->getResult(0).getType();
auto res = ptxBuilder.launch(rewriter, loc, retTy,
hasSideEffects);
if (op->getNumResults() == 0) {
rewriter.eraseOp(op);
} else {
rewriter.replaceOp(op, res);
}
return success();
}
}
}