#ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include <unordered_map>
using CacheKey = std::tuple<std::vector<int64_t>, mlir::Attribute>;
namespace llvm {
template <typename T> size_t hash_value(const std::vector<T> &vec) {
return hash_combine_range(vec.begin(), vec.end());
}
}
namespace std {
template <> struct hash<CacheKey> {
size_t operator()(const CacheKey &key) const noexcept {
using llvm::hash_value;
size_t seed = 0;
std::apply(
[&seed](const auto &...elems) {
((seed = llvm::hash_combine(seed, hash_value(elems))), ...);
},
key);
return seed;
}
};
}
namespace mlir::triton::gpu {
constexpr static char AttrMaxRegistersName[] = "ttg.maxnreg";
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
constexpr static char AttrTargetName[] = "ttg.target";
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
int lookupNumWarps(Operation *op);
std::optional<int> maybeLookupNumWarps(Operation *op);
int lookupThreadsPerWarp(OpBuilder &rewriter);
int lookupNumCTAs(OpBuilder &rewriter);
template <typename Key, typename Value> class Cache {
public:
std::optional<Value> get(const Key &key) {
std::shared_lock lock(mutex);
auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
return std::nullopt;
}
void set(Key key, Value result) {
std::scoped_lock lock(mutex);
cache.emplace(std::move(key), std::move(result));
}
private:
std::unordered_map<Key, Value> cache;
llvm::sys::SmartRWMutex<true> mutex;
};
using LinearLayoutCache = Cache<CacheKey, LinearLayout>;
using LinearEncodingCache = Cache<CacheKey, LinearEncodingAttr>;
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Dialect.h.inc"
#include "triton/Dialect/TritonGPU/IR/Ops.h.inc"
namespace mlir::triton::gpu {
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
LinearEncodingAttr toLinearEncoding(RankedTensorType type);
LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape);
unsigned getTotalElemsPerThread(Type type);
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<unsigned> getElemsPerThread(Type type);
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
ArrayRef<int64_t> tensorShape);
inline SmallVector<unsigned> getWarpsPerCTA(RankedTensorType type) {
return getWarpsPerCTA(type.getEncoding(), type.getShape());
}
SmallVector<unsigned> getContigPerThread(RankedTensorType tensorType);
SmallVector<unsigned> getThreadsPerWarp(Attribute layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getThreadsPerWarp(RankedTensorType type) {
return getThreadsPerWarp(type.getEncoding(), type.getShape());
}
SmallVector<unsigned> getOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getOrder(RankedTensorType type) {
return getOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
type.getShape());
}
SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getOrder(MemDescType type) {
return getOrder(cast<SharedEncodingTrait>(type.getEncoding()),
type.getShape());
}
inline SmallVector<unsigned> getOrder(TensorOrMemDesc type) {
if (auto memDesc = dyn_cast<MemDescType>(type)) {
return getOrder(memDesc);
} else {
auto tensorTy = cast<RankedTensorType>(type);
return getOrder(tensorTy);
}
}
SmallVector<unsigned> getOrderForMemory(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getOrderForMemory(RankedTensorType type) {
return getOrderForMemory(cast<DistributedEncodingTrait>(type.getEncoding()),
type.getShape());
}
inline SmallVector<unsigned> getOrderForMemory(TensorOrMemDesc type) {
if (auto memDesc = dyn_cast<MemDescType>(type)) {
return getOrder(memDesc);
} else {
auto tensorTy = cast<RankedTensorType>(type);
return getOrderForMemory(tensorTy);
}
}
SmallVector<unsigned> getWarpOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getWarpOrder(RankedTensorType type) {
return getWarpOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
type.getShape());
}
SmallVector<unsigned> getThreadOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape);
inline SmallVector<unsigned> getThreadOrder(RankedTensorType type) {
return getThreadOrder(cast<DistributedEncodingTrait>(type.getEncoding()),
type.getShape());
}
CTALayoutAttr getCTALayout(Attribute layout);
SmallVector<unsigned> getCTAsPerCGA(Attribute layout);
SmallVector<unsigned> getCTASplitNum(Attribute layout);
SmallVector<unsigned> getCTAOrder(Attribute layout);
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape);
SmallVector<int64_t> getShapePerCTA(Type type);
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
ArrayRef<int64_t> shape);
SmallVector<int64_t> getAllocationShapePerCTA(Type type);
unsigned getNumCTAs(Attribute layout);
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kContig);
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
bool isExpensiveView(Type srcType, Type dstType);
triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);
void dumpLayout(RankedTensorType tensorType);
void dumpHWLayout(RankedTensorType tensorType);
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);
template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);
llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);
bool areLayoutsEquivalent(ArrayRef<int64_t> shape, DistributedEncodingTrait lhs,
DistributedEncodingTrait rhs);
bool isInnermostContiguous(MemDescType type, unsigned numElems);
LinearLayout inferReshapeLinearLayout(TensorOrMemDesc srcTy,
ArrayRef<int64_t> dstShape);
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
ShapedType dstTy);
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
}
#endif