#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
#include "triton/Conversion/MLIRTypes.h"
namespace mlir::triton {
enum class ProgramIDDim : uint32_t;
class TargetInfoBase {
public:
virtual bool supportMaximumMinimum() const = 0;
virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0;
virtual Value ballot(RewriterBase &rewriter, Location loc, Type type,
Value cmp) const = 0;
virtual void barrier(Location loc, RewriterBase &rewriter,
bool isWarpSync = false) const = 0;
virtual void storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const = 0;
virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Type elemTy, Value pred,
Operation *localLoadOp = nullptr) const = 0;
void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred) const {
storeDShared(rewriter, loc, ptr, std::nullopt, val, pred);
}
Value loadShared(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred) const {
return loadDShared(rewriter, loc, ptr, std::nullopt, elemTy,
pred);
}
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
Value i) const = 0;
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
Value selector) const = 0;
virtual Value programId(RewriterBase &rewriter, Location loc,
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
unsigned interleave) const = 0;
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
virtual void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args,
ArrayRef<bool> isSigned = {}) const = 0;
virtual void printf(RewriterBase &rewriter, StringRef msg, ValueRange args,
ArrayRef<bool> isSigned = {}) const = 0;
virtual void assertFail(RewriterBase &rewriter, Location loc,
StringRef message, StringRef file, StringRef func,
int line) const = 0;
virtual int getSharedAddressSpace() const = 0;
virtual int getAddressSpace(Attribute addressSpace) const = 0;
virtual bool supportVectorizedAtomics() const = 0;
virtual bool supportLdMatrix() const { return false; }
virtual bool supportStMatrix() const { return false; }
virtual bool isCuda() const { return false; }
virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp,
Operation *llLoadOp) const {}
virtual ~TargetInfoBase() {}
};
}
#endif