#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "third_party/amd/include/Analysis/AxisInfoExt.h"
#include "third_party/amd/include/Analysis/RangeAnalysis.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#undef DEBUG_TYPE
#define DEBUG_TYPE "tritonamdgpu-convert-buffer-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using ::mlir::LLVM::AMD::getVectorSize;
using mlir::triton::AMD::ISAFamily;
namespace ttg = mlir::triton::gpu;
namespace tt = mlir::triton;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUCONVERTTOBUFFEROPS
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
bool isSplatOneConstTensor(const Value v) {
auto constantOp = v.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return false;
if (auto denseAttr =
dyn_cast<DenseIntElementsAttr>(constantOp.getValueAttr()))
return denseAttr.isSplat() && denseAttr.getSplatValue<APInt>().isOne();
return false;
}
bool verifyNonSmallerByAssumption(
Value expr, const DenseMap<Value, SetVector<Operation *>> &assumptions,
const std::function<bool(Value)> &matchesOther) {
if (!assumptions.contains(expr))
return false;
for (Operation *assume : assumptions.at(expr)) {
auto cmpOp = llvm::dyn_cast<arith::CmpIOp>(assume);
if (!cmpOp)
continue;
switch (cmpOp.getPredicate()) {
case arith::CmpIPredicate::eq:
case arith::CmpIPredicate::sge:
case arith::CmpIPredicate::sgt: {
if (cmpOp.getLhs() == expr && matchesOther(cmpOp.getRhs())) {
LDBG(" " << expr << " non-neg by assumption " << cmpOp);
return true;
}
break;
}
case arith::CmpIPredicate::sle:
case arith::CmpIPredicate::slt: {
if (cmpOp.getRhs() == expr && matchesOther(cmpOp.getLhs())) {
LDBG(" " << expr << " non-neg by assumption " << cmpOp);
return true;
}
break;
}
default:
break;
}
}
return false;
}
bool verifyNonSmallerByAssumption(
Value expr, const DenseMap<Value, SetVector<Operation *>> &assumptions,
Value other) {
return verifyNonSmallerByAssumption(
expr, assumptions, [&](auto otherAssum) { return otherAssum == other; });
}
bool verifyNonNegativeExpr(
Value expr, const DenseMap<Value, SetVector<Operation *>> &assumptions,
std::shared_ptr<DataFlowSolver> solver) {
LDBG("Determing if non-negative: " << expr);
auto nonNegativePred = [&solver](Value v) -> bool {
if (const auto *r =
solver->lookupState<dataflow::IntegerValueRangeLattice>(v)) {
if (r->getValue().isUninitialized())
return false;
if (AMD::isEmptyInitializedRange(r->getValue().getValue()))
return false;
}
return succeeded(dataflow::staticallyNonNegative(*solver, v));
};
if (nonNegativePred(expr))
return true;
Operation *op = expr.getDefiningOp();
if (!op) {
LDBG(" No defining op, assuming possibly negative");
return false;
}
bool nonNegative =
llvm::TypeSwitch<Operation *, bool>(expr.getDefiningOp())
.Case<triton::TransOp, triton::SplitOp, triton::BroadcastOp,
triton::ExpandDimsOp, triton::SplatOp, triton::ReshapeOp,
triton::gpu::ConvertLayoutOp>([&](auto unaryOp) {
return verifyNonNegativeExpr(unaryOp.getOperand(), assumptions,
solver);
})
.Case<triton::GatherOp>([&](auto gatherOp) {
return verifyNonNegativeExpr(gatherOp.getSrc(), assumptions,
solver);
})
.Case<triton::JoinOp, triton::CatOp>([&](auto joinOp) {
return verifyNonNegativeExpr(joinOp.getLhs(), assumptions,
solver) &&
verifyNonNegativeExpr(joinOp.getRhs(), assumptions, solver);
})
.Case<triton::HistogramOp>([&](auto) { return true; })
.Case<triton::MakeRangeOp>([&](auto makeRangeOp) {
return makeRangeOp.getStartAttr().getInt() >= 0 &&
makeRangeOp.getEndAttr().getInt() >= 0;
})
.Case<arith::ConstantIntOp>(
[&](auto constIntOp) { return constIntOp.value() >= 0; })
.Case<arith::ConstantOp>([&](arith::ConstantOp constOp) {
Value val = constOp.getResult();
DenseIntElementsAttr constVal;
if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat())
return constVal.getSplatValue<APInt>().isNonNegative();
return false;
})
.Case<triton::GetNumProgramsOp, triton::GetProgramIdOp>([&](auto) {
return true;
})
.Case<arith::MaxSIOp>([&](auto maxOp) {
return verifyNonNegativeExpr(maxOp.getLhs(), assumptions, solver) ||
verifyNonNegativeExpr(maxOp.getRhs(), assumptions, solver);
})
.Case<arith::RemSIOp>([&](auto remsiOp) {
return verifyNonNegativeExpr(remsiOp.getLhs(), assumptions, solver);
})
.Case<arith::TruncIOp, arith::ExtSIOp>([&](Operation *unaryOp) {
return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions,
solver);
})
.Case<triton::PtrToIntOp, triton::BitcastOp>(
[&](auto) { return false; })
.Case<arith::CeilDivUIOp, arith::DivUIOp, arith::ExtUIOp,
arith::FPToUIOp, arith::MaxUIOp, arith::MinUIOp, arith::RemUIOp,
arith::ShRUIOp>(
[&](auto uOp) { return true; })
.Case<arith::AddIOp, arith::MinSIOp, arith::MulIOp, arith::DivSIOp>(
[&](Operation *binOp) {
return verifyNonNegativeExpr(binOp->getOperand(0), assumptions,
solver) &&
verifyNonNegativeExpr(binOp->getOperand(1), assumptions,
solver);
})
.Case<scf::IfOp>([&](auto ifOp) {
auto results = ifOp.getResults();
auto it = std::find(results.begin(), results.end(), expr);
assert(it != results.end() && "expr should be the result of ifOp");
auto resultIdx = it - results.begin();
auto thenYield = cast<scf::YieldOp>(ifOp.thenYield());
auto elseYield = cast<scf::YieldOp>(ifOp.elseYield());
return verifyNonNegativeExpr(thenYield->getOperand(resultIdx),
assumptions, solver) &&
verifyNonNegativeExpr(elseYield->getOperand(resultIdx),
assumptions, solver);
})
.Case<arith::SubIOp>([&](auto op) {
return verifyNonSmallerByAssumption(op.getLhs(), assumptions,
op.getRhs());
})
.Case<triton::amdgpu::ExtractSliceOp>([&](auto op) {
return verifyNonNegativeExpr(op->getOperand(0), assumptions,
solver);
})
.Default([&](Operation *) {
LDBG(" Unhandled op, cannot assume non-negative");
return false;
});
return nonNegative;
}
bool canUseBufferOps(Value ptr,
const DenseMap<Value, SetVector<Operation *>> &assumptions,
std::shared_ptr<DataFlowSolver> solver) {
LDBG("Buffer op checks for: " << ptr);
auto addPtrOp = ptr.getDefiningOp<triton::AddPtrOp>();
if (!addPtrOp)
return false;
auto maybeSplatOp = addPtrOp.getPtr().getDefiningOp<triton::SplatOp>();
if (!maybeSplatOp)
return false;
LDBG("Pattern matched");
Value offset = addPtrOp.getOffset();
if (cast<RankedTensorType>(offset.getType()).getElementTypeBitWidth() != 32)
return false;
LDBG("32 bit offset");
return verifyNonNegativeExpr(offset, assumptions, std::move(solver));
}
Value getBlockStride(Location loc, Value offset, PatternRewriter &rewriter) {
if (auto maybeAdd = offset.getDefiningOp<arith::AddIOp>())
for (auto addOpr : maybeAdd.getOperands())
if (auto maybeBC = addOpr.getDefiningOp<tt::BroadcastOp>()) {
auto bcSrc = maybeBC.getSrc();
if (auto maybeMul = bcSrc.getDefiningOp<arith::MulIOp>())
for (auto mulOpr : maybeMul.getOperands())
if (auto maybeSplat = mulOpr.getDefiningOp<tt::SplatOp>())
return maybeSplat.getSrc();
}
return nullptr;
}
struct ConvertTritonAtomicCASOpToBufferAtomicCAS
: public mlir::OpRewritePattern<triton::AtomicCASOp> {
using OpRewritePattern::OpRewritePattern;
ConvertTritonAtomicCASOpToBufferAtomicCAS(
mlir::MLIRContext *context,
DenseMap<Value, SetVector<Operation *>> &assumptions,
ModuleAxisInfoAnalysis &axisAnalysisPass,
std::shared_ptr<DataFlowSolver> solver)
: mlir::OpRewritePattern<triton::AtomicCASOp>(context),
assumptions(assumptions), axisAnalysisPass(axisAnalysisPass),
solver(std::move(solver)) {}
mlir::LogicalResult
matchAndRewrite(triton::AtomicCASOp op,
PatternRewriter &rewriter) const override {
LDBG("Try to convert: " << op);
Value ptr = op.getPtr();
auto sem = op.getSem();
auto scope = op.getScope();
if (!canUseBufferOps(ptr, assumptions, solver)) {
return rewriter.notifyMatchFailure(op, "canUseBufferOps check failed");
}
switch (scope) {
case MemSyncScope::GPU:
case MemSyncScope::CTA:
break;
default:
return rewriter.notifyMatchFailure(op, "CAS with unsupported scope");
}
LDBG("CAS supported scope");
switch (sem) {
case MemSemantic::RELAXED:
case MemSemantic::RELEASE:
case MemSemantic::ACQUIRE:
case MemSemantic::ACQUIRE_RELEASE:
break;
default:
return rewriter.notifyMatchFailure(
op, "CAS with unsupported memory ordering");
}
auto addPtrOp = ptr.getDefiningOp<triton::AddPtrOp>();
Value tensorPtr = addPtrOp.getPtr();
Value tensorOffset = addPtrOp.getOffset();
auto splatOp = tensorPtr.getDefiningOp<triton::SplatOp>();
Value basePtr = splatOp.getSrc();
auto checkType = getElementTypeOrSelf(op.getVal());
bool isSupportedType = checkType.isInteger(32) || checkType.isInteger(64);
if (!isSupportedType) {
return rewriter.notifyMatchFailure(op, "AtomicCAS with unsupported type");
}
LDBG("AtomicCAS supported type");
auto opValueType = op.getVal().getType();
auto opBitWidth = 0;
if (auto tensorType = dyn_cast<RankedTensorType>(opValueType)) {
auto elemBitWidth = tensorType.getElementTypeBitWidth();
opBitWidth =
getVectorSize(basePtr, tensorOffset, axisAnalysisPass) * elemBitWidth;
} else {
opBitWidth = opValueType.getIntOrFloatBitWidth();
}
if (opBitWidth < 32) {
return rewriter.notifyMatchFailure(
op, "BufferAtomicCAS requires opBitWidth >= 32");
}
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
rewriter.replaceOpWithNewOp<triton::amdgpu::BufferAtomicCASOp>(
op, op.getVal().getType(), basePtr, tensorOffset, op.getCmp(),
op.getVal(), blockStride, sem, scope);
return success();
}
private:
const DenseMap<Value, SetVector<Operation *>> &assumptions;
ModuleAxisInfoAnalysis &axisAnalysisPass;
std::shared_ptr<DataFlowSolver> solver;
};
struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
: public mlir::OpRewritePattern<triton::AtomicRMWOp> {
using OpRewritePattern::OpRewritePattern;
ConvertTritonAtomicRMWOpToBufferAtomicRMW(
mlir::MLIRContext *context,
DenseMap<Value, SetVector<Operation *>> &assumptions,
ModuleAxisInfoAnalysis &axisAnalysisPass,
std::shared_ptr<DataFlowSolver> solver, ISAFamily isaFamily)
: mlir::OpRewritePattern<triton::AtomicRMWOp>(context),
assumptions(assumptions), axisAnalysisPass(axisAnalysisPass),
solver(std::move(solver)), isaFamily(isaFamily) {}
mlir::LogicalResult
matchAndRewrite(triton::AtomicRMWOp op,
PatternRewriter &rewriter) const override {
LDBG("Try to convert: " << op);
Value ptr = op.getPtr();
auto atomicRmwOp = op.getAtomicRmwOp();
auto sem = op.getSem();
auto scope = op.getScope();
if (!canUseBufferOps(ptr, assumptions, solver)) {
return rewriter.notifyMatchFailure(op, "canUseBufferOps check failed");
}
switch (scope) {
case MemSyncScope::GPU:
case MemSyncScope::CTA:
break;
default:
return rewriter.notifyMatchFailure(op, "RMW with unsupported scope");
}
LDBG("RMW supported scope");
switch (sem) {
case MemSemantic::RELAXED:
case MemSemantic::RELEASE:
case MemSemantic::ACQUIRE:
case MemSemantic::ACQUIRE_RELEASE:
break;
default:
return rewriter.notifyMatchFailure(
op, "RMW with unsupported memory ordering");
}
auto addPtrOp = ptr.getDefiningOp<triton::AddPtrOp>();
Value tensorPtr = addPtrOp.getPtr();
Value tensorOffset = addPtrOp.getOffset();
auto splatOp = tensorPtr.getDefiningOp<triton::SplatOp>();
Value basePtr = splatOp.getSrc();
auto checkType = getElementTypeOrSelf(op.getVal());
bool isSupportedType = checkType.isF16() || checkType.isBF16() ||
checkType.isF32() || checkType.isF64() ||
checkType.isInteger(32) || checkType.isInteger(64);
if (!isSupportedType) {
return rewriter.notifyMatchFailure(op, "RMW with unsupported type");
}
LDBG("RMW supported type");
if (isaFamily == ISAFamily::CDNA3 && checkType.isBF16() &&
atomicRmwOp == RMWOp::FADD) {
return rewriter.notifyMatchFailure(op, "RMW FADD does not support bf16");
}
LDBG("RMW FADD supported 16-bit type");
auto vecSize = getVectorSize(ptr, axisAnalysisPass);
if (vecSize % 2 != 0 && (checkType.isF16() || checkType.isBF16())) {
return rewriter.notifyMatchFailure(
op, "RMW float 16 dtypes must be aligned by 2");
}
LDBG("RMW passed alignment check");
switch (atomicRmwOp) {
case RMWOp::AND:
case RMWOp::OR:
case RMWOp::XOR:
case RMWOp::ADD:
case RMWOp::FADD:
case RMWOp::MAX:
case RMWOp::MIN:
case RMWOp::UMAX:
case RMWOp::UMIN:
case RMWOp::XCHG:
break;
default:
auto rmwOpStr = stringifyRMWOp(atomicRmwOp).str();
return rewriter.notifyMatchFailure(op, "RMW with unsupported op: " +
rmwOpStr);
}
LDBG("RMW supported Op");
auto opValueType = op.getVal().getType();
auto opBitWidth = 0;
if (auto tensorType = dyn_cast<RankedTensorType>(opValueType)) {
auto elemBitWidth = tensorType.getElementTypeBitWidth();
opBitWidth = vecSize * elemBitWidth;
} else {
opBitWidth = opValueType.getIntOrFloatBitWidth();
}
if (opBitWidth < 32) {
return rewriter.notifyMatchFailure(op, "RMW requires opBitWidth >= 32");
}
Value maybeMask{};
if (op.getMask() && !isSplatOneConstTensor(op.getMask()))
maybeMask = op.getMask();
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
rewriter.replaceOpWithNewOp<triton::amdgpu::BufferAtomicRMWOp>(
op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset,
op.getVal(), blockStride, sem, scope, maybeMask);
return success();
}
private:
DenseMap<Value, SetVector<Operation *>> assumptions;
ModuleAxisInfoAnalysis &axisAnalysisPass;
std::shared_ptr<DataFlowSolver> solver;
ISAFamily isaFamily;
};
template <typename T> struct always_false : std::false_type {};
template <typename SourceOp>
struct ConvertTritonLoadToBufferLoad : public mlir::OpRewritePattern<SourceOp> {
using OpRewritePattern<SourceOp>::OpRewritePattern;
ConvertTritonLoadToBufferLoad(
mlir::MLIRContext *context,
DenseMap<Value, SetVector<Operation *>> &assumptions,
std::shared_ptr<DataFlowSolver> solver)
: mlir::OpRewritePattern<SourceOp>(context), assumptions(assumptions),
solver(std::move(solver)) {}
mlir::LogicalResult
matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const override {
LDBG("Try to convert: " << op);
Value ptr = op.getOperand(0);
if (canUseBufferOps(ptr, assumptions, solver)) {
auto addPtrOp = ptr.getDefiningOp<triton::AddPtrOp>();
Value tensorPtr = addPtrOp.getPtr();
Value tensorOffset = addPtrOp.getOffset();
auto splatOp = tensorPtr.getDefiningOp<triton::SplatOp>();
Value basePtr = splatOp.getSrc();
Value maybeOther{};
if (op.getOther() && !isZeroConst(op.getOther()))
maybeOther = op.getOther();
Value maybeMask{};
if (op.getMask() && !isSplatOneConstTensor(op.getMask()))
maybeMask = op.getMask();
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
auto bufferLoadOp = [&]() {
if constexpr (std::is_same_v<SourceOp, triton::LoadOp>) {
return rewriter.create<triton::amdgpu::BufferLoadOp>(
op->getLoc(), op.getType(), basePtr, tensorOffset, blockStride,
op.getCache(), maybeMask, maybeOther);
} else if constexpr (std::is_same_v<
SourceOp,
triton::gpu::AsyncCopyGlobalToLocalOp>) {
return rewriter.create<triton::amdgpu::BufferLoadToLocalOp>(
op->getLoc(), op.getType(), op.getResult(), basePtr, tensorOffset,
maybeMask, maybeOther, blockStride, op.getCache());
} else {
static_assert(always_false<SourceOp>::value,
"Unsupported type in ConvertTritonLoadToBufferLoad");
}
}();
assert(bufferLoadOp);
rewriter.replaceOp(op, bufferLoadOp);
return success();
}
LDBG("Failed to convert: " << op);
return rewriter.notifyMatchFailure(op, "Failed to convert LoadOp");
}
private:
DenseMap<Value, SetVector<Operation *>> assumptions;
std::shared_ptr<DataFlowSolver> solver;
};
struct ConvertTritonStoreToBufferStore
: public mlir::OpRewritePattern<triton::StoreOp> {
using OpRewritePattern::OpRewritePattern;
ConvertTritonStoreToBufferStore(
mlir::MLIRContext *context,
DenseMap<Value, SetVector<Operation *>> &assumptions,
std::shared_ptr<DataFlowSolver> solver)
: mlir::OpRewritePattern<triton::StoreOp>(context),
assumptions(assumptions), solver(std::move(solver)) {}
mlir::LogicalResult
matchAndRewrite(triton::StoreOp op,
PatternRewriter &rewriter) const override {
LDBG("Try to convert: " << op);
Value ptr = op.getPtr();
if (canUseBufferOps(ptr, assumptions, solver)) {
auto addPtrOp = ptr.getDefiningOp<triton::AddPtrOp>();
Value tensorPtr = addPtrOp.getPtr();
Value tensorOffset = addPtrOp.getOffset();
auto splatOp = tensorPtr.getDefiningOp<triton::SplatOp>();
Value basePtr = splatOp.getSrc();
Value maybeMask{};
if (op.getMask() && !isSplatOneConstTensor(op.getMask()))
maybeMask = op.getMask();
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
rewriter.replaceOpWithNewOp<triton::amdgpu::BufferStoreOp>(
op, op.getValue(), basePtr, tensorOffset, blockStride, op.getCache(),
maybeMask);
return success();
}
LDBG("Failed to convert: " << op);
return rewriter.notifyMatchFailure(op, "Failed to convert StoreOp");
}
private:
DenseMap<Value, SetVector<Operation *>> assumptions;
std::shared_ptr<DataFlowSolver> solver;
};
}
struct TritonAMDGPUConvertToBufferOpsPass
: impl::TritonAMDGPUConvertToBufferOpsBase<
TritonAMDGPUConvertToBufferOpsPass> {
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ModuleOp mod = getOperation();
DenseMap<Value, SetVector<Operation *>> assumptions =
AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation());
std::shared_ptr<DataFlowSolver> solver = createDataFlowSolver();
AMD::TritonIntegerRangeAnalysis *rangeAnalysis =
solver->load<AMD::TritonIntegerRangeAnalysis>(assumptions);
AMD::initializeFuncOps(mod, rangeAnalysis);
if (failed(solver->initializeAndRun(getOperation())))
return signalPassFailure();
AMD::ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
patterns.add<ConvertTritonLoadToBufferLoad<tt::LoadOp>,
ConvertTritonLoadToBufferLoad<ttg::AsyncCopyGlobalToLocalOp>,
ConvertTritonStoreToBufferStore>(context, assumptions, solver);
triton::AMD::ISAFamily isaFamily =
triton::AMD::deduceISAFamily(archGenerationName);
if (this->allowBufferAtomics && ISAFamily::CDNA3 == isaFamily)
patterns.add<ConvertTritonAtomicRMWOpToBufferAtomicRMW>(
context, assumptions, axisInfoAnalysis, solver, isaFamily);
patterns.add<ConvertTritonAtomicCASOpToBufferAtomicCAS>(
context, assumptions, axisInfoAnalysis, solver);
if (applyPatternsGreedily(mod, std::move(patterns)).failed())
signalPassFailure();
}
};
}