#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Tools/GenericSwizzling.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/STLExtras.h"
#define DEBUG_TYPE "ttgpu_to_llvm"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::triton;
namespace mlir::LLVM {
using namespace mlir::triton;
Value createConstantI1(Location loc, OpBuilder &rewriter, bool v);
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v);
Value createConstantF16(Location loc, OpBuilder &rewriter, float v);
Value createConstantBF16(Location loc, OpBuilder &rewriter, float v);
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
Value createConstantF64(Location loc, OpBuilder &rewriter, double v);
Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type);
Value createIndexConstant(OpBuilder &builder, Location loc,
const TypeConverter *converter, int64_t value);
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);
LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
LLVMFuncOp funcOp, ValueRange args);
LLVM::CallIntrinsicOp
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
TypeRange types, ValueRange args);
}
namespace mlir::triton {
struct TritonLLVMOpBuilder {
TritonLLVMOpBuilder(Location loc, OpBuilder &builder)
: loc(loc), builder(&builder) {}
template <typename... Args> LLVM::SIToFPOp inttofloat(Args &&...args) {
return builder->create<LLVM::SIToFPOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::IntToPtrOp inttoptr(Args &&...args) {
return builder->create<LLVM::IntToPtrOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::PtrToIntOp ptrtoint(Args &&...args) {
return builder->create<LLVM::PtrToIntOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ZExtOp zext(Args &&...args) {
return builder->create<LLVM::ZExtOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SExtOp sext(Args &&...args) {
return builder->create<LLVM::SExtOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FPExtOp fpext(Args &&...args) {
return builder->create<LLVM::FPExtOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FPTruncOp fptrunc(Args &&...args) {
return builder->create<LLVM::FPTruncOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::TruncOp trunc(Args &&...args) {
return builder->create<LLVM::TruncOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::UDivOp udiv(Args &&...args) {
return builder->create<LLVM::UDivOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SDivOp sdiv(Args &&...args) {
return builder->create<LLVM::SDivOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::URemOp urem(Args &&...args) {
return builder->create<LLVM::URemOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::AddOp add(Args &&...args) {
return builder->create<LLVM::AddOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SubOp sub(Args &&...args) {
return builder->create<LLVM::SubOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FAddOp fadd(Args &&...args) {
return builder->create<LLVM::FAddOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::MulOp mul(Args &&...args) {
return builder->create<LLVM::MulOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FMulOp fmul(Args &&...args) {
return builder->create<LLVM::FMulOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FMAOp fma(Args &&...args) {
return builder->create<LLVM::FMAOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::FNegOp neg(Args &&...args) {
return builder->create<LLVM::FNegOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SMaxOp smax(Args &&...args) {
return builder->create<LLVM::SMaxOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::UMaxOp umax(Args &&...args) {
return builder->create<LLVM::UMaxOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::MaxNumOp fmax(Args &&...args) {
return builder->create<LLVM::MaxNumOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SMinOp smin(Args &&...args) {
return builder->create<LLVM::SMinOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::UMinOp umin(Args &&...args) {
return builder->create<LLVM::UMinOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::MinNumOp fmin(Args &&...args) {
return builder->create<LLVM::MinNumOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ShlOp shl(Args &&...args) {
return builder->create<LLVM::ShlOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::LShrOp lshr(Args &&...args) {
return builder->create<LLVM::LShrOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::AShrOp ashr(Args &&...args) {
return builder->create<LLVM::AShrOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::AndOp and_(Args &&...args) {
return builder->create<LLVM::AndOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::XOrOp xor_(Args &&...args) {
return builder->create<LLVM::XOrOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::OrOp or_(Args &&...args) {
return builder->create<LLVM::OrOp>(loc, std::forward<Args>(args)...);
}
LLVM::BitcastOp bitcast(Value val, Type type) {
return builder->create<LLVM::BitcastOp>(loc, type, val);
}
template <typename... Args>
LLVM::AddrSpaceCastOp addrspacecast(Args &&...args) {
return builder->create<LLVM::AddrSpaceCastOp>(loc,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::GEPOp gep(Args &&...args) {
return builder->create<LLVM::GEPOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::InsertValueOp insert_val(Args &&...args) {
return builder->create<LLVM::InsertValueOp>(loc,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ExtractValueOp extract_val(Args &&...args) {
return builder->create<LLVM::ExtractValueOp>(loc,
std::forward<Args>(args)...);
}
template <typename... Args>
LLVM::InsertElementOp insert_element(Args &&...args) {
return builder->create<LLVM::InsertElementOp>(loc,
std::forward<Args>(args)...);
}
template <typename... Args>
LLVM::ExtractElementOp extract_element(Args &&...args) {
return builder->create<LLVM::ExtractElementOp>(loc,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::LoadOp load(Args &&...args) {
return builder->create<LLVM::LoadOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::StoreOp store(Args &&...args) {
return builder->create<LLVM::StoreOp>(loc, std::forward<Args>(args)...);
}
LLVM::FCmpOp fcmp_ogt(Value lhs, Value rhs) {
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
LLVM::FCmpPredicate::ogt, lhs, rhs);
}
LLVM::FCmpOp fcmp_olt(Value lhs, Value rhs) {
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
LLVM::FCmpPredicate::olt, lhs, rhs);
}
LLVM::FCmpOp fcmp_eq(Value lhs, Value rhs) {
return builder->create<LLVM::FCmpOp>(loc, builder->getI1Type(),
LLVM::FCmpPredicate::oeq, lhs, rhs);
}
template <typename... Args> LLVM::ICmpOp icmp_eq(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_ne(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_slt(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_sle(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_sgt(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_sge(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_ult(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_ule(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_ugt(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ICmpOp icmp_uge(Args &&...args) {
return builder->create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge,
std::forward<Args>(args)...);
}
template <typename... Args> LLVM::SelectOp select(Args &&...args) {
return builder->create<LLVM::SelectOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::AddressOfOp address_of(Args &&...args) {
return builder->create<LLVM::AddressOfOp>(loc, std::forward<Args>(args)...);
}
mlir::gpu::BarrierOp barrier() {
return builder->create<mlir::gpu::BarrierOp>(loc);
}
template <typename... Args> LLVM::UndefOp undef(Args &&...args) {
return builder->create<LLVM::UndefOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::ZeroOp null(Args &&...args) {
return builder->create<LLVM::ZeroOp>(loc, std::forward<Args>(args)...);
}
template <typename... Args> LLVM::CallOp call(Args &&...args) {
return builder->create<LLVM::CallOp>(loc, std::forward<Args>(args)...);
}
Value int_val(short bitwidth, int64_t val) {
Type ty = builder->getIntegerType(bitwidth);
return builder->create<LLVM::ConstantOp>(loc, ty,
builder->getIntegerAttr(ty, val));
}
Value i1_val(int64_t val) { return int_val(1, val); }
Value true_val() { return int_val(1, true); }
Value false_val() { return int_val(1, false); }
Value f16_val(float v) { return LLVM::createConstantF16(loc, *builder, v); }
Value bf16_val(float v) { return LLVM::createConstantBF16(loc, *builder, v); }
Value f32_val(float v) { return LLVM::createConstantF32(loc, *builder, v); }
Value f64_val(double v) { return LLVM::createConstantF64(loc, *builder, v); }
Value i8_val(int64_t val) { return int_val(8, val); }
Value i16_val(int64_t val) { return int_val(16, val); }
Value i32_val(int64_t val) { return int_val(32, val); }
Value i64_val(int64_t val) { return int_val(64, val); }
Location loc;
OpBuilder *builder;
};
class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
public:
template <typename... Args>
TritonLLVMIRRewriter(Location loc, Args &&...args)
: IRRewriter(std::forward<Args>(args)...),
TritonLLVMOpBuilder(loc, *this) {}
Location getLoc() const { return loc; }
void setLoc(Location loc) { this->loc = loc; }
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
return OpBuilder::create<OpTy>(loc, std::forward<Args>(args)...);
}
};
}
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
#define int_ty(width) rewriter.getIntegerType(width)
#define i16_ty rewriter.getIntegerType(16)
#define i32_ty rewriter.getIntegerType(32)
#define i64_ty rewriter.getIntegerType(64)
#define ui32_ty rewriter.getIntegerType(32, false)
#define ui64_ty rewriter.getIntegerType(64, false)
#define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type)
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
#define i32_arr_attr(...) rewriter.getI32ArrayAttr({__VA_ARGS__})
#define i64_arr_attr(...) rewriter.getI64ArrayAttr({__VA_ARGS__})
#define str_attr(str) ::mlir::StringAttr::get(ctx, (str))
namespace mlir {
constexpr int kProfileScratchBufferOffset = -1;
constexpr int kGlobalScratchBufferOffset = -2;
constexpr int kSharedMemoryOffset = -3;
namespace triton {
namespace gpu {
std::pair<SmallVector<LocalMemOpTile>, SmallVector<LocalMemOpTile>>
getSrcDstTiles(const TargetInfoBase &targetInfo, int bitwidth);
Type getFunctionType(Type resultType, ValueRange operands);
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
StringRef funcName, Type funcType,
StringRef libname = "",
StringRef libpath = "");
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x);
}
}
namespace LLVM {
using namespace mlir::triton;
class SharedMemoryObject {
public:
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets);
SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc,
RewriterBase &rewriter);
SmallVector<Value> getOffsets() const { return offsets; }
Value getBase() const { return base; }
Type getBaseElemType() const { return baseElemType; }
SmallVector<Value> getElems() const;
SmallVector<Type> getTypes() const;
static uint64_t getMaskSpanOffsets(triton::gpu::MemDescType srcTy);
static bool isAffineSharedMemoryAccess(triton::gpu::MemDescType srcTy) {
return getMaskSpanOffsets(srcTy) != 0;
}
Value getShmemOffset(Location loc, RewriterBase &rewriter,
triton::gpu::MemDescType srcTy) const;
Value getShmemAffineBase(Location loc, RewriterBase &rewriter,
triton::gpu::MemDescType srcTy) const;
Value getCSwizzleOffset(int dim) const {
assert(dim >= 0 && dim < offsets.size());
return offsets[dim];
}
Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const;
private:
Value base;
Type baseElemType;
SmallVector<Value>
offsets;
};
Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter);
SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
RewriterBase &rewriter);
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
unsigned linear, ArrayRef<unsigned> shape);
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);
SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
std::tuple<SmallVector<Value>, Value>
delinearize(RewriterBase &rewriter, Location loc,
triton::gpu::DistributedEncodingTrait layout,
ArrayRef<int64_t> shape, StringAttr dimName, Value linear);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp);
Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
const TargetInfoBase &targetInfo,
FunctionOpInterface funcOp, Value allocOffset);
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
FunctionOpInterface funcOp);
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, Operation *op);
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
bool fastMath);
}
std::optional<int> getWarpGroupStartThreadId(Block *block);
Value getThreadId(OpBuilder &rewriter, Location loc);
Value getLaneId(OpBuilder &rewriter, Location loc);
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc);
using LLVM::SharedMemoryObject;
using ::mlir::LLVM::delinearize;
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
using ::mlir::triton::gpu::BlockedEncodingAttr;
using ::mlir::triton::gpu::CTALayoutAttr;
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
ArrayRef<Value> strides);
SmallVector<std::pair<StringAttr, Value>>
applyLinearLayout(Location loc, RewriterBase &rewriter,
const LinearLayout &layout,
ArrayRef<std::pair<StringAttr, Value>> indices);
SmallVector<SmallVector<unsigned>> emitOffsetForLayout(Attribute layout,
RankedTensorType type);
SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset);
Value emitPadding(Location loc, RewriterBase &rewriter,
triton::gpu::PaddedSharedEncodingAttr layout,
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
const TargetInfoBase &target,
std::function<void(VectorType, Value )> perVectorCallback);
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Value laneId, Value warpId,
std::function<void(VectorType, Value )> perVectorCallback);
SmallVector<Value>
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
ArrayRef<Value> valsArray,
Type llvmElemTy, Value smemBase,
std::function<Value(Value)> calcPaddedOffset,
Value affineOffset, uint64_t maskSpanAffineOffset,
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
Operation *localLoadOp = nullptr);
SmallVector<Value> lowerLdSt(
Location loc, MLIRContext *ctx, LinearLayout cvt,
ArrayRef<Value> valsArray,
Type llvmElemTy, Value smemBase,
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
uint64_t maskSpanAffineOffset, Value laneId, Value warpId,
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
std::optional<int> maybeMaxVecElems,
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
Value, int, VectorType)>
lowerInst);
SmallVector<Value>
lowerLocalLdSt(Location loc, MLIRContext *ctx,
LinearLayout cvt,
ArrayRef<Value> valsArray,
Type llvmElemTy, triton::gpu::MemDescType srcTy,
SharedMemoryObject smemObj, RewriterBase &rewriter,
const TargetInfoBase &targetInfo,
Operation *localLoadOp = nullptr);
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter);
Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter,
ValueRange resultVals, RewriterBase &rewriter, Type type);
SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
RewriterBase &rewriter);
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);
inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
return (index & freeVarMask) == 0;
}
void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
void fixUpLoopAnnotation(ModuleOp mod);
void transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
const TargetInfoBase &targetInfo,
const LLVMTypeConverter *typeConverter,
RewriterBase &rewriter);
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region ®ion,
ArrayRef<Value> args,
mlir::TypeID terminatorTypeId,
Location loc);
template <typename TerminatorOp>
SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region ®ion,
ArrayRef<Value> args, Location loc) {
return inlineRegionImpl(rewriter, region, args,
mlir::TypeID::get<TerminatorOp>(), loc);
}
void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultVals,
Type valueElemTy, TritonLLVMOpBuilder &b,
Value threadPred,
const TargetInfoBase &targetInfo,
const LLVMTypeConverter *typeConverter);
}
#endif