#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
#include "triton/Dialect/TritonInstrument/IR/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
namespace {
namespace tt = mlir::triton;
namespace ttg = tt::gpu;
namespace tti = mlir::triton::instrument;
namespace ttng = mlir::triton::nvidia_gpu;
Value createFullLike(OpBuilder &builder, Location loc, Value scalar,
RankedTensorType tensorTy) {
auto scalarTy = scalar.getType();
auto elemTy = tensorTy.getElementType();
assert(scalarTy == elemTy &&
"Expected scalar to be of the same type as the tensor elements");
return builder.create<triton::SplatOp>(loc, tensorTy, scalar);
}
Value createCmpIntTensorScalar(
OpBuilder &builder, Location loc, Value tensor, Value scalar,
arith::CmpIPredicate predicate = arith::CmpIPredicate::eq) {
auto tensorTy = cast<RankedTensorType>(tensor.getType());
auto splat = createFullLike(builder, loc, scalar, tensorTy);
auto cmp = builder.create<arith::CmpIOp>(loc, predicate, tensor, splat);
return cmp;
}
Value createMemDescToI64(RewriterBase &rewriter, Location loc,
const LLVMTypeConverter *typeConverter,
ttg::MemDescType memDescTy, Value sharedMemStruct) {
TritonLLVMOpBuilder b(loc, rewriter);
if (isa<ttng::TensorMemoryEncodingAttr>(memDescTy.getEncoding())) {
return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct);
}
assert(isa<ttg::SharedEncodingTrait>(memDescTy.getEncoding()) &&
"Unsupported memory encoding");
Type srcElemTy = typeConverter->convertType(memDescTy.getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, sharedMemStruct,
srcElemTy, rewriter);
auto offset = smemObj.getShmemOffset(loc, rewriter, memDescTy);
auto elemSize = srcElemTy.getIntOrFloatBitWidth() / 8;
offset = b.mul(offset, b.i32_val(elemSize));
auto i64Ty = rewriter.getIntegerType(64);
offset = b.zext(i64Ty, offset);
return b.add(offset, b.ptrtoint(i64Ty, smemObj.getBase()));
}
Type getBarsElType(OpBuilder &b) { return b.getIntegerType(64); }
RankedTensorType getWriteBarsType(OpBuilder &b, RankedTensorType buffersType) {
int size = buffersType.getShape()[0];
assert(llvm::isPowerOf2_64(size) && "Expected power of 2");
auto tensorType = RankedTensorType::get({size}, getBarsElType(b),
buffersType.getEncoding());
return tensorType;
}
RankedTensorType getReadBarsType(OpBuilder &b, RankedTensorType buffersType,
RankedTensorType barriersType) {
int size = buffersType.getShape()[0];
assert(llvm::isPowerOf2_64(size) && "Expected power of 2");
auto tensorType = RankedTensorType::get({size}, getBarsElType(b),
buffersType.getEncoding());
return tensorType;
}
Value convertAndBroadcast(OpBuilder &b, Location loc, Value tensor, int dim,
ArrayRef<int64_t> shape,
ttg::BlockedEncodingAttr encoding) {
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto resultType =
RankedTensorType::get(shape, tensorType.getElementType(), encoding);
auto slicedLayout =
ttg::SliceEncodingAttr::get(b.getContext(), dim, encoding);
tensor = b.create<ttg::ConvertLayoutOp>(
loc, tensorType.cloneWithEncoding(slicedLayout), tensor);
tensor = tti::expandOuterSlicedDim(b, loc, tensor);
tensor = b.create<tt::BroadcastOp>(loc, resultType, tensor);
return tensor;
}
std::tuple<Block *, Block *, Block *>
createIfBlock(ConversionPatternRewriter &b, Location loc, Value cnd) {
Block *prevBlock = b.getInsertionBlock();
Block *ifBlock = b.splitBlock(prevBlock, b.getInsertionPoint());
Block *thenBlock = b.splitBlock(ifBlock, ifBlock->begin());
b.setInsertionPointToEnd(ifBlock);
b.create<LLVM::BrOp>(loc, thenBlock);
b.setInsertionPointToEnd(prevBlock);
b.create<LLVM::CondBrOp>(loc, cnd, ifBlock, thenBlock);
b.setInsertionPointToStart(thenBlock);
return {prevBlock, ifBlock, thenBlock};
}
Value createMaxReduce(OpBuilder &b, Location loc, Value tensor, int axis) {
OpBuilder::InsertionGuard guard(b);
auto tensorType = cast<RankedTensorType>(tensor.getType());
auto reduceOp = b.create<tt::ReduceOp>(loc, std::vector<Value>{tensor}, axis);
auto ®ion = reduceOp.getRegion();
auto &block = region.emplaceBlock();
block.addArguments({tensorType.getElementType(), tensorType.getElementType()},
{loc, loc});
auto arg0 = block.getArgument(0);
auto arg1 = block.getArgument(1);
b.setInsertionPointToStart(&block);
auto cmpOp =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, arg0, arg1);
auto result = b.create<arith::SelectOp>(loc, cmpOp, arg0, arg1);
auto returnOp = b.create<tt::ReduceReturnOp>(loc, std::vector<Value>{result});
return reduceOp->getResult(0);
}
struct AssertInThreadOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalAssertInThreadOp> {
explicit AssertInThreadOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: ConvertOpToLLVMPattern<tti::ExperimentalAssertInThreadOp>(typeConverter,
benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(tti::ExperimentalAssertInThreadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto tensorTy = cast<RankedTensorType>(op.getCondition().getType());
auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> condElems =
unpackLLElements(loc, adaptor.getCondition(), rewriter);
auto condTy = condElems[0].getType();
bool check_any = adaptor.getCheckAny();
Value condition = check_any ? b.int_val(condTy.getIntOrFloatBitWidth(), 0)
: b.int_val(condTy.getIntOrFloatBitWidth(), 1);
assert(condTy.isSignedInteger() ||
condTy.isSignlessInteger() &&
"Unsupported type for assert_in_thread");
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, condTy, rewriter.getZeroAttr(condTy));
for (auto elem : condElems) {
if (check_any) {
condition = b.or_(condition, elem);
} else {
condition = b.and_(condition, elem);
}
}
condition = b.xor_(condition, b.int_val(condTy.getIntOrFloatBitWidth(), 1));
llAssert(op, condition, adaptor.getMessage(), rewriter);
b.barrier();
rewriter.eraseOp(op);
return success();
}
void llAssert(Operation *op, Value condition, StringRef message,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
StringRef file = "unknown";
StringRef func = "unknown";
int line = 0;
int col = 0;
while (auto callLoc = dyn_cast<CallSiteLoc>(loc))
loc = callLoc.getCallee();
while (auto nameLoc = dyn_cast<NameLoc>(loc))
loc = nameLoc.getChildLoc();
if (auto fileLineColLoc = dyn_cast<FileLineColLoc>(loc)) {
file = fileLineColLoc.getFilename();
line = fileLineColLoc.getLine();
col = fileLineColLoc.getColumn();
}
Value threadId = getThreadId(*b.builder, loc);
Value zero = b.int_val(threadId.getType().getIntOrFloatBitWidth(), 0);
Value threadIdIsZero = b.icmp_eq(threadId, zero);
condition = b.and_(condition, threadIdIsZero);
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(rewriter, loc, condition);
rewriter.setInsertionPointToStart(ifBlock);
targetInfo.assertFail(rewriter, loc, message, file, func, line);
rewriter.setInsertionPointToStart(thenBlock);
}
protected:
const TargetInfoBase &targetInfo;
};
struct BufferPointersOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalBufferPointersOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(tti::ExperimentalBufferPointersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto *ctx = rewriter.getContext();
auto module = op->getParentOfType<ModuleOp>();
auto values = adaptor.getOffsets();
auto encoding =
cast<ttg::BlockedEncodingAttr>(op.getResult().getType().getEncoding());
auto bufPointers =
createInitializedIntArrayTensor(rewriter, loc, encoding, values);
Value base = nullptr;
if (op.getMemType() == tti::MemType::SHARED_MEM) {
base = getSharedMemoryBase(rewriter,
op->getParentOfType<FunctionOpInterface>());
} else {
assert(op.getMemType() == tti::MemType::TENSOR_MEM &&
"Unsupported memory type");
TritonLLVMOpBuilder b(loc, rewriter);
base = rewriter.create<nvgpu::TensorMemoryBaseAddress>(loc);
base = b.ptrtoint(i32_ty, base);
}
bufPointers = rewriter.create<arith::AddIOp>(
loc, bufPointers,
rewriter.create<triton::SplatOp>(loc, bufPointers.getType(), base));
rewriter.replaceOp(op, bufPointers);
return success();
}
Value createInitializedIntArrayTensor(OpBuilder &builder, Location loc,
BlockedEncodingAttr encoding,
ArrayRef<int32_t> values) const {
int64_t size = values.size();
assert(llvm::isPowerOf2_64(size) && "Expected power of 2");
auto tensorType =
RankedTensorType::get({size}, builder.getIntegerType(64), encoding);
SmallVector<APInt> apInts = llvm::to_vector(
llvm::map_range(values, [](int32_t v) { return APInt(64, v); }));
auto denseAttr = DenseElementsAttr::get(tensorType, apInts);
return builder.create<arith::ConstantOp>(loc, tensorType, denseAttr);
}
Value getSharedMemoryBase(ConversionPatternRewriter &rewriter,
FunctionOpInterface func) const {
Location loc = func.getLoc();
Value base = LLVM::getStackPointer(rewriter, func);
auto i64Ty = rewriter.getIntegerType(64);
TritonLLVMOpBuilder b(loc, rewriter);
base = b.ptrtoint(i64Ty, base);
return base;
}
};
struct CheckWriteStateOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalCheckWriteStateOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCheckWriteStateOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
RankedTensorType writeStateType =
cast<RankedTensorType>(op.getWriteStateType());
Value writeState =
tti::createLoadScratchMemory(b, loc, op.getWriteState(), writeStateType)
->getResult(0);
int hwPipelined = op.getHwPipelined() ? 1 : 0;
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
Value writeStateZero = tti::createConstIntTensor(b, loc, 0, writeStateType);
Value buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
Value currBufState = b.create<arith::SelectOp>(loc, buffersEqBuf,
writeState, writeStateZero);
auto shiftVal =
tti::createConstIntTensor(b, loc, hwPipelined, writeStateType);
currBufState = b.create<arith::ShRUIOp>(loc, currBufState, shiftVal);
Value currBufStateEqZero = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, currBufState, writeStateZero);
b.create<tti::ExperimentalAssertInThreadOp>(
loc, currBufStateEqZero, "Buffer being accessed has outstanding writes",
false);
b.eraseOp(op);
return success();
}
};
struct CheckReadBarriersOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalCheckReadBarriersOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCheckReadBarriersOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
RankedTensorType readBarsType =
cast<RankedTensorType>(op.getReadBarsType());
Value readBars =
tti::createLoadScratchMemory(b, loc, op.getReadBars(), readBarsType)
->getResult(0);
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
buffersEqBuf = convertAndBroadcast(
b, loc, buffersEqBuf, 1, readBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(readBarsType.getEncoding()));
auto readBarsZero = tti::createConstIntTensor(b, loc, 0, readBarsType);
auto currBufBar =
b.create<arith::SelectOp>(loc, buffersEqBuf, readBars, readBarsZero);
auto currBufBarEqZero = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, currBufBar, readBarsZero);
b.create<tti::ExperimentalAssertInThreadOp>(
loc, currBufBarEqZero, "Buffer being accessed has outstanding reads",
false);
b.eraseOp(op);
return success();
}
};
struct SetWriteStateOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalSetWriteStateOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalSetWriteStateOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
RankedTensorType writeStateType =
cast<RankedTensorType>(op.getWriteStateType());
Value writeState =
tti::createLoadScratchMemory(b, loc, op.getWriteState(), writeStateType)
->getResult(0);
int notHwPipelined = op.getHwPipelined() ? 0 : 1;
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
int val = 1 | (notHwPipelined << 1);
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
writeState = b.create<arith::SelectOp>(
loc, buffersEqBuf,
tti::createConstIntTensor(b, loc, val, writeStateType), writeState);
tti::createStoreScratchMemory(b, loc, op.getWriteState(), writeState,
writeStateType);
b.eraseOp(op);
return success();
}
};
struct CommitWriteWithBarrierOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalCommitWriteWithBarrierOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCommitWriteWithBarrierOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> barriers = op.getBarriers();
RankedTensorType writeBarsType =
cast<RankedTensorType>(op.getWriteBarsType());
Value writeBars =
tti::createLoadScratchMemory(b, loc, op.getWriteBars(), writeBarsType)
->getResult(0);
RankedTensorType writeStateType =
cast<RankedTensorType>(op.getWriteStateType());
Value writeState =
tti::createLoadScratchMemory(b, loc, op.getWriteState(), writeStateType)
->getResult(0);
Value mbar = createMemDescToI64(b, loc, getTypeConverter(),
op.getMbar().getType(), adaptor.getMbar());
writeState = convertAndBroadcast(
b, loc, writeState, 1, writeBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(writeBarsType.getEncoding()));
auto barriersEqMbar = createCmpIntTensorScalar(b, loc, barriers, mbar);
barriersEqMbar = convertAndBroadcast(
b, loc, barriersEqMbar, 0, writeBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(writeBarsType.getEncoding()));
barriersEqMbar =
b.create<arith::ExtUIOp>(loc, writeBarsType, barriersEqMbar);
Value stateAndBar =
b.create<arith::AndIOp>(loc, writeState, barriersEqMbar);
writeBars = b.create<arith::OrIOp>(loc, writeBars, stateAndBar);
tti::createStoreScratchMemory(b, loc, op.getWriteBars(), writeBars,
writeBarsType);
b.eraseOp(op);
return success();
}
};
struct SetReadBarrierOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalSetReadBarrierOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalSetReadBarrierOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
TypedValue<RankedTensorType> barriers = op.getBarriers();
RankedTensorType readBarsType =
cast<RankedTensorType>(op.getReadBarsType());
Value readBars =
tti::createLoadScratchMemory(b, loc, op.getReadBars(), readBarsType)
->getResult(0);
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
Value mbar = createMemDescToI64(b, loc, getTypeConverter(),
op.getMbar().getType(), adaptor.getMbar());
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
buffersEqBuf = convertAndBroadcast(
b, loc, buffersEqBuf, 1, readBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(readBarsType.getEncoding()));
auto barriersEqMbar = createCmpIntTensorScalar(b, loc, barriers, mbar);
barriersEqMbar = convertAndBroadcast(
b, loc, barriersEqMbar, 0, readBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(readBarsType.getEncoding()));
Value bufAndBar =
b.create<arith::AndIOp>(loc, buffersEqBuf, barriersEqMbar);
bufAndBar = b.create<arith::ExtUIOp>(loc, readBarsType, bufAndBar);
readBars = b.create<arith::OrIOp>(loc, readBars, bufAndBar);
tti::createStoreScratchMemory(b, loc, op.getReadBars(), readBars,
readBarsType);
b.eraseOp(op);
return success();
}
};
struct ClearWriteBarrierOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalClearWriteBarrierOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalClearWriteBarrierOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> barriers = op.getBarriers();
RankedTensorType writeBarsType =
cast<RankedTensorType>(op.getWriteBarsType());
Value writeBars =
tti::createLoadScratchMemory(b, loc, op.getWriteBars(), writeBarsType)
->getResult(0);
RankedTensorType writeStateType =
cast<RankedTensorType>(op.getWriteStateType());
Value writeState =
tti::createLoadScratchMemory(b, loc, op.getWriteState(), writeStateType)
->getResult(0);
Value mbar = createMemDescToI64(b, loc, getTypeConverter(),
op.getMbar().getType(), adaptor.getMbar());
auto barriersEqMbar = createCmpIntTensorScalar(b, loc, barriers, mbar);
barriersEqMbar = convertAndBroadcast(
b, loc, barriersEqMbar, 0, writeBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(writeBarsType.getEncoding()));
Value barriersEqMbarI8 =
b.create<arith::ExtUIOp>(loc, writeBarsType, barriersEqMbar);
Value writeBarsForMbar =
b.create<arith::AndIOp>(loc, writeBars, barriersEqMbarI8);
writeBarsForMbar = createMaxReduce(b, loc, writeBarsForMbar, 1);
Value writeStateZero = tti::createConstIntTensor(b, loc, 0, writeStateType);
Value writeBarsForMbarNonZero = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, writeBarsForMbar, writeStateZero);
writeState = b.create<arith::SelectOp>(loc, writeBarsForMbarNonZero,
writeStateZero, writeState);
tti::createStoreScratchMemory(b, loc, op.getWriteState(), writeState,
writeStateType);
Value writeBarsZero = tti::createConstIntTensor(b, loc, 0, writeBarsType);
writeBars = b.create<arith::SelectOp>(loc, barriersEqMbar, writeBarsZero,
writeBars);
tti::createStoreScratchMemory(b, loc, op.getWriteBars(), writeBars,
writeBarsType);
b.eraseOp(op);
return success();
}
};
struct ClearReadBarrierOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalClearReadBarrierOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalClearReadBarrierOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> barriers = op.getBarriers();
RankedTensorType readBarsType =
cast<RankedTensorType>(op.getReadBarsType());
Value readBars =
tti::createLoadScratchMemory(b, loc, op.getReadBars(), readBarsType)
->getResult(0);
Value mbar = createMemDescToI64(b, loc, getTypeConverter(),
op.getMbar().getType(), adaptor.getMbar());
auto readBarsZero = tti::createConstIntTensor(b, loc, 0, readBarsType);
auto readBarsEqMbar = createCmpIntTensorScalar(b, loc, barriers, mbar);
readBarsEqMbar = convertAndBroadcast(
b, loc, readBarsEqMbar, 0, readBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(readBarsType.getEncoding()));
readBars =
b.create<arith::SelectOp>(loc, readBarsEqMbar, readBarsZero, readBars);
tti::createStoreScratchMemory(b, loc, op.getReadBars(), readBars,
readBarsType);
b.eraseOp(op);
return success();
}
};
struct CheckBarrierWritesClearedOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalCheckBarrierWritesClearedOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCheckBarrierWritesClearedOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> barriers = op.getBarriers();
RankedTensorType writeBarsType =
cast<RankedTensorType>(op.getWriteBarsType());
Value writeBars =
tti::createLoadScratchMemory(b, loc, op.getWriteBars(), writeBarsType)
->getResult(0);
Value mbar = createMemDescToI64(b, loc, getTypeConverter(),
op.getMbar().getType(), adaptor.getMbar());
auto writeBarsZero = tti::createConstIntTensor(b, loc, 0, writeBarsType);
auto barsEqMbar = createCmpIntTensorScalar(b, loc, barriers, mbar);
barsEqMbar = convertAndBroadcast(
b, loc, barsEqMbar, 0, writeBarsType.getShape(),
cast<ttg::BlockedEncodingAttr>(writeBarsType.getEncoding()));
barsEqMbar = b.create<arith::ExtUIOp>(loc, writeBarsType, barsEqMbar);
Value currWriteBars = b.create<arith::AndIOp>(loc, writeBars, barsEqMbar);
Value currWriteBarsEqZero = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, currWriteBars, writeBarsZero);
b.create<tti::ExperimentalAssertInThreadOp>(
loc, currWriteBarsEqZero,
"Barrier is being reused while still tracking writes", false);
b.eraseOp(op);
return success();
}
};
struct StageAccessForCommitOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalStageAccessForCommitOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalStageAccessForCommitOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
RankedTensorType writeCommitsType =
cast<RankedTensorType>(op.getOutstandingCommitsType());
Value writeCommits =
tti::createLoadScratchMemory(b, loc, op.getOutstandingCommits(),
writeCommitsType)
->getResult(0);
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
auto writeCommitsMinusOne =
tti::createConstIntTensor(b, loc, -1, writeCommitsType);
writeCommits = b.create<arith::SelectOp>(
loc, buffersEqBuf, writeCommitsMinusOne, writeCommits);
tti::createStoreScratchMemory(b, loc, op.getOutstandingCommits(),
writeCommits, writeCommitsType);
b.eraseOp(op);
return success();
}
};
struct CommitAccessesOpConversion
: public ConvertOpToLLVMPattern<tti::ExperimentalCommitAccessesOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCommitAccessesOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
RankedTensorType writeCommitsType =
cast<RankedTensorType>(op.getOutstandingCommitsType());
Value writeCommits =
tti::createLoadScratchMemory(b, loc, op.getOutstandingCommits(),
writeCommitsType)
->getResult(0);
Type elementType = writeCommitsType.getElementType();
Value minusOne = b.create<arith::ConstantOp>(
loc, elementType, b.getIntegerAttr(elementType, -1));
Value zero = b.create<arith::ConstantOp>(loc, elementType,
b.getIntegerAttr(elementType, 0));
Value writeCommitsOne =
tti::createConstIntTensor(b, loc, 1, writeCommitsType);
auto writeCommitsGtZero = createCmpIntTensorScalar(
b, loc, writeCommits, zero, arith::CmpIPredicate::sgt);
auto writeCommitsPlusOne =
b.create<arith::AddIOp>(loc, writeCommits, writeCommitsOne);
writeCommits = b.create<arith::SelectOp>(loc, writeCommitsGtZero,
writeCommitsPlusOne, writeCommits);
auto writeCommitsEqMinusOne =
createCmpIntTensorScalar(b, loc, writeCommits, minusOne);
writeCommits = b.create<arith::SelectOp>(loc, writeCommitsEqMinusOne,
writeCommitsOne, writeCommits);
tti::createStoreScratchMemory(b, loc, op.getOutstandingCommits(),
writeCommits, writeCommitsType);
b.eraseOp(op);
return success();
}
};
struct ClearOutstandingCommitsOpConversion
: public ConvertOpToLLVMPattern<
tti::ExperimentalClearOutstandingCommitsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalClearOutstandingCommitsOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
RankedTensorType outstandingCommitsType =
cast<RankedTensorType>(op.getOutstandingCommitsType());
Value outstandingCommits =
tti::createLoadScratchMemory(b, loc, op.getOutstandingCommits(),
outstandingCommitsType)
->getResult(0);
Type elementType = outstandingCommitsType.getElementType();
Value outstandingNum = b.create<arith::ConstantOp>(
loc, elementType,
b.getIntegerAttr(elementType, op.getOutstandingNum()));
Value outstandingCommitsZero =
tti::createConstIntTensor(b, loc, 0, outstandingCommitsType);
auto outstandingCommitsGtOutstandingNum = createCmpIntTensorScalar(
b, loc, outstandingCommits, outstandingNum, arith::CmpIPredicate::sgt);
outstandingCommits =
b.create<arith::SelectOp>(loc, outstandingCommitsGtOutstandingNum,
outstandingCommitsZero, outstandingCommits);
tti::createStoreScratchMemory(b, loc, op.getOutstandingCommits(),
outstandingCommits, outstandingCommitsType);
b.eraseOp(op);
return success();
}
};
struct CheckOutstandingCommitsOpConversion
: public ConvertOpToLLVMPattern<
tti::ExperimentalCheckOutstandingCommitsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(tti::ExperimentalCheckOutstandingCommitsOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &b) const override {
Location loc = op.getLoc();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
if (op.getPred()) {
auto [prevBlock, ifBlock, thenBlock] =
createIfBlock(b, loc, op.getPred());
b.setInsertionPointToStart(ifBlock);
}
TypedValue<RankedTensorType> buffers = op.getBuffers();
RankedTensorType outstandingCommitsType =
cast<RankedTensorType>(op.getOutstandingCommitsType());
Value outstandingCommits =
tti::createLoadScratchMemory(b, loc, op.getOutstandingCommits(),
outstandingCommitsType)
->getResult(0);
Value buf = createMemDescToI64(b, loc, getTypeConverter(),
op.getBuf().getType(), adaptor.getBuf());
StringRef pendingAccessType = op.getPendingAccessType();
Type elementType = outstandingCommitsType.getElementType();
auto buffersEqBuf = createCmpIntTensorScalar(b, loc, buffers, buf);
auto zero = b.create<arith::ConstantOp>(loc, elementType,
b.getIntegerAttr(elementType, 0));
auto outstandingCommitsZero =
tti::createConstIntTensor(b, loc, 0, outstandingCommitsType);
auto currCommits = b.create<arith::SelectOp>(
loc, buffersEqBuf, outstandingCommits, outstandingCommitsZero);
auto currCommitsEqZero =
createCmpIntTensorScalar(b, loc, currCommits, zero);
std::string message =
"Accessing buffer with pending access. Pending access type: " +
pendingAccessType.str();
b.create<tti::ExperimentalAssertInThreadOp>(
loc, currCommitsEqZero, b.getStringAttr(message), false);
b.eraseOp(op);
return success();
}
};
}
void mlir::triton::populateInstrumentationToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<AssertInThreadOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<BufferPointersOpConversion>(typeConverter);
patterns.add<CheckWriteStateOpConversion>(typeConverter);
patterns.add<CheckReadBarriersOpConversion>(typeConverter);
patterns.add<SetWriteStateOpConversion>(typeConverter);
patterns.add<CommitWriteWithBarrierOpConversion>(typeConverter);
patterns.add<SetReadBarrierOpConversion>(typeConverter);
patterns.add<ClearWriteBarrierOpConversion>(typeConverter);
patterns.add<ClearReadBarrierOpConversion>(typeConverter);
patterns.add<CheckBarrierWritesClearedOpConversion>(typeConverter);
patterns.add<StageAccessForCommitOpConversion>(typeConverter);
patterns.add<CommitAccessesOpConversion>(typeConverter);
patterns.add<ClearOutstandingCommitsOpConversion>(typeConverter);
patterns.add<CheckOutstandingCommitsOpConversion>(typeConverter);
}